fix: Handle case where verifier Nonce func is nil (#594)

* Skip nonce check if verifier nonce func is nil

* add unit test
This commit is contained in:
Yuval Marcus 2024-05-02 03:46:12 -04:00 committed by GitHub
parent 37ca0e472a
commit 24d43f538e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 38 additions and 28 deletions

View file

@ -73,8 +73,10 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenV
return nilClaims, err
}
if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil {
return nilClaims, err
if v.Nonce != nil {
if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil {
return nilClaims, err
}
}
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil {

View file

@ -100,22 +100,21 @@ func TestVerifyIDToken(t *testing.T) {
MaxAge: 2 * time.Minute,
ACR: tu.ACRVerify,
Nonce: func(context.Context) string { return tu.ValidNonce },
ClientID: tu.ValidClientID,
}
tests := []struct {
name string
clientID string
tokenClaims func() (string, *oidc.IDTokenClaims)
wantErr bool
name string
tokenClaims func() (string, *oidc.IDTokenClaims)
customVerifier func(verifier *IDTokenVerifier)
wantErr bool
}{
{
name: "success",
clientID: tu.ValidClientID,
tokenClaims: tu.ValidIDToken,
},
{
name: "custom claims",
clientID: tu.ValidClientID,
name: "custom claims",
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDTokenCustom(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
@ -125,21 +124,31 @@ func TestVerifyIDToken(t *testing.T) {
)
},
},
{
name: "skip nonce check",
customVerifier: func(verifier *IDTokenVerifier) {
verifier.Nonce = nil
},
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, "foo",
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
},
{
name: "parse err",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
wantErr: true,
},
{
name: "invalid signature",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
wantErr: true,
},
{
name: "empty subject",
clientID: tu.ValidClientID,
name: "empty subject",
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, "", tu.ValidAudience,
@ -150,8 +159,7 @@ func TestVerifyIDToken(t *testing.T) {
wantErr: true,
},
{
name: "wrong issuer",
clientID: tu.ValidClientID,
name: "wrong issuer",
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
"foo", tu.ValidSubject, tu.ValidAudience,
@ -162,14 +170,15 @@ func TestVerifyIDToken(t *testing.T) {
wantErr: true,
},
{
name: "wrong clientID",
clientID: "foo",
name: "wrong clientID",
customVerifier: func(verifier *IDTokenVerifier) {
verifier.ClientID = "foo"
},
tokenClaims: tu.ValidIDToken,
wantErr: true,
},
{
name: "expired",
clientID: tu.ValidClientID,
name: "expired",
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
@ -180,8 +189,7 @@ func TestVerifyIDToken(t *testing.T) {
wantErr: true,
},
{
name: "wrong IAT",
clientID: tu.ValidClientID,
name: "wrong IAT",
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
@ -192,8 +200,7 @@ func TestVerifyIDToken(t *testing.T) {
wantErr: true,
},
{
name: "wrong acr",
clientID: tu.ValidClientID,
name: "wrong acr",
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
@ -204,8 +211,7 @@ func TestVerifyIDToken(t *testing.T) {
wantErr: true,
},
{
name: "expired auth",
clientID: tu.ValidClientID,
name: "expired auth",
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
@ -216,8 +222,7 @@ func TestVerifyIDToken(t *testing.T) {
wantErr: true,
},
{
name: "wrong nonce",
clientID: tu.ValidClientID,
name: "wrong nonce",
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
@ -231,7 +236,10 @@ func TestVerifyIDToken(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, want := tt.tokenClaims()
verifier.ClientID = tt.clientID
if tt.customVerifier != nil {
tt.customVerifier(verifier)
}
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
if tt.wantErr {
assert.Error(t, err)