feat: merge the verifier types (#336)
BREAKING CHANGE: - The various verifier types are merged into a oidc.Verifir. - oidc.Verfier became a struct with exported fields * use type aliases for oidc.Verifier this binds the correct contstructor to each verifier usecase. * fix: handle the zero cases for oidc.Time * add unit tests to oidc verifier * fix: correct returned field for JWTTokenRequest JWTTokenRequest.GetIssuedAt() was returning the ExpiresAt field. This change corrects that by returning IssuedAt instead.
This commit is contained in:
parent
c8cf15e266
commit
33c716ddcf
29 changed files with 948 additions and 351 deletions
|
@ -8,6 +8,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/muhlemmer/gu"
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
)
|
)
|
||||||
|
@ -17,7 +18,7 @@ type KeySet struct{}
|
||||||
|
|
||||||
// VerifySignature implments op.KeySet.
|
// VerifySignature implments op.KeySet.
|
||||||
func (KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
|
func (KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
|
||||||
if ctx.Err() != nil {
|
if err = ctx.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,6 +46,16 @@ func init() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type JWTProfileKeyStorage struct{}
|
||||||
|
|
||||||
|
func (JWTProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID string, clientID string) (*jose.JSONWebKey, error) {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return gu.Ptr(WebKey.Public()), nil
|
||||||
|
}
|
||||||
|
|
||||||
func signEncodeTokenClaims(claims any) string {
|
func signEncodeTokenClaims(claims any) string {
|
||||||
payload, err := json.Marshal(claims)
|
payload, err := json.Marshal(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -106,6 +117,25 @@ func NewAccessToken(issuer, subject string, audience []string, expiration time.T
|
||||||
return NewAccessTokenCustom(issuer, subject, audience, expiration, jwtid, clientID, skew, nil)
|
return NewAccessTokenCustom(issuer, subject, audience, expiration, jwtid, clientID, skew, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewJWTProfileAssertion(issuer, clientID string, audience []string, issuedAt, expiration time.Time) (string, *oidc.JWTTokenRequest) {
|
||||||
|
req := &oidc.JWTTokenRequest{
|
||||||
|
Issuer: issuer,
|
||||||
|
Subject: clientID,
|
||||||
|
Audience: audience,
|
||||||
|
ExpiresAt: oidc.FromTime(expiration),
|
||||||
|
IssuedAt: oidc.FromTime(issuedAt),
|
||||||
|
}
|
||||||
|
// make sure the private claim map is set correctly
|
||||||
|
data, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err = json.Unmarshal(data, req); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return signEncodeTokenClaims(req), req
|
||||||
|
}
|
||||||
|
|
||||||
const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg`
|
const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg`
|
||||||
|
|
||||||
// These variables always result in a valid token
|
// These variables always result in a valid token
|
||||||
|
@ -137,6 +167,10 @@ func ValidAccessToken() (string, *oidc.AccessTokenClaims) {
|
||||||
return NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew)
|
return NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ValidJWTProfileAssertion() (string, *oidc.JWTTokenRequest) {
|
||||||
|
return NewJWTProfileAssertion(ValidClientID, ValidClientID, []string{ValidIssuer}, time.Now(), ValidExpiration)
|
||||||
|
}
|
||||||
|
|
||||||
// ACRVerify is a oidc.ACRVerifier func.
|
// ACRVerify is a oidc.ACRVerifier func.
|
||||||
func ACRVerify(acr string) error {
|
func ACRVerify(acr string) error {
|
||||||
if acr != ValidACR {
|
if acr != ValidACR {
|
||||||
|
|
|
@ -63,8 +63,8 @@ type RelyingParty interface {
|
||||||
// be used to start a DeviceAuthorization flow.
|
// be used to start a DeviceAuthorization flow.
|
||||||
GetDeviceAuthorizationEndpoint() string
|
GetDeviceAuthorizationEndpoint() string
|
||||||
|
|
||||||
// IDTokenVerifier returns the verifier interface used for oidc id_token verification
|
// IDTokenVerifier returns the verifier used for oidc id_token verification
|
||||||
IDTokenVerifier() IDTokenVerifier
|
IDTokenVerifier() *IDTokenVerifier
|
||||||
// ErrorHandler returns the handler used for callback errors
|
// ErrorHandler returns the handler used for callback errors
|
||||||
|
|
||||||
ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)
|
ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)
|
||||||
|
@ -88,7 +88,7 @@ type relyingParty struct {
|
||||||
cookieHandler *httphelper.CookieHandler
|
cookieHandler *httphelper.CookieHandler
|
||||||
|
|
||||||
errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
|
errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
|
||||||
idTokenVerifier IDTokenVerifier
|
idTokenVerifier *IDTokenVerifier
|
||||||
verifierOpts []VerifierOption
|
verifierOpts []VerifierOption
|
||||||
signer jose.Signer
|
signer jose.Signer
|
||||||
}
|
}
|
||||||
|
@ -137,7 +137,7 @@ func (rp *relyingParty) GetRevokeEndpoint() string {
|
||||||
return rp.endpoints.RevokeURL
|
return rp.endpoints.RevokeURL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *relyingParty) IDTokenVerifier() IDTokenVerifier {
|
func (rp *relyingParty) IDTokenVerifier() *IDTokenVerifier {
|
||||||
if rp.idTokenVerifier == nil {
|
if rp.idTokenVerifier == nil {
|
||||||
rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...)
|
rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...)
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,19 +9,9 @@ import (
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IDTokenVerifier interface {
|
|
||||||
oidc.Verifier
|
|
||||||
ClientID() string
|
|
||||||
SupportedSignAlgs() []string
|
|
||||||
KeySet() oidc.KeySet
|
|
||||||
Nonce(context.Context) string
|
|
||||||
ACR() oidc.ACRVerifier
|
|
||||||
MaxAge() time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// VerifyTokens implement the Token Response Validation as defined in OIDC specification
|
// VerifyTokens implement the Token Response Validation as defined in OIDC specification
|
||||||
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
||||||
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) {
|
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier) (claims C, err error) {
|
||||||
var nilClaims C
|
var nilClaims C
|
||||||
|
|
||||||
claims, err = VerifyIDToken[C](ctx, idToken, v)
|
claims, err = VerifyIDToken[C](ctx, idToken, v)
|
||||||
|
@ -36,7 +26,7 @@ func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken str
|
||||||
|
|
||||||
// VerifyIDToken validates the id token according to
|
// VerifyIDToken validates the id token according to
|
||||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||||
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) {
|
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier) (claims C, err error) {
|
||||||
var nilClaims C
|
var nilClaims C
|
||||||
|
|
||||||
decrypted, err := oidc.DecryptToken(token)
|
decrypted, err := oidc.DecryptToken(token)
|
||||||
|
@ -52,27 +42,27 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
if err = oidc.CheckIssuer(claims, v.Issuer); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckAudience(claims, v.ClientID()); err != nil {
|
if err = oidc.CheckAudience(claims, v.ClientID); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil {
|
if err = oidc.CheckAuthorizedParty(claims, v.ClientID); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
|
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
|
if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
|
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,16 +70,18 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
|
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
|
if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type IDTokenVerifier oidc.Verifier
|
||||||
|
|
||||||
// VerifyAccessToken validates the access token according to
|
// VerifyAccessToken validates the access token according to
|
||||||
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
|
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
|
||||||
func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
|
func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
|
||||||
|
@ -107,15 +99,14 @@ func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAl
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewIDTokenVerifier returns an implementation of `IDTokenVerifier`
|
// NewIDTokenVerifier returns a oidc.Verifier suitable for ID token verification.
|
||||||
// for `VerifyTokens` and `VerifyIDToken`
|
func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) *IDTokenVerifier {
|
||||||
func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) IDTokenVerifier {
|
v := &IDTokenVerifier{
|
||||||
v := &idTokenVerifier{
|
Issuer: issuer,
|
||||||
issuer: issuer,
|
ClientID: clientID,
|
||||||
clientID: clientID,
|
KeySet: keySet,
|
||||||
keySet: keySet,
|
Offset: time.Second,
|
||||||
offset: time.Second,
|
Nonce: func(_ context.Context) string {
|
||||||
nonce: func(_ context.Context) string {
|
|
||||||
return ""
|
return ""
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -128,95 +119,47 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifierOption is the type for providing dynamic options to the IDTokenVerifier
|
// VerifierOption is the type for providing dynamic options to the IDTokenVerifier
|
||||||
type VerifierOption func(*idTokenVerifier)
|
type VerifierOption func(*IDTokenVerifier)
|
||||||
|
|
||||||
// WithIssuedAtOffset mitigates the risk of iat to be in the future
|
// WithIssuedAtOffset mitigates the risk of iat to be in the future
|
||||||
// because of clock skews with the ability to add an offset to the current time
|
// because of clock skews with the ability to add an offset to the current time
|
||||||
func WithIssuedAtOffset(offset time.Duration) func(*idTokenVerifier) {
|
func WithIssuedAtOffset(offset time.Duration) VerifierOption {
|
||||||
return func(v *idTokenVerifier) {
|
return func(v *IDTokenVerifier) {
|
||||||
v.offset = offset
|
v.Offset = offset
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
|
// WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
|
||||||
func WithIssuedAtMaxAge(maxAge time.Duration) func(*idTokenVerifier) {
|
func WithIssuedAtMaxAge(maxAge time.Duration) VerifierOption {
|
||||||
return func(v *idTokenVerifier) {
|
return func(v *IDTokenVerifier) {
|
||||||
v.maxAgeIAT = maxAge
|
v.MaxAgeIAT = maxAge
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithNonce sets the function to check the nonce
|
// WithNonce sets the function to check the nonce
|
||||||
func WithNonce(nonce func(context.Context) string) VerifierOption {
|
func WithNonce(nonce func(context.Context) string) VerifierOption {
|
||||||
return func(v *idTokenVerifier) {
|
return func(v *IDTokenVerifier) {
|
||||||
v.nonce = nonce
|
v.Nonce = nonce
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithACRVerifier sets the verifier for the acr claim
|
// WithACRVerifier sets the verifier for the acr claim
|
||||||
func WithACRVerifier(verifier oidc.ACRVerifier) VerifierOption {
|
func WithACRVerifier(verifier oidc.ACRVerifier) VerifierOption {
|
||||||
return func(v *idTokenVerifier) {
|
return func(v *IDTokenVerifier) {
|
||||||
v.acr = verifier
|
v.ACR = verifier
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now
|
// WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now
|
||||||
func WithAuthTimeMaxAge(maxAge time.Duration) VerifierOption {
|
func WithAuthTimeMaxAge(maxAge time.Duration) VerifierOption {
|
||||||
return func(v *idTokenVerifier) {
|
return func(v *IDTokenVerifier) {
|
||||||
v.maxAge = maxAge
|
v.MaxAge = maxAge
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm
|
// WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm
|
||||||
func WithSupportedSigningAlgorithms(algs ...string) VerifierOption {
|
func WithSupportedSigningAlgorithms(algs ...string) VerifierOption {
|
||||||
return func(v *idTokenVerifier) {
|
return func(v *IDTokenVerifier) {
|
||||||
v.supportedSignAlgs = algs
|
v.SupportedSignAlgs = algs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type idTokenVerifier struct {
|
|
||||||
issuer string
|
|
||||||
maxAgeIAT time.Duration
|
|
||||||
offset time.Duration
|
|
||||||
clientID string
|
|
||||||
supportedSignAlgs []string
|
|
||||||
keySet oidc.KeySet
|
|
||||||
acr oidc.ACRVerifier
|
|
||||||
maxAge time.Duration
|
|
||||||
nonce func(ctx context.Context) string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenVerifier) Issuer() string {
|
|
||||||
return i.issuer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenVerifier) MaxAgeIAT() time.Duration {
|
|
||||||
return i.maxAgeIAT
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenVerifier) Offset() time.Duration {
|
|
||||||
return i.offset
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenVerifier) ClientID() string {
|
|
||||||
return i.clientID
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenVerifier) SupportedSignAlgs() []string {
|
|
||||||
return i.supportedSignAlgs
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenVerifier) KeySet() oidc.KeySet {
|
|
||||||
return i.keySet
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenVerifier) Nonce(ctx context.Context) string {
|
|
||||||
return i.nonce(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenVerifier) ACR() oidc.ACRVerifier {
|
|
||||||
return i.acr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenVerifier) MaxAge() time.Duration {
|
|
||||||
return i.maxAge
|
|
||||||
}
|
|
||||||
|
|
|
@ -13,16 +13,16 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestVerifyTokens(t *testing.T) {
|
func TestVerifyTokens(t *testing.T) {
|
||||||
verifier := &idTokenVerifier{
|
verifier := &IDTokenVerifier{
|
||||||
issuer: tu.ValidIssuer,
|
Issuer: tu.ValidIssuer,
|
||||||
maxAgeIAT: 2 * time.Minute,
|
MaxAgeIAT: 2 * time.Minute,
|
||||||
offset: time.Second,
|
Offset: time.Second,
|
||||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||||
keySet: tu.KeySet{},
|
KeySet: tu.KeySet{},
|
||||||
maxAge: 2 * time.Minute,
|
MaxAge: 2 * time.Minute,
|
||||||
acr: tu.ACRVerify,
|
ACR: tu.ACRVerify,
|
||||||
nonce: func(context.Context) string { return tu.ValidNonce },
|
Nonce: func(context.Context) string { return tu.ValidNonce },
|
||||||
clientID: tu.ValidClientID,
|
ClientID: tu.ValidClientID,
|
||||||
}
|
}
|
||||||
accessToken, _ := tu.ValidAccessToken()
|
accessToken, _ := tu.ValidAccessToken()
|
||||||
atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm)
|
atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm)
|
||||||
|
@ -91,15 +91,15 @@ func TestVerifyTokens(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVerifyIDToken(t *testing.T) {
|
func TestVerifyIDToken(t *testing.T) {
|
||||||
verifier := &idTokenVerifier{
|
verifier := &IDTokenVerifier{
|
||||||
issuer: tu.ValidIssuer,
|
Issuer: tu.ValidIssuer,
|
||||||
maxAgeIAT: 2 * time.Minute,
|
MaxAgeIAT: 2 * time.Minute,
|
||||||
offset: time.Second,
|
Offset: time.Second,
|
||||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||||
keySet: tu.KeySet{},
|
KeySet: tu.KeySet{},
|
||||||
maxAge: 2 * time.Minute,
|
MaxAge: 2 * time.Minute,
|
||||||
acr: tu.ACRVerify,
|
ACR: tu.ACRVerify,
|
||||||
nonce: func(context.Context) string { return tu.ValidNonce },
|
Nonce: func(context.Context) string { return tu.ValidNonce },
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
@ -219,7 +219,7 @@ func TestVerifyIDToken(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
token, want := tt.tokenClaims()
|
token, want := tt.tokenClaims()
|
||||||
verifier.clientID = tt.clientID
|
verifier.ClientID = tt.clientID
|
||||||
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
|
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
@ -300,7 +300,7 @@ func TestNewIDTokenVerifier(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
want IDTokenVerifier
|
want *IDTokenVerifier
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "nil nonce", // otherwise assert.Equal will fail on the function
|
name: "nil nonce", // otherwise assert.Equal will fail on the function
|
||||||
|
@ -317,16 +317,16 @@ func TestNewIDTokenVerifier(t *testing.T) {
|
||||||
WithSupportedSigningAlgorithms("ABC", "DEF"),
|
WithSupportedSigningAlgorithms("ABC", "DEF"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: &idTokenVerifier{
|
want: &IDTokenVerifier{
|
||||||
issuer: tu.ValidIssuer,
|
Issuer: tu.ValidIssuer,
|
||||||
offset: time.Minute,
|
Offset: time.Minute,
|
||||||
maxAgeIAT: time.Hour,
|
MaxAgeIAT: time.Hour,
|
||||||
clientID: tu.ValidClientID,
|
ClientID: tu.ValidClientID,
|
||||||
keySet: tu.KeySet{},
|
KeySet: tu.KeySet{},
|
||||||
nonce: nil,
|
Nonce: nil,
|
||||||
acr: nil,
|
ACR: nil,
|
||||||
maxAge: 2 * time.Hour,
|
MaxAge: 2 * time.Hour,
|
||||||
supportedSignAlgs: []string{"ABC", "DEF"},
|
SupportedSignAlgs: []string{"ABC", "DEF"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -192,7 +192,7 @@ func (j *JWTTokenRequest) GetExpiration() time.Time {
|
||||||
|
|
||||||
// GetIssuedAt implements the Claims interface
|
// GetIssuedAt implements the Claims interface
|
||||||
func (j *JWTTokenRequest) GetIssuedAt() time.Time {
|
func (j *JWTTokenRequest) GetIssuedAt() time.Time {
|
||||||
return j.ExpiresAt.AsTime()
|
return j.IssuedAt.AsTime()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNonce implements the Claims interface
|
// GetNonce implements the Claims interface
|
||||||
|
|
|
@ -173,10 +173,16 @@ func NewEncoder() *schema.Encoder {
|
||||||
type Time int64
|
type Time int64
|
||||||
|
|
||||||
func (ts Time) AsTime() time.Time {
|
func (ts Time) AsTime() time.Time {
|
||||||
|
if ts == 0 {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
return time.Unix(int64(ts), 0)
|
return time.Unix(int64(ts), 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func FromTime(tt time.Time) Time {
|
func FromTime(tt time.Time) Time {
|
||||||
|
if tt.IsZero() {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
return Time(tt.Unix())
|
return Time(tt.Unix())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/schema"
|
"github.com/gorilla/schema"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -467,6 +468,56 @@ func TestNewEncoder(t *testing.T) {
|
||||||
assert.Equal(t, a, b)
|
assert.Equal(t, a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTime_AsTime(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ts Time
|
||||||
|
want time.Time
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "unset",
|
||||||
|
ts: 0,
|
||||||
|
want: time.Time{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set",
|
||||||
|
ts: 1,
|
||||||
|
want: time.Unix(1, 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := tt.ts.AsTime()
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTime_FromTime(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tt time.Time
|
||||||
|
want Time
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "zero",
|
||||||
|
tt: time.Time{},
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set",
|
||||||
|
tt: time.Unix(1, 0),
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := FromTime(tt.tt)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTime_UnmarshalJSON(t *testing.T) {
|
func TestTime_UnmarshalJSON(t *testing.T) {
|
||||||
type dst struct {
|
type dst struct {
|
||||||
UpdatedAt Time `json:"updated_at"`
|
UpdatedAt Time `json:"updated_at"`
|
||||||
|
|
|
@ -61,10 +61,19 @@ var (
|
||||||
ErrAtHash = errors.New("at_hash does not correspond to access token")
|
ErrAtHash = errors.New("at_hash does not correspond to access token")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Verifier interface {
|
// Verifier caries configuration for the various token verification
|
||||||
Issuer() string
|
// functions. Use package specific constructor functions to know
|
||||||
MaxAgeIAT() time.Duration
|
// which values need to be set.
|
||||||
Offset() time.Duration
|
type Verifier struct {
|
||||||
|
Issuer string
|
||||||
|
MaxAgeIAT time.Duration
|
||||||
|
Offset time.Duration
|
||||||
|
ClientID string
|
||||||
|
SupportedSignAlgs []string
|
||||||
|
MaxAge time.Duration
|
||||||
|
ACR ACRVerifier
|
||||||
|
KeySet KeySet
|
||||||
|
Nonce func(ctx context.Context) string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
|
// ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
|
||||||
|
@ -121,6 +130,11 @@ func CheckAudience(claims Claims, clientID string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckAuthorizedParty checks azp (authorized party) claim requirements.
|
||||||
|
//
|
||||||
|
// If the ID Token contains multiple audiences, the Client SHOULD verify that an azp Claim is present.
|
||||||
|
// If an azp Claim is present, the Client SHOULD verify that its client_id is the Claim Value.
|
||||||
|
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||||
func CheckAuthorizedParty(claims Claims, clientID string) error {
|
func CheckAuthorizedParty(claims Claims, clientID string) error {
|
||||||
if len(claims.GetAudience()) > 1 {
|
if len(claims.GetAudience()) > 1 {
|
||||||
if claims.GetAuthorizedParty() == "" {
|
if claims.GetAuthorizedParty() == "" {
|
||||||
|
@ -167,26 +181,26 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckExpiration(claims Claims, offset time.Duration) error {
|
func CheckExpiration(claims Claims, offset time.Duration) error {
|
||||||
expiration := claims.GetExpiration().Round(time.Second)
|
expiration := claims.GetExpiration()
|
||||||
if !time.Now().UTC().Add(offset).Before(expiration) {
|
if !time.Now().Add(offset).Before(expiration) {
|
||||||
return ErrExpired
|
return ErrExpired
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error {
|
func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error {
|
||||||
issuedAt := claims.GetIssuedAt().Round(time.Second)
|
issuedAt := claims.GetIssuedAt()
|
||||||
if issuedAt.IsZero() {
|
if issuedAt.IsZero() {
|
||||||
return ErrIatMissing
|
return ErrIatMissing
|
||||||
}
|
}
|
||||||
nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second)
|
nowWithOffset := time.Now().Add(offset).Round(time.Second)
|
||||||
if issuedAt.After(nowWithOffset) {
|
if issuedAt.After(nowWithOffset) {
|
||||||
return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset)
|
return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset)
|
||||||
}
|
}
|
||||||
if maxAgeIAT == 0 {
|
if maxAgeIAT == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
maxAge := time.Now().UTC().Add(-maxAgeIAT).Round(time.Second)
|
maxAge := time.Now().Add(-maxAgeIAT).Round(time.Second)
|
||||||
if issuedAt.Before(maxAge) {
|
if issuedAt.Before(maxAge) {
|
||||||
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt))
|
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt))
|
||||||
}
|
}
|
||||||
|
@ -216,8 +230,8 @@ func CheckAuthTime(claims Claims, maxAge time.Duration) error {
|
||||||
if claims.GetAuthTime().IsZero() {
|
if claims.GetAuthTime().IsZero() {
|
||||||
return ErrAuthTimeNotPresent
|
return ErrAuthTimeNotPresent
|
||||||
}
|
}
|
||||||
authTime := claims.GetAuthTime().Round(time.Second)
|
authTime := claims.GetAuthTime()
|
||||||
maxAuthTime := time.Now().UTC().Add(-maxAge).Round(time.Second)
|
maxAuthTime := time.Now().Add(-maxAge).Round(time.Second)
|
||||||
if authTime.Before(maxAuthTime) {
|
if authTime.Before(maxAuthTime) {
|
||||||
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime))
|
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime))
|
||||||
}
|
}
|
||||||
|
|
128
pkg/oidc/verifier_parse_test.go
Normal file
128
pkg/oidc/verifier_parse_test.go
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
package oidc_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
tu "github.com/zitadel/oidc/v3/internal/testutil"
|
||||||
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseToken(t *testing.T) {
|
||||||
|
token, wantClaims := tu.ValidIDToken()
|
||||||
|
wantClaims.SignatureAlg = "" // unset, because is not part of the JSON payload
|
||||||
|
|
||||||
|
wantPayload, err := json.Marshal(wantClaims)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tokenString string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "split error",
|
||||||
|
tokenString: "nope",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "base64 error",
|
||||||
|
tokenString: "foo.~.bar",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
tokenString: token,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotClaims := new(oidc.IDTokenClaims)
|
||||||
|
gotPayload, err := oidc.ParseToken(tt.tokenString, gotClaims)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, wantClaims, gotClaims)
|
||||||
|
assert.JSONEq(t, string(wantPayload), string(gotPayload))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckSignature(t *testing.T) {
|
||||||
|
errCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
token, _ := tu.ValidIDToken()
|
||||||
|
payload, err := oidc.ParseToken(token, &oidc.IDTokenClaims{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
token string
|
||||||
|
payload []byte
|
||||||
|
supportedSigAlgs []string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "parse error",
|
||||||
|
args: args{
|
||||||
|
ctx: context.Background(),
|
||||||
|
token: "~",
|
||||||
|
payload: payload,
|
||||||
|
},
|
||||||
|
wantErr: oidc.ErrParse,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default sigAlg",
|
||||||
|
args: args{
|
||||||
|
ctx: context.Background(),
|
||||||
|
token: token,
|
||||||
|
payload: payload,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported sigAlg",
|
||||||
|
args: args{
|
||||||
|
ctx: context.Background(),
|
||||||
|
token: token,
|
||||||
|
payload: payload,
|
||||||
|
supportedSigAlgs: []string{"foo", "bar"},
|
||||||
|
},
|
||||||
|
wantErr: oidc.ErrSignatureUnsupportedAlg,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "verify error",
|
||||||
|
args: args{
|
||||||
|
ctx: errCtx,
|
||||||
|
token: token,
|
||||||
|
payload: payload,
|
||||||
|
},
|
||||||
|
wantErr: oidc.ErrSignatureInvalid,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "inequal payloads",
|
||||||
|
args: args{
|
||||||
|
ctx: context.Background(),
|
||||||
|
token: token,
|
||||||
|
payload: []byte{0, 1, 2},
|
||||||
|
},
|
||||||
|
wantErr: oidc.ErrSignatureInvalidPayload,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
claims := new(oidc.TokenClaims)
|
||||||
|
err := oidc.CheckSignature(tt.args.ctx, tt.args.token, tt.args.payload, claims, tt.args.supportedSigAlgs, tu.KeySet{})
|
||||||
|
assert.ErrorIs(t, err, tt.wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
374
pkg/oidc/verifier_test.go
Normal file
374
pkg/oidc/verifier_test.go
Normal file
|
@ -0,0 +1,374 @@
|
||||||
|
package oidc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDecryptToken(t *testing.T) {
|
||||||
|
const tokenString = "ABC"
|
||||||
|
got, err := DecryptToken(tokenString)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tokenString, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultACRVerifier(t *testing.T) {
|
||||||
|
acrVerfier := DefaultACRVerifier([]string{"foo", "bar"})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
acr string
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
acr: "bar",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error",
|
||||||
|
acr: "hello",
|
||||||
|
wantErr: "expected one of: [foo bar], got: \"hello\"",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := acrVerfier(tt.acr)
|
||||||
|
if tt.wantErr != "" {
|
||||||
|
assert.EqualError(t, err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckSubject(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
claims Claims
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing",
|
||||||
|
claims: &TokenClaims{},
|
||||||
|
wantErr: ErrSubjectMissing,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
Subject: "foo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := CheckSubject(tt.claims)
|
||||||
|
assert.ErrorIs(t, err, tt.wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckIssuer(t *testing.T) {
|
||||||
|
const issuer = "foo.bar"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
claims Claims
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing",
|
||||||
|
claims: &TokenClaims{},
|
||||||
|
wantErr: ErrIssuerInvalid,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
Issuer: "wrong",
|
||||||
|
},
|
||||||
|
wantErr: ErrIssuerInvalid,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
Issuer: issuer,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := CheckIssuer(tt.claims, issuer)
|
||||||
|
assert.ErrorIs(t, err, tt.wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckAudience(t *testing.T) {
|
||||||
|
const clientID = "foo.bar"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
claims Claims
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing",
|
||||||
|
claims: &TokenClaims{},
|
||||||
|
wantErr: ErrAudience,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
Audience: []string{"wrong"},
|
||||||
|
},
|
||||||
|
wantErr: ErrAudience,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
Audience: []string{clientID},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := CheckAudience(tt.claims, clientID)
|
||||||
|
assert.ErrorIs(t, err, tt.wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckAuthorizedParty(t *testing.T) {
|
||||||
|
const clientID = "foo.bar"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
claims Claims
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single audience, no azp",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
Audience: []string{clientID},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple audience, no azp",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
Audience: []string{clientID, "other"},
|
||||||
|
},
|
||||||
|
wantErr: ErrAzpMissing,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single audience, with azp",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
Audience: []string{clientID},
|
||||||
|
AuthorizedParty: clientID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple audience, with azp",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
Audience: []string{clientID, "other"},
|
||||||
|
AuthorizedParty: clientID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong azp",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
AuthorizedParty: "wrong",
|
||||||
|
},
|
||||||
|
wantErr: ErrAzpInvalid,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := CheckAuthorizedParty(tt.claims, clientID)
|
||||||
|
assert.ErrorIs(t, err, tt.wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckExpiration(t *testing.T) {
|
||||||
|
const offset = time.Minute
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
claims Claims
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing",
|
||||||
|
claims: &TokenClaims{},
|
||||||
|
wantErr: ErrExpired,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
Expiration: FromTime(time.Now().Add(-2 * offset)),
|
||||||
|
},
|
||||||
|
wantErr: ErrExpired,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
Expiration: FromTime(time.Now().Add(2 * offset)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := CheckExpiration(tt.claims, offset)
|
||||||
|
assert.ErrorIs(t, err, tt.wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckIssuedAt(t *testing.T) {
|
||||||
|
const offset = time.Minute
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
maxAgeIAT time.Duration
|
||||||
|
claims Claims
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing",
|
||||||
|
claims: &TokenClaims{},
|
||||||
|
wantErr: ErrIatMissing,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "future",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
IssuedAt: FromTime(time.Now().Add(time.Hour)),
|
||||||
|
},
|
||||||
|
wantErr: ErrIatInFuture,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no max",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
IssuedAt: FromTime(time.Now()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "past max",
|
||||||
|
maxAgeIAT: time.Minute,
|
||||||
|
claims: &TokenClaims{
|
||||||
|
IssuedAt: FromTime(time.Now().Add(-time.Hour)),
|
||||||
|
},
|
||||||
|
wantErr: ErrIatToOld,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "within max",
|
||||||
|
maxAgeIAT: time.Hour,
|
||||||
|
claims: &TokenClaims{
|
||||||
|
IssuedAt: FromTime(time.Now()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := CheckIssuedAt(tt.claims, tt.maxAgeIAT, offset)
|
||||||
|
assert.ErrorIs(t, err, tt.wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckNonce(t *testing.T) {
|
||||||
|
const nonce = "123"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
claims Claims
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing",
|
||||||
|
claims: &TokenClaims{},
|
||||||
|
wantErr: ErrNonceInvalid,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
Nonce: "wrong",
|
||||||
|
},
|
||||||
|
wantErr: ErrNonceInvalid,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
claims: &TokenClaims{
|
||||||
|
Nonce: nonce,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := CheckNonce(tt.claims, nonce)
|
||||||
|
assert.ErrorIs(t, err, tt.wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckAuthorizationContextClassReference(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
acr ACRVerifier
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "error",
|
||||||
|
acr: func(s string) error { return errors.New("oops") },
|
||||||
|
wantErr: ErrAcrInvalid,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
acr: func(s string) error { return nil },
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := CheckAuthorizationContextClassReference(&IDTokenClaims{}, tt.acr)
|
||||||
|
assert.ErrorIs(t, err, tt.wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckAuthTime(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
claims Claims
|
||||||
|
maxAge time.Duration
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no max age",
|
||||||
|
claims: &TokenClaims{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing",
|
||||||
|
claims: &TokenClaims{},
|
||||||
|
maxAge: time.Minute,
|
||||||
|
wantErr: ErrAuthTimeNotPresent,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired",
|
||||||
|
maxAge: time.Minute,
|
||||||
|
claims: &TokenClaims{
|
||||||
|
AuthTime: FromTime(time.Now().Add(-time.Hour)),
|
||||||
|
},
|
||||||
|
wantErr: ErrAuthTimeToOld,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
maxAge: time.Minute,
|
||||||
|
claims: &TokenClaims{
|
||||||
|
AuthTime: NowTime(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := CheckAuthTime(tt.claims, tt.maxAge)
|
||||||
|
assert.ErrorIs(t, err, tt.wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -38,7 +38,7 @@ type Authorizer interface {
|
||||||
Storage() Storage
|
Storage() Storage
|
||||||
Decoder() httphelper.Decoder
|
Decoder() httphelper.Decoder
|
||||||
Encoder() httphelper.Encoder
|
Encoder() httphelper.Encoder
|
||||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
|
||||||
Crypto() Crypto
|
Crypto() Crypto
|
||||||
RequestObjectSupported() bool
|
RequestObjectSupported() bool
|
||||||
}
|
}
|
||||||
|
@ -47,7 +47,7 @@ type Authorizer interface {
|
||||||
// implementing its own validation mechanism for the auth request
|
// implementing its own validation mechanism for the auth request
|
||||||
type AuthorizeValidator interface {
|
type AuthorizeValidator interface {
|
||||||
Authorizer
|
Authorizer
|
||||||
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error)
|
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, *IDTokenHintVerifier) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) {
|
func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) {
|
||||||
|
@ -204,7 +204,7 @@ func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oi
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed
|
// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed
|
||||||
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) {
|
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) {
|
||||||
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
|
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -384,7 +384,7 @@ func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType)
|
||||||
|
|
||||||
// ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request)
|
// ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request)
|
||||||
// and returns the `sub` claim
|
// and returns the `sub` claim
|
||||||
func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier IDTokenHintVerifier) (string, error) {
|
func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier *IDTokenHintVerifier) (string, error) {
|
||||||
if idTokenHint == "" {
|
if idTokenHint == "" {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/gorilla/schema"
|
"github.com/gorilla/schema"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
tu "github.com/zitadel/oidc/v3/internal/testutil"
|
||||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
"github.com/zitadel/oidc/v3/pkg/op"
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
|
@ -146,7 +147,7 @@ func TestValidateAuthRequest(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
authRequest *oidc.AuthRequest
|
authRequest *oidc.AuthRequest
|
||||||
storage op.Storage
|
storage op.Storage
|
||||||
verifier op.IDTokenHintVerifier
|
verifier *op.IDTokenHintVerifier
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -1003,3 +1004,34 @@ func Test_parseAuthorizeCallbackRequest(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateAuthReqIDTokenHint(t *testing.T) {
|
||||||
|
token, _ := tu.ValidIDToken()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
idTokenHint string
|
||||||
|
want string
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "verify err",
|
||||||
|
idTokenHint: "foo",
|
||||||
|
wantErr: oidc.ErrLoginRequired(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
idTokenHint: token,
|
||||||
|
want: tu.ValidSubject,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := op.ValidateAuthReqIDTokenHint(context.Background(), tt.idTokenHint, op.NewIDTokenHintVerifier(tu.ValidIssuer, tu.KeySet{}))
|
||||||
|
require.ErrorIs(t, err, tt.wantErr)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -81,7 +81,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
type ClientJWTProfile interface {
|
type ClientJWTProfile interface {
|
||||||
JWTProfileVerifier(context.Context) JWTProfileVerifier
|
JWTProfileVerifier(context.Context) *JWTProfileVerifier
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) {
|
func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) {
|
||||||
|
|
|
@ -22,7 +22,7 @@ import (
|
||||||
|
|
||||||
type testClientJWTProfile struct{}
|
type testClientJWTProfile struct{}
|
||||||
|
|
||||||
func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil }
|
func (testClientJWTProfile) JWTProfileVerifier(context.Context) *op.JWTProfileVerifier { return nil }
|
||||||
|
|
||||||
func TestClientJWTAuth(t *testing.T) {
|
func TestClientJWTAuth(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
|
|
|
@ -79,10 +79,10 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
|
||||||
}
|
}
|
||||||
|
|
||||||
// IDTokenHintVerifier mocks base method.
|
// IDTokenHintVerifier mocks base method.
|
||||||
func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) op.IDTokenHintVerifier {
|
func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) *op.IDTokenHintVerifier {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0)
|
ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0)
|
||||||
ret0, _ := ret[0].(op.IDTokenHintVerifier)
|
ret0, _ := ret[0].(*op.IDTokenHintVerifier)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ func ExpectEncoder(a op.Authorizer) {
|
||||||
func ExpectVerifier(a op.Authorizer, t *testing.T) {
|
func ExpectVerifier(a op.Authorizer, t *testing.T) {
|
||||||
mockA := a.(*MockAuthorizer)
|
mockA := a.(*MockAuthorizer)
|
||||||
mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn(
|
mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn(
|
||||||
func() op.IDTokenHintVerifier {
|
func() *op.IDTokenHintVerifier {
|
||||||
return op.NewIDTokenHintVerifier("", nil)
|
return op.NewIDTokenHintVerifier("", nil)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
10
pkg/op/op.go
10
pkg/op/op.go
|
@ -73,8 +73,8 @@ type OpenIDProvider interface {
|
||||||
Storage() Storage
|
Storage() Storage
|
||||||
Decoder() httphelper.Decoder
|
Decoder() httphelper.Decoder
|
||||||
Encoder() httphelper.Encoder
|
Encoder() httphelper.Encoder
|
||||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
|
||||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
AccessTokenVerifier(context.Context) *AccessTokenVerifier
|
||||||
Crypto() Crypto
|
Crypto() Crypto
|
||||||
DefaultLogoutRedirectURI() string
|
DefaultLogoutRedirectURI() string
|
||||||
Probes() []ProbesFn
|
Probes() []ProbesFn
|
||||||
|
@ -342,15 +342,15 @@ func (o *Provider) Encoder() httphelper.Encoder {
|
||||||
return o.encoder
|
return o.encoder
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Provider) IDTokenHintVerifier(ctx context.Context) IDTokenHintVerifier {
|
func (o *Provider) IDTokenHintVerifier(ctx context.Context) *IDTokenHintVerifier {
|
||||||
return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...)
|
return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Provider) JWTProfileVerifier(ctx context.Context) JWTProfileVerifier {
|
func (o *Provider) JWTProfileVerifier(ctx context.Context) *JWTProfileVerifier {
|
||||||
return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second)
|
return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Provider) AccessTokenVerifier(ctx context.Context) AccessTokenVerifier {
|
func (o *Provider) AccessTokenVerifier(ctx context.Context) *AccessTokenVerifier {
|
||||||
return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...)
|
return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
type SessionEnder interface {
|
type SessionEnder interface {
|
||||||
Decoder() httphelper.Decoder
|
Decoder() httphelper.Decoder
|
||||||
Storage() Storage
|
Storage() Storage
|
||||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
|
||||||
DefaultLogoutRedirectURI() string
|
DefaultLogoutRedirectURI() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ type Introspector interface {
|
||||||
Decoder() httphelper.Decoder
|
Decoder() httphelper.Decoder
|
||||||
Crypto() Crypto
|
Crypto() Crypto
|
||||||
Storage() Storage
|
Storage() Storage
|
||||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
AccessTokenVerifier(context.Context) *AccessTokenVerifier
|
||||||
}
|
}
|
||||||
|
|
||||||
type IntrospectorJWTProfile interface {
|
type IntrospectorJWTProfile interface {
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
|
|
||||||
type JWTAuthorizationGrantExchanger interface {
|
type JWTAuthorizationGrantExchanger interface {
|
||||||
Exchanger
|
Exchanger
|
||||||
JWTProfileVerifier(context.Context) JWTProfileVerifier
|
JWTProfileVerifier(context.Context) *JWTProfileVerifier
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant https://tools.ietf.org/html/rfc7523#section-2.1
|
// JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant https://tools.ietf.org/html/rfc7523#section-2.1
|
||||||
|
|
|
@ -20,8 +20,8 @@ type Exchanger interface {
|
||||||
GrantTypeJWTAuthorizationSupported() bool
|
GrantTypeJWTAuthorizationSupported() bool
|
||||||
GrantTypeClientCredentialsSupported() bool
|
GrantTypeClientCredentialsSupported() bool
|
||||||
GrantTypeDeviceCodeSupported() bool
|
GrantTypeDeviceCodeSupported() bool
|
||||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
AccessTokenVerifier(context.Context) *AccessTokenVerifier
|
||||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
|
||||||
}
|
}
|
||||||
|
|
||||||
func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) {
|
func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
|
@ -15,14 +15,14 @@ type Revoker interface {
|
||||||
Decoder() httphelper.Decoder
|
Decoder() httphelper.Decoder
|
||||||
Crypto() Crypto
|
Crypto() Crypto
|
||||||
Storage() Storage
|
Storage() Storage
|
||||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
AccessTokenVerifier(context.Context) *AccessTokenVerifier
|
||||||
AuthMethodPrivateKeyJWTSupported() bool
|
AuthMethodPrivateKeyJWTSupported() bool
|
||||||
AuthMethodPostSupported() bool
|
AuthMethodPostSupported() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type RevokerJWTProfile interface {
|
type RevokerJWTProfile interface {
|
||||||
Revoker
|
Revoker
|
||||||
JWTProfileVerifier(context.Context) JWTProfileVerifier
|
JWTProfileVerifier(context.Context) *JWTProfileVerifier
|
||||||
}
|
}
|
||||||
|
|
||||||
func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) {
|
func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) {
|
||||||
|
|
|
@ -14,7 +14,7 @@ type UserinfoProvider interface {
|
||||||
Decoder() httphelper.Decoder
|
Decoder() httphelper.Decoder
|
||||||
Crypto() Crypto
|
Crypto() Crypto
|
||||||
Storage() Storage
|
Storage() Storage
|
||||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
AccessTokenVerifier(context.Context) *AccessTokenVerifier
|
||||||
}
|
}
|
||||||
|
|
||||||
func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) {
|
func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) {
|
||||||
|
|
|
@ -2,62 +2,25 @@ package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AccessTokenVerifier interface {
|
type AccessTokenVerifier oidc.Verifier
|
||||||
oidc.Verifier
|
|
||||||
SupportedSignAlgs() []string
|
|
||||||
KeySet() oidc.KeySet
|
|
||||||
}
|
|
||||||
|
|
||||||
type accessTokenVerifier struct {
|
type AccessTokenVerifierOpt func(*AccessTokenVerifier)
|
||||||
issuer string
|
|
||||||
maxAgeIAT time.Duration
|
|
||||||
offset time.Duration
|
|
||||||
supportedSignAlgs []string
|
|
||||||
keySet oidc.KeySet
|
|
||||||
}
|
|
||||||
|
|
||||||
// Issuer implements oidc.Verifier interface
|
|
||||||
func (i *accessTokenVerifier) Issuer() string {
|
|
||||||
return i.issuer
|
|
||||||
}
|
|
||||||
|
|
||||||
// MaxAgeIAT implements oidc.Verifier interface
|
|
||||||
func (i *accessTokenVerifier) MaxAgeIAT() time.Duration {
|
|
||||||
return i.maxAgeIAT
|
|
||||||
}
|
|
||||||
|
|
||||||
// Offset implements oidc.Verifier interface
|
|
||||||
func (i *accessTokenVerifier) Offset() time.Duration {
|
|
||||||
return i.offset
|
|
||||||
}
|
|
||||||
|
|
||||||
// SupportedSignAlgs implements AccessTokenVerifier interface
|
|
||||||
func (i *accessTokenVerifier) SupportedSignAlgs() []string {
|
|
||||||
return i.supportedSignAlgs
|
|
||||||
}
|
|
||||||
|
|
||||||
// KeySet implements AccessTokenVerifier interface
|
|
||||||
func (i *accessTokenVerifier) KeySet() oidc.KeySet {
|
|
||||||
return i.keySet
|
|
||||||
}
|
|
||||||
|
|
||||||
type AccessTokenVerifierOpt func(*accessTokenVerifier)
|
|
||||||
|
|
||||||
func WithSupportedAccessTokenSigningAlgorithms(algs ...string) AccessTokenVerifierOpt {
|
func WithSupportedAccessTokenSigningAlgorithms(algs ...string) AccessTokenVerifierOpt {
|
||||||
return func(verifier *accessTokenVerifier) {
|
return func(verifier *AccessTokenVerifier) {
|
||||||
verifier.supportedSignAlgs = algs
|
verifier.SupportedSignAlgs = algs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) AccessTokenVerifier {
|
// NewAccessTokenVerifier returns a AccessTokenVerifier suitable for access token verification.
|
||||||
verifier := &accessTokenVerifier{
|
func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) *AccessTokenVerifier {
|
||||||
issuer: issuer,
|
verifier := &AccessTokenVerifier{
|
||||||
keySet: keySet,
|
Issuer: issuer,
|
||||||
|
KeySet: keySet,
|
||||||
}
|
}
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(verifier)
|
opt(verifier)
|
||||||
|
@ -66,7 +29,7 @@ func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTok
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyAccessToken validates the access token (issuer, signature and expiration).
|
// VerifyAccessToken validates the access token (issuer, signature and expiration).
|
||||||
func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v AccessTokenVerifier) (claims C, err error) {
|
func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v *AccessTokenVerifier) (claims C, err error) {
|
||||||
var nilClaims C
|
var nilClaims C
|
||||||
|
|
||||||
decrypted, err := oidc.DecryptToken(token)
|
decrypted, err := oidc.DecryptToken(token)
|
||||||
|
@ -78,15 +41,15 @@ func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v Acces
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
if err := oidc.CheckIssuer(claims, v.Issuer); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
|
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
|
if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ func TestNewAccessTokenVerifier(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
want AccessTokenVerifier
|
want *AccessTokenVerifier
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "simple",
|
name: "simple",
|
||||||
|
@ -28,9 +28,9 @@ func TestNewAccessTokenVerifier(t *testing.T) {
|
||||||
issuer: tu.ValidIssuer,
|
issuer: tu.ValidIssuer,
|
||||||
keySet: tu.KeySet{},
|
keySet: tu.KeySet{},
|
||||||
},
|
},
|
||||||
want: &accessTokenVerifier{
|
want: &AccessTokenVerifier{
|
||||||
issuer: tu.ValidIssuer,
|
Issuer: tu.ValidIssuer,
|
||||||
keySet: tu.KeySet{},
|
KeySet: tu.KeySet{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -42,10 +42,10 @@ func TestNewAccessTokenVerifier(t *testing.T) {
|
||||||
WithSupportedAccessTokenSigningAlgorithms("ABC", "DEF"),
|
WithSupportedAccessTokenSigningAlgorithms("ABC", "DEF"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: &accessTokenVerifier{
|
want: &AccessTokenVerifier{
|
||||||
issuer: tu.ValidIssuer,
|
Issuer: tu.ValidIssuer,
|
||||||
keySet: tu.KeySet{},
|
KeySet: tu.KeySet{},
|
||||||
supportedSignAlgs: []string{"ABC", "DEF"},
|
SupportedSignAlgs: []string{"ABC", "DEF"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -58,12 +58,12 @@ func TestNewAccessTokenVerifier(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVerifyAccessToken(t *testing.T) {
|
func TestVerifyAccessToken(t *testing.T) {
|
||||||
verifier := &accessTokenVerifier{
|
verifier := &AccessTokenVerifier{
|
||||||
issuer: tu.ValidIssuer,
|
Issuer: tu.ValidIssuer,
|
||||||
maxAgeIAT: 2 * time.Minute,
|
MaxAgeIAT: 2 * time.Minute,
|
||||||
offset: time.Second,
|
Offset: time.Second,
|
||||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||||
keySet: tu.KeySet{},
|
KeySet: tu.KeySet{},
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
|
|
@ -2,69 +2,24 @@ package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IDTokenHintVerifier interface {
|
type IDTokenHintVerifier oidc.Verifier
|
||||||
oidc.Verifier
|
|
||||||
SupportedSignAlgs() []string
|
|
||||||
KeySet() oidc.KeySet
|
|
||||||
ACR() oidc.ACRVerifier
|
|
||||||
MaxAge() time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
type idTokenHintVerifier struct {
|
type IDTokenHintVerifierOpt func(*IDTokenHintVerifier)
|
||||||
issuer string
|
|
||||||
maxAgeIAT time.Duration
|
|
||||||
offset time.Duration
|
|
||||||
supportedSignAlgs []string
|
|
||||||
maxAge time.Duration
|
|
||||||
acr oidc.ACRVerifier
|
|
||||||
keySet oidc.KeySet
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenHintVerifier) Issuer() string {
|
|
||||||
return i.issuer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenHintVerifier) MaxAgeIAT() time.Duration {
|
|
||||||
return i.maxAgeIAT
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenHintVerifier) Offset() time.Duration {
|
|
||||||
return i.offset
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenHintVerifier) SupportedSignAlgs() []string {
|
|
||||||
return i.supportedSignAlgs
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenHintVerifier) KeySet() oidc.KeySet {
|
|
||||||
return i.keySet
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenHintVerifier) ACR() oidc.ACRVerifier {
|
|
||||||
return i.acr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *idTokenHintVerifier) MaxAge() time.Duration {
|
|
||||||
return i.maxAge
|
|
||||||
}
|
|
||||||
|
|
||||||
type IDTokenHintVerifierOpt func(*idTokenHintVerifier)
|
|
||||||
|
|
||||||
func WithSupportedIDTokenHintSigningAlgorithms(algs ...string) IDTokenHintVerifierOpt {
|
func WithSupportedIDTokenHintSigningAlgorithms(algs ...string) IDTokenHintVerifierOpt {
|
||||||
return func(verifier *idTokenHintVerifier) {
|
return func(verifier *IDTokenHintVerifier) {
|
||||||
verifier.supportedSignAlgs = algs
|
verifier.SupportedSignAlgs = algs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) IDTokenHintVerifier {
|
func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) *IDTokenHintVerifier {
|
||||||
verifier := &idTokenHintVerifier{
|
verifier := &IDTokenHintVerifier{
|
||||||
issuer: issuer,
|
Issuer: issuer,
|
||||||
keySet: keySet,
|
KeySet: keySet,
|
||||||
}
|
}
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(verifier)
|
opt(verifier)
|
||||||
|
@ -74,7 +29,7 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHi
|
||||||
|
|
||||||
// VerifyIDTokenHint validates the id token according to
|
// VerifyIDTokenHint validates the id token according to
|
||||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||||
func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTokenHintVerifier) (claims C, err error) {
|
func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v *IDTokenHintVerifier) (claims C, err error) {
|
||||||
var nilClaims C
|
var nilClaims C
|
||||||
|
|
||||||
decrypted, err := oidc.DecryptToken(token)
|
decrypted, err := oidc.DecryptToken(token)
|
||||||
|
@ -86,27 +41,27 @@ func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTok
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
if err := oidc.CheckIssuer(claims, v.Issuer); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
|
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
|
if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
|
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
|
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
|
if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
return claims, nil
|
return claims, nil
|
||||||
|
|
|
@ -20,7 +20,7 @@ func TestNewIDTokenHintVerifier(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
want IDTokenHintVerifier
|
want *IDTokenHintVerifier
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "simple",
|
name: "simple",
|
||||||
|
@ -28,9 +28,9 @@ func TestNewIDTokenHintVerifier(t *testing.T) {
|
||||||
issuer: tu.ValidIssuer,
|
issuer: tu.ValidIssuer,
|
||||||
keySet: tu.KeySet{},
|
keySet: tu.KeySet{},
|
||||||
},
|
},
|
||||||
want: &idTokenHintVerifier{
|
want: &IDTokenHintVerifier{
|
||||||
issuer: tu.ValidIssuer,
|
Issuer: tu.ValidIssuer,
|
||||||
keySet: tu.KeySet{},
|
KeySet: tu.KeySet{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -42,10 +42,10 @@ func TestNewIDTokenHintVerifier(t *testing.T) {
|
||||||
WithSupportedIDTokenHintSigningAlgorithms("ABC", "DEF"),
|
WithSupportedIDTokenHintSigningAlgorithms("ABC", "DEF"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: &idTokenHintVerifier{
|
want: &IDTokenHintVerifier{
|
||||||
issuer: tu.ValidIssuer,
|
Issuer: tu.ValidIssuer,
|
||||||
keySet: tu.KeySet{},
|
KeySet: tu.KeySet{},
|
||||||
supportedSignAlgs: []string{"ABC", "DEF"},
|
SupportedSignAlgs: []string{"ABC", "DEF"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -58,14 +58,14 @@ func TestNewIDTokenHintVerifier(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVerifyIDTokenHint(t *testing.T) {
|
func TestVerifyIDTokenHint(t *testing.T) {
|
||||||
verifier := &idTokenHintVerifier{
|
verifier := &IDTokenHintVerifier{
|
||||||
issuer: tu.ValidIssuer,
|
Issuer: tu.ValidIssuer,
|
||||||
maxAgeIAT: 2 * time.Minute,
|
MaxAgeIAT: 2 * time.Minute,
|
||||||
offset: time.Second,
|
Offset: time.Second,
|
||||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||||
maxAge: 2 * time.Minute,
|
MaxAge: 2 * time.Minute,
|
||||||
acr: tu.ACRVerify,
|
ACR: tu.ACRVerify,
|
||||||
keySet: tu.KeySet{},
|
KeySet: tu.KeySet{},
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
|
|
@ -11,28 +11,25 @@ import (
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JWTProfileVerifier interface {
|
// JWTProfileVerfiier extends oidc.Verifier with
|
||||||
|
// a jwtProfileKeyStorage and a function to check
|
||||||
|
// the subject in a token.
|
||||||
|
type JWTProfileVerifier struct {
|
||||||
oidc.Verifier
|
oidc.Verifier
|
||||||
Storage() jwtProfileKeyStorage
|
Storage JWTProfileKeyStorage
|
||||||
CheckSubject(request *oidc.JWTTokenRequest) error
|
CheckSubject func(request *oidc.JWTTokenRequest) error
|
||||||
}
|
|
||||||
|
|
||||||
type jwtProfileVerifier struct {
|
|
||||||
storage jwtProfileKeyStorage
|
|
||||||
subjectCheck func(request *oidc.JWTTokenRequest) error
|
|
||||||
issuer string
|
|
||||||
maxAgeIAT time.Duration
|
|
||||||
offset time.Duration
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewJWTProfileVerifier creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication)
|
// NewJWTProfileVerifier creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication)
|
||||||
func NewJWTProfileVerifier(storage jwtProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) JWTProfileVerifier {
|
func NewJWTProfileVerifier(storage JWTProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier {
|
||||||
j := &jwtProfileVerifier{
|
j := &JWTProfileVerifier{
|
||||||
storage: storage,
|
Verifier: oidc.Verifier{
|
||||||
subjectCheck: SubjectIsIssuer,
|
Issuer: issuer,
|
||||||
issuer: issuer,
|
MaxAgeIAT: maxAgeIAT,
|
||||||
maxAgeIAT: maxAgeIAT,
|
Offset: offset,
|
||||||
offset: offset,
|
},
|
||||||
|
Storage: storage,
|
||||||
|
CheckSubject: SubjectIsIssuer,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
|
@ -42,53 +39,35 @@ func NewJWTProfileVerifier(storage jwtProfileKeyStorage, issuer string, maxAgeIA
|
||||||
return j
|
return j
|
||||||
}
|
}
|
||||||
|
|
||||||
type JWTProfileVerifierOption func(*jwtProfileVerifier)
|
type JWTProfileVerifierOption func(*JWTProfileVerifier)
|
||||||
|
|
||||||
|
// SubjectCheck sets a custom function to check the subject.
|
||||||
|
// Defaults to SubjectIsIssuer()
|
||||||
func SubjectCheck(check func(request *oidc.JWTTokenRequest) error) JWTProfileVerifierOption {
|
func SubjectCheck(check func(request *oidc.JWTTokenRequest) error) JWTProfileVerifierOption {
|
||||||
return func(verifier *jwtProfileVerifier) {
|
return func(verifier *JWTProfileVerifier) {
|
||||||
verifier.subjectCheck = check
|
verifier.CheckSubject = check
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *jwtProfileVerifier) Issuer() string {
|
|
||||||
return v.issuer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *jwtProfileVerifier) Storage() jwtProfileKeyStorage {
|
|
||||||
return v.storage
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *jwtProfileVerifier) MaxAgeIAT() time.Duration {
|
|
||||||
return v.maxAgeIAT
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *jwtProfileVerifier) Offset() time.Duration {
|
|
||||||
return v.offset
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *jwtProfileVerifier) CheckSubject(request *oidc.JWTTokenRequest) error {
|
|
||||||
return v.subjectCheck(request)
|
|
||||||
}
|
|
||||||
|
|
||||||
// VerifyJWTAssertion verifies the assertion string from JWT Profile (authorization grant and client authentication)
|
// VerifyJWTAssertion verifies the assertion string from JWT Profile (authorization grant and client authentication)
|
||||||
//
|
//
|
||||||
// checks audience, exp, iat, signature and that issuer and sub are the same
|
// checks audience, exp, iat, signature and that issuer and sub are the same
|
||||||
func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) {
|
func VerifyJWTAssertion(ctx context.Context, assertion string, v *JWTProfileVerifier) (*oidc.JWTTokenRequest, error) {
|
||||||
request := new(oidc.JWTTokenRequest)
|
request := new(oidc.JWTTokenRequest)
|
||||||
payload, err := oidc.ParseToken(assertion, request)
|
payload, err := oidc.ParseToken(assertion, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckAudience(request, v.Issuer()); err != nil {
|
if err = oidc.CheckAudience(request, v.Issuer); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckExpiration(request, v.Offset()); err != nil {
|
if err = oidc.CheckExpiration(request, v.Offset); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = oidc.CheckIssuedAt(request, v.MaxAgeIAT(), v.Offset()); err != nil {
|
if err = oidc.CheckIssuedAt(request, v.MaxAgeIAT, v.Offset); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,17 +75,18 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
keySet := &jwtProfileKeySet{storage: v.Storage(), clientID: request.Issuer}
|
keySet := &jwtProfileKeySet{storage: v.Storage, clientID: request.Issuer}
|
||||||
if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil {
|
if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type jwtProfileKeyStorage interface {
|
type JWTProfileKeyStorage interface {
|
||||||
GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
|
GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SubjectIsIssuer
|
||||||
func SubjectIsIssuer(request *oidc.JWTTokenRequest) error {
|
func SubjectIsIssuer(request *oidc.JWTTokenRequest) error {
|
||||||
if request.Issuer != request.Subject {
|
if request.Issuer != request.Subject {
|
||||||
return errors.New("delegation not allowed, issuer and sub must be identical")
|
return errors.New("delegation not allowed, issuer and sub must be identical")
|
||||||
|
@ -115,7 +95,7 @@ func SubjectIsIssuer(request *oidc.JWTTokenRequest) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
type jwtProfileKeySet struct {
|
type jwtProfileKeySet struct {
|
||||||
storage jwtProfileKeyStorage
|
storage JWTProfileKeyStorage
|
||||||
clientID string
|
clientID string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
117
pkg/op/verifier_jwt_profile_test.go
Normal file
117
pkg/op/verifier_jwt_profile_test.go
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
package op_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
tu "github.com/zitadel/oidc/v3/internal/testutil"
|
||||||
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewJWTProfileVerifier(t *testing.T) {
|
||||||
|
want := &op.JWTProfileVerifier{
|
||||||
|
Verifier: oidc.Verifier{
|
||||||
|
Issuer: tu.ValidIssuer,
|
||||||
|
MaxAgeIAT: time.Minute,
|
||||||
|
Offset: time.Second,
|
||||||
|
},
|
||||||
|
Storage: tu.JWTProfileKeyStorage{},
|
||||||
|
}
|
||||||
|
got := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, time.Second, op.SubjectCheck(func(request *oidc.JWTTokenRequest) error {
|
||||||
|
return oidc.ErrSubjectMissing
|
||||||
|
}))
|
||||||
|
assert.Equal(t, want.Verifier, got.Verifier)
|
||||||
|
assert.Equal(t, want.Storage, got.Storage)
|
||||||
|
assert.ErrorIs(t, got.CheckSubject(nil), oidc.ErrSubjectMissing)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyJWTAssertion(t *testing.T) {
|
||||||
|
errCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
verifier := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, 0)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
newToken func() (string, *oidc.JWTTokenRequest)
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "parse error",
|
||||||
|
ctx: context.Background(),
|
||||||
|
newToken: func() (string, *oidc.JWTTokenRequest) { return "!", nil },
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong audience",
|
||||||
|
ctx: context.Background(),
|
||||||
|
newToken: func() (string, *oidc.JWTTokenRequest) {
|
||||||
|
return tu.NewJWTProfileAssertion(
|
||||||
|
tu.ValidClientID, tu.ValidClientID, []string{"wrong"},
|
||||||
|
time.Now(), tu.ValidExpiration,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired",
|
||||||
|
ctx: context.Background(),
|
||||||
|
newToken: func() (string, *oidc.JWTTokenRequest) {
|
||||||
|
return tu.NewJWTProfileAssertion(
|
||||||
|
tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer},
|
||||||
|
time.Now(), time.Now().Add(-time.Hour),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid iat",
|
||||||
|
ctx: context.Background(),
|
||||||
|
newToken: func() (string, *oidc.JWTTokenRequest) {
|
||||||
|
return tu.NewJWTProfileAssertion(
|
||||||
|
tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer},
|
||||||
|
time.Now().Add(time.Hour), tu.ValidExpiration,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid subject",
|
||||||
|
ctx: context.Background(),
|
||||||
|
newToken: func() (string, *oidc.JWTTokenRequest) {
|
||||||
|
return tu.NewJWTProfileAssertion(
|
||||||
|
tu.ValidClientID, "wrong", []string{tu.ValidIssuer},
|
||||||
|
time.Now(), tu.ValidExpiration,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "check signature fail",
|
||||||
|
ctx: errCtx,
|
||||||
|
newToken: tu.ValidJWTProfileAssertion,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
ctx: context.Background(),
|
||||||
|
newToken: tu.ValidJWTProfileAssertion,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assertion, want := tt.newToken()
|
||||||
|
got, err := op.VerifyJWTAssertion(tt.ctx, assertion, verifier)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue