fix: do not modify userInfo when marshaling

This commit is contained in:
David Sharnoff 2023-03-27 13:40:10 -07:00 committed by Tim Möhlmann
parent be3cc13c27
commit e1d50faf9b
6 changed files with 23 additions and 11 deletions

View file

@ -4,6 +4,7 @@ import (
"encoding/json"
"testing"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -25,7 +26,7 @@ func TestIntrospectionResponse_SetUserInfo(t *testing.T) {
UserInfoProfile: userInfoData.UserInfoProfile,
UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone,
Claims: userInfoData.Claims,
Claims: gu.MapCopy(userInfoData.Claims),
},
},
{

View file

@ -6,6 +6,7 @@ import (
"encoding/json"
"io"
"os"
"reflect"
"strings"
"testing"
@ -38,10 +39,12 @@ func Test_assert_regression(t *testing.T) {
assert.JSONEq(t, want, first)
target := reflect.New(reflect.TypeOf(obj).Elem()).Interface()
require.NoError(t,
json.Unmarshal([]byte(first), obj),
json.Unmarshal([]byte(first), target),
)
second, err := json.Marshal(obj)
second, err := json.Marshal(target)
require.NoError(t, err)
assert.JSONEq(t, want, string(second))

View file

@ -158,7 +158,6 @@ func (t *IDTokenClaims) SetUserInfo(i *UserInfo) {
t.UserInfoEmail = i.UserInfoEmail
t.UserInfoPhone = i.UserInfoPhone
t.Address = i.Address
if t.Claims == nil {
t.Claims = make(map[string]any, len(t.Claims))
}

View file

@ -4,7 +4,6 @@ import (
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
@ -182,7 +181,9 @@ func TestIDTokenClaims_SetUserInfo(t *testing.T) {
UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone,
Address: userInfoData.Address,
Claims: gu.MapCopy(userInfoData.Claims),
Claims: map[string]interface{}{
"foo": "bar",
},
}
var got IDTokenClaims

View file

@ -52,11 +52,14 @@ func TestUserInfoMarshal(t *testing.T) {
out := new(UserInfo)
assert.NoError(t, json.Unmarshal(marshal, out))
assert.Equal(t, userinfo, out)
expected, err := json.Marshal(out)
assert.NoError(t, err)
assert.Equal(t, expected, marshal)
out2 := new(UserInfo)
assert.NoError(t, json.Unmarshal(expected, out2))
assert.Equal(t, out, out2)
}
func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) {

View file

@ -9,7 +9,7 @@ import (
// mergeAndMarshalClaims merges registered and the custom
// claims map into a single JSON object.
// Registered fields overwrite custom claims.
func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error) {
func mergeAndMarshalClaims(registered any, extraClaims map[string]any) ([]byte, error) {
// Use a buffer for memory re-use, instead off letting
// json allocate a new []byte for every step.
buf := new(bytes.Buffer)
@ -19,16 +19,21 @@ func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error
return nil, fmt.Errorf("oidc registered claims: %w", err)
}
if len(claims) > 0 {
if len(extraClaims) > 0 {
merged := make(map[string]any)
for k, v := range extraClaims {
merged[k] = v
}
// Merge JSON data into custom claims.
// The full-read action by the decoder resets the buffer
// to zero len, while retaining underlaying cap.
if err := json.NewDecoder(buf).Decode(&claims); err != nil {
if err := json.NewDecoder(buf).Decode(&merged); err != nil {
return nil, fmt.Errorf("oidc registered claims: %w", err)
}
// Marshal the final result.
if err := json.NewEncoder(buf).Encode(claims); err != nil {
if err := json.NewEncoder(buf).Encode(merged); err != nil {
return nil, fmt.Errorf("oidc custom claims: %w", err)
}
}