feat: merge the verifier types (#336)

BREAKING CHANGE:

- The various verifier types are merged into a oidc.Verifir.
- oidc.Verfier became a struct with exported fields

* use type aliases for oidc.Verifier

this binds the correct contstructor to each verifier usecase.

* fix: handle the zero cases for oidc.Time

* add unit tests to oidc verifier

* fix: correct returned field for JWTTokenRequest

JWTTokenRequest.GetIssuedAt() was returning the ExpiresAt field.
This change corrects that by returning IssuedAt instead.
This commit is contained in:
Tim Möhlmann 2023-03-22 19:18:41 +02:00 committed by GitHub
parent c8cf15e266
commit 33c716ddcf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 948 additions and 351 deletions

View file

@ -8,6 +8,7 @@ import (
"errors" "errors"
"time" "time"
"github.com/muhlemmer/gu"
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
) )
@ -17,7 +18,7 @@ type KeySet struct{}
// VerifySignature implments op.KeySet. // VerifySignature implments op.KeySet.
func (KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) { func (KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
if ctx.Err() != nil { if err = ctx.Err(); err != nil {
return nil, err return nil, err
} }
@ -45,6 +46,16 @@ func init() {
} }
} }
type JWTProfileKeyStorage struct{}
func (JWTProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID string, clientID string) (*jose.JSONWebKey, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
return gu.Ptr(WebKey.Public()), nil
}
func signEncodeTokenClaims(claims any) string { func signEncodeTokenClaims(claims any) string {
payload, err := json.Marshal(claims) payload, err := json.Marshal(claims)
if err != nil { if err != nil {
@ -106,6 +117,25 @@ func NewAccessToken(issuer, subject string, audience []string, expiration time.T
return NewAccessTokenCustom(issuer, subject, audience, expiration, jwtid, clientID, skew, nil) return NewAccessTokenCustom(issuer, subject, audience, expiration, jwtid, clientID, skew, nil)
} }
func NewJWTProfileAssertion(issuer, clientID string, audience []string, issuedAt, expiration time.Time) (string, *oidc.JWTTokenRequest) {
req := &oidc.JWTTokenRequest{
Issuer: issuer,
Subject: clientID,
Audience: audience,
ExpiresAt: oidc.FromTime(expiration),
IssuedAt: oidc.FromTime(issuedAt),
}
// make sure the private claim map is set correctly
data, err := json.Marshal(req)
if err != nil {
panic(err)
}
if err = json.Unmarshal(data, req); err != nil {
panic(err)
}
return signEncodeTokenClaims(req), req
}
const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg` const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg`
// These variables always result in a valid token // These variables always result in a valid token
@ -137,6 +167,10 @@ func ValidAccessToken() (string, *oidc.AccessTokenClaims) {
return NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew) return NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew)
} }
func ValidJWTProfileAssertion() (string, *oidc.JWTTokenRequest) {
return NewJWTProfileAssertion(ValidClientID, ValidClientID, []string{ValidIssuer}, time.Now(), ValidExpiration)
}
// ACRVerify is a oidc.ACRVerifier func. // ACRVerify is a oidc.ACRVerifier func.
func ACRVerify(acr string) error { func ACRVerify(acr string) error {
if acr != ValidACR { if acr != ValidACR {

View file

@ -63,8 +63,8 @@ type RelyingParty interface {
// be used to start a DeviceAuthorization flow. // be used to start a DeviceAuthorization flow.
GetDeviceAuthorizationEndpoint() string GetDeviceAuthorizationEndpoint() string
// IDTokenVerifier returns the verifier interface used for oidc id_token verification // IDTokenVerifier returns the verifier used for oidc id_token verification
IDTokenVerifier() IDTokenVerifier IDTokenVerifier() *IDTokenVerifier
// ErrorHandler returns the handler used for callback errors // ErrorHandler returns the handler used for callback errors
ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)
@ -88,7 +88,7 @@ type relyingParty struct {
cookieHandler *httphelper.CookieHandler cookieHandler *httphelper.CookieHandler
errorHandler func(http.ResponseWriter, *http.Request, string, string, string) errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
idTokenVerifier IDTokenVerifier idTokenVerifier *IDTokenVerifier
verifierOpts []VerifierOption verifierOpts []VerifierOption
signer jose.Signer signer jose.Signer
} }
@ -137,7 +137,7 @@ func (rp *relyingParty) GetRevokeEndpoint() string {
return rp.endpoints.RevokeURL return rp.endpoints.RevokeURL
} }
func (rp *relyingParty) IDTokenVerifier() IDTokenVerifier { func (rp *relyingParty) IDTokenVerifier() *IDTokenVerifier {
if rp.idTokenVerifier == nil { if rp.idTokenVerifier == nil {
rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...) rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...)
} }

View file

@ -9,19 +9,9 @@ import (
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
type IDTokenVerifier interface {
oidc.Verifier
ClientID() string
SupportedSignAlgs() []string
KeySet() oidc.KeySet
Nonce(context.Context) string
ACR() oidc.ACRVerifier
MaxAge() time.Duration
}
// VerifyTokens implement the Token Response Validation as defined in OIDC specification // VerifyTokens implement the Token Response Validation as defined in OIDC specification
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) { func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier) (claims C, err error) {
var nilClaims C var nilClaims C
claims, err = VerifyIDToken[C](ctx, idToken, v) claims, err = VerifyIDToken[C](ctx, idToken, v)
@ -36,7 +26,7 @@ func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken str
// VerifyIDToken validates the id token according to // VerifyIDToken validates the id token according to
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) { func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier) (claims C, err error) {
var nilClaims C var nilClaims C
decrypted, err := oidc.DecryptToken(token) decrypted, err := oidc.DecryptToken(token)
@ -52,27 +42,27 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil { if err = oidc.CheckIssuer(claims, v.Issuer); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckAudience(claims, v.ClientID()); err != nil { if err = oidc.CheckAudience(claims, v.ClientID); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil { if err = oidc.CheckAuthorizedParty(claims, v.ClientID); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil { if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil {
return nilClaims, err return nilClaims, err
} }
@ -80,16 +70,18 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil { if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil { if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil {
return nilClaims, err return nilClaims, err
} }
return claims, nil return claims, nil
} }
type IDTokenVerifier oidc.Verifier
// VerifyAccessToken validates the access token according to // VerifyAccessToken validates the access token according to
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation // https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error { func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
@ -107,15 +99,14 @@ func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAl
return nil return nil
} }
// NewIDTokenVerifier returns an implementation of `IDTokenVerifier` // NewIDTokenVerifier returns a oidc.Verifier suitable for ID token verification.
// for `VerifyTokens` and `VerifyIDToken` func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) *IDTokenVerifier {
func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) IDTokenVerifier { v := &IDTokenVerifier{
v := &idTokenVerifier{ Issuer: issuer,
issuer: issuer, ClientID: clientID,
clientID: clientID, KeySet: keySet,
keySet: keySet, Offset: time.Second,
offset: time.Second, Nonce: func(_ context.Context) string {
nonce: func(_ context.Context) string {
return "" return ""
}, },
} }
@ -128,95 +119,47 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...
} }
// VerifierOption is the type for providing dynamic options to the IDTokenVerifier // VerifierOption is the type for providing dynamic options to the IDTokenVerifier
type VerifierOption func(*idTokenVerifier) type VerifierOption func(*IDTokenVerifier)
// WithIssuedAtOffset mitigates the risk of iat to be in the future // WithIssuedAtOffset mitigates the risk of iat to be in the future
// because of clock skews with the ability to add an offset to the current time // because of clock skews with the ability to add an offset to the current time
func WithIssuedAtOffset(offset time.Duration) func(*idTokenVerifier) { func WithIssuedAtOffset(offset time.Duration) VerifierOption {
return func(v *idTokenVerifier) { return func(v *IDTokenVerifier) {
v.offset = offset v.Offset = offset
} }
} }
// WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now // WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
func WithIssuedAtMaxAge(maxAge time.Duration) func(*idTokenVerifier) { func WithIssuedAtMaxAge(maxAge time.Duration) VerifierOption {
return func(v *idTokenVerifier) { return func(v *IDTokenVerifier) {
v.maxAgeIAT = maxAge v.MaxAgeIAT = maxAge
} }
} }
// WithNonce sets the function to check the nonce // WithNonce sets the function to check the nonce
func WithNonce(nonce func(context.Context) string) VerifierOption { func WithNonce(nonce func(context.Context) string) VerifierOption {
return func(v *idTokenVerifier) { return func(v *IDTokenVerifier) {
v.nonce = nonce v.Nonce = nonce
} }
} }
// WithACRVerifier sets the verifier for the acr claim // WithACRVerifier sets the verifier for the acr claim
func WithACRVerifier(verifier oidc.ACRVerifier) VerifierOption { func WithACRVerifier(verifier oidc.ACRVerifier) VerifierOption {
return func(v *idTokenVerifier) { return func(v *IDTokenVerifier) {
v.acr = verifier v.ACR = verifier
} }
} }
// WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now // WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now
func WithAuthTimeMaxAge(maxAge time.Duration) VerifierOption { func WithAuthTimeMaxAge(maxAge time.Duration) VerifierOption {
return func(v *idTokenVerifier) { return func(v *IDTokenVerifier) {
v.maxAge = maxAge v.MaxAge = maxAge
} }
} }
// WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm // WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm
func WithSupportedSigningAlgorithms(algs ...string) VerifierOption { func WithSupportedSigningAlgorithms(algs ...string) VerifierOption {
return func(v *idTokenVerifier) { return func(v *IDTokenVerifier) {
v.supportedSignAlgs = algs v.SupportedSignAlgs = algs
} }
} }
type idTokenVerifier struct {
issuer string
maxAgeIAT time.Duration
offset time.Duration
clientID string
supportedSignAlgs []string
keySet oidc.KeySet
acr oidc.ACRVerifier
maxAge time.Duration
nonce func(ctx context.Context) string
}
func (i *idTokenVerifier) Issuer() string {
return i.issuer
}
func (i *idTokenVerifier) MaxAgeIAT() time.Duration {
return i.maxAgeIAT
}
func (i *idTokenVerifier) Offset() time.Duration {
return i.offset
}
func (i *idTokenVerifier) ClientID() string {
return i.clientID
}
func (i *idTokenVerifier) SupportedSignAlgs() []string {
return i.supportedSignAlgs
}
func (i *idTokenVerifier) KeySet() oidc.KeySet {
return i.keySet
}
func (i *idTokenVerifier) Nonce(ctx context.Context) string {
return i.nonce(ctx)
}
func (i *idTokenVerifier) ACR() oidc.ACRVerifier {
return i.acr
}
func (i *idTokenVerifier) MaxAge() time.Duration {
return i.maxAge
}

View file

@ -13,16 +13,16 @@ import (
) )
func TestVerifyTokens(t *testing.T) { func TestVerifyTokens(t *testing.T) {
verifier := &idTokenVerifier{ verifier := &IDTokenVerifier{
issuer: tu.ValidIssuer, Issuer: tu.ValidIssuer,
maxAgeIAT: 2 * time.Minute, MaxAgeIAT: 2 * time.Minute,
offset: time.Second, Offset: time.Second,
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
keySet: tu.KeySet{}, KeySet: tu.KeySet{},
maxAge: 2 * time.Minute, MaxAge: 2 * time.Minute,
acr: tu.ACRVerify, ACR: tu.ACRVerify,
nonce: func(context.Context) string { return tu.ValidNonce }, Nonce: func(context.Context) string { return tu.ValidNonce },
clientID: tu.ValidClientID, ClientID: tu.ValidClientID,
} }
accessToken, _ := tu.ValidAccessToken() accessToken, _ := tu.ValidAccessToken()
atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm) atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm)
@ -91,15 +91,15 @@ func TestVerifyTokens(t *testing.T) {
} }
func TestVerifyIDToken(t *testing.T) { func TestVerifyIDToken(t *testing.T) {
verifier := &idTokenVerifier{ verifier := &IDTokenVerifier{
issuer: tu.ValidIssuer, Issuer: tu.ValidIssuer,
maxAgeIAT: 2 * time.Minute, MaxAgeIAT: 2 * time.Minute,
offset: time.Second, Offset: time.Second,
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
keySet: tu.KeySet{}, KeySet: tu.KeySet{},
maxAge: 2 * time.Minute, MaxAge: 2 * time.Minute,
acr: tu.ACRVerify, ACR: tu.ACRVerify,
nonce: func(context.Context) string { return tu.ValidNonce }, Nonce: func(context.Context) string { return tu.ValidNonce },
} }
tests := []struct { tests := []struct {
@ -219,7 +219,7 @@ func TestVerifyIDToken(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
token, want := tt.tokenClaims() token, want := tt.tokenClaims()
verifier.clientID = tt.clientID verifier.ClientID = tt.clientID
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier) got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
if tt.wantErr { if tt.wantErr {
assert.Error(t, err) assert.Error(t, err)
@ -300,7 +300,7 @@ func TestNewIDTokenVerifier(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
want IDTokenVerifier want *IDTokenVerifier
}{ }{
{ {
name: "nil nonce", // otherwise assert.Equal will fail on the function name: "nil nonce", // otherwise assert.Equal will fail on the function
@ -317,16 +317,16 @@ func TestNewIDTokenVerifier(t *testing.T) {
WithSupportedSigningAlgorithms("ABC", "DEF"), WithSupportedSigningAlgorithms("ABC", "DEF"),
}, },
}, },
want: &idTokenVerifier{ want: &IDTokenVerifier{
issuer: tu.ValidIssuer, Issuer: tu.ValidIssuer,
offset: time.Minute, Offset: time.Minute,
maxAgeIAT: time.Hour, MaxAgeIAT: time.Hour,
clientID: tu.ValidClientID, ClientID: tu.ValidClientID,
keySet: tu.KeySet{}, KeySet: tu.KeySet{},
nonce: nil, Nonce: nil,
acr: nil, ACR: nil,
maxAge: 2 * time.Hour, MaxAge: 2 * time.Hour,
supportedSignAlgs: []string{"ABC", "DEF"}, SupportedSignAlgs: []string{"ABC", "DEF"},
}, },
}, },
} }

View file

@ -192,7 +192,7 @@ func (j *JWTTokenRequest) GetExpiration() time.Time {
// GetIssuedAt implements the Claims interface // GetIssuedAt implements the Claims interface
func (j *JWTTokenRequest) GetIssuedAt() time.Time { func (j *JWTTokenRequest) GetIssuedAt() time.Time {
return j.ExpiresAt.AsTime() return j.IssuedAt.AsTime()
} }
// GetNonce implements the Claims interface // GetNonce implements the Claims interface

View file

@ -173,10 +173,16 @@ func NewEncoder() *schema.Encoder {
type Time int64 type Time int64
func (ts Time) AsTime() time.Time { func (ts Time) AsTime() time.Time {
if ts == 0 {
return time.Time{}
}
return time.Unix(int64(ts), 0) return time.Unix(int64(ts), 0)
} }
func FromTime(tt time.Time) Time { func FromTime(tt time.Time) Time {
if tt.IsZero() {
return 0
}
return Time(tt.Unix()) return Time(tt.Unix())
} }

View file

@ -7,6 +7,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time"
"github.com/gorilla/schema" "github.com/gorilla/schema"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -467,6 +468,56 @@ func TestNewEncoder(t *testing.T) {
assert.Equal(t, a, b) assert.Equal(t, a, b)
} }
func TestTime_AsTime(t *testing.T) {
tests := []struct {
name string
ts Time
want time.Time
}{
{
name: "unset",
ts: 0,
want: time.Time{},
},
{
name: "set",
ts: 1,
want: time.Unix(1, 0),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.ts.AsTime()
assert.Equal(t, tt.want, got)
})
}
}
func TestTime_FromTime(t *testing.T) {
tests := []struct {
name string
tt time.Time
want Time
}{
{
name: "zero",
tt: time.Time{},
want: 0,
},
{
name: "set",
tt: time.Unix(1, 0),
want: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FromTime(tt.tt)
assert.Equal(t, tt.want, got)
})
}
}
func TestTime_UnmarshalJSON(t *testing.T) { func TestTime_UnmarshalJSON(t *testing.T) {
type dst struct { type dst struct {
UpdatedAt Time `json:"updated_at"` UpdatedAt Time `json:"updated_at"`

View file

@ -61,10 +61,19 @@ var (
ErrAtHash = errors.New("at_hash does not correspond to access token") ErrAtHash = errors.New("at_hash does not correspond to access token")
) )
type Verifier interface { // Verifier caries configuration for the various token verification
Issuer() string // functions. Use package specific constructor functions to know
MaxAgeIAT() time.Duration // which values need to be set.
Offset() time.Duration type Verifier struct {
Issuer string
MaxAgeIAT time.Duration
Offset time.Duration
ClientID string
SupportedSignAlgs []string
MaxAge time.Duration
ACR ACRVerifier
KeySet KeySet
Nonce func(ctx context.Context) string
} }
// ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim // ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
@ -121,6 +130,11 @@ func CheckAudience(claims Claims, clientID string) error {
return nil return nil
} }
// CheckAuthorizedParty checks azp (authorized party) claim requirements.
//
// If the ID Token contains multiple audiences, the Client SHOULD verify that an azp Claim is present.
// If an azp Claim is present, the Client SHOULD verify that its client_id is the Claim Value.
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func CheckAuthorizedParty(claims Claims, clientID string) error { func CheckAuthorizedParty(claims Claims, clientID string) error {
if len(claims.GetAudience()) > 1 { if len(claims.GetAudience()) > 1 {
if claims.GetAuthorizedParty() == "" { if claims.GetAuthorizedParty() == "" {
@ -167,26 +181,26 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl
} }
func CheckExpiration(claims Claims, offset time.Duration) error { func CheckExpiration(claims Claims, offset time.Duration) error {
expiration := claims.GetExpiration().Round(time.Second) expiration := claims.GetExpiration()
if !time.Now().UTC().Add(offset).Before(expiration) { if !time.Now().Add(offset).Before(expiration) {
return ErrExpired return ErrExpired
} }
return nil return nil
} }
func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error { func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error {
issuedAt := claims.GetIssuedAt().Round(time.Second) issuedAt := claims.GetIssuedAt()
if issuedAt.IsZero() { if issuedAt.IsZero() {
return ErrIatMissing return ErrIatMissing
} }
nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second) nowWithOffset := time.Now().Add(offset).Round(time.Second)
if issuedAt.After(nowWithOffset) { if issuedAt.After(nowWithOffset) {
return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset) return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset)
} }
if maxAgeIAT == 0 { if maxAgeIAT == 0 {
return nil return nil
} }
maxAge := time.Now().UTC().Add(-maxAgeIAT).Round(time.Second) maxAge := time.Now().Add(-maxAgeIAT).Round(time.Second)
if issuedAt.Before(maxAge) { if issuedAt.Before(maxAge) {
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt)) return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt))
} }
@ -216,8 +230,8 @@ func CheckAuthTime(claims Claims, maxAge time.Duration) error {
if claims.GetAuthTime().IsZero() { if claims.GetAuthTime().IsZero() {
return ErrAuthTimeNotPresent return ErrAuthTimeNotPresent
} }
authTime := claims.GetAuthTime().Round(time.Second) authTime := claims.GetAuthTime()
maxAuthTime := time.Now().UTC().Add(-maxAge).Round(time.Second) maxAuthTime := time.Now().Add(-maxAge).Round(time.Second)
if authTime.Before(maxAuthTime) { if authTime.Before(maxAuthTime) {
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime)) return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime))
} }

View file

@ -0,0 +1,128 @@
package oidc_test
import (
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
func TestParseToken(t *testing.T) {
token, wantClaims := tu.ValidIDToken()
wantClaims.SignatureAlg = "" // unset, because is not part of the JSON payload
wantPayload, err := json.Marshal(wantClaims)
require.NoError(t, err)
tests := []struct {
name string
tokenString string
wantErr bool
}{
{
name: "split error",
tokenString: "nope",
wantErr: true,
},
{
name: "base64 error",
tokenString: "foo.~.bar",
wantErr: true,
},
{
name: "success",
tokenString: token,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotClaims := new(oidc.IDTokenClaims)
gotPayload, err := oidc.ParseToken(tt.tokenString, gotClaims)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, wantClaims, gotClaims)
assert.JSONEq(t, string(wantPayload), string(gotPayload))
})
}
}
func TestCheckSignature(t *testing.T) {
errCtx, cancel := context.WithCancel(context.Background())
cancel()
token, _ := tu.ValidIDToken()
payload, err := oidc.ParseToken(token, &oidc.IDTokenClaims{})
require.NoError(t, err)
type args struct {
ctx context.Context
token string
payload []byte
supportedSigAlgs []string
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "parse error",
args: args{
ctx: context.Background(),
token: "~",
payload: payload,
},
wantErr: oidc.ErrParse,
},
{
name: "default sigAlg",
args: args{
ctx: context.Background(),
token: token,
payload: payload,
},
},
{
name: "unsupported sigAlg",
args: args{
ctx: context.Background(),
token: token,
payload: payload,
supportedSigAlgs: []string{"foo", "bar"},
},
wantErr: oidc.ErrSignatureUnsupportedAlg,
},
{
name: "verify error",
args: args{
ctx: errCtx,
token: token,
payload: payload,
},
wantErr: oidc.ErrSignatureInvalid,
},
{
name: "inequal payloads",
args: args{
ctx: context.Background(),
token: token,
payload: []byte{0, 1, 2},
},
wantErr: oidc.ErrSignatureInvalidPayload,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
claims := new(oidc.TokenClaims)
err := oidc.CheckSignature(tt.args.ctx, tt.args.token, tt.args.payload, claims, tt.args.supportedSigAlgs, tu.KeySet{})
assert.ErrorIs(t, err, tt.wantErr)
})
}
}

374
pkg/oidc/verifier_test.go Normal file
View file

@ -0,0 +1,374 @@
package oidc
import (
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDecryptToken(t *testing.T) {
const tokenString = "ABC"
got, err := DecryptToken(tokenString)
require.NoError(t, err)
assert.Equal(t, tokenString, got)
}
func TestDefaultACRVerifier(t *testing.T) {
acrVerfier := DefaultACRVerifier([]string{"foo", "bar"})
tests := []struct {
name string
acr string
wantErr string
}{
{
name: "ok",
acr: "bar",
},
{
name: "error",
acr: "hello",
wantErr: "expected one of: [foo bar], got: \"hello\"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := acrVerfier(tt.acr)
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
return
}
require.NoError(t, err)
})
}
}
func TestCheckSubject(t *testing.T) {
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrSubjectMissing,
},
{
name: "ok",
claims: &TokenClaims{
Subject: "foo",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckSubject(tt.claims)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckIssuer(t *testing.T) {
const issuer = "foo.bar"
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrIssuerInvalid,
},
{
name: "wrong",
claims: &TokenClaims{
Issuer: "wrong",
},
wantErr: ErrIssuerInvalid,
},
{
name: "ok",
claims: &TokenClaims{
Issuer: issuer,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckIssuer(tt.claims, issuer)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckAudience(t *testing.T) {
const clientID = "foo.bar"
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrAudience,
},
{
name: "wrong",
claims: &TokenClaims{
Audience: []string{"wrong"},
},
wantErr: ErrAudience,
},
{
name: "ok",
claims: &TokenClaims{
Audience: []string{clientID},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckAudience(tt.claims, clientID)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckAuthorizedParty(t *testing.T) {
const clientID = "foo.bar"
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "single audience, no azp",
claims: &TokenClaims{
Audience: []string{clientID},
},
},
{
name: "multiple audience, no azp",
claims: &TokenClaims{
Audience: []string{clientID, "other"},
},
wantErr: ErrAzpMissing,
},
{
name: "single audience, with azp",
claims: &TokenClaims{
Audience: []string{clientID},
AuthorizedParty: clientID,
},
},
{
name: "multiple audience, with azp",
claims: &TokenClaims{
Audience: []string{clientID, "other"},
AuthorizedParty: clientID,
},
},
{
name: "wrong azp",
claims: &TokenClaims{
AuthorizedParty: "wrong",
},
wantErr: ErrAzpInvalid,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckAuthorizedParty(tt.claims, clientID)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckExpiration(t *testing.T) {
const offset = time.Minute
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrExpired,
},
{
name: "expired",
claims: &TokenClaims{
Expiration: FromTime(time.Now().Add(-2 * offset)),
},
wantErr: ErrExpired,
},
{
name: "valid",
claims: &TokenClaims{
Expiration: FromTime(time.Now().Add(2 * offset)),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckExpiration(tt.claims, offset)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckIssuedAt(t *testing.T) {
const offset = time.Minute
tests := []struct {
name string
maxAgeIAT time.Duration
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrIatMissing,
},
{
name: "future",
claims: &TokenClaims{
IssuedAt: FromTime(time.Now().Add(time.Hour)),
},
wantErr: ErrIatInFuture,
},
{
name: "no max",
claims: &TokenClaims{
IssuedAt: FromTime(time.Now()),
},
},
{
name: "past max",
maxAgeIAT: time.Minute,
claims: &TokenClaims{
IssuedAt: FromTime(time.Now().Add(-time.Hour)),
},
wantErr: ErrIatToOld,
},
{
name: "within max",
maxAgeIAT: time.Hour,
claims: &TokenClaims{
IssuedAt: FromTime(time.Now()),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckIssuedAt(tt.claims, tt.maxAgeIAT, offset)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckNonce(t *testing.T) {
const nonce = "123"
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrNonceInvalid,
},
{
name: "wrong",
claims: &TokenClaims{
Nonce: "wrong",
},
wantErr: ErrNonceInvalid,
},
{
name: "ok",
claims: &TokenClaims{
Nonce: nonce,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckNonce(tt.claims, nonce)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckAuthorizationContextClassReference(t *testing.T) {
tests := []struct {
name string
acr ACRVerifier
wantErr error
}{
{
name: "error",
acr: func(s string) error { return errors.New("oops") },
wantErr: ErrAcrInvalid,
},
{
name: "ok",
acr: func(s string) error { return nil },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckAuthorizationContextClassReference(&IDTokenClaims{}, tt.acr)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckAuthTime(t *testing.T) {
tests := []struct {
name string
claims Claims
maxAge time.Duration
wantErr error
}{
{
name: "no max age",
claims: &TokenClaims{},
},
{
name: "missing",
claims: &TokenClaims{},
maxAge: time.Minute,
wantErr: ErrAuthTimeNotPresent,
},
{
name: "expired",
maxAge: time.Minute,
claims: &TokenClaims{
AuthTime: FromTime(time.Now().Add(-time.Hour)),
},
wantErr: ErrAuthTimeToOld,
},
{
name: "ok",
maxAge: time.Minute,
claims: &TokenClaims{
AuthTime: NowTime(),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckAuthTime(tt.claims, tt.maxAge)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}

View file

@ -38,7 +38,7 @@ type Authorizer interface {
Storage() Storage Storage() Storage
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Encoder() httphelper.Encoder Encoder() httphelper.Encoder
IDTokenHintVerifier(context.Context) IDTokenHintVerifier IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
Crypto() Crypto Crypto() Crypto
RequestObjectSupported() bool RequestObjectSupported() bool
} }
@ -47,7 +47,7 @@ type Authorizer interface {
// implementing its own validation mechanism for the auth request // implementing its own validation mechanism for the auth request
type AuthorizeValidator interface { type AuthorizeValidator interface {
Authorizer Authorizer
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error) ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, *IDTokenHintVerifier) (string, error)
} }
func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) { func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) {
@ -204,7 +204,7 @@ func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oi
} }
// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed // ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) { func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) {
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge) authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
if err != nil { if err != nil {
return "", err return "", err
@ -384,7 +384,7 @@ func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType)
// ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request) // ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request)
// and returns the `sub` claim // and returns the `sub` claim
func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier IDTokenHintVerifier) (string, error) { func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier *IDTokenHintVerifier) (string, error) {
if idTokenHint == "" { if idTokenHint == "" {
return "", nil return "", nil
} }

View file

@ -12,6 +12,7 @@ import (
"github.com/gorilla/schema" "github.com/gorilla/schema"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v3/internal/testutil"
httphelper "github.com/zitadel/oidc/v3/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
@ -146,7 +147,7 @@ func TestValidateAuthRequest(t *testing.T) {
type args struct { type args struct {
authRequest *oidc.AuthRequest authRequest *oidc.AuthRequest
storage op.Storage storage op.Storage
verifier op.IDTokenHintVerifier verifier *op.IDTokenHintVerifier
} }
tests := []struct { tests := []struct {
name string name string
@ -1003,3 +1004,34 @@ func Test_parseAuthorizeCallbackRequest(t *testing.T) {
}) })
} }
} }
func TestValidateAuthReqIDTokenHint(t *testing.T) {
token, _ := tu.ValidIDToken()
tests := []struct {
name string
idTokenHint string
want string
wantErr error
}{
{
name: "empty",
},
{
name: "verify err",
idTokenHint: "foo",
wantErr: oidc.ErrLoginRequired(),
},
{
name: "ok",
idTokenHint: token,
want: tu.ValidSubject,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := op.ValidateAuthReqIDTokenHint(context.Background(), tt.idTokenHint, op.NewIDTokenHintVerifier(tu.ValidIssuer, tu.KeySet{}))
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -81,7 +81,7 @@ var (
) )
type ClientJWTProfile interface { type ClientJWTProfile interface {
JWTProfileVerifier(context.Context) JWTProfileVerifier JWTProfileVerifier(context.Context) *JWTProfileVerifier
} }
func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) { func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) {

View file

@ -22,7 +22,7 @@ import (
type testClientJWTProfile struct{} type testClientJWTProfile struct{}
func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil } func (testClientJWTProfile) JWTProfileVerifier(context.Context) *op.JWTProfileVerifier { return nil }
func TestClientJWTAuth(t *testing.T) { func TestClientJWTAuth(t *testing.T) {
type args struct { type args struct {

View file

@ -79,10 +79,10 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
} }
// IDTokenHintVerifier mocks base method. // IDTokenHintVerifier mocks base method.
func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) op.IDTokenHintVerifier { func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) *op.IDTokenHintVerifier {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0) ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0)
ret0, _ := ret[0].(op.IDTokenHintVerifier) ret0, _ := ret[0].(*op.IDTokenHintVerifier)
return ret0 return ret0
} }

View file

@ -49,7 +49,7 @@ func ExpectEncoder(a op.Authorizer) {
func ExpectVerifier(a op.Authorizer, t *testing.T) { func ExpectVerifier(a op.Authorizer, t *testing.T) {
mockA := a.(*MockAuthorizer) mockA := a.(*MockAuthorizer)
mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn( mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn(
func() op.IDTokenHintVerifier { func() *op.IDTokenHintVerifier {
return op.NewIDTokenHintVerifier("", nil) return op.NewIDTokenHintVerifier("", nil)
}) })
} }

View file

@ -73,8 +73,8 @@ type OpenIDProvider interface {
Storage() Storage Storage() Storage
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Encoder() httphelper.Encoder Encoder() httphelper.Encoder
IDTokenHintVerifier(context.Context) IDTokenHintVerifier IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
AccessTokenVerifier(context.Context) AccessTokenVerifier AccessTokenVerifier(context.Context) *AccessTokenVerifier
Crypto() Crypto Crypto() Crypto
DefaultLogoutRedirectURI() string DefaultLogoutRedirectURI() string
Probes() []ProbesFn Probes() []ProbesFn
@ -342,15 +342,15 @@ func (o *Provider) Encoder() httphelper.Encoder {
return o.encoder return o.encoder
} }
func (o *Provider) IDTokenHintVerifier(ctx context.Context) IDTokenHintVerifier { func (o *Provider) IDTokenHintVerifier(ctx context.Context) *IDTokenHintVerifier {
return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...) return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...)
} }
func (o *Provider) JWTProfileVerifier(ctx context.Context) JWTProfileVerifier { func (o *Provider) JWTProfileVerifier(ctx context.Context) *JWTProfileVerifier {
return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second) return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second)
} }
func (o *Provider) AccessTokenVerifier(ctx context.Context) AccessTokenVerifier { func (o *Provider) AccessTokenVerifier(ctx context.Context) *AccessTokenVerifier {
return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...) return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...)
} }

View file

@ -13,7 +13,7 @@ import (
type SessionEnder interface { type SessionEnder interface {
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Storage() Storage Storage() Storage
IDTokenHintVerifier(context.Context) IDTokenHintVerifier IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
DefaultLogoutRedirectURI() string DefaultLogoutRedirectURI() string
} }

View file

@ -13,7 +13,7 @@ type Introspector interface {
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Crypto() Crypto Crypto() Crypto
Storage() Storage Storage() Storage
AccessTokenVerifier(context.Context) AccessTokenVerifier AccessTokenVerifier(context.Context) *AccessTokenVerifier
} }
type IntrospectorJWTProfile interface { type IntrospectorJWTProfile interface {

View file

@ -11,7 +11,7 @@ import (
type JWTAuthorizationGrantExchanger interface { type JWTAuthorizationGrantExchanger interface {
Exchanger Exchanger
JWTProfileVerifier(context.Context) JWTProfileVerifier JWTProfileVerifier(context.Context) *JWTProfileVerifier
} }
// JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant https://tools.ietf.org/html/rfc7523#section-2.1 // JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant https://tools.ietf.org/html/rfc7523#section-2.1

View file

@ -20,8 +20,8 @@ type Exchanger interface {
GrantTypeJWTAuthorizationSupported() bool GrantTypeJWTAuthorizationSupported() bool
GrantTypeClientCredentialsSupported() bool GrantTypeClientCredentialsSupported() bool
GrantTypeDeviceCodeSupported() bool GrantTypeDeviceCodeSupported() bool
AccessTokenVerifier(context.Context) AccessTokenVerifier AccessTokenVerifier(context.Context) *AccessTokenVerifier
IDTokenHintVerifier(context.Context) IDTokenHintVerifier IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
} }
func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) { func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) {

View file

@ -15,14 +15,14 @@ type Revoker interface {
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Crypto() Crypto Crypto() Crypto
Storage() Storage Storage() Storage
AccessTokenVerifier(context.Context) AccessTokenVerifier AccessTokenVerifier(context.Context) *AccessTokenVerifier
AuthMethodPrivateKeyJWTSupported() bool AuthMethodPrivateKeyJWTSupported() bool
AuthMethodPostSupported() bool AuthMethodPostSupported() bool
} }
type RevokerJWTProfile interface { type RevokerJWTProfile interface {
Revoker Revoker
JWTProfileVerifier(context.Context) JWTProfileVerifier JWTProfileVerifier(context.Context) *JWTProfileVerifier
} }
func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) { func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) {

View file

@ -14,7 +14,7 @@ type UserinfoProvider interface {
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Crypto() Crypto Crypto() Crypto
Storage() Storage Storage() Storage
AccessTokenVerifier(context.Context) AccessTokenVerifier AccessTokenVerifier(context.Context) *AccessTokenVerifier
} }
func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) { func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) {

View file

@ -2,62 +2,25 @@ package op
import ( import (
"context" "context"
"time"
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
type AccessTokenVerifier interface { type AccessTokenVerifier oidc.Verifier
oidc.Verifier
SupportedSignAlgs() []string
KeySet() oidc.KeySet
}
type accessTokenVerifier struct { type AccessTokenVerifierOpt func(*AccessTokenVerifier)
issuer string
maxAgeIAT time.Duration
offset time.Duration
supportedSignAlgs []string
keySet oidc.KeySet
}
// Issuer implements oidc.Verifier interface
func (i *accessTokenVerifier) Issuer() string {
return i.issuer
}
// MaxAgeIAT implements oidc.Verifier interface
func (i *accessTokenVerifier) MaxAgeIAT() time.Duration {
return i.maxAgeIAT
}
// Offset implements oidc.Verifier interface
func (i *accessTokenVerifier) Offset() time.Duration {
return i.offset
}
// SupportedSignAlgs implements AccessTokenVerifier interface
func (i *accessTokenVerifier) SupportedSignAlgs() []string {
return i.supportedSignAlgs
}
// KeySet implements AccessTokenVerifier interface
func (i *accessTokenVerifier) KeySet() oidc.KeySet {
return i.keySet
}
type AccessTokenVerifierOpt func(*accessTokenVerifier)
func WithSupportedAccessTokenSigningAlgorithms(algs ...string) AccessTokenVerifierOpt { func WithSupportedAccessTokenSigningAlgorithms(algs ...string) AccessTokenVerifierOpt {
return func(verifier *accessTokenVerifier) { return func(verifier *AccessTokenVerifier) {
verifier.supportedSignAlgs = algs verifier.SupportedSignAlgs = algs
} }
} }
func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) AccessTokenVerifier { // NewAccessTokenVerifier returns a AccessTokenVerifier suitable for access token verification.
verifier := &accessTokenVerifier{ func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) *AccessTokenVerifier {
issuer: issuer, verifier := &AccessTokenVerifier{
keySet: keySet, Issuer: issuer,
KeySet: keySet,
} }
for _, opt := range opts { for _, opt := range opts {
opt(verifier) opt(verifier)
@ -66,7 +29,7 @@ func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTok
} }
// VerifyAccessToken validates the access token (issuer, signature and expiration). // VerifyAccessToken validates the access token (issuer, signature and expiration).
func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v AccessTokenVerifier) (claims C, err error) { func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v *AccessTokenVerifier) (claims C, err error) {
var nilClaims C var nilClaims C
decrypted, err := oidc.DecryptToken(token) decrypted, err := oidc.DecryptToken(token)
@ -78,15 +41,15 @@ func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v Acces
return nilClaims, err return nilClaims, err
} }
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { if err := oidc.CheckIssuer(claims, v.Issuer); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
return nilClaims, err return nilClaims, err
} }

View file

@ -20,7 +20,7 @@ func TestNewAccessTokenVerifier(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
want AccessTokenVerifier want *AccessTokenVerifier
}{ }{
{ {
name: "simple", name: "simple",
@ -28,9 +28,9 @@ func TestNewAccessTokenVerifier(t *testing.T) {
issuer: tu.ValidIssuer, issuer: tu.ValidIssuer,
keySet: tu.KeySet{}, keySet: tu.KeySet{},
}, },
want: &accessTokenVerifier{ want: &AccessTokenVerifier{
issuer: tu.ValidIssuer, Issuer: tu.ValidIssuer,
keySet: tu.KeySet{}, KeySet: tu.KeySet{},
}, },
}, },
{ {
@ -42,10 +42,10 @@ func TestNewAccessTokenVerifier(t *testing.T) {
WithSupportedAccessTokenSigningAlgorithms("ABC", "DEF"), WithSupportedAccessTokenSigningAlgorithms("ABC", "DEF"),
}, },
}, },
want: &accessTokenVerifier{ want: &AccessTokenVerifier{
issuer: tu.ValidIssuer, Issuer: tu.ValidIssuer,
keySet: tu.KeySet{}, KeySet: tu.KeySet{},
supportedSignAlgs: []string{"ABC", "DEF"}, SupportedSignAlgs: []string{"ABC", "DEF"},
}, },
}, },
} }
@ -58,12 +58,12 @@ func TestNewAccessTokenVerifier(t *testing.T) {
} }
func TestVerifyAccessToken(t *testing.T) { func TestVerifyAccessToken(t *testing.T) {
verifier := &accessTokenVerifier{ verifier := &AccessTokenVerifier{
issuer: tu.ValidIssuer, Issuer: tu.ValidIssuer,
maxAgeIAT: 2 * time.Minute, MaxAgeIAT: 2 * time.Minute,
offset: time.Second, Offset: time.Second,
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
keySet: tu.KeySet{}, KeySet: tu.KeySet{},
} }
tests := []struct { tests := []struct {

View file

@ -2,69 +2,24 @@ package op
import ( import (
"context" "context"
"time"
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
type IDTokenHintVerifier interface { type IDTokenHintVerifier oidc.Verifier
oidc.Verifier
SupportedSignAlgs() []string
KeySet() oidc.KeySet
ACR() oidc.ACRVerifier
MaxAge() time.Duration
}
type idTokenHintVerifier struct { type IDTokenHintVerifierOpt func(*IDTokenHintVerifier)
issuer string
maxAgeIAT time.Duration
offset time.Duration
supportedSignAlgs []string
maxAge time.Duration
acr oidc.ACRVerifier
keySet oidc.KeySet
}
func (i *idTokenHintVerifier) Issuer() string {
return i.issuer
}
func (i *idTokenHintVerifier) MaxAgeIAT() time.Duration {
return i.maxAgeIAT
}
func (i *idTokenHintVerifier) Offset() time.Duration {
return i.offset
}
func (i *idTokenHintVerifier) SupportedSignAlgs() []string {
return i.supportedSignAlgs
}
func (i *idTokenHintVerifier) KeySet() oidc.KeySet {
return i.keySet
}
func (i *idTokenHintVerifier) ACR() oidc.ACRVerifier {
return i.acr
}
func (i *idTokenHintVerifier) MaxAge() time.Duration {
return i.maxAge
}
type IDTokenHintVerifierOpt func(*idTokenHintVerifier)
func WithSupportedIDTokenHintSigningAlgorithms(algs ...string) IDTokenHintVerifierOpt { func WithSupportedIDTokenHintSigningAlgorithms(algs ...string) IDTokenHintVerifierOpt {
return func(verifier *idTokenHintVerifier) { return func(verifier *IDTokenHintVerifier) {
verifier.supportedSignAlgs = algs verifier.SupportedSignAlgs = algs
} }
} }
func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) IDTokenHintVerifier { func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) *IDTokenHintVerifier {
verifier := &idTokenHintVerifier{ verifier := &IDTokenHintVerifier{
issuer: issuer, Issuer: issuer,
keySet: keySet, KeySet: keySet,
} }
for _, opt := range opts { for _, opt := range opts {
opt(verifier) opt(verifier)
@ -74,7 +29,7 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHi
// VerifyIDTokenHint validates the id token according to // VerifyIDTokenHint validates the id token according to
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTokenHintVerifier) (claims C, err error) { func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v *IDTokenHintVerifier) (claims C, err error) {
var nilClaims C var nilClaims C
decrypted, err := oidc.DecryptToken(token) decrypted, err := oidc.DecryptToken(token)
@ -86,27 +41,27 @@ func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTok
return nilClaims, err return nilClaims, err
} }
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { if err := oidc.CheckIssuer(claims, v.Issuer); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil { if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil { if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil { if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil {
return nilClaims, err return nilClaims, err
} }
return claims, nil return claims, nil

View file

@ -20,7 +20,7 @@ func TestNewIDTokenHintVerifier(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
want IDTokenHintVerifier want *IDTokenHintVerifier
}{ }{
{ {
name: "simple", name: "simple",
@ -28,9 +28,9 @@ func TestNewIDTokenHintVerifier(t *testing.T) {
issuer: tu.ValidIssuer, issuer: tu.ValidIssuer,
keySet: tu.KeySet{}, keySet: tu.KeySet{},
}, },
want: &idTokenHintVerifier{ want: &IDTokenHintVerifier{
issuer: tu.ValidIssuer, Issuer: tu.ValidIssuer,
keySet: tu.KeySet{}, KeySet: tu.KeySet{},
}, },
}, },
{ {
@ -42,10 +42,10 @@ func TestNewIDTokenHintVerifier(t *testing.T) {
WithSupportedIDTokenHintSigningAlgorithms("ABC", "DEF"), WithSupportedIDTokenHintSigningAlgorithms("ABC", "DEF"),
}, },
}, },
want: &idTokenHintVerifier{ want: &IDTokenHintVerifier{
issuer: tu.ValidIssuer, Issuer: tu.ValidIssuer,
keySet: tu.KeySet{}, KeySet: tu.KeySet{},
supportedSignAlgs: []string{"ABC", "DEF"}, SupportedSignAlgs: []string{"ABC", "DEF"},
}, },
}, },
} }
@ -58,14 +58,14 @@ func TestNewIDTokenHintVerifier(t *testing.T) {
} }
func TestVerifyIDTokenHint(t *testing.T) { func TestVerifyIDTokenHint(t *testing.T) {
verifier := &idTokenHintVerifier{ verifier := &IDTokenHintVerifier{
issuer: tu.ValidIssuer, Issuer: tu.ValidIssuer,
maxAgeIAT: 2 * time.Minute, MaxAgeIAT: 2 * time.Minute,
offset: time.Second, Offset: time.Second,
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
maxAge: 2 * time.Minute, MaxAge: 2 * time.Minute,
acr: tu.ACRVerify, ACR: tu.ACRVerify,
keySet: tu.KeySet{}, KeySet: tu.KeySet{},
} }
tests := []struct { tests := []struct {

View file

@ -11,28 +11,25 @@ import (
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
type JWTProfileVerifier interface { // JWTProfileVerfiier extends oidc.Verifier with
// a jwtProfileKeyStorage and a function to check
// the subject in a token.
type JWTProfileVerifier struct {
oidc.Verifier oidc.Verifier
Storage() jwtProfileKeyStorage Storage JWTProfileKeyStorage
CheckSubject(request *oidc.JWTTokenRequest) error CheckSubject func(request *oidc.JWTTokenRequest) error
}
type jwtProfileVerifier struct {
storage jwtProfileKeyStorage
subjectCheck func(request *oidc.JWTTokenRequest) error
issuer string
maxAgeIAT time.Duration
offset time.Duration
} }
// NewJWTProfileVerifier creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication) // NewJWTProfileVerifier creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication)
func NewJWTProfileVerifier(storage jwtProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) JWTProfileVerifier { func NewJWTProfileVerifier(storage JWTProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier {
j := &jwtProfileVerifier{ j := &JWTProfileVerifier{
storage: storage, Verifier: oidc.Verifier{
subjectCheck: SubjectIsIssuer, Issuer: issuer,
issuer: issuer, MaxAgeIAT: maxAgeIAT,
maxAgeIAT: maxAgeIAT, Offset: offset,
offset: offset, },
Storage: storage,
CheckSubject: SubjectIsIssuer,
} }
for _, opt := range opts { for _, opt := range opts {
@ -42,53 +39,35 @@ func NewJWTProfileVerifier(storage jwtProfileKeyStorage, issuer string, maxAgeIA
return j return j
} }
type JWTProfileVerifierOption func(*jwtProfileVerifier) type JWTProfileVerifierOption func(*JWTProfileVerifier)
// SubjectCheck sets a custom function to check the subject.
// Defaults to SubjectIsIssuer()
func SubjectCheck(check func(request *oidc.JWTTokenRequest) error) JWTProfileVerifierOption { func SubjectCheck(check func(request *oidc.JWTTokenRequest) error) JWTProfileVerifierOption {
return func(verifier *jwtProfileVerifier) { return func(verifier *JWTProfileVerifier) {
verifier.subjectCheck = check verifier.CheckSubject = check
} }
} }
func (v *jwtProfileVerifier) Issuer() string {
return v.issuer
}
func (v *jwtProfileVerifier) Storage() jwtProfileKeyStorage {
return v.storage
}
func (v *jwtProfileVerifier) MaxAgeIAT() time.Duration {
return v.maxAgeIAT
}
func (v *jwtProfileVerifier) Offset() time.Duration {
return v.offset
}
func (v *jwtProfileVerifier) CheckSubject(request *oidc.JWTTokenRequest) error {
return v.subjectCheck(request)
}
// VerifyJWTAssertion verifies the assertion string from JWT Profile (authorization grant and client authentication) // VerifyJWTAssertion verifies the assertion string from JWT Profile (authorization grant and client authentication)
// //
// checks audience, exp, iat, signature and that issuer and sub are the same // checks audience, exp, iat, signature and that issuer and sub are the same
func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) { func VerifyJWTAssertion(ctx context.Context, assertion string, v *JWTProfileVerifier) (*oidc.JWTTokenRequest, error) {
request := new(oidc.JWTTokenRequest) request := new(oidc.JWTTokenRequest)
payload, err := oidc.ParseToken(assertion, request) payload, err := oidc.ParseToken(assertion, request)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err = oidc.CheckAudience(request, v.Issuer()); err != nil { if err = oidc.CheckAudience(request, v.Issuer); err != nil {
return nil, err return nil, err
} }
if err = oidc.CheckExpiration(request, v.Offset()); err != nil { if err = oidc.CheckExpiration(request, v.Offset); err != nil {
return nil, err return nil, err
} }
if err = oidc.CheckIssuedAt(request, v.MaxAgeIAT(), v.Offset()); err != nil { if err = oidc.CheckIssuedAt(request, v.MaxAgeIAT, v.Offset); err != nil {
return nil, err return nil, err
} }
@ -96,17 +75,18 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif
return nil, err return nil, err
} }
keySet := &jwtProfileKeySet{storage: v.Storage(), clientID: request.Issuer} keySet := &jwtProfileKeySet{storage: v.Storage, clientID: request.Issuer}
if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil { if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil {
return nil, err return nil, err
} }
return request, nil return request, nil
} }
type jwtProfileKeyStorage interface { type JWTProfileKeyStorage interface {
GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
} }
// SubjectIsIssuer
func SubjectIsIssuer(request *oidc.JWTTokenRequest) error { func SubjectIsIssuer(request *oidc.JWTTokenRequest) error {
if request.Issuer != request.Subject { if request.Issuer != request.Subject {
return errors.New("delegation not allowed, issuer and sub must be identical") return errors.New("delegation not allowed, issuer and sub must be identical")
@ -115,7 +95,7 @@ func SubjectIsIssuer(request *oidc.JWTTokenRequest) error {
} }
type jwtProfileKeySet struct { type jwtProfileKeySet struct {
storage jwtProfileKeyStorage storage JWTProfileKeyStorage
clientID string clientID string
} }

View file

@ -0,0 +1,117 @@
package op_test
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
)
func TestNewJWTProfileVerifier(t *testing.T) {
want := &op.JWTProfileVerifier{
Verifier: oidc.Verifier{
Issuer: tu.ValidIssuer,
MaxAgeIAT: time.Minute,
Offset: time.Second,
},
Storage: tu.JWTProfileKeyStorage{},
}
got := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, time.Second, op.SubjectCheck(func(request *oidc.JWTTokenRequest) error {
return oidc.ErrSubjectMissing
}))
assert.Equal(t, want.Verifier, got.Verifier)
assert.Equal(t, want.Storage, got.Storage)
assert.ErrorIs(t, got.CheckSubject(nil), oidc.ErrSubjectMissing)
}
func TestVerifyJWTAssertion(t *testing.T) {
errCtx, cancel := context.WithCancel(context.Background())
cancel()
verifier := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, 0)
tests := []struct {
name string
ctx context.Context
newToken func() (string, *oidc.JWTTokenRequest)
wantErr bool
}{
{
name: "parse error",
ctx: context.Background(),
newToken: func() (string, *oidc.JWTTokenRequest) { return "!", nil },
wantErr: true,
},
{
name: "wrong audience",
ctx: context.Background(),
newToken: func() (string, *oidc.JWTTokenRequest) {
return tu.NewJWTProfileAssertion(
tu.ValidClientID, tu.ValidClientID, []string{"wrong"},
time.Now(), tu.ValidExpiration,
)
},
wantErr: true,
},
{
name: "expired",
ctx: context.Background(),
newToken: func() (string, *oidc.JWTTokenRequest) {
return tu.NewJWTProfileAssertion(
tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer},
time.Now(), time.Now().Add(-time.Hour),
)
},
wantErr: true,
},
{
name: "invalid iat",
ctx: context.Background(),
newToken: func() (string, *oidc.JWTTokenRequest) {
return tu.NewJWTProfileAssertion(
tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer},
time.Now().Add(time.Hour), tu.ValidExpiration,
)
},
wantErr: true,
},
{
name: "invalid subject",
ctx: context.Background(),
newToken: func() (string, *oidc.JWTTokenRequest) {
return tu.NewJWTProfileAssertion(
tu.ValidClientID, "wrong", []string{tu.ValidIssuer},
time.Now(), tu.ValidExpiration,
)
},
wantErr: true,
},
{
name: "check signature fail",
ctx: errCtx,
newToken: tu.ValidJWTProfileAssertion,
wantErr: true,
},
{
name: "ok",
ctx: context.Background(),
newToken: tu.ValidJWTProfileAssertion,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertion, want := tt.newToken()
got, err := op.VerifyJWTAssertion(tt.ctx, assertion, verifier)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, want, got)
})
}
}