feat: check allowed scopes (and pass clientID to GetUserinfoFromScopes)
This commit is contained in:
parent
b2903212ab
commit
b311610d06
10 changed files with 101 additions and 22 deletions
|
@ -211,9 +211,9 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st
|
|||
}
|
||||
|
||||
func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (oidc.UserInfoSetter, error) {
|
||||
return s.GetUserinfoFromScopes(ctx, "", []string{})
|
||||
return s.GetUserinfoFromScopes(ctx, "", "", []string{})
|
||||
}
|
||||
func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (oidc.UserInfoSetter, error) {
|
||||
func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _, _ string, _ []string) (oidc.UserInfoSetter, error) {
|
||||
userinfo := oidc.NewUserInfo()
|
||||
userinfo.SetSubject(a.GetSubject())
|
||||
userinfo.SetAddress(oidc.NewUserInfoAddress("Test 789\nPostfach 2", "", "", "", "", ""))
|
||||
|
@ -276,3 +276,7 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType {
|
|||
func (c *ConfClient) DevMode() bool {
|
||||
return c.devMode
|
||||
}
|
||||
|
||||
func (c *ConfClient) AllowedScopes() []string {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -91,7 +91,8 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
|
|||
if err != nil {
|
||||
return "", ErrServerError(err.Error())
|
||||
}
|
||||
if err := ValidateAuthReqScopes(authReq.Scopes); err != nil {
|
||||
authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil {
|
||||
|
@ -104,14 +105,33 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
|
|||
}
|
||||
|
||||
//ValidateAuthReqScopes validates the passed scopes
|
||||
func ValidateAuthReqScopes(scopes []string) error {
|
||||
func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
|
||||
if len(scopes) == 0 {
|
||||
return ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.")
|
||||
return nil, ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.")
|
||||
}
|
||||
if !utils.Contains(scopes, oidc.ScopeOpenID) {
|
||||
return ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.")
|
||||
openID := false
|
||||
for i := len(scopes) - 1; i >= 0; i-- {
|
||||
switch scopes[i] {
|
||||
case oidc.ScopeOpenID:
|
||||
openID = true
|
||||
case oidc.ScopeProfile,
|
||||
oidc.ScopeEmail,
|
||||
oidc.ScopePhone,
|
||||
oidc.ScopeAddress,
|
||||
oidc.ScopeOfflineAccess:
|
||||
default:
|
||||
if !utils.Contains(client.AllowedScopes(), scopes[i]) {
|
||||
scopes[i] = scopes[len(scopes)-1]
|
||||
scopes[len(scopes)-1] = ""
|
||||
scopes = scopes[:len(scopes)-1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
if !openID {
|
||||
return nil, ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.")
|
||||
}
|
||||
|
||||
return scopes, nil
|
||||
}
|
||||
|
||||
//ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
|
@ -193,28 +194,63 @@ func TestValidateAuthRequest(t *testing.T) {
|
|||
|
||||
func TestValidateAuthReqScopes(t *testing.T) {
|
||||
type args struct {
|
||||
client op.Client
|
||||
scopes []string
|
||||
}
|
||||
type res struct {
|
||||
err bool
|
||||
scopes []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"scopes missing fails", args{}, true,
|
||||
"scopes missing fails",
|
||||
args{},
|
||||
res{
|
||||
err: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"scope openid missing fails", args{[]string{"email"}}, true,
|
||||
"scope openid missing fails",
|
||||
args{
|
||||
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
|
||||
[]string{"email"},
|
||||
},
|
||||
res{
|
||||
err: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"scope ok", args{[]string{"openid"}}, false,
|
||||
"scope ok",
|
||||
args{
|
||||
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
|
||||
[]string{"openid"},
|
||||
},
|
||||
res{
|
||||
scopes: []string{"openid"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"scope with drop ok",
|
||||
args{
|
||||
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
|
||||
[]string{"openid", "email", "unknown"},
|
||||
},
|
||||
res{
|
||||
scopes: []string{"openid", "email"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := op.ValidateAuthReqScopes(tt.args.scopes); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr)
|
||||
scopes, err := op.ValidateAuthReqScopes(tt.args.client, tt.args.scopes)
|
||||
if (err != nil) != tt.res.err {
|
||||
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.res.err)
|
||||
}
|
||||
assert.ElementsMatch(t, scopes, tt.res.scopes)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,6 +32,7 @@ type Client interface {
|
|||
AccessTokenType() AccessTokenType
|
||||
IDTokenLifetime() time.Duration
|
||||
DevMode() bool
|
||||
AllowedScopes() []string
|
||||
}
|
||||
|
||||
func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool {
|
||||
|
|
|
@ -26,6 +26,7 @@ func NewClientExpectAny(t *testing.T, appType op.ApplicationType) op.Client {
|
|||
func(id string) string {
|
||||
return "login?id=" + id
|
||||
})
|
||||
m.EXPECT().AllowedScopes().AnyTimes().Return(nil)
|
||||
return c
|
||||
}
|
||||
|
||||
|
|
|
@ -49,6 +49,20 @@ func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType))
|
||||
}
|
||||
|
||||
// AllowedScopes mocks base method
|
||||
func (m *MockClient) AllowedScopes() []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AllowedScopes")
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AllowedScopes indicates an expected call of AllowedScopes
|
||||
func (mr *MockClientMockRecorder) AllowedScopes() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowedScopes", reflect.TypeOf((*MockClient)(nil).AllowedScopes))
|
||||
}
|
||||
|
||||
// ApplicationType mocks base method
|
||||
func (m *MockClient) ApplicationType() op.ApplicationType {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -184,18 +184,18 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac
|
|||
}
|
||||
|
||||
// GetUserinfoFromScopes mocks base method
|
||||
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (oidc.UserInfoSetter, error) {
|
||||
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (oidc.UserInfoSetter, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2)
|
||||
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(oidc.UserInfoSetter)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
|
||||
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// GetUserinfoFromToken mocks base method
|
||||
|
|
|
@ -168,3 +168,6 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType {
|
|||
func (c *ConfClient) DevMode() bool {
|
||||
return c.devMode
|
||||
}
|
||||
func (c *ConfClient) AllowedScopes() []string {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ type AuthStorage interface {
|
|||
type OPStorage interface {
|
||||
GetClientByClientID(context.Context, string) (Client, error)
|
||||
AuthorizeClientIDSecret(context.Context, string, string) error
|
||||
GetUserinfoFromScopes(context.Context, string, []string) (oidc.UserInfoSetter, error)
|
||||
GetUserinfoFromScopes(context.Context, string, string, []string) (oidc.UserInfoSetter, error)
|
||||
GetUserinfoFromToken(context.Context, string, string) (oidc.UserInfoSetter, error)
|
||||
GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
|
||||
}
|
||||
|
|
|
@ -98,7 +98,7 @@ func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, vali
|
|||
}
|
||||
claims.SetAccessTokenHash(atHash)
|
||||
} else {
|
||||
userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes())
|
||||
userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetClientID(), authReq.GetScopes())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue