add unit test

This commit is contained in:
Yuval Marcus 2024-04-23 15:01:07 -04:00
parent 0f227323e6
commit 883c156bd0

View file

@ -100,22 +100,21 @@ func TestVerifyIDToken(t *testing.T) {
MaxAge: 2 * time.Minute, MaxAge: 2 * time.Minute,
ACR: tu.ACRVerify, ACR: tu.ACRVerify,
Nonce: func(context.Context) string { return tu.ValidNonce }, Nonce: func(context.Context) string { return tu.ValidNonce },
ClientID: tu.ValidClientID,
} }
tests := []struct { tests := []struct {
name string name string
clientID string tokenClaims func() (string, *oidc.IDTokenClaims)
tokenClaims func() (string, *oidc.IDTokenClaims) customVerifier func(verifier *IDTokenVerifier)
wantErr bool wantErr bool
}{ }{
{ {
name: "success", name: "success",
clientID: tu.ValidClientID,
tokenClaims: tu.ValidIDToken, tokenClaims: tu.ValidIDToken,
}, },
{ {
name: "custom claims", name: "custom claims",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDTokenCustom( return tu.NewIDTokenCustom(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
@ -125,21 +124,31 @@ func TestVerifyIDToken(t *testing.T) {
) )
}, },
}, },
{
name: "skip nonce check",
customVerifier: func(verifier *IDTokenVerifier) {
verifier.Nonce = nil
},
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, "foo",
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
},
{ {
name: "parse err", name: "parse err",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil }, tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
wantErr: true, wantErr: true,
}, },
{ {
name: "invalid signature", name: "invalid signature",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil }, tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
wantErr: true, wantErr: true,
}, },
{ {
name: "empty subject", name: "empty subject",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken( return tu.NewIDToken(
tu.ValidIssuer, "", tu.ValidAudience, tu.ValidIssuer, "", tu.ValidAudience,
@ -150,8 +159,7 @@ func TestVerifyIDToken(t *testing.T) {
wantErr: true, wantErr: true,
}, },
{ {
name: "wrong issuer", name: "wrong issuer",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken( return tu.NewIDToken(
"foo", tu.ValidSubject, tu.ValidAudience, "foo", tu.ValidSubject, tu.ValidAudience,
@ -162,14 +170,15 @@ func TestVerifyIDToken(t *testing.T) {
wantErr: true, wantErr: true,
}, },
{ {
name: "wrong clientID", name: "wrong clientID",
clientID: "foo", customVerifier: func(verifier *IDTokenVerifier) {
verifier.ClientID = "foo"
},
tokenClaims: tu.ValidIDToken, tokenClaims: tu.ValidIDToken,
wantErr: true, wantErr: true,
}, },
{ {
name: "expired", name: "expired",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken( return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
@ -180,8 +189,7 @@ func TestVerifyIDToken(t *testing.T) {
wantErr: true, wantErr: true,
}, },
{ {
name: "wrong IAT", name: "wrong IAT",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken( return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
@ -192,8 +200,7 @@ func TestVerifyIDToken(t *testing.T) {
wantErr: true, wantErr: true,
}, },
{ {
name: "wrong acr", name: "wrong acr",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken( return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
@ -204,8 +211,7 @@ func TestVerifyIDToken(t *testing.T) {
wantErr: true, wantErr: true,
}, },
{ {
name: "expired auth", name: "expired auth",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken( return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
@ -216,8 +222,7 @@ func TestVerifyIDToken(t *testing.T) {
wantErr: true, wantErr: true,
}, },
{ {
name: "wrong nonce", name: "wrong nonce",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken( return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
@ -231,7 +236,10 @@ func TestVerifyIDToken(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
token, want := tt.tokenClaims() token, want := tt.tokenClaims()
verifier.ClientID = tt.clientID if tt.customVerifier != nil {
tt.customVerifier(verifier)
}
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier) got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
if tt.wantErr { if tt.wantErr {
assert.Error(t, err) assert.Error(t, err)