Merge branch 'main' into pr/ay4toh5i/721

This commit is contained in:
Tim Möhlmann 2025-03-24 17:54:12 +02:00
commit c3cac2bedd
10 changed files with 106 additions and 72 deletions

View file

@ -3,6 +3,7 @@ package oidc
import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"reflect"
"strings"
@ -77,16 +78,25 @@ func (l *Locale) MarshalJSON() ([]byte, error) {
}
// UnmarshalJSON implements json.Unmarshaler.
// All unmarshal errors for are ignored.
// When an error is encountered, the containing tag will be set
// When [language.ValueError] is encountered, the containing tag will be set
// to an empty value (language "und") and no error will be returned.
// This state can be checked with the `l.Tag().IsRoot()` method.
func (l *Locale) UnmarshalJSON(data []byte) error {
err := json.Unmarshal(data, &l.tag)
if err != nil {
l.tag = language.Tag{}
if len(data) == 0 || string(data) == "\"\"" {
return nil
}
return nil
err := json.Unmarshal(data, &l.tag)
if err == nil {
return nil
}
// catch "well-formed but unknown" errors
var target language.ValueError
if errors.As(err, &target) {
l.tag = language.Tag{}
return nil
}
return err
}
type Locales []language.Tag

View file

@ -217,6 +217,30 @@ func TestLocale_UnmarshalJSON(t *testing.T) {
want dst
wantErr bool
}{
{
name: "value not present",
input: `{}`,
wantErr: false,
want: dst{
Locale: nil,
},
},
{
name: "null",
input: `{"locale": null}`,
wantErr: false,
want: dst{
Locale: nil,
},
},
{
name: "empty, ignored",
input: `{"locale": ""}`,
wantErr: false,
want: dst{
Locale: &Locale{},
},
},
{
name: "afrikaans, ok",
input: `{"locale": "af"}`,
@ -232,23 +256,22 @@ func TestLocale_UnmarshalJSON(t *testing.T) {
},
},
{
name: "bad form, error",
input: `{"locale": "g!!!!!"}`,
want: dst{
Locale: &Locale{},
},
name: "bad form, error",
input: `{"locale": "g!!!!!"}`,
wantErr: true,
},
}
for _, tt := range tests {
var got dst
err := json.Unmarshal([]byte(tt.input), &got)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.want, got)
t.Run(tt.name, func(t *testing.T) {
var got dst
err := json.Unmarshal([]byte(tt.input), &got)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -91,10 +91,7 @@ func createDeviceAuthorization(ctx context.Context, req *oidc.DeviceAuthorizatio
}
config := o.DeviceAuthorization()
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
if err != nil {
return nil, NewStatusError(err, http.StatusInternalServerError)
}
deviceCode, _ := NewDeviceCode(RecommendedDeviceCodeBytes)
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval)
if err != nil {
return nil, NewStatusError(err, http.StatusInternalServerError)
@ -163,11 +160,14 @@ func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuth
// results in a 22 character base64 encoded string.
const RecommendedDeviceCodeBytes = 16
// NewDeviceCode generates a new cryptographically secure device code as a base64 encoded string.
// The length of the string is nBytes * 4 / 3.
// An error is never returned.
//
// TODO(v4): change return type to string alone.
func NewDeviceCode(nBytes int) (string, error) {
bytes := make([]byte, nBytes)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("%w getting entropy for device code", err)
}
rand.Read(bytes)
return base64.RawURLEncoding.EncodeToString(bytes), nil
}

View file

@ -145,21 +145,11 @@ func runWithRandReader(r io.Reader, f func()) {
}
func TestNewDeviceCode(t *testing.T) {
t.Run("reader error", func(t *testing.T) {
runWithRandReader(errReader{}, func() {
_, err := op.NewDeviceCode(16)
require.Error(t, err)
})
})
t.Run("different lengths, rand reader", func(t *testing.T) {
for i := 1; i <= 32; i++ {
got, err := op.NewDeviceCode(i)
require.NoError(t, err)
assert.Len(t, got, base64.RawURLEncoding.EncodedLen(i))
}
})
for i := 1; i <= 32; i++ {
got, err := op.NewDeviceCode(i)
require.NoError(t, err)
assert.Len(t, got, base64.RawURLEncoding.EncodedLen(i))
}
}
func TestNewUserCode(t *testing.T) {

View file

@ -144,6 +144,12 @@ type CanSetUserinfoFromRequest interface {
SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, request IDTokenRequest, scopes []string) error
}
// CanGetPrivateClaimsFromRequest is an optional additional interface that may be implemented by
// implementors of Storage. It allows setting the jwt token claims based on the request.
type CanGetPrivateClaimsFromRequest interface {
GetPrivateClaimsFromRequest(ctx context.Context, request TokenRequest, restrictedScopes []string) (map[string]any, error)
}
// Storage is a required parameter for NewOpenIDProvider(). In addition to the
// embedded interfaces below, if the passed Storage implements ClientCredentialsStorage
// then the grant type "client_credentials" will be supported. In that case, the access

View file

@ -147,7 +147,11 @@ func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, ex
tokenExchangeRequest,
)
} else {
privateClaims, err = storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(restrictedScopes))
if fromRequest, ok := storage.(CanGetPrivateClaimsFromRequest); ok {
privateClaims, err = fromRequest.GetPrivateClaimsFromRequest(ctx, tokenRequest, removeUserinfoScopes(restrictedScopes))
} else {
privateClaims, err = storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(restrictedScopes))
}
}
if err != nil {