fix: don't force server errors in legacy server (#517)
* fix: don't force server errors in legacy server * fix tests and be more consistent with the returned status code
This commit is contained in:
parent
844e2337bb
commit
57d04e7465
5 changed files with 42 additions and 22 deletions
|
@ -138,20 +138,20 @@ func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage
|
||||||
}
|
}
|
||||||
|
|
||||||
if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID {
|
if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID {
|
||||||
return oidc.ErrInvalidRequest()
|
return oidc.ErrInvalidRequest().WithDescription("missing or wrong client id in request")
|
||||||
}
|
}
|
||||||
if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType {
|
if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType {
|
||||||
return oidc.ErrInvalidRequest()
|
return oidc.ErrInvalidRequest().WithDescription("missing or wrong response type in request")
|
||||||
}
|
}
|
||||||
if requestObject.Issuer != requestObject.ClientID {
|
if requestObject.Issuer != requestObject.ClientID {
|
||||||
return oidc.ErrInvalidRequest()
|
return oidc.ErrInvalidRequest().WithDescription("missing or wrong issuer in request")
|
||||||
}
|
}
|
||||||
if !str.Contains(requestObject.Audience, issuer) {
|
if !str.Contains(requestObject.Audience, issuer) {
|
||||||
return oidc.ErrInvalidRequest()
|
return oidc.ErrInvalidRequest().WithDescription("issuer missing in audience")
|
||||||
}
|
}
|
||||||
keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer}
|
keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer}
|
||||||
if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil {
|
if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil {
|
||||||
return err
|
return oidc.ErrInvalidRequest().WithParent(err).WithDescription(err.Error())
|
||||||
}
|
}
|
||||||
CopyRequestObjectToAuthRequest(authReq, requestObject)
|
CopyRequestObjectToAuthRequest(authReq, requestObject)
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -157,13 +157,29 @@ func (e StatusError) Is(err error) bool {
|
||||||
e.statusCode == target.statusCode
|
e.statusCode == target.statusCode
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteError asserts for a StatusError containing an [oidc.Error].
|
// WriteError asserts for a [StatusError] containing an [oidc.Error].
|
||||||
// If no StatusError is found, the status code will default to [http.StatusBadRequest].
|
// If no `StatusError` is found, the status code will default to [http.StatusBadRequest].
|
||||||
// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError].
|
// If no `oidc.Error` was found in the parent, the error type defaults to [oidc.ServerError].
|
||||||
|
// When there was no `StatusError` and the `oidc.Error` is of type `oidc.ServerError`,
|
||||||
|
// the status code will be set to [http.StatusInternalServerError]
|
||||||
func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
|
func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
|
||||||
statusError := AsStatusError(err, http.StatusBadRequest)
|
var statusError StatusError
|
||||||
e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error())
|
if errors.As(err, &statusError) {
|
||||||
|
writeError(w, r,
|
||||||
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
|
oidc.DefaultToServerError(statusError.parent, statusError.parent.Error()),
|
||||||
httphelper.MarshalJSONWithStatus(w, e, statusError.statusCode)
|
statusError.statusCode, logger,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
statusCode := http.StatusBadRequest
|
||||||
|
e := oidc.DefaultToServerError(err, err.Error())
|
||||||
|
if e.ErrorType == oidc.ServerError {
|
||||||
|
statusCode = http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
writeError(w, r, e, statusCode, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeError(w http.ResponseWriter, r *http.Request, err *oidc.Error, statusCode int, logger *slog.Logger) {
|
||||||
|
logger.Log(r.Context(), err.LogLevel(), "request error", "oidc_error", err, "status_code", statusCode)
|
||||||
|
httphelper.MarshalJSONWithStatus(w, err, statusCode)
|
||||||
}
|
}
|
||||||
|
|
|
@ -579,7 +579,7 @@ func TestWriteError(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "not a status or oidc error",
|
name: "not a status or oidc error",
|
||||||
err: io.ErrClosedPipe,
|
err: io.ErrClosedPipe,
|
||||||
wantStatus: http.StatusBadRequest,
|
wantStatus: http.StatusInternalServerError,
|
||||||
wantBody: `{
|
wantBody: `{
|
||||||
"error":"server_error",
|
"error":"server_error",
|
||||||
"error_description":"io: read/write on closed pipe"
|
"error_description":"io: read/write on closed pipe"
|
||||||
|
@ -592,6 +592,7 @@ func TestWriteError(t *testing.T) {
|
||||||
"parent":"io: read/write on closed pipe",
|
"parent":"io: read/write on closed pipe",
|
||||||
"type":"server_error"
|
"type":"server_error"
|
||||||
},
|
},
|
||||||
|
"status_code":500,
|
||||||
"time":"not"
|
"time":"not"
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
|
@ -611,6 +612,7 @@ func TestWriteError(t *testing.T) {
|
||||||
"parent":"io: read/write on closed pipe",
|
"parent":"io: read/write on closed pipe",
|
||||||
"type":"server_error"
|
"type":"server_error"
|
||||||
},
|
},
|
||||||
|
"status_code":500,
|
||||||
"time":"not"
|
"time":"not"
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
|
@ -629,6 +631,7 @@ func TestWriteError(t *testing.T) {
|
||||||
"description":"oops",
|
"description":"oops",
|
||||||
"type":"invalid_request"
|
"type":"invalid_request"
|
||||||
},
|
},
|
||||||
|
"status_code":400,
|
||||||
"time":"not"
|
"time":"not"
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
|
@ -650,6 +653,7 @@ func TestWriteError(t *testing.T) {
|
||||||
"description":"oops",
|
"description":"oops",
|
||||||
"type":"unauthorized_client"
|
"type":"unauthorized_client"
|
||||||
},
|
},
|
||||||
|
"status_code":401,
|
||||||
"time":"not"
|
"time":"not"
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
|
|
|
@ -365,14 +365,14 @@ func Test_webServer_authorizeHandler(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "authorize error",
|
name: "server error",
|
||||||
fields: fields{
|
fields: fields{
|
||||||
server: &requestVerifier{},
|
server: &requestVerifier{},
|
||||||
decoder: testDecoder,
|
decoder: testDecoder,
|
||||||
},
|
},
|
||||||
r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")),
|
r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")),
|
||||||
want: webServerResult{
|
want: webServerResult{
|
||||||
wantStatus: http.StatusBadRequest,
|
wantStatus: http.StatusInternalServerError,
|
||||||
wantBody: `{"error":"server_error"}`,
|
wantBody: `{"error":"server_error"}`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1237,7 +1237,7 @@ func Test_webServer_simpleHandler(t *testing.T) {
|
||||||
},
|
},
|
||||||
r: httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(make([]byte, 11<<20))),
|
r: httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(make([]byte, 11<<20))),
|
||||||
want: webServerResult{
|
want: webServerResult{
|
||||||
wantStatus: http.StatusBadRequest,
|
wantStatus: http.StatusInternalServerError,
|
||||||
wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`,
|
wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
@ -91,7 +91,7 @@ func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Respon
|
||||||
for _, probe := range s.provider.Probes() {
|
for _, probe := range s.provider.Probes() {
|
||||||
// shouldn't we run probes in Go routines?
|
// shouldn't we run probes in Go routines?
|
||||||
if err := probe(ctx); err != nil {
|
if err := probe(ctx); err != nil {
|
||||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
return nil, AsStatusError(err, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return NewResponse(Status{Status: "ok"}), nil
|
return NewResponse(Status{Status: "ok"}), nil
|
||||||
|
@ -106,7 +106,7 @@ func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Re
|
||||||
func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||||
keys, err := s.provider.Storage().KeySet(ctx)
|
keys, err := s.provider.Storage().KeySet(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
return nil, AsStatusError(err, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return NewResponse(jsonWebKeySet(keys)), nil
|
return NewResponse(jsonWebKeySet(keys)), nil
|
||||||
}
|
}
|
||||||
|
@ -127,7 +127,7 @@ func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.Au
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if r.Data.ClientID == "" {
|
if r.Data.ClientID == "" {
|
||||||
return nil, ErrAuthReqMissingClientID
|
return nil, oidc.ErrInvalidRequest().WithParent(ErrAuthReqMissingClientID).WithDescription(ErrAuthReqMissingClientID.Error())
|
||||||
}
|
}
|
||||||
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
|
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -155,7 +155,7 @@ func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.Auth
|
||||||
func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
|
func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
|
||||||
response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider)
|
response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
return nil, AsStatusError(err, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return NewResponse(response), nil
|
return NewResponse(response), nil
|
||||||
}
|
}
|
||||||
|
@ -248,7 +248,7 @@ func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfil
|
||||||
}
|
}
|
||||||
tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx))
|
tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, oidc.ErrInvalidRequest().WithParent(err).WithDescription("assertion invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope)
|
tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue