use AuthRequest code flow to create device tokens

This commit is contained in:
Tim Möhlmann 2023-02-27 08:18:33 +01:00
parent 65cd4528e4
commit b885398466
5 changed files with 89 additions and 134 deletions

View file

@ -28,7 +28,7 @@ type Configuration interface {
EndSessionEndpoint() Endpoint EndSessionEndpoint() Endpoint
KeysEndpoint() Endpoint KeysEndpoint() Endpoint
DeviceAuthorizationEndpoint() Endpoint DeviceAuthorizationEndpoint() Endpoint
UserCodeFormEndpoint() Endpoint UserCodeVerificationEndpoint() Endpoint
AuthMethodPostSupported() bool AuthMethodPostSupported() bool
CodeMethodS256Supported() bool CodeMethodS256Supported() bool

View file

@ -88,7 +88,7 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
VerificationURI: config.UserFormURL, VerificationURI: config.UserFormURL,
} }
endpoint := o.UserCodeFormEndpoint().Absolute(IssuerFromContext(r.Context())) endpoint := o.UserCodeVerificationEndpoint().Absolute(IssuerFromContext(r.Context()))
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", endpoint, userCode) response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", endpoint, userCode)
httphelper.MarshalJSON(w, response) httphelper.MarshalJSON(w, response)
@ -148,24 +148,6 @@ func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) {
return buf.String(), nil return buf.String(), nil
} }
type deviceAccessTokenRequest struct {
subject string
audience []string
scopes []string
}
func (r *deviceAccessTokenRequest) GetSubject() string {
return r.subject
}
func (r *deviceAccessTokenRequest) GetAudience() []string {
return r.audience
}
func (r *deviceAccessTokenRequest) GetScopes() []string {
return r.scopes
}
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
if err := deviceAccessToken(w, r, exchanger); err != nil { if err := deviceAccessToken(w, r, exchanger); err != nil {
RequestError(w, r, err) RequestError(w, r, err)
@ -179,7 +161,7 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
defer cancel() defer cancel()
r = r.WithContext(ctx) r = r.WithContext(ctx)
clientID, authenticated, err := ClientIDFromRequest(r, exchanger) clientID, clientAuthenticated, err := ClientIDFromRequest(r, exchanger)
if err != nil { if err != nil {
return err return err
} }
@ -188,7 +170,7 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
if err != nil { if err != nil {
return err return err
} }
state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger) state, authReq, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
if err != nil { if err != nil {
return err return err
} }
@ -197,19 +179,14 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
if err != nil { if err != nil {
return err return err
} }
if !authenticated { if !clientAuthenticated {
if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client? if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client?
return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials). return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
WithDescription(fmt.Sprintf("required client auth method: %s", m)) WithDescription(fmt.Sprintf("required client auth method: %s", m))
} }
} }
tokenRequest := &deviceAccessTokenRequest{ resp, err := CreateTokenResponse(ctx, authReq, client, exchanger, true, state.AuthCode, "")
subject: state.Subject,
audience: []string{clientID},
scopes: state.Scopes,
}
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
if err != nil { if err != nil {
return err return err
} }
@ -226,108 +203,88 @@ func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.
return req, nil return req, nil
} }
func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) { func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, AuthRequest, error) {
storage, err := assertDeviceStorage(exchanger.Storage()) storage, err := assertDeviceStorage(exchanger.Storage())
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode) state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode)
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {
return nil, oidc.ErrSlowDown().WithParent(err) return nil, nil, oidc.ErrSlowDown().WithParent(err)
} }
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if state.Denied { if state.Denied {
return state, oidc.ErrAccessDenied() return state, nil, oidc.ErrAccessDenied()
} }
if state.Completed { if state.AuthCode != "" {
return state, nil return state, nil, nil
} }
if time.Now().After(state.Expires) { if time.Now().After(state.Expires) {
return state, oidc.ErrExpiredDeviceCode() return state, nil, oidc.ErrExpiredDeviceCode()
} }
return state, oidc.ErrAuthorizationPending() authReq, err := AuthRequestByCode(ctx, exchanger.Storage(), state.AuthCode)
return state, authReq, err
} }
func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client AccessTokenClient) (*oidc.AccessTokenResponse, error) { func userCodeVerificationHandler(o OpenIDProvider) http.HandlerFunc {
tokenType := AccessTokenTypeBearer // not sure if this is the correct type?
accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "")
if err != nil {
return nil, err
}
return &oidc.AccessTokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: oidc.BearerToken,
ExpiresIn: uint64(validity.Seconds()),
}, nil
}
func userCodeFormHandler(o OpenIDProvider) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
UserCodeForm(w, r, o) UserCodeVerification(w, r, o)
} }
} }
type UserCodeFormData struct { type UserCodeVerificationRequest struct {
AccesssToken string `schema:"access_token"` Code string `schema:"code"`
UserCode string `schema:"user_code"` UserCode string `schema:"user_code"`
RedirectURL string `schema:"redirect_url"` RedirectURL string `schema:"redirect_url"`
} }
func UserCodeForm(w http.ResponseWriter, r *http.Request, o OpenIDProvider) { func UserCodeVerification(w http.ResponseWriter, r *http.Request, o OpenIDProvider) {
data, err := ParseUserCodeFormData(r, o.Decoder()) if err := userCodeVerification(w, r, o); err != nil {
if err != nil {
RequestError(w, r, err) RequestError(w, r, err)
return }
}
func userCodeVerification(w http.ResponseWriter, r *http.Request, o OpenIDProvider) (err error) {
req, err := ParseUserCodeVerificationRequest(r, o.Decoder())
if err != nil {
return err
} }
storage, err := assertDeviceStorage(o.Storage()) storage, err := assertDeviceStorage(o.Storage())
if err != nil { if err != nil {
RequestError(w, r, err) return err
return
} }
ctx := r.Context() ctx := r.Context()
token, err := VerifyAccessToken(ctx, data.AccesssToken, o.AccessTokenVerifier(ctx)) if err := storage.CompleteDeviceAuthorization(ctx, req.Code, req.UserCode); err != nil {
if err != nil { return err
if se := storage.DenyDeviceAuthorization(ctx, data.UserCode); se != nil {
err = se
}
RequestError(w, r, err)
return
} }
if err := storage.CompleteDeviceAuthorization(ctx, data.UserCode, token.GetSubject()); err != nil { if req.RedirectURL != "" {
RequestError(w, r, err) http.Redirect(w, r, req.RedirectURL, http.StatusSeeOther)
return
}
if data.RedirectURL != "" {
http.Redirect(w, r, data.RedirectURL, http.StatusSeeOther)
} }
fmt.Fprintln(w, "Authorization successfull, please return to your device") fmt.Fprintln(w, "Authorization successfull, please return to your device")
return nil
} }
func ParseUserCodeFormData(r *http.Request, decoder httphelper.Decoder) (*UserCodeFormData, error) { func ParseUserCodeVerificationRequest(r *http.Request, decoder httphelper.Decoder) (*UserCodeVerificationRequest, error) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
} }
req := new(UserCodeFormData) req := new(UserCodeVerificationRequest)
if err := decoder.Decode(req, r.Form); err != nil { if err := decoder.Decode(req, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse user code form").WithParent(err) return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse user code form").WithParent(err)
} }
if req.AccesssToken == "" { if req.Code == "" {
return nil, oidc.ErrInvalidRequest().WithDescription("access_token missing in form") return nil, oidc.ErrInvalidRequest().WithDescription("\"code\" missing in form")
} }
if req.UserCode == "" { if req.UserCode == "" {
return nil, oidc.ErrInvalidRequest().WithDescription("user_code missing in form") return nil, oidc.ErrInvalidRequest().WithDescription("\"user_code\" missing in form")
} }
return req, nil return req, nil

