Merge branch 'main' into update-1.24

This commit is contained in:
Tim Möhlmann 2025-03-14 14:27:41 +02:00 committed by GitHub
commit e09cdecdbc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 64 additions and 21 deletions

View file

@ -3,6 +3,7 @@ package oidc
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -77,17 +78,26 @@ func (l *Locale) MarshalJSON() ([]byte, error) {
} }
// UnmarshalJSON implements json.Unmarshaler. // UnmarshalJSON implements json.Unmarshaler.
// All unmarshal errors for are ignored. // When [language.ValueError] is encountered, the containing tag will be set
// When an error is encountered, the containing tag will be set
// to an empty value (language "und") and no error will be returned. // to an empty value (language "und") and no error will be returned.
// This state can be checked with the `l.Tag().IsRoot()` method. // This state can be checked with the `l.Tag().IsRoot()` method.
func (l *Locale) UnmarshalJSON(data []byte) error { func (l *Locale) UnmarshalJSON(data []byte) error {
err := json.Unmarshal(data, &l.tag) if len(data) == 0 || string(data) == "\"\"" {
if err != nil {
l.tag = language.Tag{}
}
return nil return nil
} }
err := json.Unmarshal(data, &l.tag)
if err == nil {
return nil
}
// catch "well-formed but unknown" errors
var target language.ValueError
if errors.As(err, &target) {
l.tag = language.Tag{}
return nil
}
return err
}
type Locales []language.Tag type Locales []language.Tag

View file

@ -217,6 +217,30 @@ func TestLocale_UnmarshalJSON(t *testing.T) {
want dst want dst
wantErr bool wantErr bool
}{ }{
{
name: "value not present",
input: `{}`,
wantErr: false,
want: dst{
Locale: nil,
},
},
{
name: "null",
input: `{"locale": null}`,
wantErr: false,
want: dst{
Locale: nil,
},
},
{
name: "empty, ignored",
input: `{"locale": ""}`,
wantErr: false,
want: dst{
Locale: &Locale{},
},
},
{ {
name: "afrikaans, ok", name: "afrikaans, ok",
input: `{"locale": "af"}`, input: `{"locale": "af"}`,
@ -234,13 +258,11 @@ func TestLocale_UnmarshalJSON(t *testing.T) {
{ {
name: "bad form, error", name: "bad form, error",
input: `{"locale": "g!!!!!"}`, input: `{"locale": "g!!!!!"}`,
want: dst{ wantErr: true,
Locale: &Locale{},
},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got dst var got dst
err := json.Unmarshal([]byte(tt.input), &got) err := json.Unmarshal([]byte(tt.input), &got)
if tt.wantErr { if tt.wantErr {
@ -249,6 +271,7 @@ func TestLocale_UnmarshalJSON(t *testing.T) {
} }
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tt.want, got) assert.Equal(t, tt.want, got)
})
} }
} }

View file

@ -144,6 +144,12 @@ type CanSetUserinfoFromRequest interface {
SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, request IDTokenRequest, scopes []string) error SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, request IDTokenRequest, scopes []string) error
} }
// CanGetPrivateClaimsFromRequest is an optional additional interface that may be implemented by
// implementors of Storage. It allows setting the jwt token claims based on the request.
type CanGetPrivateClaimsFromRequest interface {
GetPrivateClaimsFromRequest(ctx context.Context, request TokenRequest, restrictedScopes []string) (map[string]any, error)
}
// Storage is a required parameter for NewOpenIDProvider(). In addition to the // Storage is a required parameter for NewOpenIDProvider(). In addition to the
// embedded interfaces below, if the passed Storage implements ClientCredentialsStorage // embedded interfaces below, if the passed Storage implements ClientCredentialsStorage
// then the grant type "client_credentials" will be supported. In that case, the access // then the grant type "client_credentials" will be supported. In that case, the access

View file

@ -146,9 +146,13 @@ func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, ex
ctx, ctx,
tokenExchangeRequest, tokenExchangeRequest,
) )
} else {
if fromRequest, ok := storage.(CanGetPrivateClaimsFromRequest); ok {
privateClaims, err = fromRequest.GetPrivateClaimsFromRequest(ctx, tokenRequest, removeUserinfoScopes(restrictedScopes))
} else { } else {
privateClaims, err = storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(restrictedScopes)) privateClaims, err = storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(restrictedScopes))
} }
}
if err != nil { if err != nil {
return "", err return "", err