From b88539846610e1ea7dd3f7f520e06be0d53a907e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Mon, 27 Feb 2023 08:18:33 +0100 Subject: [PATCH] use AuthRequest code flow to create device tokens --- pkg/op/config.go | 2 +- pkg/op/device.go | 125 ++++++++++-------------------- pkg/op/mock/configuration.mock.go | 12 +-- pkg/op/op.go | 68 ++++++++-------- pkg/op/storage.go | 16 ++-- 5 files changed, 89 insertions(+), 134 deletions(-) diff --git a/pkg/op/config.go b/pkg/op/config.go index d339a30..233d11f 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -28,7 +28,7 @@ type Configuration interface { EndSessionEndpoint() Endpoint KeysEndpoint() Endpoint DeviceAuthorizationEndpoint() Endpoint - UserCodeFormEndpoint() Endpoint + UserCodeVerificationEndpoint() Endpoint AuthMethodPostSupported() bool CodeMethodS256Supported() bool diff --git a/pkg/op/device.go b/pkg/op/device.go index ab276b8..622e27a 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -88,7 +88,7 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide VerificationURI: config.UserFormURL, } - endpoint := o.UserCodeFormEndpoint().Absolute(IssuerFromContext(r.Context())) + endpoint := o.UserCodeVerificationEndpoint().Absolute(IssuerFromContext(r.Context())) response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", endpoint, userCode) httphelper.MarshalJSON(w, response) @@ -148,24 +148,6 @@ func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) { return buf.String(), nil } -type deviceAccessTokenRequest struct { - subject string - audience []string - scopes []string -} - -func (r *deviceAccessTokenRequest) GetSubject() string { - return r.subject -} - -func (r *deviceAccessTokenRequest) GetAudience() []string { - return r.audience -} - -func (r *deviceAccessTokenRequest) GetScopes() []string { - return r.scopes -} - func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { if err := deviceAccessToken(w, r, exchanger); err != nil { RequestError(w, r, err) @@ -179,7 +161,7 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang defer cancel() r = r.WithContext(ctx) - clientID, authenticated, err := ClientIDFromRequest(r, exchanger) + clientID, clientAuthenticated, err := ClientIDFromRequest(r, exchanger) if err != nil { return err } @@ -188,7 +170,7 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang if err != nil { return err } - state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger) + state, authReq, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger) if err != nil { return err } @@ -197,19 +179,14 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang if err != nil { return err } - if !authenticated { + if !clientAuthenticated { if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client? return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials). WithDescription(fmt.Sprintf("required client auth method: %s", m)) } } - tokenRequest := &deviceAccessTokenRequest{ - subject: state.Subject, - audience: []string{clientID}, - scopes: state.Scopes, - } - resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client) + resp, err := CreateTokenResponse(ctx, authReq, client, exchanger, true, state.AuthCode, "") if err != nil { return err } @@ -226,108 +203,88 @@ func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc. return req, nil } -func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) { +func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, AuthRequest, error) { storage, err := assertDeviceStorage(exchanger.Storage()) if err != nil { - return nil, err + return nil, nil, err } state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode) if errors.Is(err, context.DeadlineExceeded) { - return nil, oidc.ErrSlowDown().WithParent(err) + return nil, nil, oidc.ErrSlowDown().WithParent(err) } if err != nil { - return nil, err + return nil, nil, err } if state.Denied { - return state, oidc.ErrAccessDenied() + return state, nil, oidc.ErrAccessDenied() } - if state.Completed { - return state, nil + if state.AuthCode != "" { + return state, nil, nil } if time.Now().After(state.Expires) { - return state, oidc.ErrExpiredDeviceCode() + return state, nil, oidc.ErrExpiredDeviceCode() } - return state, oidc.ErrAuthorizationPending() + authReq, err := AuthRequestByCode(ctx, exchanger.Storage(), state.AuthCode) + return state, authReq, err } -func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client AccessTokenClient) (*oidc.AccessTokenResponse, error) { - tokenType := AccessTokenTypeBearer // not sure if this is the correct type? - - accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "") - if err != nil { - return nil, err - } - - return &oidc.AccessTokenResponse{ - AccessToken: accessToken, - RefreshToken: refreshToken, - TokenType: oidc.BearerToken, - ExpiresIn: uint64(validity.Seconds()), - }, nil -} - -func userCodeFormHandler(o OpenIDProvider) http.HandlerFunc { +func userCodeVerificationHandler(o OpenIDProvider) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - UserCodeForm(w, r, o) + UserCodeVerification(w, r, o) } } -type UserCodeFormData struct { - AccesssToken string `schema:"access_token"` - UserCode string `schema:"user_code"` - RedirectURL string `schema:"redirect_url"` +type UserCodeVerificationRequest struct { + Code string `schema:"code"` + UserCode string `schema:"user_code"` + RedirectURL string `schema:"redirect_url"` } -func UserCodeForm(w http.ResponseWriter, r *http.Request, o OpenIDProvider) { - data, err := ParseUserCodeFormData(r, o.Decoder()) - if err != nil { +func UserCodeVerification(w http.ResponseWriter, r *http.Request, o OpenIDProvider) { + if err := userCodeVerification(w, r, o); err != nil { RequestError(w, r, err) - return + } +} + +func userCodeVerification(w http.ResponseWriter, r *http.Request, o OpenIDProvider) (err error) { + req, err := ParseUserCodeVerificationRequest(r, o.Decoder()) + if err != nil { + return err } storage, err := assertDeviceStorage(o.Storage()) if err != nil { - RequestError(w, r, err) - return + return err } ctx := r.Context() - token, err := VerifyAccessToken(ctx, data.AccesssToken, o.AccessTokenVerifier(ctx)) - if err != nil { - if se := storage.DenyDeviceAuthorization(ctx, data.UserCode); se != nil { - err = se - } - RequestError(w, r, err) - return + if err := storage.CompleteDeviceAuthorization(ctx, req.Code, req.UserCode); err != nil { + return err } - if err := storage.CompleteDeviceAuthorization(ctx, data.UserCode, token.GetSubject()); err != nil { - RequestError(w, r, err) - return - } - - if data.RedirectURL != "" { - http.Redirect(w, r, data.RedirectURL, http.StatusSeeOther) + if req.RedirectURL != "" { + http.Redirect(w, r, req.RedirectURL, http.StatusSeeOther) } fmt.Fprintln(w, "Authorization successfull, please return to your device") + return nil } -func ParseUserCodeFormData(r *http.Request, decoder httphelper.Decoder) (*UserCodeFormData, error) { +func ParseUserCodeVerificationRequest(r *http.Request, decoder httphelper.Decoder) (*UserCodeVerificationRequest, error) { if err := r.ParseForm(); err != nil { return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) } - req := new(UserCodeFormData) + req := new(UserCodeVerificationRequest) if err := decoder.Decode(req, r.Form); err != nil { return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse user code form").WithParent(err) } - if req.AccesssToken == "" { - return nil, oidc.ErrInvalidRequest().WithDescription("access_token missing in form") + if req.Code == "" { + return nil, oidc.ErrInvalidRequest().WithDescription("\"code\" missing in form") } if req.UserCode == "" { - return nil, oidc.ErrInvalidRequest().WithDescription("user_code missing in form") + return nil, oidc.ErrInvalidRequest().WithDescription("\"user_code\" missing in form") } return req, nil diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go index 44b5ceb..c17d243 100644 --- a/pkg/op/mock/configuration.mock.go +++ b/pkg/op/mock/configuration.mock.go @@ -400,18 +400,18 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported)) } -// UserCodeFormEndpoint mocks base method. -func (m *MockConfiguration) UserCodeFormEndpoint() op.Endpoint { +// UserCodeVerificationEndpoint mocks base method. +func (m *MockConfiguration) UserCodeVerificationEndpoint() op.Endpoint { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UserCodeFormEndpoint") + ret := m.ctrl.Call(m, "UserCodeVerificationEndpoint") ret0, _ := ret[0].(op.Endpoint) return ret0 } -// UserCodeFormEndpoint indicates an expected call of UserCodeFormEndpoint. -func (mr *MockConfigurationMockRecorder) UserCodeFormEndpoint() *gomock.Call { +// UserCodeVerificationEndpoint indicates an expected call of UserCodeVerificationEndpoint. +func (mr *MockConfigurationMockRecorder) UserCodeVerificationEndpoint() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeFormEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserCodeFormEndpoint)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeVerificationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserCodeVerificationEndpoint)) } // UserinfoEndpoint mocks base method. diff --git a/pkg/op/op.go b/pkg/op/op.go index a618dc0..b0b419e 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -17,31 +17,31 @@ import ( ) const ( - healthEndpoint = "/healthz" - readinessEndpoint = "/ready" - authCallbackPathSuffix = "/callback" - defaultAuthorizationEndpoint = "authorize" - defaultTokenEndpoint = "oauth/token" - defaultIntrospectEndpoint = "oauth/introspect" - defaultUserinfoEndpoint = "userinfo" - defaultRevocationEndpoint = "revoke" - defaultEndSessionEndpoint = "end_session" - defaultKeysEndpoint = "keys" - defaultDeviceAuthzEndpoint = "/device_authorization" - defaultUserCodeFormEndpoint = "/submit_user_code" + healthEndpoint = "/healthz" + readinessEndpoint = "/ready" + authCallbackPathSuffix = "/callback" + defaultAuthorizationEndpoint = "authorize" + defaultTokenEndpoint = "oauth/token" + defaultIntrospectEndpoint = "oauth/introspect" + defaultUserinfoEndpoint = "userinfo" + defaultRevocationEndpoint = "revoke" + defaultEndSessionEndpoint = "end_session" + defaultKeysEndpoint = "keys" + defaultDeviceAuthzEndpoint = "/device_authorization" + defaultUserCodeVerificationEndpoint = "/user_code" ) var ( DefaultEndpoints = &endpoints{ - Authorization: NewEndpoint(defaultAuthorizationEndpoint), - Token: NewEndpoint(defaultTokenEndpoint), - Introspection: NewEndpoint(defaultIntrospectEndpoint), - Userinfo: NewEndpoint(defaultUserinfoEndpoint), - Revocation: NewEndpoint(defaultRevocationEndpoint), - EndSession: NewEndpoint(defaultEndSessionEndpoint), - JwksURI: NewEndpoint(defaultKeysEndpoint), - DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint), - UserCodeForm: NewEndpoint(defaultUserCodeFormEndpoint), + Authorization: NewEndpoint(defaultAuthorizationEndpoint), + Token: NewEndpoint(defaultTokenEndpoint), + Introspection: NewEndpoint(defaultIntrospectEndpoint), + Userinfo: NewEndpoint(defaultUserinfoEndpoint), + Revocation: NewEndpoint(defaultRevocationEndpoint), + EndSession: NewEndpoint(defaultEndSessionEndpoint), + JwksURI: NewEndpoint(defaultKeysEndpoint), + DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint), + UserCodeVerification: NewEndpoint(defaultUserCodeVerificationEndpoint), } defaultCORSOptions = cors.Options{ @@ -100,7 +100,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o)) router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage())) router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), deviceAuthorizationHandler(o)) - router.HandleFunc(o.UserCodeFormEndpoint().Relative(), userCodeFormHandler(o)) + router.HandleFunc(o.UserCodeVerificationEndpoint().Relative(), userCodeVerificationHandler(o)) return router } @@ -128,16 +128,16 @@ type Config struct { } type endpoints struct { - Authorization Endpoint - Token Endpoint - Introspection Endpoint - Userinfo Endpoint - Revocation Endpoint - EndSession Endpoint - CheckSessionIframe Endpoint - JwksURI Endpoint - DeviceAuthorization Endpoint - UserCodeForm Endpoint + Authorization Endpoint + Token Endpoint + Introspection Endpoint + Userinfo Endpoint + Revocation Endpoint + EndSession Endpoint + CheckSessionIframe Endpoint + JwksURI Endpoint + DeviceAuthorization Endpoint + UserCodeVerification Endpoint } // NewOpenIDProvider creates a provider. The provider provides (with HttpHandler()) @@ -256,8 +256,8 @@ func (o *Provider) DeviceAuthorizationEndpoint() Endpoint { return o.endpoints.DeviceAuthorization } -func (o *Provider) UserCodeFormEndpoint() Endpoint { - return o.endpoints.UserCodeForm +func (o *Provider) UserCodeVerificationEndpoint() Endpoint { + return o.endpoints.UserCodeVerification } func (o *Provider) KeysEndpoint() Endpoint { diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 9054844..f0a5425 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -155,11 +155,10 @@ type EndSessionRequest struct { var ErrDuplicateUserCode = errors.New("user code already exists") type DeviceAuthorizationState struct { - Scopes []string - Expires time.Time - Completed bool - Subject string - Denied bool + Scopes []string + Expires time.Time + AuthCode string + Denied bool } type DeviceAuthorizationStorage interface { @@ -177,10 +176,9 @@ type DeviceAuthorizationStorage interface { // The method is polled untill the the authorization is eighter Completed, Expired or Denied. GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error) - // CompleteDeviceAuthorization marks a device authorization entry as Completed, - // identified by userCode. The Subject is added to the state, so that - // GetDeviceAuthorizatonState can use it to create a new Access Token. - CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error + // CompleteDeviceAuthorization marks a device authorization entry identified by userCode + // as completed, by setting the related authCode from an AuthRequest. + CompleteDeviceAuthorization(ctx context.Context, authCode, userCode string) error // DenyDeviceAuthorization marks a device authorization entry as Denied. DenyDeviceAuthorization(ctx context.Context, userCode string) error