error handling

This commit is contained in:
Tim Möhlmann 2023-09-12 11:17:59 +03:00
parent 6993769f06
commit f4dac05713
3 changed files with 91 additions and 58 deletions

View file

@ -2,6 +2,8 @@ package op
import (
"context"
"errors"
"fmt"
"net/http"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
@ -68,13 +70,16 @@ func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slo
httphelper.MarshalJSONWithStatus(w, e, status)
}
// TryErrorRedirect tries to handle an error by redirecting a client.
// If this attempt fails, an error is returned that must be returned
// to the client instead.
func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) {
e := oidc.DefaultToServerError(parent, parent.Error())
logger = logger.With("oidc_error", e)
if authReq == nil {
logger.Log(ctx, e.LogLevel(), "auth request")
return nil, NewStatusError(parent, http.StatusBadRequest)
return nil, AsStatusError(e, http.StatusBadRequest)
}
if logAuthReq, ok := authReq.(LogAuthRequest); ok {
@ -83,7 +88,7 @@ func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error,
if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
logger.Log(ctx, e.LogLevel(), "auth request: not redirecting")
return nil, NewStatusError(parent, http.StatusBadRequest)
return nil, AsStatusError(e, http.StatusBadRequest)
}
e.State = authReq.GetState()
@ -94,7 +99,71 @@ func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error,
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
if err != nil {
logger.ErrorContext(ctx, "auth response URL", "error", err)
return nil, NewStatusError(err, http.StatusBadRequest)
return nil, AsStatusError(err, http.StatusBadRequest)
}
logger.Log(ctx, e.LogLevel(), "auth request redirect", "url", url)
return NewRedirect(url), nil
}
// StatusError wraps an error with a HTTP status code.
// The status code is passed to the handler's writer.
type StatusError struct {
parent error
statusCode int
}
// NewStatusError sets the parent and statusCode to a new StatusError.
// It is recommended for parent to be an [oidc.Error].
//
// Typically implementations should only use this to signal something
// very specific, like an internal server error.
// If a returned error is not a StatusError, the framework
// will set a statusCode based on what the standard specifies,
// which is [http.StatusBadRequest] for most of the time.
// If the error encountered can described clearly with a [oidc.Error],
// do not use this function, as it might break standard rules!
func NewStatusError(parent error, statusCode int) StatusError {
return StatusError{
parent: parent,
statusCode: statusCode,
}
}
// AsStatusError unwraps a StatusError from err
// and returns it unmodified if found.
// If no StatuError was found, a new one is returned
// with statusCode set to it as a default.
func AsStatusError(err error, statusCode int) (target StatusError) {
if errors.As(err, &target) {
return target
}
return NewStatusError(err, statusCode)
}
func (e StatusError) Error() string {
return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error())
}
func (e StatusError) Unwrap() error {
return e.parent
}
func (e StatusError) Is(err error) bool {
var target StatusError
if !errors.As(err, &target) {
return false
}
return errors.Is(e.parent, target.parent) &&
e.statusCode == target.statusCode
}
// WriteError asserts for a StatusError containing an [oidc.Error].
// If no StatusError is found, the status code will default to [http.StatusBadRequest].
// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError].
func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
statusError := AsStatusError(err, http.StatusBadRequest)
e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error())
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
httphelper.MarshalJSONWithStatus(w, oidc.DefaultToServerError(e, e.Error()), statusError.statusCode)
}