code exchange fixes
This commit is contained in:
parent
85814fb69a
commit
20a90c71d9
9 changed files with 107 additions and 36 deletions
|
@ -7,6 +7,7 @@ const (
|
|||
)
|
||||
|
||||
type Client interface {
|
||||
GetID() string
|
||||
RedirectURIs() []string
|
||||
ApplicationType() ApplicationType
|
||||
LoginURL(string) string
|
||||
|
|
|
@ -47,6 +47,20 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
|
||||
}
|
||||
|
||||
// GetID mocks base method
|
||||
func (m *MockClient) GetID() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetID")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetID indicates an expected call of GetID
|
||||
func (mr *MockClientMockRecorder) GetID() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID))
|
||||
}
|
||||
|
||||
// LoginURL mocks base method
|
||||
func (m *MockClient) LoginURL(arg0 string) string {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -36,18 +36,18 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
|
|||
}
|
||||
|
||||
// AuthRequestByCode mocks base method
|
||||
func (m *MockStorage) AuthRequestByCode(arg0 op.Client, arg1, arg2 string) (op.AuthRequest, error) {
|
||||
func (m *MockStorage) AuthRequestByCode(arg0 string) (op.AuthRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthRequestByCode", arg0, arg1, arg2)
|
||||
ret := m.ctrl.Call(m, "AuthRequestByCode", arg0)
|
||||
ret0, _ := ret[0].(op.AuthRequest)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AuthRequestByCode indicates an expected call of AuthRequestByCode
|
||||
func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0, arg1, arg2)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0)
|
||||
}
|
||||
|
||||
// AuthRequestByID mocks base method
|
||||
|
@ -140,10 +140,10 @@ func (mr *MockStorageMockRecorder) GetClientByClientID(arg0 interface{}) *gomock
|
|||
}
|
||||
|
||||
// GetKeySet mocks base method
|
||||
func (m *MockStorage) GetKeySet() (go_jose_v2.JSONWebKeySet, error) {
|
||||
func (m *MockStorage) GetKeySet() (*go_jose_v2.JSONWebKeySet, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetKeySet")
|
||||
ret0, _ := ret[0].(go_jose_v2.JSONWebKeySet)
|
||||
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
|
|
@ -11,11 +11,11 @@ import (
|
|||
type AuthStorage interface {
|
||||
CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error)
|
||||
AuthRequestByID(string) (AuthRequest, error)
|
||||
AuthRequestByCode(Client, string, string) (AuthRequest, error)
|
||||
AuthRequestByCode(string) (AuthRequest, error)
|
||||
DeleteAuthRequestAndCode(string, string) error
|
||||
|
||||
GetSigningKey() (*jose.SigningKey, error)
|
||||
GetKeySet() (jose.JSONWebKeySet, error)
|
||||
GetKeySet() (*jose.JSONWebKeySet, error)
|
||||
}
|
||||
|
||||
type OPStorage interface {
|
||||
|
@ -38,6 +38,7 @@ type AuthRequest interface {
|
|||
GetAuthTime() time.Time
|
||||
GetClientID() string
|
||||
GetCode() string
|
||||
GetCodeChallenge() *oidc.CodeChallenge
|
||||
GetNonce() string
|
||||
GetRedirectURI() string
|
||||
GetResponseType() oidc.ResponseType
|
||||
|
|
|
@ -39,12 +39,17 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
|||
return
|
||||
}
|
||||
|
||||
client, err := AuthorizeClient(r, tokenReq, exchanger)
|
||||
authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code)
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
authReq, err := exchanger.Storage().AuthRequestByCode(client, tokenReq.Code, tokenReq.RedirectURI)
|
||||
client, err := AuthorizeClient(r, tokenReq, authReq, exchanger)
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
err = ValidateAccessTokenRequest(tokenReq, client, authReq)
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
|
@ -74,7 +79,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
|||
utils.MarshalJSON(w, resp)
|
||||
}
|
||||
|
||||
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Client, error) {
|
||||
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, authReq AuthRequest, exchanger Exchanger) (Client, error) {
|
||||
if tokenReq.ClientID == "" {
|
||||
if !exchanger.AuthMethodBasicSupported() {
|
||||
return nil, errors.New("basic not supported")
|
||||
|
@ -92,11 +97,24 @@ func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchang
|
|||
return exchanger.Storage().AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret)
|
||||
}
|
||||
if tokenReq.CodeVerifier != "" {
|
||||
return exchanger.Storage().AuthorizeClientIDCodeVerifier(tokenReq.ClientID, tokenReq.CodeVerifier)
|
||||
if !authReq.GetCodeChallenge().Verify(tokenReq.CodeVerifier) {
|
||||
return nil, ErrInvalidRequest("code_challenge invalid")
|
||||
}
|
||||
return exchanger.Storage().GetClientByClientID(tokenReq.ClientID)
|
||||
}
|
||||
return nil, errors.New("Unimplemented") //TODO: impl
|
||||
}
|
||||
|
||||
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, client Client, authReq AuthRequest) error {
|
||||
if client.GetID() != authReq.GetClientID() {
|
||||
return ErrInvalidRequest("invalid auth code")
|
||||
}
|
||||
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
|
||||
return ErrInvalidRequest("redirect_uri does no correspond")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) {
|
||||
return nil, errors.New("Unimplemented") //TODO: impl
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue