refactor: use struct types for claim related types (#283)
* oidc: add regression tests for token claim json this helps to verify that the same JSON is produced, after these types are refactored. * refactor: use struct types for claim related types BREAKING CHANGE: The following types are changed from interface to struct type: - AccessTokenClaims - IDTokenClaims - IntrospectionResponse - UserInfo and related types. The following methods of OPStorage now take a pointer to a struct type, instead of an interface: - SetUserinfoFromScopes - SetUserinfoFromToken - SetIntrospectionFromToken The following functions are now generic, so that type-safe extension of Claims is now possible: - op.VerifyIDTokenHint - op.VerifyAccessToken - rp.VerifyTokens - rp.VerifyIDToken - Changed UserInfoAddress to pointer in UserInfo and IntrospectionResponse. This was needed to make omitempty work correctly. - Copy or merge maps in IntrospectionResponse and SetUserInfo * op: add example for VerifyAccessToken * fix: rp: wrong assignment in WithIssuedAtMaxAge WithIssuedAtMaxAge assigned its value to v.maxAge, which was wrong. This change fixes that by assiging the duration to v.maxAgeIAT. * rp: add VerifyTokens example * oidc: add standard references to: - IDTokenClaims - IntrospectionResponse - UserInfo * only count coverage for `./pkg/...`
This commit is contained in:
parent
4bd2b742f9
commit
dea8bc96ea
55 changed files with 2358 additions and 1516 deletions
|
@ -371,7 +371,7 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
|
|||
if idTokenHint == "" {
|
||||
return "", nil
|
||||
}
|
||||
claims, err := VerifyIDTokenHint(ctx, idTokenHint, verifier)
|
||||
claims, err := VerifyIDTokenHint[*oidc.TokenClaims](ctx, idTokenHint, verifier)
|
||||
if err != nil {
|
||||
return "", oidc.ErrLoginRequired().WithDescription("The id_token_hint is invalid. " +
|
||||
"If you have any questions, you may contact the administrator of the application.")
|
||||
|
|
|
@ -263,7 +263,7 @@ func (mr *MockStorageMockRecorder) SaveAuthCode(arg0, arg1, arg2 interface{}) *g
|
|||
}
|
||||
|
||||
// SetIntrospectionFromToken mocks base method.
|
||||
func (m *MockStorage) SetIntrospectionFromToken(arg0 context.Context, arg1 oidc.IntrospectionResponse, arg2, arg3, arg4 string) error {
|
||||
func (m *MockStorage) SetIntrospectionFromToken(arg0 context.Context, arg1 *oidc.IntrospectionResponse, arg2, arg3, arg4 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetIntrospectionFromToken", arg0, arg1, arg2, arg3, arg4)
|
||||
ret0, _ := ret[0].(error)
|
||||
|
@ -277,7 +277,7 @@ func (mr *MockStorageMockRecorder) SetIntrospectionFromToken(arg0, arg1, arg2, a
|
|||
}
|
||||
|
||||
// SetUserinfoFromScopes mocks base method.
|
||||
func (m *MockStorage) SetUserinfoFromScopes(arg0 context.Context, arg1 oidc.UserInfoSetter, arg2, arg3 string, arg4 []string) error {
|
||||
func (m *MockStorage) SetUserinfoFromScopes(arg0 context.Context, arg1 *oidc.UserInfo, arg2, arg3 string, arg4 []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetUserinfoFromScopes", arg0, arg1, arg2, arg3, arg4)
|
||||
ret0, _ := ret[0].(error)
|
||||
|
@ -291,7 +291,7 @@ func (mr *MockStorageMockRecorder) SetUserinfoFromScopes(arg0, arg1, arg2, arg3,
|
|||
}
|
||||
|
||||
// SetUserinfoFromToken mocks base method.
|
||||
func (m *MockStorage) SetUserinfoFromToken(arg0 context.Context, arg1 oidc.UserInfoSetter, arg2, arg3, arg4 string) error {
|
||||
func (m *MockStorage) SetUserinfoFromToken(arg0 context.Context, arg1 *oidc.UserInfo, arg2, arg3, arg4 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetUserinfoFromToken", arg0, arg1, arg2, arg3, arg4)
|
||||
ret0, _ := ret[0].(error)
|
||||
|
|
|
@ -59,7 +59,7 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
|
|||
RedirectURI: ender.DefaultLogoutRedirectURI(),
|
||||
}
|
||||
if req.IdTokenHint != "" {
|
||||
claims, err := VerifyIDTokenHint(ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx))
|
||||
claims, err := VerifyIDTokenHint[*oidc.TokenClaims](ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx))
|
||||
if err != nil {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err)
|
||||
}
|
||||
|
|
|
@ -96,7 +96,7 @@ type TokenExchangeStorage interface {
|
|||
|
||||
// SetUserinfoFromTokenExchangeRequest will be called during id token creation.
|
||||
// Claims evaluation can be based on all validated request data available, including: scopes, resource, audience, etc.
|
||||
SetUserinfoFromTokenExchangeRequest(ctx context.Context, userinfo oidc.UserInfoSetter, request TokenExchangeRequest) error
|
||||
SetUserinfoFromTokenExchangeRequest(ctx context.Context, userinfo *oidc.UserInfo, request TokenExchangeRequest) error
|
||||
}
|
||||
|
||||
// TokenExchangeTokensVerifierStorage is an optional interface used in token exchange process to verify tokens
|
||||
|
@ -111,9 +111,9 @@ var ErrInvalidRefreshToken = errors.New("invalid_refresh_token")
|
|||
type OPStorage interface {
|
||||
GetClientByClientID(ctx context.Context, clientID string) (Client, error)
|
||||
AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error
|
||||
SetUserinfoFromScopes(ctx context.Context, userinfo oidc.UserInfoSetter, userID, clientID string, scopes []string) error
|
||||
SetUserinfoFromToken(ctx context.Context, userinfo oidc.UserInfoSetter, tokenID, subject, origin string) error
|
||||
SetIntrospectionFromToken(ctx context.Context, userinfo oidc.IntrospectionResponse, tokenID, subject, clientID string) error
|
||||
SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error
|
||||
SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error
|
||||
SetIntrospectionFromToken(ctx context.Context, userinfo *oidc.IntrospectionResponse, tokenID, subject, clientID string) error
|
||||
GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]interface{}, error)
|
||||
GetKeyByIDAndClientID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error)
|
||||
ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error)
|
||||
|
|
|
@ -129,7 +129,7 @@ func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, ex
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
claims.SetPrivateClaims(privateClaims)
|
||||
claims.Claims = privateClaims
|
||||
}
|
||||
signingKey, err := storage.SigningKey(ctx)
|
||||
if err != nil {
|
||||
|
@ -169,7 +169,7 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
claims.SetAccessTokenHash(atHash)
|
||||
claims.AccessTokenHash = atHash
|
||||
if !client.IDTokenUserinfoClaimsAssertion() {
|
||||
scopes = removeUserinfoScopes(scopes)
|
||||
}
|
||||
|
@ -178,26 +178,26 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v
|
|||
tokenExchangeRequest, okReq := request.(TokenExchangeRequest)
|
||||
teStorage, okStorage := storage.(TokenExchangeStorage)
|
||||
if okReq && okStorage {
|
||||
userInfo := oidc.NewUserInfo()
|
||||
userInfo := new(oidc.UserInfo)
|
||||
err := teStorage.SetUserinfoFromTokenExchangeRequest(ctx, userInfo, tokenExchangeRequest)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
claims.SetUserinfo(userInfo)
|
||||
claims.SetUserInfo(userInfo)
|
||||
} else if len(scopes) > 0 {
|
||||
userInfo := oidc.NewUserInfo()
|
||||
userInfo := new(oidc.UserInfo)
|
||||
err := storage.SetUserinfoFromScopes(ctx, userInfo, request.GetSubject(), request.GetClientID(), scopes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
claims.SetUserinfo(userInfo)
|
||||
claims.SetUserInfo(userInfo)
|
||||
}
|
||||
if code != "" {
|
||||
codeHash, err := oidc.ClaimHash(code, signingKey.SignatureAlgorithm())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
claims.SetCodeHash(codeHash)
|
||||
claims.CodeHash = codeHash
|
||||
}
|
||||
signer, err := SignerFromKey(signingKey)
|
||||
if err != nil {
|
||||
|
|
|
@ -280,9 +280,9 @@ func GetTokenIDAndSubjectFromToken(
|
|||
) (tokenIDOrToken, subject string, claims map[string]interface{}, ok bool) {
|
||||
switch tokenType {
|
||||
case oidc.AccessTokenType:
|
||||
var accessTokenClaims oidc.AccessTokenClaims
|
||||
var accessTokenClaims *oidc.AccessTokenClaims
|
||||
tokenIDOrToken, subject, accessTokenClaims, ok = getTokenIDAndClaims(ctx, exchanger, token)
|
||||
claims = accessTokenClaims.GetClaims()
|
||||
claims = accessTokenClaims.Claims
|
||||
case oidc.RefreshTokenType:
|
||||
refreshTokenRequest, err := exchanger.Storage().TokenRequestByRefreshToken(ctx, token)
|
||||
if err != nil {
|
||||
|
@ -291,12 +291,12 @@ func GetTokenIDAndSubjectFromToken(
|
|||
|
||||
tokenIDOrToken, subject, ok = token, refreshTokenRequest.GetSubject(), true
|
||||
case oidc.IDTokenType:
|
||||
idTokenClaims, err := VerifyIDTokenHint(ctx, token, exchanger.IDTokenHintVerifier(ctx))
|
||||
idTokenClaims, err := VerifyIDTokenHint[*oidc.IDTokenClaims](ctx, token, exchanger.IDTokenHintVerifier(ctx))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
tokenIDOrToken, subject, claims, ok = token, idTokenClaims.GetSubject(), idTokenClaims.GetClaims(), true
|
||||
tokenIDOrToken, subject, claims, ok = token, idTokenClaims.Subject, idTokenClaims.Claims, true
|
||||
}
|
||||
|
||||
if !ok {
|
||||
|
@ -380,7 +380,7 @@ func CreateTokenExchangeResponse(
|
|||
}, nil
|
||||
}
|
||||
|
||||
func getTokenIDAndClaims(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, oidc.AccessTokenClaims, bool) {
|
||||
func getTokenIDAndClaims(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, *oidc.AccessTokenClaims, bool) {
|
||||
tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
|
||||
if err == nil {
|
||||
splitToken := strings.Split(tokenIDSubject, ":")
|
||||
|
@ -390,10 +390,10 @@ func getTokenIDAndClaims(ctx context.Context, userinfoProvider UserinfoProvider,
|
|||
|
||||
return splitToken[0], splitToken[1], nil, true
|
||||
}
|
||||
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
|
||||
accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
|
||||
if err != nil {
|
||||
return "", "", nil, false
|
||||
}
|
||||
|
||||
return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), accessTokenClaims, true
|
||||
return accessTokenClaims.JWTID, accessTokenClaims.Subject, accessTokenClaims, true
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ func introspectionHandler(introspector Introspector) func(http.ResponseWriter, *
|
|||
}
|
||||
|
||||
func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspector) {
|
||||
response := oidc.NewIntrospectionResponse()
|
||||
response := new(oidc.IntrospectionResponse)
|
||||
token, clientID, err := ParseTokenIntrospectionRequest(r, introspector)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
|
@ -44,7 +44,7 @@ func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspecto
|
|||
httphelper.MarshalJSON(w, response)
|
||||
return
|
||||
}
|
||||
response.SetActive(true)
|
||||
response.Active = true
|
||||
httphelper.MarshalJSON(w, response)
|
||||
}
|
||||
|
||||
|
|
|
@ -151,9 +151,9 @@ func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider Use
|
|||
}
|
||||
return splitToken[0], splitToken[1], true
|
||||
}
|
||||
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
|
||||
accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
|
||||
if err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true
|
||||
return accessTokenClaims.JWTID, accessTokenClaims.Subject, true
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP
|
|||
http.Error(w, "access token invalid", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
info := oidc.NewUserInfo()
|
||||
info := new(oidc.UserInfo)
|
||||
err = userinfoProvider.Storage().SetUserinfoFromToken(r.Context(), info, tokenID, subject, r.Header.Get("origin"))
|
||||
if err != nil {
|
||||
httphelper.MarshalJSONWithStatus(w, err, http.StatusForbidden)
|
||||
|
@ -81,9 +81,9 @@ func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider
|
|||
}
|
||||
return splitToken[0], splitToken[1], true
|
||||
}
|
||||
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
|
||||
accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
|
||||
if err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true
|
||||
return accessTokenClaims.JWTID, accessTokenClaims.Subject, true
|
||||
}
|
||||
|
|
|
@ -18,8 +18,6 @@ type accessTokenVerifier struct {
|
|||
maxAgeIAT time.Duration
|
||||
offset time.Duration
|
||||
supportedSignAlgs []string
|
||||
maxAge time.Duration
|
||||
acr oidc.ACRVerifier
|
||||
keySet oidc.KeySet
|
||||
}
|
||||
|
||||
|
@ -67,29 +65,29 @@ func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTok
|
|||
return verifier
|
||||
}
|
||||
|
||||
// VerifyAccessToken validates the access token (issuer, signature and expiration)
|
||||
func VerifyAccessToken(ctx context.Context, token string, v AccessTokenVerifier) (oidc.AccessTokenClaims, error) {
|
||||
claims := oidc.EmptyAccessTokenClaims()
|
||||
// VerifyAccessToken validates the access token (issuer, signature and expiration).
|
||||
func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v AccessTokenVerifier) (claims C, err error) {
|
||||
var nilClaims C
|
||||
|
||||
decrypted, err := oidc.DecryptToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
payload, err := oidc.ParseToken(decrypted, claims)
|
||||
payload, err := oidc.ParseToken(decrypted, &claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
|
|
70
pkg/op/verifier_access_token_example_test.go
Normal file
70
pkg/op/verifier_access_token_example_test.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package op_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
)
|
||||
|
||||
// MyCustomClaims extends the TokenClaims base,
|
||||
// so it implements the oidc.Claims interface.
|
||||
// Instead of carrying a map, we add needed fields// to the struct for type safe access.
|
||||
type MyCustomClaims struct {
|
||||
oidc.TokenClaims
|
||||
NotBefore oidc.Time `json:"nbf,omitempty"`
|
||||
CodeHash string `json:"c_hash,omitempty"`
|
||||
SessionID string `json:"sid,omitempty"`
|
||||
Scopes []string `json:"scope,omitempty"`
|
||||
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
|
||||
Foo string `json:"foo,omitempty"`
|
||||
Bar *Nested `json:"bar,omitempty"`
|
||||
}
|
||||
|
||||
// Nested struct types are also possible.
|
||||
type Nested struct {
|
||||
Count int `json:"count,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
/*
|
||||
accessToken carries the following claims. foo and bar are custom claims
|
||||
|
||||
{
|
||||
"aud": [
|
||||
"unit",
|
||||
"test"
|
||||
],
|
||||
"bar": {
|
||||
"count": 22,
|
||||
"tags": [
|
||||
"some",
|
||||
"tags"
|
||||
]
|
||||
},
|
||||
"exp": 4802234675,
|
||||
"foo": "Hello, World!",
|
||||
"iat": 1678097014,
|
||||
"iss": "local.com",
|
||||
"jti": "9876",
|
||||
"nbf": 1678097014,
|
||||
"sub": "tim@local.com"
|
||||
}
|
||||
*/
|
||||
const accessToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhdWQiOlsidW5pdCIsInRlc3QiXSwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiZXhwIjo0ODAyMjM0Njc1LCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MDk3MDE0LCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MDk3MDE0LCJzdWIiOiJ0aW1AbG9jYWwuY29tIn0.OUgk-B7OXjYlYFj-nogqSDJiQE19tPrbzqUHEAjcEiJkaWo6-IpGVfDiGKm-TxjXQsNScxpaY0Pg3XIh1xK6TgtfYtoLQm-5RYw_mXgb9xqZB2VgPs6nNEYFUDM513MOU0EBc0QMyqAEGzW-HiSPAb4ugCvkLtM1yo11Xyy6vksAdZNs_mJDT4X3vFXnr0jk0ugnAW6fTN3_voC0F_9HQUAkmd750OIxkAHxAMvEPQcpbLHenVvX_Q0QMrzClVrxehn5TVMfmkYYg7ocr876Bq9xQGPNHAcrwvVIJqdg5uMUA38L3HC2BEueG6furZGvc7-qDWAT1VR9liM5ieKpPg`
|
||||
|
||||
func ExampleVerifyAccessToken_customClaims() {
|
||||
v := op.NewAccessTokenVerifier("local.com", tu.KeySet{})
|
||||
|
||||
// VerifyAccessToken can be called with the *MyCustomClaims.
|
||||
claims, err := op.VerifyAccessToken[*MyCustomClaims](context.TODO(), accessToken, v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Here we have typesafe access to the custom claims
|
||||
fmt.Println(claims.Foo, claims.Bar.Count, claims.Bar.Tags)
|
||||
// Output: Hello, World! 22 [some tags]
|
||||
}
|
126
pkg/op/verifier_access_token_test.go
Normal file
126
pkg/op/verifier_access_token_test.go
Normal file
|
@ -0,0 +1,126 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
func TestNewAccessTokenVerifier(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
keySet oidc.KeySet
|
||||
opts []AccessTokenVerifierOpt
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want AccessTokenVerifier
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
args: args{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
},
|
||||
want: &accessTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with signature algorithm",
|
||||
args: args{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
opts: []AccessTokenVerifierOpt{
|
||||
WithSupportedAccessTokenSigningAlgorithms("ABC", "DEF"),
|
||||
},
|
||||
},
|
||||
want: &accessTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
supportedSignAlgs: []string{"ABC", "DEF"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewAccessTokenVerifier(tt.args.issuer, tt.args.keySet, tt.args.opts...)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAccessToken(t *testing.T) {
|
||||
verifier := &accessTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
maxAgeIAT: 2 * time.Minute,
|
||||
offset: time.Second,
|
||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
keySet: tu.KeySet{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenClaims func() (string, *oidc.AccessTokenClaims)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
tokenClaims: tu.ValidAccessToken,
|
||||
},
|
||||
{
|
||||
name: "parse err",
|
||||
tokenClaims: func() (string, *oidc.AccessTokenClaims) { return "~~~~", nil },
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid signature",
|
||||
tokenClaims: func() (string, *oidc.AccessTokenClaims) { return tu.InvalidSignatureToken, nil },
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong issuer",
|
||||
tokenClaims: func() (string, *oidc.AccessTokenClaims) {
|
||||
return tu.NewAccessToken(
|
||||
"foo", tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidJWTID, tu.ValidClientID,
|
||||
tu.ValidSkew,
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired",
|
||||
tokenClaims: func() (string, *oidc.AccessTokenClaims) {
|
||||
return tu.NewAccessToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration.Add(-time.Hour), tu.ValidJWTID, tu.ValidClientID,
|
||||
tu.ValidSkew,
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, want := tt.tokenClaims()
|
||||
|
||||
got, err := VerifyAccessToken[*oidc.AccessTokenClaims](context.Background(), token, verifier)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, got, want)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -74,40 +74,40 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHi
|
|||
|
||||
// VerifyIDTokenHint validates the id token according to
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (oidc.IDTokenClaims, error) {
|
||||
claims := oidc.EmptyIDTokenClaims()
|
||||
func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTokenHintVerifier) (claims C, err error) {
|
||||
var nilClaims C
|
||||
|
||||
decrypted, err := oidc.DecryptToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
payload, err := oidc.ParseToken(decrypted, claims)
|
||||
payload, err := oidc.ParseToken(decrypted, &claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
|
161
pkg/op/verifier_id_token_hint_test.go
Normal file
161
pkg/op/verifier_id_token_hint_test.go
Normal file
|
@ -0,0 +1,161 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
func TestNewIDTokenHintVerifier(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
keySet oidc.KeySet
|
||||
opts []IDTokenHintVerifierOpt
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want IDTokenHintVerifier
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
args: args{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
},
|
||||
want: &idTokenHintVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with signature algorithm",
|
||||
args: args{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
opts: []IDTokenHintVerifierOpt{
|
||||
WithSupportedIDTokenHintSigningAlgorithms("ABC", "DEF"),
|
||||
},
|
||||
},
|
||||
want: &idTokenHintVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
supportedSignAlgs: []string{"ABC", "DEF"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewIDTokenHintVerifier(tt.args.issuer, tt.args.keySet, tt.args.opts...)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyIDTokenHint(t *testing.T) {
|
||||
verifier := &idTokenHintVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
maxAgeIAT: 2 * time.Minute,
|
||||
offset: time.Second,
|
||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
maxAge: 2 * time.Minute,
|
||||
acr: tu.ACRVerify,
|
||||
keySet: tu.KeySet{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenClaims func() (string, *oidc.IDTokenClaims)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
tokenClaims: tu.ValidIDToken,
|
||||
},
|
||||
{
|
||||
name: "parse err",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid signature",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong issuer",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
"foo", tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong IAT",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong acr",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
"else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired auth",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime.Add(-time.Hour), tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, want := tt.tokenClaims()
|
||||
|
||||
got, err := VerifyIDTokenHint[*oidc.IDTokenClaims](context.Background(), token, verifier)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, got, want)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue