fix: handle code separately
This commit is contained in:
parent
303fdfc421
commit
a56c3f92c3
7 changed files with 69 additions and 24 deletions
|
@ -110,6 +110,7 @@ func (a *AuthRequest) Done() bool {
|
|||
var (
|
||||
a = &AuthRequest{}
|
||||
t bool
|
||||
c string
|
||||
)
|
||||
|
||||
func (s *AuthStorage) Health(ctx context.Context) error {
|
||||
|
@ -127,9 +128,19 @@ func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthReq
|
|||
t = false
|
||||
return a, nil
|
||||
}
|
||||
func (s *AuthStorage) AuthRequestByCode(context.Context, string) (op.AuthRequest, error) {
|
||||
func (s *AuthStorage) AuthRequestByCode(_ context.Context, code string) (op.AuthRequest, error) {
|
||||
if code != c {
|
||||
return nil, errors.New("invalid code")
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
func (s *AuthStorage) SaveAuthCode(_ context.Context, id, code string) error {
|
||||
if a.ID != id {
|
||||
return errors.New("not found")
|
||||
}
|
||||
c = code
|
||||
return nil
|
||||
}
|
||||
func (s *AuthStorage) DeleteAuthRequest(context.Context, string) error {
|
||||
t = true
|
||||
return nil
|
||||
|
|
|
@ -173,7 +173,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
|
|||
}
|
||||
|
||||
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
|
||||
code, err := BuildAuthRequestCode(authReq, authorizer.Crypto())
|
||||
code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
|
@ -201,6 +201,17 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque
|
|||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
}
|
||||
|
||||
func CreateAuthRequestCode(ctx context.Context, authReq AuthRequest, storage Storage, crypto Crypto) (string, error) {
|
||||
code, err := BuildAuthRequestCode(authReq, crypto)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := storage.SaveAuthCode(ctx, authReq.GetID(), code); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) {
|
||||
return crypto.Encrypt(authReq.GetID())
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
context "context"
|
||||
oidc "github.com/caos/oidc/pkg/oidc"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
go_jose_v2 "gopkg.in/square/go-jose.v2"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
|
@ -80,10 +80,10 @@ func (mr *MockSignerMockRecorder) SignIDToken(arg0 interface{}) *gomock.Call {
|
|||
}
|
||||
|
||||
// SignatureAlgorithm mocks base method
|
||||
func (m *MockSigner) SignatureAlgorithm() go_jose_v2.SignatureAlgorithm {
|
||||
func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SignatureAlgorithm")
|
||||
ret0, _ := ret[0].(go_jose_v2.SignatureAlgorithm)
|
||||
ret0, _ := ret[0].(jose.SignatureAlgorithm)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
oidc "github.com/caos/oidc/pkg/oidc"
|
||||
op "github.com/caos/oidc/pkg/op"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
go_jose_v2 "gopkg.in/square/go-jose.v2"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
)
|
||||
|
@ -37,6 +37,21 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
|
|||
return m.recorder
|
||||
}
|
||||
|
||||
// AuthRequestByCode mocks base method
|
||||
func (m *MockStorage) AuthRequestByCode(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthRequestByCode", arg0, arg1)
|
||||
ret0, _ := ret[0].(op.AuthRequest)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AuthRequestByCode indicates an expected call of AuthRequestByCode
|
||||
func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0, arg1)
|
||||
}
|
||||
|
||||
// AuthRequestByID mocks base method
|
||||
func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -127,10 +142,10 @@ func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *
|
|||
}
|
||||
|
||||
// GetKeySet mocks base method
|
||||
func (m *MockStorage) GetKeySet(arg0 context.Context) (*go_jose_v2.JSONWebKeySet, error) {
|
||||
func (m *MockStorage) GetKeySet(arg0 context.Context) (*jose.JSONWebKeySet, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetKeySet", arg0)
|
||||
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet)
|
||||
ret0, _ := ret[0].(*jose.JSONWebKeySet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
@ -142,7 +157,7 @@ func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
|
|||
}
|
||||
|
||||
// GetSigningKey mocks base method
|
||||
func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- go_jose_v2.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) {
|
||||
func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "GetSigningKey", arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
@ -197,6 +212,20 @@ func (mr *MockStorageMockRecorder) Health(arg0 interface{}) *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0)
|
||||
}
|
||||
|
||||
// SaveAuthCode mocks base method
|
||||
func (m *MockStorage) SaveAuthCode(arg0 context.Context, arg1, arg2 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveAuthCode", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveAuthCode indicates an expected call of SaveAuthCode
|
||||
func (mr *MockStorageMockRecorder) SaveAuthCode(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAuthCode", reflect.TypeOf((*MockStorage)(nil).SaveAuthCode), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SaveNewKeyPair mocks base method
|
||||
func (m *MockStorage) SaveNewKeyPair(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -12,6 +12,8 @@ import (
|
|||
type AuthStorage interface {
|
||||
CreateAuthRequest(context.Context, *oidc.AuthRequest, string) (AuthRequest, error)
|
||||
AuthRequestByID(context.Context, string) (AuthRequest, error)
|
||||
AuthRequestByCode(context.Context, string) (AuthRequest, error)
|
||||
SaveAuthCode(context.Context, string, string) error
|
||||
DeleteAuthRequest(context.Context, string) error
|
||||
|
||||
CreateToken(context.Context, AuthRequest) (string, time.Time, error)
|
||||
|
|
|
@ -29,6 +29,11 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client
|
|||
return nil, err
|
||||
}
|
||||
|
||||
err = creator.Storage().DeleteAuthRequest(ctx, authReq.GetID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exp := uint64(validity.Seconds())
|
||||
return &oidc.AccessTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
|
|
|
@ -34,11 +34,6 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
|||
RequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code)
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
|
@ -96,7 +91,7 @@ func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exc
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
|
||||
authReq, err := exchanger.Storage().AuthRequestByCode(ctx, tokenReq.Code)
|
||||
if err != nil {
|
||||
return nil, nil, ErrInvalidRequest("invalid code")
|
||||
}
|
||||
|
@ -111,7 +106,7 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque
|
|||
if tokenReq.CodeVerifier == "" {
|
||||
return nil, ErrInvalidRequest("code_challenge required")
|
||||
}
|
||||
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
|
||||
authReq, err := exchanger.Storage().AuthRequestByCode(ctx, tokenReq.Code)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidRequest("invalid code")
|
||||
}
|
||||
|
@ -121,14 +116,6 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque
|
|||
return authReq, nil
|
||||
}
|
||||
|
||||
func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
|
||||
id, err := crypto.Decrypt(code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return storage.AuthRequestByID(ctx, id)
|
||||
}
|
||||
|
||||
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
tokenRequest, err := ParseTokenExchangeRequest(w, r)
|
||||
if err != nil {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue