fixes for token endpoint
This commit is contained in:
parent
20a90c71d9
commit
a21f6745f7
12 changed files with 192 additions and 146 deletions
|
@ -38,7 +38,7 @@ func main() {
|
||||||
// cookieHandler := utils.NewCookieHandler(hashKey, nil, utils.WithUnsecure())
|
// cookieHandler := utils.NewCookieHandler(hashKey, nil, utils.WithUnsecure())
|
||||||
provider, err := rp.NewDefaultRP(rpConfig) //, rp.WithCookieHandler(cookieHandler))
|
provider, err := rp.NewDefaultRP(rpConfig) //, rp.WithCookieHandler(cookieHandler))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Panic("error creating provider")
|
logrus.Panicf("error creating provider %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// state := "foobar"
|
// state := "foobar"
|
||||||
|
|
|
@ -31,11 +31,12 @@ func NewAuthStorage() op.AuthStorage {
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthRequest struct {
|
type AuthRequest struct {
|
||||||
ID string
|
ID string
|
||||||
ResponseType oidc.ResponseType
|
ResponseType oidc.ResponseType
|
||||||
RedirectURI string
|
RedirectURI string
|
||||||
Nonce string
|
Nonce string
|
||||||
ClientID string
|
ClientID string
|
||||||
|
CodeChallenge *oidc.CodeChallenge
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthRequest) GetACR() string {
|
func (a *AuthRequest) GetACR() string {
|
||||||
|
@ -66,6 +67,10 @@ func (a *AuthRequest) GetCode() string {
|
||||||
return "code"
|
return "code"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *AuthRequest) GetCodeChallenge() *oidc.CodeChallenge {
|
||||||
|
return a.CodeChallenge
|
||||||
|
}
|
||||||
|
|
||||||
func (a *AuthRequest) GetID() string {
|
func (a *AuthRequest) GetID() string {
|
||||||
return a.ID
|
return a.ID
|
||||||
}
|
}
|
||||||
|
@ -105,38 +110,23 @@ var (
|
||||||
|
|
||||||
func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) {
|
func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) {
|
||||||
a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI}
|
a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI}
|
||||||
|
if authReq.CodeChallenge != "" {
|
||||||
|
a.CodeChallenge = &oidc.CodeChallenge{
|
||||||
|
Challenge: authReq.CodeChallenge,
|
||||||
|
Method: authReq.CodeChallengeMethod,
|
||||||
|
}
|
||||||
|
}
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
func (s *OPStorage) GetClientByClientID(id string) (op.Client, error) {
|
func (s *AuthStorage) AuthRequestByCode(string) (op.AuthRequest, error) {
|
||||||
if id == "none" {
|
|
||||||
return nil, errors.New("not found")
|
|
||||||
}
|
|
||||||
var appType op.ApplicationType
|
|
||||||
if id == "web" {
|
|
||||||
appType = op.ApplicationTypeWeb
|
|
||||||
} else if id == "native" {
|
|
||||||
appType = op.ApplicationTypeNative
|
|
||||||
} else {
|
|
||||||
appType = op.ApplicationTypeUserAgent
|
|
||||||
}
|
|
||||||
return &ConfClient{applicationType: appType}, nil
|
|
||||||
}
|
|
||||||
func (s *AuthStorage) AuthRequestByCode(op.Client, string, string) (op.AuthRequest, error) {
|
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
func (s *OPStorage) AuthorizeClientIDSecret(string, string) (op.Client, error) {
|
|
||||||
return &ConfClient{}, nil
|
|
||||||
}
|
|
||||||
func (s *OPStorage) AuthorizeClientIDCodeVerifier(string, string) (op.Client, error) {
|
|
||||||
return &ConfClient{}, nil
|
|
||||||
}
|
|
||||||
func (s *AuthStorage) DeleteAuthRequestAndCode(string, string) error {
|
func (s *AuthStorage) DeleteAuthRequestAndCode(string, string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) {
|
func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) {
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) {
|
func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) {
|
||||||
return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil
|
return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil
|
||||||
}
|
}
|
||||||
|
@ -152,53 +142,61 @@ func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OPStorage) GetUserinfoFromScopes([]string) (interface{}, error) {
|
func (s *OPStorage) GetClientByClientID(id string) (op.Client, error) {
|
||||||
return &oidc.Test{
|
if id == "none" {
|
||||||
Userinfo: oidc.Userinfo{
|
return nil, errors.New("not found")
|
||||||
Subject: a.GetSubject(),
|
|
||||||
Address: &oidc.UserinfoAddress{
|
|
||||||
StreetAddress: "Hjkhkj 789\ndsf",
|
|
||||||
},
|
|
||||||
UserinfoEmail: oidc.UserinfoEmail{
|
|
||||||
Email: "test",
|
|
||||||
EmailVerified: true,
|
|
||||||
},
|
|
||||||
UserinfoPhone: oidc.UserinfoPhone{
|
|
||||||
PhoneNumber: "sadsa",
|
|
||||||
PhoneNumberVerified: true,
|
|
||||||
},
|
|
||||||
UserinfoProfile: oidc.UserinfoProfile{
|
|
||||||
UpdatedAt: time.Now(),
|
|
||||||
},
|
|
||||||
// Claims: map[string]interface{}{
|
|
||||||
// "test": "test",
|
|
||||||
// "hkjh": "",
|
|
||||||
// },
|
|
||||||
},
|
|
||||||
Add: "jkhnkj",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type info struct {
|
|
||||||
Subject string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *info) GetSubject() string {
|
|
||||||
return i.Subject
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *info) Claims() map[string]interface{} {
|
|
||||||
return map[string]interface{}{
|
|
||||||
"hodor": "hoidoir",
|
|
||||||
"email": "asdfd",
|
|
||||||
"emailVerfied": true,
|
|
||||||
}
|
}
|
||||||
|
var appType op.ApplicationType
|
||||||
|
var authMethod op.AuthMethod
|
||||||
|
if id == "web" {
|
||||||
|
appType = op.ApplicationTypeWeb
|
||||||
|
authMethod = op.AuthMethodBasic
|
||||||
|
} else if id == "native" {
|
||||||
|
appType = op.ApplicationTypeNative
|
||||||
|
authMethod = op.AuthMethodNone
|
||||||
|
} else {
|
||||||
|
appType = op.ApplicationTypeUserAgent
|
||||||
|
authMethod = op.AuthMethodNone
|
||||||
|
}
|
||||||
|
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OPStorage) AuthorizeClientIDSecret(id string, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *OPStorage) GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) {
|
||||||
|
return &oidc.Userinfo{
|
||||||
|
Subject: a.GetSubject(),
|
||||||
|
Address: &oidc.UserinfoAddress{
|
||||||
|
StreetAddress: "Hjkhkj 789\ndsf",
|
||||||
|
},
|
||||||
|
UserinfoEmail: oidc.UserinfoEmail{
|
||||||
|
Email: "test",
|
||||||
|
EmailVerified: true,
|
||||||
|
},
|
||||||
|
UserinfoPhone: oidc.UserinfoPhone{
|
||||||
|
PhoneNumber: "sadsa",
|
||||||
|
PhoneNumberVerified: true,
|
||||||
|
},
|
||||||
|
UserinfoProfile: oidc.UserinfoProfile{
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
},
|
||||||
|
// Claims: map[string]interface{}{
|
||||||
|
// "test": "test",
|
||||||
|
// "hkjh": "",
|
||||||
|
// },
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConfClient struct {
|
type ConfClient struct {
|
||||||
applicationType op.ApplicationType
|
applicationType op.ApplicationType
|
||||||
|
authMethod op.AuthMethod
|
||||||
|
ID string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ConfClient) GetID() string {
|
||||||
|
return c.ID
|
||||||
|
}
|
||||||
func (c *ConfClient) RedirectURIs() []string {
|
func (c *ConfClient) RedirectURIs() []string {
|
||||||
return []string{
|
return []string{
|
||||||
"https://registered.com/callback",
|
"https://registered.com/callback",
|
||||||
|
@ -218,3 +216,7 @@ func (c *ConfClient) LoginURL(id string) string {
|
||||||
func (c *ConfClient) ApplicationType() op.ApplicationType {
|
func (c *ConfClient) ApplicationType() op.ApplicationType {
|
||||||
return c.applicationType
|
return c.applicationType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ConfClient) GetAuthMethod() op.AuthMethod {
|
||||||
|
return c.authMethod
|
||||||
|
}
|
||||||
|
|
|
@ -58,6 +58,9 @@ type AuthRequest struct {
|
||||||
IDTokenHint string `schema:"id_token_hint"`
|
IDTokenHint string `schema:"id_token_hint"`
|
||||||
LoginHint string `schema:"login_hint"`
|
LoginHint string `schema:"login_hint"`
|
||||||
ACRValues []string `schema:"acr_values"`
|
ACRValues []string `schema:"acr_values"`
|
||||||
|
|
||||||
|
CodeChallenge string `schema:"code_challenge"`
|
||||||
|
CodeChallengeMethod CodeChallengeMethod `schema:"code_challenge_method"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// func (a *AuthRequest) GetID() string {
|
// func (a *AuthRequest) GetID() string {
|
||||||
|
|
|
@ -10,6 +10,7 @@ type Client interface {
|
||||||
GetID() string
|
GetID() string
|
||||||
RedirectURIs() []string
|
RedirectURIs() []string
|
||||||
ApplicationType() ApplicationType
|
ApplicationType() ApplicationType
|
||||||
|
GetAuthMethod() AuthMethod
|
||||||
LoginURL(string) string
|
LoginURL(string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,3 +19,5 @@ func IsConfidentialType(c Client) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ApplicationType int
|
type ApplicationType int
|
||||||
|
|
||||||
|
type AuthMethod string
|
||||||
|
|
|
@ -16,8 +16,9 @@ const (
|
||||||
defaultUserinfoEndpoint = "userinfo"
|
defaultUserinfoEndpoint = "userinfo"
|
||||||
defaultKeysEndpoint = "keys"
|
defaultKeysEndpoint = "keys"
|
||||||
|
|
||||||
AuthMethodBasic = "client_secret_basic"
|
AuthMethodBasic AuthMethod = "client_secret_basic"
|
||||||
AuthMethodPost = "client_secret_post"
|
AuthMethodPost = "client_secret_post"
|
||||||
|
AuthMethodNone = "none"
|
||||||
|
|
||||||
DefaultIDTokenValidity = time.Duration(5 * time.Minute)
|
DefaultIDTokenValidity = time.Duration(5 * time.Minute)
|
||||||
)
|
)
|
||||||
|
|
|
@ -110,10 +110,10 @@ func SubjectTypes(c Configuration) []string {
|
||||||
|
|
||||||
func AuthMethods(c Configuration) []string {
|
func AuthMethods(c Configuration) []string {
|
||||||
authMethods := []string{
|
authMethods := []string{
|
||||||
AuthMethodBasic,
|
string(AuthMethodBasic),
|
||||||
}
|
}
|
||||||
if c.AuthMethodPostSupported() {
|
if c.AuthMethodPostSupported() {
|
||||||
authMethods = append(authMethods, AuthMethodPost)
|
authMethods = append(authMethods, string(AuthMethodPost))
|
||||||
}
|
}
|
||||||
return authMethods
|
return authMethods
|
||||||
}
|
}
|
||||||
|
|
|
@ -214,7 +214,7 @@ func Test_AuthMethods(t *testing.T) {
|
||||||
m.EXPECT().AuthMethodPostSupported().Return(false)
|
m.EXPECT().AuthMethodPostSupported().Return(false)
|
||||||
return m
|
return m
|
||||||
}()},
|
}()},
|
||||||
[]string{op.AuthMethodBasic},
|
[]string{string(op.AuthMethodBasic)},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"basic and post",
|
"basic and post",
|
||||||
|
@ -222,7 +222,7 @@ func Test_AuthMethods(t *testing.T) {
|
||||||
m.EXPECT().AuthMethodPostSupported().Return(true)
|
m.EXPECT().AuthMethodPostSupported().Return(true)
|
||||||
return m
|
return m
|
||||||
}()},
|
}()},
|
||||||
[]string{op.AuthMethodBasic, op.AuthMethodPost},
|
[]string{string(op.AuthMethodBasic), string(op.AuthMethodPost)},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|
|
@ -47,6 +47,20 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAuthMethod mocks base method
|
||||||
|
func (m *MockClient) GetAuthMethod() op.AuthMethod {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetAuthMethod")
|
||||||
|
ret0, _ := ret[0].(op.AuthMethod)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthMethod indicates an expected call of GetAuthMethod
|
||||||
|
func (mr *MockClientMockRecorder) GetAuthMethod() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthMethod", reflect.TypeOf((*MockClient)(nil).GetAuthMethod))
|
||||||
|
}
|
||||||
|
|
||||||
// GetID mocks base method
|
// GetID mocks base method
|
||||||
func (m *MockClient) GetID() string {
|
func (m *MockClient) GetID() string {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -65,28 +65,12 @@ func (mr *MockStorageMockRecorder) AuthRequestByID(arg0 interface{}) *gomock.Cal
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeClientIDCodeVerifier mocks base method
|
|
||||||
func (m *MockStorage) AuthorizeClientIDCodeVerifier(arg0, arg1 string) (op.Client, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "AuthorizeClientIDCodeVerifier", arg0, arg1)
|
|
||||||
ret0, _ := ret[0].(op.Client)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthorizeClientIDCodeVerifier indicates an expected call of AuthorizeClientIDCodeVerifier
|
|
||||||
func (mr *MockStorageMockRecorder) AuthorizeClientIDCodeVerifier(arg0, arg1 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDCodeVerifier", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDCodeVerifier), arg0, arg1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthorizeClientIDSecret mocks base method
|
// AuthorizeClientIDSecret mocks base method
|
||||||
func (m *MockStorage) AuthorizeClientIDSecret(arg0, arg1 string) (op.Client, error) {
|
func (m *MockStorage) AuthorizeClientIDSecret(arg0, arg1 string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1)
|
ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1)
|
||||||
ret0, _ := ret[0].(op.Client)
|
ret0, _ := ret[0].(error)
|
||||||
ret1, _ := ret[1].(error)
|
return ret0
|
||||||
return ret0, ret1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret
|
// AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret
|
||||||
|
|
|
@ -31,7 +31,7 @@ func NewMockStorageAny(t *testing.T) op.Storage {
|
||||||
m := NewStorage(t)
|
m := NewStorage(t)
|
||||||
mockS := m.(*MockStorage)
|
mockS := m.(*MockStorage)
|
||||||
mockS.EXPECT().GetClientByClientID(gomock.Any()).AnyTimes().Return(&ConfClient{}, nil)
|
mockS.EXPECT().GetClientByClientID(gomock.Any()).AnyTimes().Return(&ConfClient{}, nil)
|
||||||
mockS.EXPECT().AuthorizeClientIDSecret(gomock.Any(), gomock.Any()).AnyTimes().Return(&ConfClient{}, nil)
|
mockS.EXPECT().AuthorizeClientIDSecret(gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,15 +62,19 @@ func ExpectValidClientID(s op.Storage) {
|
||||||
mockS.EXPECT().GetClientByClientID(gomock.Any()).DoAndReturn(
|
mockS.EXPECT().GetClientByClientID(gomock.Any()).DoAndReturn(
|
||||||
func(id string) (op.Client, error) {
|
func(id string) (op.Client, error) {
|
||||||
var appType op.ApplicationType
|
var appType op.ApplicationType
|
||||||
|
var authMethod op.AuthMethod
|
||||||
switch id {
|
switch id {
|
||||||
case "web_client":
|
case "web_client":
|
||||||
appType = op.ApplicationTypeWeb
|
appType = op.ApplicationTypeWeb
|
||||||
|
authMethod = op.AuthMethodBasic
|
||||||
case "native_client":
|
case "native_client":
|
||||||
appType = op.ApplicationTypeNative
|
appType = op.ApplicationTypeNative
|
||||||
|
authMethod = op.AuthMethodNone
|
||||||
case "useragent_client":
|
case "useragent_client":
|
||||||
appType = op.ApplicationTypeUserAgent
|
appType = op.ApplicationTypeUserAgent
|
||||||
|
authMethod = op.AuthMethodBasic
|
||||||
}
|
}
|
||||||
return &ConfClient{appType: appType}, nil
|
return &ConfClient{id: id, appType: appType, authMethod: authMethod}, nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,7 +94,9 @@ func ExpectSigningKey(s op.Storage) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConfClient struct {
|
type ConfClient struct {
|
||||||
appType op.ApplicationType
|
id string
|
||||||
|
appType op.ApplicationType
|
||||||
|
authMethod op.AuthMethod
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConfClient) RedirectURIs() []string {
|
func (c *ConfClient) RedirectURIs() []string {
|
||||||
|
@ -109,3 +115,11 @@ func (c *ConfClient) LoginURL(id string) string {
|
||||||
func (c *ConfClient) ApplicationType() op.ApplicationType {
|
func (c *ConfClient) ApplicationType() op.ApplicationType {
|
||||||
return c.appType
|
return c.appType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ConfClient) GetAuthMethod() op.AuthMethod {
|
||||||
|
return c.authMethod
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConfClient) GetID() string {
|
||||||
|
return c.id
|
||||||
|
}
|
||||||
|
|
|
@ -20,8 +20,7 @@ type AuthStorage interface {
|
||||||
|
|
||||||
type OPStorage interface {
|
type OPStorage interface {
|
||||||
GetClientByClientID(string) (Client, error)
|
GetClientByClientID(string) (Client, error)
|
||||||
AuthorizeClientIDSecret(string, string) (Client, error)
|
AuthorizeClientIDSecret(string, string) error
|
||||||
AuthorizeClientIDCodeVerifier(string, string) (Client, error)
|
|
||||||
GetUserinfoFromScopes([]string) (*oidc.Userinfo, error)
|
GetUserinfoFromScopes([]string) (*oidc.Userinfo, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,38 +22,21 @@ type Exchanger interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
err := r.ParseForm()
|
tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, ErrInvalidRequest("error parsing form"))
|
ExchangeRequestError(w, r, err)
|
||||||
return
|
|
||||||
}
|
|
||||||
tokenReq := new(oidc.AccessTokenRequest)
|
|
||||||
|
|
||||||
err = exchanger.Decoder().Decode(tokenReq, r.Form)
|
|
||||||
if err != nil {
|
|
||||||
ExchangeRequestError(w, r, ErrInvalidRequest("error decoding form"))
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if tokenReq.Code == "" {
|
if tokenReq.Code == "" {
|
||||||
ExchangeRequestError(w, r, ErrInvalidRequest("code missing"))
|
ExchangeRequestError(w, r, ErrInvalidRequest("code missing"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code)
|
authReq, err := ValidateAccessTokenRequest(tokenReq, exchanger)
|
||||||
if err != nil {
|
|
||||||
ExchangeRequestError(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
client, err := AuthorizeClient(r, tokenReq, authReq, exchanger)
|
|
||||||
if err != nil {
|
|
||||||
ExchangeRequestError(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = ValidateAccessTokenRequest(tokenReq, client, authReq)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
ExchangeRequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = exchanger.Storage().DeleteAuthRequestAndCode(authReq.GetID(), tokenReq.Code)
|
err = exchanger.Storage().DeleteAuthRequestAndCode(authReq.GetID(), tokenReq.Code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
ExchangeRequestError(w, r, err)
|
||||||
|
@ -79,40 +62,84 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
utils.MarshalJSON(w, resp)
|
utils.MarshalJSON(w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, authReq AuthRequest, exchanger Exchanger) (Client, error) {
|
func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.AccessTokenRequest, error) {
|
||||||
if tokenReq.ClientID == "" {
|
err := r.ParseForm()
|
||||||
if !exchanger.AuthMethodBasicSupported() {
|
if err != nil {
|
||||||
return nil, errors.New("basic not supported")
|
return nil, ErrInvalidRequest("error parsing form")
|
||||||
}
|
}
|
||||||
clientID, clientSecret, ok := r.BasicAuth()
|
tokenReq := new(oidc.AccessTokenRequest)
|
||||||
if ok {
|
err = decoder.Decode(tokenReq, r.Form)
|
||||||
return exchanger.Storage().AuthorizeClientIDSecret(clientID, clientSecret)
|
if err != nil {
|
||||||
}
|
return nil, ErrInvalidRequest("error decoding form")
|
||||||
|
}
|
||||||
|
clientID, clientSecret, ok := r.BasicAuth()
|
||||||
|
if ok {
|
||||||
|
tokenReq.ClientID = clientID
|
||||||
|
tokenReq.ClientSecret = clientSecret
|
||||||
|
|
||||||
}
|
}
|
||||||
if tokenReq.ClientSecret != "" {
|
return tokenReq, nil
|
||||||
if !exchanger.AuthMethodPostSupported() {
|
|
||||||
return nil, errors.New("post not supported")
|
|
||||||
}
|
|
||||||
return exchanger.Storage().AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret)
|
|
||||||
}
|
|
||||||
if tokenReq.CodeVerifier != "" {
|
|
||||||
if !authReq.GetCodeChallenge().Verify(tokenReq.CodeVerifier) {
|
|
||||||
return nil, ErrInvalidRequest("code_challenge invalid")
|
|
||||||
}
|
|
||||||
return exchanger.Storage().GetClientByClientID(tokenReq.ClientID)
|
|
||||||
}
|
|
||||||
return nil, errors.New("Unimplemented") //TODO: impl
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, client Client, authReq AuthRequest) error {
|
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
|
||||||
|
authReq, client, err := AuthorizeClient(tokenReq, exchanger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if client.GetID() != authReq.GetClientID() {
|
if client.GetID() != authReq.GetClientID() {
|
||||||
return ErrInvalidRequest("invalid auth code")
|
return nil, ErrInvalidRequest("invalid auth code")
|
||||||
}
|
}
|
||||||
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
|
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
|
||||||
return ErrInvalidRequest("redirect_uri does no correspond")
|
return nil, ErrInvalidRequest("redirect_uri does no correspond")
|
||||||
}
|
}
|
||||||
return nil
|
return authReq, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func AuthorizeClient(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
||||||
|
client, err := exchanger.Storage().GetClientByClientID(tokenReq.ClientID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
switch client.GetAuthMethod() {
|
||||||
|
case AuthMethodNone:
|
||||||
|
authReq, err := AuthorizeCodeChallenge(tokenReq, exchanger)
|
||||||
|
return authReq, client, err
|
||||||
|
case AuthMethodPost:
|
||||||
|
if !exchanger.AuthMethodPostSupported() {
|
||||||
|
return nil, nil, errors.New("basic not supported")
|
||||||
|
}
|
||||||
|
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger)
|
||||||
|
case AuthMethodBasic:
|
||||||
|
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger)
|
||||||
|
default:
|
||||||
|
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return authReq, client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func AuthorizeClientIDSecret(clientID, clientSecret string, exchanger Exchanger) error {
|
||||||
|
return exchanger.Storage().AuthorizeClientIDSecret(clientID, clientSecret)
|
||||||
|
}
|
||||||
|
|
||||||
|
func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
|
||||||
|
if tokenReq.CodeVerifier == "" {
|
||||||
|
return nil, ErrInvalidRequest("code_challenge required")
|
||||||
|
}
|
||||||
|
authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrInvalidRequest("invalid code")
|
||||||
|
}
|
||||||
|
if !authReq.GetCodeChallenge().Verify(tokenReq.CodeVerifier) {
|
||||||
|
return nil, ErrInvalidRequest("code_challenge invalid")
|
||||||
|
}
|
||||||
|
return authReq, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) {
|
func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) {
|
||||||
|
@ -120,6 +147,5 @@ func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.Tok
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateTokenExchangeRequest(tokenReq oidc.TokenRequest, storage Storage) error {
|
func ValidateTokenExchangeRequest(tokenReq oidc.TokenRequest, storage Storage) error {
|
||||||
|
|
||||||
return errors.New("Unimplemented") //TODO: impl
|
return errors.New("Unimplemented") //TODO: impl
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue