use AuthRequest code flow to create device tokens

This commit is contained in:
Tim Möhlmann 2023-02-27 08:18:33 +01:00
parent 65cd4528e4
commit b885398466
5 changed files with 89 additions and 134 deletions

View file

@ -28,7 +28,7 @@ type Configuration interface {
EndSessionEndpoint() Endpoint EndSessionEndpoint() Endpoint
KeysEndpoint() Endpoint KeysEndpoint() Endpoint
DeviceAuthorizationEndpoint() Endpoint DeviceAuthorizationEndpoint() Endpoint
UserCodeFormEndpoint() Endpoint UserCodeVerificationEndpoint() Endpoint
AuthMethodPostSupported() bool AuthMethodPostSupported() bool
CodeMethodS256Supported() bool CodeMethodS256Supported() bool

View file

@ -88,7 +88,7 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
VerificationURI: config.UserFormURL, VerificationURI: config.UserFormURL,
} }
endpoint := o.UserCodeFormEndpoint().Absolute(IssuerFromContext(r.Context())) endpoint := o.UserCodeVerificationEndpoint().Absolute(IssuerFromContext(r.Context()))
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", endpoint, userCode) response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", endpoint, userCode)
httphelper.MarshalJSON(w, response) httphelper.MarshalJSON(w, response)
@ -148,24 +148,6 @@ func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) {
return buf.String(), nil return buf.String(), nil
} }
type deviceAccessTokenRequest struct {
subject string
audience []string
scopes []string
}
func (r *deviceAccessTokenRequest) GetSubject() string {
return r.subject
}
func (r *deviceAccessTokenRequest) GetAudience() []string {
return r.audience
}
func (r *deviceAccessTokenRequest) GetScopes() []string {
return r.scopes
}
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
if err := deviceAccessToken(w, r, exchanger); err != nil { if err := deviceAccessToken(w, r, exchanger); err != nil {
RequestError(w, r, err) RequestError(w, r, err)
@ -179,7 +161,7 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
defer cancel() defer cancel()
r = r.WithContext(ctx) r = r.WithContext(ctx)
clientID, authenticated, err := ClientIDFromRequest(r, exchanger) clientID, clientAuthenticated, err := ClientIDFromRequest(r, exchanger)
if err != nil { if err != nil {
return err return err
} }
@ -188,7 +170,7 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
if err != nil { if err != nil {
return err return err
} }
state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger) state, authReq, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
if err != nil { if err != nil {
return err return err
} }
@ -197,19 +179,14 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
if err != nil { if err != nil {
return err return err
} }
if !authenticated { if !clientAuthenticated {
if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client? if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client?
return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials). return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
WithDescription(fmt.Sprintf("required client auth method: %s", m)) WithDescription(fmt.Sprintf("required client auth method: %s", m))
} }
} }
tokenRequest := &deviceAccessTokenRequest{ resp, err := CreateTokenResponse(ctx, authReq, client, exchanger, true, state.AuthCode, "")
subject: state.Subject,
audience: []string{clientID},
scopes: state.Scopes,
}
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
if err != nil { if err != nil {
return err return err
} }
@ -226,108 +203,88 @@ func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.
return req, nil return req, nil
} }
func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) { func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, AuthRequest, error) {
storage, err := assertDeviceStorage(exchanger.Storage()) storage, err := assertDeviceStorage(exchanger.Storage())
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode) state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode)
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {
return nil, oidc.ErrSlowDown().WithParent(err) return nil, nil, oidc.ErrSlowDown().WithParent(err)
} }
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if state.Denied { if state.Denied {
return state, oidc.ErrAccessDenied() return state, nil, oidc.ErrAccessDenied()
} }
if state.Completed { if state.AuthCode != "" {
return state, nil return state, nil, nil
} }
if time.Now().After(state.Expires) { if time.Now().After(state.Expires) {
return state, oidc.ErrExpiredDeviceCode() return state, nil, oidc.ErrExpiredDeviceCode()
} }
return state, oidc.ErrAuthorizationPending() authReq, err := AuthRequestByCode(ctx, exchanger.Storage(), state.AuthCode)
return state, authReq, err
} }
func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client AccessTokenClient) (*oidc.AccessTokenResponse, error) { func userCodeVerificationHandler(o OpenIDProvider) http.HandlerFunc {
tokenType := AccessTokenTypeBearer // not sure if this is the correct type?
accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "")
if err != nil {
return nil, err
}
return &oidc.AccessTokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: oidc.BearerToken,
ExpiresIn: uint64(validity.Seconds()),
}, nil
}
func userCodeFormHandler(o OpenIDProvider) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
UserCodeForm(w, r, o) UserCodeVerification(w, r, o)
} }
} }
type UserCodeFormData struct { type UserCodeVerificationRequest struct {
AccesssToken string `schema:"access_token"` Code string `schema:"code"`
UserCode string `schema:"user_code"` UserCode string `schema:"user_code"`
RedirectURL string `schema:"redirect_url"` RedirectURL string `schema:"redirect_url"`
} }
func UserCodeForm(w http.ResponseWriter, r *http.Request, o OpenIDProvider) { func UserCodeVerification(w http.ResponseWriter, r *http.Request, o OpenIDProvider) {
data, err := ParseUserCodeFormData(r, o.Decoder()) if err := userCodeVerification(w, r, o); err != nil {
if err != nil {
RequestError(w, r, err) RequestError(w, r, err)
return }
}
func userCodeVerification(w http.ResponseWriter, r *http.Request, o OpenIDProvider) (err error) {
req, err := ParseUserCodeVerificationRequest(r, o.Decoder())
if err != nil {
return err
} }
storage, err := assertDeviceStorage(o.Storage()) storage, err := assertDeviceStorage(o.Storage())
if err != nil { if err != nil {
RequestError(w, r, err) return err
return
} }
ctx := r.Context() ctx := r.Context()
token, err := VerifyAccessToken(ctx, data.AccesssToken, o.AccessTokenVerifier(ctx)) if err := storage.CompleteDeviceAuthorization(ctx, req.Code, req.UserCode); err != nil {
if err != nil { return err
if se := storage.DenyDeviceAuthorization(ctx, data.UserCode); se != nil {
err = se
}
RequestError(w, r, err)
return
} }
if err := storage.CompleteDeviceAuthorization(ctx, data.UserCode, token.GetSubject()); err != nil { if req.RedirectURL != "" {
RequestError(w, r, err) http.Redirect(w, r, req.RedirectURL, http.StatusSeeOther)
return
}
if data.RedirectURL != "" {
http.Redirect(w, r, data.RedirectURL, http.StatusSeeOther)
} }
fmt.Fprintln(w, "Authorization successfull, please return to your device") fmt.Fprintln(w, "Authorization successfull, please return to your device")
return nil
} }
func ParseUserCodeFormData(r *http.Request, decoder httphelper.Decoder) (*UserCodeFormData, error) { func ParseUserCodeVerificationRequest(r *http.Request, decoder httphelper.Decoder) (*UserCodeVerificationRequest, error) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
} }
req := new(UserCodeFormData) req := new(UserCodeVerificationRequest)
if err := decoder.Decode(req, r.Form); err != nil { if err := decoder.Decode(req, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse user code form").WithParent(err) return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse user code form").WithParent(err)
} }
if req.AccesssToken == "" { if req.Code == "" {
return nil, oidc.ErrInvalidRequest().WithDescription("access_token missing in form") return nil, oidc.ErrInvalidRequest().WithDescription("\"code\" missing in form")
} }
if req.UserCode == "" { if req.UserCode == "" {
return nil, oidc.ErrInvalidRequest().WithDescription("user_code missing in form") return nil, oidc.ErrInvalidRequest().WithDescription("\"user_code\" missing in form")
} }
return req, nil return req, nil

View file

@ -400,18 +400,18 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported))
} }
// UserCodeFormEndpoint mocks base method. // UserCodeVerificationEndpoint mocks base method.
func (m *MockConfiguration) UserCodeFormEndpoint() op.Endpoint { func (m *MockConfiguration) UserCodeVerificationEndpoint() op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserCodeFormEndpoint") ret := m.ctrl.Call(m, "UserCodeVerificationEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(op.Endpoint)
return ret0 return ret0
} }
// UserCodeFormEndpoint indicates an expected call of UserCodeFormEndpoint. // UserCodeVerificationEndpoint indicates an expected call of UserCodeVerificationEndpoint.
func (mr *MockConfigurationMockRecorder) UserCodeFormEndpoint() *gomock.Call { func (mr *MockConfigurationMockRecorder) UserCodeVerificationEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeFormEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserCodeFormEndpoint)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeVerificationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserCodeVerificationEndpoint))
} }
// UserinfoEndpoint mocks base method. // UserinfoEndpoint mocks base method.

View file

@ -28,7 +28,7 @@ const (
defaultEndSessionEndpoint = "end_session" defaultEndSessionEndpoint = "end_session"
defaultKeysEndpoint = "keys" defaultKeysEndpoint = "keys"
defaultDeviceAuthzEndpoint = "/device_authorization" defaultDeviceAuthzEndpoint = "/device_authorization"
defaultUserCodeFormEndpoint = "/submit_user_code" defaultUserCodeVerificationEndpoint = "/user_code"
) )
var ( var (
@ -41,7 +41,7 @@ var (
EndSession: NewEndpoint(defaultEndSessionEndpoint), EndSession: NewEndpoint(defaultEndSessionEndpoint),
JwksURI: NewEndpoint(defaultKeysEndpoint), JwksURI: NewEndpoint(defaultKeysEndpoint),
DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint), DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint),
UserCodeForm: NewEndpoint(defaultUserCodeFormEndpoint), UserCodeVerification: NewEndpoint(defaultUserCodeVerificationEndpoint),
} }
defaultCORSOptions = cors.Options{ defaultCORSOptions = cors.Options{
@ -100,7 +100,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router
router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o)) router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o))
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage())) router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), deviceAuthorizationHandler(o)) router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), deviceAuthorizationHandler(o))
router.HandleFunc(o.UserCodeFormEndpoint().Relative(), userCodeFormHandler(o)) router.HandleFunc(o.UserCodeVerificationEndpoint().Relative(), userCodeVerificationHandler(o))
return router return router
} }
@ -137,7 +137,7 @@ type endpoints struct {
CheckSessionIframe Endpoint CheckSessionIframe Endpoint
JwksURI Endpoint JwksURI Endpoint
DeviceAuthorization Endpoint DeviceAuthorization Endpoint
UserCodeForm Endpoint UserCodeVerification Endpoint
} }
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler()) // NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
@ -256,8 +256,8 @@ func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
return o.endpoints.DeviceAuthorization return o.endpoints.DeviceAuthorization
} }
func (o *Provider) UserCodeFormEndpoint() Endpoint { func (o *Provider) UserCodeVerificationEndpoint() Endpoint {
return o.endpoints.UserCodeForm return o.endpoints.UserCodeVerification
} }
func (o *Provider) KeysEndpoint() Endpoint { func (o *Provider) KeysEndpoint() Endpoint {

View file

@ -157,8 +157,7 @@ var ErrDuplicateUserCode = errors.New("user code already exists")
type DeviceAuthorizationState struct { type DeviceAuthorizationState struct {
Scopes []string Scopes []string
Expires time.Time Expires time.Time
Completed bool AuthCode string
Subject string
Denied bool Denied bool
} }
@ -177,10 +176,9 @@ type DeviceAuthorizationStorage interface {
// The method is polled untill the the authorization is eighter Completed, Expired or Denied. // The method is polled untill the the authorization is eighter Completed, Expired or Denied.
GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error)
// CompleteDeviceAuthorization marks a device authorization entry as Completed, // CompleteDeviceAuthorization marks a device authorization entry identified by userCode
// identified by userCode. The Subject is added to the state, so that // as completed, by setting the related authCode from an AuthRequest.
// GetDeviceAuthorizatonState can use it to create a new Access Token. CompleteDeviceAuthorization(ctx context.Context, authCode, userCode string) error
CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error
// DenyDeviceAuthorization marks a device authorization entry as Denied. // DenyDeviceAuthorization marks a device authorization entry as Denied.
DenyDeviceAuthorization(ctx context.Context, userCode string) error DenyDeviceAuthorization(ctx context.Context, userCode string) error