diff --git a/example/client/app/app.go b/example/client/app/app.go index 51caf8b..c78921a 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -38,7 +38,7 @@ func main() { // cookieHandler := utils.NewCookieHandler(hashKey, nil, utils.WithUnsecure()) provider, err := rp.NewDefaultRP(rpConfig) //, rp.WithCookieHandler(cookieHandler)) if err != nil { - logrus.Panic("error creating provider") + logrus.Panicf("error creating provider %s", err.Error()) } // state := "foobar" diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 8cf9e80..690f9e2 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -31,11 +31,12 @@ func NewAuthStorage() op.AuthStorage { } type AuthRequest struct { - ID string - ResponseType oidc.ResponseType - RedirectURI string - Nonce string - ClientID string + ID string + ResponseType oidc.ResponseType + RedirectURI string + Nonce string + ClientID string + CodeChallenge *oidc.CodeChallenge } func (a *AuthRequest) GetACR() string { @@ -66,6 +67,10 @@ func (a *AuthRequest) GetCode() string { return "code" } +func (a *AuthRequest) GetCodeChallenge() *oidc.CodeChallenge { + return a.CodeChallenge +} + func (a *AuthRequest) GetID() string { return a.ID } @@ -105,38 +110,23 @@ var ( func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) { a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI} + if authReq.CodeChallenge != "" { + a.CodeChallenge = &oidc.CodeChallenge{ + Challenge: authReq.CodeChallenge, + Method: authReq.CodeChallengeMethod, + } + } return a, nil } -func (s *OPStorage) GetClientByClientID(id string) (op.Client, error) { - if id == "none" { - return nil, errors.New("not found") - } - var appType op.ApplicationType - if id == "web" { - appType = op.ApplicationTypeWeb - } else if id == "native" { - appType = op.ApplicationTypeNative - } else { - appType = op.ApplicationTypeUserAgent - } - return &ConfClient{applicationType: appType}, nil -} -func (s *AuthStorage) AuthRequestByCode(op.Client, string, string) (op.AuthRequest, error) { +func (s *AuthStorage) AuthRequestByCode(string) (op.AuthRequest, error) { return a, nil } -func (s *OPStorage) AuthorizeClientIDSecret(string, string) (op.Client, error) { - return &ConfClient{}, nil -} -func (s *OPStorage) AuthorizeClientIDCodeVerifier(string, string) (op.Client, error) { - return &ConfClient{}, nil -} func (s *AuthStorage) DeleteAuthRequestAndCode(string, string) error { return nil } func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) { return a, nil } - func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) { return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil } @@ -152,53 +142,61 @@ func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) { }, nil } -func (s *OPStorage) GetUserinfoFromScopes([]string) (interface{}, error) { - return &oidc.Test{ - Userinfo: oidc.Userinfo{ - Subject: a.GetSubject(), - Address: &oidc.UserinfoAddress{ - StreetAddress: "Hjkhkj 789\ndsf", - }, - UserinfoEmail: oidc.UserinfoEmail{ - Email: "test", - EmailVerified: true, - }, - UserinfoPhone: oidc.UserinfoPhone{ - PhoneNumber: "sadsa", - PhoneNumberVerified: true, - }, - UserinfoProfile: oidc.UserinfoProfile{ - UpdatedAt: time.Now(), - }, - // Claims: map[string]interface{}{ - // "test": "test", - // "hkjh": "", - // }, - }, - Add: "jkhnkj", - }, nil -} - -type info struct { - Subject string -} - -func (i *info) GetSubject() string { - return i.Subject -} - -func (i *info) Claims() map[string]interface{} { - return map[string]interface{}{ - "hodor": "hoidoir", - "email": "asdfd", - "emailVerfied": true, +func (s *OPStorage) GetClientByClientID(id string) (op.Client, error) { + if id == "none" { + return nil, errors.New("not found") } + var appType op.ApplicationType + var authMethod op.AuthMethod + if id == "web" { + appType = op.ApplicationTypeWeb + authMethod = op.AuthMethodBasic + } else if id == "native" { + appType = op.ApplicationTypeNative + authMethod = op.AuthMethodNone + } else { + appType = op.ApplicationTypeUserAgent + authMethod = op.AuthMethodNone + } + return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil +} + +func (s *OPStorage) AuthorizeClientIDSecret(id string, _ string) error { + return nil +} +func (s *OPStorage) GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) { + return &oidc.Userinfo{ + Subject: a.GetSubject(), + Address: &oidc.UserinfoAddress{ + StreetAddress: "Hjkhkj 789\ndsf", + }, + UserinfoEmail: oidc.UserinfoEmail{ + Email: "test", + EmailVerified: true, + }, + UserinfoPhone: oidc.UserinfoPhone{ + PhoneNumber: "sadsa", + PhoneNumberVerified: true, + }, + UserinfoProfile: oidc.UserinfoProfile{ + UpdatedAt: time.Now(), + }, + // Claims: map[string]interface{}{ + // "test": "test", + // "hkjh": "", + // }, + }, nil } type ConfClient struct { applicationType op.ApplicationType + authMethod op.AuthMethod + ID string } +func (c *ConfClient) GetID() string { + return c.ID +} func (c *ConfClient) RedirectURIs() []string { return []string{ "https://registered.com/callback", @@ -218,3 +216,7 @@ func (c *ConfClient) LoginURL(id string) string { func (c *ConfClient) ApplicationType() op.ApplicationType { return c.applicationType } + +func (c *ConfClient) GetAuthMethod() op.AuthMethod { + return c.authMethod +} diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index 0ebd26d..c3245f6 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -58,6 +58,9 @@ type AuthRequest struct { IDTokenHint string `schema:"id_token_hint"` LoginHint string `schema:"login_hint"` ACRValues []string `schema:"acr_values"` + + CodeChallenge string `schema:"code_challenge"` + CodeChallengeMethod CodeChallengeMethod `schema:"code_challenge_method"` } // func (a *AuthRequest) GetID() string { diff --git a/pkg/op/client.go b/pkg/op/client.go index 41a6d60..cbd69fb 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -10,6 +10,7 @@ type Client interface { GetID() string RedirectURIs() []string ApplicationType() ApplicationType + GetAuthMethod() AuthMethod LoginURL(string) string } @@ -18,3 +19,5 @@ func IsConfidentialType(c Client) bool { } type ApplicationType int + +type AuthMethod string diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index 7f832a5..9d4aeaf 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -16,8 +16,9 @@ const ( defaultUserinfoEndpoint = "userinfo" defaultKeysEndpoint = "keys" - AuthMethodBasic = "client_secret_basic" - AuthMethodPost = "client_secret_post" + AuthMethodBasic AuthMethod = "client_secret_basic" + AuthMethodPost = "client_secret_post" + AuthMethodNone = "none" DefaultIDTokenValidity = time.Duration(5 * time.Minute) ) diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 9a3d97e..3d4ea98 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -110,10 +110,10 @@ func SubjectTypes(c Configuration) []string { func AuthMethods(c Configuration) []string { authMethods := []string{ - AuthMethodBasic, + string(AuthMethodBasic), } if c.AuthMethodPostSupported() { - authMethods = append(authMethods, AuthMethodPost) + authMethods = append(authMethods, string(AuthMethodPost)) } return authMethods } diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go index 59e9daa..39b39bc 100644 --- a/pkg/op/discovery_test.go +++ b/pkg/op/discovery_test.go @@ -214,7 +214,7 @@ func Test_AuthMethods(t *testing.T) { m.EXPECT().AuthMethodPostSupported().Return(false) return m }()}, - []string{op.AuthMethodBasic}, + []string{string(op.AuthMethodBasic)}, }, { "basic and post", @@ -222,7 +222,7 @@ func Test_AuthMethods(t *testing.T) { m.EXPECT().AuthMethodPostSupported().Return(true) return m }()}, - []string{op.AuthMethodBasic, op.AuthMethodPost}, + []string{string(op.AuthMethodBasic), string(op.AuthMethodPost)}, }, } for _, tt := range tests { diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go index 9757457..e856860 100644 --- a/pkg/op/mock/client.mock.go +++ b/pkg/op/mock/client.mock.go @@ -47,6 +47,20 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType)) } +// GetAuthMethod mocks base method +func (m *MockClient) GetAuthMethod() op.AuthMethod { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAuthMethod") + ret0, _ := ret[0].(op.AuthMethod) + return ret0 +} + +// GetAuthMethod indicates an expected call of GetAuthMethod +func (mr *MockClientMockRecorder) GetAuthMethod() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthMethod", reflect.TypeOf((*MockClient)(nil).GetAuthMethod)) +} + // GetID mocks base method func (m *MockClient) GetID() string { m.ctrl.T.Helper() diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 22cd49b..69133ba 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -65,28 +65,12 @@ func (mr *MockStorageMockRecorder) AuthRequestByID(arg0 interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0) } -// AuthorizeClientIDCodeVerifier mocks base method -func (m *MockStorage) AuthorizeClientIDCodeVerifier(arg0, arg1 string) (op.Client, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AuthorizeClientIDCodeVerifier", arg0, arg1) - ret0, _ := ret[0].(op.Client) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AuthorizeClientIDCodeVerifier indicates an expected call of AuthorizeClientIDCodeVerifier -func (mr *MockStorageMockRecorder) AuthorizeClientIDCodeVerifier(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDCodeVerifier", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDCodeVerifier), arg0, arg1) -} - // AuthorizeClientIDSecret mocks base method -func (m *MockStorage) AuthorizeClientIDSecret(arg0, arg1 string) (op.Client, error) { +func (m *MockStorage) AuthorizeClientIDSecret(arg0, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1) - ret0, _ := ret[0].(op.Client) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret0, _ := ret[0].(error) + return ret0 } // AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go index fd9cd76..c52ace0 100644 --- a/pkg/op/mock/storage.mock.impl.go +++ b/pkg/op/mock/storage.mock.impl.go @@ -31,7 +31,7 @@ func NewMockStorageAny(t *testing.T) op.Storage { m := NewStorage(t) mockS := m.(*MockStorage) mockS.EXPECT().GetClientByClientID(gomock.Any()).AnyTimes().Return(&ConfClient{}, nil) - mockS.EXPECT().AuthorizeClientIDSecret(gomock.Any(), gomock.Any()).AnyTimes().Return(&ConfClient{}, nil) + mockS.EXPECT().AuthorizeClientIDSecret(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) return m } @@ -62,15 +62,19 @@ func ExpectValidClientID(s op.Storage) { mockS.EXPECT().GetClientByClientID(gomock.Any()).DoAndReturn( func(id string) (op.Client, error) { var appType op.ApplicationType + var authMethod op.AuthMethod switch id { case "web_client": appType = op.ApplicationTypeWeb + authMethod = op.AuthMethodBasic case "native_client": appType = op.ApplicationTypeNative + authMethod = op.AuthMethodNone case "useragent_client": appType = op.ApplicationTypeUserAgent + authMethod = op.AuthMethodBasic } - return &ConfClient{appType: appType}, nil + return &ConfClient{id: id, appType: appType, authMethod: authMethod}, nil }) } @@ -90,7 +94,9 @@ func ExpectSigningKey(s op.Storage) { } type ConfClient struct { - appType op.ApplicationType + id string + appType op.ApplicationType + authMethod op.AuthMethod } func (c *ConfClient) RedirectURIs() []string { @@ -109,3 +115,11 @@ func (c *ConfClient) LoginURL(id string) string { func (c *ConfClient) ApplicationType() op.ApplicationType { return c.appType } + +func (c *ConfClient) GetAuthMethod() op.AuthMethod { + return c.authMethod +} + +func (c *ConfClient) GetID() string { + return c.id +} diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 81b5b68..8ec7aea 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -20,8 +20,7 @@ type AuthStorage interface { type OPStorage interface { GetClientByClientID(string) (Client, error) - AuthorizeClientIDSecret(string, string) (Client, error) - AuthorizeClientIDCodeVerifier(string, string) (Client, error) + AuthorizeClientIDSecret(string, string) error GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) } diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index 5ece515..d934144 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -22,38 +22,21 @@ type Exchanger interface { } func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { - err := r.ParseForm() + tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder()) if err != nil { - ExchangeRequestError(w, r, ErrInvalidRequest("error parsing form")) - return - } - tokenReq := new(oidc.AccessTokenRequest) - - err = exchanger.Decoder().Decode(tokenReq, r.Form) - if err != nil { - ExchangeRequestError(w, r, ErrInvalidRequest("error decoding form")) - return + ExchangeRequestError(w, r, err) } if tokenReq.Code == "" { ExchangeRequestError(w, r, ErrInvalidRequest("code missing")) return } - authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code) - if err != nil { - ExchangeRequestError(w, r, err) - return - } - client, err := AuthorizeClient(r, tokenReq, authReq, exchanger) - if err != nil { - ExchangeRequestError(w, r, err) - return - } - err = ValidateAccessTokenRequest(tokenReq, client, authReq) + authReq, err := ValidateAccessTokenRequest(tokenReq, exchanger) if err != nil { ExchangeRequestError(w, r, err) return } + err = exchanger.Storage().DeleteAuthRequestAndCode(authReq.GetID(), tokenReq.Code) if err != nil { ExchangeRequestError(w, r, err) @@ -79,40 +62,84 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { utils.MarshalJSON(w, resp) } -func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, authReq AuthRequest, exchanger Exchanger) (Client, error) { - if tokenReq.ClientID == "" { - if !exchanger.AuthMethodBasicSupported() { - return nil, errors.New("basic not supported") - } - clientID, clientSecret, ok := r.BasicAuth() - if ok { - return exchanger.Storage().AuthorizeClientIDSecret(clientID, clientSecret) - } +func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.AccessTokenRequest, error) { + err := r.ParseForm() + if err != nil { + return nil, ErrInvalidRequest("error parsing form") + } + tokenReq := new(oidc.AccessTokenRequest) + err = decoder.Decode(tokenReq, r.Form) + if err != nil { + return nil, ErrInvalidRequest("error decoding form") + } + clientID, clientSecret, ok := r.BasicAuth() + if ok { + tokenReq.ClientID = clientID + tokenReq.ClientSecret = clientSecret } - if tokenReq.ClientSecret != "" { - if !exchanger.AuthMethodPostSupported() { - return nil, errors.New("post not supported") - } - return exchanger.Storage().AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret) - } - if tokenReq.CodeVerifier != "" { - if !authReq.GetCodeChallenge().Verify(tokenReq.CodeVerifier) { - return nil, ErrInvalidRequest("code_challenge invalid") - } - return exchanger.Storage().GetClientByClientID(tokenReq.ClientID) - } - return nil, errors.New("Unimplemented") //TODO: impl + return tokenReq, nil } -func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, client Client, authReq AuthRequest) error { +func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) { + authReq, client, err := AuthorizeClient(tokenReq, exchanger) + if err != nil { + return nil, err + } if client.GetID() != authReq.GetClientID() { - return ErrInvalidRequest("invalid auth code") + return nil, ErrInvalidRequest("invalid auth code") } if tokenReq.RedirectURI != authReq.GetRedirectURI() { - return ErrInvalidRequest("redirect_uri does no correspond") + return nil, ErrInvalidRequest("redirect_uri does no correspond") } - return nil + return authReq, nil +} + +func AuthorizeClient(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) { + client, err := exchanger.Storage().GetClientByClientID(tokenReq.ClientID) + if err != nil { + return nil, nil, err + } + switch client.GetAuthMethod() { + case AuthMethodNone: + authReq, err := AuthorizeCodeChallenge(tokenReq, exchanger) + return authReq, client, err + case AuthMethodPost: + if !exchanger.AuthMethodPostSupported() { + return nil, nil, errors.New("basic not supported") + } + err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger) + case AuthMethodBasic: + err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger) + default: + err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger) + } + if err != nil { + return nil, nil, err + } + authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code) + if err != nil { + return nil, nil, err + } + return authReq, client, nil +} + +func AuthorizeClientIDSecret(clientID, clientSecret string, exchanger Exchanger) error { + return exchanger.Storage().AuthorizeClientIDSecret(clientID, clientSecret) +} + +func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) { + if tokenReq.CodeVerifier == "" { + return nil, ErrInvalidRequest("code_challenge required") + } + authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code) + if err != nil { + return nil, ErrInvalidRequest("invalid code") + } + if !authReq.GetCodeChallenge().Verify(tokenReq.CodeVerifier) { + return nil, ErrInvalidRequest("code_challenge invalid") + } + return authReq, nil } func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) { @@ -120,6 +147,5 @@ func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.Tok } func ValidateTokenExchangeRequest(tokenReq oidc.TokenRequest, storage Storage) error { - return errors.New("Unimplemented") //TODO: impl }