fix: Handle case where verifier Nonce func is nil (#594)
* Skip nonce check if verifier nonce func is nil * add unit test
This commit is contained in:
parent
37ca0e472a
commit
24d43f538e
2 changed files with 38 additions and 28 deletions
|
@ -73,9 +73,11 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenV
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if v.Nonce != nil {
|
||||||
if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil {
|
if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil {
|
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil {
|
||||||
return nilClaims, err
|
return nilClaims, err
|
||||||
|
|
|
@ -100,22 +100,21 @@ func TestVerifyIDToken(t *testing.T) {
|
||||||
MaxAge: 2 * time.Minute,
|
MaxAge: 2 * time.Minute,
|
||||||
ACR: tu.ACRVerify,
|
ACR: tu.ACRVerify,
|
||||||
Nonce: func(context.Context) string { return tu.ValidNonce },
|
Nonce: func(context.Context) string { return tu.ValidNonce },
|
||||||
|
ClientID: tu.ValidClientID,
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
clientID string
|
|
||||||
tokenClaims func() (string, *oidc.IDTokenClaims)
|
tokenClaims func() (string, *oidc.IDTokenClaims)
|
||||||
|
customVerifier func(verifier *IDTokenVerifier)
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "success",
|
name: "success",
|
||||||
clientID: tu.ValidClientID,
|
|
||||||
tokenClaims: tu.ValidIDToken,
|
tokenClaims: tu.ValidIDToken,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "custom claims",
|
name: "custom claims",
|
||||||
clientID: tu.ValidClientID,
|
|
||||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
return tu.NewIDTokenCustom(
|
return tu.NewIDTokenCustom(
|
||||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
@ -125,21 +124,31 @@ func TestVerifyIDToken(t *testing.T) {
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "skip nonce check",
|
||||||
|
customVerifier: func(verifier *IDTokenVerifier) {
|
||||||
|
verifier.Nonce = nil
|
||||||
|
},
|
||||||
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
|
return tu.NewIDToken(
|
||||||
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
tu.ValidExpiration, tu.ValidAuthTime, "foo",
|
||||||
|
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "parse err",
|
name: "parse err",
|
||||||
clientID: tu.ValidClientID,
|
|
||||||
tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
|
tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid signature",
|
name: "invalid signature",
|
||||||
clientID: tu.ValidClientID,
|
|
||||||
tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
|
tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty subject",
|
name: "empty subject",
|
||||||
clientID: tu.ValidClientID,
|
|
||||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
return tu.NewIDToken(
|
return tu.NewIDToken(
|
||||||
tu.ValidIssuer, "", tu.ValidAudience,
|
tu.ValidIssuer, "", tu.ValidAudience,
|
||||||
|
@ -151,7 +160,6 @@ func TestVerifyIDToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "wrong issuer",
|
name: "wrong issuer",
|
||||||
clientID: tu.ValidClientID,
|
|
||||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
return tu.NewIDToken(
|
return tu.NewIDToken(
|
||||||
"foo", tu.ValidSubject, tu.ValidAudience,
|
"foo", tu.ValidSubject, tu.ValidAudience,
|
||||||
|
@ -163,13 +171,14 @@ func TestVerifyIDToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "wrong clientID",
|
name: "wrong clientID",
|
||||||
clientID: "foo",
|
customVerifier: func(verifier *IDTokenVerifier) {
|
||||||
|
verifier.ClientID = "foo"
|
||||||
|
},
|
||||||
tokenClaims: tu.ValidIDToken,
|
tokenClaims: tu.ValidIDToken,
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "expired",
|
name: "expired",
|
||||||
clientID: tu.ValidClientID,
|
|
||||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
return tu.NewIDToken(
|
return tu.NewIDToken(
|
||||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
@ -181,7 +190,6 @@ func TestVerifyIDToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "wrong IAT",
|
name: "wrong IAT",
|
||||||
clientID: tu.ValidClientID,
|
|
||||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
return tu.NewIDToken(
|
return tu.NewIDToken(
|
||||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
@ -193,7 +201,6 @@ func TestVerifyIDToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "wrong acr",
|
name: "wrong acr",
|
||||||
clientID: tu.ValidClientID,
|
|
||||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
return tu.NewIDToken(
|
return tu.NewIDToken(
|
||||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
@ -205,7 +212,6 @@ func TestVerifyIDToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "expired auth",
|
name: "expired auth",
|
||||||
clientID: tu.ValidClientID,
|
|
||||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
return tu.NewIDToken(
|
return tu.NewIDToken(
|
||||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
@ -217,7 +223,6 @@ func TestVerifyIDToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "wrong nonce",
|
name: "wrong nonce",
|
||||||
clientID: tu.ValidClientID,
|
|
||||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||||
return tu.NewIDToken(
|
return tu.NewIDToken(
|
||||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||||
|
@ -231,7 +236,10 @@ func TestVerifyIDToken(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
token, want := tt.tokenClaims()
|
token, want := tt.tokenClaims()
|
||||||
verifier.ClientID = tt.clientID
|
if tt.customVerifier != nil {
|
||||||
|
tt.customVerifier(verifier)
|
||||||
|
}
|
||||||
|
|
||||||
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
|
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue