op: finish verifier tests

This commit is contained in:
Tim Möhlmann 2023-03-21 15:31:21 +02:00
parent 522f65670d
commit 137fcdfd33
3 changed files with 183 additions and 0 deletions

View file

@ -8,6 +8,7 @@ import (
"errors"
"time"
"github.com/muhlemmer/gu"
"github.com/zitadel/oidc/v3/pkg/oidc"
"gopkg.in/square/go-jose.v2"
)
@ -45,6 +46,16 @@ func init() {
}
}
type JWTProfileKeyStorage struct{}
func (JWTProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID string, clientID string) (*jose.JSONWebKey, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
return gu.Ptr(WebKey.Public()), nil
}
func signEncodeTokenClaims(claims any) string {
payload, err := json.Marshal(claims)
if err != nil {
@ -106,6 +117,25 @@ func NewAccessToken(issuer, subject string, audience []string, expiration time.T
return NewAccessTokenCustom(issuer, subject, audience, expiration, jwtid, clientID, skew, nil)
}
func NewJWTProfileAssertion(issuer, clientID string, audience []string, issuedAt, expiration time.Time) (string, *oidc.JWTTokenRequest) {
req := &oidc.JWTTokenRequest{
Issuer: issuer,
Subject: clientID,
Audience: audience,
ExpiresAt: oidc.FromTime(expiration),
IssuedAt: oidc.FromTime(issuedAt),
}
// make sure the private claim map is set correctly
data, err := json.Marshal(req)
if err != nil {
panic(err)
}
if err = json.Unmarshal(data, req); err != nil {
panic(err)
}
return signEncodeTokenClaims(req), req
}
const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg`
// These variables always result in a valid token
@ -137,6 +167,10 @@ func ValidAccessToken() (string, *oidc.AccessTokenClaims) {
return NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew)
}
func ValidJWTProfileAssertion() (string, *oidc.JWTTokenRequest) {
return NewJWTProfileAssertion(ValidClientID, ValidClientID, []string{ValidIssuer}, time.Now(), ValidExpiration)
}
// ACRVerify is a oidc.ACRVerifier func.
func ACRVerify(acr string) error {
if acr != ValidACR {

View file

@ -12,6 +12,7 @@ import (
"github.com/gorilla/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v3/internal/testutil"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
@ -1003,3 +1004,34 @@ func Test_parseAuthorizeCallbackRequest(t *testing.T) {
})
}
}
func TestValidateAuthReqIDTokenHint(t *testing.T) {
token, _ := tu.ValidIDToken()
tests := []struct {
name string
idTokenHint string
want string
wantErr error
}{
{
name: "empty",
},
{
name: "verify err",
idTokenHint: "foo",
wantErr: oidc.ErrLoginRequired(),
},
{
name: "ok",
idTokenHint: token,
want: tu.ValidSubject,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := op.ValidateAuthReqIDTokenHint(context.Background(), tt.idTokenHint, op.NewIDTokenHintVerifier(tu.ValidIssuer, tu.KeySet{}))
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -0,0 +1,117 @@
package op_test
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
)
func TestNewJWTProfileVerifier(t *testing.T) {
want := &op.JWTProfileVerifier{
Verifier: oidc.Verifier{
Issuer: tu.ValidIssuer,
MaxAgeIAT: time.Minute,
Offset: time.Second,
},
Storage: tu.JWTProfileKeyStorage{},
}
got := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, time.Second, op.SubjectCheck(func(request *oidc.JWTTokenRequest) error {
return oidc.ErrSubjectMissing
}))
assert.Equal(t, want.Verifier, got.Verifier)
assert.Equal(t, want.Storage, got.Storage)
assert.ErrorIs(t, got.CheckSubject(nil), oidc.ErrSubjectMissing)
}
func TestVerifyJWTAssertion(t *testing.T) {
errCtx, cancel := context.WithCancel(context.Background())
cancel()
verifier := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, 0)
tests := []struct {
name string
ctx context.Context
newToken func() (string, *oidc.JWTTokenRequest)
wantErr bool
}{
{
name: "parse error",
ctx: context.Background(),
newToken: func() (string, *oidc.JWTTokenRequest) { return "!", nil },
wantErr: true,
},
{
name: "wrong audience",
ctx: context.Background(),
newToken: func() (string, *oidc.JWTTokenRequest) {
return tu.NewJWTProfileAssertion(
tu.ValidClientID, tu.ValidClientID, []string{"wrong"},
time.Now(), tu.ValidExpiration,
)
},
wantErr: true,
},
{
name: "expired",
ctx: context.Background(),
newToken: func() (string, *oidc.JWTTokenRequest) {
return tu.NewJWTProfileAssertion(
tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer},
time.Now(), time.Now().Add(-time.Hour),
)
},
wantErr: true,
},
{
name: "invalid iat",
ctx: context.Background(),
newToken: func() (string, *oidc.JWTTokenRequest) {
return tu.NewJWTProfileAssertion(
tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer},
time.Now().Add(time.Hour), tu.ValidExpiration,
)
},
wantErr: true,
},
{
name: "invalid subject",
ctx: context.Background(),
newToken: func() (string, *oidc.JWTTokenRequest) {
return tu.NewJWTProfileAssertion(
tu.ValidClientID, "wrong", []string{tu.ValidIssuer},
time.Now(), tu.ValidExpiration,
)
},
wantErr: true,
},
{
name: "check signature fail",
ctx: errCtx,
newToken: tu.ValidJWTProfileAssertion,
wantErr: true,
},
{
name: "ok",
ctx: context.Background(),
newToken: tu.ValidJWTProfileAssertion,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertion, want := tt.newToken()
got, err := op.VerifyJWTAssertion(tt.ctx, assertion, verifier)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, want, got)
})
}
}