diff --git a/pkg/client/rp/verifier.go b/pkg/client/rp/verifier.go index ca59454..98221d9 100644 --- a/pkg/client/rp/verifier.go +++ b/pkg/client/rp/verifier.go @@ -10,15 +10,18 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" ) +// VerifyTokenFn is a custom verification function for use in [VerifyIDToken] +type VerifyTokenFn func(claims oidc.Claims, v *IDTokenVerifier) error + // VerifyTokens implement the Token Response Validation as defined in OIDC specification // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation -func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier) (claims C, err error) { +func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier, options ...VerifyTokenFn) (claims C, err error) { ctx, span := client.Tracer.Start(ctx, "VerifyTokens") defer span.End() var nilClaims C - claims, err = VerifyIDToken[C](ctx, idToken, v) + claims, err = VerifyIDToken[C](ctx, idToken, v, options...) if err != nil { return nilClaims, err } @@ -30,7 +33,7 @@ func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken str // VerifyIDToken validates the id token according to // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation -func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier) (claims C, err error) { +func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier, options ...VerifyTokenFn) (claims C, err error) { ctx, span := client.Tracer.Start(ctx, "VerifyIDToken") defer span.End() @@ -57,10 +60,6 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenV return nilClaims, err } - if err = oidc.CheckAuthorizedParty(claims, v.ClientID); err != nil { - return nilClaims, err - } - if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil { return nilClaims, err } @@ -86,9 +85,23 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenV if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil { return nilClaims, err } + + for _, verifyFn := range options { + if err := verifyFn(claims, v); err != nil { + return nilClaims, err + } + } + return claims, nil } +// WithCheckAuthorizedParty checks azp (authorized party) claim requirements. +func WithCheckAuthorizedParty() VerifyTokenFn { + return func(claims oidc.Claims, v *IDTokenVerifier) error { + return oidc.CheckAuthorizedParty(claims, v.ClientID) + } +} + type IDTokenVerifier oidc.Verifier // VerifyAccessToken validates the access token according to