diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 3d7bb63..b8a1648 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -251,6 +251,7 @@ type ConfClient struct { applicationType op.ApplicationType authMethod oidc.AuthMethod responseTypes []oidc.ResponseType + grantTypes []oidc.GrantType ID string accessTokenType op.AccessTokenType devMode bool @@ -295,6 +296,9 @@ func (c *ConfClient) AccessTokenType() op.AccessTokenType { func (c *ConfClient) ResponseTypes() []oidc.ResponseType { return c.responseTypes } +func (c *ConfClient) GrantTypes() []oidc.GrantType { + return c.grantTypes +} func (c *ConfClient) DevMode() bool { return c.devMode diff --git a/pkg/op/client.go b/pkg/op/client.go index 79715b0..f1e18fa 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -30,6 +30,7 @@ type Client interface { ApplicationType() ApplicationType AuthMethod() oidc.AuthMethod ResponseTypes() []oidc.ResponseType + GrantTypes() []oidc.GrantType LoginURL(string) string AccessTokenType() AccessTokenType IDTokenLifetime() time.Duration diff --git a/pkg/op/config.go b/pkg/op/config.go index 7cb522a..0e5216b 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -21,6 +21,7 @@ type Configuration interface { AuthMethodPostSupported() bool CodeMethodS256Supported() bool AuthMethodPrivateKeyJWTSupported() bool + GrantTypeRefreshTokenSupported() bool GrantTypeTokenExchangeSupported() bool GrantTypeJWTAuthorizationSupported() bool } diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index d8ef7c3..d057042 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -66,6 +66,9 @@ func GrantTypes(c Configuration) []oidc.GrantType { oidc.GrantTypeCode, oidc.GrantTypeImplicit, } + if c.GrantTypeRefreshTokenSupported() { + grantTypes = append(grantTypes, oidc.GrantTypeRefreshToken) + } if c.GrantTypeTokenExchangeSupported() { grantTypes = append(grantTypes, oidc.GrantTypeTokenExchange) } diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go index e03ae0c..f78754d 100644 --- a/pkg/op/mock/client.mock.go +++ b/pkg/op/mock/client.mock.go @@ -120,6 +120,20 @@ func (mr *MockClientMockRecorder) GetID() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID)) } +// GrantTypes mocks base method. +func (m *MockClient) GrantTypes() []oidc.GrantType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GrantTypes") + ret0, _ := ret[0].([]oidc.GrantType) + return ret0 +} + +// GrantTypes indicates an expected call of GrantTypes. +func (mr *MockClientMockRecorder) GrantTypes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypes", reflect.TypeOf((*MockClient)(nil).GrantTypes)) +} + // IDTokenLifetime mocks base method. func (m *MockClient) IDTokenLifetime() time.Duration { m.ctrl.T.Helper() diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go index f9f297e..da21751 100644 --- a/pkg/op/mock/configuration.mock.go +++ b/pkg/op/mock/configuration.mock.go @@ -118,6 +118,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeJWTAuthorizationSupported() *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeJWTAuthorizationSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeJWTAuthorizationSupported)) } +// GrantTypeRefreshTokenSupported mocks base method. +func (m *MockConfiguration) GrantTypeRefreshTokenSupported() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GrantTypeRefreshTokenSupported") + ret0, _ := ret[0].(bool) + return ret0 +} + +// GrantTypeRefreshTokenSupported indicates an expected call of GrantTypeRefreshTokenSupported. +func (mr *MockConfigurationMockRecorder) GrantTypeRefreshTokenSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeRefreshTokenSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeRefreshTokenSupported)) +} + // GrantTypeTokenExchangeSupported mocks base method. func (m *MockConfiguration) GrantTypeTokenExchangeSupported() bool { m.ctrl.T.Helper() diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go index da04fc9..7689bd3 100644 --- a/pkg/op/mock/storage.mock.impl.go +++ b/pkg/op/mock/storage.mock.impl.go @@ -107,6 +107,7 @@ type ConfClient struct { authMethod oidc.AuthMethod accessTokenType op.AccessTokenType responseTypes []oidc.ResponseType + grantTypes []oidc.GrantType devMode bool } @@ -150,6 +151,9 @@ func (c *ConfClient) AccessTokenType() op.AccessTokenType { func (c *ConfClient) ResponseTypes() []oidc.ResponseType { return c.responseTypes } +func (c *ConfClient) GrantTypes() []oidc.GrantType { + return c.grantTypes +} func (c *ConfClient) DevMode() bool { return c.devMode } diff --git a/pkg/op/op.go b/pkg/op/op.go index c91865d..03b053d 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -84,6 +84,7 @@ type Config struct { DefaultLogoutRedirectURI string CodeMethodS256 bool AuthMethodPrivateKeyJWT bool + GrantTypeRefreshToken bool } type endpoints struct { @@ -189,6 +190,10 @@ func (o *openidProvider) AuthMethodPrivateKeyJWTSupported() bool { return o.config.AuthMethodPrivateKeyJWT } +func (o *openidProvider) GrantTypeRefreshTokenSupported() bool { + return o.config.GrantTypeRefreshToken +} + func (o *openidProvider) GrantTypeTokenExchangeSupported() bool { return false } diff --git a/pkg/op/token.go b/pkg/op/token.go index 28bc011..1faffa8 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -53,18 +53,18 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli }, nil } -func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage, refreshToken string) (id, newRefreshToken string, exp time.Time, err error) { - if needsRefreshToken(tokenRequest) { +func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage, refreshToken string, client Client) (id, newRefreshToken string, exp time.Time, err error) { + if needsRefreshToken(tokenRequest, client) { return storage.CreateAccessAndRefreshTokens(ctx, tokenRequest, refreshToken) } id, exp, err = storage.CreateAccessToken(ctx, tokenRequest) return } -func needsRefreshToken(tokenRequest TokenRequest) bool { +func needsRefreshToken(tokenRequest TokenRequest, client Client) bool { switch req := tokenRequest.(type) { case AuthRequest: - return utils.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode + return utils.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode && ValidateGrantType(client, oidc.GrantTypeRefreshToken) case RefreshTokenRequest: return true default: @@ -73,7 +73,7 @@ func needsRefreshToken(tokenRequest TokenRequest) bool { } func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client Client, refreshToken string) (accessToken, newRefreshToken string, validity time.Duration, err error) { - id, newRefreshToken, exp, err := createTokens(ctx, tokenRequest, creator.Storage(), refreshToken) + id, newRefreshToken, exp, err := createTokens(ctx, tokenRequest, creator.Storage(), refreshToken, client) if err != nil { return "", "", 0, err } diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index 9aae67b..fa941df 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -53,6 +53,9 @@ func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenR if client.GetID() != authReq.GetClientID() { return nil, nil, ErrInvalidRequest("invalid auth code") } + if !ValidateGrantType(client, oidc.GrantTypeCode) { + return nil, nil, ErrInvalidRequest("invalid_grant") + } if tokenReq.RedirectURI != authReq.GetRedirectURI() { return nil, nil, ErrInvalidRequest("redirect_uri does not correspond") } diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index f8148d7..fd26f19 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -17,6 +17,7 @@ type Exchanger interface { Crypto() Crypto AuthMethodPostSupported() bool AuthMethodPrivateKeyJWTSupported() bool + GrantTypeRefreshTokenSupported() bool GrantTypeTokenExchangeSupported() bool GrantTypeJWTAuthorizationSupported() bool } @@ -28,8 +29,10 @@ func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Reque CodeExchange(w, r, exchanger) return case string(oidc.GrantTypeRefreshToken): - RefreshTokenExchange(w, r, exchanger) - return + if exchanger.GrantTypeRefreshTokenSupported() { + RefreshTokenExchange(w, r, exchanger) + return + } case string(oidc.GrantTypeBearer): if ex, ok := exchanger.(JWTAuthorizationGrantExchanger); ok && exchanger.GrantTypeJWTAuthorizationSupported() { JWTProfile(w, r, ex) @@ -119,3 +122,16 @@ func AuthorizePrivateJWTKey(ctx context.Context, clientAssertion string, exchang } return client, nil } + +//ValidateGrantType ensures that the requested grant_type is allowed by the Client +func ValidateGrantType(client Client, grantType oidc.GrantType) bool { + if client == nil { + return false + } + for _, grant := range client.GrantTypes() { + if grantType == grant { + return true + } + } + return false +}