fix: handle code separately (#30)
This commit is contained in:
parent
303fdfc421
commit
58545a1710
7 changed files with 69 additions and 24 deletions
|
@ -110,6 +110,7 @@ func (a *AuthRequest) Done() bool {
|
||||||
var (
|
var (
|
||||||
a = &AuthRequest{}
|
a = &AuthRequest{}
|
||||||
t bool
|
t bool
|
||||||
|
c string
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *AuthStorage) Health(ctx context.Context) error {
|
func (s *AuthStorage) Health(ctx context.Context) error {
|
||||||
|
@ -127,9 +128,19 @@ func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthReq
|
||||||
t = false
|
t = false
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
func (s *AuthStorage) AuthRequestByCode(context.Context, string) (op.AuthRequest, error) {
|
func (s *AuthStorage) AuthRequestByCode(_ context.Context, code string) (op.AuthRequest, error) {
|
||||||
|
if code != c {
|
||||||
|
return nil, errors.New("invalid code")
|
||||||
|
}
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
func (s *AuthStorage) SaveAuthCode(_ context.Context, id, code string) error {
|
||||||
|
if a.ID != id {
|
||||||
|
return errors.New("not found")
|
||||||
|
}
|
||||||
|
c = code
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (s *AuthStorage) DeleteAuthRequest(context.Context, string) error {
|
func (s *AuthStorage) DeleteAuthRequest(context.Context, string) error {
|
||||||
t = true
|
t = true
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -173,7 +173,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
|
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
|
||||||
code, err := BuildAuthRequestCode(authReq, authorizer.Crypto())
|
code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||||
return
|
return
|
||||||
|
@ -201,6 +201,17 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque
|
||||||
http.Redirect(w, r, callback, http.StatusFound)
|
http.Redirect(w, r, callback, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CreateAuthRequestCode(ctx context.Context, authReq AuthRequest, storage Storage, crypto Crypto) (string, error) {
|
||||||
|
code, err := BuildAuthRequestCode(authReq, crypto)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := storage.SaveAuthCode(ctx, authReq.GetID(), code); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return code, nil
|
||||||
|
}
|
||||||
|
|
||||||
func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) {
|
func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) {
|
||||||
return crypto.Encrypt(authReq.GetID())
|
return crypto.Encrypt(authReq.GetID())
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
context "context"
|
context "context"
|
||||||
oidc "github.com/caos/oidc/pkg/oidc"
|
oidc "github.com/caos/oidc/pkg/oidc"
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
go_jose_v2 "gopkg.in/square/go-jose.v2"
|
jose "gopkg.in/square/go-jose.v2"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -80,10 +80,10 @@ func (mr *MockSignerMockRecorder) SignIDToken(arg0 interface{}) *gomock.Call {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignatureAlgorithm mocks base method
|
// SignatureAlgorithm mocks base method
|
||||||
func (m *MockSigner) SignatureAlgorithm() go_jose_v2.SignatureAlgorithm {
|
func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "SignatureAlgorithm")
|
ret := m.ctrl.Call(m, "SignatureAlgorithm")
|
||||||
ret0, _ := ret[0].(go_jose_v2.SignatureAlgorithm)
|
ret0, _ := ret[0].(jose.SignatureAlgorithm)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
oidc "github.com/caos/oidc/pkg/oidc"
|
oidc "github.com/caos/oidc/pkg/oidc"
|
||||||
op "github.com/caos/oidc/pkg/op"
|
op "github.com/caos/oidc/pkg/op"
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
go_jose_v2 "gopkg.in/square/go-jose.v2"
|
jose "gopkg.in/square/go-jose.v2"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
time "time"
|
time "time"
|
||||||
)
|
)
|
||||||
|
@ -37,6 +37,21 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
|
||||||
return m.recorder
|
return m.recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthRequestByCode mocks base method
|
||||||
|
func (m *MockStorage) AuthRequestByCode(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "AuthRequestByCode", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].(op.AuthRequest)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthRequestByCode indicates an expected call of AuthRequestByCode
|
||||||
|
func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
// AuthRequestByID mocks base method
|
// AuthRequestByID mocks base method
|
||||||
func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
|
func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -127,10 +142,10 @@ func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetKeySet mocks base method
|
// GetKeySet mocks base method
|
||||||
func (m *MockStorage) GetKeySet(arg0 context.Context) (*go_jose_v2.JSONWebKeySet, error) {
|
func (m *MockStorage) GetKeySet(arg0 context.Context) (*jose.JSONWebKeySet, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetKeySet", arg0)
|
ret := m.ctrl.Call(m, "GetKeySet", arg0)
|
||||||
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet)
|
ret0, _ := ret[0].(*jose.JSONWebKeySet)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
@ -142,7 +157,7 @@ func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSigningKey mocks base method
|
// GetSigningKey mocks base method
|
||||||
func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- go_jose_v2.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) {
|
func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "GetSigningKey", arg0, arg1, arg2, arg3)
|
m.ctrl.Call(m, "GetSigningKey", arg0, arg1, arg2, arg3)
|
||||||
}
|
}
|
||||||
|
@ -197,6 +212,20 @@ func (mr *MockStorageMockRecorder) Health(arg0 interface{}) *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveAuthCode mocks base method
|
||||||
|
func (m *MockStorage) SaveAuthCode(arg0 context.Context, arg1, arg2 string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "SaveAuthCode", arg0, arg1, arg2)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveAuthCode indicates an expected call of SaveAuthCode
|
||||||
|
func (mr *MockStorageMockRecorder) SaveAuthCode(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAuthCode", reflect.TypeOf((*MockStorage)(nil).SaveAuthCode), arg0, arg1, arg2)
|
||||||
|
}
|
||||||
|
|
||||||
// SaveNewKeyPair mocks base method
|
// SaveNewKeyPair mocks base method
|
||||||
func (m *MockStorage) SaveNewKeyPair(arg0 context.Context) error {
|
func (m *MockStorage) SaveNewKeyPair(arg0 context.Context) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -12,6 +12,8 @@ import (
|
||||||
type AuthStorage interface {
|
type AuthStorage interface {
|
||||||
CreateAuthRequest(context.Context, *oidc.AuthRequest, string) (AuthRequest, error)
|
CreateAuthRequest(context.Context, *oidc.AuthRequest, string) (AuthRequest, error)
|
||||||
AuthRequestByID(context.Context, string) (AuthRequest, error)
|
AuthRequestByID(context.Context, string) (AuthRequest, error)
|
||||||
|
AuthRequestByCode(context.Context, string) (AuthRequest, error)
|
||||||
|
SaveAuthCode(context.Context, string, string) error
|
||||||
DeleteAuthRequest(context.Context, string) error
|
DeleteAuthRequest(context.Context, string) error
|
||||||
|
|
||||||
CreateToken(context.Context, AuthRequest) (string, time.Time, error)
|
CreateToken(context.Context, AuthRequest) (string, time.Time, error)
|
||||||
|
|
|
@ -29,6 +29,11 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = creator.Storage().DeleteAuthRequest(ctx, authReq.GetID())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
exp := uint64(validity.Seconds())
|
exp := uint64(validity.Seconds())
|
||||||
return &oidc.AccessTokenResponse{
|
return &oidc.AccessTokenResponse{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
|
|
|
@ -34,11 +34,6 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
RequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
|
|
||||||
if err != nil {
|
|
||||||
RequestError(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code)
|
resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
|
@ -96,7 +91,7 @@ func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exc
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
|
authReq, err := exchanger.Storage().AuthRequestByCode(ctx, tokenReq.Code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, ErrInvalidRequest("invalid code")
|
return nil, nil, ErrInvalidRequest("invalid code")
|
||||||
}
|
}
|
||||||
|
@ -111,7 +106,7 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque
|
||||||
if tokenReq.CodeVerifier == "" {
|
if tokenReq.CodeVerifier == "" {
|
||||||
return nil, ErrInvalidRequest("code_challenge required")
|
return nil, ErrInvalidRequest("code_challenge required")
|
||||||
}
|
}
|
||||||
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
|
authReq, err := exchanger.Storage().AuthRequestByCode(ctx, tokenReq.Code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrInvalidRequest("invalid code")
|
return nil, ErrInvalidRequest("invalid code")
|
||||||
}
|
}
|
||||||
|
@ -121,14 +116,6 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque
|
||||||
return authReq, nil
|
return authReq, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
|
|
||||||
id, err := crypto.Decrypt(code)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return storage.AuthRequestByID(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
tokenRequest, err := ParseTokenExchangeRequest(w, r)
|
tokenRequest, err := ParseTokenExchangeRequest(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue