cleanup nested types and add some unit tests

This commit is contained in:
Tim Möhlmann 2023-02-19 22:14:55 +02:00
parent 62a3af61f3
commit 3940b520a8
4 changed files with 247 additions and 259 deletions

View file

@ -1,84 +0,0 @@
package oidc
// Some expirimental stuff, no sure yet if it can be used
// or deleted before final PR.
/*
// CustomClaims allows the joining of any type
// with Registered fields and a map of custom Claims.
type CustomClaims[R any] struct {
Registered R
Claims map[string]any
}
func (c *CustomClaims[_]) AppendClaims(k string, v any) {
if c.Claims == nil {
c.Claims = make(map[string]any)
}
c.Claims[k] = v
}
// MarshalJSON implements the json.Marshaller interface.
// The Registered and Claims map are merged into a
// single JSON object. Registered fields overwrite
// custom Claims.
func (c *CustomClaims[_]) MarshalJSON() ([]byte, error) {
return mergeAndMarshalClaims(&c.Registered, c.Claims)
}
// UnmashalJSON implements the json.Unmarshaller interface.
// Matching values from the JSON document are set in Registered.
// The map Claims will contain all claims from the JSON document.
func (c *CustomClaims[_]) UnmarshalJSON(data []byte) error {
return unmarshalJSONMulti(data, &c.Registered, &c.Claims)
}
// CustomTokenClaims allows the joining of a Claims
// type with registered fields and a map of custom Claims.
// CustomTokenClaims implements the Claims interface,
// and any type that embeds TokenClaims can be used as
// type argument.
type CustomTokenClaims[TC Claims] struct {
Registered TC
Claims map[string]any
}
func (c *CustomTokenClaims[_]) AppendClaims(k string, v any) {
if c.Claims == nil {
c.Claims = make(map[string]any)
}
c.Claims[k] = v
}
// MarshalJSON implements the json.Marshaller interface.
// The Registered and Claims map are merged into a
// single JSON object. Registered fields overwrite
// custom Claims.
func (c *CustomTokenClaims[_]) MarshalJSON() ([]byte, error) {
return mergeAndMarshalClaims(&c.Registered, c.Claims)
}
// UnmashalJSON implements the json.Unmarshaller interface.
// Matching values from the JSON document are set in Registered.
// The map Claims will contain all claims from the JSON document.
func (c *CustomTokenClaims[_]) UnmarshalJSON(data []byte) error {
return unmarshalJSONMulti(data, &c.Registered, &c.Claims)
}
func (c *CustomTokenClaims[_]) GetIssuer() string { return c.Registered.GetIssuer() }
func (c *CustomTokenClaims[_]) GetSubject() string { return c.Registered.GetSubject() }
func (c *CustomTokenClaims[_]) GetAudience() []string { return c.Registered.GetAudience() }
func (c *CustomTokenClaims[_]) GetExpiration() time.Time { return c.Registered.GetExpiration() }
func (c *CustomTokenClaims[_]) GetIssuedAt() time.Time { return c.Registered.GetIssuedAt() }
func (c *CustomTokenClaims[_]) GetNonce() string { return c.Registered.GetNonce() }
func (c *CustomTokenClaims[_]) GetAuthTime() time.Time { return c.Registered.GetAuthTime() }
func (c *CustomTokenClaims[_]) GetAuthorizedParty() string {
return c.Registered.GetAuthorizedParty()
}
func (c *CustomTokenClaims[_]) GetAuthenticationContextClassReference() string {
return c.Registered.GetAuthenticationContextClassReference()
}
func (c *CustomTokenClaims[_]) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
c.Registered.SetSignatureAlgorithm(algorithm)
}
*/

View file

@ -11,8 +11,6 @@ import (
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
)
const dataDir = "regression_data"
@ -33,139 +31,10 @@ func encodeJSON(t *testing.T, w io.Writer, obj interface{}) {
require.NoError(t, enc.Encode(obj))
}
var (
accessTokenRegressData = &AccessTokenClaims{
RegisteredAccessTokenClaims: RegisteredAccessTokenClaims{
TokenClaims: TokenClaims{
Issuer: "zitadel",
Subject: "hello@me.com",
Audience: Audience{"foo", "bar"},
Expiration: 12345,
IssuedAt: 12000,
JWTID: "900",
AuthorizedParty: "just@me.com",
Nonce: "6969",
AuthTime: 12000,
AuthenticationContextClassReference: "something",
AuthenticationMethodsReferences: []string{"some", "methods"},
ClientID: "777",
SignatureAlg: jose.ES256,
},
NotBefore: 12000,
CodeHash: "hashhash",
SessionID: "666",
Scopes: []string{"email", "phone"},
AccessTokenUseNumber: 22,
},
Claims: map[string]interface{}{
"foo": "bar",
},
}
idTokenRegressData = &IDTokenClaims{
RegisteredIDTokenClaims: RegisteredIDTokenClaims{
TokenClaims: TokenClaims{
Issuer: "zitadel",
Subject: "hello@me.com",
Audience: Audience{"foo", "bar"},
Expiration: 12345,
IssuedAt: 12000,
JWTID: "900",
AuthorizedParty: "just@me.com",
Nonce: "6969",
AuthTime: 12000,
AuthenticationContextClassReference: "something",
AuthenticationMethodsReferences: []string{"some", "methods"},
ClientID: "777",
SignatureAlg: jose.ES256,
},
NotBefore: 12000,
AccessTokenHash: "acthashhash",
CodeHash: "hashhash",
UserInfoProfile: userInfoRegressData.UserInfoProfile,
UserInfoEmail: userInfoRegressData.UserInfoEmail,
UserInfoPhone: userInfoRegressData.UserInfoPhone,
Address: userInfoRegressData.Address,
},
Claims: map[string]interface{}{
"foo": "bar",
},
}
introspectionResponseRegressData = &IntrospectionResponse{
Active: true,
Scope: SpaceDelimitedArray{"email", "phone"},
ClientID: "777",
TokenType: "idtoken",
Expiration: 12345,
IssuedAt: 12000,
NotBefore: 12000,
Subject: "hello@me.com",
Audience: Audience{"foo", "bar"},
Issuer: "zitadel",
JWTID: "900",
Username: "muhlemmer",
UserInfoProfile: userInfoRegressData.UserInfoProfile,
UserInfoEmail: userInfoRegressData.UserInfoEmail,
UserInfoPhone: userInfoRegressData.UserInfoPhone,
Address: userInfoRegressData.Address,
Claims: map[string]interface{}{
"foo": "bar",
},
}
userInfoRegressData = &UserInfo{
Subject: "hello@me.com",
UserInfoProfile: UserInfoProfile{
Name: "Tim Möhlmann",
GivenName: "Tim",
FamilyName: "Möhlmann",
MiddleName: "Danger",
Nickname: "muhlemmer",
Profile: "https://github.com/muhlemmer",
Picture: "https://avatars.githubusercontent.com/u/5411563?v=4",
Website: "https://zitadel.com",
Gender: "male",
Birthdate: "1st of April",
Zoneinfo: "Europe/Amsterdam",
Locale: NewLocale(language.Dutch),
UpdatedAt: 1,
PreferredUsername: "muhlemmer",
},
UserInfoEmail: UserInfoEmail{
Email: "tim@zitadel.com",
EmailVerified: true,
},
UserInfoPhone: UserInfoPhone{
PhoneNumber: "+1234567890",
PhoneNumberVerified: true,
},
Address: UserInfoAddress{
Formatted: "Sesame street 666\n666-666, Smallvile\nMoon",
StreetAddress: "Sesame street 666",
Locality: "Smallvile",
Region: "Outer space",
PostalCode: "666-666",
Country: "Moon",
},
Claims: map[string]interface{}{
"foo": "bar",
},
}
jwtProfileAssertionRegressData = &JWTProfileAssertionClaims{
PrivateKeyID: "8888",
PrivateKey: []byte("qwerty"),
Issuer: "zitadel",
Subject: "hello@me.com",
Audience: Audience{"foo", "bar"},
Expiration: 12345,
IssuedAt: 12000,
Claims: map[string]interface{}{
"foo": "bar",
},
}
regressionData = []interface{}{
accessTokenRegressData,
idTokenRegressData,
introspectionResponseRegressData,
userInfoRegressData,
jwtProfileAssertionRegressData,
}
)
var regressionData = []interface{}{
accessTokenData,
idTokenData,
introspectionResponseData,
userInfoData,
jwtProfileAssertionData,
}

View file

@ -66,38 +66,32 @@ func (c *TokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
c.SignatureAlg = algorithm
}
type RegisteredAccessTokenClaims struct {
type AccessTokenClaims struct {
TokenClaims
NotBefore Time `json:"nbf,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
SessionID string `json:"sid,omitempty"`
Scopes []string `json:"scope,omitempty"`
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
}
type AccessTokenClaims struct {
RegisteredAccessTokenClaims
Claims map[string]any `json:"-"`
}
func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, id, clientID string, skew time.Duration) *AccessTokenClaims {
func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, jwtid, clientID string, skew time.Duration) *AccessTokenClaims {
now := time.Now().UTC().Add(-skew)
if len(audience) == 0 {
audience = append(audience, clientID)
}
return &AccessTokenClaims{
RegisteredAccessTokenClaims: RegisteredAccessTokenClaims{
TokenClaims: TokenClaims{
Issuer: issuer,
Subject: subject,
Audience: audience,
Expiration: FromTime(expiration),
IssuedAt: FromTime(now),
JWTID: id,
},
NotBefore: FromTime(now),
TokenClaims: TokenClaims{
Issuer: issuer,
Subject: subject,
Audience: audience,
Expiration: FromTime(expiration),
IssuedAt: FromTime(now),
JWTID: jwtid,
},
NotBefore: FromTime(now),
}
}
@ -111,7 +105,7 @@ func (a *AccessTokenClaims) UnmarshalJSON(data []byte) error {
return unmarshalJSONMulti(data, (*atcAlias)(a), &a.Claims)
}
type RegisteredIDTokenClaims struct {
type IDTokenClaims struct {
TokenClaims
NotBefore Time `json:"nbf,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"`
@ -120,14 +114,15 @@ type RegisteredIDTokenClaims struct {
UserInfoEmail
UserInfoPhone
Address UserInfoAddress `json:"address,omitempty"`
Claims map[string]any `json:"-"`
}
// GetAccessTokenHash implements the IDTokenClaims interface
func (t *RegisteredIDTokenClaims) GetAccessTokenHash() string {
func (t *IDTokenClaims) GetAccessTokenHash() string {
return t.AccessTokenHash
}
func (t *RegisteredIDTokenClaims) SetUserInfo(i *UserInfo) {
func (t *IDTokenClaims) SetUserInfo(i *UserInfo) {
t.Subject = i.Subject
t.UserInfoProfile = i.UserInfoProfile
t.UserInfoEmail = i.UserInfoEmail
@ -135,27 +130,20 @@ func (t *RegisteredIDTokenClaims) SetUserInfo(i *UserInfo) {
t.Address = i.Address
}
type IDTokenClaims struct {
RegisteredIDTokenClaims
Claims map[string]any `json:"-"`
}
func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration) *IDTokenClaims {
audience = AppendClientIDToAudience(clientID, audience)
return &IDTokenClaims{
RegisteredIDTokenClaims: RegisteredIDTokenClaims{
TokenClaims: TokenClaims{
Issuer: issuer,
Subject: subject,
Audience: audience,
Expiration: FromTime(expiration),
IssuedAt: FromTime(time.Now().Add(-skew)),
AuthTime: FromTime(authTime.Add(-skew)),
Nonce: nonce,
AuthenticationContextClassReference: acr,
AuthenticationMethodsReferences: amr,
AuthorizedParty: clientID,
},
TokenClaims: TokenClaims{
Issuer: issuer,
Subject: subject,
Audience: audience,
Expiration: FromTime(expiration),
IssuedAt: FromTime(time.Now().Add(-skew)),
AuthTime: FromTime(authTime.Add(-skew)),
Nonce: nonce,
AuthenticationContextClassReference: acr,
AuthenticationMethodsReferences: amr,
AuthorizedParty: clientID,
},
}
}

215
pkg/oidc/token_test.go Normal file
View file

@ -0,0 +1,215 @@
package oidc
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
)
var (
tokenClaimsData = TokenClaims{
Issuer: "zitadel",
Subject: "hello@me.com",
Audience: Audience{"foo", "bar"},
Expiration: 12345,
IssuedAt: 12000,
JWTID: "900",
AuthorizedParty: "just@me.com",
Nonce: "6969",
AuthTime: 12000,
AuthenticationContextClassReference: "something",
AuthenticationMethodsReferences: []string{"some", "methods"},
ClientID: "777",
SignatureAlg: jose.ES256,
}
accessTokenData = &AccessTokenClaims{
TokenClaims: tokenClaimsData,
NotBefore: 12000,
CodeHash: "hashhash",
SessionID: "666",
Scopes: []string{"email", "phone"},
AccessTokenUseNumber: 22,
Claims: map[string]interface{}{
"foo": "bar",
},
}
idTokenData = &IDTokenClaims{
TokenClaims: tokenClaimsData,
NotBefore: 12000,
AccessTokenHash: "acthashhash",
CodeHash: "hashhash",
UserInfoProfile: userInfoData.UserInfoProfile,
UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone,
Address: userInfoData.Address,
Claims: map[string]interface{}{
"foo": "bar",
},
}
introspectionResponseData = &IntrospectionResponse{
Active: true,
Scope: SpaceDelimitedArray{"email", "phone"},
ClientID: "777",
TokenType: "idtoken",
Expiration: 12345,
IssuedAt: 12000,
NotBefore: 12000,
Subject: "hello@me.com",
Audience: Audience{"foo", "bar"},
Issuer: "zitadel",
JWTID: "900",
Username: "muhlemmer",
UserInfoProfile: userInfoData.UserInfoProfile,
UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone,
Address: userInfoData.Address,
Claims: map[string]interface{}{
"foo": "bar",
},
}
userInfoData = &UserInfo{
Subject: "hello@me.com",
UserInfoProfile: UserInfoProfile{
Name: "Tim Möhlmann",
GivenName: "Tim",
FamilyName: "Möhlmann",
MiddleName: "Danger",
Nickname: "muhlemmer",
Profile: "https://github.com/muhlemmer",
Picture: "https://avatars.githubusercontent.com/u/5411563?v=4",
Website: "https://zitadel.com",
Gender: "male",
Birthdate: "1st of April",
Zoneinfo: "Europe/Amsterdam",
Locale: NewLocale(language.Dutch),
UpdatedAt: 1,
PreferredUsername: "muhlemmer",
},
UserInfoEmail: UserInfoEmail{
Email: "tim@zitadel.com",
EmailVerified: true,
},
UserInfoPhone: UserInfoPhone{
PhoneNumber: "+1234567890",
PhoneNumberVerified: true,
},
Address: UserInfoAddress{
Formatted: "Sesame street 666\n666-666, Smallvile\nMoon",
StreetAddress: "Sesame street 666",
Locality: "Smallvile",
Region: "Outer space",
PostalCode: "666-666",
Country: "Moon",
},
Claims: map[string]interface{}{
"foo": "bar",
},
}
jwtProfileAssertionData = &JWTProfileAssertionClaims{
PrivateKeyID: "8888",
PrivateKey: []byte("qwerty"),
Issuer: "zitadel",
Subject: "hello@me.com",
Audience: Audience{"foo", "bar"},
Expiration: 12345,
IssuedAt: 12000,
Claims: map[string]interface{}{
"foo": "bar",
},
}
)
func TestTokenClaims(t *testing.T) {
claims := tokenClaimsData
assert.Equal(t, claims.Issuer, tokenClaimsData.GetIssuer())
assert.Equal(t, claims.Subject, tokenClaimsData.GetSubject())
assert.Equal(t, []string(claims.Audience), tokenClaimsData.GetAudience())
assert.Equal(t, claims.Expiration.AsTime(), tokenClaimsData.GetExpiration())
assert.Equal(t, claims.IssuedAt.AsTime(), tokenClaimsData.GetIssuedAt())
assert.Equal(t, claims.Nonce, tokenClaimsData.GetNonce())
assert.Equal(t, claims.AuthTime.AsTime(), tokenClaimsData.GetAuthTime())
assert.Equal(t, claims.AuthorizedParty, tokenClaimsData.GetAuthorizedParty())
assert.Equal(t, claims.SignatureAlg, tokenClaimsData.GetSignatureAlgorithm())
assert.Equal(t, claims.AuthenticationContextClassReference, tokenClaimsData.GetAuthenticationContextClassReference())
claims.SetSignatureAlgorithm(jose.ES384)
assert.Equal(t, jose.ES384, claims.SignatureAlg)
}
func TestNewAccessTokenClaims(t *testing.T) {
want := &AccessTokenClaims{
TokenClaims: TokenClaims{
Issuer: "zitadel",
Subject: "hello@me.com",
Audience: Audience{"foo"},
Expiration: 12345,
JWTID: "900",
},
}
got := NewAccessTokenClaims(
want.Issuer, want.Subject, nil,
want.Expiration.AsTime(), want.JWTID, "foo", time.Second,
)
// Make equal not fail on dynamic timestamp
got.IssuedAt = 0
got.NotBefore = 0
assert.Equal(t, want, got)
}
func TestIDTokenClaims_GetAccessTokenHash(t *testing.T) {
assert.Equal(t, idTokenData.AccessTokenHash, idTokenData.GetAccessTokenHash())
}
func TestIDTokenClaims_SetUserInfo(t *testing.T) {
want := IDTokenClaims{
TokenClaims: TokenClaims{
Subject: userInfoData.Subject,
},
UserInfoProfile: userInfoData.UserInfoProfile,
UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone,
Address: userInfoData.Address,
}
var got IDTokenClaims
got.SetUserInfo(userInfoData)
assert.Equal(t, want, got)
}
func TestNewIDTokenClaims(t *testing.T) {
want := &IDTokenClaims{
TokenClaims: TokenClaims{
Issuer: "zitadel",
Subject: "hello@me.com",
Audience: Audience{"foo", "just@me.com"},
Expiration: 12345,
AuthTime: 12000,
Nonce: "6969",
AuthenticationContextClassReference: "something",
AuthenticationMethodsReferences: []string{"some", "methods"},
AuthorizedParty: "just@me.com",
},
}
got := NewIDTokenClaims(
want.Issuer, want.Subject, want.Audience,
want.Expiration.AsTime(),
want.AuthTime.AsTime().Add(time.Second),
want.Nonce, want.AuthenticationContextClassReference,
want.AuthenticationMethodsReferences, want.AuthorizedParty,
time.Second,
)
// Make equal not fail on dynamic timestamp
got.IssuedAt = 0
assert.Equal(t, want, got)
}