error handling
This commit is contained in:
parent
6993769f06
commit
f4dac05713
3 changed files with 91 additions and 58 deletions
|
@ -2,6 +2,8 @@ package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||||
|
@ -68,13 +70,16 @@ func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slo
|
||||||
httphelper.MarshalJSONWithStatus(w, e, status)
|
httphelper.MarshalJSONWithStatus(w, e, status)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TryErrorRedirect tries to handle an error by redirecting a client.
|
||||||
|
// If this attempt fails, an error is returned that must be returned
|
||||||
|
// to the client instead.
|
||||||
func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) {
|
func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) {
|
||||||
e := oidc.DefaultToServerError(parent, parent.Error())
|
e := oidc.DefaultToServerError(parent, parent.Error())
|
||||||
logger = logger.With("oidc_error", e)
|
logger = logger.With("oidc_error", e)
|
||||||
|
|
||||||
if authReq == nil {
|
if authReq == nil {
|
||||||
logger.Log(ctx, e.LogLevel(), "auth request")
|
logger.Log(ctx, e.LogLevel(), "auth request")
|
||||||
return nil, NewStatusError(parent, http.StatusBadRequest)
|
return nil, AsStatusError(e, http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
if logAuthReq, ok := authReq.(LogAuthRequest); ok {
|
if logAuthReq, ok := authReq.(LogAuthRequest); ok {
|
||||||
|
@ -83,7 +88,7 @@ func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error,
|
||||||
|
|
||||||
if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
|
if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
|
||||||
logger.Log(ctx, e.LogLevel(), "auth request: not redirecting")
|
logger.Log(ctx, e.LogLevel(), "auth request: not redirecting")
|
||||||
return nil, NewStatusError(parent, http.StatusBadRequest)
|
return nil, AsStatusError(e, http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.State = authReq.GetState()
|
e.State = authReq.GetState()
|
||||||
|
@ -94,7 +99,71 @@ func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error,
|
||||||
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
|
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.ErrorContext(ctx, "auth response URL", "error", err)
|
logger.ErrorContext(ctx, "auth response URL", "error", err)
|
||||||
return nil, NewStatusError(err, http.StatusBadRequest)
|
return nil, AsStatusError(err, http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
logger.Log(ctx, e.LogLevel(), "auth request redirect", "url", url)
|
||||||
return NewRedirect(url), nil
|
return NewRedirect(url), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StatusError wraps an error with a HTTP status code.
|
||||||
|
// The status code is passed to the handler's writer.
|
||||||
|
type StatusError struct {
|
||||||
|
parent error
|
||||||
|
statusCode int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStatusError sets the parent and statusCode to a new StatusError.
|
||||||
|
// It is recommended for parent to be an [oidc.Error].
|
||||||
|
//
|
||||||
|
// Typically implementations should only use this to signal something
|
||||||
|
// very specific, like an internal server error.
|
||||||
|
// If a returned error is not a StatusError, the framework
|
||||||
|
// will set a statusCode based on what the standard specifies,
|
||||||
|
// which is [http.StatusBadRequest] for most of the time.
|
||||||
|
// If the error encountered can described clearly with a [oidc.Error],
|
||||||
|
// do not use this function, as it might break standard rules!
|
||||||
|
func NewStatusError(parent error, statusCode int) StatusError {
|
||||||
|
return StatusError{
|
||||||
|
parent: parent,
|
||||||
|
statusCode: statusCode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AsStatusError unwraps a StatusError from err
|
||||||
|
// and returns it unmodified if found.
|
||||||
|
// If no StatuError was found, a new one is returned
|
||||||
|
// with statusCode set to it as a default.
|
||||||
|
func AsStatusError(err error, statusCode int) (target StatusError) {
|
||||||
|
if errors.As(err, &target) {
|
||||||
|
return target
|
||||||
|
}
|
||||||
|
return NewStatusError(err, statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e StatusError) Error() string {
|
||||||
|
return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e StatusError) Unwrap() error {
|
||||||
|
return e.parent
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e StatusError) Is(err error) bool {
|
||||||
|
var target StatusError
|
||||||
|
if !errors.As(err, &target) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return errors.Is(e.parent, target.parent) &&
|
||||||
|
e.statusCode == target.statusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteError asserts for a StatusError containing an [oidc.Error].
|
||||||
|
// If no StatusError is found, the status code will default to [http.StatusBadRequest].
|
||||||
|
// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError].
|
||||||
|
func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
|
||||||
|
statusError := AsStatusError(err, http.StatusBadRequest)
|
||||||
|
e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error())
|
||||||
|
|
||||||
|
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
|
||||||
|
httphelper.MarshalJSONWithStatus(w, oidc.DefaultToServerError(e, e.Error()), statusError.statusCode)
|
||||||
|
}
|
||||||
|
|
|
@ -2,45 +2,14 @@ package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/muhlemmer/gu"
|
"github.com/muhlemmer/gu"
|
||||||
|
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
type StatusError struct {
|
|
||||||
parent error
|
|
||||||
statusCode int
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewStatusError(parent error, statusCode int) StatusError {
|
|
||||||
return StatusError{
|
|
||||||
parent: parent,
|
|
||||||
statusCode: statusCode,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e StatusError) Error() string {
|
|
||||||
return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e StatusError) Unwrap() error {
|
|
||||||
return e.parent
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e StatusError) Is(err error) bool {
|
|
||||||
var target StatusError
|
|
||||||
if !errors.As(err, &target) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return errors.Is(e.parent, target.parent) &&
|
|
||||||
e.statusCode == target.statusCode
|
|
||||||
}
|
|
||||||
|
|
||||||
// Server describes the interface that needs to be implemented to serve
|
// Server describes the interface that needs to be implemented to serve
|
||||||
// OpenID Connect and Oauth2 standard requests.
|
// OpenID Connect and Oauth2 standard requests.
|
||||||
//
|
//
|
||||||
|
@ -177,6 +146,10 @@ type Request[T any] struct {
|
||||||
Data *T
|
Data *T
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Request[_]) path() string {
|
||||||
|
return r.URL.Path
|
||||||
|
}
|
||||||
|
|
||||||
func newRequest[T any](r *http.Request, data *T) *Request[T] {
|
func newRequest[T any](r *http.Request, data *T) *Request[T] {
|
||||||
return &Request[T]{
|
return &Request[T]{
|
||||||
Method: r.Method,
|
Method: r.Method,
|
||||||
|
@ -226,7 +199,7 @@ func NewResponse(data any) *Response {
|
||||||
|
|
||||||
func (resp *Response) writeOut(w http.ResponseWriter) {
|
func (resp *Response) writeOut(w http.ResponseWriter) {
|
||||||
gu.MapMerge(resp.Header, w.Header())
|
gu.MapMerge(resp.Header, w.Header())
|
||||||
json.NewEncoder(w).Encode(resp.Data)
|
httphelper.MarshalJSON(w, resp.Data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redirect is a special response type which will
|
// Redirect is a special response type which will
|
||||||
|
@ -253,12 +226,9 @@ type UnimplementedServer struct{}
|
||||||
// and not http methods covered by "501 Not Implemented".
|
// and not http methods covered by "501 Not Implemented".
|
||||||
var UnimplementedStatusCode = http.StatusNotFound
|
var UnimplementedStatusCode = http.StatusNotFound
|
||||||
|
|
||||||
func unimplementedError[T any](r *Request[T]) StatusError {
|
func unimplementedError(r interface{ path() string }) StatusError {
|
||||||
err := oidc.ErrServerError().WithDescription("%s not implemented on this server", r.URL.Path)
|
err := oidc.ErrServerError().WithDescription("%s not implemented on this server", r.path())
|
||||||
return StatusError{
|
return NewStatusError(err, UnimplementedStatusCode)
|
||||||
parent: err,
|
|
||||||
statusCode: UnimplementedStatusCode,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func unimplementedGrantError(gt oidc.GrantType) StatusError {
|
func unimplementedGrantError(gt oidc.GrantType) StatusError {
|
||||||
|
@ -289,7 +259,7 @@ func (UnimplementedServer) Authorize(ctx context.Context, r *Request[oidc.AuthRe
|
||||||
}
|
}
|
||||||
|
|
||||||
func (UnimplementedServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
|
func (UnimplementedServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
|
||||||
return nil, unimplementedError(r.Request)
|
return nil, unimplementedError(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (UnimplementedServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
|
func (UnimplementedServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
|
||||||
|
@ -321,7 +291,7 @@ func (UnimplementedServer) DeviceToken(ctx context.Context, r *ClientRequest[oid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (UnimplementedServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) {
|
func (UnimplementedServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) {
|
||||||
return nil, unimplementedError(r.Request)
|
return nil, unimplementedError(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (UnimplementedServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) {
|
func (UnimplementedServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) {
|
||||||
|
@ -329,7 +299,7 @@ func (UnimplementedServer) UserInfo(ctx context.Context, r *Request[oidc.UserInf
|
||||||
}
|
}
|
||||||
|
|
||||||
func (UnimplementedServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) {
|
func (UnimplementedServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) {
|
||||||
return nil, unimplementedError(r.Request)
|
return nil, unimplementedError(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (UnimplementedServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) {
|
func (UnimplementedServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) {
|
||||||
|
|
|
@ -61,37 +61,31 @@ func (s *webServer) verifyRequestClient(r *http.Request) (Client, error) {
|
||||||
func (s *webServer) handleToken(w http.ResponseWriter, r *http.Request) {
|
func (s *webServer) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||||
client, err := s.verifyRequestClient(r)
|
client, err := s.verifyRequestClient(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err, slog.Default())
|
WriteError(w, r, err, slog.Default())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
grantType := oidc.GrantType(r.Form.Get("grant_type"))
|
grantType := oidc.GrantType(r.Form.Get("grant_type"))
|
||||||
var handle func(w http.ResponseWriter, r *http.Request, client Client)
|
|
||||||
switch grantType {
|
switch grantType {
|
||||||
case oidc.GrantTypeCode:
|
case oidc.GrantTypeCode:
|
||||||
handle = s.handleCodeExchange
|
s.handleCodeExchange(w, r, client)
|
||||||
case oidc.GrantTypeRefreshToken:
|
case oidc.GrantTypeRefreshToken:
|
||||||
handle = s.handleRefreshToken
|
s.handleRefreshToken(w, r, client)
|
||||||
case "":
|
case "":
|
||||||
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default())
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default())
|
||||||
return
|
|
||||||
default:
|
default:
|
||||||
RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType), slog.Default())
|
WriteError(w, r, unimplementedGrantError(grantType), slog.Default())
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
handle(w, r, client)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *webServer) handleCodeExchange(w http.ResponseWriter, r *http.Request, client Client) {
|
func (s *webServer) handleCodeExchange(w http.ResponseWriter, r *http.Request, client Client) {
|
||||||
request, err := decodeRequest[*oidc.AccessTokenRequest](s.decoder, r.Form)
|
request, err := decodeRequest[*oidc.AccessTokenRequest](s.decoder, r.Form)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err, s.logger)
|
WriteError(w, r, err, s.logger)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client))
|
resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err, s.logger)
|
WriteError(w, r, err, s.logger)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.writeOut(w)
|
resp.writeOut(w)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue