fixes for token endpoint

This commit is contained in:
Livio Amstutz 2019-12-16 14:10:43 +01:00
parent 20a90c71d9
commit a21f6745f7
12 changed files with 192 additions and 146 deletions

View file

@ -38,7 +38,7 @@ func main() {
// cookieHandler := utils.NewCookieHandler(hashKey, nil, utils.WithUnsecure()) // cookieHandler := utils.NewCookieHandler(hashKey, nil, utils.WithUnsecure())
provider, err := rp.NewDefaultRP(rpConfig) //, rp.WithCookieHandler(cookieHandler)) provider, err := rp.NewDefaultRP(rpConfig) //, rp.WithCookieHandler(cookieHandler))
if err != nil { if err != nil {
logrus.Panic("error creating provider") logrus.Panicf("error creating provider %s", err.Error())
} }
// state := "foobar" // state := "foobar"

View file

@ -36,6 +36,7 @@ type AuthRequest struct {
RedirectURI string RedirectURI string
Nonce string Nonce string
ClientID string ClientID string
CodeChallenge *oidc.CodeChallenge
} }
func (a *AuthRequest) GetACR() string { func (a *AuthRequest) GetACR() string {
@ -66,6 +67,10 @@ func (a *AuthRequest) GetCode() string {
return "code" return "code"
} }
func (a *AuthRequest) GetCodeChallenge() *oidc.CodeChallenge {
return a.CodeChallenge
}
func (a *AuthRequest) GetID() string { func (a *AuthRequest) GetID() string {
return a.ID return a.ID
} }
@ -105,38 +110,23 @@ var (
func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) { func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) {
a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI} a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI}
if authReq.CodeChallenge != "" {
a.CodeChallenge = &oidc.CodeChallenge{
Challenge: authReq.CodeChallenge,
Method: authReq.CodeChallengeMethod,
}
}
return a, nil return a, nil
} }
func (s *OPStorage) GetClientByClientID(id string) (op.Client, error) { func (s *AuthStorage) AuthRequestByCode(string) (op.AuthRequest, error) {
if id == "none" {
return nil, errors.New("not found")
}
var appType op.ApplicationType
if id == "web" {
appType = op.ApplicationTypeWeb
} else if id == "native" {
appType = op.ApplicationTypeNative
} else {
appType = op.ApplicationTypeUserAgent
}
return &ConfClient{applicationType: appType}, nil
}
func (s *AuthStorage) AuthRequestByCode(op.Client, string, string) (op.AuthRequest, error) {
return a, nil return a, nil
} }
func (s *OPStorage) AuthorizeClientIDSecret(string, string) (op.Client, error) {
return &ConfClient{}, nil
}
func (s *OPStorage) AuthorizeClientIDCodeVerifier(string, string) (op.Client, error) {
return &ConfClient{}, nil
}
func (s *AuthStorage) DeleteAuthRequestAndCode(string, string) error { func (s *AuthStorage) DeleteAuthRequestAndCode(string, string) error {
return nil return nil
} }
func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) { func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) {
return a, nil return a, nil
} }
func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) { func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) {
return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil
} }
@ -152,9 +142,30 @@ func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) {
}, nil }, nil
} }
func (s *OPStorage) GetUserinfoFromScopes([]string) (interface{}, error) { func (s *OPStorage) GetClientByClientID(id string) (op.Client, error) {
return &oidc.Test{ if id == "none" {
Userinfo: oidc.Userinfo{ return nil, errors.New("not found")
}
var appType op.ApplicationType
var authMethod op.AuthMethod
if id == "web" {
appType = op.ApplicationTypeWeb
authMethod = op.AuthMethodBasic
} else if id == "native" {
appType = op.ApplicationTypeNative
authMethod = op.AuthMethodNone
} else {
appType = op.ApplicationTypeUserAgent
authMethod = op.AuthMethodNone
}
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil
}
func (s *OPStorage) AuthorizeClientIDSecret(id string, _ string) error {
return nil
}
func (s *OPStorage) GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) {
return &oidc.Userinfo{
Subject: a.GetSubject(), Subject: a.GetSubject(),
Address: &oidc.UserinfoAddress{ Address: &oidc.UserinfoAddress{
StreetAddress: "Hjkhkj 789\ndsf", StreetAddress: "Hjkhkj 789\ndsf",
@ -174,31 +185,18 @@ func (s *OPStorage) GetUserinfoFromScopes([]string) (interface{}, error) {
// "test": "test", // "test": "test",
// "hkjh": "", // "hkjh": "",
// }, // },
},
Add: "jkhnkj",
}, nil }, nil
} }
type info struct {
Subject string
}
func (i *info) GetSubject() string {
return i.Subject
}
func (i *info) Claims() map[string]interface{} {
return map[string]interface{}{
"hodor": "hoidoir",
"email": "asdfd",
"emailVerfied": true,
}
}
type ConfClient struct { type ConfClient struct {
applicationType op.ApplicationType applicationType op.ApplicationType
authMethod op.AuthMethod
ID string
} }
func (c *ConfClient) GetID() string {
return c.ID
}
func (c *ConfClient) RedirectURIs() []string { func (c *ConfClient) RedirectURIs() []string {
return []string{ return []string{
"https://registered.com/callback", "https://registered.com/callback",
@ -218,3 +216,7 @@ func (c *ConfClient) LoginURL(id string) string {
func (c *ConfClient) ApplicationType() op.ApplicationType { func (c *ConfClient) ApplicationType() op.ApplicationType {
return c.applicationType return c.applicationType
} }
func (c *ConfClient) GetAuthMethod() op.AuthMethod {
return c.authMethod
}

View file

@ -58,6 +58,9 @@ type AuthRequest struct {
IDTokenHint string `schema:"id_token_hint"` IDTokenHint string `schema:"id_token_hint"`
LoginHint string `schema:"login_hint"` LoginHint string `schema:"login_hint"`
ACRValues []string `schema:"acr_values"` ACRValues []string `schema:"acr_values"`
CodeChallenge string `schema:"code_challenge"`
CodeChallengeMethod CodeChallengeMethod `schema:"code_challenge_method"`
} }
// func (a *AuthRequest) GetID() string { // func (a *AuthRequest) GetID() string {

View file

@ -10,6 +10,7 @@ type Client interface {
GetID() string GetID() string
RedirectURIs() []string RedirectURIs() []string
ApplicationType() ApplicationType ApplicationType() ApplicationType
GetAuthMethod() AuthMethod
LoginURL(string) string LoginURL(string) string
} }
@ -18,3 +19,5 @@ func IsConfidentialType(c Client) bool {
} }
type ApplicationType int type ApplicationType int
type AuthMethod string

View file

@ -16,8 +16,9 @@ const (
defaultUserinfoEndpoint = "userinfo" defaultUserinfoEndpoint = "userinfo"
defaultKeysEndpoint = "keys" defaultKeysEndpoint = "keys"
AuthMethodBasic = "client_secret_basic" AuthMethodBasic AuthMethod = "client_secret_basic"
AuthMethodPost = "client_secret_post" AuthMethodPost = "client_secret_post"
AuthMethodNone = "none"
DefaultIDTokenValidity = time.Duration(5 * time.Minute) DefaultIDTokenValidity = time.Duration(5 * time.Minute)
) )

View file

@ -110,10 +110,10 @@ func SubjectTypes(c Configuration) []string {
func AuthMethods(c Configuration) []string { func AuthMethods(c Configuration) []string {
authMethods := []string{ authMethods := []string{
AuthMethodBasic, string(AuthMethodBasic),
} }
if c.AuthMethodPostSupported() { if c.AuthMethodPostSupported() {
authMethods = append(authMethods, AuthMethodPost) authMethods = append(authMethods, string(AuthMethodPost))
} }
return authMethods return authMethods
} }

View file

@ -214,7 +214,7 @@ func Test_AuthMethods(t *testing.T) {
m.EXPECT().AuthMethodPostSupported().Return(false) m.EXPECT().AuthMethodPostSupported().Return(false)
return m return m
}()}, }()},
[]string{op.AuthMethodBasic}, []string{string(op.AuthMethodBasic)},
}, },
{ {
"basic and post", "basic and post",
@ -222,7 +222,7 @@ func Test_AuthMethods(t *testing.T) {
m.EXPECT().AuthMethodPostSupported().Return(true) m.EXPECT().AuthMethodPostSupported().Return(true)
return m return m
}()}, }()},
[]string{op.AuthMethodBasic, op.AuthMethodPost}, []string{string(op.AuthMethodBasic), string(op.AuthMethodPost)},
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View file

@ -47,6 +47,20 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
} }
// GetAuthMethod mocks base method
func (m *MockClient) GetAuthMethod() op.AuthMethod {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthMethod")
ret0, _ := ret[0].(op.AuthMethod)
return ret0
}
// GetAuthMethod indicates an expected call of GetAuthMethod
func (mr *MockClientMockRecorder) GetAuthMethod() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthMethod", reflect.TypeOf((*MockClient)(nil).GetAuthMethod))
}
// GetID mocks base method // GetID mocks base method
func (m *MockClient) GetID() string { func (m *MockClient) GetID() string {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -65,28 +65,12 @@ func (mr *MockStorageMockRecorder) AuthRequestByID(arg0 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0)
} }
// AuthorizeClientIDCodeVerifier mocks base method
func (m *MockStorage) AuthorizeClientIDCodeVerifier(arg0, arg1 string) (op.Client, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthorizeClientIDCodeVerifier", arg0, arg1)
ret0, _ := ret[0].(op.Client)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AuthorizeClientIDCodeVerifier indicates an expected call of AuthorizeClientIDCodeVerifier
func (mr *MockStorageMockRecorder) AuthorizeClientIDCodeVerifier(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDCodeVerifier", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDCodeVerifier), arg0, arg1)
}
// AuthorizeClientIDSecret mocks base method // AuthorizeClientIDSecret mocks base method
func (m *MockStorage) AuthorizeClientIDSecret(arg0, arg1 string) (op.Client, error) { func (m *MockStorage) AuthorizeClientIDSecret(arg0, arg1 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1) ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1)
ret0, _ := ret[0].(op.Client) ret0, _ := ret[0].(error)
ret1, _ := ret[1].(error) return ret0
return ret0, ret1
} }
// AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret // AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret

