fixes for token endpoint

This commit is contained in:
Livio Amstutz 2019-12-16 14:10:43 +01:00
parent 20a90c71d9
commit a21f6745f7
12 changed files with 192 additions and 146 deletions

View file

@ -38,7 +38,7 @@ func main() {
// cookieHandler := utils.NewCookieHandler(hashKey, nil, utils.WithUnsecure())
provider, err := rp.NewDefaultRP(rpConfig) //, rp.WithCookieHandler(cookieHandler))
if err != nil {
logrus.Panic("error creating provider")
logrus.Panicf("error creating provider %s", err.Error())
}
// state := "foobar"

View file

@ -31,11 +31,12 @@ func NewAuthStorage() op.AuthStorage {
}
type AuthRequest struct {
ID string
ResponseType oidc.ResponseType
RedirectURI string
Nonce string
ClientID string
ID string
ResponseType oidc.ResponseType
RedirectURI string
Nonce string
ClientID string
CodeChallenge *oidc.CodeChallenge
}
func (a *AuthRequest) GetACR() string {
@ -66,6 +67,10 @@ func (a *AuthRequest) GetCode() string {
return "code"
}
func (a *AuthRequest) GetCodeChallenge() *oidc.CodeChallenge {
return a.CodeChallenge
}
func (a *AuthRequest) GetID() string {
return a.ID
}
@ -105,38 +110,23 @@ var (
func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) {
a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI}
if authReq.CodeChallenge != "" {
a.CodeChallenge = &oidc.CodeChallenge{
Challenge: authReq.CodeChallenge,
Method: authReq.CodeChallengeMethod,
}
}
return a, nil
}
func (s *OPStorage) GetClientByClientID(id string) (op.Client, error) {
if id == "none" {
return nil, errors.New("not found")
}
var appType op.ApplicationType
if id == "web" {
appType = op.ApplicationTypeWeb
} else if id == "native" {
appType = op.ApplicationTypeNative
} else {
appType = op.ApplicationTypeUserAgent
}
return &ConfClient{applicationType: appType}, nil
}
func (s *AuthStorage) AuthRequestByCode(op.Client, string, string) (op.AuthRequest, error) {
func (s *AuthStorage) AuthRequestByCode(string) (op.AuthRequest, error) {
return a, nil
}
func (s *OPStorage) AuthorizeClientIDSecret(string, string) (op.Client, error) {
return &ConfClient{}, nil
}
func (s *OPStorage) AuthorizeClientIDCodeVerifier(string, string) (op.Client, error) {
return &ConfClient{}, nil
}
func (s *AuthStorage) DeleteAuthRequestAndCode(string, string) error {
return nil
}
func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) {
return a, nil
}
func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) {
return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil
}
@ -152,53 +142,61 @@ func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) {
}, nil
}
func (s *OPStorage) GetUserinfoFromScopes([]string) (interface{}, error) {
return &oidc.Test{
Userinfo: oidc.Userinfo{
Subject: a.GetSubject(),
Address: &oidc.UserinfoAddress{
StreetAddress: "Hjkhkj 789\ndsf",
},
UserinfoEmail: oidc.UserinfoEmail{
Email: "test",
EmailVerified: true,
},
UserinfoPhone: oidc.UserinfoPhone{
PhoneNumber: "sadsa",
PhoneNumberVerified: true,
},
UserinfoProfile: oidc.UserinfoProfile{
UpdatedAt: time.Now(),
},
// Claims: map[string]interface{}{
// "test": "test",
// "hkjh": "",
// },
},
Add: "jkhnkj",
}, nil
}
type info struct {
Subject string
}
func (i *info) GetSubject() string {
return i.Subject
}
func (i *info) Claims() map[string]interface{} {
return map[string]interface{}{
"hodor": "hoidoir",
"email": "asdfd",
"emailVerfied": true,
func (s *OPStorage) GetClientByClientID(id string) (op.Client, error) {
if id == "none" {
return nil, errors.New("not found")
}
var appType op.ApplicationType
var authMethod op.AuthMethod
if id == "web" {
appType = op.ApplicationTypeWeb
authMethod = op.AuthMethodBasic
} else if id == "native" {
appType = op.ApplicationTypeNative
authMethod = op.AuthMethodNone
} else {
appType = op.ApplicationTypeUserAgent
authMethod = op.AuthMethodNone
}
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil
}
func (s *OPStorage) AuthorizeClientIDSecret(id string, _ string) error {
return nil
}
func (s *OPStorage) GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) {
return &oidc.Userinfo{
Subject: a.GetSubject(),
Address: &oidc.UserinfoAddress{
StreetAddress: "Hjkhkj 789\ndsf",
},
UserinfoEmail: oidc.UserinfoEmail{
Email: "test",
EmailVerified: true,
},
UserinfoPhone: oidc.UserinfoPhone{
PhoneNumber: "sadsa",
PhoneNumberVerified: true,
},
UserinfoProfile: oidc.UserinfoProfile{
UpdatedAt: time.Now(),
},
// Claims: map[string]interface{}{
// "test": "test",
// "hkjh": "",
// },
}, nil
}
type ConfClient struct {
applicationType op.ApplicationType
authMethod op.AuthMethod
ID string
}
func (c *ConfClient) GetID() string {
return c.ID
}
func (c *ConfClient) RedirectURIs() []string {
return []string{
"https://registered.com/callback",
@ -218,3 +216,7 @@ func (c *ConfClient) LoginURL(id string) string {
func (c *ConfClient) ApplicationType() op.ApplicationType {
return c.applicationType
}
func (c *ConfClient) GetAuthMethod() op.AuthMethod {
return c.authMethod
}

View file

@ -58,6 +58,9 @@ type AuthRequest struct {
IDTokenHint string `schema:"id_token_hint"`
LoginHint string `schema:"login_hint"`
ACRValues []string `schema:"acr_values"`
CodeChallenge string `schema:"code_challenge"`
CodeChallengeMethod CodeChallengeMethod `schema:"code_challenge_method"`
}
// func (a *AuthRequest) GetID() string {

View file

@ -10,6 +10,7 @@ type Client interface {
GetID() string
RedirectURIs() []string
ApplicationType() ApplicationType
GetAuthMethod() AuthMethod
LoginURL(string) string
}
@ -18,3 +19,5 @@ func IsConfidentialType(c Client) bool {
}
type ApplicationType int
type AuthMethod string

View file

@ -16,8 +16,9 @@ const (
defaultUserinfoEndpoint = "userinfo"
defaultKeysEndpoint = "keys"
AuthMethodBasic = "client_secret_basic"
AuthMethodPost = "client_secret_post"
AuthMethodBasic AuthMethod = "client_secret_basic"
AuthMethodPost = "client_secret_post"
AuthMethodNone = "none"
DefaultIDTokenValidity = time.Duration(5 * time.Minute)
)

View file

@ -110,10 +110,10 @@ func SubjectTypes(c Configuration) []string {
func AuthMethods(c Configuration) []string {
authMethods := []string{
AuthMethodBasic,
string(AuthMethodBasic),
}
if c.AuthMethodPostSupported() {
authMethods = append(authMethods, AuthMethodPost)
authMethods = append(authMethods, string(AuthMethodPost))
}
return authMethods
}

View file

@ -214,7 +214,7 @@ func Test_AuthMethods(t *testing.T) {
m.EXPECT().AuthMethodPostSupported().Return(false)
return m
}()},
[]string{op.AuthMethodBasic},
[]string{string(op.AuthMethodBasic)},
},
{
"basic and post",
@ -222,7 +222,7 @@ func Test_AuthMethods(t *testing.T) {
m.EXPECT().AuthMethodPostSupported().Return(true)
return m
}()},
[]string{op.AuthMethodBasic, op.AuthMethodPost},
[]string{string(op.AuthMethodBasic), string(op.AuthMethodPost)},
},
}
for _, tt := range tests {

View file

@ -47,6 +47,20 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
}
// GetAuthMethod mocks base method
func (m *MockClient) GetAuthMethod() op.AuthMethod {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthMethod")
ret0, _ := ret[0].(op.AuthMethod)
return ret0
}
// GetAuthMethod indicates an expected call of GetAuthMethod
func (mr *MockClientMockRecorder) GetAuthMethod() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthMethod", reflect.TypeOf((*MockClient)(nil).GetAuthMethod))
}
// GetID mocks base method
func (m *MockClient) GetID() string {
m.ctrl.T.Helper()

View file

@ -65,28 +65,12 @@ func (mr *MockStorageMockRecorder) AuthRequestByID(arg0 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0)
}
// AuthorizeClientIDCodeVerifier mocks base method
func (m *MockStorage) AuthorizeClientIDCodeVerifier(arg0, arg1 string) (op.Client, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthorizeClientIDCodeVerifier", arg0, arg1)
ret0, _ := ret[0].(op.Client)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AuthorizeClientIDCodeVerifier indicates an expected call of AuthorizeClientIDCodeVerifier
func (mr *MockStorageMockRecorder) AuthorizeClientIDCodeVerifier(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDCodeVerifier", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDCodeVerifier), arg0, arg1)
}
// AuthorizeClientIDSecret mocks base method
func (m *MockStorage) AuthorizeClientIDSecret(arg0, arg1 string) (op.Client, error) {
func (m *MockStorage) AuthorizeClientIDSecret(arg0, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1)
ret0, _ := ret[0].(op.Client)
ret1, _ := ret[1].(error)
return ret0, ret1
ret0, _ := ret[0].(error)
return ret0
}
// AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret

View file

@ -31,7 +31,7 @@ func NewMockStorageAny(t *testing.T) op.Storage {
m := NewStorage(t)
mockS := m.(*MockStorage)
mockS.EXPECT().GetClientByClientID(gomock.Any()).AnyTimes().Return(&ConfClient{}, nil)
mockS.EXPECT().AuthorizeClientIDSecret(gomock.Any(), gomock.Any()).AnyTimes().Return(&ConfClient{}, nil)
mockS.EXPECT().AuthorizeClientIDSecret(gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
return m
}
@ -62,15 +62,19 @@ func ExpectValidClientID(s op.Storage) {
mockS.EXPECT().GetClientByClientID(gomock.Any()).DoAndReturn(
func(id string) (op.Client, error) {
var appType op.ApplicationType
var authMethod op.AuthMethod
switch id {
case "web_client":
appType = op.ApplicationTypeWeb
authMethod = op.AuthMethodBasic
case "native_client":
appType = op.ApplicationTypeNative
authMethod = op.AuthMethodNone
case "useragent_client":
appType = op.ApplicationTypeUserAgent
authMethod = op.AuthMethodBasic
}
return &ConfClient{appType: appType}, nil
return &ConfClient{id: id, appType: appType, authMethod: authMethod}, nil
})
}
@ -90,7 +94,9 @@ func ExpectSigningKey(s op.Storage) {
}
type ConfClient struct {
appType op.ApplicationType
id string
appType op.ApplicationType
authMethod op.AuthMethod
}
func (c *ConfClient) RedirectURIs() []string {
@ -109,3 +115,11 @@ func (c *ConfClient) LoginURL(id string) string {
func (c *ConfClient) ApplicationType() op.ApplicationType {
return c.appType
}
func (c *ConfClient) GetAuthMethod() op.AuthMethod {
return c.authMethod
}
func (c *ConfClient) GetID() string {
return c.id
}

View file

@ -20,8 +20,7 @@ type AuthStorage interface {
type OPStorage interface {
GetClientByClientID(string) (Client, error)
AuthorizeClientIDSecret(string, string) (Client, error)
AuthorizeClientIDCodeVerifier(string, string) (Client, error)
AuthorizeClientIDSecret(string, string) error
GetUserinfoFromScopes([]string) (*oidc.Userinfo, error)
}

View file

@ -22,38 +22,21 @@ type Exchanger interface {
}
func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
err := r.ParseForm()
tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder())
if err != nil {
ExchangeRequestError(w, r, ErrInvalidRequest("error parsing form"))
return
}
tokenReq := new(oidc.AccessTokenRequest)
err = exchanger.Decoder().Decode(tokenReq, r.Form)
if err != nil {
ExchangeRequestError(w, r, ErrInvalidRequest("error decoding form"))
return
ExchangeRequestError(w, r, err)
}
if tokenReq.Code == "" {
ExchangeRequestError(w, r, ErrInvalidRequest("code missing"))
return
}
authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
client, err := AuthorizeClient(r, tokenReq, authReq, exchanger)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
err = ValidateAccessTokenRequest(tokenReq, client, authReq)
authReq, err := ValidateAccessTokenRequest(tokenReq, exchanger)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
err = exchanger.Storage().DeleteAuthRequestAndCode(authReq.GetID(), tokenReq.Code)
if err != nil {
ExchangeRequestError(w, r, err)
@ -79,40 +62,84 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
utils.MarshalJSON(w, resp)
}
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, authReq AuthRequest, exchanger Exchanger) (Client, error) {
if tokenReq.ClientID == "" {
if !exchanger.AuthMethodBasicSupported() {
return nil, errors.New("basic not supported")
}
clientID, clientSecret, ok := r.BasicAuth()
if ok {
return exchanger.Storage().AuthorizeClientIDSecret(clientID, clientSecret)
}
func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.AccessTokenRequest, error) {
err := r.ParseForm()
if err != nil {
return nil, ErrInvalidRequest("error parsing form")
}
tokenReq := new(oidc.AccessTokenRequest)
err = decoder.Decode(tokenReq, r.Form)
if err != nil {
return nil, ErrInvalidRequest("error decoding form")
}
clientID, clientSecret, ok := r.BasicAuth()
if ok {
tokenReq.ClientID = clientID
tokenReq.ClientSecret = clientSecret
}
if tokenReq.ClientSecret != "" {
if !exchanger.AuthMethodPostSupported() {
return nil, errors.New("post not supported")
}
return exchanger.Storage().AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret)
}
if tokenReq.CodeVerifier != "" {
if !authReq.GetCodeChallenge().Verify(tokenReq.CodeVerifier) {
return nil, ErrInvalidRequest("code_challenge invalid")
}
return exchanger.Storage().GetClientByClientID(tokenReq.ClientID)
}
return nil, errors.New("Unimplemented") //TODO: impl
return tokenReq, nil
}
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, client Client, authReq AuthRequest) error {
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
authReq, client, err := AuthorizeClient(tokenReq, exchanger)
if err != nil {
return nil, err
}
if client.GetID() != authReq.GetClientID() {
return ErrInvalidRequest("invalid auth code")
return nil, ErrInvalidRequest("invalid auth code")
}
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
return ErrInvalidRequest("redirect_uri does no correspond")
return nil, ErrInvalidRequest("redirect_uri does no correspond")
}
return nil
return authReq, nil
}
func AuthorizeClient(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
client, err := exchanger.Storage().GetClientByClientID(tokenReq.ClientID)
if err != nil {
return nil, nil, err
}
switch client.GetAuthMethod() {
case AuthMethodNone:
authReq, err := AuthorizeCodeChallenge(tokenReq, exchanger)
return authReq, client, err
case AuthMethodPost:
if !exchanger.AuthMethodPostSupported() {
return nil, nil, errors.New("basic not supported")
}
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger)
case AuthMethodBasic:
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger)
default:
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger)
}
if err != nil {
return nil, nil, err
}
authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code)
if err != nil {
return nil, nil, err
}
return authReq, client, nil
}
func AuthorizeClientIDSecret(clientID, clientSecret string, exchanger Exchanger) error {
return exchanger.Storage().AuthorizeClientIDSecret(clientID, clientSecret)
}
func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
if tokenReq.CodeVerifier == "" {
return nil, ErrInvalidRequest("code_challenge required")
}
authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code)
if err != nil {
return nil, ErrInvalidRequest("invalid code")
}
if !authReq.GetCodeChallenge().Verify(tokenReq.CodeVerifier) {
return nil, ErrInvalidRequest("code_challenge invalid")
}
return authReq, nil
}
func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) {
@ -120,6 +147,5 @@ func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.Tok
}
func ValidateTokenExchangeRequest(tokenReq oidc.TokenRequest, storage Storage) error {
return errors.New("Unimplemented") //TODO: impl
}