diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 9671ec7..aee9802 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -284,10 +284,18 @@ func (c *ConfClient) AllowedScopes() []string { return nil } -func (c *ConfClient) AssertAdditionalIdTokenScopes() bool { - return false +func (c *ConfClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } } -func (c *ConfClient) AssertAdditionalAccessTokenScopes() bool { +func (c *ConfClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } +} + +func (c *ConfClient) IsScopeAllowed(scope string) bool { return false } diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index 4d6118c..9e320f8 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -121,7 +121,7 @@ func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) { scope == oidc.ScopePhone || scope == oidc.ScopeAddress || scope == oidc.ScopeOfflineAccess) && - !utils.Contains(client.AllowedScopes(), scope) { + !client.IsScopeAllowed(scope) { scopes[i] = scopes[len(scopes)-1] scopes[len(scopes)-1] = "" scopes = scopes[:len(scopes)-1] diff --git a/pkg/op/client.go b/pkg/op/client.go index 790933e..ceca8b0 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -34,9 +34,9 @@ type Client interface { AccessTokenType() AccessTokenType IDTokenLifetime() time.Duration DevMode() bool - AllowedScopes() []string - AssertAdditionalIdTokenScopes() bool - AssertAdditionalAccessTokenScopes() bool + RestrictAdditionalIdTokenScopes() func(scopes []string) []string + RestrictAdditionalAccessTokenScopes() func(scopes []string) []string + IsScopeAllowed(scope string) bool } func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool { diff --git a/pkg/op/mock/client.go b/pkg/op/mock/client.go index 12c00cc..b7ac3e8 100644 --- a/pkg/op/mock/client.go +++ b/pkg/op/mock/client.go @@ -26,7 +26,7 @@ func NewClientExpectAny(t *testing.T, appType op.ApplicationType) op.Client { func(id string) string { return "login?id=" + id }) - m.EXPECT().AllowedScopes().AnyTimes().Return(nil) + m.EXPECT().IsScopeAllowed(gomock.Any()).AnyTimes().Return(false) return c } diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go index 0780623..df80bd0 100644 --- a/pkg/op/mock/client.mock.go +++ b/pkg/op/mock/client.mock.go @@ -49,20 +49,6 @@ func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType)) } -// AllowedScopes mocks base method -func (m *MockClient) AllowedScopes() []string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AllowedScopes") - ret0, _ := ret[0].([]string) - return ret0 -} - -// AllowedScopes indicates an expected call of AllowedScopes -func (mr *MockClientMockRecorder) AllowedScopes() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowedScopes", reflect.TypeOf((*MockClient)(nil).AllowedScopes)) -} - // ApplicationType mocks base method func (m *MockClient) ApplicationType() op.ApplicationType { m.ctrl.T.Helper() @@ -77,34 +63,6 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType)) } -// AssertAdditionalAccessTokenScopes mocks base method -func (m *MockClient) AssertAdditionalAccessTokenScopes() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AssertAdditionalAccessTokenScopes") - ret0, _ := ret[0].(bool) - return ret0 -} - -// AssertAdditionalAccessTokenScopes indicates an expected call of AssertAdditionalAccessTokenScopes -func (mr *MockClientMockRecorder) AssertAdditionalAccessTokenScopes() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalAccessTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalAccessTokenScopes)) -} - -// AssertAdditionalIdTokenScopes mocks base method -func (m *MockClient) AssertAdditionalIdTokenScopes() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AssertAdditionalIdTokenScopes") - ret0, _ := ret[0].(bool) - return ret0 -} - -// AssertAdditionalIdTokenScopes indicates an expected call of AssertAdditionalIdTokenScopes -func (mr *MockClientMockRecorder) AssertAdditionalIdTokenScopes() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalIdTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalIdTokenScopes)) -} - // AuthMethod mocks base method func (m *MockClient) AuthMethod() op.AuthMethod { m.ctrl.T.Helper() @@ -161,6 +119,20 @@ func (mr *MockClientMockRecorder) IDTokenLifetime() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenLifetime", reflect.TypeOf((*MockClient)(nil).IDTokenLifetime)) } +// IsScopeAllowed mocks base method +func (m *MockClient) IsScopeAllowed(arg0 string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsScopeAllowed", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsScopeAllowed indicates an expected call of IsScopeAllowed +func (mr *MockClientMockRecorder) IsScopeAllowed(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsScopeAllowed", reflect.TypeOf((*MockClient)(nil).IsScopeAllowed), arg0) +} + // LoginURL mocks base method func (m *MockClient) LoginURL(arg0 string) string { m.ctrl.T.Helper() @@ -216,3 +188,31 @@ func (mr *MockClientMockRecorder) ResponseTypes() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResponseTypes", reflect.TypeOf((*MockClient)(nil).ResponseTypes)) } + +// RestrictAdditionalAccessTokenScopes mocks base method +func (m *MockClient) RestrictAdditionalAccessTokenScopes() func([]string) []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RestrictAdditionalAccessTokenScopes") + ret0, _ := ret[0].(func([]string) []string) + return ret0 +} + +// RestrictAdditionalAccessTokenScopes indicates an expected call of RestrictAdditionalAccessTokenScopes +func (mr *MockClientMockRecorder) RestrictAdditionalAccessTokenScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestrictAdditionalAccessTokenScopes", reflect.TypeOf((*MockClient)(nil).RestrictAdditionalAccessTokenScopes)) +} + +// RestrictAdditionalIdTokenScopes mocks base method +func (m *MockClient) RestrictAdditionalIdTokenScopes() func([]string) []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RestrictAdditionalIdTokenScopes") + ret0, _ := ret[0].(func([]string) []string) + return ret0 +} + +// RestrictAdditionalIdTokenScopes indicates an expected call of RestrictAdditionalIdTokenScopes +func (mr *MockClientMockRecorder) RestrictAdditionalIdTokenScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestrictAdditionalIdTokenScopes", reflect.TypeOf((*MockClient)(nil).RestrictAdditionalIdTokenScopes)) +} diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go index de9dee9..92d5ad7 100644 --- a/pkg/op/mock/storage.mock.impl.go +++ b/pkg/op/mock/storage.mock.impl.go @@ -171,9 +171,16 @@ func (c *ConfClient) DevMode() bool { func (c *ConfClient) AllowedScopes() []string { return nil } -func (c *ConfClient) AssertAdditionalIdTokenScopes() bool { - return false -} -func (c *ConfClient) AssertAdditionalAccessTokenScopes() bool { +func (c *ConfClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } +} +func (c *ConfClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } +} +func (c *ConfClient) IsScopeAllowed(scope string) bool { return false } diff --git a/pkg/op/token.go b/pkg/op/token.go index 2d66ef5..4fd4c0a 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -31,7 +31,7 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client return nil, err } } - idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer(), client.AssertAdditionalIdTokenScopes()) + idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer(), client.RestrictAdditionalIdTokenScopes()) if err != nil { return nil, err } @@ -84,8 +84,9 @@ func CreateBearerToken(tokenID, subject string, crypto Crypto) (string, error) { func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, signer Signer, client Client, storage Storage) (string, error) { claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id) - if client != nil && client.AssertAdditionalAccessTokenScopes() { - privateClaims, err := storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(tokenRequest.GetScopes())) + if client != nil { + restrictedScopes := client.RestrictAdditionalAccessTokenScopes()(tokenRequest.GetScopes()) + privateClaims, err := storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(restrictedScopes)) if err != nil { return "", err } @@ -94,11 +95,10 @@ func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, ex return utils.Sign(claims, signer.Signer()) } -func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer, additonalScopes bool) (string, error) { +func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer, restictAdditionalScopesFunc func([]string) []string) (string, error) { exp := time.Now().UTC().Add(validity) claims := oidc.NewIDTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, authReq.GetAuthTime(), authReq.GetNonce(), authReq.GetACR(), authReq.GetAMR(), authReq.GetClientID()) - scopes := authReq.GetScopes() - + scopes := restictAdditionalScopesFunc(authReq.GetScopes()) if accessToken != "" { atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm()) if err != nil { @@ -107,9 +107,6 @@ func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, vali claims.SetAccessTokenHash(atHash) scopes = removeUserinfoScopes(scopes) } - if !additonalScopes { - scopes = removeAdditionalScopes(scopes) - } if len(scopes) > 0 { userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetClientID(), scopes) if err != nil { @@ -142,19 +139,3 @@ func removeUserinfoScopes(scopes []string) []string { } return scopes } - -func removeAdditionalScopes(scopes []string) []string { - for i := len(scopes) - 1; i >= 0; i-- { - if !(scopes[i] == oidc.ScopeOpenID || - scopes[i] == oidc.ScopeProfile || - scopes[i] == oidc.ScopeEmail || - scopes[i] == oidc.ScopeAddress || - scopes[i] == oidc.ScopePhone) { - - scopes[i] = scopes[len(scopes)-1] - scopes[len(scopes)-1] = "" - scopes = scopes[:len(scopes)-1] - } - } - return scopes -}