From 6ba72be7ea7584207291d86057bfd3e22916dc53 Mon Sep 17 00:00:00 2001 From: Fabiennne Date: Wed, 28 Oct 2020 15:08:22 +0100 Subject: [PATCH] fix: restrict additional scopes --- example/internal/mock/storage.go | 12 ++++--- pkg/op/client.go | 4 +-- pkg/op/mock/client.mock.go | 56 ++++++++++++++++---------------- pkg/op/mock/storage.mock.impl.go | 12 ++++--- pkg/op/token.go | 14 ++++---- 5 files changed, 53 insertions(+), 45 deletions(-) diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 2c5988b..aee9802 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -284,12 +284,16 @@ func (c *ConfClient) AllowedScopes() []string { return nil } -func (c *ConfClient) AssertAdditionalIdTokenScopes() bool { - return false +func (c *ConfClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } } -func (c *ConfClient) AssertAdditionalAccessTokenScopes() bool { - return false +func (c *ConfClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } } func (c *ConfClient) IsScopeAllowed(scope string) bool { diff --git a/pkg/op/client.go b/pkg/op/client.go index 0e77627..ceca8b0 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -34,8 +34,8 @@ type Client interface { AccessTokenType() AccessTokenType IDTokenLifetime() time.Duration DevMode() bool - AssertAdditionalIdTokenScopes() bool - AssertAdditionalAccessTokenScopes() bool + RestrictAdditionalIdTokenScopes() func(scopes []string) []string + RestrictAdditionalAccessTokenScopes() func(scopes []string) []string IsScopeAllowed(scope string) bool } diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go index 606b278..df80bd0 100644 --- a/pkg/op/mock/client.mock.go +++ b/pkg/op/mock/client.mock.go @@ -63,34 +63,6 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType)) } -// AssertAdditionalAccessTokenScopes mocks base method -func (m *MockClient) AssertAdditionalAccessTokenScopes() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AssertAdditionalAccessTokenScopes") - ret0, _ := ret[0].(bool) - return ret0 -} - -// AssertAdditionalAccessTokenScopes indicates an expected call of AssertAdditionalAccessTokenScopes -func (mr *MockClientMockRecorder) AssertAdditionalAccessTokenScopes() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalAccessTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalAccessTokenScopes)) -} - -// AssertAdditionalIdTokenScopes mocks base method -func (m *MockClient) AssertAdditionalIdTokenScopes() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AssertAdditionalIdTokenScopes") - ret0, _ := ret[0].(bool) - return ret0 -} - -// AssertAdditionalIdTokenScopes indicates an expected call of AssertAdditionalIdTokenScopes -func (mr *MockClientMockRecorder) AssertAdditionalIdTokenScopes() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalIdTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalIdTokenScopes)) -} - // AuthMethod mocks base method func (m *MockClient) AuthMethod() op.AuthMethod { m.ctrl.T.Helper() @@ -216,3 +188,31 @@ func (mr *MockClientMockRecorder) ResponseTypes() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResponseTypes", reflect.TypeOf((*MockClient)(nil).ResponseTypes)) } + +// RestrictAdditionalAccessTokenScopes mocks base method +func (m *MockClient) RestrictAdditionalAccessTokenScopes() func([]string) []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RestrictAdditionalAccessTokenScopes") + ret0, _ := ret[0].(func([]string) []string) + return ret0 +} + +// RestrictAdditionalAccessTokenScopes indicates an expected call of RestrictAdditionalAccessTokenScopes +func (mr *MockClientMockRecorder) RestrictAdditionalAccessTokenScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestrictAdditionalAccessTokenScopes", reflect.TypeOf((*MockClient)(nil).RestrictAdditionalAccessTokenScopes)) +} + +// RestrictAdditionalIdTokenScopes mocks base method +func (m *MockClient) RestrictAdditionalIdTokenScopes() func([]string) []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RestrictAdditionalIdTokenScopes") + ret0, _ := ret[0].(func([]string) []string) + return ret0 +} + +// RestrictAdditionalIdTokenScopes indicates an expected call of RestrictAdditionalIdTokenScopes +func (mr *MockClientMockRecorder) RestrictAdditionalIdTokenScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestrictAdditionalIdTokenScopes", reflect.TypeOf((*MockClient)(nil).RestrictAdditionalIdTokenScopes)) +} diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go index 693e489..92d5ad7 100644 --- a/pkg/op/mock/storage.mock.impl.go +++ b/pkg/op/mock/storage.mock.impl.go @@ -171,11 +171,15 @@ func (c *ConfClient) DevMode() bool { func (c *ConfClient) AllowedScopes() []string { return nil } -func (c *ConfClient) AssertAdditionalIdTokenScopes() bool { - return false +func (c *ConfClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } } -func (c *ConfClient) AssertAdditionalAccessTokenScopes() bool { - return false +func (c *ConfClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } } func (c *ConfClient) IsScopeAllowed(scope string) bool { return false diff --git a/pkg/op/token.go b/pkg/op/token.go index 2d66ef5..aff5bcb 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -31,7 +31,7 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client return nil, err } } - idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer(), client.AssertAdditionalIdTokenScopes()) + idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer(), client.RestrictAdditionalIdTokenScopes()) if err != nil { return nil, err } @@ -84,8 +84,9 @@ func CreateBearerToken(tokenID, subject string, crypto Crypto) (string, error) { func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, signer Signer, client Client, storage Storage) (string, error) { claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id) - if client != nil && client.AssertAdditionalAccessTokenScopes() { - privateClaims, err := storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(tokenRequest.GetScopes())) + if client != nil { + restrictedScopes := client.RestrictAdditionalAccessTokenScopes()(tokenRequest.GetScopes()) + privateClaims, err := storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(restrictedScopes)) if err != nil { return "", err } @@ -94,7 +95,7 @@ func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, ex return utils.Sign(claims, signer.Signer()) } -func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer, additonalScopes bool) (string, error) { +func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer, restictAdditionalScopesFunc func([]string) []string) (string, error) { exp := time.Now().UTC().Add(validity) claims := oidc.NewIDTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, authReq.GetAuthTime(), authReq.GetNonce(), authReq.GetACR(), authReq.GetAMR(), authReq.GetClientID()) scopes := authReq.GetScopes() @@ -107,9 +108,8 @@ func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, vali claims.SetAccessTokenHash(atHash) scopes = removeUserinfoScopes(scopes) } - if !additonalScopes { - scopes = removeAdditionalScopes(scopes) - } + scopes = restictAdditionalScopesFunc(scopes) + if len(scopes) > 0 { userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetClientID(), scopes) if err != nil {