From a80ad6df8a6faa12f8eb2af6f7f6da6cda9ae142 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Tue, 28 Feb 2023 12:03:28 +0200 Subject: [PATCH] Revert "use AuthRequest code flow to create device tokens" This reverts commit b88539846610e1ea7dd3f7f520e06be0d53a907e. --- pkg/op/config.go | 2 +- pkg/op/device.go | 135 ++++++++++++++++++++---------- pkg/op/mock/configuration.mock.go | 12 +-- pkg/op/op.go | 68 +++++++-------- pkg/op/storage.go | 16 ++-- 5 files changed, 139 insertions(+), 94 deletions(-) diff --git a/pkg/op/config.go b/pkg/op/config.go index 233d11f..d339a30 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -28,7 +28,7 @@ type Configuration interface { EndSessionEndpoint() Endpoint KeysEndpoint() Endpoint DeviceAuthorizationEndpoint() Endpoint - UserCodeVerificationEndpoint() Endpoint + UserCodeFormEndpoint() Endpoint AuthMethodPostSupported() bool CodeMethodS256Supported() bool diff --git a/pkg/op/device.go b/pkg/op/device.go index 622e27a..ab276b8 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -88,7 +88,7 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide VerificationURI: config.UserFormURL, } - endpoint := o.UserCodeVerificationEndpoint().Absolute(IssuerFromContext(r.Context())) + endpoint := o.UserCodeFormEndpoint().Absolute(IssuerFromContext(r.Context())) response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", endpoint, userCode) httphelper.MarshalJSON(w, response) @@ -148,6 +148,24 @@ func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) { return buf.String(), nil } +type deviceAccessTokenRequest struct { + subject string + audience []string + scopes []string +} + +func (r *deviceAccessTokenRequest) GetSubject() string { + return r.subject +} + +func (r *deviceAccessTokenRequest) GetAudience() []string { + return r.audience +} + +func (r *deviceAccessTokenRequest) GetScopes() []string { + return r.scopes +} + func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { if err := deviceAccessToken(w, r, exchanger); err != nil { RequestError(w, r, err) @@ -161,7 +179,7 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang defer cancel() r = r.WithContext(ctx) - clientID, clientAuthenticated, err := ClientIDFromRequest(r, exchanger) + clientID, authenticated, err := ClientIDFromRequest(r, exchanger) if err != nil { return err } @@ -170,7 +188,7 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang if err != nil { return err } - state, authReq, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger) + state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger) if err != nil { return err } @@ -179,14 +197,19 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang if err != nil { return err } - if !clientAuthenticated { + if !authenticated { if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client? return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials). WithDescription(fmt.Sprintf("required client auth method: %s", m)) } } - resp, err := CreateTokenResponse(ctx, authReq, client, exchanger, true, state.AuthCode, "") + tokenRequest := &deviceAccessTokenRequest{ + subject: state.Subject, + audience: []string{clientID}, + scopes: state.Scopes, + } + resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client) if err != nil { return err } @@ -203,88 +226,108 @@ func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc. return req, nil } -func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, AuthRequest, error) { +func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) { storage, err := assertDeviceStorage(exchanger.Storage()) if err != nil { - return nil, nil, err + return nil, err } state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode) if errors.Is(err, context.DeadlineExceeded) { - return nil, nil, oidc.ErrSlowDown().WithParent(err) + return nil, oidc.ErrSlowDown().WithParent(err) } if err != nil { - return nil, nil, err + return nil, err } if state.Denied { - return state, nil, oidc.ErrAccessDenied() + return state, oidc.ErrAccessDenied() } - if state.AuthCode != "" { - return state, nil, nil + if state.Completed { + return state, nil } if time.Now().After(state.Expires) { - return state, nil, oidc.ErrExpiredDeviceCode() + return state, oidc.ErrExpiredDeviceCode() } - authReq, err := AuthRequestByCode(ctx, exchanger.Storage(), state.AuthCode) - return state, authReq, err + return state, oidc.ErrAuthorizationPending() } -func userCodeVerificationHandler(o OpenIDProvider) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - UserCodeVerification(w, r, o) - } -} +func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client AccessTokenClient) (*oidc.AccessTokenResponse, error) { + tokenType := AccessTokenTypeBearer // not sure if this is the correct type? -type UserCodeVerificationRequest struct { - Code string `schema:"code"` - UserCode string `schema:"user_code"` - RedirectURL string `schema:"redirect_url"` -} - -func UserCodeVerification(w http.ResponseWriter, r *http.Request, o OpenIDProvider) { - if err := userCodeVerification(w, r, o); err != nil { - RequestError(w, r, err) - } -} - -func userCodeVerification(w http.ResponseWriter, r *http.Request, o OpenIDProvider) (err error) { - req, err := ParseUserCodeVerificationRequest(r, o.Decoder()) + accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "") if err != nil { - return err + return nil, err + } + + return &oidc.AccessTokenResponse{ + AccessToken: accessToken, + RefreshToken: refreshToken, + TokenType: oidc.BearerToken, + ExpiresIn: uint64(validity.Seconds()), + }, nil +} + +func userCodeFormHandler(o OpenIDProvider) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + UserCodeForm(w, r, o) + } +} + +type UserCodeFormData struct { + AccesssToken string `schema:"access_token"` + UserCode string `schema:"user_code"` + RedirectURL string `schema:"redirect_url"` +} + +func UserCodeForm(w http.ResponseWriter, r *http.Request, o OpenIDProvider) { + data, err := ParseUserCodeFormData(r, o.Decoder()) + if err != nil { + RequestError(w, r, err) + return } storage, err := assertDeviceStorage(o.Storage()) if err != nil { - return err + RequestError(w, r, err) + return } ctx := r.Context() - if err := storage.CompleteDeviceAuthorization(ctx, req.Code, req.UserCode); err != nil { - return err + token, err := VerifyAccessToken(ctx, data.AccesssToken, o.AccessTokenVerifier(ctx)) + if err != nil { + if se := storage.DenyDeviceAuthorization(ctx, data.UserCode); se != nil { + err = se + } + RequestError(w, r, err) + return } - if req.RedirectURL != "" { - http.Redirect(w, r, req.RedirectURL, http.StatusSeeOther) + if err := storage.CompleteDeviceAuthorization(ctx, data.UserCode, token.GetSubject()); err != nil { + RequestError(w, r, err) + return + } + + if data.RedirectURL != "" { + http.Redirect(w, r, data.RedirectURL, http.StatusSeeOther) } fmt.Fprintln(w, "Authorization successfull, please return to your device") - return nil } -func ParseUserCodeVerificationRequest(r *http.Request, decoder httphelper.Decoder) (*UserCodeVerificationRequest, error) { +func ParseUserCodeFormData(r *http.Request, decoder httphelper.Decoder) (*UserCodeFormData, error) { if err := r.ParseForm(); err != nil { return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) } - req := new(UserCodeVerificationRequest) + req := new(UserCodeFormData) if err := decoder.Decode(req, r.Form); err != nil { return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse user code form").WithParent(err) } - if req.Code == "" { - return nil, oidc.ErrInvalidRequest().WithDescription("\"code\" missing in form") + if req.AccesssToken == "" { + return nil, oidc.ErrInvalidRequest().WithDescription("access_token missing in form") } if req.UserCode == "" { - return nil, oidc.ErrInvalidRequest().WithDescription("\"user_code\" missing in form") + return nil, oidc.ErrInvalidRequest().WithDescription("user_code missing in form") } return req, nil diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go index c17d243..44b5ceb 100644 --- a/pkg/op/mock/configuration.mock.go +++ b/pkg/op/mock/configuration.mock.go @@ -400,18 +400,18 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported)) } -// UserCodeVerificationEndpoint mocks base method. -func (m *MockConfiguration) UserCodeVerificationEndpoint() op.Endpoint { +// UserCodeFormEndpoint mocks base method. +func (m *MockConfiguration) UserCodeFormEndpoint() op.Endpoint { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UserCodeVerificationEndpoint") + ret := m.ctrl.Call(m, "UserCodeFormEndpoint") ret0, _ := ret[0].(op.Endpoint) return ret0 } -// UserCodeVerificationEndpoint indicates an expected call of UserCodeVerificationEndpoint. -func (mr *MockConfigurationMockRecorder) UserCodeVerificationEndpoint() *gomock.Call { +// UserCodeFormEndpoint indicates an expected call of UserCodeFormEndpoint. +func (mr *MockConfigurationMockRecorder) UserCodeFormEndpoint() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeVerificationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserCodeVerificationEndpoint)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeFormEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserCodeFormEndpoint)) } // UserinfoEndpoint mocks base method. diff --git a/pkg/op/op.go b/pkg/op/op.go index b0b419e..a618dc0 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -17,31 +17,31 @@ import ( ) const ( - healthEndpoint = "/healthz" - readinessEndpoint = "/ready" - authCallbackPathSuffix = "/callback" - defaultAuthorizationEndpoint = "authorize" - defaultTokenEndpoint = "oauth/token" - defaultIntrospectEndpoint = "oauth/introspect" - defaultUserinfoEndpoint = "userinfo" - defaultRevocationEndpoint = "revoke" - defaultEndSessionEndpoint = "end_session" - defaultKeysEndpoint = "keys" - defaultDeviceAuthzEndpoint = "/device_authorization" - defaultUserCodeVerificationEndpoint = "/user_code" + healthEndpoint = "/healthz" + readinessEndpoint = "/ready" + authCallbackPathSuffix = "/callback" + defaultAuthorizationEndpoint = "authorize" + defaultTokenEndpoint = "oauth/token" + defaultIntrospectEndpoint = "oauth/introspect" + defaultUserinfoEndpoint = "userinfo" + defaultRevocationEndpoint = "revoke" + defaultEndSessionEndpoint = "end_session" + defaultKeysEndpoint = "keys" + defaultDeviceAuthzEndpoint = "/device_authorization" + defaultUserCodeFormEndpoint = "/submit_user_code" ) var ( DefaultEndpoints = &endpoints{ - Authorization: NewEndpoint(defaultAuthorizationEndpoint), - Token: NewEndpoint(defaultTokenEndpoint), - Introspection: NewEndpoint(defaultIntrospectEndpoint), - Userinfo: NewEndpoint(defaultUserinfoEndpoint), - Revocation: NewEndpoint(defaultRevocationEndpoint), - EndSession: NewEndpoint(defaultEndSessionEndpoint), - JwksURI: NewEndpoint(defaultKeysEndpoint), - DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint), - UserCodeVerification: NewEndpoint(defaultUserCodeVerificationEndpoint), + Authorization: NewEndpoint(defaultAuthorizationEndpoint), + Token: NewEndpoint(defaultTokenEndpoint), + Introspection: NewEndpoint(defaultIntrospectEndpoint), + Userinfo: NewEndpoint(defaultUserinfoEndpoint), + Revocation: NewEndpoint(defaultRevocationEndpoint), + EndSession: NewEndpoint(defaultEndSessionEndpoint), + JwksURI: NewEndpoint(defaultKeysEndpoint), + DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint), + UserCodeForm: NewEndpoint(defaultUserCodeFormEndpoint), } defaultCORSOptions = cors.Options{ @@ -100,7 +100,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o)) router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage())) router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), deviceAuthorizationHandler(o)) - router.HandleFunc(o.UserCodeVerificationEndpoint().Relative(), userCodeVerificationHandler(o)) + router.HandleFunc(o.UserCodeFormEndpoint().Relative(), userCodeFormHandler(o)) return router } @@ -128,16 +128,16 @@ type Config struct { } type endpoints struct { - Authorization Endpoint - Token Endpoint - Introspection Endpoint - Userinfo Endpoint - Revocation Endpoint - EndSession Endpoint - CheckSessionIframe Endpoint - JwksURI Endpoint - DeviceAuthorization Endpoint - UserCodeVerification Endpoint + Authorization Endpoint + Token Endpoint + Introspection Endpoint + Userinfo Endpoint + Revocation Endpoint + EndSession Endpoint + CheckSessionIframe Endpoint + JwksURI Endpoint + DeviceAuthorization Endpoint + UserCodeForm Endpoint } // NewOpenIDProvider creates a provider. The provider provides (with HttpHandler()) @@ -256,8 +256,8 @@ func (o *Provider) DeviceAuthorizationEndpoint() Endpoint { return o.endpoints.DeviceAuthorization } -func (o *Provider) UserCodeVerificationEndpoint() Endpoint { - return o.endpoints.UserCodeVerification +func (o *Provider) UserCodeFormEndpoint() Endpoint { + return o.endpoints.UserCodeForm } func (o *Provider) KeysEndpoint() Endpoint { diff --git a/pkg/op/storage.go b/pkg/op/storage.go index f0a5425..9054844 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -155,10 +155,11 @@ type EndSessionRequest struct { var ErrDuplicateUserCode = errors.New("user code already exists") type DeviceAuthorizationState struct { - Scopes []string - Expires time.Time - AuthCode string - Denied bool + Scopes []string + Expires time.Time + Completed bool + Subject string + Denied bool } type DeviceAuthorizationStorage interface { @@ -176,9 +177,10 @@ type DeviceAuthorizationStorage interface { // The method is polled untill the the authorization is eighter Completed, Expired or Denied. GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error) - // CompleteDeviceAuthorization marks a device authorization entry identified by userCode - // as completed, by setting the related authCode from an AuthRequest. - CompleteDeviceAuthorization(ctx context.Context, authCode, userCode string) error + // CompleteDeviceAuthorization marks a device authorization entry as Completed, + // identified by userCode. The Subject is added to the state, so that + // GetDeviceAuthorizatonState can use it to create a new Access Token. + CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error // DenyDeviceAuthorization marks a device authorization entry as Denied. DenyDeviceAuthorization(ctx context.Context, userCode string) error