refactoring
This commit is contained in:
parent
a793e77679
commit
310220d38e
17 changed files with 346 additions and 149 deletions
100
pkg/op/error.go
100
pkg/op/error.go
|
@ -1,8 +1,10 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
|
@ -13,6 +15,21 @@ const (
|
|||
ServerError errorType = "server_error"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidRequest = func(description string) *OAuthError {
|
||||
return &OAuthError{
|
||||
ErrorType: InvalidRequest,
|
||||
Description: description,
|
||||
}
|
||||
}
|
||||
ErrServerError = func(description string) *OAuthError {
|
||||
return &OAuthError{
|
||||
ErrorType: ServerError,
|
||||
Description: description,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
type errorType string
|
||||
|
||||
type ErrAuthRequest interface {
|
||||
|
@ -21,7 +38,7 @@ type ErrAuthRequest interface {
|
|||
GetState() string
|
||||
}
|
||||
|
||||
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error) {
|
||||
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder *schema.Encoder) {
|
||||
if authReq == nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
|
@ -30,27 +47,23 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
|
|||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
e, ok := err.(*OAuthError)
|
||||
if !ok {
|
||||
e = new(OAuthError)
|
||||
e.ErrorType = ServerError
|
||||
e.Description = err.Error()
|
||||
}
|
||||
e.state = authReq.GetState()
|
||||
params, err := utils.URLEncodeResponse(e, encoder)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
url := authReq.GetRedirectURI()
|
||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||
url += "?"
|
||||
url += "?" + params
|
||||
} else {
|
||||
url += "#"
|
||||
}
|
||||
var errorType errorType
|
||||
var description string
|
||||
if e, ok := err.(*OAuthError); ok {
|
||||
errorType = e.ErrorType
|
||||
description = e.Description
|
||||
} else {
|
||||
errorType = ServerError
|
||||
description = err.Error()
|
||||
}
|
||||
url += "error=" + string(errorType)
|
||||
if description != "" {
|
||||
url += "&error_description=" + description
|
||||
}
|
||||
if authReq.GetState() != "" {
|
||||
url += "&state=" + authReq.GetState()
|
||||
url += "#" + params
|
||||
}
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
@ -67,50 +80,11 @@ func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) {
|
|||
}
|
||||
|
||||
type OAuthError struct {
|
||||
ErrorType errorType `json:"error"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidRequest = func(description string, args ...interface{}) *OAuthError {
|
||||
return &OAuthError{
|
||||
ErrorType: InvalidRequest,
|
||||
Description: description,
|
||||
}
|
||||
}
|
||||
ErrServerError = func(description string, args ...interface{}) *OAuthError {
|
||||
return &OAuthError{
|
||||
ErrorType: ServerError,
|
||||
Description: description,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest) {
|
||||
if authReq == nil {
|
||||
http.Error(w, e.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if authReq.GetRedirectURI() == "" {
|
||||
http.Error(w, e.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
callback := authReq.GetRedirectURI()
|
||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||
callback += "?"
|
||||
} else {
|
||||
callback += "#"
|
||||
}
|
||||
callback += "error=" + string(e.ErrorType)
|
||||
if e.Description != "" {
|
||||
callback += "&error_description=" + url.QueryEscape(e.Description)
|
||||
}
|
||||
if authReq.GetState() != "" {
|
||||
callback += "&state=" + authReq.GetState()
|
||||
}
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
ErrorType errorType `json:"error" schema:"error"`
|
||||
Description string `json:"description" schema:"description"`
|
||||
state string `json:"state" schema:"state"`
|
||||
}
|
||||
|
||||
func (e *OAuthError) Error() string {
|
||||
return ""
|
||||
return fmt.Sprintf("%s: %s", e.ErrorType, e.Description)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue