374 lines
6.7 KiB
Go
374 lines
6.7 KiB
Go
package oidc
|
|
|
|
import (
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestDecryptToken(t *testing.T) {
|
|
const tokenString = "ABC"
|
|
got, err := DecryptToken(tokenString)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, tokenString, got)
|
|
}
|
|
|
|
func TestDefaultACRVerifier(t *testing.T) {
|
|
acrVerfier := DefaultACRVerifier([]string{"foo", "bar"})
|
|
|
|
tests := []struct {
|
|
name string
|
|
acr string
|
|
wantErr string
|
|
}{
|
|
{
|
|
name: "ok",
|
|
acr: "bar",
|
|
},
|
|
{
|
|
name: "error",
|
|
acr: "hello",
|
|
wantErr: "expected one of: [foo bar], got: \"hello\"",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := acrVerfier(tt.acr)
|
|
if tt.wantErr != "" {
|
|
assert.EqualError(t, err, tt.wantErr)
|
|
return
|
|
}
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCheckSubject(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
claims Claims
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "missing",
|
|
claims: &TokenClaims{},
|
|
wantErr: ErrSubjectMissing,
|
|
},
|
|
{
|
|
name: "ok",
|
|
claims: &TokenClaims{
|
|
Subject: "foo",
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := CheckSubject(tt.claims)
|
|
assert.ErrorIs(t, err, tt.wantErr)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCheckIssuer(t *testing.T) {
|
|
const issuer = "foo.bar"
|
|
tests := []struct {
|
|
name string
|
|
claims Claims
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "missing",
|
|
claims: &TokenClaims{},
|
|
wantErr: ErrIssuerInvalid,
|
|
},
|
|
{
|
|
name: "wrong",
|
|
claims: &TokenClaims{
|
|
Issuer: "wrong",
|
|
},
|
|
wantErr: ErrIssuerInvalid,
|
|
},
|
|
{
|
|
name: "ok",
|
|
claims: &TokenClaims{
|
|
Issuer: issuer,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := CheckIssuer(tt.claims, issuer)
|
|
assert.ErrorIs(t, err, tt.wantErr)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCheckAudience(t *testing.T) {
|
|
const clientID = "foo.bar"
|
|
tests := []struct {
|
|
name string
|
|
claims Claims
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "missing",
|
|
claims: &TokenClaims{},
|
|
wantErr: ErrAudience,
|
|
},
|
|
{
|
|
name: "wrong",
|
|
claims: &TokenClaims{
|
|
Audience: []string{"wrong"},
|
|
},
|
|
wantErr: ErrAudience,
|
|
},
|
|
{
|
|
name: "ok",
|
|
claims: &TokenClaims{
|
|
Audience: []string{clientID},
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := CheckAudience(tt.claims, clientID)
|
|
assert.ErrorIs(t, err, tt.wantErr)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCheckAuthorizedParty(t *testing.T) {
|
|
const clientID = "foo.bar"
|
|
tests := []struct {
|
|
name string
|
|
claims Claims
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "single audience, no azp",
|
|
claims: &TokenClaims{
|
|
Audience: []string{clientID},
|
|
},
|
|
},
|
|
{
|
|
name: "multiple audience, no azp",
|
|
claims: &TokenClaims{
|
|
Audience: []string{clientID, "other"},
|
|
},
|
|
wantErr: ErrAzpMissing,
|
|
},
|
|
{
|
|
name: "single audience, with azp",
|
|
claims: &TokenClaims{
|
|
Audience: []string{clientID},
|
|
AuthorizedParty: clientID,
|
|
},
|
|
},
|
|
{
|
|
name: "multiple audience, with azp",
|
|
claims: &TokenClaims{
|
|
Audience: []string{clientID, "other"},
|
|
AuthorizedParty: clientID,
|
|
},
|
|
},
|
|
{
|
|
name: "wrong azp",
|
|
claims: &TokenClaims{
|
|
AuthorizedParty: "wrong",
|
|
},
|
|
wantErr: ErrAzpInvalid,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := CheckAuthorizedParty(tt.claims, clientID)
|
|
assert.ErrorIs(t, err, tt.wantErr)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCheckExpiration(t *testing.T) {
|
|
const offset = time.Minute
|
|
tests := []struct {
|
|
name string
|
|
claims Claims
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "missing",
|
|
claims: &TokenClaims{},
|
|
wantErr: ErrExpired,
|
|
},
|
|
{
|
|
name: "expired",
|
|
claims: &TokenClaims{
|
|
Expiration: FromTime(time.Now().Add(-2 * offset)),
|
|
},
|
|
wantErr: ErrExpired,
|
|
},
|
|
{
|
|
name: "valid",
|
|
claims: &TokenClaims{
|
|
Expiration: FromTime(time.Now().Add(2 * offset)),
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := CheckExpiration(tt.claims, offset)
|
|
assert.ErrorIs(t, err, tt.wantErr)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCheckIssuedAt(t *testing.T) {
|
|
const offset = time.Minute
|
|
tests := []struct {
|
|
name string
|
|
maxAgeIAT time.Duration
|
|
claims Claims
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "missing",
|
|
claims: &TokenClaims{},
|
|
wantErr: ErrIatMissing,
|
|
},
|
|
{
|
|
name: "future",
|
|
claims: &TokenClaims{
|
|
IssuedAt: FromTime(time.Now().Add(time.Hour)),
|
|
},
|
|
wantErr: ErrIatInFuture,
|
|
},
|
|
{
|
|
name: "no max",
|
|
claims: &TokenClaims{
|
|
IssuedAt: FromTime(time.Now()),
|
|
},
|
|
},
|
|
{
|
|
name: "past max",
|
|
maxAgeIAT: time.Minute,
|
|
claims: &TokenClaims{
|
|
IssuedAt: FromTime(time.Now().Add(-time.Hour)),
|
|
},
|
|
wantErr: ErrIatToOld,
|
|
},
|
|
{
|
|
name: "within max",
|
|
maxAgeIAT: time.Hour,
|
|
claims: &TokenClaims{
|
|
IssuedAt: FromTime(time.Now()),
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := CheckIssuedAt(tt.claims, tt.maxAgeIAT, offset)
|
|
assert.ErrorIs(t, err, tt.wantErr)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCheckNonce(t *testing.T) {
|
|
const nonce = "123"
|
|
tests := []struct {
|
|
name string
|
|
claims Claims
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "missing",
|
|
claims: &TokenClaims{},
|
|
wantErr: ErrNonceInvalid,
|
|
},
|
|
{
|
|
name: "wrong",
|
|
claims: &TokenClaims{
|
|
Nonce: "wrong",
|
|
},
|
|
wantErr: ErrNonceInvalid,
|
|
},
|
|
{
|
|
name: "ok",
|
|
claims: &TokenClaims{
|
|
Nonce: nonce,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := CheckNonce(tt.claims, nonce)
|
|
assert.ErrorIs(t, err, tt.wantErr)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCheckAuthorizationContextClassReference(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
acr ACRVerifier
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "error",
|
|
acr: func(s string) error { return errors.New("oops") },
|
|
wantErr: ErrAcrInvalid,
|
|
},
|
|
{
|
|
name: "ok",
|
|
acr: func(s string) error { return nil },
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := CheckAuthorizationContextClassReference(&IDTokenClaims{}, tt.acr)
|
|
assert.ErrorIs(t, err, tt.wantErr)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCheckAuthTime(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
claims Claims
|
|
maxAge time.Duration
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "no max age",
|
|
claims: &TokenClaims{},
|
|
},
|
|
{
|
|
name: "missing",
|
|
claims: &TokenClaims{},
|
|
maxAge: time.Minute,
|
|
wantErr: ErrAuthTimeNotPresent,
|
|
},
|
|
{
|
|
name: "expired",
|
|
maxAge: time.Minute,
|
|
claims: &TokenClaims{
|
|
AuthTime: FromTime(time.Now().Add(-time.Hour)),
|
|
},
|
|
wantErr: ErrAuthTimeToOld,
|
|
},
|
|
{
|
|
name: "ok",
|
|
maxAge: time.Minute,
|
|
claims: &TokenClaims{
|
|
AuthTime: NowTime(),
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := CheckAuthTime(tt.claims, tt.maxAge)
|
|
assert.ErrorIs(t, err, tt.wantErr)
|
|
})
|
|
}
|
|
}
|