From c12305457b3c1e31723cd158e439ceb58566982c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Thu, 23 Feb 2023 14:26:55 +0100 Subject: [PATCH] some updates after feedback --- pkg/oidc/device_authorization.go | 3 +- pkg/oidc/error.go | 2 +- pkg/op/device.go | 164 ++++++++++++++++++++++--------- pkg/op/storage.go | 50 ++++++---- 4 files changed, 150 insertions(+), 69 deletions(-) diff --git a/pkg/oidc/device_authorization.go b/pkg/oidc/device_authorization.go index 58244cd..e8862d3 100644 --- a/pkg/oidc/device_authorization.go +++ b/pkg/oidc/device_authorization.go @@ -24,8 +24,7 @@ type DeviceAuthorizationResponse struct { // https://www.rfc-editor.org/rfc/rfc8628#section-3.4, // Device Access Token Request. type DeviceAccessTokenRequest struct { - JWTTokenRequest GrantType string `json:"grant_type"` DeviceCode string `json:"device_code"` - ClientID string `json:"client_id"` // required, how?? + ClientID string `json:"client_id"` } diff --git a/pkg/oidc/error.go b/pkg/oidc/error.go index b84b7f2..79acecd 100644 --- a/pkg/oidc/error.go +++ b/pkg/oidc/error.go @@ -105,7 +105,7 @@ var ( Description: "The authorization request was denied.", } } - ErrExpiredToken = func() *Error { + ErrExpiredDeviceCode = func() *Error { return &Error{ ErrorType: ExpiredToken, Description: "The \"device_code\" has expired.", diff --git a/pkg/op/device.go b/pkg/op/device.go index 438b78a..6bda9a6 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -4,11 +4,12 @@ import ( "context" "crypto/rand" "encoding/base64" + "errors" "fmt" "math/big" "net/http" - "net/url" "strings" + "time" httphelper "github.com/zitadel/oidc/v2/pkg/http" "github.com/zitadel/oidc/v2/pkg/oidc" @@ -17,6 +18,7 @@ import ( type DeviceAuthorizationConfig struct { Lifetime int PollInterval int + UserFormURL string UserCode UserCodeConfig } @@ -24,8 +26,6 @@ type UserCodeConfig struct { CharSet string CharAmount int DashInterval int - QueryKey string - FormHTML []byte } const ( @@ -38,13 +38,11 @@ var ( CharSet: CharSetBase20, CharAmount: 8, DashInterval: 4, - QueryKey: "user_code", } UserCodeDigits = UserCodeConfig{ CharSet: CharSetDigits, CharAmount: 9, DashInterval: 3, - QueryKey: "user_code", } ) @@ -55,10 +53,12 @@ func deviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *htt } func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) { - storage, ok := o.Storage().(DeviceCodeStorage) - if !ok { - // unimplemented error? + storage, err := assertDeviceStorage(o.Storage()) + if err != nil { + RequestError(w, r, err) + return } + req, err := ParseDeviceCodeRequest(r, o.Decoder()) if err != nil { RequestError(w, r, err) @@ -77,25 +77,20 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide RequestError(w, r, err) return } - err = storage.StoreDeviceAuthorizationRequest(r.Context(), req, deviceCode, userCode) + err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, req.Scopes) if err != nil { RequestError(w, r, err) return } - endpoint := o.UserCodeFormEndpoint().Absolute(IssuerFromContext(r.Context())) - response := &oidc.DeviceAuthorizationResponse{ DeviceCode: deviceCode, UserCode: userCode, - VerificationURI: endpoint, + VerificationURI: config.UserFormURL, } - if key := config.UserCode.QueryKey; key != "" { - vals := make(url.Values, 1) - vals.Set(key, userCode) - response.VerificationURIComplete = strings.Join([]string{endpoint, vals.Encode()}, "?") - } + endpoint := o.UserCodeFormEndpoint().Absolute(IssuerFromContext(r.Context())) + response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", endpoint, userCode) httphelper.MarshalJSON(w, response) } @@ -107,7 +102,7 @@ func ParseDeviceCodeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc. devReq := new(oidc.DeviceAuthorizationRequest) if err := decoder.Decode(devReq, r.Form); err != nil { - return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse dev auth request").WithParent(err) + return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err) } return devReq, nil @@ -151,23 +146,49 @@ func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) { return buf.String(), nil } +type deviceAccessTokenRequest struct { + subject string + audience []string + scopes []string +} + +func (r *deviceAccessTokenRequest) GetSubject() string { + return r.subject +} + +func (r *deviceAccessTokenRequest) GetAudience() []string { + return r.audience +} + +func (r *deviceAccessTokenRequest) GetScopes() []string { + return r.scopes +} + func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { req := new(oidc.DeviceAccessTokenRequest) if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil { RequestError(w, r, err) + return } - storage, ok := exchanger.Storage().(DeviceCodeStorage) - if !ok { - // unimplemented error? - } + // use a limited context timeout shorter as the default + // poll interval of 5 seconds. + ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second) + defer cancel() - client, err := storage.DeviceAccessPoll(r.Context(), req.DeviceCode) + state, err := CheckDeviceAuthorizationState(ctx, req, exchanger) if err != nil { RequestError(w, r, err) + return } - resp, err := CreateDeviceTokenResponse(r.Context(), req, exchanger, client) + tokenRequest := &deviceAccessTokenRequest{ + subject: state.Subject, + audience: []string{req.ClientID}, + scopes: state.Scopes, + } + + resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, &jwtProfileClient{id: req.ClientID}) if err != nil { RequestError(w, r, err) return @@ -175,18 +196,44 @@ func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang httphelper.MarshalJSON(w, resp) } -func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) { +func CheckDeviceAuthorizationState(ctx context.Context, req *oidc.DeviceAccessTokenRequest, exchanger Exchanger) (*DeviceAuthorizationState, error) { + storage, err := assertDeviceStorage(exchanger.Storage()) + if err != nil { + return nil, err + } + + state, err := storage.GetDeviceAuthorizatonState(ctx, req.ClientID, req.DeviceCode) + if errors.Is(err, context.DeadlineExceeded) { + return nil, oidc.ErrSlowDown().WithParent(err) + } + if err != nil { + return nil, err + } + if state.Denied { + return state, oidc.ErrAccessDenied() + } + if state.Completed { + return state, nil + } + if time.Now().After(state.Expires) { + return state, oidc.ErrExpiredDeviceCode() + } + return state, oidc.ErrAuthorizationPending() +} + +func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client AccessTokenClient) (*oidc.AccessTokenResponse, error) { tokenType := AccessTokenTypeBearer // not sure if this is the correct type? - accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "") + accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "") if err != nil { return nil, err } return &oidc.AccessTokenResponse{ - AccessToken: accessToken, - TokenType: oidc.BearerToken, - ExpiresIn: uint64(validity.Seconds()), + AccessToken: accessToken, + RefreshToken: refreshToken, + TokenType: oidc.BearerToken, + ExpiresIn: uint64(validity.Seconds()), }, nil } @@ -196,37 +243,62 @@ func userCodeFormHandler(o OpenIDProvider) http.HandlerFunc { } } -func UserCodeForm(w http.ResponseWriter, r *http.Request, o OpenIDProvider) { - // check cookie, or what?? +type UserCodeFormData struct { + AccesssToken string `schema:"access_token"` + UserCode string `schema:"user_code"` + RedirectURL string `schema:"redirect_url"` +} - config := o.DeviceAuthorization().UserCode - userCode, err := UserCodeFromRequest(r, config.QueryKey) +func UserCodeForm(w http.ResponseWriter, r *http.Request, o OpenIDProvider) { + data, err := ParseUserCodeFormData(r, o.Decoder()) if err != nil { RequestError(w, r, err) return } - if userCode == "" { - w.Write(config.FormHTML) - return - } - storage, ok := o.Storage().(DeviceCodeStorage) - if !ok { - // unimplemented error? - } - - if err := storage.ReleaseDeviceAccessToken(r.Context(), userCode); err != nil { + storage, err := assertDeviceStorage(o.Storage()) + if err != nil { RequestError(w, r, err) return } + ctx := r.Context() + token, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, data.AccesssToken, o.AccessTokenVerifier(ctx)) + if err != nil { + if se := storage.DenyDeviceAuthorization(ctx, data.UserCode); se != nil { + err = se + } + RequestError(w, r, err) + return + } + + if err := storage.CompleteDeviceAuthorization(ctx, data.UserCode, token.Subject); err != nil { + RequestError(w, r, err) + return + } + + if data.RedirectURL != "" { + http.Redirect(w, r, data.RedirectURL, http.StatusSeeOther) + } + fmt.Fprintln(w, "Authorization successfull, please return to your device") } -func UserCodeFromRequest(r *http.Request, key string) (string, error) { +func ParseUserCodeFormData(r *http.Request, decoder httphelper.Decoder) (*UserCodeFormData, error) { if err := r.ParseForm(); err != nil { - return "", oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) + return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) } - return r.Form.Get(key), nil + req := new(UserCodeFormData) + if err := decoder.Decode(req, r.Form); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse user code form").WithParent(err) + } + if req.AccesssToken == "" { + return nil, oidc.ErrInvalidRequest().WithDescription("access_token missing in form") + } + if req.UserCode == "" { + return nil, oidc.ErrInvalidRequest().WithDescription("user_code missing in form") + } + + return req, nil } diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 69b05b7..b0d31de 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -154,32 +154,42 @@ type EndSessionRequest struct { var ErrDuplicateUserCode = errors.New("user code already exists") -type DeviceCodeStorage interface { +type DeviceAuthorizationState struct { + Scopes []string + Expires time.Time + Completed bool + Subject string + Denied bool +} + +type DeviceAuthorizationStorage interface { // StoreDeviceAuthorizationRequest stores a new device authorization request in the database. // User code will be used by the user to complete the login flow and must be unique. // ErrDuplicateUserCode signals the caller should try again with a new code. // // Note that user codes are low entropy keys and when many exist in the // database, the change for collisions increases. Therefore implementers - // of this interface must make sure that user codes of completed or expired - // authentication flows are deleted. - StoreDeviceAuthorizationRequest(ctx context.Context, req *oidc.DeviceAuthorizationRequest, deviceCode, userCode string) error + // of this interface must make sure that user codes of expired authentication flows are purged, + // after some time. + StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, scopes []string) error - // DeviceAccessPoll is called by the device untill the authorization flow is - // completed or expired. - // - // The following errors are defined for the Device Authorization workflow, - // that can be returned by this method: - // - oidc.ErrAuthorizationPending should be returned on each poll, while the flow is not completed by the user. - // - oidc.ErrSlowDown signals to the device that the polling interval is to be increased by 5 seconds. - // - oidc.ErrAccessDenied when the authorization request is denied. - // - oidc.ErrExpiredToken when the device code has expired. - // - // A token should be returned once the authorization flow is completed - // by the user. - DeviceAccessPoll(ctx context.Context, deviceCode string) (Client, error) + // GetDeviceAuthorizatonState returns the current state of the device authorization flow in the database. + // The method is polled untill the the authorization is eighter Completed, Expired or Denied. + GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error) - // ReleaseDeviceAccessToken releases DeviceAccessPoll to return the Access Token, - // destined for a user code. - ReleaseDeviceAccessToken(ctx context.Context, userCode string) error + // CompleteDeviceAuthorization marks a device authorization entry as Completed, + // identified by userCode. The Subject is added to the state, so that + // GetDeviceAuthorizatonState can use it to create a new Access Token. + CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error + + // DenyDeviceAuthorization marks a device authorization entry as Denied. + DenyDeviceAuthorization(ctx context.Context, userCode string) error +} + +func assertDeviceStorage(s Storage) (DeviceAuthorizationStorage, error) { + storage, ok := s.(DeviceAuthorizationStorage) + if !ok { + return nil, oidc.ErrUnsupportedGrantType().WithDescription("device_code grant not supported") + } + return storage, nil }