View file

@ -31,7 +31,7 @@ func NewMockStorageAny(t *testing.T) op.Storage {
m := NewStorage(t) m := NewStorage(t)
mockS := m.(*MockStorage) mockS := m.(*MockStorage)
mockS.EXPECT().GetClientByClientID(gomock.Any()).AnyTimes().Return(&ConfClient{}, nil) mockS.EXPECT().GetClientByClientID(gomock.Any()).AnyTimes().Return(&ConfClient{}, nil)
mockS.EXPECT().AuthorizeClientIDSecret(gomock.Any(), gomock.Any()).AnyTimes().Return(&ConfClient{}, nil) mockS.EXPECT().AuthorizeClientIDSecret(gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
return m return m
} }
@ -62,15 +62,19 @@ func ExpectValidClientID(s op.Storage) {
mockS.EXPECT().GetClientByClientID(gomock.Any()).DoAndReturn( mockS.EXPECT().GetClientByClientID(gomock.Any()).DoAndReturn(
func(id string) (op.Client, error) { func(id string) (op.Client, error) {
var appType op.ApplicationType var appType op.ApplicationType
var authMethod op.AuthMethod
switch id { switch id {
case "web_client": case "web_client":
appType = op.ApplicationTypeWeb appType = op.ApplicationTypeWeb
authMethod = op.AuthMethodBasic
case "native_client": case "native_client":
appType = op.ApplicationTypeNative appType = op.ApplicationTypeNative
authMethod = op.AuthMethodNone
case "useragent_client": case "useragent_client":
appType = op.ApplicationTypeUserAgent appType = op.ApplicationTypeUserAgent
authMethod = op.AuthMethodBasic
} }
return &ConfClient{appType: appType}, nil return &ConfClient{id: id, appType: appType, authMethod: authMethod}, nil
}) })
} }
@ -90,7 +94,9 @@ func ExpectSigningKey(s op.Storage) {
} }
type ConfClient struct { type ConfClient struct {
id string
appType op.ApplicationType appType op.ApplicationType
authMethod op.AuthMethod
} }
func (c *ConfClient) RedirectURIs() []string { func (c *ConfClient) RedirectURIs() []string {
@ -109,3 +115,11 @@ func (c *ConfClient) LoginURL(id string) string {
func (c *ConfClient) ApplicationType() op.ApplicationType { func (c *ConfClient) ApplicationType() op.ApplicationType {
return c.appType return c.appType
} }
func (c *ConfClient) GetAuthMethod() op.AuthMethod {
return c.authMethod
}
func (c *ConfClient) GetID() string {
return c.id
}

View file

@ -20,8 +20,7 @@ type AuthStorage interface {
type OPStorage interface { type OPStorage interface {
GetClientByClientID(string) (Client, error) GetClientByClientID(string) (Client, error)
AuthorizeClientIDSecret(string, string) (Client, error) AuthorizeClientIDSecret(string, string) error
AuthorizeClientIDCodeVerifier(string, string) (Client, error)
GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) GetUserinfoFromScopes([]string) (*oidc.Userinfo, error)
} }

View file

@ -22,38 +22,21 @@ type Exchanger interface {
} }
func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
err := r.ParseForm() tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder())
if err != nil { if err != nil {
ExchangeRequestError(w, r, ErrInvalidRequest("error parsing form")) ExchangeRequestError(w, r, err)
return
}
tokenReq := new(oidc.AccessTokenRequest)
err = exchanger.Decoder().Decode(tokenReq, r.Form)
if err != nil {
ExchangeRequestError(w, r, ErrInvalidRequest("error decoding form"))
return
} }
if tokenReq.Code == "" { if tokenReq.Code == "" {
ExchangeRequestError(w, r, ErrInvalidRequest("code missing")) ExchangeRequestError(w, r, ErrInvalidRequest("code missing"))
return return
} }
authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code) authReq, err := ValidateAccessTokenRequest(tokenReq, exchanger)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
client, err := AuthorizeClient(r, tokenReq, authReq, exchanger)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
err = ValidateAccessTokenRequest(tokenReq, client, authReq)
if err != nil { if err != nil {
ExchangeRequestError(w, r, err) ExchangeRequestError(w, r, err)
return return
} }
err = exchanger.Storage().DeleteAuthRequestAndCode(authReq.GetID(), tokenReq.Code) err = exchanger.Storage().DeleteAuthRequestAndCode(authReq.GetID(), tokenReq.Code)
if err != nil { if err != nil {
ExchangeRequestError(w, r, err) ExchangeRequestError(w, r, err)
@ -79,40 +62,84 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
utils.MarshalJSON(w, resp) utils.MarshalJSON(w, resp)
} }
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, authReq AuthRequest, exchanger Exchanger) (Client, error) { func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.AccessTokenRequest, error) {
if tokenReq.ClientID == "" { err := r.ParseForm()
if !exchanger.AuthMethodBasicSupported() { if err != nil {
return nil, errors.New("basic not supported") return nil, ErrInvalidRequest("error parsing form")
}
tokenReq := new(oidc.AccessTokenRequest)
err = decoder.Decode(tokenReq, r.Form)
if err != nil {
return nil, ErrInvalidRequest("error decoding form")
} }
clientID, clientSecret, ok := r.BasicAuth() clientID, clientSecret, ok := r.BasicAuth()
if ok { if ok {
return exchanger.Storage().AuthorizeClientIDSecret(clientID, clientSecret) tokenReq.ClientID = clientID
} tokenReq.ClientSecret = clientSecret
} }
if tokenReq.ClientSecret != "" { return tokenReq, nil
}
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
authReq, client, err := AuthorizeClient(tokenReq, exchanger)
if err != nil {
return nil, err
}
if client.GetID() != authReq.GetClientID() {
return nil, ErrInvalidRequest("invalid auth code")
}
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
return nil, ErrInvalidRequest("redirect_uri does no correspond")
}
return authReq, nil
}
func AuthorizeClient(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
client, err := exchanger.Storage().GetClientByClientID(tokenReq.ClientID)
if err != nil {
return nil, nil, err
}
switch client.GetAuthMethod() {
case AuthMethodNone:
authReq, err := AuthorizeCodeChallenge(tokenReq, exchanger)
return authReq, client, err
case AuthMethodPost:
if !exchanger.AuthMethodPostSupported() { if !exchanger.AuthMethodPostSupported() {
return nil, errors.New("post not supported") return nil, nil, errors.New("basic not supported")
} }
return exchanger.Storage().AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret) err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger)
case AuthMethodBasic:
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger)
default:
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger)
}
if err != nil {
return nil, nil, err
}
authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code)
if err != nil {
return nil, nil, err
}
return authReq, client, nil
}
func AuthorizeClientIDSecret(clientID, clientSecret string, exchanger Exchanger) error {
return exchanger.Storage().AuthorizeClientIDSecret(clientID, clientSecret)
}
func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
if tokenReq.CodeVerifier == "" {
return nil, ErrInvalidRequest("code_challenge required")
}
authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code)
if err != nil {
return nil, ErrInvalidRequest("invalid code")
} }
if tokenReq.CodeVerifier != "" {
if !authReq.GetCodeChallenge().Verify(tokenReq.CodeVerifier) { if !authReq.GetCodeChallenge().Verify(tokenReq.CodeVerifier) {
return nil, ErrInvalidRequest("code_challenge invalid") return nil, ErrInvalidRequest("code_challenge invalid")
} }
return exchanger.Storage().GetClientByClientID(tokenReq.ClientID) return authReq, nil
}
return nil, errors.New("Unimplemented") //TODO: impl
}
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, client Client, authReq AuthRequest) error {
if client.GetID() != authReq.GetClientID() {
return ErrInvalidRequest("invalid auth code")
}
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
return ErrInvalidRequest("redirect_uri does no correspond")
}
return nil
} }
func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) { func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) {
@ -120,6 +147,5 @@ func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.Tok
} }
func ValidateTokenExchangeRequest(tokenReq oidc.TokenRequest, storage Storage) error { func ValidateTokenExchangeRequest(tokenReq oidc.TokenRequest, storage Storage) error {
return errors.New("Unimplemented") //TODO: impl return errors.New("Unimplemented") //TODO: impl
} }