fix: improve error handling

This commit is contained in:
Livio Amstutz 2021-08-20 07:47:07 +02:00
parent 6cc3c91d07
commit d2d3395c25
9 changed files with 364 additions and 239 deletions

133
pkg/oidc/error.go Normal file
View file

@ -0,0 +1,133 @@
package oidc
import (
"errors"
"fmt"
)
type Error struct {
Parent error `json:"-" schema:"-"`
ErrorType errorType `json:"error" schema:"error"`
Description string `json:"error_description,omitempty" schema:"error_description,omitempty"`
State string `json:"state,omitempty" schema:"state,omitempty"`
redirectDisabled bool `schema:"-"`
}
func (e *Error) Error() string {
message := "ErrorType=" + string(e.ErrorType)
if e.Description != "" {
message += " Description=" + e.Description
}
if e.Parent != nil {
message += " Parent=" + e.Parent.Error()
}
return message
}
func (e *Error) Unwrap() error {
return e.Parent
}
func (e *Error) Is(target error) bool {
t, ok := target.(*Error)
if !ok {
return false
}
return e.ErrorType == t.ErrorType &&
(e.Description == t.Description || t.Description == "") &&
(e.State == t.State || t.State == "")
}
func (e *Error) WithParent(err error) *Error {
e.Parent = err
return e
}
func (e *Error) WithDescription(desc string, args ...interface{}) *Error {
e.Description = fmt.Sprintf(desc, args...)
return e
}
func (e *Error) IsRedirectDisabled() bool {
return e.redirectDisabled
}
var (
ErrInvalidRequest = func() *Error {
return &Error{
ErrorType: InvalidRequest,
}
}
ErrInvalidRequestRedirectURI = func() *Error {
return &Error{
ErrorType: InvalidRequest,
redirectDisabled: true,
}
}
ErrInvalidScope = func() *Error {
return &Error{
ErrorType: InvalidScope,
}
}
ErrInvalidClient = func() *Error {
return &Error{
ErrorType: InvalidClient,
}
}
ErrInvalidGrant = func() *Error {
return &Error{
ErrorType: InvalidGrant,
}
}
ErrUnauthorizedClient = func() *Error {
return &Error{
ErrorType: UnauthorizedClient,
}
}
ErrUnsupportedGrantType = func() *Error {
return &Error{
ErrorType: UnsupportedGrantType,
}
}
ErrServerError = func() *Error {
return &Error{
ErrorType: ServerError,
}
}
ErrInteractionRequired = func() *Error {
return &Error{
ErrorType: InteractionRequired,
}
}
ErrLoginRequired = func() *Error {
return &Error{
ErrorType: LoginRequired,
}
}
)
// DefaultToServerError checks if the error is an Error
// if not the provided error will be wrapped into a ServerError
func DefaultToServerError(err error, description string) *Error {
oauth := new(Error)
if ok := errors.As(err, &oauth); !ok {
oauth.ErrorType = ServerError
oauth.Description = description
oauth.Parent = err
}
return oauth
}
type errorType string
const (
InvalidRequest errorType = "invalid_request"
InvalidScope errorType = "invalid_scope"
InvalidClient errorType = "invalid_client"
InvalidGrant errorType = "invalid_grant"
UnauthorizedClient errorType = "unauthorized_client"
UnsupportedGrantType errorType = "unsupported_grant_type"
ServerError errorType = "server_error"
InteractionRequired errorType = "interaction_required"
LoginRequired errorType = "login_required"
)

View file

@ -43,7 +43,7 @@ type Authorizer interface {
} }
//AuthorizeValidator is an extension of Authorizer interface //AuthorizeValidator is an extension of Authorizer interface
//implementing it's own validation mechanism for the auth request //implementing its own validation mechanism for the auth request
type AuthorizeValidator interface { type AuthorizeValidator interface {
Authorizer Authorizer
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error) ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error)
@ -80,12 +80,12 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
} }
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq, userID) req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq, userID)
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder())
return return
} }
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID()) client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
if err != nil { if err != nil {
AuthRequestError(w, r, req, err, authorizer.Encoder()) AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder())
return return
} }
RedirectToLogin(req.GetID(), client, w, r) RedirectToLogin(req.GetID(), client, w, r)
@ -95,12 +95,12 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
func ParseAuthorizeRequest(r *http.Request, decoder utils.Decoder) (*oidc.AuthRequest, error) { func ParseAuthorizeRequest(r *http.Request, decoder utils.Decoder) (*oidc.AuthRequest, error) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
return nil, ErrInvalidRequest("cannot parse form") return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
} }
authReq := new(oidc.AuthRequest) authReq := new(oidc.AuthRequest)
err = decoder.Decode(authReq, r.Form) err = decoder.Decode(authReq, r.Form)
if err != nil { if err != nil {
return nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)) return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse auth request").WithParent(err)
} }
return authReq, nil return authReq, nil
} }
@ -113,7 +113,7 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
} }
client, err := storage.GetClientByClientID(ctx, authReq.ClientID) client, err := storage.GetClientByClientID(ctx, authReq.ClientID)
if err != nil { if err != nil {
return "", ErrServerError(err.Error()) return "", oidc.DefaultToServerError(err, "unable to retrieve client by id")
} }
authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes) authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes)
if err != nil { if err != nil {
@ -132,7 +132,7 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
func ValidateAuthReqPrompt(prompts []string, maxAge *uint) (_ *uint, err error) { func ValidateAuthReqPrompt(prompts []string, maxAge *uint) (_ *uint, err error) {
for _, prompt := range prompts { for _, prompt := range prompts {
if prompt == oidc.PromptNone && len(prompts) > 1 { if prompt == oidc.PromptNone && len(prompts) > 1 {
return nil, ErrInvalidRequest("The prompt parameter `none` must only be used as a single value") return nil, oidc.ErrInvalidRequest().WithDescription("The prompt parameter `none` must only be used as a single value")
} }
if prompt == oidc.PromptLogin { if prompt == oidc.PromptLogin {
maxAge = oidc.NewMaxAge(0) maxAge = oidc.NewMaxAge(0)
@ -144,7 +144,9 @@ func ValidateAuthReqPrompt(prompts []string, maxAge *uint) (_ *uint, err error)
//ValidateAuthReqScopes validates the passed scopes //ValidateAuthReqScopes validates the passed scopes
func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) { func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
if len(scopes) == 0 { if len(scopes) == 0 {
return nil, ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.") return nil, oidc.ErrInvalidRequest().
WithDescription("The scope of your request is missing. Please ensure some scopes are requested. " +
"If you have any questions, you may contact the administrator of the application.")
} }
openID := false openID := false
for i := len(scopes) - 1; i >= 0; i-- { for i := len(scopes) - 1; i >= 0; i-- {
@ -165,7 +167,9 @@ func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
} }
} }
if !openID { if !openID {
return nil, ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.") return nil, oidc.ErrInvalidScope().WithDescription("The scope openid is missing in your request. " +
"Please ensure the scope openid is added to the request. " +
"If you have any questions, you may contact the administrator of the application.")
} }
return scopes, nil return scopes, nil
@ -174,11 +178,14 @@ func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
//ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type //ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type
func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.ResponseType) error { func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.ResponseType) error {
if uri == "" { if uri == "" {
return ErrInvalidRequestRedirectURI("The redirect_uri is missing in the request. Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequestRedirectURI().WithDescription("The redirect_uri is missing in the request. " +
"Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.")
} }
if strings.HasPrefix(uri, "https://") { if strings.HasPrefix(uri, "https://") {
if !utils.Contains(client.RedirectURIs(), uri) { if !utils.Contains(client.RedirectURIs(), uri) {
return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequestRedirectURI().
WithDescription("The requested redirect_uri is missing in the client configuration. " +
"If you have any questions, you may contact the administrator of the application.")
} }
return nil return nil
} }
@ -186,7 +193,8 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res
return validateAuthReqRedirectURINative(client, uri, responseType) return validateAuthReqRedirectURINative(client, uri, responseType)
} }
if !utils.Contains(client.RedirectURIs(), uri) { if !utils.Contains(client.RedirectURIs(), uri) {
return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequestRedirectURI().WithDescription("The requested redirect_uri is missing in the client configuration. " +
"If you have any questions, you may contact the administrator of the application.")
} }
if strings.HasPrefix(uri, "http://") { if strings.HasPrefix(uri, "http://") {
if client.DevMode() { if client.DevMode() {
@ -195,9 +203,11 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res
if responseType == oidc.ResponseTypeCode && IsConfidentialType(client) { if responseType == oidc.ResponseTypeCode && IsConfidentialType(client) {
return nil return nil
} }
return ErrInvalidRequest("This client's redirect_uri is http and is not allowed. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequest().WithDescription("This client's redirect_uri is http and is not allowed. " +
"If you have any questions, you may contact the administrator of the application.")
} }
return ErrInvalidRequest("This client's redirect_uri is using a custom schema and is not allowed. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequest().WithDescription("This client's redirect_uri is using a custom schema and is not allowed. " +
"If you have any questions, you may contact the administrator of the application.")
} }
//ValidateAuthReqRedirectURINative validates the passed redirect_uri and response_type to the registered uris and client type //ValidateAuthReqRedirectURINative validates the passed redirect_uri and response_type to the registered uris and client type
@ -208,10 +218,12 @@ func validateAuthReqRedirectURINative(client Client, uri string, responseType oi
if isLoopback || isCustomSchema { if isLoopback || isCustomSchema {
return nil return nil
} }
return ErrInvalidRequest("This client's redirect_uri is http and is not allowed. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequest().WithDescription("This client's redirect_uri is http and is not allowed. " +
"If you have any questions, you may contact the administrator of the application.")
} }
if !isLoopback { if !isLoopback {
return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequestRedirectURI().WithDescription("The requested redirect_uri is missing in the client configuration. " +
"If you have any questions, you may contact the administrator of the application.")
} }
for _, uri := range client.RedirectURIs() { for _, uri := range client.RedirectURIs() {
redirectURI, ok := HTTPLoopbackOrLocalhost(uri) redirectURI, ok := HTTPLoopbackOrLocalhost(uri)
@ -219,7 +231,8 @@ func validateAuthReqRedirectURINative(client Client, uri string, responseType oi
return nil return nil
} }
} }
return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequestRedirectURI().WithDescription("The requested redirect_uri is missing in the client configuration." +
" If you have any questions, you may contact the administrator of the application.")
} }
func equalURI(url1, url2 *url.URL) bool { func equalURI(url1, url2 *url.URL) bool {
@ -241,10 +254,12 @@ func HTTPLoopbackOrLocalhost(rawurl string) (*url.URL, bool) {
//ValidateAuthReqResponseType validates the passed response_type to the registered response types //ValidateAuthReqResponseType validates the passed response_type to the registered response types
func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType) error { func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType) error {
if responseType == "" { if responseType == "" {
return ErrInvalidRequest("The response type is missing in your request. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequest().WithDescription("The response type is missing in your request. " +
"If you have any questions, you may contact the administrator of the application.")
} }
if !ContainsResponseType(client.ResponseTypes(), responseType) { if !ContainsResponseType(client.ResponseTypes(), responseType) {
return ErrInvalidRequest("The requested response type is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") return oidc.ErrUnauthorizedClient().WithDescription("The requested response type is missing in the client configuration. " +
"If you have any questions, you may contact the administrator of the application.")
} }
return nil return nil
} }
@ -257,7 +272,8 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
} }
claims, err := VerifyIDTokenHint(ctx, idTokenHint, verifier) claims, err := VerifyIDTokenHint(ctx, idTokenHint, verifier)
if err != nil { if err != nil {
return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.") return "", oidc.ErrLoginRequired().WithDescription("The id_token_hint is invalid. " +
"If you have any questions, you may contact the administrator of the application.")
} }
return claims.GetSubject(), nil return claims.GetSubject(), nil
} }
@ -279,7 +295,9 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
return return
} }
if !authReq.Done() { if !authReq.Done() {
AuthRequestError(w, r, authReq, ErrInteractionRequired("Unfortunately, the user may is not logged in and/or additional interaction is required."), authorizer.Encoder()) AuthRequestError(w, r, authReq,
oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."),
authorizer.Encoder())
return return
} }
AuthResponse(authReq, authorizer, w, r) AuthResponse(authReq, authorizer, w, r)

View file

@ -1,6 +1,8 @@
package op_test package op_test
import ( import (
"context"
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -141,6 +143,7 @@ func TestParseAuthorizeRequest(t *testing.T) {
} }
} }
//TODO: extend cases
func TestValidateAuthRequest(t *testing.T) { func TestValidateAuthRequest(t *testing.T) {
type args struct { type args struct {
authRequest *oidc.AuthRequest authRequest *oidc.AuthRequest
@ -150,7 +153,7 @@ func TestValidateAuthRequest(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
wantErr bool wantErr error
}{ }{
//TODO: //TODO:
// { // {
@ -159,39 +162,45 @@ func TestValidateAuthRequest(t *testing.T) {
{ {
"scope missing fails", "scope missing fails",
args{&oidc.AuthRequest{}, mock.NewMockStorageExpectValidClientID(t), nil}, args{&oidc.AuthRequest{}, mock.NewMockStorageExpectValidClientID(t), nil},
true, oidc.ErrInvalidRequest(),
}, },
{ {
"scope openid missing fails", "scope openid missing fails",
args{&oidc.AuthRequest{Scopes: []string{"profile"}}, mock.NewMockStorageExpectValidClientID(t), nil}, args{&oidc.AuthRequest{Scopes: []string{"profile"}}, mock.NewMockStorageExpectValidClientID(t), nil},
true, oidc.ErrInvalidScope(),
}, },
{ {
"response_type missing fails", "response_type missing fails",
args{&oidc.AuthRequest{Scopes: []string{"openid"}}, mock.NewMockStorageExpectValidClientID(t), nil}, args{&oidc.AuthRequest{Scopes: []string{"openid"}}, mock.NewMockStorageExpectValidClientID(t), nil},
true, oidc.ErrInvalidRequest(),
}, },
{ {
"client_id missing fails", "client_id missing fails",
args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, mock.NewMockStorageExpectValidClientID(t), nil}, args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, mock.NewMockStorageExpectValidClientID(t), nil},
true, oidc.ErrInvalidRequest(),
}, },
{ {
"redirect_uri missing fails", "redirect_uri missing fails",
args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode, ClientID: "client_id"}, mock.NewMockStorageExpectValidClientID(t), nil}, args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode, ClientID: "client_id"}, mock.NewMockStorageExpectValidClientID(t), nil},
true, oidc.ErrInvalidRequest(),
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := op.ValidateAuthRequest(nil, tt.args.authRequest, tt.args.storage, tt.args.verifier) _, err := op.ValidateAuthRequest(context.TODO(), tt.args.authRequest, tt.args.storage, tt.args.verifier)
if (err != nil) != tt.wantErr { if tt.wantErr == nil && err != nil {
t.Errorf("ValidateAuthRequest() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ValidateAuthRequest() unexpected error = %v", err)
}
if tt.wantErr != nil && !errors.Is(err, tt.wantErr) {
t.Errorf("ValidateAuthRequest() unexpected error = %v, want = %v", err, tt.wantErr)
} }
}) })
} }
} }
//TODO: implement
func TestValidateAuthReqPrompt(t *testing.T) {}
func TestValidateAuthReqScopes(t *testing.T) { func TestValidateAuthReqScopes(t *testing.T) {
type args struct { type args struct {
client op.Client client op.Client
@ -465,115 +474,8 @@ func TestValidateAuthReqRedirectURI(t *testing.T) {
} }
} }
func TestValidateAuthReqResponseType(t *testing.T) { //TODO: test not parsable url
type args struct { func TestLoopbackOrLocalhost(t *testing.T) {
responseType oidc.ResponseType
client op.Client
}
tests := []struct {
name string
args args
wantErr bool
}{
{
"empty response type",
args{"",
mock.NewClientWithConfig(t, nil, op.ApplicationTypeNative, []oidc.ResponseType{oidc.ResponseTypeCode}, true)},
true,
},
{
"response type missing in client config",
args{oidc.ResponseTypeIDToken,
mock.NewClientWithConfig(t, nil, op.ApplicationTypeNative, []oidc.ResponseType{oidc.ResponseTypeCode}, true)},
true,
},
{
"valid response type",
args{oidc.ResponseTypeCode,
mock.NewClientWithConfig(t, nil, op.ApplicationTypeNative, []oidc.ResponseType{oidc.ResponseTypeCode}, true)},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := op.ValidateAuthReqResponseType(tt.args.client, tt.args.responseType); (err != nil) != tt.wantErr {
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestRedirectToLogin(t *testing.T) {
type args struct {
authReqID string
client op.Client
w http.ResponseWriter
r *http.Request
}
tests := []struct {
name string
args args
}{
{
"redirect ok",
args{
"id",
mock.NewClientExpectAny(t, op.ApplicationTypeNative),
httptest.NewRecorder(),
httptest.NewRequest("GET", "/authorize", nil),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.RedirectToLogin(tt.args.authReqID, tt.args.client, tt.args.w, tt.args.r)
rec := tt.args.w.(*httptest.ResponseRecorder)
require.Equal(t, http.StatusFound, rec.Code)
require.Equal(t, "/login?id=id", rec.Header().Get("location"))
})
}
}
func TestAuthorizeCallback(t *testing.T) {
type args struct {
w http.ResponseWriter
r *http.Request
authorizer op.Authorizer
}
tests := []struct {
name string
args args
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.AuthorizeCallback(tt.args.w, tt.args.r, tt.args.authorizer)
})
}
}
func TestAuthResponse(t *testing.T) {
type args struct {
authReq op.AuthRequest
authorizer op.Authorizer
w http.ResponseWriter
r *http.Request
}
tests := []struct {
name string
args args
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.AuthResponse(tt.args.authReq, tt.args.authorizer, tt.args.w, tt.args.r)
})
}
}
func Test_LoopbackOrLocalhost(t *testing.T) {
type args struct { type args struct {
url string url string
} }
@ -631,3 +533,128 @@ func Test_LoopbackOrLocalhost(t *testing.T) {
}) })
} }
} }
func TestValidateAuthReqResponseType(t *testing.T) {
type args struct {
responseType oidc.ResponseType
client op.Client
}
tests := []struct {
name string
args args
wantErr bool
}{
{
"empty response type",
args{"",
mock.NewClientWithConfig(t, nil, op.ApplicationTypeNative, []oidc.ResponseType{oidc.ResponseTypeCode}, true)},
true,
},
{
"response type missing in client config",
args{oidc.ResponseTypeIDToken,
mock.NewClientWithConfig(t, nil, op.ApplicationTypeNative, []oidc.ResponseType{oidc.ResponseTypeCode}, true)},
true,
},
{
"valid response type",
args{oidc.ResponseTypeCode,
mock.NewClientWithConfig(t, nil, op.ApplicationTypeNative, []oidc.ResponseType{oidc.ResponseTypeCode}, true)},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := op.ValidateAuthReqResponseType(tt.args.client, tt.args.responseType); (err != nil) != tt.wantErr {
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
//TODO: implement
func TestValidateAuthReqIDTokenHint(t *testing.T) {}
func TestRedirectToLogin(t *testing.T) {
type args struct {
authReqID string
client op.Client
w http.ResponseWriter
r *http.Request
}
tests := []struct {
name string
args args
}{
{
"redirect ok",
args{
"id",
mock.NewClientExpectAny(t, op.ApplicationTypeNative),
httptest.NewRecorder(),
httptest.NewRequest("GET", "/authorize", nil),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.RedirectToLogin(tt.args.authReqID, tt.args.client, tt.args.w, tt.args.r)
rec := tt.args.w.(*httptest.ResponseRecorder)
require.Equal(t, http.StatusFound, rec.Code)
require.Equal(t, "/login?id=id", rec.Header().Get("location"))
})
}
}
//TODO: implement
func TestAuthorizeCallback(t *testing.T) {
type args struct {
w http.ResponseWriter
r *http.Request
authorizer op.Authorizer
}
tests := []struct {
name string
args args
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.AuthorizeCallback(tt.args.w, tt.args.r, tt.args.authorizer)
})
}
}
//TODO: implement
func TestAuthResponse(t *testing.T) {
type args struct {
authReq op.AuthRequest
authorizer op.Authorizer
w http.ResponseWriter
r *http.Request
}
tests := []struct {
name string
args args
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.AuthResponse(tt.args.authReq, tt.args.authorizer, tt.args.w, tt.args.r)
})
}
}
//TODO: implement
func TestAuthResponseCode(t *testing.T) {}
//TODO: implement
func TestAuthResponseToken(t *testing.T) {}
//TODO: implement
func TestCreateAuthRequestCode(t *testing.T) {}
//TODO: implement
func TestBuildAuthRequestCode(t *testing.T) {}

View file

@ -1,50 +1,12 @@
package op package op
import ( import (
"fmt"
"net/http" "net/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/utils"
) )
const (
InvalidRequest errorType = "invalid_request"
InvalidRequestURI errorType = "invalid_request_uri"
InteractionRequired errorType = "interaction_required"
ServerError errorType = "server_error"
)
var (
ErrInvalidRequest = func(description string) *OAuthError {
return &OAuthError{
ErrorType: InvalidRequest,
Description: description,
}
}
ErrInvalidRequestRedirectURI = func(description string) *OAuthError {
return &OAuthError{
ErrorType: InvalidRequestURI,
Description: description,
redirectDisabled: true,
}
}
ErrInteractionRequired = func(description string) *OAuthError {
return &OAuthError{
ErrorType: InteractionRequired,
Description: description,
}
}
ErrServerError = func(description string) *OAuthError {
return &OAuthError{
ErrorType: ServerError,
Description: description,
}
}
)
type errorType string
type ErrAuthRequest interface { type ErrAuthRequest interface {
GetRedirectURI() string GetRedirectURI() string
GetResponseType() oidc.ResponseType GetResponseType() oidc.ResponseType
@ -56,14 +18,9 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
e, ok := err.(*OAuthError) e := oidc.DefaultToServerError(err, err.Error()) //TODO: desc?
if !ok {
e = new(OAuthError)
e.ErrorType = ServerError
e.Description = err.Error()
}
e.State = authReq.GetState() e.State = authReq.GetState()
if authReq.GetRedirectURI() == "" || e.redirectDisabled { if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
http.Error(w, e.Description, http.StatusBadRequest) http.Error(w, e.Description, http.StatusBadRequest)
return return
} }
@ -83,23 +40,10 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
} }
func RequestError(w http.ResponseWriter, r *http.Request, err error) { func RequestError(w http.ResponseWriter, r *http.Request, err error) {
e, ok := err.(*OAuthError) e := oidc.DefaultToServerError(err, err.Error()) //TODO: desc?
if !ok { status := http.StatusBadRequest
e = new(OAuthError) if e.ErrorType == oidc.InvalidClient {
e.ErrorType = ServerError status = 401
e.Description = err.Error()
} }
w.WriteHeader(http.StatusBadRequest) utils.MarshalJSONWithStatus(w, e, status)
utils.MarshalJSON(w, e)
}
type OAuthError struct {
ErrorType errorType `json:"error" schema:"error"`
Description string `json:"error_description,omitempty" schema:"error_description,omitempty"`
State string `json:"state,omitempty" schema:"state,omitempty"`
redirectDisabled bool `json:"-" schema:"-"`
}
func (e *OAuthError) Error() string {
return fmt.Sprintf("%s: %s", e.ErrorType, e.Description)
} }

View file

@ -38,7 +38,7 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
} }
err = ender.Storage().TerminateSession(r.Context(), session.UserID, clientID) err = ender.Storage().TerminateSession(r.Context(), session.UserID, clientID)
if err != nil { if err != nil {
RequestError(w, r, ErrServerError("error terminating session")) RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"))
return return
} }
http.Redirect(w, r, session.RedirectURI, http.StatusFound) http.Redirect(w, r, session.RedirectURI, http.StatusFound)
@ -47,12 +47,12 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
func ParseEndSessionRequest(r *http.Request, decoder utils.Decoder) (*oidc.EndSessionRequest, error) { func ParseEndSessionRequest(r *http.Request, decoder utils.Decoder) (*oidc.EndSessionRequest, error) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
return nil, ErrInvalidRequest("error parsing form") return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
} }
req := new(oidc.EndSessionRequest) req := new(oidc.EndSessionRequest)
err = decoder.Decode(req, r.Form) err = decoder.Decode(req, r.Form)
if err != nil { if err != nil {
return nil, ErrInvalidRequest("error decoding form") return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
} }
return req, nil return req, nil
} }
@ -64,12 +64,12 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
} }
claims, err := VerifyIDTokenHint(ctx, req.IdTokenHint, ender.IDTokenHintVerifier()) claims, err := VerifyIDTokenHint(ctx, req.IdTokenHint, ender.IDTokenHintVerifier())
if err != nil { if err != nil {
return nil, ErrInvalidRequest("id_token_hint invalid") return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err)
} }
session.UserID = claims.GetSubject() session.UserID = claims.GetSubject()
session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.GetAuthorizedParty()) session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.GetAuthorizedParty())
if err != nil { if err != nil {
return nil, ErrServerError("") return nil, oidc.DefaultToServerError(err, "")
} }
if req.PostLogoutRedirectURI == "" { if req.PostLogoutRedirectURI == "" {
session.RedirectURI = ender.DefaultLogoutRedirectURI() session.RedirectURI = ender.DefaultLogoutRedirectURI()
@ -81,5 +81,5 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
return session, nil return session, nil
} }
} }
return nil, ErrInvalidRequest("post_logout_redirect_uri invalid") return nil, oidc.ErrInvalidRequest().WithDescription("post_logout_redirect_uri invalid")
} }

View file

@ -2,7 +2,6 @@ package op
import ( import (
"context" "context"
"errors"
"net/http" "net/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
@ -17,7 +16,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
RequestError(w, r, err) RequestError(w, r, err)
} }
if tokenReq.Code == "" { if tokenReq.Code == "" {
RequestError(w, r, ErrInvalidRequest("code missing")) RequestError(w, r, oidc.ErrInvalidGrant()) //TODO: ErrInvalidRequest("code missing")?
return return
} }
authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger) authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
@ -51,13 +50,13 @@ func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenR
return nil, nil, err return nil, nil, err
} }
if client.GetID() != authReq.GetClientID() { if client.GetID() != authReq.GetClientID() {
return nil, nil, ErrInvalidRequest("invalid auth code") return nil, nil, oidc.ErrInvalidGrant()
} }
if !ValidateGrantType(client, oidc.GrantTypeCode) { if !ValidateGrantType(client, oidc.GrantTypeCode) {
return nil, nil, ErrInvalidRequest("invalid_grant") return nil, nil, oidc.ErrUnauthorizedClient()
} }
if tokenReq.RedirectURI != authReq.GetRedirectURI() { if tokenReq.RedirectURI != authReq.GetRedirectURI() {
return nil, nil, ErrInvalidRequest("redirect_uri does not correspond") return nil, nil, oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond")
} }
return authReq, client, nil return authReq, client, nil
} }
@ -68,7 +67,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
jwtExchanger, ok := exchanger.(JWTAuthorizationGrantExchanger) jwtExchanger, ok := exchanger.(JWTAuthorizationGrantExchanger)
if !ok || !exchanger.AuthMethodPrivateKeyJWTSupported() { if !ok || !exchanger.AuthMethodPrivateKeyJWTSupported() {
return nil, nil, errors.New("auth_method private_key_jwt not supported") return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
} }
client, err = AuthorizePrivateJWTKey(ctx, tokenReq.ClientAssertion, jwtExchanger) client, err = AuthorizePrivateJWTKey(ctx, tokenReq.ClientAssertion, jwtExchanger)
if err != nil { if err != nil {
@ -79,10 +78,10 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
} }
client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID) client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, oidc.ErrInvalidClient().WithParent(err)
} }
if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT { if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT {
return nil, nil, errors.New("invalid_grant") return nil, nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client")
} }
if client.AuthMethod() == oidc.AuthMethodNone { if client.AuthMethod() == oidc.AuthMethodNone {
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code) request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
@ -93,9 +92,12 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
return request, client, err return request, client, err
} }
if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() { if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() {
return nil, nil, errors.New("auth_method post not supported") return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
} }
err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()) err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
if err != nil {
return nil, nil, oidc.ErrInvalidClient().WithDescription("invalid client_id / client_secret").WithParent(err)
}
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code) request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
return request, client, err return request, client, err
} }
@ -104,7 +106,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
func AuthRequestByCode(ctx context.Context, storage Storage, code string) (AuthRequest, error) { func AuthRequestByCode(ctx context.Context, storage Storage, code string) (AuthRequest, error) {
authReq, err := storage.AuthRequestByCode(ctx, code) authReq, err := storage.AuthRequestByCode(ctx, code)
if err != nil { if err != nil {
return nil, ErrInvalidRequest("invalid code") return nil, oidc.ErrInvalidGrant().WithDescription("invalid code").WithParent(err)
} }
return authReq, nil return authReq, nil
} }