View file

@ -400,18 +400,18 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported))
} }
// UserCodeFormEndpoint mocks base method. // UserCodeVerificationEndpoint mocks base method.
func (m *MockConfiguration) UserCodeFormEndpoint() op.Endpoint { func (m *MockConfiguration) UserCodeVerificationEndpoint() op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserCodeFormEndpoint") ret := m.ctrl.Call(m, "UserCodeVerificationEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(op.Endpoint)
return ret0 return ret0
} }
// UserCodeFormEndpoint indicates an expected call of UserCodeFormEndpoint. // UserCodeVerificationEndpoint indicates an expected call of UserCodeVerificationEndpoint.
func (mr *MockConfigurationMockRecorder) UserCodeFormEndpoint() *gomock.Call { func (mr *MockConfigurationMockRecorder) UserCodeVerificationEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeFormEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserCodeFormEndpoint)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeVerificationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserCodeVerificationEndpoint))
} }
// UserinfoEndpoint mocks base method. // UserinfoEndpoint mocks base method.

View file

@ -17,31 +17,31 @@ import (
) )
const ( const (
healthEndpoint = "/healthz" healthEndpoint = "/healthz"
readinessEndpoint = "/ready" readinessEndpoint = "/ready"
authCallbackPathSuffix = "/callback" authCallbackPathSuffix = "/callback"
defaultAuthorizationEndpoint = "authorize" defaultAuthorizationEndpoint = "authorize"
defaultTokenEndpoint = "oauth/token" defaultTokenEndpoint = "oauth/token"
defaultIntrospectEndpoint = "oauth/introspect" defaultIntrospectEndpoint = "oauth/introspect"
defaultUserinfoEndpoint = "userinfo" defaultUserinfoEndpoint = "userinfo"
defaultRevocationEndpoint = "revoke" defaultRevocationEndpoint = "revoke"
defaultEndSessionEndpoint = "end_session" defaultEndSessionEndpoint = "end_session"
defaultKeysEndpoint = "keys" defaultKeysEndpoint = "keys"
defaultDeviceAuthzEndpoint = "/device_authorization" defaultDeviceAuthzEndpoint = "/device_authorization"
defaultUserCodeFormEndpoint = "/submit_user_code" defaultUserCodeVerificationEndpoint = "/user_code"
) )
var ( var (
DefaultEndpoints = &endpoints{ DefaultEndpoints = &endpoints{
Authorization: NewEndpoint(defaultAuthorizationEndpoint), Authorization: NewEndpoint(defaultAuthorizationEndpoint),
Token: NewEndpoint(defaultTokenEndpoint), Token: NewEndpoint(defaultTokenEndpoint),
Introspection: NewEndpoint(defaultIntrospectEndpoint), Introspection: NewEndpoint(defaultIntrospectEndpoint),
Userinfo: NewEndpoint(defaultUserinfoEndpoint), Userinfo: NewEndpoint(defaultUserinfoEndpoint),
Revocation: NewEndpoint(defaultRevocationEndpoint), Revocation: NewEndpoint(defaultRevocationEndpoint),
EndSession: NewEndpoint(defaultEndSessionEndpoint), EndSession: NewEndpoint(defaultEndSessionEndpoint),
JwksURI: NewEndpoint(defaultKeysEndpoint), JwksURI: NewEndpoint(defaultKeysEndpoint),
DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint), DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint),
UserCodeForm: NewEndpoint(defaultUserCodeFormEndpoint), UserCodeVerification: NewEndpoint(defaultUserCodeVerificationEndpoint),
} }
defaultCORSOptions = cors.Options{ defaultCORSOptions = cors.Options{
@ -100,7 +100,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router
router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o)) router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o))
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage())) router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), deviceAuthorizationHandler(o)) router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), deviceAuthorizationHandler(o))
router.HandleFunc(o.UserCodeFormEndpoint().Relative(), userCodeFormHandler(o)) router.HandleFunc(o.UserCodeVerificationEndpoint().Relative(), userCodeVerificationHandler(o))
return router return router
} }
@ -128,16 +128,16 @@ type Config struct {
} }
type endpoints struct { type endpoints struct {
Authorization Endpoint Authorization Endpoint
Token Endpoint Token Endpoint
Introspection Endpoint Introspection Endpoint
Userinfo Endpoint Userinfo Endpoint
Revocation Endpoint Revocation Endpoint
EndSession Endpoint EndSession Endpoint
CheckSessionIframe Endpoint CheckSessionIframe Endpoint
JwksURI Endpoint JwksURI Endpoint
DeviceAuthorization Endpoint DeviceAuthorization Endpoint
UserCodeForm Endpoint UserCodeVerification Endpoint
} }
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler()) // NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
@ -256,8 +256,8 @@ func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
return o.endpoints.DeviceAuthorization return o.endpoints.DeviceAuthorization
} }
func (o *Provider) UserCodeFormEndpoint() Endpoint { func (o *Provider) UserCodeVerificationEndpoint() Endpoint {
return o.endpoints.UserCodeForm return o.endpoints.UserCodeVerification
} }
func (o *Provider) KeysEndpoint() Endpoint { func (o *Provider) KeysEndpoint() Endpoint {

View file

@ -155,11 +155,10 @@ type EndSessionRequest struct {
var ErrDuplicateUserCode = errors.New("user code already exists") var ErrDuplicateUserCode = errors.New("user code already exists")
type DeviceAuthorizationState struct { type DeviceAuthorizationState struct {
Scopes []string Scopes []string
Expires time.Time Expires time.Time
Completed bool AuthCode string
Subject string Denied bool
Denied bool
} }
type DeviceAuthorizationStorage interface { type DeviceAuthorizationStorage interface {
@ -177,10 +176,9 @@ type DeviceAuthorizationStorage interface {
// The method is polled untill the the authorization is eighter Completed, Expired or Denied. // The method is polled untill the the authorization is eighter Completed, Expired or Denied.
GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error)
// CompleteDeviceAuthorization marks a device authorization entry as Completed, // CompleteDeviceAuthorization marks a device authorization entry identified by userCode
// identified by userCode. The Subject is added to the state, so that // as completed, by setting the related authCode from an AuthRequest.
// GetDeviceAuthorizatonState can use it to create a new Access Token. CompleteDeviceAuthorization(ctx context.Context, authCode, userCode string) error
CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error
// DenyDeviceAuthorization marks a device authorization entry as Denied. // DenyDeviceAuthorization marks a device authorization entry as Denied.
DenyDeviceAuthorization(ctx context.Context, userCode string) error DenyDeviceAuthorization(ctx context.Context, userCode string) error