From 58545a1710cd9720289ac9529be0fe36aca7efc7 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Fri, 29 May 2020 09:40:34 +0200 Subject: [PATCH] fix: handle code separately (#30) --- example/internal/mock/storage.go | 13 ++++++++++- pkg/op/authrequest.go | 13 ++++++++++- pkg/op/mock/signer.mock.go | 6 +++--- pkg/op/mock/storage.mock.go | 37 ++++++++++++++++++++++++++++---- pkg/op/storage.go | 2 ++ pkg/op/token.go | 5 +++++ pkg/op/tokenrequest.go | 17 ++------------- 7 files changed, 69 insertions(+), 24 deletions(-) diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 96f9b45..5fb823b 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -110,6 +110,7 @@ func (a *AuthRequest) Done() bool { var ( a = &AuthRequest{} t bool + c string ) func (s *AuthStorage) Health(ctx context.Context) error { @@ -127,9 +128,19 @@ func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthReq t = false return a, nil } -func (s *AuthStorage) AuthRequestByCode(context.Context, string) (op.AuthRequest, error) { +func (s *AuthStorage) AuthRequestByCode(_ context.Context, code string) (op.AuthRequest, error) { + if code != c { + return nil, errors.New("invalid code") + } return a, nil } +func (s *AuthStorage) SaveAuthCode(_ context.Context, id, code string) error { + if a.ID != id { + return errors.New("not found") + } + c = code + return nil +} func (s *AuthStorage) DeleteAuthRequest(context.Context, string) error { t = true return nil diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index e01de51..c25f60d 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -173,7 +173,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri } func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) { - code, err := BuildAuthRequestCode(authReq, authorizer.Crypto()) + code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto()) if err != nil { AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return @@ -201,6 +201,17 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque http.Redirect(w, r, callback, http.StatusFound) } +func CreateAuthRequestCode(ctx context.Context, authReq AuthRequest, storage Storage, crypto Crypto) (string, error) { + code, err := BuildAuthRequestCode(authReq, crypto) + if err != nil { + return "", err + } + if err := storage.SaveAuthCode(ctx, authReq.GetID(), code); err != nil { + return "", err + } + return code, nil +} + func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) { return crypto.Encrypt(authReq.GetID()) } diff --git a/pkg/op/mock/signer.mock.go b/pkg/op/mock/signer.mock.go index c780752..a7d909c 100644 --- a/pkg/op/mock/signer.mock.go +++ b/pkg/op/mock/signer.mock.go @@ -8,7 +8,7 @@ import ( context "context" oidc "github.com/caos/oidc/pkg/oidc" gomock "github.com/golang/mock/gomock" - go_jose_v2 "gopkg.in/square/go-jose.v2" + jose "gopkg.in/square/go-jose.v2" reflect "reflect" ) @@ -80,10 +80,10 @@ func (mr *MockSignerMockRecorder) SignIDToken(arg0 interface{}) *gomock.Call { } // SignatureAlgorithm mocks base method -func (m *MockSigner) SignatureAlgorithm() go_jose_v2.SignatureAlgorithm { +func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SignatureAlgorithm") - ret0, _ := ret[0].(go_jose_v2.SignatureAlgorithm) + ret0, _ := ret[0].(jose.SignatureAlgorithm) return ret0 } diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 405865c..ac8ba27 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -9,7 +9,7 @@ import ( oidc "github.com/caos/oidc/pkg/oidc" op "github.com/caos/oidc/pkg/op" gomock "github.com/golang/mock/gomock" - go_jose_v2 "gopkg.in/square/go-jose.v2" + jose "gopkg.in/square/go-jose.v2" reflect "reflect" time "time" ) @@ -37,6 +37,21 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder { return m.recorder } +// AuthRequestByCode mocks base method +func (m *MockStorage) AuthRequestByCode(arg0 context.Context, arg1 string) (op.AuthRequest, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AuthRequestByCode", arg0, arg1) + ret0, _ := ret[0].(op.AuthRequest) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AuthRequestByCode indicates an expected call of AuthRequestByCode +func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0, arg1) +} + // AuthRequestByID mocks base method func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) { m.ctrl.T.Helper() @@ -127,10 +142,10 @@ func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) * } // GetKeySet mocks base method -func (m *MockStorage) GetKeySet(arg0 context.Context) (*go_jose_v2.JSONWebKeySet, error) { +func (m *MockStorage) GetKeySet(arg0 context.Context) (*jose.JSONWebKeySet, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetKeySet", arg0) - ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet) + ret0, _ := ret[0].(*jose.JSONWebKeySet) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -142,7 +157,7 @@ func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call { } // GetSigningKey mocks base method -func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- go_jose_v2.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) { +func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) { m.ctrl.T.Helper() m.ctrl.Call(m, "GetSigningKey", arg0, arg1, arg2, arg3) } @@ -197,6 +212,20 @@ func (mr *MockStorageMockRecorder) Health(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0) } +// SaveAuthCode mocks base method +func (m *MockStorage) SaveAuthCode(arg0 context.Context, arg1, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveAuthCode", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveAuthCode indicates an expected call of SaveAuthCode +func (mr *MockStorageMockRecorder) SaveAuthCode(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAuthCode", reflect.TypeOf((*MockStorage)(nil).SaveAuthCode), arg0, arg1, arg2) +} + // SaveNewKeyPair mocks base method func (m *MockStorage) SaveNewKeyPair(arg0 context.Context) error { m.ctrl.T.Helper() diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 4655b88..e3ef5ff 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -12,6 +12,8 @@ import ( type AuthStorage interface { CreateAuthRequest(context.Context, *oidc.AuthRequest, string) (AuthRequest, error) AuthRequestByID(context.Context, string) (AuthRequest, error) + AuthRequestByCode(context.Context, string) (AuthRequest, error) + SaveAuthCode(context.Context, string, string) error DeleteAuthRequest(context.Context, string) error CreateToken(context.Context, AuthRequest) (string, time.Time, error) diff --git a/pkg/op/token.go b/pkg/op/token.go index 06e9f9c..9d37788 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -29,6 +29,11 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client return nil, err } + err = creator.Storage().DeleteAuthRequest(ctx, authReq.GetID()) + if err != nil { + return nil, err + } + exp := uint64(validity.Seconds()) return &oidc.AccessTokenResponse{ AccessToken: accessToken, diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index cce3564..5ef4b22 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -34,11 +34,6 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { RequestError(w, r, err) return } - err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID()) - if err != nil { - RequestError(w, r, err) - return - } resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code) if err != nil { RequestError(w, r, err) @@ -96,7 +91,7 @@ func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exc if err != nil { return nil, nil, err } - authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage()) + authReq, err := exchanger.Storage().AuthRequestByCode(ctx, tokenReq.Code) if err != nil { return nil, nil, ErrInvalidRequest("invalid code") } @@ -111,7 +106,7 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque if tokenReq.CodeVerifier == "" { return nil, ErrInvalidRequest("code_challenge required") } - authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage()) + authReq, err := exchanger.Storage().AuthRequestByCode(ctx, tokenReq.Code) if err != nil { return nil, ErrInvalidRequest("invalid code") } @@ -121,14 +116,6 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque return authReq, nil } -func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) { - id, err := crypto.Decrypt(code) - if err != nil { - return nil, err - } - return storage.AuthRequestByID(ctx, id) -} - func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { tokenRequest, err := ParseTokenExchangeRequest(w, r) if err != nil {