fix: handle code separately (#30)

This commit is contained in:
Livio Amstutz 2020-05-29 09:40:34 +02:00 committed by GitHub
parent 303fdfc421
commit 58545a1710
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 69 additions and 24 deletions

View file

@ -110,6 +110,7 @@ func (a *AuthRequest) Done() bool {
var (
a = &AuthRequest{}
t bool
c string
)
func (s *AuthStorage) Health(ctx context.Context) error {
@ -127,9 +128,19 @@ func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthReq
t = false
return a, nil
}
func (s *AuthStorage) AuthRequestByCode(context.Context, string) (op.AuthRequest, error) {
func (s *AuthStorage) AuthRequestByCode(_ context.Context, code string) (op.AuthRequest, error) {
if code != c {
return nil, errors.New("invalid code")
}
return a, nil
}
func (s *AuthStorage) SaveAuthCode(_ context.Context, id, code string) error {
if a.ID != id {
return errors.New("not found")
}
c = code
return nil
}
func (s *AuthStorage) DeleteAuthRequest(context.Context, string) error {
t = true
return nil

View file

@ -173,7 +173,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
}
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
code, err := BuildAuthRequestCode(authReq, authorizer.Crypto())
code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
@ -201,6 +201,17 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque
http.Redirect(w, r, callback, http.StatusFound)
}
func CreateAuthRequestCode(ctx context.Context, authReq AuthRequest, storage Storage, crypto Crypto) (string, error) {
code, err := BuildAuthRequestCode(authReq, crypto)
if err != nil {
return "", err
}
if err := storage.SaveAuthCode(ctx, authReq.GetID(), code); err != nil {
return "", err
}
return code, nil
}
func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) {
return crypto.Encrypt(authReq.GetID())
}

View file

@ -8,7 +8,7 @@ import (
context "context"
oidc "github.com/caos/oidc/pkg/oidc"
gomock "github.com/golang/mock/gomock"
go_jose_v2 "gopkg.in/square/go-jose.v2"
jose "gopkg.in/square/go-jose.v2"
reflect "reflect"
)
@ -80,10 +80,10 @@ func (mr *MockSignerMockRecorder) SignIDToken(arg0 interface{}) *gomock.Call {
}
// SignatureAlgorithm mocks base method
func (m *MockSigner) SignatureAlgorithm() go_jose_v2.SignatureAlgorithm {
func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignatureAlgorithm")
ret0, _ := ret[0].(go_jose_v2.SignatureAlgorithm)
ret0, _ := ret[0].(jose.SignatureAlgorithm)
return ret0
}

View file

@ -9,7 +9,7 @@ import (
oidc "github.com/caos/oidc/pkg/oidc"
op "github.com/caos/oidc/pkg/op"
gomock "github.com/golang/mock/gomock"
go_jose_v2 "gopkg.in/square/go-jose.v2"
jose "gopkg.in/square/go-jose.v2"
reflect "reflect"
time "time"
)
@ -37,6 +37,21 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
return m.recorder
}
// AuthRequestByCode mocks base method
func (m *MockStorage) AuthRequestByCode(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthRequestByCode", arg0, arg1)
ret0, _ := ret[0].(op.AuthRequest)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AuthRequestByCode indicates an expected call of AuthRequestByCode
func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0, arg1)
}
// AuthRequestByID mocks base method
func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
m.ctrl.T.Helper()
@ -127,10 +142,10 @@ func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *
}
// GetKeySet mocks base method
func (m *MockStorage) GetKeySet(arg0 context.Context) (*go_jose_v2.JSONWebKeySet, error) {
func (m *MockStorage) GetKeySet(arg0 context.Context) (*jose.JSONWebKeySet, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKeySet", arg0)
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet)
ret0, _ := ret[0].(*jose.JSONWebKeySet)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@ -142,7 +157,7 @@ func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
}
// GetSigningKey mocks base method
func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- go_jose_v2.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) {
func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "GetSigningKey", arg0, arg1, arg2, arg3)
}
@ -197,6 +212,20 @@ func (mr *MockStorageMockRecorder) Health(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0)
}
// SaveAuthCode mocks base method
func (m *MockStorage) SaveAuthCode(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveAuthCode", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// SaveAuthCode indicates an expected call of SaveAuthCode
func (mr *MockStorageMockRecorder) SaveAuthCode(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAuthCode", reflect.TypeOf((*MockStorage)(nil).SaveAuthCode), arg0, arg1, arg2)
}
// SaveNewKeyPair mocks base method
func (m *MockStorage) SaveNewKeyPair(arg0 context.Context) error {
m.ctrl.T.Helper()

View file

@ -12,6 +12,8 @@ import (
type AuthStorage interface {
CreateAuthRequest(context.Context, *oidc.AuthRequest, string) (AuthRequest, error)
AuthRequestByID(context.Context, string) (AuthRequest, error)
AuthRequestByCode(context.Context, string) (AuthRequest, error)
SaveAuthCode(context.Context, string, string) error
DeleteAuthRequest(context.Context, string) error
CreateToken(context.Context, AuthRequest) (string, time.Time, error)

View file

@ -29,6 +29,11 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client
return nil, err
}
err = creator.Storage().DeleteAuthRequest(ctx, authReq.GetID())
if err != nil {
return nil, err
}
exp := uint64(validity.Seconds())
return &oidc.AccessTokenResponse{
AccessToken: accessToken,

View file

@ -34,11 +34,6 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
RequestError(w, r, err)
return
}
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
if err != nil {
RequestError(w, r, err)
return
}
resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code)
if err != nil {
RequestError(w, r, err)
@ -96,7 +91,7 @@ func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exc
if err != nil {
return nil, nil, err
}
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
authReq, err := exchanger.Storage().AuthRequestByCode(ctx, tokenReq.Code)
if err != nil {
return nil, nil, ErrInvalidRequest("invalid code")
}
@ -111,7 +106,7 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque
if tokenReq.CodeVerifier == "" {
return nil, ErrInvalidRequest("code_challenge required")
}
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
authReq, err := exchanger.Storage().AuthRequestByCode(ctx, tokenReq.Code)
if err != nil {
return nil, ErrInvalidRequest("invalid code")
}
@ -121,14 +116,6 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque
return authReq, nil
}
func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
id, err := crypto.Decrypt(code)
if err != nil {
return nil, err
}
return storage.AuthRequestByID(ctx, id)
}
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
tokenRequest, err := ParseTokenExchangeRequest(w, r)
if err != nil {