fix: do not modify userInfo when marshaling

This commit is contained in:
David Sharnoff 2023-03-27 13:40:10 -07:00 committed by Tim Möhlmann
parent be3cc13c27
commit e1d50faf9b
6 changed files with 23 additions and 11 deletions

View file

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -25,7 +26,7 @@ func TestIntrospectionResponse_SetUserInfo(t *testing.T) {
UserInfoProfile: userInfoData.UserInfoProfile, UserInfoProfile: userInfoData.UserInfoProfile,
UserInfoEmail: userInfoData.UserInfoEmail, UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone, UserInfoPhone: userInfoData.UserInfoPhone,
Claims: userInfoData.Claims, Claims: gu.MapCopy(userInfoData.Claims),
}, },
}, },
{ {

View file

@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"io" "io"
"os" "os"
"reflect"
"strings" "strings"
"testing" "testing"
@ -38,10 +39,12 @@ func Test_assert_regression(t *testing.T) {
assert.JSONEq(t, want, first) assert.JSONEq(t, want, first)
target := reflect.New(reflect.TypeOf(obj).Elem()).Interface()
require.NoError(t, require.NoError(t,
json.Unmarshal([]byte(first), obj), json.Unmarshal([]byte(first), target),
) )
second, err := json.Marshal(obj) second, err := json.Marshal(target)
require.NoError(t, err) require.NoError(t, err)
assert.JSONEq(t, want, string(second)) assert.JSONEq(t, want, string(second))

View file

@ -158,7 +158,6 @@ func (t *IDTokenClaims) SetUserInfo(i *UserInfo) {
t.UserInfoEmail = i.UserInfoEmail t.UserInfoEmail = i.UserInfoEmail
t.UserInfoPhone = i.UserInfoPhone t.UserInfoPhone = i.UserInfoPhone
t.Address = i.Address t.Address = i.Address
if t.Claims == nil { if t.Claims == nil {
t.Claims = make(map[string]any, len(t.Claims)) t.Claims = make(map[string]any, len(t.Claims))
} }

View file

@ -4,7 +4,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/text/language" "golang.org/x/text/language"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
@ -182,7 +181,9 @@ func TestIDTokenClaims_SetUserInfo(t *testing.T) {
UserInfoEmail: userInfoData.UserInfoEmail, UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone, UserInfoPhone: userInfoData.UserInfoPhone,
Address: userInfoData.Address, Address: userInfoData.Address,
Claims: gu.MapCopy(userInfoData.Claims), Claims: map[string]interface{}{
"foo": "bar",
},
} }
var got IDTokenClaims var got IDTokenClaims

View file

@ -52,11 +52,14 @@ func TestUserInfoMarshal(t *testing.T) {
out := new(UserInfo) out := new(UserInfo)
assert.NoError(t, json.Unmarshal(marshal, out)) assert.NoError(t, json.Unmarshal(marshal, out))
assert.Equal(t, userinfo, out)
expected, err := json.Marshal(out) expected, err := json.Marshal(out)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, marshal) assert.Equal(t, expected, marshal)
out2 := new(UserInfo)
assert.NoError(t, json.Unmarshal(expected, out2))
assert.Equal(t, out, out2)
} }
func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) { func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) {

View file

@ -9,7 +9,7 @@ import (
// mergeAndMarshalClaims merges registered and the custom // mergeAndMarshalClaims merges registered and the custom
// claims map into a single JSON object. // claims map into a single JSON object.
// Registered fields overwrite custom claims. // Registered fields overwrite custom claims.
func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error) { func mergeAndMarshalClaims(registered any, extraClaims map[string]any) ([]byte, error) {
// Use a buffer for memory re-use, instead off letting // Use a buffer for memory re-use, instead off letting
// json allocate a new []byte for every step. // json allocate a new []byte for every step.
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
@ -19,16 +19,21 @@ func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error
return nil, fmt.Errorf("oidc registered claims: %w", err) return nil, fmt.Errorf("oidc registered claims: %w", err)
} }
if len(claims) > 0 { if len(extraClaims) > 0 {
merged := make(map[string]any)
for k, v := range extraClaims {
merged[k] = v
}
// Merge JSON data into custom claims. // Merge JSON data into custom claims.
// The full-read action by the decoder resets the buffer // The full-read action by the decoder resets the buffer
// to zero len, while retaining underlaying cap. // to zero len, while retaining underlaying cap.
if err := json.NewDecoder(buf).Decode(&claims); err != nil { if err := json.NewDecoder(buf).Decode(&merged); err != nil {
return nil, fmt.Errorf("oidc registered claims: %w", err) return nil, fmt.Errorf("oidc registered claims: %w", err)
} }
// Marshal the final result. // Marshal the final result.
if err := json.NewEncoder(buf).Encode(claims); err != nil { if err := json.NewEncoder(buf).Encode(merged); err != nil {
return nil, fmt.Errorf("oidc custom claims: %w", err) return nil, fmt.Errorf("oidc custom claims: %w", err)
} }
} }