From f4dac05713d0458db2fd8d9cc0fabac31c8d3d12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Tue, 12 Sep 2023 11:17:59 +0300 Subject: [PATCH] error handling --- pkg/op/error.go | 75 +++++++++++++++++++++++++++++++++++++++++-- pkg/op/server.go | 54 +++++++------------------------ pkg/op/server_http.go | 20 ++++-------- 3 files changed, 91 insertions(+), 58 deletions(-) diff --git a/pkg/op/error.go b/pkg/op/error.go index 67278c6..6c5a04a 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -2,6 +2,8 @@ package op import ( "context" + "errors" + "fmt" "net/http" httphelper "github.com/zitadel/oidc/v3/pkg/http" @@ -68,13 +70,16 @@ func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slo httphelper.MarshalJSONWithStatus(w, e, status) } +// TryErrorRedirect tries to handle an error by redirecting a client. +// If this attempt fails, an error is returned that must be returned +// to the client instead. func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) { e := oidc.DefaultToServerError(parent, parent.Error()) logger = logger.With("oidc_error", e) if authReq == nil { logger.Log(ctx, e.LogLevel(), "auth request") - return nil, NewStatusError(parent, http.StatusBadRequest) + return nil, AsStatusError(e, http.StatusBadRequest) } if logAuthReq, ok := authReq.(LogAuthRequest); ok { @@ -83,7 +88,7 @@ func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() { logger.Log(ctx, e.LogLevel(), "auth request: not redirecting") - return nil, NewStatusError(parent, http.StatusBadRequest) + return nil, AsStatusError(e, http.StatusBadRequest) } e.State = authReq.GetState() @@ -94,7 +99,71 @@ func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder) if err != nil { logger.ErrorContext(ctx, "auth response URL", "error", err) - return nil, NewStatusError(err, http.StatusBadRequest) + return nil, AsStatusError(err, http.StatusBadRequest) } + logger.Log(ctx, e.LogLevel(), "auth request redirect", "url", url) return NewRedirect(url), nil } + +// StatusError wraps an error with a HTTP status code. +// The status code is passed to the handler's writer. +type StatusError struct { + parent error + statusCode int +} + +// NewStatusError sets the parent and statusCode to a new StatusError. +// It is recommended for parent to be an [oidc.Error]. +// +// Typically implementations should only use this to signal something +// very specific, like an internal server error. +// If a returned error is not a StatusError, the framework +// will set a statusCode based on what the standard specifies, +// which is [http.StatusBadRequest] for most of the time. +// If the error encountered can described clearly with a [oidc.Error], +// do not use this function, as it might break standard rules! +func NewStatusError(parent error, statusCode int) StatusError { + return StatusError{ + parent: parent, + statusCode: statusCode, + } +} + +// AsStatusError unwraps a StatusError from err +// and returns it unmodified if found. +// If no StatuError was found, a new one is returned +// with statusCode set to it as a default. +func AsStatusError(err error, statusCode int) (target StatusError) { + if errors.As(err, &target) { + return target + } + return NewStatusError(err, statusCode) +} + +func (e StatusError) Error() string { + return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error()) +} + +func (e StatusError) Unwrap() error { + return e.parent +} + +func (e StatusError) Is(err error) bool { + var target StatusError + if !errors.As(err, &target) { + return false + } + return errors.Is(e.parent, target.parent) && + e.statusCode == target.statusCode +} + +// WriteError asserts for a StatusError containing an [oidc.Error]. +// If no StatusError is found, the status code will default to [http.StatusBadRequest]. +// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError]. +func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) { + statusError := AsStatusError(err, http.StatusBadRequest) + e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error()) + + logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) + httphelper.MarshalJSONWithStatus(w, oidc.DefaultToServerError(e, e.Error()), statusError.statusCode) +} diff --git a/pkg/op/server.go b/pkg/op/server.go index d2419dd..a9be613 100644 --- a/pkg/op/server.go +++ b/pkg/op/server.go @@ -2,45 +2,14 @@ package op import ( "context" - "encoding/json" - "errors" - "fmt" "net/http" "net/url" "github.com/muhlemmer/gu" + httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" ) -type StatusError struct { - parent error - statusCode int -} - -func NewStatusError(parent error, statusCode int) StatusError { - return StatusError{ - parent: parent, - statusCode: statusCode, - } -} - -func (e StatusError) Error() string { - return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error()) -} - -func (e StatusError) Unwrap() error { - return e.parent -} - -func (e StatusError) Is(err error) bool { - var target StatusError - if !errors.As(err, &target) { - return false - } - return errors.Is(e.parent, target.parent) && - e.statusCode == target.statusCode -} - // Server describes the interface that needs to be implemented to serve // OpenID Connect and Oauth2 standard requests. // @@ -177,6 +146,10 @@ type Request[T any] struct { Data *T } +func (r *Request[_]) path() string { + return r.URL.Path +} + func newRequest[T any](r *http.Request, data *T) *Request[T] { return &Request[T]{ Method: r.Method, @@ -226,7 +199,7 @@ func NewResponse(data any) *Response { func (resp *Response) writeOut(w http.ResponseWriter) { gu.MapMerge(resp.Header, w.Header()) - json.NewEncoder(w).Encode(resp.Data) + httphelper.MarshalJSON(w, resp.Data) } // Redirect is a special response type which will @@ -253,12 +226,9 @@ type UnimplementedServer struct{} // and not http methods covered by "501 Not Implemented". var UnimplementedStatusCode = http.StatusNotFound -func unimplementedError[T any](r *Request[T]) StatusError { - err := oidc.ErrServerError().WithDescription("%s not implemented on this server", r.URL.Path) - return StatusError{ - parent: err, - statusCode: UnimplementedStatusCode, - } +func unimplementedError(r interface{ path() string }) StatusError { + err := oidc.ErrServerError().WithDescription("%s not implemented on this server", r.path()) + return NewStatusError(err, UnimplementedStatusCode) } func unimplementedGrantError(gt oidc.GrantType) StatusError { @@ -289,7 +259,7 @@ func (UnimplementedServer) Authorize(ctx context.Context, r *Request[oidc.AuthRe } func (UnimplementedServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) { - return nil, unimplementedError(r.Request) + return nil, unimplementedError(r) } func (UnimplementedServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { @@ -321,7 +291,7 @@ func (UnimplementedServer) DeviceToken(ctx context.Context, r *ClientRequest[oid } func (UnimplementedServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) { - return nil, unimplementedError(r.Request) + return nil, unimplementedError(r) } func (UnimplementedServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) { @@ -329,7 +299,7 @@ func (UnimplementedServer) UserInfo(ctx context.Context, r *Request[oidc.UserInf } func (UnimplementedServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) { - return nil, unimplementedError(r.Request) + return nil, unimplementedError(r) } func (UnimplementedServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) { diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index e727626..4016db3 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -61,37 +61,31 @@ func (s *webServer) verifyRequestClient(r *http.Request) (Client, error) { func (s *webServer) handleToken(w http.ResponseWriter, r *http.Request) { client, err := s.verifyRequestClient(r) if err != nil { - RequestError(w, r, err, slog.Default()) + WriteError(w, r, err, slog.Default()) return } - grantType := oidc.GrantType(r.Form.Get("grant_type")) - var handle func(w http.ResponseWriter, r *http.Request, client Client) switch grantType { case oidc.GrantTypeCode: - handle = s.handleCodeExchange + s.handleCodeExchange(w, r, client) case oidc.GrantTypeRefreshToken: - handle = s.handleRefreshToken + s.handleRefreshToken(w, r, client) case "": - RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default()) - return + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default()) default: - RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType), slog.Default()) - return + WriteError(w, r, unimplementedGrantError(grantType), slog.Default()) } - - handle(w, r, client) } func (s *webServer) handleCodeExchange(w http.ResponseWriter, r *http.Request, client Client) { request, err := decodeRequest[*oidc.AccessTokenRequest](s.decoder, r.Form) if err != nil { - RequestError(w, r, err, s.logger) + WriteError(w, r, err, s.logger) return } resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client)) if err != nil { - RequestError(w, r, err, s.logger) + WriteError(w, r, err, s.logger) return } resp.writeOut(w)