From 57d04e74651766c47097c8253bf08417a366d553 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Wed, 17 Jan 2024 17:06:45 +0200 Subject: [PATCH] fix: don't force server errors in legacy server (#517) * fix: don't force server errors in legacy server * fix tests and be more consistent with the returned status code --- pkg/op/auth_request.go | 10 +++++----- pkg/op/error.go | 32 ++++++++++++++++++++++++-------- pkg/op/error_test.go | 6 +++++- pkg/op/server_http_test.go | 6 +++--- pkg/op/server_legacy.go | 10 +++++----- 5 files changed, 42 insertions(+), 22 deletions(-) diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 02c820e..ed368eb 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -138,20 +138,20 @@ func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage } if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID { - return oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest().WithDescription("missing or wrong client id in request") } if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType { - return oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest().WithDescription("missing or wrong response type in request") } if requestObject.Issuer != requestObject.ClientID { - return oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest().WithDescription("missing or wrong issuer in request") } if !str.Contains(requestObject.Audience, issuer) { - return oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest().WithDescription("issuer missing in audience") } keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer} if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil { - return err + return oidc.ErrInvalidRequest().WithParent(err).WithDescription(err.Error()) } CopyRequestObjectToAuthRequest(authReq, requestObject) return nil diff --git a/pkg/op/error.go b/pkg/op/error.go index 0cac14b..e4580f6 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -157,13 +157,29 @@ func (e StatusError) Is(err error) bool { e.statusCode == target.statusCode } -// WriteError asserts for a StatusError containing an [oidc.Error]. -// If no StatusError is found, the status code will default to [http.StatusBadRequest]. -// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError]. +// WriteError asserts for a [StatusError] containing an [oidc.Error]. +// If no `StatusError` is found, the status code will default to [http.StatusBadRequest]. +// If no `oidc.Error` was found in the parent, the error type defaults to [oidc.ServerError]. +// When there was no `StatusError` and the `oidc.Error` is of type `oidc.ServerError`, +// the status code will be set to [http.StatusInternalServerError] func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) { - statusError := AsStatusError(err, http.StatusBadRequest) - e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error()) - - logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) - httphelper.MarshalJSONWithStatus(w, e, statusError.statusCode) + var statusError StatusError + if errors.As(err, &statusError) { + writeError(w, r, + oidc.DefaultToServerError(statusError.parent, statusError.parent.Error()), + statusError.statusCode, logger, + ) + return + } + statusCode := http.StatusBadRequest + e := oidc.DefaultToServerError(err, err.Error()) + if e.ErrorType == oidc.ServerError { + statusCode = http.StatusInternalServerError + } + writeError(w, r, e, statusCode, logger) +} + +func writeError(w http.ResponseWriter, r *http.Request, err *oidc.Error, statusCode int, logger *slog.Logger) { + logger.Log(r.Context(), err.LogLevel(), "request error", "oidc_error", err, "status_code", statusCode) + httphelper.MarshalJSONWithStatus(w, err, statusCode) } diff --git a/pkg/op/error_test.go b/pkg/op/error_test.go index 689ee5a..50a9cbf 100644 --- a/pkg/op/error_test.go +++ b/pkg/op/error_test.go @@ -579,7 +579,7 @@ func TestWriteError(t *testing.T) { { name: "not a status or oidc error", err: io.ErrClosedPipe, - wantStatus: http.StatusBadRequest, + wantStatus: http.StatusInternalServerError, wantBody: `{ "error":"server_error", "error_description":"io: read/write on closed pipe" @@ -592,6 +592,7 @@ func TestWriteError(t *testing.T) { "parent":"io: read/write on closed pipe", "type":"server_error" }, + "status_code":500, "time":"not" }`, }, @@ -611,6 +612,7 @@ func TestWriteError(t *testing.T) { "parent":"io: read/write on closed pipe", "type":"server_error" }, + "status_code":500, "time":"not" }`, }, @@ -629,6 +631,7 @@ func TestWriteError(t *testing.T) { "description":"oops", "type":"invalid_request" }, + "status_code":400, "time":"not" }`, }, @@ -650,6 +653,7 @@ func TestWriteError(t *testing.T) { "description":"oops", "type":"unauthorized_client" }, + "status_code":401, "time":"not" }`, }, diff --git a/pkg/op/server_http_test.go b/pkg/op/server_http_test.go index 4eac4a0..6cb268f 100644 --- a/pkg/op/server_http_test.go +++ b/pkg/op/server_http_test.go @@ -365,14 +365,14 @@ func Test_webServer_authorizeHandler(t *testing.T) { }, }, { - name: "authorize error", + name: "server error", fields: fields{ server: &requestVerifier{}, decoder: testDecoder, }, r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")), want: webServerResult{ - wantStatus: http.StatusBadRequest, + wantStatus: http.StatusInternalServerError, wantBody: `{"error":"server_error"}`, }, }, @@ -1237,7 +1237,7 @@ func Test_webServer_simpleHandler(t *testing.T) { }, r: httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(make([]byte, 11<<20))), want: webServerResult{ - wantStatus: http.StatusBadRequest, + wantStatus: http.StatusInternalServerError, wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`, }, }, diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go index 089be6f..f99d15d 100644 --- a/pkg/op/server_legacy.go +++ b/pkg/op/server_legacy.go @@ -91,7 +91,7 @@ func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Respon for _, probe := range s.provider.Probes() { // shouldn't we run probes in Go routines? if err := probe(ctx); err != nil { - return nil, NewStatusError(err, http.StatusInternalServerError) + return nil, AsStatusError(err, http.StatusInternalServerError) } } return NewResponse(Status{Status: "ok"}), nil @@ -106,7 +106,7 @@ func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Re func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) { keys, err := s.provider.Storage().KeySet(ctx) if err != nil { - return nil, NewStatusError(err, http.StatusInternalServerError) + return nil, AsStatusError(err, http.StatusInternalServerError) } return NewResponse(jsonWebKeySet(keys)), nil } @@ -127,7 +127,7 @@ func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.Au } } if r.Data.ClientID == "" { - return nil, ErrAuthReqMissingClientID + return nil, oidc.ErrInvalidRequest().WithParent(ErrAuthReqMissingClientID).WithDescription(ErrAuthReqMissingClientID.Error()) } client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID) if err != nil { @@ -155,7 +155,7 @@ func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.Auth func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) { response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider) if err != nil { - return nil, NewStatusError(err, http.StatusInternalServerError) + return nil, AsStatusError(err, http.StatusInternalServerError) } return NewResponse(response), nil } @@ -248,7 +248,7 @@ func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfil } tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx)) if err != nil { - return nil, err + return nil, oidc.ErrInvalidRequest().WithParent(err).WithDescription("assertion invalid") } tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope)