diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 7610248..85c8ef4 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -74,7 +74,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { } ctx := r.Context() if authReq.RequestParam != "" && authorizer.RequestObjectSupported() { - authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx)) + err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx)) if err != nil { AuthRequestError(w, r, authReq, err, authorizer) return @@ -130,31 +130,31 @@ func ParseAuthorizeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.A // ParseRequestObject parse the `request` parameter, validates the token including the signature // and copies the token claims into the auth request -func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) (*oidc.AuthRequest, error) { +func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) error { requestObject := new(oidc.RequestObject) payload, err := oidc.ParseToken(authReq.RequestParam, requestObject) if err != nil { - return nil, err + return err } if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } if requestObject.Issuer != requestObject.ClientID { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } if !str.Contains(requestObject.Audience, issuer) { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest() } keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer} if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil { - return authReq, err + return err } CopyRequestObjectToAuthRequest(authReq, requestObject) - return authReq, nil + return nil } // CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request @@ -228,6 +228,44 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage return ValidateAuthReqIDTokenHint(ctx, authReq.IDTokenHint, verifier) } +// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed +func ValidateAuthRequestV2(ctx context.Context, authReq *oidc.AuthRequest, provider Authorizer) (sub string, err error) { + if authReq.RequestParam != "" && provider.RequestObjectSupported() { + err := ParseRequestObject(ctx, authReq, provider.Storage(), IssuerFromContext(ctx)) + if err != nil { + return "", err + } + } + if authReq.ClientID == "" { + return "", ErrAuthReqMissingClientID + } + if authReq.RedirectURI == "" { + return "", ErrAuthReqMissingRedirectURI + } + if authReq.RequestParam != "" { + return "", oidc.ErrRequestNotSupported() + } + authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge) + if err != nil { + return "", err + } + client, err := provider.Storage().GetClientByClientID(ctx, authReq.ClientID) + if err != nil { + return "", oidc.DefaultToServerError(err, "unable to retrieve client by id") + } + authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes) + if err != nil { + return "", err + } + if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil { + return "", err + } + if err := ValidateAuthReqResponseType(client, authReq.ResponseType); err != nil { + return "", err + } + return ValidateAuthReqIDTokenHint(ctx, authReq.IDTokenHint, provider.IDTokenHintVerifier(ctx)) +} + // ValidateAuthReqPrompt validates the passed prompt values and sets max_age to 0 if prompt login is present func ValidateAuthReqPrompt(prompts []string, maxAge *uint) (_ *uint, err error) { for _, prompt := range prompts { diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go index 4411619..03caf33 100644 --- a/pkg/op/server_legacy.go +++ b/pkg/op/server_legacy.go @@ -50,37 +50,17 @@ var ( ) func (s *LegacyServer) Authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) { - authReq := r.Data - if authReq.RequestParam != "" && s.provider.RequestObjectSupported() { - authReq, err = ParseRequestObject(ctx, authReq, s.provider.Storage(), IssuerFromContext(ctx)) - if err != nil { - return nil, NewStatusError(err, http.StatusBadRequest) - } - } - if authReq.ClientID == "" { - return TryErrorRedirect(ctx, authReq, ErrAuthReqMissingClientID, s.provider.Encoder(), s.provider.Logger()) - } - if authReq.RedirectURI == "" { - return TryErrorRedirect(ctx, authReq, ErrAuthReqMissingRedirectURI, s.provider.Encoder(), s.provider.Logger()) - } - validation := ValidateAuthRequest - if validater, ok := s.provider.(AuthorizeValidator); ok { - validation = validater.ValidateAuthRequest - } - userID, err := validation(ctx, authReq, s.provider.Storage(), s.provider.IDTokenHintVerifier(ctx)) + userID, err := ValidateAuthRequestV2(ctx, r.Data, s.provider) if err != nil { - return TryErrorRedirect(ctx, authReq, err, s.provider.Encoder(), s.provider.Logger()) + return nil, err } - if authReq.RequestParam != "" { - return TryErrorRedirect(ctx, authReq, oidc.ErrRequestNotSupported(), s.provider.Encoder(), s.provider.Logger()) - } - req, err := s.provider.Storage().CreateAuthRequest(ctx, authReq, userID) + req, err := s.provider.Storage().CreateAuthRequest(ctx, r.Data, userID) if err != nil { - return TryErrorRedirect(ctx, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger()) + return TryErrorRedirect(ctx, r.Data, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger()) } client, err := s.provider.Storage().GetClientByClientID(ctx, req.GetClientID()) if err != nil { - return TryErrorRedirect(ctx, authReq, oidc.DefaultToServerError(err, "unable to retrieve client by id"), s.provider.Encoder(), s.provider.Logger()) + return TryErrorRedirect(ctx, r.Data, oidc.DefaultToServerError(err, "unable to retrieve client by id"), s.provider.Encoder(), s.provider.Logger()) } return NewRedirect(client.LoginURL(req.GetID())), nil }