fix: don't force server errors in legacy server

This commit is contained in:
Tim Möhlmann 2024-01-16 13:19:38 +02:00
parent 844e2337bb
commit c22649b20b
3 changed files with 15 additions and 12 deletions

View file

@ -138,20 +138,20 @@ func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage
} }
if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID { if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID {
return oidc.ErrInvalidRequest() return oidc.ErrInvalidRequest().WithDescription("missing or wrong client id in request")
} }
if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType { if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType {
return oidc.ErrInvalidRequest() return oidc.ErrInvalidRequest().WithDescription("missing or wrong response type in request")
} }
if requestObject.Issuer != requestObject.ClientID { if requestObject.Issuer != requestObject.ClientID {
return oidc.ErrInvalidRequest() return oidc.ErrInvalidRequest().WithDescription("missing or wrong issuer in request")
} }
if !str.Contains(requestObject.Audience, issuer) { if !str.Contains(requestObject.Audience, issuer) {
return oidc.ErrInvalidRequest() return oidc.ErrInvalidRequest().WithDescription("issuer missing in audience")
} }
keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer} keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer}
if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil { if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil {
return err return oidc.ErrInvalidRequest().WithParent(err).WithDescription(err.Error())
} }
CopyRequestObjectToAuthRequest(authReq, requestObject) CopyRequestObjectToAuthRequest(authReq, requestObject)
return nil return nil

View file

@ -160,10 +160,13 @@ func (e StatusError) Is(err error) bool {
// WriteError asserts for a StatusError containing an [oidc.Error]. // WriteError asserts for a StatusError containing an [oidc.Error].
// If no StatusError is found, the status code will default to [http.StatusBadRequest]. // If no StatusError is found, the status code will default to [http.StatusBadRequest].
// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError]. // If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError].
// When the final oidc Error is a server error, the status code is adjusted to [http.StatusInternalServerError].
func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) { func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
statusError := AsStatusError(err, http.StatusBadRequest) statusError := AsStatusError(err, http.StatusBadRequest)
e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error()) e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error())
if e.ErrorType == oidc.ServerError {
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) statusError.statusCode = http.StatusInternalServerError
}
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e, "status_code", statusError.statusCode)
httphelper.MarshalJSONWithStatus(w, e, statusError.statusCode) httphelper.MarshalJSONWithStatus(w, e, statusError.statusCode)
} }

View file

@ -91,7 +91,7 @@ func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Respon
for _, probe := range s.provider.Probes() { for _, probe := range s.provider.Probes() {
// shouldn't we run probes in Go routines? // shouldn't we run probes in Go routines?
if err := probe(ctx); err != nil { if err := probe(ctx); err != nil {
return nil, NewStatusError(err, http.StatusInternalServerError) return nil, AsStatusError(err, http.StatusInternalServerError)
} }
} }
return NewResponse(Status{Status: "ok"}), nil return NewResponse(Status{Status: "ok"}), nil
@ -106,7 +106,7 @@ func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Re
func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) { func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
keys, err := s.provider.Storage().KeySet(ctx) keys, err := s.provider.Storage().KeySet(ctx)
if err != nil { if err != nil {
return nil, NewStatusError(err, http.StatusInternalServerError) return nil, AsStatusError(err, http.StatusInternalServerError)
} }
return NewResponse(jsonWebKeySet(keys)), nil return NewResponse(jsonWebKeySet(keys)), nil
} }
@ -127,7 +127,7 @@ func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.Au
} }
} }
if r.Data.ClientID == "" { if r.Data.ClientID == "" {
return nil, ErrAuthReqMissingClientID return nil, oidc.ErrInvalidRequest().WithParent(ErrAuthReqMissingClientID).WithDescription(ErrAuthReqMissingClientID.Error())
} }
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID) client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
if err != nil { if err != nil {
@ -155,7 +155,7 @@ func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.Auth
func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) { func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider) response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider)
if err != nil { if err != nil {
return nil, NewStatusError(err, http.StatusInternalServerError) return nil, AsStatusError(err, http.StatusInternalServerError)
} }
return NewResponse(response), nil return NewResponse(response), nil
} }
@ -248,7 +248,7 @@ func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfil
} }
tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx)) tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx))
if err != nil { if err != nil {
return nil, err return nil, oidc.ErrInvalidRequest().WithParent(err).WithDescription("assertion invalid")
} }
tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope) tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope)