fix: do not modify userInfo when marshaling
This commit is contained in:
parent
be3cc13c27
commit
e1d50faf9b
6 changed files with 23 additions and 11 deletions
|
@ -4,6 +4,7 @@ import (
|
|||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -25,7 +26,7 @@ func TestIntrospectionResponse_SetUserInfo(t *testing.T) {
|
|||
UserInfoProfile: userInfoData.UserInfoProfile,
|
||||
UserInfoEmail: userInfoData.UserInfoEmail,
|
||||
UserInfoPhone: userInfoData.UserInfoPhone,
|
||||
Claims: userInfoData.Claims,
|
||||
Claims: gu.MapCopy(userInfoData.Claims),
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -38,10 +39,12 @@ func Test_assert_regression(t *testing.T) {
|
|||
|
||||
assert.JSONEq(t, want, first)
|
||||
|
||||
target := reflect.New(reflect.TypeOf(obj).Elem()).Interface()
|
||||
|
||||
require.NoError(t,
|
||||
json.Unmarshal([]byte(first), obj),
|
||||
json.Unmarshal([]byte(first), target),
|
||||
)
|
||||
second, err := json.Marshal(obj)
|
||||
second, err := json.Marshal(target)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.JSONEq(t, want, string(second))
|
||||
|
|
|
@ -158,7 +158,6 @@ func (t *IDTokenClaims) SetUserInfo(i *UserInfo) {
|
|||
t.UserInfoEmail = i.UserInfoEmail
|
||||
t.UserInfoPhone = i.UserInfoPhone
|
||||
t.Address = i.Address
|
||||
|
||||
if t.Claims == nil {
|
||||
t.Claims = make(map[string]any, len(t.Claims))
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/text/language"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
@ -182,7 +181,9 @@ func TestIDTokenClaims_SetUserInfo(t *testing.T) {
|
|||
UserInfoEmail: userInfoData.UserInfoEmail,
|
||||
UserInfoPhone: userInfoData.UserInfoPhone,
|
||||
Address: userInfoData.Address,
|
||||
Claims: gu.MapCopy(userInfoData.Claims),
|
||||
Claims: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
}
|
||||
|
||||
var got IDTokenClaims
|
||||
|
|
|
@ -52,11 +52,14 @@ func TestUserInfoMarshal(t *testing.T) {
|
|||
|
||||
out := new(UserInfo)
|
||||
assert.NoError(t, json.Unmarshal(marshal, out))
|
||||
assert.Equal(t, userinfo, out)
|
||||
expected, err := json.Marshal(out)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, marshal)
|
||||
|
||||
out2 := new(UserInfo)
|
||||
assert.NoError(t, json.Unmarshal(expected, out2))
|
||||
assert.Equal(t, out, out2)
|
||||
}
|
||||
|
||||
func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) {
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
// mergeAndMarshalClaims merges registered and the custom
|
||||
// claims map into a single JSON object.
|
||||
// Registered fields overwrite custom claims.
|
||||
func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error) {
|
||||
func mergeAndMarshalClaims(registered any, extraClaims map[string]any) ([]byte, error) {
|
||||
// Use a buffer for memory re-use, instead off letting
|
||||
// json allocate a new []byte for every step.
|
||||
buf := new(bytes.Buffer)
|
||||
|
@ -19,16 +19,21 @@ func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error
|
|||
return nil, fmt.Errorf("oidc registered claims: %w", err)
|
||||
}
|
||||
|
||||
if len(claims) > 0 {
|
||||
if len(extraClaims) > 0 {
|
||||
merged := make(map[string]any)
|
||||
for k, v := range extraClaims {
|
||||
merged[k] = v
|
||||
}
|
||||
|
||||
// Merge JSON data into custom claims.
|
||||
// The full-read action by the decoder resets the buffer
|
||||
// to zero len, while retaining underlaying cap.
|
||||
if err := json.NewDecoder(buf).Decode(&claims); err != nil {
|
||||
if err := json.NewDecoder(buf).Decode(&merged); err != nil {
|
||||
return nil, fmt.Errorf("oidc registered claims: %w", err)
|
||||
}
|
||||
|
||||
// Marshal the final result.
|
||||
if err := json.NewEncoder(buf).Encode(claims); err != nil {
|
||||
if err := json.NewEncoder(buf).Encode(merged); err != nil {
|
||||
return nil, fmt.Errorf("oidc custom claims: %w", err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue