diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index c1dd080..799d3d0 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -105,14 +105,14 @@ func ValidateAuthReqScopes(scopes []string) error { func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error { if uri == "" { - return ErrInvalidRequest("redirect_uri must not be empty") + return ErrInvalidRequestRedirectURI("redirect_uri must not be empty") } client, err := storage.GetClientByClientID(client_id) if err != nil { return ErrServerError(err.Error()) } if !utils.Contains(client.RedirectURIs(), uri) { - return ErrInvalidRequest("redirect_uri not allowed") + return ErrInvalidRequestRedirectURI("redirect_uri not allowed") } if strings.HasPrefix(uri, "https://") { return nil @@ -127,10 +127,10 @@ func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.Respons return ErrInvalidRequest("redirect_uri not allowed 2") } else { if client.ApplicationType() != ApplicationTypeNative { - return ErrInvalidRequest("redirect_uri not allowed 3") + return ErrInvalidRequestRedirectURI("redirect_uri not allowed 3") } if !(strings.HasPrefix(uri, "http://localhost:") || strings.HasPrefix(uri, "http://localhost/")) { - return ErrInvalidRequest("redirect_uri not allowed 4") + return ErrInvalidRequestRedirectURI("redirect_uri not allowed 4") } } return nil diff --git a/pkg/op/error.go b/pkg/op/error.go index 0a3b4ab..1e84c1a 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -22,6 +22,13 @@ var ( Description: description, } } + ErrInvalidRequestRedirectURI = func(description string) *OAuthError { + return &OAuthError{ + ErrorType: InvalidRequest, + Description: description, + redirectDisabled: true, + } + } ErrServerError = func(description string) *OAuthError { return &OAuthError{ ErrorType: ServerError, @@ -43,10 +50,6 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq http.Error(w, err.Error(), http.StatusBadRequest) return } - if authReq.GetRedirectURI() == "" { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } e, ok := err.(*OAuthError) if !ok { e = new(OAuthError) @@ -54,6 +57,10 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq e.Description = err.Error() } e.state = authReq.GetState() + if authReq.GetRedirectURI() == "" || e.redirectDisabled { + http.Error(w, e.Description, http.StatusBadRequest) + return + } params, err := utils.URLEncodeResponse(e, encoder) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) @@ -81,9 +88,10 @@ func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) { } type OAuthError struct { - ErrorType errorType `json:"error" schema:"error"` - Description string `json:"description" schema:"description"` - state string `json:"state" schema:"state"` + ErrorType errorType `json:"error" schema:"error"` + Description string `json:"description" schema:"description"` + state string `json:"state" schema:"state"` + redirectDisabled bool } func (e *OAuthError) Error() string {