Merge pull request #374 from zitadel/main-to-next

chore: merge main into next
This commit is contained in:
Tim Möhlmann 2023-05-02 17:40:20 +03:00 committed by GitHub
commit a446f4f9da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 308 additions and 146 deletions

View file

@ -10,7 +10,7 @@ jobs:
name: Add issue to project
runs-on: ubuntu-latest
steps:
- uses: actions/add-to-project@v0.4.1
- uses: actions/add-to-project@v0.5.0
with:
# You can target a repository in a different organization
# to the issue

View file

@ -21,11 +21,11 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Setup go
uses: actions/setup-go@v3
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- run: go test -race -v -coverprofile=profile.cov -coverpkg=./pkg/... ./pkg/...
- uses: codecov/codecov-action@v3.1.1
- uses: codecov/codecov-action@v3.1.2
with:
file: ./profile.cov
name: codecov-go

View file

@ -108,7 +108,7 @@ func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider,
DeviceAuthorization: op.DeviceAuthorizationConfig{
Lifetime: 5 * time.Minute,
PollInterval: 5 * time.Second,
UserFormURL: issuer + "device",
UserFormPath: "/device",
UserCode: op.UserCodeBase20,
},
}

View file

@ -32,6 +32,8 @@ type Client struct {
devMode bool
idTokenUserinfoClaimsAssertion bool
clockSkew time.Duration
postLogoutRedirectURIGlobs []string
redirectURIGlobs []string
}
// GetID must return the client_id
@ -44,21 +46,11 @@ func (c *Client) RedirectURIs() []string {
return c.redirectURIs
}
// RedirectURIGlobs provide wildcarding for additional valid redirects
func (c *Client) RedirectURIGlobs() []string {
return nil
}
// PostLogoutRedirectURIs must return the registered post_logout_redirect_uris for sign-outs
func (c *Client) PostLogoutRedirectURIs() []string {
return []string{}
}
// PostLogoutRedirectURIGlobs provide extra wildcarding for additional valid redirects
func (c *Client) PostLogoutRedirectURIGlobs() []string {
return nil
}
// ApplicationType must return the type of the client (app, native, user agent)
func (c *Client) ApplicationType() op.ApplicationType {
return c.applicationType
@ -200,3 +192,26 @@ func WebClient(id, secret string, redirectURIs ...string) *Client {
clockSkew: 0,
}
}
type hasRedirectGlobs struct {
*Client
}
// RedirectURIGlobs provide wildcarding for additional valid redirects
func (c hasRedirectGlobs) RedirectURIGlobs() []string {
return c.redirectURIGlobs
}
// PostLogoutRedirectURIGlobs provide extra wildcarding for additional valid redirects
func (c hasRedirectGlobs) PostLogoutRedirectURIGlobs() []string {
return c.postLogoutRedirectURIGlobs
}
// RedirectGlobsClient wraps the client in a op.HasRedirectGlobs
// only if DevMode is enabled.
func RedirectGlobsClient(client *Client) op.Client {
if client.devMode {
return hasRedirectGlobs{client}
}
return client
}

View file

@ -418,7 +418,7 @@ func (s *Storage) GetClientByClientID(ctx context.Context, clientID string) (op.
if !ok {
return nil, fmt.Errorf("client not found")
}
return client, nil
return RedirectGlobsClient(client), nil
}
// AuthorizeClientIDSecret implements the op.Storage interface
@ -438,10 +438,17 @@ func (s *Storage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientS
return nil
}
// SetUserinfoFromScopes implements the op.Storage interface
// it will be called for the creation of an id_token, so we'll just pass it to the private function without any further check
// SetUserinfoFromScopes implements the op.Storage interface.
// Provide an empty implementation and use SetUserinfoFromRequest instead.
func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
return s.setUserinfo(ctx, userinfo, userID, clientID, scopes)
return nil
}
// SetUserinfoFromRequests implements the op.CanSetUserinfoFromRequest interface. In the
// next major release, it will be required for op.Storage.
// It will be called for the creation of an id_token, so we'll just pass it to the private function without any further check
func (s *Storage) SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, token op.IDTokenRequest, scopes []string) error {
return s.setUserinfo(ctx, userinfo, token.GetSubject(), token.GetClientID(), scopes)
}
// SetUserinfoFromToken implements the op.Storage interface

View file

@ -196,8 +196,8 @@ func (s *multiStorage) AuthorizeClientIDSecret(ctx context.Context, clientID, cl
return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
}
// SetUserinfoFromScopes implements the op.Storage interface
// it will be called for the creation of an id_token, so we'll just pass it to the private function without any further check
// SetUserinfoFromScopes implements the op.Storage interface.
// Provide an empty implementation and use SetUserinfoFromRequest instead.
func (s *multiStorage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
storage, err := s.storageFromContext(ctx)
if err != nil {
@ -206,6 +206,17 @@ func (s *multiStorage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc
return storage.SetUserinfoFromScopes(ctx, userinfo, userID, clientID, scopes)
}
// SetUserinfoFromRequests implements the op.CanSetUserinfoFromRequest interface. In the
// next major release, it will be required for op.Storage.
// It will be called for the creation of an id_token, so we'll just pass it to the private function without any further check
func (s *multiStorage) SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, token op.IDTokenRequest, scopes []string) error {
storage, err := s.storageFromContext(ctx)
if err != nil {
return err
}
return storage.SetUserinfoFromRequest(ctx, userinfo, token, scopes)
}
// SetUserinfoFromToken implements the op.Storage interface
// it will be called for the userinfo endpoint, so we read the token and pass the information from that to the private function
func (s *multiStorage) SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error {

12
go.mod
View file

@ -10,12 +10,12 @@ require (
github.com/gorilla/securecookie v1.1.1
github.com/jeremija/gosubmit v0.2.7
github.com/muhlemmer/gu v0.3.1
github.com/rs/cors v1.8.3
github.com/rs/cors v1.9.0
github.com/sirupsen/logrus v1.9.0
github.com/stretchr/testify v1.8.2
github.com/zitadel/schema v1.3.0
golang.org/x/oauth2 v0.6.0
golang.org/x/text v0.8.0
golang.org/x/oauth2 v0.7.0
golang.org/x/text v0.9.0
gopkg.in/square/go-jose.v2 v2.6.0
)
@ -26,10 +26,10 @@ require (
github.com/google/go-querystring v1.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/crypto v0.7.0 // indirect
golang.org/x/net v0.8.0 // indirect
golang.org/x/sys v0.6.0 // indirect
golang.org/x/net v0.9.0 // indirect
golang.org/x/sys v0.7.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.29.0 // indirect
google.golang.org/protobuf v1.29.1 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

24
go.sum
View file

@ -34,8 +34,8 @@ github.com/muhlemmer/gu v0.3.1 h1:7EAqmFrW7n3hETvuAdmFmn4hS8W+z3LgKtrnow+YzNM=
github.com/muhlemmer/gu v0.3.1/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/cors v1.8.3 h1:O+qNyWn7Z+F9M0ILBHgMVPuB1xTOucVd5gtaYyXBpRo=
github.com/rs/cors v1.8.3/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE=
github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -60,11 +60,11 @@ golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.6.0 h1:Lh8GPgSKBfWSwFvtuWOfeI3aAAnbXTSutYxJiOJFgIw=
golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw=
golang.org/x/oauth2 v0.7.0 h1:qe6s0zUXlPX80/dITx3440hWZ7GwMwgDDyrSGTPJG/g=
golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -73,14 +73,14 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
@ -93,8 +93,8 @@ google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.29.0 h1:44S3JjaKmLEE4YIkjzexaP+NzZsudE3Zin5Njn/pYX0=
google.golang.org/protobuf v1.29.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.29.1 h1:7QBf+IK2gx70Ap/hDsOmam3GE0v9HicjfEdAxE62UoM=
google.golang.org/protobuf v1.29.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

View file

@ -61,12 +61,18 @@ func callTokenEndpoint(ctx context.Context, request interface{}, authFn interfac
if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil {
return nil, err
}
return &oauth2.Token{
token := &oauth2.Token{
AccessToken: tokenRes.AccessToken,
TokenType: tokenRes.TokenType,
RefreshToken: tokenRes.RefreshToken,
Expiry: time.Now().UTC().Add(time.Duration(tokenRes.ExpiresIn) * time.Second),
}, nil
}
if tokenRes.IDToken != "" {
token = token.WithExtra(map[string]any{
"id_token": tokenRes.IDToken,
})
}
return token, nil
}
type EndSessionCaller interface {

View file

@ -68,6 +68,7 @@ func TestRelyingPartySession(t *testing.T) {
t.Logf("new token type %s", newTokens.TokenType)
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
assert.NotEmpty(t, newTokens.Extra("id_token"), "new idToken")
t.Log("------ end session (logout) ------")
@ -158,7 +159,6 @@ func TestResourceServerTokenExchange(t *testing.T) {
require.Error(t, err, "refresh token")
assert.Contains(t, err.Error(), "subject_token is invalid")
require.Nil(t, tokenExchangeResponse, "token exchange response")
}
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) {

View file

@ -17,7 +17,9 @@ type TokenSource interface {
TokenCtx(context.Context) (*oauth2.Token, error)
}
// jwtProfileTokenSource implements the TokenSource
// jwtProfileTokenSource implement the oauth2.TokenSource
// it will request a token using the OAuth2 JWT Profile Grant
// therefore sending an `assertion` by signing a JWT with the provided private key
type jwtProfileTokenSource struct {
clientID string
audience []string

View file

@ -599,6 +599,10 @@ type RefreshTokenRequest struct {
GrantType oidc.GrantType `schema:"grant_type"`
}
// RefreshAccessToken performs a token refresh. If it doesn't error, it will always
// provide a new AccessToken. It may provide a new RefreshToken, and if it does, then
// the old one should be considered invalid. It may also provide a new IDToken. The
// new IDToken can be retrieved with token.Extra("id_token").
func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) {
request := RefreshTokenRequest{
RefreshToken: refreshToken,

View file

@ -4,6 +4,7 @@ import (
"encoding/json"
"testing"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -25,7 +26,7 @@ func TestIntrospectionResponse_SetUserInfo(t *testing.T) {
UserInfoProfile: userInfoData.UserInfoProfile,
UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone,
Claims: userInfoData.Claims,
Claims: gu.MapCopy(userInfoData.Claims),
},
},
{

View file

@ -6,6 +6,7 @@ import (
"encoding/json"
"io"
"os"
"reflect"
"strings"
"testing"
@ -38,10 +39,12 @@ func Test_assert_regression(t *testing.T) {
assert.JSONEq(t, want, first)
target := reflect.New(reflect.TypeOf(obj).Elem()).Interface()
require.NoError(t,
json.Unmarshal([]byte(first), obj),
json.Unmarshal([]byte(first), target),
)
second, err := json.Marshal(obj)
second, err := json.Marshal(target)
require.NoError(t, err)
assert.JSONEq(t, want, string(second))

View file

@ -8,6 +8,7 @@ import (
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/muhlemmer/gu"
"github.com/zitadel/oidc/v3/pkg/crypto"
)
@ -157,6 +158,21 @@ func (t *IDTokenClaims) SetUserInfo(i *UserInfo) {
t.UserInfoEmail = i.UserInfoEmail
t.UserInfoPhone = i.UserInfoPhone
t.Address = i.Address
if t.Claims == nil {
t.Claims = make(map[string]any, len(t.Claims))
}
gu.MapMerge(i.Claims, t.Claims)
}
func (t *IDTokenClaims) GetUserInfo() *UserInfo {
return &UserInfo{
Subject: t.Subject,
UserInfoProfile: t.UserInfoProfile,
UserInfoEmail: t.UserInfoEmail,
UserInfoPhone: t.UserInfoPhone,
Address: t.Address,
Claims: gu.MapCopy(t.Claims),
}
}
func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration) *IDTokenClaims {

View file

@ -181,6 +181,9 @@ func TestIDTokenClaims_SetUserInfo(t *testing.T) {
UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone,
Address: userInfoData.Address,
Claims: map[string]interface{}{
"foo": "bar",
},
}
var got IDTokenClaims
@ -225,3 +228,16 @@ func TestNewIDTokenClaims(t *testing.T) {
assert.Equal(t, want, got)
}
func TestIDTokenClaims_GetUserInfo(t *testing.T) {
want := &UserInfo{
Subject: idTokenData.Subject,
UserInfoProfile: idTokenData.UserInfoProfile,
UserInfoEmail: idTokenData.UserInfoEmail,
UserInfoPhone: idTokenData.UserInfoPhone,
Address: idTokenData.Address,
Claims: idTokenData.Claims,
}
got := idTokenData.GetUserInfo()
assert.Equal(t, want, got)
}

View file

@ -52,11 +52,14 @@ func TestUserInfoMarshal(t *testing.T) {
out := new(UserInfo)
assert.NoError(t, json.Unmarshal(marshal, out))
assert.Equal(t, userinfo, out)
expected, err := json.Marshal(out)
assert.NoError(t, err)
assert.Equal(t, expected, marshal)
out2 := new(UserInfo)
assert.NoError(t, json.Unmarshal(expected, out2))
assert.Equal(t, out, out2)
}
func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) {

View file

@ -9,7 +9,7 @@ import (
// mergeAndMarshalClaims merges registered and the custom
// claims map into a single JSON object.
// Registered fields overwrite custom claims.
func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error) {
func mergeAndMarshalClaims(registered any, extraClaims map[string]any) ([]byte, error) {
// Use a buffer for memory re-use, instead off letting
// json allocate a new []byte for every step.
buf := new(bytes.Buffer)
@ -19,16 +19,21 @@ func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error
return nil, fmt.Errorf("oidc registered claims: %w", err)
}
if len(claims) > 0 {
if len(extraClaims) > 0 {
merged := make(map[string]any)
for k, v := range extraClaims {
merged[k] = v
}
// Merge JSON data into custom claims.
// The full-read action by the decoder resets the buffer
// to zero len, while retaining underlaying cap.
if err := json.NewDecoder(buf).Decode(&claims); err != nil {
if err := json.NewDecoder(buf).Decode(&merged); err != nil {
return nil, fmt.Errorf("oidc registered claims: %w", err)
}
// Marshal the final result.
if err := json.NewEncoder(buf).Encode(claims); err != nil {
if err := json.NewEncoder(buf).Encode(merged); err != nil {
return nil, fmt.Errorf("oidc custom claims: %w", err)
}
}

View file

@ -67,7 +67,7 @@ func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, *
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, nil, err, authorizer.Encoder())
return
}
ctx := r.Context()
@ -273,9 +273,9 @@ func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
return scopes, nil
}
// checkURIAginstRedirects just checks aginst the valid redirect URIs and ignores
// checkURIAgainstRedirects just checks aginst the valid redirect URIs and ignores
// other factors.
func checkURIAginstRedirects(client Client, uri string) error {
func checkURIAgainstRedirects(client Client, uri string) error {
if str.Contains(client.RedirectURIs(), uri) {
return nil
}
@ -302,12 +302,12 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res
"Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.")
}
if strings.HasPrefix(uri, "https://") {
return checkURIAginstRedirects(client, uri)
return checkURIAgainstRedirects(client, uri)
}
if client.ApplicationType() == ApplicationTypeNative {
return validateAuthReqRedirectURINative(client, uri, responseType)
}
if err := checkURIAginstRedirects(client, uri); err != nil {
if err := checkURIAgainstRedirects(client, uri); err != nil {
return err
}
if strings.HasPrefix(uri, "http://") {
@ -328,7 +328,7 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res
func validateAuthReqRedirectURINative(client Client, uri string, responseType oidc.ResponseType) error {
parsedURL, isLoopback := HTTPLoopbackOrLocalhost(uri)
isCustomSchema := !strings.HasPrefix(uri, "http://")
if err := checkURIAginstRedirects(client, uri); err == nil {
if err := checkURIAgainstRedirects(client, uri); err == nil {
if client.DevMode() {
return nil
}

View file

@ -9,6 +9,7 @@ import (
"reflect"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v3/internal/testutil"
@ -19,60 +20,34 @@ import (
"github.com/zitadel/schema"
)
//
// TOOD: tests will be implemented in branch for service accounts
// func TestAuthorize(t *testing.T) {
// // testCallback := func(t *testing.T, clienID string) callbackHandler {
// // return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) {
// // // require.Equal(t, clientID, client.)
// // }
// // }
// // testErr := func(t *testing.T, expected error) errorHandler {
// // return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
// // require.Equal(t, expected, err)
// // }
// // }
// type args struct {
// w http.ResponseWriter
// r *http.Request
// authorizer op.Authorizer
// }
// tests := []struct {
// name string
// args args
// }{
// {
// "parsing fails",
// args{
// httptest.NewRecorder(),
// &http.Request{Method: "POST", Body: nil},
// mock.NewAuthorizerExpectValid(t, true),
// // testCallback(t, ""),
// // testErr(t, ErrInvalidRequest("cannot parse form")),
// },
// },
// {
// "decoding fails",
// args{
// httptest.NewRecorder(),
// func() *http.Request {
// r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo"))
// r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// return r
// }(),
// mock.NewAuthorizerExpectValid(t, true),
// // testCallback(t, ""),
// // testErr(t, ErrInvalidRequest("cannot parse auth request")),
// },
// },
// // {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}},
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer)
// })
// }
//}
func TestAuthorize(t *testing.T) {
tests := []struct {
name string
req *http.Request
expect func(a *mock.MockAuthorizerMockRecorder)
}{
{
name: "parse error", // used to panic, see issue #315
req: httptest.NewRequest(http.MethodPost, "/?;", nil),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
authorizer := mock.NewMockAuthorizer(gomock.NewController(t))
expect := authorizer.EXPECT()
expect.Decoder().Return(schema.NewDecoder())
expect.Encoder().Return(schema.NewEncoder())
if tt.expect != nil {
tt.expect(expect)
}
op.Authorize(w, tt.req, authorizer)
})
}
}
func TestParseAuthorizeRequest(t *testing.T) {
type args struct {

View file

@ -56,6 +56,12 @@ type Client interface {
// interpretation. Redirect URIs that match either the non-glob version or the
// glob version will be accepted. Glob URIs are only partially supported for native
// clients: "http://" is not allowed except for loopback or in dev mode.
//
// Note that globbing / wildcards are not permitted by the OIDC
// standard and implementing this interface can have security implications.
// It is advised to only return a client of this type in rare cases,
// such as DevMode for the client being enabled.
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
type HasRedirectGlobs interface {
RedirectURIGlobs() []string
PostLogoutRedirectURIGlobs() []string
@ -145,21 +151,30 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au
}
data := new(clientData)
if err = p.Decoder().Decode(data, r.PostForm); err != nil {
if err = p.Decoder().Decode(data, r.Form); err != nil {
return "", false, err
}
JWTProfile, ok := p.(ClientJWTProfile)
if ok {
if ok && data.ClientAssertion != "" {
// if JWTProfile is supported and client sent an assertion, check it and use it as response
// regardless if it succeeded or failed
clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile)
return clientID, err == nil, err
}
if !ok || errors.Is(err, ErrNoClientCredentials) {
clientID, err = ClientBasicAuth(r, p.Storage())
}
// try basic auth
clientID, err = ClientBasicAuth(r, p.Storage())
// if that succeeded, use it
if err == nil {
return clientID, true, nil
}
// if the client did not send a Basic Auth Header, ignore the `ErrNoClientCredentials`
// but return other errors immediately
if err != nil && !errors.Is(err, ErrNoClientCredentials) {
return "", false, err
}
// if the client did not authenticate (public clients) it must at least send a client_id
if data.ClientID == "" {
return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID)
}

View file

@ -8,6 +8,7 @@ import (
"fmt"
"math/big"
"net/http"
"net/url"
"strings"
"time"
@ -18,7 +19,14 @@ import (
type DeviceAuthorizationConfig struct {
Lifetime time.Duration
PollInterval time.Duration
UserFormURL string // the URL where the user must go to authorize the device
// UserFormURL is the complete URL where the user must go to authorize the device.
// Deprecated: use UserFormPath instead.
UserFormURL string
// UserFormPath is the path where the user must go to authorize the device.
// The hostname for the URL is taken from the request by IssuerFromContext.
UserFormPath string
UserCode UserCodeConfig
}
@ -82,15 +90,28 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
return err
}
var verification *url.URL
if config.UserFormURL != "" {
if verification, err = url.Parse(config.UserFormURL); err != nil {
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
}
} else {
if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil {
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
}
verification.Path = config.UserFormPath
}
response := &oidc.DeviceAuthorizationResponse{
DeviceCode: deviceCode,
UserCode: userCode,
VerificationURI: config.UserFormURL,
VerificationURI: verification.String(),
ExpiresIn: int(config.Lifetime / time.Second),
Interval: int(config.PollInterval / time.Second),
}
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode)
verification.RawQuery = "user_code=" + userCode
response.VerificationURIComplete = verification.String()
httphelper.MarshalJSON(w, response)
return nil

View file

@ -13,6 +13,7 @@ import (
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/oidc"
@ -20,29 +21,60 @@ import (
)
func Test_deviceAuthorizationHandler(t *testing.T) {
req := &oidc.DeviceAuthorizationRequest{
Scopes: []string{"foo", "bar"},
ClientID: "web",
type conf struct {
UserFormURL string
UserFormPath string
}
values := make(url.Values)
testProvider.Encoder().Encode(req, values)
body := strings.NewReader(values.Encode())
tests := []struct {
name string
conf conf
}{
{
name: "UserFormURL",
conf: conf{
UserFormURL: "https://localhost:9998/device",
},
},
{
name: "UserFormPath",
conf: conf{
UserFormPath: "/device",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conf := gu.PtrCopy(testConfig)
conf.DeviceAuthorization.UserFormURL = tt.conf.UserFormURL
conf.DeviceAuthorization.UserFormPath = tt.conf.UserFormPath
provider := newTestProvider(conf)
r := httptest.NewRequest(http.MethodPost, "/", body)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req := &oidc.DeviceAuthorizationRequest{
Scopes: []string{"foo", "bar"},
ClientID: "web",
}
values := make(url.Values)
testProvider.Encoder().Encode(req, values)
body := strings.NewReader(values.Encode())
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/", body)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r = r.WithContext(op.ContextWithIssuer(r.Context(), testIssuer))
runWithRandReader(mr.New(mr.NewSource(1)), func() {
op.DeviceAuthorizationHandler(testProvider)(w, r)
})
w := httptest.NewRecorder()
result := w.Result()
runWithRandReader(mr.New(mr.NewSource(1)), func() {
op.DeviceAuthorizationHandler(provider)(w, r)
})
assert.Less(t, result.StatusCode, 300)
result := w.Result()
got, _ := io.ReadAll(result.Body)
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
assert.Less(t, result.StatusCode, 300)
got, _ := io.ReadAll(result.Body)
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
})
}
}
func TestParseDeviceCodeRequest(t *testing.T) {

View file

@ -480,6 +480,16 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
}
o.endpoints.DeviceAuthorization = endpoint
return nil
}
}
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option {
return func(o *Provider) error {
o.endpoints.Authorization = auth

View file

@ -20,15 +20,9 @@ import (
"golang.org/x/text/language"
)
var testProvider op.OpenIDProvider
const (
testIssuer = "https://localhost:9998/"
pathLoggedOut = "/logged-out"
)
func init() {
config := &op.Config{
var (
testProvider op.OpenIDProvider
testConfig = &op.Config{
CryptoKey: sha256.Sum256([]byte("test")),
DefaultLogoutRedirectURI: pathLoggedOut,
CodeMethodS256: true,
@ -40,24 +34,35 @@ func init() {
DeviceAuthorization: op.DeviceAuthorizationConfig{
Lifetime: 5 * time.Minute,
PollInterval: 5 * time.Second,
UserFormURL: testIssuer + "device",
UserFormPath: "/device",
UserCode: op.UserCodeBase20,
},
}
)
const (
testIssuer = "https://localhost:9998/"
pathLoggedOut = "/logged-out"
)
func init() {
storage.RegisterClients(
storage.NativeClient("native"),
storage.WebClient("web", "secret", "https://example.com"),
storage.WebClient("api", "secret"),
)
var err error
testProvider, err = op.NewOpenIDProvider(testIssuer, config,
testProvider = newTestProvider(testConfig)
}
func newTestProvider(config *op.Config) op.OpenIDProvider {
provider, err := op.NewOpenIDProvider(testIssuer, config,
storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
)
if err != nil {
panic(err)
}
return provider
}
type routesTestStorage interface {

View file

@ -113,6 +113,8 @@ type OPStorage interface {
// handle the current request.
GetClientByClientID(ctx context.Context, clientID string) (Client, error)
AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error
// SetUserinfoFromScopes is deprecated and should have an empty implementation for now.
// Implement SetUserinfoFromRequest instead.
SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error
SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error
SetIntrospectionFromToken(ctx context.Context, userinfo *oidc.IntrospectionResponse, tokenID, subject, clientID string) error
@ -127,6 +129,13 @@ type JWTProfileTokenStorage interface {
JWTProfileTokenType(ctx context.Context, request TokenRequest) (AccessTokenType, error)
}
// CanSetUserinfoFromRequest is an optional additional interface that may be implemented by
// implementors of Storage. It allows additional data to be set in id_tokens based on the
// request.
type CanSetUserinfoFromRequest interface {
SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, request IDTokenRequest, scopes []string) error
}
// Storage is a required parameter for NewOpenIDProvider(). In addition to the
// embedded interfaces below, if the passed Storage implements ClientCredentialsStorage
// then the grant type "client_credentials" will be supported. In that case, the access

View file

@ -190,6 +190,12 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v
if err != nil {
return "", err
}
if fromRequest, ok := storage.(CanSetUserinfoFromRequest); ok {
err := fromRequest.SetUserinfoFromRequest(ctx, userInfo, request, scopes)
if err != nil {
return "", err
}
}
claims.SetUserInfo(userInfo)
}
if code != "" {