fix: handle code separately

This commit is contained in:
Livio Amstutz 2020-05-29 09:35:20 +02:00
parent 303fdfc421
commit a56c3f92c3
7 changed files with 69 additions and 24 deletions

View file

@ -110,6 +110,7 @@ func (a *AuthRequest) Done() bool {
var ( var (
a = &AuthRequest{} a = &AuthRequest{}
t bool t bool
c string
) )
func (s *AuthStorage) Health(ctx context.Context) error { func (s *AuthStorage) Health(ctx context.Context) error {
@ -127,9 +128,19 @@ func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthReq
t = false t = false
return a, nil return a, nil
} }
func (s *AuthStorage) AuthRequestByCode(context.Context, string) (op.AuthRequest, error) { func (s *AuthStorage) AuthRequestByCode(_ context.Context, code string) (op.AuthRequest, error) {
if code != c {
return nil, errors.New("invalid code")
}
return a, nil return a, nil
} }
func (s *AuthStorage) SaveAuthCode(_ context.Context, id, code string) error {
if a.ID != id {
return errors.New("not found")
}
c = code
return nil
}
func (s *AuthStorage) DeleteAuthRequest(context.Context, string) error { func (s *AuthStorage) DeleteAuthRequest(context.Context, string) error {
t = true t = true
return nil return nil

View file

@ -173,7 +173,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
} }
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) { func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
code, err := BuildAuthRequestCode(authReq, authorizer.Crypto()) code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto())
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return return
@ -201,6 +201,17 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque
http.Redirect(w, r, callback, http.StatusFound) http.Redirect(w, r, callback, http.StatusFound)
} }
func CreateAuthRequestCode(ctx context.Context, authReq AuthRequest, storage Storage, crypto Crypto) (string, error) {
code, err := BuildAuthRequestCode(authReq, crypto)
if err != nil {
return "", err
}
if err := storage.SaveAuthCode(ctx, authReq.GetID(), code); err != nil {
return "", err
}
return code, nil
}
func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) { func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) {
return crypto.Encrypt(authReq.GetID()) return crypto.Encrypt(authReq.GetID())
} }

View file

@ -8,7 +8,7 @@ import (
context "context" context "context"
oidc "github.com/caos/oidc/pkg/oidc" oidc "github.com/caos/oidc/pkg/oidc"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
go_jose_v2 "gopkg.in/square/go-jose.v2" jose "gopkg.in/square/go-jose.v2"
reflect "reflect" reflect "reflect"
) )
@ -80,10 +80,10 @@ func (mr *MockSignerMockRecorder) SignIDToken(arg0 interface{}) *gomock.Call {
} }
// SignatureAlgorithm mocks base method // SignatureAlgorithm mocks base method
func (m *MockSigner) SignatureAlgorithm() go_jose_v2.SignatureAlgorithm { func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignatureAlgorithm") ret := m.ctrl.Call(m, "SignatureAlgorithm")
ret0, _ := ret[0].(go_jose_v2.SignatureAlgorithm) ret0, _ := ret[0].(jose.SignatureAlgorithm)
return ret0 return ret0
} }

View file

@ -9,7 +9,7 @@ import (
oidc "github.com/caos/oidc/pkg/oidc" oidc "github.com/caos/oidc/pkg/oidc"
op "github.com/caos/oidc/pkg/op" op "github.com/caos/oidc/pkg/op"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
go_jose_v2 "gopkg.in/square/go-jose.v2" jose "gopkg.in/square/go-jose.v2"
reflect "reflect" reflect "reflect"
time "time" time "time"
) )
@ -37,6 +37,21 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
return m.recorder return m.recorder
} }
// AuthRequestByCode mocks base method
func (m *MockStorage) AuthRequestByCode(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthRequestByCode", arg0, arg1)
ret0, _ := ret[0].(op.AuthRequest)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AuthRequestByCode indicates an expected call of AuthRequestByCode
func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0, arg1)
}
// AuthRequestByID mocks base method // AuthRequestByID mocks base method
func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) { func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -127,10 +142,10 @@ func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *
} }
// GetKeySet mocks base method // GetKeySet mocks base method
func (m *MockStorage) GetKeySet(arg0 context.Context) (*go_jose_v2.JSONWebKeySet, error) { func (m *MockStorage) GetKeySet(arg0 context.Context) (*jose.JSONWebKeySet, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKeySet", arg0) ret := m.ctrl.Call(m, "GetKeySet", arg0)
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet) ret0, _ := ret[0].(*jose.JSONWebKeySet)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@ -142,7 +157,7 @@ func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
} }
// GetSigningKey mocks base method // GetSigningKey mocks base method
func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- go_jose_v2.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) { func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "GetSigningKey", arg0, arg1, arg2, arg3) m.ctrl.Call(m, "GetSigningKey", arg0, arg1, arg2, arg3)
} }
@ -197,6 +212,20 @@ func (mr *MockStorageMockRecorder) Health(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0)
} }
// SaveAuthCode mocks base method
func (m *MockStorage) SaveAuthCode(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveAuthCode", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// SaveAuthCode indicates an expected call of SaveAuthCode
func (mr *MockStorageMockRecorder) SaveAuthCode(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAuthCode", reflect.TypeOf((*MockStorage)(nil).SaveAuthCode), arg0, arg1, arg2)
}
// SaveNewKeyPair mocks base method // SaveNewKeyPair mocks base method
func (m *MockStorage) SaveNewKeyPair(arg0 context.Context) error { func (m *MockStorage) SaveNewKeyPair(arg0 context.Context) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -12,6 +12,8 @@ import (
type AuthStorage interface { type AuthStorage interface {
CreateAuthRequest(context.Context, *oidc.AuthRequest, string) (AuthRequest, error) CreateAuthRequest(context.Context, *oidc.AuthRequest, string) (AuthRequest, error)
AuthRequestByID(context.Context, string) (AuthRequest, error) AuthRequestByID(context.Context, string) (AuthRequest, error)
AuthRequestByCode(context.Context, string) (AuthRequest, error)
SaveAuthCode(context.Context, string, string) error
DeleteAuthRequest(context.Context, string) error DeleteAuthRequest(context.Context, string) error
CreateToken(context.Context, AuthRequest) (string, time.Time, error) CreateToken(context.Context, AuthRequest) (string, time.Time, error)

View file

@ -29,6 +29,11 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client
return nil, err return nil, err
} }
err = creator.Storage().DeleteAuthRequest(ctx, authReq.GetID())
if err != nil {
return nil, err
}
exp := uint64(validity.Seconds()) exp := uint64(validity.Seconds())
return &oidc.AccessTokenResponse{ return &oidc.AccessTokenResponse{
AccessToken: accessToken, AccessToken: accessToken,

View file

@ -34,11 +34,6 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
RequestError(w, r, err) RequestError(w, r, err)
return return
} }
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
if err != nil {
RequestError(w, r, err)
return
}
resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code) resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code)
if err != nil { if err != nil {
RequestError(w, r, err) RequestError(w, r, err)
@ -96,7 +91,7 @@ func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exc
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage()) authReq, err := exchanger.Storage().AuthRequestByCode(ctx, tokenReq.Code)
if err != nil { if err != nil {
return nil, nil, ErrInvalidRequest("invalid code") return nil, nil, ErrInvalidRequest("invalid code")
} }
@ -111,7 +106,7 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque
if tokenReq.CodeVerifier == "" { if tokenReq.CodeVerifier == "" {
return nil, ErrInvalidRequest("code_challenge required") return nil, ErrInvalidRequest("code_challenge required")
} }
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage()) authReq, err := exchanger.Storage().AuthRequestByCode(ctx, tokenReq.Code)
if err != nil { if err != nil {
return nil, ErrInvalidRequest("invalid code") return nil, ErrInvalidRequest("invalid code")
} }
@ -121,14 +116,6 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque
return authReq, nil return authReq, nil
} }
func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
id, err := crypto.Decrypt(code)
if err != nil {
return nil, err
}
return storage.AuthRequestByID(ctx, id)
}
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
tokenRequest, err := ParseTokenExchangeRequest(w, r) tokenRequest, err := ParseTokenExchangeRequest(w, r)
if err != nil { if err != nil {