From 671b13b9c6116e18b0449af82c21d6c59d5e3d62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Wed, 22 Feb 2023 20:11:42 +0100 Subject: [PATCH] implement RFC 8628: Device authorization grant WIP Related #264 --- pkg/client/client.go | 17 +++ pkg/client/rp/device.go | 20 +++ pkg/client/rp/relying_party.go | 8 ++ pkg/oidc/device_authorization.go | 31 ++++ pkg/oidc/discovery.go | 2 + pkg/oidc/error.go | 34 +++++ pkg/oidc/token_request.go | 5 +- pkg/op/config.go | 4 + pkg/op/device.go | 232 ++++++++++++++++++++++++++++++ pkg/op/device_test.go | 204 ++++++++++++++++++++++++++ pkg/op/discovery.go | 4 + pkg/op/mock/configuration.mock.go | 56 ++++++++ pkg/op/op.go | 54 +++++-- pkg/op/storage.go | 32 +++++ pkg/op/token_request.go | 6 + 15 files changed, 693 insertions(+), 16 deletions(-) create mode 100644 pkg/client/rp/device.go create mode 100644 pkg/oidc/device_authorization.go create mode 100644 pkg/op/device.go create mode 100644 pkg/op/device_test.go diff --git a/pkg/client/client.go b/pkg/client/client.go index 077baf2..08e16fd 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -186,3 +186,20 @@ func SignedJWTProfileAssertion(clientID string, audience []string, expiration ti IssuedAt: oidc.Time(iat), }, signer) } + +type DeviceAuthorizationCaller interface { + GetDeviceCodeEndpoint() string + HttpClient() *http.Client +} + +func CallDeviceAuthorizationEndpoint(request interface{}, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) { + req, err := httphelper.FormRequest(caller.GetDeviceCodeEndpoint(), request, Encoder, nil) + if err != nil { + return nil, err + } + resp := new(oidc.DeviceAuthorizationResponse) + if err := httphelper.HttpRequest(caller.HttpClient(), req, &resp); err != nil { + return nil, err + } + return resp, nil +} diff --git a/pkg/client/rp/device.go b/pkg/client/rp/device.go new file mode 100644 index 0000000..c7d0be8 --- /dev/null +++ b/pkg/client/rp/device.go @@ -0,0 +1,20 @@ +package rp + +import ( + "github.com/zitadel/oidc/v2/pkg/client" + "github.com/zitadel/oidc/v2/pkg/oidc" +) + +func DeviceAuthorization(clientID string, scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) { + req := &oidc.DeviceAuthorizationRequest{ + Scopes: scopes, + ClientID: clientID, + } + return client.CallDeviceAuthorizationEndpoint(req, rp) +} + +/* +func DeviceAccessToken() (*oauth2.Token, error) { + req := &oidc.DeviceAccessTokenRequest{} +} +*/ diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index d2e3cf7..d44f78a 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -59,6 +59,8 @@ type RelyingParty interface { // UserinfoEndpoint returns the userinfo UserinfoEndpoint() string + GetDeviceCodeEndpoint() string + // IDTokenVerifier returns the verifier interface used for oidc id_token verification IDTokenVerifier() IDTokenVerifier // ErrorHandler returns the handler used for callback errors @@ -121,6 +123,10 @@ func (rp *relyingParty) UserinfoEndpoint() string { return rp.endpoints.UserinfoURL } +func (rp *relyingParty) GetDeviceCodeEndpoint() string { + return rp.endpoints.DeviceCodeURL +} + func (rp *relyingParty) GetEndSessionEndpoint() string { return rp.endpoints.EndSessionURL } @@ -500,6 +506,7 @@ type Endpoints struct { JKWsURL string EndSessionURL string RevokeURL string + DeviceCodeURL string } func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints { @@ -514,6 +521,7 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints { JKWsURL: discoveryConfig.JwksURI, EndSessionURL: discoveryConfig.EndSessionEndpoint, RevokeURL: discoveryConfig.RevocationEndpoint, + DeviceCodeURL: discoveryConfig.DeviceAuthorizationEndpoint, } } diff --git a/pkg/oidc/device_authorization.go b/pkg/oidc/device_authorization.go new file mode 100644 index 0000000..58244cd --- /dev/null +++ b/pkg/oidc/device_authorization.go @@ -0,0 +1,31 @@ +package oidc + +// DeviceAuthorizationRequest implements +// https://www.rfc-editor.org/rfc/rfc8628#section-3.1, +// 3.1 Device Authorization Request. +type DeviceAuthorizationRequest struct { + Scopes SpaceDelimitedArray `schema:"scope"` + ClientID string `schema:"client_id"` +} + +// DeviceAuthorizationResponse implements +// https://www.rfc-editor.org/rfc/rfc8628#section-3.2 +// 3.2. Device Authorization Response. +type DeviceAuthorizationResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete,omitempty"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval,omitempty"` +} + +// DeviceAccessTokenRequest implements +// https://www.rfc-editor.org/rfc/rfc8628#section-3.4, +// Device Access Token Request. +type DeviceAccessTokenRequest struct { + JWTTokenRequest + GrantType string `json:"grant_type"` + DeviceCode string `json:"device_code"` + ClientID string `json:"client_id"` // required, how?? +} diff --git a/pkg/oidc/discovery.go b/pkg/oidc/discovery.go index fbc417b..3574101 100644 --- a/pkg/oidc/discovery.go +++ b/pkg/oidc/discovery.go @@ -30,6 +30,8 @@ type DiscoveryConfiguration struct { // EndSessionEndpoint is a URL where the RP can perform a redirect to request that the End-User be logged out at the OP. EndSessionEndpoint string `json:"end_session_endpoint,omitempty"` + DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint,omitempty"` + // CheckSessionIframe is a URL where the OP provides an iframe that support cross-origin communications for session state information with the RP Client. CheckSessionIframe string `json:"check_session_iframe,omitempty"` diff --git a/pkg/oidc/error.go b/pkg/oidc/error.go index 5797a59..b84b7f2 100644 --- a/pkg/oidc/error.go +++ b/pkg/oidc/error.go @@ -18,6 +18,14 @@ const ( InteractionRequired errorType = "interaction_required" LoginRequired errorType = "login_required" RequestNotSupported errorType = "request_not_supported" + + // Additional error codes as defined in + // https://www.rfc-editor.org/rfc/rfc8628#section-3.5 + // Device Access Token Response + AuthorizationPending errorType = "authorization_pending" + SlowDown errorType = "slow_down" + AccessDenied errorType = "access_denied" + ExpiredToken errorType = "expired_token" ) var ( @@ -77,6 +85,32 @@ var ( ErrorType: RequestNotSupported, } } + + // Device Access Token errors: + ErrAuthorizationPending = func() *Error { + return &Error{ + ErrorType: AuthorizationPending, + Description: "The client SHOULD repeat the access token request to the token endpoint, after interval from device authorization response.", + } + } + ErrSlowDown = func() *Error { + return &Error{ + ErrorType: SlowDown, + Description: "Polling should continue, but the interval MUST be increased by 5 seconds for this and all subsequent requests.", + } + } + ErrAccessDenied = func() *Error { + return &Error{ + ErrorType: AccessDenied, + Description: "The authorization request was denied.", + } + } + ErrExpiredToken = func() *Error { + return &Error{ + ErrorType: ExpiredToken, + Description: "The \"device_code\" has expired.", + } + } ) type Error struct { diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go index 6d8f186..78bd658 100644 --- a/pkg/oidc/token_request.go +++ b/pkg/oidc/token_request.go @@ -27,6 +27,9 @@ const ( // GrantTypeImplicit defines the grant type `implicit` used for implicit flows that skip the generation and exchange of an Authorization Code GrantTypeImplicit GrantType = "implicit" + // GrantTypeDeviceCode + GrantTypeDeviceCode GrantType = "urn:ietf:params:oauth:grant-type:device_code" + // ClientAssertionTypeJWTAssertion defines the client_assertion_type `urn:ietf:params:oauth:client-assertion-type:jwt-bearer` // used for the OAuth JWT Profile Client Authentication ClientAssertionTypeJWTAssertion = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" @@ -35,7 +38,7 @@ const ( var AllGrantTypes = []GrantType{ GrantTypeCode, GrantTypeRefreshToken, GrantTypeClientCredentials, GrantTypeBearer, GrantTypeTokenExchange, GrantTypeImplicit, - ClientAssertionTypeJWTAssertion, + GrantTypeDeviceCode, ClientAssertionTypeJWTAssertion, } type GrantType string diff --git a/pkg/op/config.go b/pkg/op/config.go index c40fa2d..d339a30 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -27,6 +27,8 @@ type Configuration interface { RevocationEndpoint() Endpoint EndSessionEndpoint() Endpoint KeysEndpoint() Endpoint + DeviceAuthorizationEndpoint() Endpoint + UserCodeFormEndpoint() Endpoint AuthMethodPostSupported() bool CodeMethodS256Supported() bool @@ -36,6 +38,7 @@ type Configuration interface { GrantTypeTokenExchangeSupported() bool GrantTypeJWTAuthorizationSupported() bool GrantTypeClientCredentialsSupported() bool + GrantTypeDeviceCodeSupported() bool IntrospectionAuthMethodPrivateKeyJWTSupported() bool IntrospectionEndpointSigningAlgorithmsSupported() []string RevocationAuthMethodPrivateKeyJWTSupported() bool @@ -44,6 +47,7 @@ type Configuration interface { RequestObjectSigningAlgorithmsSupported() []string SupportedUILocales() []language.Tag + DeviceAuthorization() DeviceAuthorizationConfig } type IssuerFromRequest func(r *http.Request) string diff --git a/pkg/op/device.go b/pkg/op/device.go new file mode 100644 index 0000000..438b78a --- /dev/null +++ b/pkg/op/device.go @@ -0,0 +1,232 @@ +package op + +import ( + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "math/big" + "net/http" + "net/url" + "strings" + + httphelper "github.com/zitadel/oidc/v2/pkg/http" + "github.com/zitadel/oidc/v2/pkg/oidc" +) + +type DeviceAuthorizationConfig struct { + Lifetime int + PollInterval int + UserCode UserCodeConfig +} + +type UserCodeConfig struct { + CharSet string + CharAmount int + DashInterval int + QueryKey string + FormHTML []byte +} + +const ( + CharSetBase20 = "BCDFGHJKLMNPQRSTVWXZ" + CharSetDigits = "0123456789" +) + +var ( + UserCodeBase20 = UserCodeConfig{ + CharSet: CharSetBase20, + CharAmount: 8, + DashInterval: 4, + QueryKey: "user_code", + } + UserCodeDigits = UserCodeConfig{ + CharSet: CharSetDigits, + CharAmount: 9, + DashInterval: 3, + QueryKey: "user_code", + } +) + +func deviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + DeviceAuthorization(w, r, o) + } +} + +func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) { + storage, ok := o.Storage().(DeviceCodeStorage) + if !ok { + // unimplemented error? + } + req, err := ParseDeviceCodeRequest(r, o.Decoder()) + if err != nil { + RequestError(w, r, err) + return + } + + config := o.DeviceAuthorization() + + deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes) + if err != nil { + RequestError(w, r, err) + return + } + userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.CharAmount) + if err != nil { + RequestError(w, r, err) + return + } + err = storage.StoreDeviceAuthorizationRequest(r.Context(), req, deviceCode, userCode) + if err != nil { + RequestError(w, r, err) + return + } + + endpoint := o.UserCodeFormEndpoint().Absolute(IssuerFromContext(r.Context())) + + response := &oidc.DeviceAuthorizationResponse{ + DeviceCode: deviceCode, + UserCode: userCode, + VerificationURI: endpoint, + } + + if key := config.UserCode.QueryKey; key != "" { + vals := make(url.Values, 1) + vals.Set(key, userCode) + response.VerificationURIComplete = strings.Join([]string{endpoint, vals.Encode()}, "?") + } + + httphelper.MarshalJSON(w, response) +} + +func ParseDeviceCodeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.DeviceAuthorizationRequest, error) { + if err := r.ParseForm(); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) + } + + devReq := new(oidc.DeviceAuthorizationRequest) + if err := decoder.Decode(devReq, r.Form); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse dev auth request").WithParent(err) + } + + return devReq, nil +} + +// 16 bytes gives 128 bit of entropy. +// results in a 22 character base64 encoded string. +const RecommendedDeviceCodeBytes = 16 + +func NewDeviceCode(nBytes int) (string, error) { + bytes := make([]byte, nBytes) + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("%w getting entropy for device code", err) + } + return base64.RawURLEncoding.EncodeToString(bytes), nil +} + +func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) { + var buf strings.Builder + if dashInterval > 0 { + buf.Grow(charAmount + charAmount/dashInterval - 1) + } else { + buf.Grow(charAmount) + } + + max := big.NewInt(int64(len(charSet))) + + for i := 0; i < charAmount; i++ { + if dashInterval != 0 && i != 0 && i%dashInterval == 0 { + buf.WriteByte('-') + } + + bi, err := rand.Int(rand.Reader, max) + if err != nil { + return "", fmt.Errorf("%w getting entropy for user code", err) + } + + buf.WriteRune(charSet[int(bi.Int64())]) + } + + return buf.String(), nil +} + +func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { + req := new(oidc.DeviceAccessTokenRequest) + if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil { + RequestError(w, r, err) + } + + storage, ok := exchanger.Storage().(DeviceCodeStorage) + if !ok { + // unimplemented error? + } + + client, err := storage.DeviceAccessPoll(r.Context(), req.DeviceCode) + if err != nil { + RequestError(w, r, err) + } + + resp, err := CreateDeviceTokenResponse(r.Context(), req, exchanger, client) + if err != nil { + RequestError(w, r, err) + return + } + httphelper.MarshalJSON(w, resp) +} + +func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) { + tokenType := AccessTokenTypeBearer // not sure if this is the correct type? + + accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "") + if err != nil { + return nil, err + } + + return &oidc.AccessTokenResponse{ + AccessToken: accessToken, + TokenType: oidc.BearerToken, + ExpiresIn: uint64(validity.Seconds()), + }, nil +} + +func userCodeFormHandler(o OpenIDProvider) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + UserCodeForm(w, r, o) + } +} + +func UserCodeForm(w http.ResponseWriter, r *http.Request, o OpenIDProvider) { + // check cookie, or what?? + + config := o.DeviceAuthorization().UserCode + userCode, err := UserCodeFromRequest(r, config.QueryKey) + if err != nil { + RequestError(w, r, err) + return + } + if userCode == "" { + w.Write(config.FormHTML) + return + } + + storage, ok := o.Storage().(DeviceCodeStorage) + if !ok { + // unimplemented error? + } + + if err := storage.ReleaseDeviceAccessToken(r.Context(), userCode); err != nil { + RequestError(w, r, err) + return + } + + fmt.Fprintln(w, "Authorization successfull, please return to your device") +} + +func UserCodeFromRequest(r *http.Request, key string) (string, error) { + if err := r.ParseForm(); err != nil { + return "", oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) + } + + return r.Form.Get(key), nil +} diff --git a/pkg/op/device_test.go b/pkg/op/device_test.go new file mode 100644 index 0000000..6eea1e3 --- /dev/null +++ b/pkg/op/device_test.go @@ -0,0 +1,204 @@ +package op + +import ( + "crypto/rand" + "encoding/base64" + "io" + mr "math/rand" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type errReader struct { +} + +func (errReader) Read([]byte) (int, error) { + return 0, io.ErrUnexpectedEOF +} + +func runWithRandReader(r io.Reader, f func()) { + originalReader := rand.Reader + rand.Reader = r + defer func() { + rand.Reader = originalReader + }() + + f() +} + +func TestNewDeviceCode(t *testing.T) { + t.Run("reader error", func(t *testing.T) { + runWithRandReader(errReader{}, func() { + _, err := NewDeviceCode(16) + require.Error(t, err) + }) + }) + + t.Run("dirrent lengths, rand reader", func(t *testing.T) { + for i := 1; i <= 32; i++ { + got, err := NewDeviceCode(i) + require.NoError(t, err) + assert.Len(t, got, base64.RawURLEncoding.EncodedLen(i)) + } + }) + +} + +func TestNewUserCode(t *testing.T) { + type args struct { + charset []rune + charAmount int + dashInterval int + } + tests := []struct { + name string + args args + reader io.Reader + want string + wantErr bool + }{ + { + name: "reader error", + args: args{ + charset: []rune(CharSetBase20), + charAmount: 8, + dashInterval: 4, + }, + reader: errReader{}, + wantErr: true, + }, + { + name: "base20", + args: args{ + charset: []rune(CharSetBase20), + charAmount: 8, + dashInterval: 4, + }, + reader: mr.New(mr.NewSource(1)), + want: "XKCD-HTTD", + }, + { + name: "digits", + args: args{ + charset: []rune(CharSetDigits), + charAmount: 9, + dashInterval: 3, + }, + reader: mr.New(mr.NewSource(1)), + want: "271-256-225", + }, + { + name: "no dashes", + args: args{ + charset: []rune(CharSetDigits), + charAmount: 9, + }, + reader: mr.New(mr.NewSource(1)), + want: "271256225", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + runWithRandReader(tt.reader, func() { + got, err := NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval) + if tt.wantErr { + require.ErrorIs(t, err, io.ErrUnexpectedEOF) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.want, got) + }) + + }) + } + + t.Run("crypto/rand", func(t *testing.T) { + const testN = 100000 + + for _, c := range []UserCodeConfig{UserCodeBase20, UserCodeDigits} { + t.Run(c.CharSet, func(t *testing.T) { + results := make(map[string]int) + + for i := 0; i < testN; i++ { + code, err := NewUserCode([]rune(c.CharSet), c.CharAmount, c.DashInterval) + require.NoError(t, err) + results[code]++ + } + + t.Log(results) + + var duplicates int + for code, count := range results { + assert.Less(t, count, 3, code) + if count == 2 { + duplicates++ + } + } + + }) + } + }) +} + +func BenchmarkNewUserCode(b *testing.B) { + type args struct { + charset []rune + charAmount int + dashInterval int + } + tests := []struct { + name string + args args + reader io.Reader + }{ + { + name: "math rand, base20", + args: args{ + charset: []rune(CharSetBase20), + charAmount: 8, + dashInterval: 4, + }, + reader: mr.New(mr.NewSource(1)), + }, + { + name: "math rand, digits", + args: args{ + charset: []rune(CharSetDigits), + charAmount: 9, + dashInterval: 3, + }, + reader: mr.New(mr.NewSource(1)), + }, + { + name: "crypto rand, base20", + args: args{ + charset: []rune(CharSetBase20), + charAmount: 8, + dashInterval: 4, + }, + reader: rand.Reader, + }, + { + name: "crypto rand, digits", + args: args{ + charset: []rune(CharSetDigits), + charAmount: 9, + dashInterval: 3, + }, + reader: rand.Reader, + }, + } + for _, tt := range tests { + runWithRandReader(tt.reader, func() { + b.Run(tt.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval) + require.NoError(b, err) + } + }) + + }) + } +} diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 9a25afc..26f89eb 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -44,6 +44,7 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov RevocationEndpoint: config.RevocationEndpoint().Absolute(issuer), EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer), JwksURI: config.KeysEndpoint().Absolute(issuer), + DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer), ScopesSupported: Scopes(config), ResponseTypesSupported: ResponseTypes(config), GrantTypesSupported: GrantTypes(config), @@ -92,6 +93,9 @@ func GrantTypes(c Configuration) []oidc.GrantType { if c.GrantTypeJWTAuthorizationSupported() { grantTypes = append(grantTypes, oidc.GrantTypeBearer) } + if c.GrantTypeDeviceCodeSupported() { + grantTypes = append(grantTypes, oidc.GrantTypeDeviceCode) + } return grantTypes } diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go index fc3158a..44b5ceb 100644 --- a/pkg/op/mock/configuration.mock.go +++ b/pkg/op/mock/configuration.mock.go @@ -92,6 +92,34 @@ func (mr *MockConfigurationMockRecorder) CodeMethodS256Supported() *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CodeMethodS256Supported", reflect.TypeOf((*MockConfiguration)(nil).CodeMethodS256Supported)) } +// DeviceAuthorization mocks base method. +func (m *MockConfiguration) DeviceAuthorization() op.DeviceAuthorizationConfig { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeviceAuthorization") + ret0, _ := ret[0].(op.DeviceAuthorizationConfig) + return ret0 +} + +// DeviceAuthorization indicates an expected call of DeviceAuthorization. +func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorization", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorization)) +} + +// DeviceAuthorizationEndpoint mocks base method. +func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint") + ret0, _ := ret[0].(op.Endpoint) + return ret0 +} + +// DeviceAuthorizationEndpoint indicates an expected call of DeviceAuthorizationEndpoint. +func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorizationEndpoint)) +} + // EndSessionEndpoint mocks base method. func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint { m.ctrl.T.Helper() @@ -120,6 +148,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeClientCredentialsSupported() * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeClientCredentialsSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeClientCredentialsSupported)) } +// GrantTypeDeviceCodeSupported mocks base method. +func (m *MockConfiguration) GrantTypeDeviceCodeSupported() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GrantTypeDeviceCodeSupported") + ret0, _ := ret[0].(bool) + return ret0 +} + +// GrantTypeDeviceCodeSupported indicates an expected call of GrantTypeDeviceCodeSupported. +func (mr *MockConfigurationMockRecorder) GrantTypeDeviceCodeSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeDeviceCodeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeDeviceCodeSupported)) +} + // GrantTypeJWTAuthorizationSupported mocks base method. func (m *MockConfiguration) GrantTypeJWTAuthorizationSupported() bool { m.ctrl.T.Helper() @@ -358,6 +400,20 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported)) } +// UserCodeFormEndpoint mocks base method. +func (m *MockConfiguration) UserCodeFormEndpoint() op.Endpoint { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UserCodeFormEndpoint") + ret0, _ := ret[0].(op.Endpoint) + return ret0 +} + +// UserCodeFormEndpoint indicates an expected call of UserCodeFormEndpoint. +func (mr *MockConfigurationMockRecorder) UserCodeFormEndpoint() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeFormEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserCodeFormEndpoint)) +} + // UserinfoEndpoint mocks base method. func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint { m.ctrl.T.Helper() diff --git a/pkg/op/op.go b/pkg/op/op.go index 699fb45..2256ca7 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -27,17 +27,21 @@ const ( defaultRevocationEndpoint = "revoke" defaultEndSessionEndpoint = "end_session" defaultKeysEndpoint = "keys" + defaultDeviceAuthzEndpoint = "/device_authorization" + defaultUserCodeFormEndpoint = "/device" ) var ( DefaultEndpoints = &endpoints{ - Authorization: NewEndpoint(defaultAuthorizationEndpoint), - Token: NewEndpoint(defaultTokenEndpoint), - Introspection: NewEndpoint(defaultIntrospectEndpoint), - Userinfo: NewEndpoint(defaultUserinfoEndpoint), - Revocation: NewEndpoint(defaultRevocationEndpoint), - EndSession: NewEndpoint(defaultEndSessionEndpoint), - JwksURI: NewEndpoint(defaultKeysEndpoint), + Authorization: NewEndpoint(defaultAuthorizationEndpoint), + Token: NewEndpoint(defaultTokenEndpoint), + Introspection: NewEndpoint(defaultIntrospectEndpoint), + Userinfo: NewEndpoint(defaultUserinfoEndpoint), + Revocation: NewEndpoint(defaultRevocationEndpoint), + EndSession: NewEndpoint(defaultEndSessionEndpoint), + JwksURI: NewEndpoint(defaultKeysEndpoint), + DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint), + UserCodeForm: NewEndpoint(defaultUserCodeFormEndpoint), } defaultCORSOptions = cors.Options{ @@ -95,6 +99,8 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o)) router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o)) router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage())) + router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), deviceAuthorizationHandler(o)) + router.HandleFunc(o.UserCodeFormEndpoint().Relative(), userCodeFormHandler(o)) return router } @@ -121,14 +127,16 @@ type Config struct { } type endpoints struct { - Authorization Endpoint - Token Endpoint - Introspection Endpoint - Userinfo Endpoint - Revocation Endpoint - EndSession Endpoint - CheckSessionIframe Endpoint - JwksURI Endpoint + Authorization Endpoint + Token Endpoint + Introspection Endpoint + Userinfo Endpoint + Revocation Endpoint + EndSession Endpoint + CheckSessionIframe Endpoint + JwksURI Endpoint + DeviceAuthorization Endpoint + UserCodeForm Endpoint } // NewOpenIDProvider creates a provider. The provider provides (with HttpHandler()) @@ -242,6 +250,14 @@ func (o *Provider) EndSessionEndpoint() Endpoint { return o.endpoints.EndSession } +func (o *Provider) DeviceAuthorizationEndpoint() Endpoint { + return o.endpoints.DeviceAuthorization +} + +func (o *Provider) UserCodeFormEndpoint() Endpoint { + return o.endpoints.UserCodeForm +} + func (o *Provider) KeysEndpoint() Endpoint { return o.endpoints.JwksURI } @@ -275,6 +291,10 @@ func (o *Provider) GrantTypeJWTAuthorizationSupported() bool { return true } +func (o *Provider) GrantTypeDeviceCodeSupported() bool { + return true +} + func (o *Provider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool { return true } @@ -308,6 +328,10 @@ func (o *Provider) SupportedUILocales() []language.Tag { return o.config.SupportedUILocales } +func (o *Provider) DeviceAuthorization() DeviceAuthorizationConfig { + return DeviceAuthorizationConfig{} +} + func (o *Provider) Storage() Storage { return o.storage } diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 1e19c76..69b05b7 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -151,3 +151,35 @@ type EndSessionRequest struct { ClientID string RedirectURI string } + +var ErrDuplicateUserCode = errors.New("user code already exists") + +type DeviceCodeStorage interface { + // StoreDeviceAuthorizationRequest stores a new device authorization request in the database. + // User code will be used by the user to complete the login flow and must be unique. + // ErrDuplicateUserCode signals the caller should try again with a new code. + // + // Note that user codes are low entropy keys and when many exist in the + // database, the change for collisions increases. Therefore implementers + // of this interface must make sure that user codes of completed or expired + // authentication flows are deleted. + StoreDeviceAuthorizationRequest(ctx context.Context, req *oidc.DeviceAuthorizationRequest, deviceCode, userCode string) error + + // DeviceAccessPoll is called by the device untill the authorization flow is + // completed or expired. + // + // The following errors are defined for the Device Authorization workflow, + // that can be returned by this method: + // - oidc.ErrAuthorizationPending should be returned on each poll, while the flow is not completed by the user. + // - oidc.ErrSlowDown signals to the device that the polling interval is to be increased by 5 seconds. + // - oidc.ErrAccessDenied when the authorization request is denied. + // - oidc.ErrExpiredToken when the device code has expired. + // + // A token should be returned once the authorization flow is completed + // by the user. + DeviceAccessPoll(ctx context.Context, deviceCode string) (Client, error) + + // ReleaseDeviceAccessToken releases DeviceAccessPoll to return the Access Token, + // destined for a user code. + ReleaseDeviceAccessToken(ctx context.Context, userCode string) error +} diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index 3d65ea0..b9e9805 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -19,6 +19,7 @@ type Exchanger interface { GrantTypeTokenExchangeSupported() bool GrantTypeJWTAuthorizationSupported() bool GrantTypeClientCredentialsSupported() bool + GrantTypeDeviceCodeSupported() bool AccessTokenVerifier(context.Context) AccessTokenVerifier IDTokenHintVerifier(context.Context) IDTokenHintVerifier } @@ -56,6 +57,11 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { ClientCredentialsExchange(w, r, exchanger) return } + case string(oidc.GrantTypeDeviceCode): + if exchanger.GrantTypeDeviceCodeSupported() { + DeviceAccessToken(w, r, exchanger) + return + } case "": RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing")) return