From b311610d063e02d27d7dc7771b4d8b4563376209 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Wed, 7 Oct 2020 08:44:26 +0200 Subject: [PATCH] feat: check allowed scopes (and pass clientID to GetUserinfoFromScopes) --- example/internal/mock/storage.go | 8 +++-- pkg/op/authrequest.go | 32 ++++++++++++++++---- pkg/op/authrequest_test.go | 52 +++++++++++++++++++++++++++----- pkg/op/client.go | 1 + pkg/op/mock/client.go | 1 + pkg/op/mock/client.mock.go | 14 +++++++++ pkg/op/mock/storage.mock.go | 8 ++--- pkg/op/mock/storage.mock.impl.go | 3 ++ pkg/op/storage.go | 2 +- pkg/op/token.go | 2 +- 10 files changed, 101 insertions(+), 22 deletions(-) diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 1c33906..e3a4e1a 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -211,9 +211,9 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st } func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (oidc.UserInfoSetter, error) { - return s.GetUserinfoFromScopes(ctx, "", []string{}) + return s.GetUserinfoFromScopes(ctx, "", "", []string{}) } -func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (oidc.UserInfoSetter, error) { +func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _, _ string, _ []string) (oidc.UserInfoSetter, error) { userinfo := oidc.NewUserInfo() userinfo.SetSubject(a.GetSubject()) userinfo.SetAddress(oidc.NewUserInfoAddress("Test 789\nPostfach 2", "", "", "", "", "")) @@ -276,3 +276,7 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType { func (c *ConfClient) DevMode() bool { return c.devMode } + +func (c *ConfClient) AllowedScopes() []string { + return nil +} diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index cee6184..86e2275 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -91,7 +91,8 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage if err != nil { return "", ErrServerError(err.Error()) } - if err := ValidateAuthReqScopes(authReq.Scopes); err != nil { + authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes) + if err != nil { return "", err } if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil { @@ -104,14 +105,33 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage } //ValidateAuthReqScopes validates the passed scopes -func ValidateAuthReqScopes(scopes []string) error { +func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) { if len(scopes) == 0 { - return ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.") + return nil, ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.") } - if !utils.Contains(scopes, oidc.ScopeOpenID) { - return ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.") + openID := false + for i := len(scopes) - 1; i >= 0; i-- { + switch scopes[i] { + case oidc.ScopeOpenID: + openID = true + case oidc.ScopeProfile, + oidc.ScopeEmail, + oidc.ScopePhone, + oidc.ScopeAddress, + oidc.ScopeOfflineAccess: + default: + if !utils.Contains(client.AllowedScopes(), scopes[i]) { + scopes[i] = scopes[len(scopes)-1] + scopes[len(scopes)-1] = "" + scopes = scopes[:len(scopes)-1] + } + } } - return nil + if !openID { + return nil, ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.") + } + + return scopes, nil } //ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type diff --git a/pkg/op/authrequest_test.go b/pkg/op/authrequest_test.go index d74d365..3856acd 100644 --- a/pkg/op/authrequest_test.go +++ b/pkg/op/authrequest_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/gorilla/schema" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/caos/oidc/pkg/oidc" @@ -193,28 +194,63 @@ func TestValidateAuthRequest(t *testing.T) { func TestValidateAuthReqScopes(t *testing.T) { type args struct { + client op.Client + scopes []string + } + type res struct { + err bool scopes []string } tests := []struct { - name string - args args - wantErr bool + name string + args args + res res }{ { - "scopes missing fails", args{}, true, + "scopes missing fails", + args{}, + res{ + err: true, + }, }, { - "scope openid missing fails", args{[]string{"email"}}, true, + "scope openid missing fails", + args{ + mock.NewClientExpectAny(t, op.ApplicationTypeWeb), + []string{"email"}, + }, + res{ + err: true, + }, }, { - "scope ok", args{[]string{"openid"}}, false, + "scope ok", + args{ + mock.NewClientExpectAny(t, op.ApplicationTypeWeb), + []string{"openid"}, + }, + res{ + scopes: []string{"openid"}, + }, + }, + { + "scope with drop ok", + args{ + mock.NewClientExpectAny(t, op.ApplicationTypeWeb), + []string{"openid", "email", "unknown"}, + }, + res{ + scopes: []string{"openid", "email"}, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := op.ValidateAuthReqScopes(tt.args.scopes); (err != nil) != tt.wantErr { - t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr) + scopes, err := op.ValidateAuthReqScopes(tt.args.client, tt.args.scopes) + if (err != nil) != tt.res.err { + t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.res.err) } + assert.ElementsMatch(t, scopes, tt.res.scopes) }) } } diff --git a/pkg/op/client.go b/pkg/op/client.go index 3184b90..258ce6e 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -32,6 +32,7 @@ type Client interface { AccessTokenType() AccessTokenType IDTokenLifetime() time.Duration DevMode() bool + AllowedScopes() []string } func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool { diff --git a/pkg/op/mock/client.go b/pkg/op/mock/client.go index eed21d5..12c00cc 100644 --- a/pkg/op/mock/client.go +++ b/pkg/op/mock/client.go @@ -26,6 +26,7 @@ func NewClientExpectAny(t *testing.T, appType op.ApplicationType) op.Client { func(id string) string { return "login?id=" + id }) + m.EXPECT().AllowedScopes().AnyTimes().Return(nil) return c } diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go index 4007347..8e18d56 100644 --- a/pkg/op/mock/client.mock.go +++ b/pkg/op/mock/client.mock.go @@ -49,6 +49,20 @@ func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType)) } +// AllowedScopes mocks base method +func (m *MockClient) AllowedScopes() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AllowedScopes") + ret0, _ := ret[0].([]string) + return ret0 +} + +// AllowedScopes indicates an expected call of AllowedScopes +func (mr *MockClientMockRecorder) AllowedScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowedScopes", reflect.TypeOf((*MockClient)(nil).AllowedScopes)) +} + // ApplicationType mocks base method func (m *MockClient) ApplicationType() op.ApplicationType { m.ctrl.T.Helper() diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 1bcd1a6..973f58b 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -184,18 +184,18 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac } // GetUserinfoFromScopes mocks base method -func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (oidc.UserInfoSetter, error) { +func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (oidc.UserInfoSetter, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(oidc.UserInfoSetter) ret1, _ := ret[1].(error) return ret0, ret1 } // GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes -func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2, arg3) } // GetUserinfoFromToken mocks base method diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go index 6fd2760..54cd059 100644 --- a/pkg/op/mock/storage.mock.impl.go +++ b/pkg/op/mock/storage.mock.impl.go @@ -168,3 +168,6 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType { func (c *ConfClient) DevMode() bool { return c.devMode } +func (c *ConfClient) AllowedScopes() []string { + return nil +} diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 69784ee..1c266d7 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -28,7 +28,7 @@ type AuthStorage interface { type OPStorage interface { GetClientByClientID(context.Context, string) (Client, error) AuthorizeClientIDSecret(context.Context, string, string) error - GetUserinfoFromScopes(context.Context, string, []string) (oidc.UserInfoSetter, error) + GetUserinfoFromScopes(context.Context, string, string, []string) (oidc.UserInfoSetter, error) GetUserinfoFromToken(context.Context, string, string) (oidc.UserInfoSetter, error) GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) } diff --git a/pkg/op/token.go b/pkg/op/token.go index a2236d4..670fca7 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -98,7 +98,7 @@ func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, vali } claims.SetAccessTokenHash(atHash) } else { - userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes()) + userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetClientID(), authReq.GetScopes()) if err != nil { return "", err }