View file

@ -43,12 +43,12 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizati
func ParseJWTProfileGrantRequest(r *http.Request, decoder utils.Decoder) (*oidc.JWTProfileGrantRequest, error) { func ParseJWTProfileGrantRequest(r *http.Request, decoder utils.Decoder) (*oidc.JWTProfileGrantRequest, error) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
return nil, ErrInvalidRequest("error parsing form") return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
} }
tokenReq := new(oidc.JWTProfileGrantRequest) tokenReq := new(oidc.JWTProfileGrantRequest)
err = decoder.Decode(tokenReq, r.Form) err = decoder.Decode(tokenReq, r.Form)
if err != nil { if err != nil {
return nil, ErrInvalidRequest("error decoding form") return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
} }
return tokenReq, nil return tokenReq, nil
} }

View file

@ -54,14 +54,14 @@ func ParseRefreshTokenRequest(r *http.Request, decoder utils.Decoder) (*oidc.Ref
//and returns the data representing the original auth request corresponding to the refresh_token //and returns the data representing the original auth request corresponding to the refresh_token
func ValidateRefreshTokenRequest(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (RefreshTokenRequest, Client, error) { func ValidateRefreshTokenRequest(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (RefreshTokenRequest, Client, error) {
if tokenReq.RefreshToken == "" { if tokenReq.RefreshToken == "" {
return nil, nil, ErrInvalidRequest("code missing") return nil, nil, oidc.ErrInvalidGrant() //TODO: ErrInvalidRequest("refresh_token missing")?
} }
request, client, err := AuthorizeRefreshClient(ctx, tokenReq, exchanger) request, client, err := AuthorizeRefreshClient(ctx, tokenReq, exchanger)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if client.GetID() != request.GetClientID() { if client.GetID() != request.GetClientID() {
return nil, nil, ErrInvalidRequest("invalid auth code") return nil, nil, oidc.ErrInvalidGrant()
} }
if err = ValidateRefreshTokenScopes(tokenReq.Scopes, request); err != nil { if err = ValidateRefreshTokenScopes(tokenReq.Scopes, request); err != nil {
return nil, nil, err return nil, nil, err
@ -78,7 +78,7 @@ func ValidateRefreshTokenScopes(requestedScopes []string, authRequest RefreshTok
} }
for _, scope := range requestedScopes { for _, scope := range requestedScopes {
if !utils.Contains(authRequest.GetScopes(), scope) { if !utils.Contains(authRequest.GetScopes(), scope) {
return errors.New("invalid_scope") return oidc.ErrInvalidScope()
} }
} }
authRequest.SetCurrentScopes(requestedScopes) authRequest.SetCurrentScopes(requestedScopes)
@ -98,7 +98,7 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ
return nil, nil, err return nil, nil, err
} }
if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) { if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) {
return nil, nil, ErrInvalidRequest("invalid_grant") return nil, nil, oidc.ErrUnauthorizedClient()
} }
request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken) request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken)
return request, client, err return request, client, err
@ -108,17 +108,17 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ
return nil, nil, err return nil, nil, err
} }
if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) { if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) {
return nil, nil, ErrInvalidRequest("invalid_grant") return nil, nil, oidc.ErrUnauthorizedClient()
} }
if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT { if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT {
return nil, nil, errors.New("invalid_grant") return nil, nil, oidc.ErrInvalidClient()
} }
if client.AuthMethod() == oidc.AuthMethodNone { if client.AuthMethod() == oidc.AuthMethodNone {
request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken) request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken)
return request, client, err return request, client, err
} }
if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() { if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() {
return nil, nil, errors.New("auth_method post not supported") return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
} }
if err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()); err != nil { if err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()); err != nil {
return nil, nil, err return nil, nil, err
@ -132,7 +132,7 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ
func RefreshTokenRequestByRefreshToken(ctx context.Context, storage Storage, refreshToken string) (RefreshTokenRequest, error) { func RefreshTokenRequestByRefreshToken(ctx context.Context, storage Storage, refreshToken string) (RefreshTokenRequest, error) {
request, err := storage.TokenRequestByRefreshToken(ctx, refreshToken) request, err := storage.TokenRequestByRefreshToken(ctx, refreshToken)
if err != nil { if err != nil {
return nil, ErrInvalidRequest("invalid refreshToken") return nil, oidc.ErrInvalidGrant().WithParent(err)
} }
return request, nil return request, nil
} }

View file

@ -24,7 +24,8 @@ type Exchanger interface {
func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) { func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
switch r.FormValue("grant_type") { grantType := r.FormValue("grant_type")
switch grantType {
case string(oidc.GrantTypeCode): case string(oidc.GrantTypeCode):
CodeExchange(w, r, exchanger) CodeExchange(w, r, exchanger)
return return
@ -44,10 +45,10 @@ func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Reque
return return
} }
case "": case "":
RequestError(w, r, ErrInvalidRequest("grant_type missing")) RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"))
return return
} }
RequestError(w, r, ErrInvalidRequest("grant_type not supported")) RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType))
} }
} }
@ -63,21 +64,21 @@ type AuthenticatedTokenRequest interface {
func ParseAuthenticatedTokenRequest(r *http.Request, decoder utils.Decoder, request AuthenticatedTokenRequest) error { func ParseAuthenticatedTokenRequest(r *http.Request, decoder utils.Decoder, request AuthenticatedTokenRequest) error {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
return ErrInvalidRequest("error parsing form") return oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
} }
err = decoder.Decode(request, r.Form) err = decoder.Decode(request, r.Form)
if err != nil { if err != nil {
return ErrInvalidRequest("error decoding form") return oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
} }
clientID, clientSecret, ok := r.BasicAuth() clientID, clientSecret, ok := r.BasicAuth()
if ok { if ok {
clientID, err = url.QueryUnescape(clientID) clientID, err = url.QueryUnescape(clientID)
if err != nil { if err != nil {
return ErrInvalidRequest("invalid basic auth header") return oidc.ErrInvalidRequest().WithDescription("invalid basic auth header").WithParent(err)
} }
clientSecret, err = url.QueryUnescape(clientSecret) clientSecret, err = url.QueryUnescape(clientSecret)
if err != nil { if err != nil {
return ErrInvalidRequest("invalid basic auth header") return oidc.ErrInvalidRequest().WithDescription("invalid basic auth header").WithParent(err)
} }
request.SetClientID(clientID) request.SetClientID(clientID)
request.SetClientSecret(clientSecret) request.SetClientSecret(clientSecret)
@ -89,7 +90,7 @@ func ParseAuthenticatedTokenRequest(r *http.Request, decoder utils.Decoder, requ
func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage Storage) error { func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage Storage) error {
err := storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret) err := storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
if err != nil { if err != nil {
return err //TODO: wrap? return oidc.ErrInvalidGrant().WithDescription("code_challenge required").WithParent(err)
} }
return nil return nil
} }
@ -98,10 +99,10 @@ func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string,
//code_challenge of the auth request (PKCE) //code_challenge of the auth request (PKCE)
func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.CodeChallenge) error { func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.CodeChallenge) error {
if tokenReq.CodeVerifier == "" { if tokenReq.CodeVerifier == "" {
return ErrInvalidRequest("code_challenge required") return oidc.ErrInvalidGrant().WithDescription("code_challenge required") //TODO: ErrInvalidRequest("code_challenge required")
} }
if !oidc.VerifyCodeChallenge(challenge, tokenReq.CodeVerifier) { if !oidc.VerifyCodeChallenge(challenge, tokenReq.CodeVerifier) {
return ErrInvalidRequest("code_challenge invalid") return oidc.ErrInvalidGrant().WithDescription("invalid code challenge")
} }
return nil return nil
} }
@ -118,7 +119,7 @@ func AuthorizePrivateJWTKey(ctx context.Context, clientAssertion string, exchang
return nil, err return nil, err
} }
if client.AuthMethod() != oidc.AuthMethodPrivateKeyJWT { if client.AuthMethod() != oidc.AuthMethodPrivateKeyJWT {
return nil, ErrInvalidRequest("invalid_client") return nil, oidc.ErrInvalidClient()
} }
return client, nil return client, nil
} }