rewrite auth request validation

This commit is contained in:
Tim Möhlmann 2023-09-13 18:03:50 +03:00
parent f4dac05713
commit fe3f98a4f9
2 changed files with 52 additions and 34 deletions

View file

@ -74,7 +74,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
}
ctx := r.Context()
if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer)
return
@ -130,31 +130,31 @@ func ParseAuthorizeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.A
// ParseRequestObject parse the `request` parameter, validates the token including the signature
// and copies the token claims into the auth request
func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) (*oidc.AuthRequest, error) {
func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) error {
requestObject := new(oidc.RequestObject)
payload, err := oidc.ParseToken(authReq.RequestParam, requestObject)
if err != nil {
return nil, err
return err
}
if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID {
return authReq, oidc.ErrInvalidRequest()
return oidc.ErrInvalidRequest()
}
if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType {
return authReq, oidc.ErrInvalidRequest()
return oidc.ErrInvalidRequest()
}
if requestObject.Issuer != requestObject.ClientID {
return authReq, oidc.ErrInvalidRequest()
return oidc.ErrInvalidRequest()
}
if !str.Contains(requestObject.Audience, issuer) {
return authReq, oidc.ErrInvalidRequest()
return oidc.ErrInvalidRequest()
}
keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer}
if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil {
return authReq, err
return err
}
CopyRequestObjectToAuthRequest(authReq, requestObject)
return authReq, nil
return nil
}
// CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request
@ -228,6 +228,44 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
return ValidateAuthReqIDTokenHint(ctx, authReq.IDTokenHint, verifier)
}
// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed
func ValidateAuthRequestV2(ctx context.Context, authReq *oidc.AuthRequest, provider Authorizer) (sub string, err error) {
if authReq.RequestParam != "" && provider.RequestObjectSupported() {
err := ParseRequestObject(ctx, authReq, provider.Storage(), IssuerFromContext(ctx))
if err != nil {
return "", err
}
}
if authReq.ClientID == "" {
return "", ErrAuthReqMissingClientID
}
if authReq.RedirectURI == "" {
return "", ErrAuthReqMissingRedirectURI
}
if authReq.RequestParam != "" {
return "", oidc.ErrRequestNotSupported()
}
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
if err != nil {
return "", err
}
client, err := provider.Storage().GetClientByClientID(ctx, authReq.ClientID)
if err != nil {
return "", oidc.DefaultToServerError(err, "unable to retrieve client by id")
}
authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes)
if err != nil {
return "", err
}
if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil {
return "", err
}
if err := ValidateAuthReqResponseType(client, authReq.ResponseType); err != nil {
return "", err
}
return ValidateAuthReqIDTokenHint(ctx, authReq.IDTokenHint, provider.IDTokenHintVerifier(ctx))
}
// ValidateAuthReqPrompt validates the passed prompt values and sets max_age to 0 if prompt login is present
func ValidateAuthReqPrompt(prompts []string, maxAge *uint) (_ *uint, err error) {
for _, prompt := range prompts {

View file

@ -50,37 +50,17 @@ var (
)
func (s *LegacyServer) Authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) {
authReq := r.Data
if authReq.RequestParam != "" && s.provider.RequestObjectSupported() {
authReq, err = ParseRequestObject(ctx, authReq, s.provider.Storage(), IssuerFromContext(ctx))
if err != nil {
return nil, NewStatusError(err, http.StatusBadRequest)
}
}
if authReq.ClientID == "" {
return TryErrorRedirect(ctx, authReq, ErrAuthReqMissingClientID, s.provider.Encoder(), s.provider.Logger())
}
if authReq.RedirectURI == "" {
return TryErrorRedirect(ctx, authReq, ErrAuthReqMissingRedirectURI, s.provider.Encoder(), s.provider.Logger())
}
validation := ValidateAuthRequest
if validater, ok := s.provider.(AuthorizeValidator); ok {
validation = validater.ValidateAuthRequest
}
userID, err := validation(ctx, authReq, s.provider.Storage(), s.provider.IDTokenHintVerifier(ctx))
userID, err := ValidateAuthRequestV2(ctx, r.Data, s.provider)
if err != nil {
return TryErrorRedirect(ctx, authReq, err, s.provider.Encoder(), s.provider.Logger())
return nil, err
}
if authReq.RequestParam != "" {
return TryErrorRedirect(ctx, authReq, oidc.ErrRequestNotSupported(), s.provider.Encoder(), s.provider.Logger())
}
req, err := s.provider.Storage().CreateAuthRequest(ctx, authReq, userID)
req, err := s.provider.Storage().CreateAuthRequest(ctx, r.Data, userID)
if err != nil {
return TryErrorRedirect(ctx, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger())
return TryErrorRedirect(ctx, r.Data, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger())
}
client, err := s.provider.Storage().GetClientByClientID(ctx, req.GetClientID())
if err != nil {
return TryErrorRedirect(ctx, authReq, oidc.DefaultToServerError(err, "unable to retrieve client by id"), s.provider.Encoder(), s.provider.Logger())
return TryErrorRedirect(ctx, r.Data, oidc.DefaultToServerError(err, "unable to retrieve client by id"), s.provider.Encoder(), s.provider.Logger())
}
return NewRedirect(client.LoginURL(req.GetID())), nil
}