From e1d50faf9b58798a0b4e397fea37a93a464ac5ba Mon Sep 17 00:00:00 2001 From: David Sharnoff Date: Mon, 27 Mar 2023 13:40:10 -0700 Subject: [PATCH] fix: do not modify userInfo when marshaling --- pkg/oidc/introspection_test.go | 3 ++- pkg/oidc/regression_assert_test.go | 7 +++++-- pkg/oidc/token.go | 1 - pkg/oidc/token_test.go | 5 +++-- pkg/oidc/userinfo_test.go | 5 ++++- pkg/oidc/util.go | 13 +++++++++---- 6 files changed, 23 insertions(+), 11 deletions(-) diff --git a/pkg/oidc/introspection_test.go b/pkg/oidc/introspection_test.go index bd49894..60cf8a4 100644 --- a/pkg/oidc/introspection_test.go +++ b/pkg/oidc/introspection_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "testing" + "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -25,7 +26,7 @@ func TestIntrospectionResponse_SetUserInfo(t *testing.T) { UserInfoProfile: userInfoData.UserInfoProfile, UserInfoEmail: userInfoData.UserInfoEmail, UserInfoPhone: userInfoData.UserInfoPhone, - Claims: userInfoData.Claims, + Claims: gu.MapCopy(userInfoData.Claims), }, }, { diff --git a/pkg/oidc/regression_assert_test.go b/pkg/oidc/regression_assert_test.go index 5e9fb3d..dd9f5ad 100644 --- a/pkg/oidc/regression_assert_test.go +++ b/pkg/oidc/regression_assert_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "io" "os" + "reflect" "strings" "testing" @@ -38,10 +39,12 @@ func Test_assert_regression(t *testing.T) { assert.JSONEq(t, want, first) + target := reflect.New(reflect.TypeOf(obj).Elem()).Interface() + require.NoError(t, - json.Unmarshal([]byte(first), obj), + json.Unmarshal([]byte(first), target), ) - second, err := json.Marshal(obj) + second, err := json.Marshal(target) require.NoError(t, err) assert.JSONEq(t, want, string(second)) diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 776e758..5283eb5 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -158,7 +158,6 @@ func (t *IDTokenClaims) SetUserInfo(i *UserInfo) { t.UserInfoEmail = i.UserInfoEmail t.UserInfoPhone = i.UserInfoPhone t.Address = i.Address - if t.Claims == nil { t.Claims = make(map[string]any, len(t.Claims)) } diff --git a/pkg/oidc/token_test.go b/pkg/oidc/token_test.go index 7377a84..ef1e77f 100644 --- a/pkg/oidc/token_test.go +++ b/pkg/oidc/token_test.go @@ -4,7 +4,6 @@ import ( "testing" "time" - "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "golang.org/x/text/language" "gopkg.in/square/go-jose.v2" @@ -182,7 +181,9 @@ func TestIDTokenClaims_SetUserInfo(t *testing.T) { UserInfoEmail: userInfoData.UserInfoEmail, UserInfoPhone: userInfoData.UserInfoPhone, Address: userInfoData.Address, - Claims: gu.MapCopy(userInfoData.Claims), + Claims: map[string]interface{}{ + "foo": "bar", + }, } var got IDTokenClaims diff --git a/pkg/oidc/userinfo_test.go b/pkg/oidc/userinfo_test.go index faab4e3..a574366 100644 --- a/pkg/oidc/userinfo_test.go +++ b/pkg/oidc/userinfo_test.go @@ -52,11 +52,14 @@ func TestUserInfoMarshal(t *testing.T) { out := new(UserInfo) assert.NoError(t, json.Unmarshal(marshal, out)) - assert.Equal(t, userinfo, out) expected, err := json.Marshal(out) assert.NoError(t, err) assert.Equal(t, expected, marshal) + + out2 := new(UserInfo) + assert.NoError(t, json.Unmarshal(expected, out2)) + assert.Equal(t, out, out2) } func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) { diff --git a/pkg/oidc/util.go b/pkg/oidc/util.go index a89d75e..462ea44 100644 --- a/pkg/oidc/util.go +++ b/pkg/oidc/util.go @@ -9,7 +9,7 @@ import ( // mergeAndMarshalClaims merges registered and the custom // claims map into a single JSON object. // Registered fields overwrite custom claims. -func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error) { +func mergeAndMarshalClaims(registered any, extraClaims map[string]any) ([]byte, error) { // Use a buffer for memory re-use, instead off letting // json allocate a new []byte for every step. buf := new(bytes.Buffer) @@ -19,16 +19,21 @@ func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error return nil, fmt.Errorf("oidc registered claims: %w", err) } - if len(claims) > 0 { + if len(extraClaims) > 0 { + merged := make(map[string]any) + for k, v := range extraClaims { + merged[k] = v + } + // Merge JSON data into custom claims. // The full-read action by the decoder resets the buffer // to zero len, while retaining underlaying cap. - if err := json.NewDecoder(buf).Decode(&claims); err != nil { + if err := json.NewDecoder(buf).Decode(&merged); err != nil { return nil, fmt.Errorf("oidc registered claims: %w", err) } // Marshal the final result. - if err := json.NewEncoder(buf).Encode(claims); err != nil { + if err := json.NewEncoder(buf).Encode(merged); err != nil { return nil, fmt.Errorf("oidc custom claims: %w", err) } }