feat: check allowed scopes (and pass clientID to GetUserinfoFromScopes)

This commit is contained in:
Livio Amstutz 2020-10-07 08:44:26 +02:00
parent b2903212ab
commit b311610d06
10 changed files with 101 additions and 22 deletions

View file

@ -211,9 +211,9 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st
}
func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (oidc.UserInfoSetter, error) {
return s.GetUserinfoFromScopes(ctx, "", []string{})
return s.GetUserinfoFromScopes(ctx, "", "", []string{})
}
func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (oidc.UserInfoSetter, error) {
func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _, _ string, _ []string) (oidc.UserInfoSetter, error) {
userinfo := oidc.NewUserInfo()
userinfo.SetSubject(a.GetSubject())
userinfo.SetAddress(oidc.NewUserInfoAddress("Test 789\nPostfach 2", "", "", "", "", ""))
@ -276,3 +276,7 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType {
func (c *ConfClient) DevMode() bool {
return c.devMode
}
func (c *ConfClient) AllowedScopes() []string {
return nil
}

View file

@ -91,7 +91,8 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
if err != nil {
return "", ErrServerError(err.Error())
}
if err := ValidateAuthReqScopes(authReq.Scopes); err != nil {
authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes)
if err != nil {
return "", err
}
if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil {
@ -104,14 +105,33 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
}
//ValidateAuthReqScopes validates the passed scopes
func ValidateAuthReqScopes(scopes []string) error {
func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
if len(scopes) == 0 {
return ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.")
return nil, ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.")
}
if !utils.Contains(scopes, oidc.ScopeOpenID) {
return ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.")
openID := false
for i := len(scopes) - 1; i >= 0; i-- {
switch scopes[i] {
case oidc.ScopeOpenID:
openID = true
case oidc.ScopeProfile,
oidc.ScopeEmail,
oidc.ScopePhone,
oidc.ScopeAddress,
oidc.ScopeOfflineAccess:
default:
if !utils.Contains(client.AllowedScopes(), scopes[i]) {
scopes[i] = scopes[len(scopes)-1]
scopes[len(scopes)-1] = ""
scopes = scopes[:len(scopes)-1]
}
}
}
return nil
if !openID {
return nil, ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.")
}
return scopes, nil
}
//ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type

View file

@ -8,6 +8,7 @@ import (
"testing"
"github.com/gorilla/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/caos/oidc/pkg/oidc"
@ -193,28 +194,63 @@ func TestValidateAuthRequest(t *testing.T) {
func TestValidateAuthReqScopes(t *testing.T) {
type args struct {
client op.Client
scopes []string
}
type res struct {
err bool
scopes []string
}
tests := []struct {
name string
args args
wantErr bool
name string
args args
res res
}{
{
"scopes missing fails", args{}, true,
"scopes missing fails",
args{},
res{
err: true,
},
},
{
"scope openid missing fails", args{[]string{"email"}}, true,
"scope openid missing fails",
args{
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
[]string{"email"},
},
res{
err: true,
},
},
{
"scope ok", args{[]string{"openid"}}, false,
"scope ok",
args{
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
[]string{"openid"},
},
res{
scopes: []string{"openid"},
},
},
{
"scope with drop ok",
args{
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
[]string{"openid", "email", "unknown"},
},
res{
scopes: []string{"openid", "email"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := op.ValidateAuthReqScopes(tt.args.scopes); (err != nil) != tt.wantErr {
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr)
scopes, err := op.ValidateAuthReqScopes(tt.args.client, tt.args.scopes)
if (err != nil) != tt.res.err {
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.res.err)
}
assert.ElementsMatch(t, scopes, tt.res.scopes)
})
}
}

View file

@ -32,6 +32,7 @@ type Client interface {
AccessTokenType() AccessTokenType
IDTokenLifetime() time.Duration
DevMode() bool
AllowedScopes() []string
}
func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool {

View file

@ -26,6 +26,7 @@ func NewClientExpectAny(t *testing.T, appType op.ApplicationType) op.Client {
func(id string) string {
return "login?id=" + id
})
m.EXPECT().AllowedScopes().AnyTimes().Return(nil)
return c
}

View file

@ -49,6 +49,20 @@ func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType))
}
// AllowedScopes mocks base method
func (m *MockClient) AllowedScopes() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AllowedScopes")
ret0, _ := ret[0].([]string)
return ret0
}
// AllowedScopes indicates an expected call of AllowedScopes
func (mr *MockClientMockRecorder) AllowedScopes() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowedScopes", reflect.TypeOf((*MockClient)(nil).AllowedScopes))
}
// ApplicationType mocks base method
func (m *MockClient) ApplicationType() op.ApplicationType {
m.ctrl.T.Helper()

View file

@ -184,18 +184,18 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac
}
// GetUserinfoFromScopes mocks base method
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (oidc.UserInfoSetter, error) {
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (oidc.UserInfoSetter, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2)
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(oidc.UserInfoSetter)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interface{}) *gomock.Call {
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2, arg3)
}
// GetUserinfoFromToken mocks base method

View file

@ -168,3 +168,6 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType {
func (c *ConfClient) DevMode() bool {
return c.devMode
}
func (c *ConfClient) AllowedScopes() []string {
return nil
}

View file

@ -28,7 +28,7 @@ type AuthStorage interface {
type OPStorage interface {
GetClientByClientID(context.Context, string) (Client, error)
AuthorizeClientIDSecret(context.Context, string, string) error
GetUserinfoFromScopes(context.Context, string, []string) (oidc.UserInfoSetter, error)
GetUserinfoFromScopes(context.Context, string, string, []string) (oidc.UserInfoSetter, error)
GetUserinfoFromToken(context.Context, string, string) (oidc.UserInfoSetter, error)
GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
}

View file

@ -98,7 +98,7 @@ func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, vali
}
claims.SetAccessTokenHash(atHash)
} else {
userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes())
userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetClientID(), authReq.GetScopes())
if err != nil {
return "", err
}