fix: do not modify userInfo when marshaling
This commit is contained in:
parent
be3cc13c27
commit
e1d50faf9b
6 changed files with 23 additions and 11 deletions
|
@ -4,6 +4,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/muhlemmer/gu"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -25,7 +26,7 @@ func TestIntrospectionResponse_SetUserInfo(t *testing.T) {
|
||||||
UserInfoProfile: userInfoData.UserInfoProfile,
|
UserInfoProfile: userInfoData.UserInfoProfile,
|
||||||
UserInfoEmail: userInfoData.UserInfoEmail,
|
UserInfoEmail: userInfoData.UserInfoEmail,
|
||||||
UserInfoPhone: userInfoData.UserInfoPhone,
|
UserInfoPhone: userInfoData.UserInfoPhone,
|
||||||
Claims: userInfoData.Claims,
|
Claims: gu.MapCopy(userInfoData.Claims),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -38,10 +39,12 @@ func Test_assert_regression(t *testing.T) {
|
||||||
|
|
||||||
assert.JSONEq(t, want, first)
|
assert.JSONEq(t, want, first)
|
||||||
|
|
||||||
|
target := reflect.New(reflect.TypeOf(obj).Elem()).Interface()
|
||||||
|
|
||||||
require.NoError(t,
|
require.NoError(t,
|
||||||
json.Unmarshal([]byte(first), obj),
|
json.Unmarshal([]byte(first), target),
|
||||||
)
|
)
|
||||||
second, err := json.Marshal(obj)
|
second, err := json.Marshal(target)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.JSONEq(t, want, string(second))
|
assert.JSONEq(t, want, string(second))
|
||||||
|
|
|
@ -158,7 +158,6 @@ func (t *IDTokenClaims) SetUserInfo(i *UserInfo) {
|
||||||
t.UserInfoEmail = i.UserInfoEmail
|
t.UserInfoEmail = i.UserInfoEmail
|
||||||
t.UserInfoPhone = i.UserInfoPhone
|
t.UserInfoPhone = i.UserInfoPhone
|
||||||
t.Address = i.Address
|
t.Address = i.Address
|
||||||
|
|
||||||
if t.Claims == nil {
|
if t.Claims == nil {
|
||||||
t.Claims = make(map[string]any, len(t.Claims))
|
t.Claims = make(map[string]any, len(t.Claims))
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/muhlemmer/gu"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/text/language"
|
"golang.org/x/text/language"
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
@ -182,7 +181,9 @@ func TestIDTokenClaims_SetUserInfo(t *testing.T) {
|
||||||
UserInfoEmail: userInfoData.UserInfoEmail,
|
UserInfoEmail: userInfoData.UserInfoEmail,
|
||||||
UserInfoPhone: userInfoData.UserInfoPhone,
|
UserInfoPhone: userInfoData.UserInfoPhone,
|
||||||
Address: userInfoData.Address,
|
Address: userInfoData.Address,
|
||||||
Claims: gu.MapCopy(userInfoData.Claims),
|
Claims: map[string]interface{}{
|
||||||
|
"foo": "bar",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var got IDTokenClaims
|
var got IDTokenClaims
|
||||||
|
|
|
@ -52,11 +52,14 @@ func TestUserInfoMarshal(t *testing.T) {
|
||||||
|
|
||||||
out := new(UserInfo)
|
out := new(UserInfo)
|
||||||
assert.NoError(t, json.Unmarshal(marshal, out))
|
assert.NoError(t, json.Unmarshal(marshal, out))
|
||||||
assert.Equal(t, userinfo, out)
|
|
||||||
expected, err := json.Marshal(out)
|
expected, err := json.Marshal(out)
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expected, marshal)
|
assert.Equal(t, expected, marshal)
|
||||||
|
|
||||||
|
out2 := new(UserInfo)
|
||||||
|
assert.NoError(t, json.Unmarshal(expected, out2))
|
||||||
|
assert.Equal(t, out, out2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) {
|
func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) {
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
// mergeAndMarshalClaims merges registered and the custom
|
// mergeAndMarshalClaims merges registered and the custom
|
||||||
// claims map into a single JSON object.
|
// claims map into a single JSON object.
|
||||||
// Registered fields overwrite custom claims.
|
// Registered fields overwrite custom claims.
|
||||||
func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error) {
|
func mergeAndMarshalClaims(registered any, extraClaims map[string]any) ([]byte, error) {
|
||||||
// Use a buffer for memory re-use, instead off letting
|
// Use a buffer for memory re-use, instead off letting
|
||||||
// json allocate a new []byte for every step.
|
// json allocate a new []byte for every step.
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
|
@ -19,16 +19,21 @@ func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error
|
||||||
return nil, fmt.Errorf("oidc registered claims: %w", err)
|
return nil, fmt.Errorf("oidc registered claims: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(claims) > 0 {
|
if len(extraClaims) > 0 {
|
||||||
|
merged := make(map[string]any)
|
||||||
|
for k, v := range extraClaims {
|
||||||
|
merged[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
// Merge JSON data into custom claims.
|
// Merge JSON data into custom claims.
|
||||||
// The full-read action by the decoder resets the buffer
|
// The full-read action by the decoder resets the buffer
|
||||||
// to zero len, while retaining underlaying cap.
|
// to zero len, while retaining underlaying cap.
|
||||||
if err := json.NewDecoder(buf).Decode(&claims); err != nil {
|
if err := json.NewDecoder(buf).Decode(&merged); err != nil {
|
||||||
return nil, fmt.Errorf("oidc registered claims: %w", err)
|
return nil, fmt.Errorf("oidc registered claims: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Marshal the final result.
|
// Marshal the final result.
|
||||||
if err := json.NewEncoder(buf).Encode(claims); err != nil {
|
if err := json.NewEncoder(buf).Encode(merged); err != nil {
|
||||||
return nil, fmt.Errorf("oidc custom claims: %w", err)
|
return nil, fmt.Errorf("oidc custom claims: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue