impelement unit tests for the token Verifiers
This commit is contained in:
parent
d41f4b5d21
commit
7b613c63eb
8 changed files with 786 additions and 7 deletions
143
internal/testutil/token.go
Normal file
143
internal/testutil/token.go
Normal file
|
@ -0,0 +1,143 @@
|
||||||
|
// Package testuril helps setting up required data for testing,
|
||||||
|
// such as tokens, claims and verifiers.
|
||||||
|
package testutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||||
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const SignatureAlgorithm = jose.PS512
|
||||||
|
|
||||||
|
// KeySet implements oidc.Keys and
|
||||||
|
// additionally can create tokens and claims that can
|
||||||
|
// be validated by this KeySet.
|
||||||
|
type KeySet struct {
|
||||||
|
Private *rsa.PrivateKey
|
||||||
|
Public *rsa.PublicKey
|
||||||
|
|
||||||
|
Signer jose.Signer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewKeySet() *KeySet {
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: SignatureAlgorithm, Key: privateKey}, nil)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return &KeySet{
|
||||||
|
Private: privateKey,
|
||||||
|
Public: &privateKey.PublicKey,
|
||||||
|
Signer: signer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KeySet) signEncodeTokenClaims(claims any) string {
|
||||||
|
payload, err := json.Marshal(claims)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
object, err := k.Signer.Sign(payload)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
token, err := object.CompactSerialize()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
func claimsMap(claims any) map[string]any {
|
||||||
|
data, err := json.Marshal(claims)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
dst := make(map[string]any)
|
||||||
|
if err = json.Unmarshal(data, &dst); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIDToken creates a new IDTokenClaims with passed data and returns a signed token and claims.
|
||||||
|
func (k *KeySet) NewIDToken(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration, atHash string) (string, *oidc.IDTokenClaims) {
|
||||||
|
claims := oidc.NewIDTokenClaims(issuer, subject, audience, expiration, authTime, nonce, acr, amr, clientID, skew)
|
||||||
|
claims.AccessTokenHash = atHash
|
||||||
|
token := k.signEncodeTokenClaims(claims)
|
||||||
|
|
||||||
|
// set this so that assertion in tests will work
|
||||||
|
claims.SignatureAlg = SignatureAlgorithm
|
||||||
|
claims.Claims = claimsMap(claims)
|
||||||
|
return token, claims
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAcccessToken creates a new AccessTokenClaims with passed data and returns a signed token and claims.
|
||||||
|
func (k *KeySet) NewAccessToken(issuer, subject string, audience []string, expiration time.Time, jwtid, clientID string, skew time.Duration) (string, *oidc.AccessTokenClaims) {
|
||||||
|
claims := oidc.NewAccessTokenClaims(issuer, subject, audience, expiration, jwtid, clientID, skew)
|
||||||
|
token := k.signEncodeTokenClaims(claims)
|
||||||
|
|
||||||
|
// set this so that assertion in tests will work
|
||||||
|
claims.SignatureAlg = SignatureAlgorithm
|
||||||
|
claims.Claims = claimsMap(claims)
|
||||||
|
return token, claims
|
||||||
|
}
|
||||||
|
|
||||||
|
const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg`
|
||||||
|
|
||||||
|
// These variables always result in a valid token
|
||||||
|
// for the same test run.
|
||||||
|
var (
|
||||||
|
ValidIssuer = "local.com"
|
||||||
|
ValidSubject = "tim@local.com"
|
||||||
|
ValidAudience = []string{"unit", "test"}
|
||||||
|
ValidAuthTime = time.Now().Add(-time.Minute) // authtime is always 1 minute in the past
|
||||||
|
ValidExpiration = ValidAuthTime.Add(2 * time.Minute) // token is always 1 more minute available
|
||||||
|
ValidJWTID = "9876"
|
||||||
|
ValidNonce = "12345"
|
||||||
|
ValidACR = "something"
|
||||||
|
ValidAMR = []string{"foo", "bar"}
|
||||||
|
ValidClientID = "555666"
|
||||||
|
ValidSkew = time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidIDToken returns a token and claims that are in the token.
|
||||||
|
// It uses the Valid* global variables and the token always passes
|
||||||
|
// verification within the same test run.
|
||||||
|
func (k *KeySet) ValidIDToken() (string, *oidc.IDTokenClaims) {
|
||||||
|
return k.NewIDToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidAuthTime, ValidNonce, ValidACR, ValidAMR, ValidClientID, ValidSkew, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidAccessToken returns a token and claims that are in the token.
|
||||||
|
// It uses the Valid* global variables and the token always passes
|
||||||
|
// verification within the same test run.
|
||||||
|
func (k *KeySet) ValidAccessToken() (string, *oidc.AccessTokenClaims) {
|
||||||
|
return k.NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifySignature implments op.KeySet.
|
||||||
|
func (k *KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return jws.Verify(k.Public)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ACRVerify is a oidc.ACRVerifier func.
|
||||||
|
func ACRVerify(acr string) error {
|
||||||
|
if acr != ValidACR {
|
||||||
|
return errors.New("invalid acr")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -21,17 +21,17 @@ type IDTokenVerifier interface {
|
||||||
|
|
||||||
// VerifyTokens implement the Token Response Validation as defined in OIDC specification
|
// VerifyTokens implement the Token Response Validation as defined in OIDC specification
|
||||||
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
||||||
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (claims C, err error) {
|
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) {
|
||||||
var nilClaims C
|
var nilClaims C
|
||||||
|
|
||||||
idToken, err := VerifyIDToken[C](ctx, idTokenString, v)
|
claims, err = VerifyIDToken[C](ctx, idToken, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil {
|
if err := VerifyAccessToken(accessToken, claims.GetAccessTokenHash(), claims.GetSignatureAlgorithm()); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
return idToken, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyIDToken validates the id token according to
|
// VerifyIDToken validates the id token according to
|
||||||
|
@ -114,7 +114,7 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
clientID: clientID,
|
clientID: clientID,
|
||||||
keySet: keySet,
|
keySet: keySet,
|
||||||
offset: 1 * time.Second,
|
offset: time.Second,
|
||||||
nonce: func(_ context.Context) string {
|
nonce: func(_ context.Context) string {
|
||||||
return ""
|
return ""
|
||||||
},
|
},
|
||||||
|
|
343
pkg/client/rp/verifier_test.go
Normal file
343
pkg/client/rp/verifier_test.go
Normal file
|
@ -0,0 +1,343 @@
|
||||||
|
package rp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||||
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestVerifyTokens(t *testing.T) {
|
||||||
|
keySet := tu.NewKeySet()
|
||||||
|
verifier := &idTokenVerifier{
|
||||||
|
issuer: tu.ValidIssuer,
|
||||||
|
maxAgeIAT: 2 * time.Minute,
|
||||||
|
offset: time.Second,
|
||||||
|
supportedSignAlgs: []string{string(jose.PS512)},
|
||||||
|
keySet: keySet,
|
||||||
|
maxAge: 2 * time.Minute,
|
||||||
|
acr: tu.ACRVerify,
|
||||||
|
nonce: func(context.Context) string { return tu.ValidNonce },
|
||||||
|
clientID: tu.ValidClientID,
|
||||||
|
}
|
||||||
|
accessToken, _ := keySet.ValidAccessToken()
|
||||||
|
atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accessToken string
|
||||||
|
idTokenClaims func() (string, *oidc.IDTokenClaims)
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "without access token",
|
||||||
|
idTokenClaims: keySet.ValidIDToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with access token",
|
||||||
|
accessToken: accessToken,
|
||||||
|
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return keySet.NewIDToken(
|
||||||
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||||
|
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired id token",
|
||||||
|
accessToken: accessToken,
|
||||||
|
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return keySet.NewIDToken(
|
||||||
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
|
||||||
|
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wronf access token",
|
||||||
|
accessToken: accessToken,
|
||||||
|
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return keySet.NewIDToken(
|
||||||
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||||
|
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "~~~",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
idToken, want := tt.idTokenClaims()
|
||||||
|
got, err := VerifyTokens[*oidc.IDTokenClaims](context.Background(), tt.accessToken, idToken, verifier)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, got)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, got)
|
||||||
|
assert.Equal(t, got, want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyIDToken(t *testing.T) {
|
||||||
|
keySet := tu.NewKeySet()
|
||||||
|
verifier := &idTokenVerifier{
|
||||||
|
issuer: tu.ValidIssuer,
|
||||||
|
maxAgeIAT: 2 * time.Minute,
|
||||||
|
offset: time.Second,
|
||||||
|
supportedSignAlgs: []string{string(jose.PS512)},
|
||||||
|
keySet: keySet,
|
||||||
|
maxAge: 2 * time.Minute,
|
||||||
|
acr: tu.ACRVerify,
|
||||||
|
nonce: func(context.Context) string { return tu.ValidNonce },
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
clientID string
|
||||||
|
tokenClaims func() (string, *oidc.IDTokenClaims)
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
clientID: tu.ValidClientID,
|
||||||
|
tokenClaims: keySet.ValidIDToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parse err",
|
||||||
|
clientID: tu.ValidClientID,
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid signature",
|
||||||
|
clientID: tu.ValidClientID,
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty subject",
|
||||||
|
clientID: tu.ValidClientID,
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return keySet.NewIDToken(
|
||||||
|
tu.ValidIssuer, "", tu.ValidAudience,
|
||||||
|
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||||
|
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong issuer",
|
||||||
|
clientID: tu.ValidClientID,
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return keySet.NewIDToken(
|
||||||
|
"foo", tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||||
|
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong clientID",
|
||||||
|
clientID: "foo",
|
||||||
|
tokenClaims: keySet.ValidIDToken,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired",
|
||||||
|
clientID: tu.ValidClientID,
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return keySet.NewIDToken(
|
||||||
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
|
||||||
|
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong IAT",
|
||||||
|
clientID: tu.ValidClientID,
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return keySet.NewIDToken(
|
||||||
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||||
|
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong acr",
|
||||||
|
clientID: tu.ValidClientID,
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return keySet.NewIDToken(
|
||||||
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||||
|
"else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired auth",
|
||||||
|
clientID: tu.ValidClientID,
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return keySet.NewIDToken(
|
||||||
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration, tu.ValidAuthTime.Add(-time.Hour), tu.ValidNonce,
|
||||||
|
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong nonce",
|
||||||
|
clientID: tu.ValidClientID,
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return keySet.NewIDToken(
|
||||||
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration, tu.ValidAuthTime, "foo",
|
||||||
|
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
token, want := tt.tokenClaims()
|
||||||
|
verifier.clientID = tt.clientID
|
||||||
|
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, got)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, got)
|
||||||
|
assert.Equal(t, got, want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyAccessToken(t *testing.T) {
|
||||||
|
keySet := tu.NewKeySet()
|
||||||
|
token, _ := keySet.ValidAccessToken()
|
||||||
|
hash, err := oidc.ClaimHash(token, tu.SignatureAlgorithm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
accessToken string
|
||||||
|
atHash string
|
||||||
|
sigAlgorithm jose.SignatureAlgorithm
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty hash",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
args: args{
|
||||||
|
accessToken: token,
|
||||||
|
atHash: hash,
|
||||||
|
sigAlgorithm: tu.SignatureAlgorithm,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid algorithm",
|
||||||
|
args: args{
|
||||||
|
accessToken: token,
|
||||||
|
atHash: hash,
|
||||||
|
sigAlgorithm: "foo",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mismatch",
|
||||||
|
args: args{
|
||||||
|
accessToken: token,
|
||||||
|
atHash: "~~",
|
||||||
|
sigAlgorithm: tu.SignatureAlgorithm,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := VerifyAccessToken(tt.args.accessToken, tt.args.atHash, tt.args.sigAlgorithm)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewIDTokenVerifier(t *testing.T) {
|
||||||
|
keySet := tu.NewKeySet()
|
||||||
|
type args struct {
|
||||||
|
issuer string
|
||||||
|
clientID string
|
||||||
|
keySet oidc.KeySet
|
||||||
|
options []VerifierOption
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want IDTokenVerifier
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil nonce", // otherwise assert.Equal will fail on the function
|
||||||
|
args: args{
|
||||||
|
issuer: tu.ValidIssuer,
|
||||||
|
clientID: tu.ValidClientID,
|
||||||
|
keySet: keySet,
|
||||||
|
options: []VerifierOption{
|
||||||
|
WithIssuedAtOffset(time.Minute),
|
||||||
|
//WithIssuedAtMaxAge(time.Hour),
|
||||||
|
WithNonce(nil), // otherwise assert.Equal will fail on the function
|
||||||
|
WithACRVerifier(nil),
|
||||||
|
WithAuthTimeMaxAge(2 * time.Hour),
|
||||||
|
WithSupportedSigningAlgorithms("ABC", "DEF"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: &idTokenVerifier{
|
||||||
|
issuer: tu.ValidIssuer,
|
||||||
|
offset: time.Minute,
|
||||||
|
//maxAgeIAT: time.Hour, // Maybe BUG?
|
||||||
|
clientID: tu.ValidClientID,
|
||||||
|
keySet: keySet,
|
||||||
|
nonce: nil,
|
||||||
|
acr: nil,
|
||||||
|
maxAge: 2 * time.Hour,
|
||||||
|
supportedSignAlgs: []string{"ABC", "DEF"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := NewIDTokenVerifier(tt.args.issuer, tt.args.clientID, tt.args.keySet, tt.args.options...)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -144,6 +144,7 @@ func NewIDTokenClaims(issuer, subject string, audience []string, expiration, aut
|
||||||
AuthenticationContextClassReference: acr,
|
AuthenticationContextClassReference: acr,
|
||||||
AuthenticationMethodsReferences: amr,
|
AuthenticationMethodsReferences: amr,
|
||||||
AuthorizedParty: clientID,
|
AuthorizedParty: clientID,
|
||||||
|
ClientID: clientID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -196,6 +196,7 @@ func TestNewIDTokenClaims(t *testing.T) {
|
||||||
AuthenticationContextClassReference: "something",
|
AuthenticationContextClassReference: "something",
|
||||||
AuthenticationMethodsReferences: []string{"some", "methods"},
|
AuthenticationMethodsReferences: []string{"some", "methods"},
|
||||||
AuthorizedParty: "just@me.com",
|
AuthorizedParty: "just@me.com",
|
||||||
|
ClientID: "just@me.com",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,8 +18,6 @@ type accessTokenVerifier struct {
|
||||||
maxAgeIAT time.Duration
|
maxAgeIAT time.Duration
|
||||||
offset time.Duration
|
offset time.Duration
|
||||||
supportedSignAlgs []string
|
supportedSignAlgs []string
|
||||||
maxAge time.Duration
|
|
||||||
acr oidc.ACRVerifier
|
|
||||||
keySet oidc.KeySet
|
keySet oidc.KeySet
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
129
pkg/op/verifier_access_token_test.go
Normal file
129
pkg/op/verifier_access_token_test.go
Normal file
|
@ -0,0 +1,129 @@
|
||||||
|
package op
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||||
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewAccessTokenVerifier(t *testing.T) {
|
||||||
|
keySet := tu.NewKeySet()
|
||||||
|
type args struct {
|
||||||
|
issuer string
|
||||||
|
keySet oidc.KeySet
|
||||||
|
opts []AccessTokenVerifierOpt
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want AccessTokenVerifier
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple",
|
||||||
|
args: args{
|
||||||
|
issuer: tu.ValidIssuer,
|
||||||
|
keySet: keySet,
|
||||||
|
},
|
||||||
|
want: &accessTokenVerifier{
|
||||||
|
issuer: tu.ValidIssuer,
|
||||||
|
keySet: keySet,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with signature algorithm",
|
||||||
|
args: args{
|
||||||
|
issuer: tu.ValidIssuer,
|
||||||
|
keySet: keySet,
|
||||||
|
opts: []AccessTokenVerifierOpt{
|
||||||
|
WithSupportedAccessTokenSigningAlgorithms("ABC", "DEF"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: &accessTokenVerifier{
|
||||||
|
issuer: tu.ValidIssuer,
|
||||||
|
keySet: keySet,
|
||||||
|
supportedSignAlgs: []string{"ABC", "DEF"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := NewAccessTokenVerifier(tt.args.issuer, tt.args.keySet, tt.args.opts...)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyAccessToken(t *testing.T) {
|
||||||
|
keySet := tu.NewKeySet()
|
||||||
|
verifier := &accessTokenVerifier{
|
||||||
|
issuer: tu.ValidIssuer,
|
||||||
|
maxAgeIAT: 2 * time.Minute,
|
||||||
|
offset: time.Second,
|
||||||
|
supportedSignAlgs: []string{string(jose.PS512)},
|
||||||
|
keySet: keySet,
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tokenClaims func() (string, *oidc.AccessTokenClaims)
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
tokenClaims: keySet.ValidAccessToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parse err",
|
||||||
|
tokenClaims: func() (string, *oidc.AccessTokenClaims) { return "~~~~", nil },
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid signature",
|
||||||
|
tokenClaims: func() (string, *oidc.AccessTokenClaims) { return tu.InvalidSignatureToken, nil },
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong issuer",
|
||||||
|
tokenClaims: func() (string, *oidc.AccessTokenClaims) {
|
||||||
|
return keySet.NewAccessToken(
|
||||||
|
"foo", tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration, tu.ValidJWTID, tu.ValidClientID,
|
||||||
|
tu.ValidSkew,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired",
|
||||||
|
tokenClaims: func() (string, *oidc.AccessTokenClaims) {
|
||||||
|
return keySet.NewAccessToken(
|
||||||
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration.Add(-time.Hour), tu.ValidJWTID, tu.ValidClientID,
|
||||||
|
tu.ValidSkew,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
token, want := tt.tokenClaims()
|
||||||
|
|
||||||
|
got, err := VerifyAccessToken[*oidc.AccessTokenClaims](context.Background(), token, verifier)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, got)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, got)
|
||||||
|
assert.Equal(t, got, want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
164
pkg/op/verifier_id_token_hint_test.go
Normal file
164
pkg/op/verifier_id_token_hint_test.go
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
package op
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||||
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewIDTokenHintVerifier(t *testing.T) {
|
||||||
|
keySet := tu.NewKeySet()
|
||||||
|
type args struct {
|
||||||
|
issuer string
|
||||||
|
keySet oidc.KeySet
|
||||||
|
opts []IDTokenHintVerifierOpt
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want IDTokenHintVerifier
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple",
|
||||||
|
args: args{
|
||||||
|
issuer: tu.ValidIssuer,
|
||||||
|
keySet: keySet,
|
||||||
|
},
|
||||||
|
want: &idTokenHintVerifier{
|
||||||
|
issuer: tu.ValidIssuer,
|
||||||
|
keySet: keySet,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with signature algorithm",
|
||||||
|
args: args{
|
||||||
|
issuer: tu.ValidIssuer,
|
||||||
|
keySet: keySet,
|
||||||
|
opts: []IDTokenHintVerifierOpt{
|
||||||
|
WithSupportedIDTokenHintSigningAlgorithms("ABC", "DEF"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: &idTokenHintVerifier{
|
||||||
|
issuer: tu.ValidIssuer,
|
||||||
|
keySet: keySet,
|
||||||
|
supportedSignAlgs: []string{"ABC", "DEF"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := NewIDTokenHintVerifier(tt.args.issuer, tt.args.keySet, tt.args.opts...)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyIDTokenHint(t *testing.T) {
|
||||||
|
keySet := tu.NewKeySet()
|
||||||
|
verifier := &idTokenHintVerifier{
|
||||||
|
issuer: tu.ValidIssuer,
|
||||||
|
maxAgeIAT: 2 * time.Minute,
|
||||||
|
offset: time.Second,
|
||||||
|
supportedSignAlgs: []string{string(jose.PS512)},
|
||||||
|
maxAge: 2 * time.Minute,
|
||||||
|
acr: tu.ACRVerify,
|
||||||
|
keySet: keySet,
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tokenClaims func() (string, *oidc.IDTokenClaims)
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
tokenClaims: keySet.ValidIDToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parse err",
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid signature",
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong issuer",
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return keySet.NewIDToken(
|
||||||
|
"foo", tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||||
|
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired",
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return keySet.NewIDToken(
|
||||||
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
|
||||||
|
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong IAT",
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return keySet.NewIDToken(
|
||||||
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||||
|
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong acr",
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return keySet.NewIDToken(
|
||||||
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||||
|
"else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired auth",
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return keySet.NewIDToken(
|
||||||
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration, tu.ValidAuthTime.Add(-time.Hour), tu.ValidNonce,
|
||||||
|
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
token, want := tt.tokenClaims()
|
||||||
|
|
||||||
|
got, err := VerifyIDTokenHint[*oidc.IDTokenClaims](context.Background(), token, verifier)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, got)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, got)
|
||||||
|
assert.Equal(t, got, want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue