From a8d10945d8d15e68c61be34ccde212c44b16ff1d Mon Sep 17 00:00:00 2001 From: livio-a Date: Wed, 11 Mar 2020 09:41:54 +0100 Subject: [PATCH] feat: preselect user with id_token_hint (#16) * feat: preselect user with id_token_hint * fix tests --- example/internal/mock/storage.go | 2 +- pkg/op/authrequest.go | 36 ++++++++++++++++++----------- pkg/op/authrequest_test.go | 15 +++++++----- pkg/op/mock/authorizer.mock.go | 15 ++++++++++++ pkg/op/mock/authorizer.mock.impl.go | 28 ++++++++++++---------- pkg/op/mock/storage.mock.go | 8 +++---- pkg/op/storage.go | 2 +- 7 files changed, 68 insertions(+), 38 deletions(-) diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index c5797db..96f9b45 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -116,7 +116,7 @@ func (s *AuthStorage) Health(ctx context.Context) error { return nil } -func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthRequest) (op.AuthRequest, error) { +func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthRequest, _ string) (op.AuthRequest, error) { a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI} if authReq.CodeChallenge != "" { a.CodeChallenge = &oidc.CodeChallenge{ diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index 0e13df2..e01de51 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -11,6 +11,7 @@ import ( "github.com/gorilla/schema" "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/rp" "github.com/caos/oidc/pkg/utils" ) @@ -19,13 +20,14 @@ type Authorizer interface { Decoder() *schema.Decoder Encoder() *schema.Encoder Signer() Signer + IDTokenVerifier() rp.Verifier Crypto() Crypto Issuer() string } type ValidationAuthorizer interface { Authorizer - ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage) error + ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, rp.Verifier) (string, error) } func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { @@ -44,11 +46,12 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { if validater, ok := authorizer.(ValidationAuthorizer); ok { validation = validater.ValidateAuthRequest } - if err := validation(r.Context(), authReq, authorizer.Storage()); err != nil { + userID, err := validation(r.Context(), authReq, authorizer.Storage(), authorizer.IDTokenVerifier()) + if err != nil { AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return } - req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq) + req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq, userID) if err != nil { AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return @@ -61,23 +64,17 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { RedirectToLogin(req.GetID(), client, w, r) } -func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage) error { +func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier rp.Verifier) (string, error) { if err := ValidateAuthReqScopes(authReq.Scopes); err != nil { - return err + return "", err } if err := ValidateAuthReqRedirectURI(ctx, authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil { - return err + return "", err } if err := ValidateAuthReqResponseType(authReq.ResponseType); err != nil { - return err + return "", err } - // if NeedsExistingSession(authReq) { - // session, err := storage.CheckSession(authReq.IDTokenHint) - // if err != nil { - // return err - // } - // } - return nil + return ValidateAuthReqIDTokenHint(ctx, authReq.IDTokenHint, verifier) } func ValidateAuthReqScopes(scopes []string) error { @@ -130,6 +127,17 @@ func ValidateAuthReqResponseType(responseType oidc.ResponseType) error { return nil } +func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier rp.Verifier) (string, error) { + if idTokenHint == "" { + return "", nil + } + claims, err := verifier.Verify(ctx, "", idTokenHint) + if err != nil { + return "", ErrInvalidRequest("id_token_hint invalid") + } + return claims.Subject, nil +} + func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *http.Request) { login := client.LoginURL(authReqID) http.Redirect(w, r, login, http.StatusFound) diff --git a/pkg/op/authrequest_test.go b/pkg/op/authrequest_test.go index b0599c3..dca72fa 100644 --- a/pkg/op/authrequest_test.go +++ b/pkg/op/authrequest_test.go @@ -11,6 +11,7 @@ import ( "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/op" "github.com/caos/oidc/pkg/op/mock" + "github.com/caos/oidc/pkg/rp" ) func TestAuthorize(t *testing.T) { @@ -70,6 +71,7 @@ func TestValidateAuthRequest(t *testing.T) { type args struct { authRequest *oidc.AuthRequest storage op.Storage + verifier rp.Verifier } tests := []struct { name string @@ -82,33 +84,34 @@ func TestValidateAuthRequest(t *testing.T) { // } { "scope missing fails", - args{&oidc.AuthRequest{}, nil}, + args{&oidc.AuthRequest{}, nil, nil}, true, }, { "scope openid missing fails", - args{&oidc.AuthRequest{Scopes: []string{"profile"}}, nil}, + args{&oidc.AuthRequest{Scopes: []string{"profile"}}, nil, nil}, true, }, { "response_type missing fails", - args{&oidc.AuthRequest{Scopes: []string{"openid"}}, nil}, + args{&oidc.AuthRequest{Scopes: []string{"openid"}}, nil, nil}, true, }, { "client_id missing fails", - args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, nil}, + args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, nil, nil}, true, }, { "redirect_uri missing fails", - args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode, ClientID: "client_id"}, nil}, + args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode, ClientID: "client_id"}, nil, nil}, true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := op.ValidateAuthRequest(nil, tt.args.authRequest, tt.args.storage); (err != nil) != tt.wantErr { + _, err := op.ValidateAuthRequest(nil, tt.args.authRequest, tt.args.storage, tt.args.verifier) + if (err != nil) != tt.wantErr { t.Errorf("ValidateAuthRequest() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/pkg/op/mock/authorizer.mock.go b/pkg/op/mock/authorizer.mock.go index 48f9aed..dbfc2a6 100644 --- a/pkg/op/mock/authorizer.mock.go +++ b/pkg/op/mock/authorizer.mock.go @@ -6,6 +6,7 @@ package mock import ( op "github.com/caos/oidc/pkg/op" + rp "github.com/caos/oidc/pkg/rp" gomock "github.com/golang/mock/gomock" schema "github.com/gorilla/schema" reflect "reflect" @@ -76,6 +77,20 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encoder", reflect.TypeOf((*MockAuthorizer)(nil).Encoder)) } +// IDTokenVerifier mocks base method +func (m *MockAuthorizer) IDTokenVerifier() rp.Verifier { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IDTokenVerifier") + ret0, _ := ret[0].(rp.Verifier) + return ret0 +} + +// IDTokenVerifier indicates an expected call of IDTokenVerifier +func (mr *MockAuthorizerMockRecorder) IDTokenVerifier() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenVerifier)) +} + // Issuer mocks base method func (m *MockAuthorizer) Issuer() string { m.ctrl.T.Helper() diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go index 93994ad..29c9354 100644 --- a/pkg/op/mock/authorizer.mock.impl.go +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -8,8 +8,9 @@ import ( "github.com/gorilla/schema" "gopkg.in/square/go-jose.v2" - oidc "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/op" + "github.com/caos/oidc/pkg/rp" ) func NewAuthorizer(t *testing.T) op.Authorizer { @@ -22,6 +23,7 @@ func NewAuthorizerExpectValid(t *testing.T, wantErr bool) op.Authorizer { ExpectEncoder(m) ExpectSigner(m, t) ExpectStorage(m, t) + ExpectVerifier(m, t) // ExpectErrorHandler(m, t, wantErr) return m } @@ -54,17 +56,19 @@ func ExpectSigner(a op.Authorizer, t *testing.T) { }) } -// func ExpectErrorHandler(a op.Authorizer, t *testing.T, wantErr bool) { -// mockA := a.(*MockAuthorizer) -// mockA.EXPECT().ErrorHandler().AnyTimes(). -// Return(func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { -// if wantErr { -// require.Error(t, err) -// return -// } -// require.NoError(t, err) -// }) -// } +func ExpectVerifier(a op.Authorizer, t *testing.T) { + mockA := a.(*MockAuthorizer) + mockA.EXPECT().IDTokenVerifier().DoAndReturn( + func() rp.Verifier { + return &Verifier{} + }) +} + +type Verifier struct{} + +func (v *Verifier) Verify(ctx context.Context, accessToken, idToken string) (*oidc.IDTokenClaims, error) { + return nil, nil +} type Sig struct{} diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 04316c3..405865c 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -67,18 +67,18 @@ func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1, arg2 inte } // CreateAuthRequest mocks base method -func (m *MockStorage) CreateAuthRequest(arg0 context.Context, arg1 *oidc.AuthRequest) (op.AuthRequest, error) { +func (m *MockStorage) CreateAuthRequest(arg0 context.Context, arg1 *oidc.AuthRequest, arg2 string) (op.AuthRequest, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateAuthRequest", arg0, arg1) + ret := m.ctrl.Call(m, "CreateAuthRequest", arg0, arg1, arg2) ret0, _ := ret[0].(op.AuthRequest) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateAuthRequest indicates an expected call of CreateAuthRequest -func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0, arg1, arg2) } // CreateToken mocks base method diff --git a/pkg/op/storage.go b/pkg/op/storage.go index f213618..4655b88 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -10,7 +10,7 @@ import ( ) type AuthStorage interface { - CreateAuthRequest(context.Context, *oidc.AuthRequest) (AuthRequest, error) + CreateAuthRequest(context.Context, *oidc.AuthRequest, string) (AuthRequest, error) AuthRequestByID(context.Context, string) (AuthRequest, error) DeleteAuthRequest(context.Context, string) error