feat: merge the verifier types (#336)
BREAKING CHANGE: - The various verifier types are merged into a oidc.Verifir. - oidc.Verfier became a struct with exported fields * use type aliases for oidc.Verifier this binds the correct contstructor to each verifier usecase. * fix: handle the zero cases for oidc.Time * add unit tests to oidc verifier * fix: correct returned field for JWTTokenRequest JWTTokenRequest.GetIssuedAt() was returning the ExpiresAt field. This change corrects that by returning IssuedAt instead.
This commit is contained in:
parent
c8cf15e266
commit
33c716ddcf
29 changed files with 948 additions and 351 deletions
|
@ -38,7 +38,7 @@ type Authorizer interface {
|
|||
Storage() Storage
|
||||
Decoder() httphelper.Decoder
|
||||
Encoder() httphelper.Encoder
|
||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
||||
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
|
||||
Crypto() Crypto
|
||||
RequestObjectSupported() bool
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ type Authorizer interface {
|
|||
// implementing its own validation mechanism for the auth request
|
||||
type AuthorizeValidator interface {
|
||||
Authorizer
|
||||
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error)
|
||||
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, *IDTokenHintVerifier) (string, error)
|
||||
}
|
||||
|
||||
func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) {
|
||||
|
@ -204,7 +204,7 @@ func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oi
|
|||
}
|
||||
|
||||
// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed
|
||||
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) {
|
||||
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) {
|
||||
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -384,7 +384,7 @@ func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType)
|
|||
|
||||
// ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request)
|
||||
// and returns the `sub` claim
|
||||
func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier IDTokenHintVerifier) (string, error) {
|
||||
func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier *IDTokenHintVerifier) (string, error) {
|
||||
if idTokenHint == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/gorilla/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
tu "github.com/zitadel/oidc/v3/internal/testutil"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
@ -146,7 +147,7 @@ func TestValidateAuthRequest(t *testing.T) {
|
|||
type args struct {
|
||||
authRequest *oidc.AuthRequest
|
||||
storage op.Storage
|
||||
verifier op.IDTokenHintVerifier
|
||||
verifier *op.IDTokenHintVerifier
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -1003,3 +1004,34 @@ func Test_parseAuthorizeCallbackRequest(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAuthReqIDTokenHint(t *testing.T) {
|
||||
token, _ := tu.ValidIDToken()
|
||||
tests := []struct {
|
||||
name string
|
||||
idTokenHint string
|
||||
want string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
},
|
||||
{
|
||||
name: "verify err",
|
||||
idTokenHint: "foo",
|
||||
wantErr: oidc.ErrLoginRequired(),
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
idTokenHint: token,
|
||||
want: tu.ValidSubject,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := op.ValidateAuthReqIDTokenHint(context.Background(), tt.idTokenHint, op.NewIDTokenHintVerifier(tu.ValidIssuer, tu.KeySet{}))
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -81,7 +81,7 @@ var (
|
|||
)
|
||||
|
||||
type ClientJWTProfile interface {
|
||||
JWTProfileVerifier(context.Context) JWTProfileVerifier
|
||||
JWTProfileVerifier(context.Context) *JWTProfileVerifier
|
||||
}
|
||||
|
||||
func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) {
|
||||
|
|
|
@ -22,7 +22,7 @@ import (
|
|||
|
||||
type testClientJWTProfile struct{}
|
||||
|
||||
func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil }
|
||||
func (testClientJWTProfile) JWTProfileVerifier(context.Context) *op.JWTProfileVerifier { return nil }
|
||||
|
||||
func TestClientJWTAuth(t *testing.T) {
|
||||
type args struct {
|
||||
|
|
|
@ -79,10 +79,10 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
|
|||
}
|
||||
|
||||
// IDTokenHintVerifier mocks base method.
|
||||
func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) op.IDTokenHintVerifier {
|
||||
func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) *op.IDTokenHintVerifier {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0)
|
||||
ret0, _ := ret[0].(op.IDTokenHintVerifier)
|
||||
ret0, _ := ret[0].(*op.IDTokenHintVerifier)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ func ExpectEncoder(a op.Authorizer) {
|
|||
func ExpectVerifier(a op.Authorizer, t *testing.T) {
|
||||
mockA := a.(*MockAuthorizer)
|
||||
mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn(
|
||||
func() op.IDTokenHintVerifier {
|
||||
func() *op.IDTokenHintVerifier {
|
||||
return op.NewIDTokenHintVerifier("", nil)
|
||||
})
|
||||
}
|
||||
|
|
10
pkg/op/op.go
10
pkg/op/op.go
|
@ -73,8 +73,8 @@ type OpenIDProvider interface {
|
|||
Storage() Storage
|
||||
Decoder() httphelper.Decoder
|
||||
Encoder() httphelper.Encoder
|
||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
||||
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
|
||||
AccessTokenVerifier(context.Context) *AccessTokenVerifier
|
||||
Crypto() Crypto
|
||||
DefaultLogoutRedirectURI() string
|
||||
Probes() []ProbesFn
|
||||
|
@ -342,15 +342,15 @@ func (o *Provider) Encoder() httphelper.Encoder {
|
|||
return o.encoder
|
||||
}
|
||||
|
||||
func (o *Provider) IDTokenHintVerifier(ctx context.Context) IDTokenHintVerifier {
|
||||
func (o *Provider) IDTokenHintVerifier(ctx context.Context) *IDTokenHintVerifier {
|
||||
return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...)
|
||||
}
|
||||
|
||||
func (o *Provider) JWTProfileVerifier(ctx context.Context) JWTProfileVerifier {
|
||||
func (o *Provider) JWTProfileVerifier(ctx context.Context) *JWTProfileVerifier {
|
||||
return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second)
|
||||
}
|
||||
|
||||
func (o *Provider) AccessTokenVerifier(ctx context.Context) AccessTokenVerifier {
|
||||
func (o *Provider) AccessTokenVerifier(ctx context.Context) *AccessTokenVerifier {
|
||||
return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...)
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
type SessionEnder interface {
|
||||
Decoder() httphelper.Decoder
|
||||
Storage() Storage
|
||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
||||
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
|
||||
DefaultLogoutRedirectURI() string
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ type Introspector interface {
|
|||
Decoder() httphelper.Decoder
|
||||
Crypto() Crypto
|
||||
Storage() Storage
|
||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
||||
AccessTokenVerifier(context.Context) *AccessTokenVerifier
|
||||
}
|
||||
|
||||
type IntrospectorJWTProfile interface {
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
|
||||
type JWTAuthorizationGrantExchanger interface {
|
||||
Exchanger
|
||||
JWTProfileVerifier(context.Context) JWTProfileVerifier
|
||||
JWTProfileVerifier(context.Context) *JWTProfileVerifier
|
||||
}
|
||||
|
||||
// JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant https://tools.ietf.org/html/rfc7523#section-2.1
|
||||
|
|
|
@ -20,8 +20,8 @@ type Exchanger interface {
|
|||
GrantTypeJWTAuthorizationSupported() bool
|
||||
GrantTypeClientCredentialsSupported() bool
|
||||
GrantTypeDeviceCodeSupported() bool
|
||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
||||
AccessTokenVerifier(context.Context) *AccessTokenVerifier
|
||||
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
|
||||
}
|
||||
|
||||
func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
|
@ -15,14 +15,14 @@ type Revoker interface {
|
|||
Decoder() httphelper.Decoder
|
||||
Crypto() Crypto
|
||||
Storage() Storage
|
||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
||||
AccessTokenVerifier(context.Context) *AccessTokenVerifier
|
||||
AuthMethodPrivateKeyJWTSupported() bool
|
||||
AuthMethodPostSupported() bool
|
||||
}
|
||||
|
||||
type RevokerJWTProfile interface {
|
||||
Revoker
|
||||
JWTProfileVerifier(context.Context) JWTProfileVerifier
|
||||
JWTProfileVerifier(context.Context) *JWTProfileVerifier
|
||||
}
|
||||
|
||||
func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) {
|
||||
|
|
|
@ -14,7 +14,7 @@ type UserinfoProvider interface {
|
|||
Decoder() httphelper.Decoder
|
||||
Crypto() Crypto
|
||||
Storage() Storage
|
||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
||||
AccessTokenVerifier(context.Context) *AccessTokenVerifier
|
||||
}
|
||||
|
||||
func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) {
|
||||
|
|
|
@ -2,62 +2,25 @@ package op
|
|||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
type AccessTokenVerifier interface {
|
||||
oidc.Verifier
|
||||
SupportedSignAlgs() []string
|
||||
KeySet() oidc.KeySet
|
||||
}
|
||||
type AccessTokenVerifier oidc.Verifier
|
||||
|
||||
type accessTokenVerifier struct {
|
||||
issuer string
|
||||
maxAgeIAT time.Duration
|
||||
offset time.Duration
|
||||
supportedSignAlgs []string
|
||||
keySet oidc.KeySet
|
||||
}
|
||||
|
||||
// Issuer implements oidc.Verifier interface
|
||||
func (i *accessTokenVerifier) Issuer() string {
|
||||
return i.issuer
|
||||
}
|
||||
|
||||
// MaxAgeIAT implements oidc.Verifier interface
|
||||
func (i *accessTokenVerifier) MaxAgeIAT() time.Duration {
|
||||
return i.maxAgeIAT
|
||||
}
|
||||
|
||||
// Offset implements oidc.Verifier interface
|
||||
func (i *accessTokenVerifier) Offset() time.Duration {
|
||||
return i.offset
|
||||
}
|
||||
|
||||
// SupportedSignAlgs implements AccessTokenVerifier interface
|
||||
func (i *accessTokenVerifier) SupportedSignAlgs() []string {
|
||||
return i.supportedSignAlgs
|
||||
}
|
||||
|
||||
// KeySet implements AccessTokenVerifier interface
|
||||
func (i *accessTokenVerifier) KeySet() oidc.KeySet {
|
||||
return i.keySet
|
||||
}
|
||||
|
||||
type AccessTokenVerifierOpt func(*accessTokenVerifier)
|
||||
type AccessTokenVerifierOpt func(*AccessTokenVerifier)
|
||||
|
||||
func WithSupportedAccessTokenSigningAlgorithms(algs ...string) AccessTokenVerifierOpt {
|
||||
return func(verifier *accessTokenVerifier) {
|
||||
verifier.supportedSignAlgs = algs
|
||||
return func(verifier *AccessTokenVerifier) {
|
||||
verifier.SupportedSignAlgs = algs
|
||||
}
|
||||
}
|
||||
|
||||
func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) AccessTokenVerifier {
|
||||
verifier := &accessTokenVerifier{
|
||||
issuer: issuer,
|
||||
keySet: keySet,
|
||||
// NewAccessTokenVerifier returns a AccessTokenVerifier suitable for access token verification.
|
||||
func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) *AccessTokenVerifier {
|
||||
verifier := &AccessTokenVerifier{
|
||||
Issuer: issuer,
|
||||
KeySet: keySet,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(verifier)
|
||||
|
@ -66,7 +29,7 @@ func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTok
|
|||
}
|
||||
|
||||
// VerifyAccessToken validates the access token (issuer, signature and expiration).
|
||||
func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v AccessTokenVerifier) (claims C, err error) {
|
||||
func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v *AccessTokenVerifier) (claims C, err error) {
|
||||
var nilClaims C
|
||||
|
||||
decrypted, err := oidc.DecryptToken(token)
|
||||
|
@ -78,15 +41,15 @@ func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v Acces
|
|||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
||||
if err := oidc.CheckIssuer(claims, v.Issuer); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
|
||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
|
||||
if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ func TestNewAccessTokenVerifier(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want AccessTokenVerifier
|
||||
want *AccessTokenVerifier
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
|
@ -28,9 +28,9 @@ func TestNewAccessTokenVerifier(t *testing.T) {
|
|||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
},
|
||||
want: &accessTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
want: &AccessTokenVerifier{
|
||||
Issuer: tu.ValidIssuer,
|
||||
KeySet: tu.KeySet{},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -42,10 +42,10 @@ func TestNewAccessTokenVerifier(t *testing.T) {
|
|||
WithSupportedAccessTokenSigningAlgorithms("ABC", "DEF"),
|
||||
},
|
||||
},
|
||||
want: &accessTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
supportedSignAlgs: []string{"ABC", "DEF"},
|
||||
want: &AccessTokenVerifier{
|
||||
Issuer: tu.ValidIssuer,
|
||||
KeySet: tu.KeySet{},
|
||||
SupportedSignAlgs: []string{"ABC", "DEF"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -58,12 +58,12 @@ func TestNewAccessTokenVerifier(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestVerifyAccessToken(t *testing.T) {
|
||||
verifier := &accessTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
maxAgeIAT: 2 * time.Minute,
|
||||
offset: time.Second,
|
||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
keySet: tu.KeySet{},
|
||||
verifier := &AccessTokenVerifier{
|
||||
Issuer: tu.ValidIssuer,
|
||||
MaxAgeIAT: 2 * time.Minute,
|
||||
Offset: time.Second,
|
||||
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
KeySet: tu.KeySet{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
|
|
@ -2,69 +2,24 @@ package op
|
|||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
type IDTokenHintVerifier interface {
|
||||
oidc.Verifier
|
||||
SupportedSignAlgs() []string
|
||||
KeySet() oidc.KeySet
|
||||
ACR() oidc.ACRVerifier
|
||||
MaxAge() time.Duration
|
||||
}
|
||||
type IDTokenHintVerifier oidc.Verifier
|
||||
|
||||
type idTokenHintVerifier struct {
|
||||
issuer string
|
||||
maxAgeIAT time.Duration
|
||||
offset time.Duration
|
||||
supportedSignAlgs []string
|
||||
maxAge time.Duration
|
||||
acr oidc.ACRVerifier
|
||||
keySet oidc.KeySet
|
||||
}
|
||||
|
||||
func (i *idTokenHintVerifier) Issuer() string {
|
||||
return i.issuer
|
||||
}
|
||||
|
||||
func (i *idTokenHintVerifier) MaxAgeIAT() time.Duration {
|
||||
return i.maxAgeIAT
|
||||
}
|
||||
|
||||
func (i *idTokenHintVerifier) Offset() time.Duration {
|
||||
return i.offset
|
||||
}
|
||||
|
||||
func (i *idTokenHintVerifier) SupportedSignAlgs() []string {
|
||||
return i.supportedSignAlgs
|
||||
}
|
||||
|
||||
func (i *idTokenHintVerifier) KeySet() oidc.KeySet {
|
||||
return i.keySet
|
||||
}
|
||||
|
||||
func (i *idTokenHintVerifier) ACR() oidc.ACRVerifier {
|
||||
return i.acr
|
||||
}
|
||||
|
||||
func (i *idTokenHintVerifier) MaxAge() time.Duration {
|
||||
return i.maxAge
|
||||
}
|
||||
|
||||
type IDTokenHintVerifierOpt func(*idTokenHintVerifier)
|
||||
type IDTokenHintVerifierOpt func(*IDTokenHintVerifier)
|
||||
|
||||
func WithSupportedIDTokenHintSigningAlgorithms(algs ...string) IDTokenHintVerifierOpt {
|
||||
return func(verifier *idTokenHintVerifier) {
|
||||
verifier.supportedSignAlgs = algs
|
||||
return func(verifier *IDTokenHintVerifier) {
|
||||
verifier.SupportedSignAlgs = algs
|
||||
}
|
||||
}
|
||||
|
||||
func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) IDTokenHintVerifier {
|
||||
verifier := &idTokenHintVerifier{
|
||||
issuer: issuer,
|
||||
keySet: keySet,
|
||||
func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) *IDTokenHintVerifier {
|
||||
verifier := &IDTokenHintVerifier{
|
||||
Issuer: issuer,
|
||||
KeySet: keySet,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(verifier)
|
||||
|
@ -74,7 +29,7 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHi
|
|||
|
||||
// VerifyIDTokenHint validates the id token according to
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTokenHintVerifier) (claims C, err error) {
|
||||
func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v *IDTokenHintVerifier) (claims C, err error) {
|
||||
var nilClaims C
|
||||
|
||||
decrypted, err := oidc.DecryptToken(token)
|
||||
|
@ -86,27 +41,27 @@ func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTok
|
|||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
||||
if err := oidc.CheckIssuer(claims, v.Issuer); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
|
||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
|
||||
if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
|
||||
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
|
||||
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
|
||||
if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
return claims, nil
|
||||
|
|
|
@ -20,7 +20,7 @@ func TestNewIDTokenHintVerifier(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want IDTokenHintVerifier
|
||||
want *IDTokenHintVerifier
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
|
@ -28,9 +28,9 @@ func TestNewIDTokenHintVerifier(t *testing.T) {
|
|||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
},
|
||||
want: &idTokenHintVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
want: &IDTokenHintVerifier{
|
||||
Issuer: tu.ValidIssuer,
|
||||
KeySet: tu.KeySet{},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -42,10 +42,10 @@ func TestNewIDTokenHintVerifier(t *testing.T) {
|
|||
WithSupportedIDTokenHintSigningAlgorithms("ABC", "DEF"),
|
||||
},
|
||||
},
|
||||
want: &idTokenHintVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
supportedSignAlgs: []string{"ABC", "DEF"},
|
||||
want: &IDTokenHintVerifier{
|
||||
Issuer: tu.ValidIssuer,
|
||||
KeySet: tu.KeySet{},
|
||||
SupportedSignAlgs: []string{"ABC", "DEF"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -58,14 +58,14 @@ func TestNewIDTokenHintVerifier(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestVerifyIDTokenHint(t *testing.T) {
|
||||
verifier := &idTokenHintVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
maxAgeIAT: 2 * time.Minute,
|
||||
offset: time.Second,
|
||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
maxAge: 2 * time.Minute,
|
||||
acr: tu.ACRVerify,
|
||||
keySet: tu.KeySet{},
|
||||
verifier := &IDTokenHintVerifier{
|
||||
Issuer: tu.ValidIssuer,
|
||||
MaxAgeIAT: 2 * time.Minute,
|
||||
Offset: time.Second,
|
||||
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
MaxAge: 2 * time.Minute,
|
||||
ACR: tu.ACRVerify,
|
||||
KeySet: tu.KeySet{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
|
|
@ -11,28 +11,25 @@ import (
|
|||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
type JWTProfileVerifier interface {
|
||||
// JWTProfileVerfiier extends oidc.Verifier with
|
||||
// a jwtProfileKeyStorage and a function to check
|
||||
// the subject in a token.
|
||||
type JWTProfileVerifier struct {
|
||||
oidc.Verifier
|
||||
Storage() jwtProfileKeyStorage
|
||||
CheckSubject(request *oidc.JWTTokenRequest) error
|
||||
}
|
||||
|
||||
type jwtProfileVerifier struct {
|
||||
storage jwtProfileKeyStorage
|
||||
subjectCheck func(request *oidc.JWTTokenRequest) error
|
||||
issuer string
|
||||
maxAgeIAT time.Duration
|
||||
offset time.Duration
|
||||
Storage JWTProfileKeyStorage
|
||||
CheckSubject func(request *oidc.JWTTokenRequest) error
|
||||
}
|
||||
|
||||
// NewJWTProfileVerifier creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication)
|
||||
func NewJWTProfileVerifier(storage jwtProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) JWTProfileVerifier {
|
||||
j := &jwtProfileVerifier{
|
||||
storage: storage,
|
||||
subjectCheck: SubjectIsIssuer,
|
||||
issuer: issuer,
|
||||
maxAgeIAT: maxAgeIAT,
|
||||
offset: offset,
|
||||
func NewJWTProfileVerifier(storage JWTProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier {
|
||||
j := &JWTProfileVerifier{
|
||||
Verifier: oidc.Verifier{
|
||||
Issuer: issuer,
|
||||
MaxAgeIAT: maxAgeIAT,
|
||||
Offset: offset,
|
||||
},
|
||||
Storage: storage,
|
||||
CheckSubject: SubjectIsIssuer,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
|
@ -42,53 +39,35 @@ func NewJWTProfileVerifier(storage jwtProfileKeyStorage, issuer string, maxAgeIA
|
|||
return j
|
||||
}
|
||||
|
||||
type JWTProfileVerifierOption func(*jwtProfileVerifier)
|
||||
type JWTProfileVerifierOption func(*JWTProfileVerifier)
|
||||
|
||||
// SubjectCheck sets a custom function to check the subject.
|
||||
// Defaults to SubjectIsIssuer()
|
||||
func SubjectCheck(check func(request *oidc.JWTTokenRequest) error) JWTProfileVerifierOption {
|
||||
return func(verifier *jwtProfileVerifier) {
|
||||
verifier.subjectCheck = check
|
||||
return func(verifier *JWTProfileVerifier) {
|
||||
verifier.CheckSubject = check
|
||||
}
|
||||
}
|
||||
|
||||
func (v *jwtProfileVerifier) Issuer() string {
|
||||
return v.issuer
|
||||
}
|
||||
|
||||
func (v *jwtProfileVerifier) Storage() jwtProfileKeyStorage {
|
||||
return v.storage
|
||||
}
|
||||
|
||||
func (v *jwtProfileVerifier) MaxAgeIAT() time.Duration {
|
||||
return v.maxAgeIAT
|
||||
}
|
||||
|
||||
func (v *jwtProfileVerifier) Offset() time.Duration {
|
||||
return v.offset
|
||||
}
|
||||
|
||||
func (v *jwtProfileVerifier) CheckSubject(request *oidc.JWTTokenRequest) error {
|
||||
return v.subjectCheck(request)
|
||||
}
|
||||
|
||||
// VerifyJWTAssertion verifies the assertion string from JWT Profile (authorization grant and client authentication)
|
||||
//
|
||||
// checks audience, exp, iat, signature and that issuer and sub are the same
|
||||
func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) {
|
||||
func VerifyJWTAssertion(ctx context.Context, assertion string, v *JWTProfileVerifier) (*oidc.JWTTokenRequest, error) {
|
||||
request := new(oidc.JWTTokenRequest)
|
||||
payload, err := oidc.ParseToken(assertion, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAudience(request, v.Issuer()); err != nil {
|
||||
if err = oidc.CheckAudience(request, v.Issuer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckExpiration(request, v.Offset()); err != nil {
|
||||
if err = oidc.CheckExpiration(request, v.Offset); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuedAt(request, v.MaxAgeIAT(), v.Offset()); err != nil {
|
||||
if err = oidc.CheckIssuedAt(request, v.MaxAgeIAT, v.Offset); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -96,17 +75,18 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif
|
|||
return nil, err
|
||||
}
|
||||
|
||||
keySet := &jwtProfileKeySet{storage: v.Storage(), clientID: request.Issuer}
|
||||
keySet := &jwtProfileKeySet{storage: v.Storage, clientID: request.Issuer}
|
||||
if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
type jwtProfileKeyStorage interface {
|
||||
type JWTProfileKeyStorage interface {
|
||||
GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
|
||||
}
|
||||
|
||||
// SubjectIsIssuer
|
||||
func SubjectIsIssuer(request *oidc.JWTTokenRequest) error {
|
||||
if request.Issuer != request.Subject {
|
||||
return errors.New("delegation not allowed, issuer and sub must be identical")
|
||||
|
@ -115,7 +95,7 @@ func SubjectIsIssuer(request *oidc.JWTTokenRequest) error {
|
|||
}
|
||||
|
||||
type jwtProfileKeySet struct {
|
||||
storage jwtProfileKeyStorage
|
||||
storage JWTProfileKeyStorage
|
||||
clientID string
|
||||
}
|
||||
|
||||
|
|
117
pkg/op/verifier_jwt_profile_test.go
Normal file
117
pkg/op/verifier_jwt_profile_test.go
Normal file
|
@ -0,0 +1,117 @@
|
|||
package op_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
tu "github.com/zitadel/oidc/v3/internal/testutil"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
func TestNewJWTProfileVerifier(t *testing.T) {
|
||||
want := &op.JWTProfileVerifier{
|
||||
Verifier: oidc.Verifier{
|
||||
Issuer: tu.ValidIssuer,
|
||||
MaxAgeIAT: time.Minute,
|
||||
Offset: time.Second,
|
||||
},
|
||||
Storage: tu.JWTProfileKeyStorage{},
|
||||
}
|
||||
got := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, time.Second, op.SubjectCheck(func(request *oidc.JWTTokenRequest) error {
|
||||
return oidc.ErrSubjectMissing
|
||||
}))
|
||||
assert.Equal(t, want.Verifier, got.Verifier)
|
||||
assert.Equal(t, want.Storage, got.Storage)
|
||||
assert.ErrorIs(t, got.CheckSubject(nil), oidc.ErrSubjectMissing)
|
||||
}
|
||||
|
||||
func TestVerifyJWTAssertion(t *testing.T) {
|
||||
errCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
verifier := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, 0)
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
newToken func() (string, *oidc.JWTTokenRequest)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "parse error",
|
||||
ctx: context.Background(),
|
||||
newToken: func() (string, *oidc.JWTTokenRequest) { return "!", nil },
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong audience",
|
||||
ctx: context.Background(),
|
||||
newToken: func() (string, *oidc.JWTTokenRequest) {
|
||||
return tu.NewJWTProfileAssertion(
|
||||
tu.ValidClientID, tu.ValidClientID, []string{"wrong"},
|
||||
time.Now(), tu.ValidExpiration,
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired",
|
||||
ctx: context.Background(),
|
||||
newToken: func() (string, *oidc.JWTTokenRequest) {
|
||||
return tu.NewJWTProfileAssertion(
|
||||
tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer},
|
||||
time.Now(), time.Now().Add(-time.Hour),
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid iat",
|
||||
ctx: context.Background(),
|
||||
newToken: func() (string, *oidc.JWTTokenRequest) {
|
||||
return tu.NewJWTProfileAssertion(
|
||||
tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer},
|
||||
time.Now().Add(time.Hour), tu.ValidExpiration,
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid subject",
|
||||
ctx: context.Background(),
|
||||
newToken: func() (string, *oidc.JWTTokenRequest) {
|
||||
return tu.NewJWTProfileAssertion(
|
||||
tu.ValidClientID, "wrong", []string{tu.ValidIssuer},
|
||||
time.Now(), tu.ValidExpiration,
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "check signature fail",
|
||||
ctx: errCtx,
|
||||
newToken: tu.ValidJWTProfileAssertion,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
ctx: context.Background(),
|
||||
newToken: tu.ValidJWTProfileAssertion,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assertion, want := tt.newToken()
|
||||
got, err := op.VerifyJWTAssertion(tt.ctx, assertion, verifier)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue