feat: check allowed scopes (and pass clientID to GetUserinfoFromScopes)

This commit is contained in:
Livio Amstutz 2020-10-07 08:44:26 +02:00
parent b2903212ab
commit b311610d06
10 changed files with 101 additions and 22 deletions

View file

@ -211,9 +211,9 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st
} }
func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (oidc.UserInfoSetter, error) { func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (oidc.UserInfoSetter, error) {
return s.GetUserinfoFromScopes(ctx, "", []string{}) return s.GetUserinfoFromScopes(ctx, "", "", []string{})
} }
func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (oidc.UserInfoSetter, error) { func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _, _ string, _ []string) (oidc.UserInfoSetter, error) {
userinfo := oidc.NewUserInfo() userinfo := oidc.NewUserInfo()
userinfo.SetSubject(a.GetSubject()) userinfo.SetSubject(a.GetSubject())
userinfo.SetAddress(oidc.NewUserInfoAddress("Test 789\nPostfach 2", "", "", "", "", "")) userinfo.SetAddress(oidc.NewUserInfoAddress("Test 789\nPostfach 2", "", "", "", "", ""))
@ -276,3 +276,7 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType {
func (c *ConfClient) DevMode() bool { func (c *ConfClient) DevMode() bool {
return c.devMode return c.devMode
} }
func (c *ConfClient) AllowedScopes() []string {
return nil
}

View file

@ -91,7 +91,8 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
if err != nil { if err != nil {
return "", ErrServerError(err.Error()) return "", ErrServerError(err.Error())
} }
if err := ValidateAuthReqScopes(authReq.Scopes); err != nil { authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes)
if err != nil {
return "", err return "", err
} }
if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil { if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil {
@ -104,14 +105,33 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
} }
//ValidateAuthReqScopes validates the passed scopes //ValidateAuthReqScopes validates the passed scopes
func ValidateAuthReqScopes(scopes []string) error { func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
if len(scopes) == 0 { if len(scopes) == 0 {
return ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.") return nil, ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.")
} }
if !utils.Contains(scopes, oidc.ScopeOpenID) { openID := false
return ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.") for i := len(scopes) - 1; i >= 0; i-- {
switch scopes[i] {
case oidc.ScopeOpenID:
openID = true
case oidc.ScopeProfile,
oidc.ScopeEmail,
oidc.ScopePhone,
oidc.ScopeAddress,
oidc.ScopeOfflineAccess:
default:
if !utils.Contains(client.AllowedScopes(), scopes[i]) {
scopes[i] = scopes[len(scopes)-1]
scopes[len(scopes)-1] = ""
scopes = scopes[:len(scopes)-1]
}
}
} }
return nil if !openID {
return nil, ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.")
}
return scopes, nil
} }
//ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type //ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/gorilla/schema" "github.com/gorilla/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
@ -193,28 +194,63 @@ func TestValidateAuthRequest(t *testing.T) {
func TestValidateAuthReqScopes(t *testing.T) { func TestValidateAuthReqScopes(t *testing.T) {
type args struct { type args struct {
client op.Client
scopes []string
}
type res struct {
err bool
scopes []string scopes []string
} }
tests := []struct { tests := []struct {
name string name string
args args args args
wantErr bool res res
}{ }{
{ {
"scopes missing fails", args{}, true, "scopes missing fails",
args{},
res{
err: true,
},
}, },
{ {
"scope openid missing fails", args{[]string{"email"}}, true, "scope openid missing fails",
args{
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
[]string{"email"},
},
res{
err: true,
},
}, },
{ {
"scope ok", args{[]string{"openid"}}, false, "scope ok",
args{
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
[]string{"openid"},
},
res{
scopes: []string{"openid"},
},
},
{
"scope with drop ok",
args{
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
[]string{"openid", "email", "unknown"},
},
res{
scopes: []string{"openid", "email"},
},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := op.ValidateAuthReqScopes(tt.args.scopes); (err != nil) != tt.wantErr { scopes, err := op.ValidateAuthReqScopes(tt.args.client, tt.args.scopes)
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr) if (err != nil) != tt.res.err {
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.res.err)
} }
assert.ElementsMatch(t, scopes, tt.res.scopes)
}) })
} }
} }

View file

@ -32,6 +32,7 @@ type Client interface {
AccessTokenType() AccessTokenType AccessTokenType() AccessTokenType
IDTokenLifetime() time.Duration IDTokenLifetime() time.Duration
DevMode() bool DevMode() bool
AllowedScopes() []string
} }
func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool { func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool {

View file

@ -26,6 +26,7 @@ func NewClientExpectAny(t *testing.T, appType op.ApplicationType) op.Client {
func(id string) string { func(id string) string {
return "login?id=" + id return "login?id=" + id
}) })
m.EXPECT().AllowedScopes().AnyTimes().Return(nil)
return c return c
} }

View file

@ -49,6 +49,20 @@ func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType))
} }
// AllowedScopes mocks base method
func (m *MockClient) AllowedScopes() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AllowedScopes")
ret0, _ := ret[0].([]string)
return ret0
}
// AllowedScopes indicates an expected call of AllowedScopes
func (mr *MockClientMockRecorder) AllowedScopes() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowedScopes", reflect.TypeOf((*MockClient)(nil).AllowedScopes))
}
// ApplicationType mocks base method // ApplicationType mocks base method
func (m *MockClient) ApplicationType() op.ApplicationType { func (m *MockClient) ApplicationType() op.ApplicationType {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -184,18 +184,18 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac
} }
// GetUserinfoFromScopes mocks base method // GetUserinfoFromScopes mocks base method
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (oidc.UserInfoSetter, error) { func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (oidc.UserInfoSetter, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2) ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(oidc.UserInfoSetter) ret0, _ := ret[0].(oidc.UserInfoSetter)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes // GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interface{}) *gomock.Call { func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2, arg3)
} }
// GetUserinfoFromToken mocks base method // GetUserinfoFromToken mocks base method

View file

@ -168,3 +168,6 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType {
func (c *ConfClient) DevMode() bool { func (c *ConfClient) DevMode() bool {
return c.devMode return c.devMode
} }
func (c *ConfClient) AllowedScopes() []string {
return nil
}

View file

@ -28,7 +28,7 @@ type AuthStorage interface {
type OPStorage interface { type OPStorage interface {
GetClientByClientID(context.Context, string) (Client, error) GetClientByClientID(context.Context, string) (Client, error)
AuthorizeClientIDSecret(context.Context, string, string) error AuthorizeClientIDSecret(context.Context, string, string) error
GetUserinfoFromScopes(context.Context, string, []string) (oidc.UserInfoSetter, error) GetUserinfoFromScopes(context.Context, string, string, []string) (oidc.UserInfoSetter, error)
GetUserinfoFromToken(context.Context, string, string) (oidc.UserInfoSetter, error) GetUserinfoFromToken(context.Context, string, string) (oidc.UserInfoSetter, error)
GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
} }

View file

@ -98,7 +98,7 @@ func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, vali
} }
claims.SetAccessTokenHash(atHash) claims.SetAccessTokenHash(atHash)
} else { } else {
userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes()) userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetClientID(), authReq.GetScopes())
if err != nil { if err != nil {
return "", err return "", err
} }