diff --git a/pkg/client/rp/verifier_test.go b/pkg/client/rp/verifier_test.go index cd2fab4..24d35af 100644 --- a/pkg/client/rp/verifier_test.go +++ b/pkg/client/rp/verifier_test.go @@ -100,22 +100,21 @@ func TestVerifyIDToken(t *testing.T) { MaxAge: 2 * time.Minute, ACR: tu.ACRVerify, Nonce: func(context.Context) string { return tu.ValidNonce }, + ClientID: tu.ValidClientID, } tests := []struct { - name string - clientID string - tokenClaims func() (string, *oidc.IDTokenClaims) - wantErr bool + name string + tokenClaims func() (string, *oidc.IDTokenClaims) + customVerifier func(verifier *IDTokenVerifier) + wantErr bool }{ { name: "success", - clientID: tu.ValidClientID, tokenClaims: tu.ValidIDToken, }, { - name: "custom claims", - clientID: tu.ValidClientID, + name: "custom claims", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.NewIDTokenCustom( tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, @@ -125,21 +124,31 @@ func TestVerifyIDToken(t *testing.T) { ) }, }, + { + name: "skip nonce check", + customVerifier: func(verifier *IDTokenVerifier) { + verifier.Nonce = nil + }, + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return tu.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime, "foo", + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", + ) + }, + }, { name: "parse err", - clientID: tu.ValidClientID, tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil }, wantErr: true, }, { name: "invalid signature", - clientID: tu.ValidClientID, tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil }, wantErr: true, }, { - name: "empty subject", - clientID: tu.ValidClientID, + name: "empty subject", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.NewIDToken( tu.ValidIssuer, "", tu.ValidAudience, @@ -150,8 +159,7 @@ func TestVerifyIDToken(t *testing.T) { wantErr: true, }, { - name: "wrong issuer", - clientID: tu.ValidClientID, + name: "wrong issuer", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.NewIDToken( "foo", tu.ValidSubject, tu.ValidAudience, @@ -162,14 +170,15 @@ func TestVerifyIDToken(t *testing.T) { wantErr: true, }, { - name: "wrong clientID", - clientID: "foo", + name: "wrong clientID", + customVerifier: func(verifier *IDTokenVerifier) { + verifier.ClientID = "foo" + }, tokenClaims: tu.ValidIDToken, wantErr: true, }, { - name: "expired", - clientID: tu.ValidClientID, + name: "expired", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.NewIDToken( tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, @@ -180,8 +189,7 @@ func TestVerifyIDToken(t *testing.T) { wantErr: true, }, { - name: "wrong IAT", - clientID: tu.ValidClientID, + name: "wrong IAT", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.NewIDToken( tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, @@ -192,8 +200,7 @@ func TestVerifyIDToken(t *testing.T) { wantErr: true, }, { - name: "wrong acr", - clientID: tu.ValidClientID, + name: "wrong acr", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.NewIDToken( tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, @@ -204,8 +211,7 @@ func TestVerifyIDToken(t *testing.T) { wantErr: true, }, { - name: "expired auth", - clientID: tu.ValidClientID, + name: "expired auth", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.NewIDToken( tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, @@ -216,8 +222,7 @@ func TestVerifyIDToken(t *testing.T) { wantErr: true, }, { - name: "wrong nonce", - clientID: tu.ValidClientID, + name: "wrong nonce", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.NewIDToken( tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, @@ -231,7 +236,10 @@ func TestVerifyIDToken(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { token, want := tt.tokenClaims() - verifier.ClientID = tt.clientID + if tt.customVerifier != nil { + tt.customVerifier(verifier) + } + got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier) if tt.wantErr { assert.Error(t, err)