diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index d8372b8..faf8e7f 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -3,6 +3,7 @@ package oidc import ( "database/sql/driver" "encoding/json" + "errors" "fmt" "reflect" "strings" @@ -77,7 +78,17 @@ func (l *Locale) MarshalJSON() ([]byte, error) { } func (l *Locale) UnmarshalJSON(data []byte) error { - return json.Unmarshal(data, &l.tag) + err := json.Unmarshal(data, &l.tag) + if err == nil { + return nil + } + + // catch "well-formed but unknown" errors + var target language.ValueError + if errors.As(err, &target) { + return nil + } + return err } type Locales []language.Tag diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go index af4f113..df93a73 100644 --- a/pkg/oidc/types_test.go +++ b/pkg/oidc/types_test.go @@ -208,20 +208,46 @@ func TestLocale_MarshalJSON(t *testing.T) { } func TestLocale_UnmarshalJSON(t *testing.T) { - type a struct { + type dst struct { Locale *Locale `json:"locale,omitempty"` } - want := a{ - Locale: NewLocale(language.Afrikaans), + tests := []struct { + name string + input string + want dst + wantErr bool + }{ + { + name: "afrikaans, ok", + input: `{"locale": "af"}`, + want: dst{ + Locale: NewLocale(language.Afrikaans), + }, + }, + { + name: "gb, ignored", + input: `{"locale": "gb"}`, + want: dst{ + Locale: &Locale{}, + }, + }, + { + name: "bad form, error", + input: `{"locale": "g!!!!!"}`, + wantErr: true, + }, } - const input = `{"locale": "af"}` - var got a - - require.NoError(t, - json.Unmarshal([]byte(input), &got), - ) - assert.Equal(t, want, got) + for _, tt := range tests { + var got dst + err := json.Unmarshal([]byte(tt.input), &got) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } } func TestParseLocales(t *testing.T) { diff --git a/pkg/oidc/userinfo.go b/pkg/oidc/userinfo.go index ef8ebe4..b824160 100644 --- a/pkg/oidc/userinfo.go +++ b/pkg/oidc/userinfo.go @@ -41,7 +41,13 @@ func (u *UserInfo) MarshalJSON() ([]byte, error) { } func (u *UserInfo) UnmarshalJSON(data []byte) error { - return unmarshalJSONMulti(data, (*uiAlias)(u), &u.Claims) + if err := unmarshalJSONMulti(data, (*uiAlias)(u), &u.Claims); err != nil { + return err + } + if u.Locale != nil && u.Locale.tag.IsRoot() { + u.Locale = nil + } + return nil } type UserInfoProfile struct {