code exchange fixes

This commit is contained in:
Livio Amstutz 2019-12-12 16:04:34 +01:00
parent 85814fb69a
commit 20a90c71d9
9 changed files with 107 additions and 36 deletions

View file

@ -7,6 +7,7 @@ const (
)
type Client interface {
GetID() string
RedirectURIs() []string
ApplicationType() ApplicationType
LoginURL(string) string

View file

@ -47,6 +47,20 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
}
// GetID mocks base method
func (m *MockClient) GetID() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetID")
ret0, _ := ret[0].(string)
return ret0
}
// GetID indicates an expected call of GetID
func (mr *MockClientMockRecorder) GetID() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID))
}
// LoginURL mocks base method
func (m *MockClient) LoginURL(arg0 string) string {
m.ctrl.T.Helper()

View file

@ -36,18 +36,18 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
}
// AuthRequestByCode mocks base method
func (m *MockStorage) AuthRequestByCode(arg0 op.Client, arg1, arg2 string) (op.AuthRequest, error) {
func (m *MockStorage) AuthRequestByCode(arg0 string) (op.AuthRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthRequestByCode", arg0, arg1, arg2)
ret := m.ctrl.Call(m, "AuthRequestByCode", arg0)
ret0, _ := ret[0].(op.AuthRequest)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AuthRequestByCode indicates an expected call of AuthRequestByCode
func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0, arg1, arg2 interface{}) *gomock.Call {
func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0)
}
// AuthRequestByID mocks base method
@ -140,10 +140,10 @@ func (mr *MockStorageMockRecorder) GetClientByClientID(arg0 interface{}) *gomock
}
// GetKeySet mocks base method
func (m *MockStorage) GetKeySet() (go_jose_v2.JSONWebKeySet, error) {
func (m *MockStorage) GetKeySet() (*go_jose_v2.JSONWebKeySet, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKeySet")
ret0, _ := ret[0].(go_jose_v2.JSONWebKeySet)
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet)
ret1, _ := ret[1].(error)
return ret0, ret1
}

View file

@ -11,11 +11,11 @@ import (
type AuthStorage interface {
CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error)
AuthRequestByID(string) (AuthRequest, error)
AuthRequestByCode(Client, string, string) (AuthRequest, error)
AuthRequestByCode(string) (AuthRequest, error)
DeleteAuthRequestAndCode(string, string) error
GetSigningKey() (*jose.SigningKey, error)
GetKeySet() (jose.JSONWebKeySet, error)
GetKeySet() (*jose.JSONWebKeySet, error)
}
type OPStorage interface {
@ -38,6 +38,7 @@ type AuthRequest interface {
GetAuthTime() time.Time
GetClientID() string
GetCode() string
GetCodeChallenge() *oidc.CodeChallenge
GetNonce() string
GetRedirectURI() string
GetResponseType() oidc.ResponseType

View file

@ -39,12 +39,17 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
return
}
client, err := AuthorizeClient(r, tokenReq, exchanger)
authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
authReq, err := exchanger.Storage().AuthRequestByCode(client, tokenReq.Code, tokenReq.RedirectURI)
client, err := AuthorizeClient(r, tokenReq, authReq, exchanger)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
err = ValidateAccessTokenRequest(tokenReq, client, authReq)
if err != nil {
ExchangeRequestError(w, r, err)
return
@ -74,7 +79,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
utils.MarshalJSON(w, resp)
}
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Client, error) {
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, authReq AuthRequest, exchanger Exchanger) (Client, error) {
if tokenReq.ClientID == "" {
if !exchanger.AuthMethodBasicSupported() {
return nil, errors.New("basic not supported")
@ -92,11 +97,24 @@ func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchang
return exchanger.Storage().AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret)
}
if tokenReq.CodeVerifier != "" {
return exchanger.Storage().AuthorizeClientIDCodeVerifier(tokenReq.ClientID, tokenReq.CodeVerifier)
if !authReq.GetCodeChallenge().Verify(tokenReq.CodeVerifier) {
return nil, ErrInvalidRequest("code_challenge invalid")
}
return exchanger.Storage().GetClientByClientID(tokenReq.ClientID)
}
return nil, errors.New("Unimplemented") //TODO: impl
}
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, client Client, authReq AuthRequest) error {
if client.GetID() != authReq.GetClientID() {
return ErrInvalidRequest("invalid auth code")
}
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
return ErrInvalidRequest("redirect_uri does no correspond")
}
return nil
}
func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) {
return nil, errors.New("Unimplemented") //TODO: impl
}