Merge branch 'main' into main-to-next

This commit is contained in:
Tim Möhlmann 2023-04-18 12:32:04 +03:00
commit 8dff7ddee0
27 changed files with 308 additions and 146 deletions

View file

@ -10,7 +10,7 @@ jobs:
name: Add issue to project name: Add issue to project
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/add-to-project@v0.4.1 - uses: actions/add-to-project@v0.5.0
with: with:
# You can target a repository in a different organization # You can target a repository in a different organization
# to the issue # to the issue

View file

@ -21,11 +21,11 @@ jobs:
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Setup go - name: Setup go
uses: actions/setup-go@v3 uses: actions/setup-go@v4
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- run: go test -race -v -coverprofile=profile.cov -coverpkg=./pkg/... ./pkg/... - run: go test -race -v -coverprofile=profile.cov -coverpkg=./pkg/... ./pkg/...
- uses: codecov/codecov-action@v3.1.1 - uses: codecov/codecov-action@v3.1.2
with: with:
file: ./profile.cov file: ./profile.cov
name: codecov-go name: codecov-go

View file

@ -108,7 +108,7 @@ func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider,
DeviceAuthorization: op.DeviceAuthorizationConfig{ DeviceAuthorization: op.DeviceAuthorizationConfig{
Lifetime: 5 * time.Minute, Lifetime: 5 * time.Minute,
PollInterval: 5 * time.Second, PollInterval: 5 * time.Second,
UserFormURL: issuer + "device", UserFormPath: "/device",
UserCode: op.UserCodeBase20, UserCode: op.UserCodeBase20,
}, },
} }

View file

@ -32,6 +32,8 @@ type Client struct {
devMode bool devMode bool
idTokenUserinfoClaimsAssertion bool idTokenUserinfoClaimsAssertion bool
clockSkew time.Duration clockSkew time.Duration
postLogoutRedirectURIGlobs []string
redirectURIGlobs []string
} }
// GetID must return the client_id // GetID must return the client_id
@ -44,21 +46,11 @@ func (c *Client) RedirectURIs() []string {
return c.redirectURIs return c.redirectURIs
} }
// RedirectURIGlobs provide wildcarding for additional valid redirects
func (c *Client) RedirectURIGlobs() []string {
return nil
}
// PostLogoutRedirectURIs must return the registered post_logout_redirect_uris for sign-outs // PostLogoutRedirectURIs must return the registered post_logout_redirect_uris for sign-outs
func (c *Client) PostLogoutRedirectURIs() []string { func (c *Client) PostLogoutRedirectURIs() []string {
return []string{} return []string{}
} }
// PostLogoutRedirectURIGlobs provide extra wildcarding for additional valid redirects
func (c *Client) PostLogoutRedirectURIGlobs() []string {
return nil
}
// ApplicationType must return the type of the client (app, native, user agent) // ApplicationType must return the type of the client (app, native, user agent)
func (c *Client) ApplicationType() op.ApplicationType { func (c *Client) ApplicationType() op.ApplicationType {
return c.applicationType return c.applicationType
@ -200,3 +192,26 @@ func WebClient(id, secret string, redirectURIs ...string) *Client {
clockSkew: 0, clockSkew: 0,
} }
} }
type hasRedirectGlobs struct {
*Client
}
// RedirectURIGlobs provide wildcarding for additional valid redirects
func (c hasRedirectGlobs) RedirectURIGlobs() []string {
return c.redirectURIGlobs
}
// PostLogoutRedirectURIGlobs provide extra wildcarding for additional valid redirects
func (c hasRedirectGlobs) PostLogoutRedirectURIGlobs() []string {
return c.postLogoutRedirectURIGlobs
}
// RedirectGlobsClient wraps the client in a op.HasRedirectGlobs
// only if DevMode is enabled.
func RedirectGlobsClient(client *Client) op.Client {
if client.devMode {
return hasRedirectGlobs{client}
}
return client
}

View file

@ -418,7 +418,7 @@ func (s *Storage) GetClientByClientID(ctx context.Context, clientID string) (op.
if !ok { if !ok {
return nil, fmt.Errorf("client not found") return nil, fmt.Errorf("client not found")
} }
return client, nil return RedirectGlobsClient(client), nil
} }
// AuthorizeClientIDSecret implements the op.Storage interface // AuthorizeClientIDSecret implements the op.Storage interface
@ -438,10 +438,17 @@ func (s *Storage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientS
return nil return nil
} }
// SetUserinfoFromScopes implements the op.Storage interface // SetUserinfoFromScopes implements the op.Storage interface.
// it will be called for the creation of an id_token, so we'll just pass it to the private function without any further check // Provide an empty implementation and use SetUserinfoFromRequest instead.
func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error { func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
return s.setUserinfo(ctx, userinfo, userID, clientID, scopes) return nil
}
// SetUserinfoFromRequests implements the op.CanSetUserinfoFromRequest interface. In the
// next major release, it will be required for op.Storage.
// It will be called for the creation of an id_token, so we'll just pass it to the private function without any further check
func (s *Storage) SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, token op.IDTokenRequest, scopes []string) error {
return s.setUserinfo(ctx, userinfo, token.GetSubject(), token.GetClientID(), scopes)
} }
// SetUserinfoFromToken implements the op.Storage interface // SetUserinfoFromToken implements the op.Storage interface

View file

@ -196,8 +196,8 @@ func (s *multiStorage) AuthorizeClientIDSecret(ctx context.Context, clientID, cl
return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret) return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
} }
// SetUserinfoFromScopes implements the op.Storage interface // SetUserinfoFromScopes implements the op.Storage interface.
// it will be called for the creation of an id_token, so we'll just pass it to the private function without any further check // Provide an empty implementation and use SetUserinfoFromRequest instead.
func (s *multiStorage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error { func (s *multiStorage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
storage, err := s.storageFromContext(ctx) storage, err := s.storageFromContext(ctx)
if err != nil { if err != nil {
@ -206,6 +206,17 @@ func (s *multiStorage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc
return storage.SetUserinfoFromScopes(ctx, userinfo, userID, clientID, scopes) return storage.SetUserinfoFromScopes(ctx, userinfo, userID, clientID, scopes)
} }
// SetUserinfoFromRequests implements the op.CanSetUserinfoFromRequest interface. In the
// next major release, it will be required for op.Storage.
// It will be called for the creation of an id_token, so we'll just pass it to the private function without any further check
func (s *multiStorage) SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, token op.IDTokenRequest, scopes []string) error {
storage, err := s.storageFromContext(ctx)
if err != nil {
return err
}
return storage.SetUserinfoFromRequest(ctx, userinfo, token, scopes)
}
// SetUserinfoFromToken implements the op.Storage interface // SetUserinfoFromToken implements the op.Storage interface
// it will be called for the userinfo endpoint, so we read the token and pass the information from that to the private function // it will be called for the userinfo endpoint, so we read the token and pass the information from that to the private function
func (s *multiStorage) SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error { func (s *multiStorage) SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error {

12
go.mod
View file

@ -10,12 +10,12 @@ require (
github.com/gorilla/securecookie v1.1.1 github.com/gorilla/securecookie v1.1.1
github.com/jeremija/gosubmit v0.2.7 github.com/jeremija/gosubmit v0.2.7
github.com/muhlemmer/gu v0.3.1 github.com/muhlemmer/gu v0.3.1
github.com/rs/cors v1.8.3 github.com/rs/cors v1.9.0
github.com/sirupsen/logrus v1.9.0 github.com/sirupsen/logrus v1.9.0
github.com/stretchr/testify v1.8.2 github.com/stretchr/testify v1.8.2
github.com/zitadel/schema v1.3.0 github.com/zitadel/schema v1.3.0
golang.org/x/oauth2 v0.6.0 golang.org/x/oauth2 v0.7.0
golang.org/x/text v0.8.0 golang.org/x/text v0.9.0
gopkg.in/square/go-jose.v2 v2.6.0 gopkg.in/square/go-jose.v2 v2.6.0
) )
@ -26,10 +26,10 @@ require (
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/crypto v0.7.0 // indirect golang.org/x/crypto v0.7.0 // indirect
golang.org/x/net v0.8.0 // indirect golang.org/x/net v0.9.0 // indirect
golang.org/x/sys v0.6.0 // indirect golang.org/x/sys v0.7.0 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.29.0 // indirect google.golang.org/protobuf v1.29.1 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

24
go.sum
View file

@ -34,8 +34,8 @@ github.com/muhlemmer/gu v0.3.1 h1:7EAqmFrW7n3hETvuAdmFmn4hS8W+z3LgKtrnow+YzNM=
github.com/muhlemmer/gu v0.3.1/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM= github.com/muhlemmer/gu v0.3.1/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/cors v1.8.3 h1:O+qNyWn7Z+F9M0ILBHgMVPuB1xTOucVd5gtaYyXBpRo= github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE=
github.com/rs/cors v1.8.3/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -60,11 +60,11 @@ golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.6.0 h1:Lh8GPgSKBfWSwFvtuWOfeI3aAAnbXTSutYxJiOJFgIw= golang.org/x/oauth2 v0.7.0 h1:qe6s0zUXlPX80/dITx3440hWZ7GwMwgDDyrSGTPJG/g=
golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -73,14 +73,14 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
@ -93,8 +93,8 @@ google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.29.0 h1:44S3JjaKmLEE4YIkjzexaP+NzZsudE3Zin5Njn/pYX0= google.golang.org/protobuf v1.29.1 h1:7QBf+IK2gx70Ap/hDsOmam3GE0v9HicjfEdAxE62UoM=
google.golang.org/protobuf v1.29.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.29.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

View file

@ -61,12 +61,18 @@ func callTokenEndpoint(ctx context.Context, request interface{}, authFn interfac
if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil { if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil {
return nil, err return nil, err
} }
return &oauth2.Token{ token := &oauth2.Token{
AccessToken: tokenRes.AccessToken, AccessToken: tokenRes.AccessToken,
TokenType: tokenRes.TokenType, TokenType: tokenRes.TokenType,
RefreshToken: tokenRes.RefreshToken, RefreshToken: tokenRes.RefreshToken,
Expiry: time.Now().UTC().Add(time.Duration(tokenRes.ExpiresIn) * time.Second), Expiry: time.Now().UTC().Add(time.Duration(tokenRes.ExpiresIn) * time.Second),
}, nil }
if tokenRes.IDToken != "" {
token = token.WithExtra(map[string]any{
"id_token": tokenRes.IDToken,
})
}
return token, nil
} }
type EndSessionCaller interface { type EndSessionCaller interface {

View file

@ -68,6 +68,7 @@ func TestRelyingPartySession(t *testing.T) {
t.Logf("new token type %s", newTokens.TokenType) t.Logf("new token type %s", newTokens.TokenType)
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339)) t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
require.NotEmpty(t, newTokens.AccessToken, "new accessToken") require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
assert.NotEmpty(t, newTokens.Extra("id_token"), "new idToken")
t.Log("------ end session (logout) ------") t.Log("------ end session (logout) ------")
@ -158,7 +159,6 @@ func TestResourceServerTokenExchange(t *testing.T) {
require.Error(t, err, "refresh token") require.Error(t, err, "refresh token")
assert.Contains(t, err.Error(), "subject_token is invalid") assert.Contains(t, err.Error(), "subject_token is invalid")
require.Nil(t, tokenExchangeResponse, "token exchange response") require.Nil(t, tokenExchangeResponse, "token exchange response")
} }
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) { func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) {

View file

@ -17,7 +17,9 @@ type TokenSource interface {
TokenCtx(context.Context) (*oauth2.Token, error) TokenCtx(context.Context) (*oauth2.Token, error)
} }
// jwtProfileTokenSource implements the TokenSource // jwtProfileTokenSource implement the oauth2.TokenSource
// it will request a token using the OAuth2 JWT Profile Grant
// therefore sending an `assertion` by signing a JWT with the provided private key
type jwtProfileTokenSource struct { type jwtProfileTokenSource struct {
clientID string clientID string
audience []string audience []string

View file

@ -599,6 +599,10 @@ type RefreshTokenRequest struct {
GrantType oidc.GrantType `schema:"grant_type"` GrantType oidc.GrantType `schema:"grant_type"`
} }
// RefreshAccessToken performs a token refresh. If it doesn't error, it will always
// provide a new AccessToken. It may provide a new RefreshToken, and if it does, then
// the old one should be considered invalid. It may also provide a new IDToken. The
// new IDToken can be retrieved with token.Extra("id_token").
func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) { func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) {
request := RefreshTokenRequest{ request := RefreshTokenRequest{
RefreshToken: refreshToken, RefreshToken: refreshToken,

View file

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -25,7 +26,7 @@ func TestIntrospectionResponse_SetUserInfo(t *testing.T) {
UserInfoProfile: userInfoData.UserInfoProfile, UserInfoProfile: userInfoData.UserInfoProfile,
UserInfoEmail: userInfoData.UserInfoEmail, UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone, UserInfoPhone: userInfoData.UserInfoPhone,
Claims: userInfoData.Claims, Claims: gu.MapCopy(userInfoData.Claims),
}, },
}, },
{ {

View file

@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"io" "io"
"os" "os"
"reflect"
"strings" "strings"
"testing" "testing"
@ -38,10 +39,12 @@ func Test_assert_regression(t *testing.T) {
assert.JSONEq(t, want, first) assert.JSONEq(t, want, first)
target := reflect.New(reflect.TypeOf(obj).Elem()).Interface()
require.NoError(t, require.NoError(t,
json.Unmarshal([]byte(first), obj), json.Unmarshal([]byte(first), target),
) )
second, err := json.Marshal(obj) second, err := json.Marshal(target)
require.NoError(t, err) require.NoError(t, err)
assert.JSONEq(t, want, string(second)) assert.JSONEq(t, want, string(second))

View file

@ -8,6 +8,7 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"github.com/muhlemmer/gu"
"github.com/zitadel/oidc/v3/pkg/crypto" "github.com/zitadel/oidc/v3/pkg/crypto"
) )
@ -157,6 +158,21 @@ func (t *IDTokenClaims) SetUserInfo(i *UserInfo) {
t.UserInfoEmail = i.UserInfoEmail t.UserInfoEmail = i.UserInfoEmail
t.UserInfoPhone = i.UserInfoPhone t.UserInfoPhone = i.UserInfoPhone
t.Address = i.Address t.Address = i.Address
if t.Claims == nil {
t.Claims = make(map[string]any, len(t.Claims))
}
gu.MapMerge(i.Claims, t.Claims)
}
func (t *IDTokenClaims) GetUserInfo() *UserInfo {
return &UserInfo{
Subject: t.Subject,
UserInfoProfile: t.UserInfoProfile,
UserInfoEmail: t.UserInfoEmail,
UserInfoPhone: t.UserInfoPhone,
Address: t.Address,
Claims: gu.MapCopy(t.Claims),
}
} }
func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration) *IDTokenClaims { func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration) *IDTokenClaims {

View file

@ -181,6 +181,9 @@ func TestIDTokenClaims_SetUserInfo(t *testing.T) {
UserInfoEmail: userInfoData.UserInfoEmail, UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone, UserInfoPhone: userInfoData.UserInfoPhone,
Address: userInfoData.Address, Address: userInfoData.Address,
Claims: map[string]interface{}{
"foo": "bar",
},
} }
var got IDTokenClaims var got IDTokenClaims
@ -225,3 +228,16 @@ func TestNewIDTokenClaims(t *testing.T) {
assert.Equal(t, want, got) assert.Equal(t, want, got)
} }
func TestIDTokenClaims_GetUserInfo(t *testing.T) {
want := &UserInfo{
Subject: idTokenData.Subject,
UserInfoProfile: idTokenData.UserInfoProfile,
UserInfoEmail: idTokenData.UserInfoEmail,
UserInfoPhone: idTokenData.UserInfoPhone,
Address: idTokenData.Address,
Claims: idTokenData.Claims,
}
got := idTokenData.GetUserInfo()
assert.Equal(t, want, got)
}

View file

@ -52,11 +52,14 @@ func TestUserInfoMarshal(t *testing.T) {
out := new(UserInfo) out := new(UserInfo)
assert.NoError(t, json.Unmarshal(marshal, out)) assert.NoError(t, json.Unmarshal(marshal, out))
assert.Equal(t, userinfo, out)
expected, err := json.Marshal(out) expected, err := json.Marshal(out)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, marshal) assert.Equal(t, expected, marshal)
out2 := new(UserInfo)
assert.NoError(t, json.Unmarshal(expected, out2))
assert.Equal(t, out, out2)
} }
func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) { func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) {

View file

@ -9,7 +9,7 @@ import (
// mergeAndMarshalClaims merges registered and the custom // mergeAndMarshalClaims merges registered and the custom
// claims map into a single JSON object. // claims map into a single JSON object.
// Registered fields overwrite custom claims. // Registered fields overwrite custom claims.
func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error) { func mergeAndMarshalClaims(registered any, extraClaims map[string]any) ([]byte, error) {
// Use a buffer for memory re-use, instead off letting // Use a buffer for memory re-use, instead off letting
// json allocate a new []byte for every step. // json allocate a new []byte for every step.
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
@ -19,16 +19,21 @@ func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error
return nil, fmt.Errorf("oidc registered claims: %w", err) return nil, fmt.Errorf("oidc registered claims: %w", err)
} }
if len(claims) > 0 { if len(extraClaims) > 0 {
merged := make(map[string]any)
for k, v := range extraClaims {
merged[k] = v
}
// Merge JSON data into custom claims. // Merge JSON data into custom claims.
// The full-read action by the decoder resets the buffer // The full-read action by the decoder resets the buffer
// to zero len, while retaining underlaying cap. // to zero len, while retaining underlaying cap.
if err := json.NewDecoder(buf).Decode(&claims); err != nil { if err := json.NewDecoder(buf).Decode(&merged); err != nil {
return nil, fmt.Errorf("oidc registered claims: %w", err) return nil, fmt.Errorf("oidc registered claims: %w", err)
} }
// Marshal the final result. // Marshal the final result.
if err := json.NewEncoder(buf).Encode(claims); err != nil { if err := json.NewEncoder(buf).Encode(merged); err != nil {
return nil, fmt.Errorf("oidc custom claims: %w", err) return nil, fmt.Errorf("oidc custom claims: %w", err)
} }
} }

View file

@ -67,7 +67,7 @@ func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, *
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder()) authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder())
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, nil, err, authorizer.Encoder())
return return
} }
ctx := r.Context() ctx := r.Context()
@ -273,9 +273,9 @@ func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
return scopes, nil return scopes, nil
} }
// checkURIAginstRedirects just checks aginst the valid redirect URIs and ignores // checkURIAgainstRedirects just checks aginst the valid redirect URIs and ignores
// other factors. // other factors.
func checkURIAginstRedirects(client Client, uri string) error { func checkURIAgainstRedirects(client Client, uri string) error {
if str.Contains(client.RedirectURIs(), uri) { if str.Contains(client.RedirectURIs(), uri) {
return nil return nil
} }
@ -302,12 +302,12 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res
"Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.") "Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.")
} }
if strings.HasPrefix(uri, "https://") { if strings.HasPrefix(uri, "https://") {
return checkURIAginstRedirects(client, uri) return checkURIAgainstRedirects(client, uri)
} }
if client.ApplicationType() == ApplicationTypeNative { if client.ApplicationType() == ApplicationTypeNative {
return validateAuthReqRedirectURINative(client, uri, responseType) return validateAuthReqRedirectURINative(client, uri, responseType)
} }
if err := checkURIAginstRedirects(client, uri); err != nil { if err := checkURIAgainstRedirects(client, uri); err != nil {
return err return err
} }
if strings.HasPrefix(uri, "http://") { if strings.HasPrefix(uri, "http://") {
@ -328,7 +328,7 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res
func validateAuthReqRedirectURINative(client Client, uri string, responseType oidc.ResponseType) error { func validateAuthReqRedirectURINative(client Client, uri string, responseType oidc.ResponseType) error {
parsedURL, isLoopback := HTTPLoopbackOrLocalhost(uri) parsedURL, isLoopback := HTTPLoopbackOrLocalhost(uri)
isCustomSchema := !strings.HasPrefix(uri, "http://") isCustomSchema := !strings.HasPrefix(uri, "http://")
if err := checkURIAginstRedirects(client, uri); err == nil { if err := checkURIAgainstRedirects(client, uri); err == nil {
if client.DevMode() { if client.DevMode() {
return nil return nil
} }

View file

@ -9,6 +9,7 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v3/internal/testutil" tu "github.com/zitadel/oidc/v3/internal/testutil"
@ -19,60 +20,34 @@ import (
"github.com/zitadel/schema" "github.com/zitadel/schema"
) )
// func TestAuthorize(t *testing.T) {
// TOOD: tests will be implemented in branch for service accounts tests := []struct {
// func TestAuthorize(t *testing.T) { name string
// // testCallback := func(t *testing.T, clienID string) callbackHandler { req *http.Request
// // return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) { expect func(a *mock.MockAuthorizerMockRecorder)
// // // require.Equal(t, clientID, client.) }{
// // } {
// // } name: "parse error", // used to panic, see issue #315
// // testErr := func(t *testing.T, expected error) errorHandler { req: httptest.NewRequest(http.MethodPost, "/?;", nil),
// // return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { },
// // require.Equal(t, expected, err) }
// // } for _, tt := range tests {
// // } t.Run(tt.name, func(t *testing.T) {
// type args struct { w := httptest.NewRecorder()
// w http.ResponseWriter authorizer := mock.NewMockAuthorizer(gomock.NewController(t))
// r *http.Request
// authorizer op.Authorizer expect := authorizer.EXPECT()
// } expect.Decoder().Return(schema.NewDecoder())
// tests := []struct { expect.Encoder().Return(schema.NewEncoder())
// name string
// args args if tt.expect != nil {
// }{ tt.expect(expect)
// { }
// "parsing fails",
// args{ op.Authorize(w, tt.req, authorizer)
// httptest.NewRecorder(), })
// &http.Request{Method: "POST", Body: nil}, }
// mock.NewAuthorizerExpectValid(t, true), }
// // testCallback(t, ""),
// // testErr(t, ErrInvalidRequest("cannot parse form")),
// },
// },
// {
// "decoding fails",
// args{
// httptest.NewRecorder(),
// func() *http.Request {
// r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo"))
// r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// return r
// }(),
// mock.NewAuthorizerExpectValid(t, true),
// // testCallback(t, ""),
// // testErr(t, ErrInvalidRequest("cannot parse auth request")),
// },
// },
// // {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}},
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer)
// })
// }
//}
func TestParseAuthorizeRequest(t *testing.T) { func TestParseAuthorizeRequest(t *testing.T) {
type args struct { type args struct {

View file

@ -56,6 +56,12 @@ type Client interface {
// interpretation. Redirect URIs that match either the non-glob version or the // interpretation. Redirect URIs that match either the non-glob version or the
// glob version will be accepted. Glob URIs are only partially supported for native // glob version will be accepted. Glob URIs are only partially supported for native
// clients: "http://" is not allowed except for loopback or in dev mode. // clients: "http://" is not allowed except for loopback or in dev mode.
//
// Note that globbing / wildcards are not permitted by the OIDC
// standard and implementing this interface can have security implications.
// It is advised to only return a client of this type in rare cases,
// such as DevMode for the client being enabled.
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
type HasRedirectGlobs interface { type HasRedirectGlobs interface {
RedirectURIGlobs() []string RedirectURIGlobs() []string
PostLogoutRedirectURIGlobs() []string PostLogoutRedirectURIGlobs() []string
@ -145,21 +151,30 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au
} }
data := new(clientData) data := new(clientData)
if err = p.Decoder().Decode(data, r.PostForm); err != nil { if err = p.Decoder().Decode(data, r.Form); err != nil {
return "", false, err return "", false, err
} }
JWTProfile, ok := p.(ClientJWTProfile) JWTProfile, ok := p.(ClientJWTProfile)
if ok { if ok && data.ClientAssertion != "" {
// if JWTProfile is supported and client sent an assertion, check it and use it as response
// regardless if it succeeded or failed
clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile) clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile)
return clientID, err == nil, err
} }
if !ok || errors.Is(err, ErrNoClientCredentials) { // try basic auth
clientID, err = ClientBasicAuth(r, p.Storage()) clientID, err = ClientBasicAuth(r, p.Storage())
} // if that succeeded, use it
if err == nil { if err == nil {
return clientID, true, nil return clientID, true, nil
} }
// if the client did not send a Basic Auth Header, ignore the `ErrNoClientCredentials`
// but return other errors immediately
if err != nil && !errors.Is(err, ErrNoClientCredentials) {
return "", false, err
}
// if the client did not authenticate (public clients) it must at least send a client_id
if data.ClientID == "" { if data.ClientID == "" {
return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID) return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID)
} }

View file

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"math/big" "math/big"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
@ -18,7 +19,14 @@ import (
type DeviceAuthorizationConfig struct { type DeviceAuthorizationConfig struct {
Lifetime time.Duration Lifetime time.Duration
PollInterval time.Duration PollInterval time.Duration
UserFormURL string // the URL where the user must go to authorize the device
// UserFormURL is the complete URL where the user must go to authorize the device.
// Deprecated: use UserFormPath instead.
UserFormURL string
// UserFormPath is the path where the user must go to authorize the device.
// The hostname for the URL is taken from the request by IssuerFromContext.
UserFormPath string
UserCode UserCodeConfig UserCode UserCodeConfig
} }
@ -82,15 +90,28 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
return err return err
} }
var verification *url.URL
if config.UserFormURL != "" {
if verification, err = url.Parse(config.UserFormURL); err != nil {
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
}
} else {
if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil {
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
}
verification.Path = config.UserFormPath
}
response := &oidc.DeviceAuthorizationResponse{ response := &oidc.DeviceAuthorizationResponse{
DeviceCode: deviceCode, DeviceCode: deviceCode,
UserCode: userCode, UserCode: userCode,
VerificationURI: config.UserFormURL, VerificationURI: verification.String(),
ExpiresIn: int(config.Lifetime / time.Second), ExpiresIn: int(config.Lifetime / time.Second),
Interval: int(config.PollInterval / time.Second), Interval: int(config.PollInterval / time.Second),
} }
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode) verification.RawQuery = "user_code=" + userCode
response.VerificationURIComplete = verification.String()
httphelper.MarshalJSON(w, response) httphelper.MarshalJSON(w, response)
return nil return nil

View file

@ -13,6 +13,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
@ -20,29 +21,60 @@ import (
) )
func Test_deviceAuthorizationHandler(t *testing.T) { func Test_deviceAuthorizationHandler(t *testing.T) {
req := &oidc.DeviceAuthorizationRequest{ type conf struct {
Scopes: []string{"foo", "bar"}, UserFormURL string
ClientID: "web", UserFormPath string
} }
values := make(url.Values) tests := []struct {
testProvider.Encoder().Encode(req, values) name string
body := strings.NewReader(values.Encode()) conf conf
}{
{
name: "UserFormURL",
conf: conf{
UserFormURL: "https://localhost:9998/device",
},
},
{
name: "UserFormPath",
conf: conf{
UserFormPath: "/device",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conf := gu.PtrCopy(testConfig)
conf.DeviceAuthorization.UserFormURL = tt.conf.UserFormURL
conf.DeviceAuthorization.UserFormPath = tt.conf.UserFormPath
provider := newTestProvider(conf)
r := httptest.NewRequest(http.MethodPost, "/", body) req := &oidc.DeviceAuthorizationRequest{
r.Header.Set("Content-Type", "application/x-www-form-urlencoded") Scopes: []string{"foo", "bar"},
ClientID: "web",
}
values := make(url.Values)
testProvider.Encoder().Encode(req, values)
body := strings.NewReader(values.Encode())
w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/", body)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r = r.WithContext(op.ContextWithIssuer(r.Context(), testIssuer))
runWithRandReader(mr.New(mr.NewSource(1)), func() { w := httptest.NewRecorder()
op.DeviceAuthorizationHandler(testProvider)(w, r)
})
result := w.Result() runWithRandReader(mr.New(mr.NewSource(1)), func() {
op.DeviceAuthorizationHandler(provider)(w, r)
})
assert.Less(t, result.StatusCode, 300) result := w.Result()
got, _ := io.ReadAll(result.Body) assert.Less(t, result.StatusCode, 300)
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
got, _ := io.ReadAll(result.Body)
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
})
}
} }
func TestParseDeviceCodeRequest(t *testing.T) { func TestParseDeviceCodeRequest(t *testing.T) {

View file

@ -480,6 +480,16 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
}
o.endpoints.DeviceAuthorization = endpoint
return nil
}
}
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option { func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
o.endpoints.Authorization = auth o.endpoints.Authorization = auth

View file

@ -20,15 +20,9 @@ import (
"golang.org/x/text/language" "golang.org/x/text/language"
) )
var testProvider op.OpenIDProvider var (
testProvider op.OpenIDProvider
const ( testConfig = &op.Config{
testIssuer = "https://localhost:9998/"
pathLoggedOut = "/logged-out"
)
func init() {
config := &op.Config{
CryptoKey: sha256.Sum256([]byte("test")), CryptoKey: sha256.Sum256([]byte("test")),
DefaultLogoutRedirectURI: pathLoggedOut, DefaultLogoutRedirectURI: pathLoggedOut,
CodeMethodS256: true, CodeMethodS256: true,
@ -40,24 +34,35 @@ func init() {
DeviceAuthorization: op.DeviceAuthorizationConfig{ DeviceAuthorization: op.DeviceAuthorizationConfig{
Lifetime: 5 * time.Minute, Lifetime: 5 * time.Minute,
PollInterval: 5 * time.Second, PollInterval: 5 * time.Second,
UserFormURL: testIssuer + "device", UserFormPath: "/device",
UserCode: op.UserCodeBase20, UserCode: op.UserCodeBase20,
}, },
} }
)
const (
testIssuer = "https://localhost:9998/"
pathLoggedOut = "/logged-out"
)
func init() {
storage.RegisterClients( storage.RegisterClients(
storage.NativeClient("native"), storage.NativeClient("native"),
storage.WebClient("web", "secret", "https://example.com"), storage.WebClient("web", "secret", "https://example.com"),
storage.WebClient("api", "secret"), storage.WebClient("api", "secret"),
) )
var err error testProvider = newTestProvider(testConfig)
testProvider, err = op.NewOpenIDProvider(testIssuer, config, }
func newTestProvider(config *op.Config) op.OpenIDProvider {
provider, err := op.NewOpenIDProvider(testIssuer, config,
storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(), storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
) )
if err != nil { if err != nil {
panic(err) panic(err)
} }
return provider
} }
type routesTestStorage interface { type routesTestStorage interface {

View file

@ -113,6 +113,8 @@ type OPStorage interface {
// handle the current request. // handle the current request.
GetClientByClientID(ctx context.Context, clientID string) (Client, error) GetClientByClientID(ctx context.Context, clientID string) (Client, error)
AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error
// SetUserinfoFromScopes is deprecated and should have an empty implementation for now.
// Implement SetUserinfoFromRequest instead.
SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error
SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error
SetIntrospectionFromToken(ctx context.Context, userinfo *oidc.IntrospectionResponse, tokenID, subject, clientID string) error SetIntrospectionFromToken(ctx context.Context, userinfo *oidc.IntrospectionResponse, tokenID, subject, clientID string) error
@ -127,6 +129,13 @@ type JWTProfileTokenStorage interface {
JWTProfileTokenType(ctx context.Context, request TokenRequest) (AccessTokenType, error) JWTProfileTokenType(ctx context.Context, request TokenRequest) (AccessTokenType, error)
} }
// CanSetUserinfoFromRequest is an optional additional interface that may be implemented by
// implementors of Storage. It allows additional data to be set in id_tokens based on the
// request.
type CanSetUserinfoFromRequest interface {
SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, request IDTokenRequest, scopes []string) error
}
// Storage is a required parameter for NewOpenIDProvider(). In addition to the // Storage is a required parameter for NewOpenIDProvider(). In addition to the
// embedded interfaces below, if the passed Storage implements ClientCredentialsStorage // embedded interfaces below, if the passed Storage implements ClientCredentialsStorage
// then the grant type "client_credentials" will be supported. In that case, the access // then the grant type "client_credentials" will be supported. In that case, the access

View file

@ -190,6 +190,12 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v
if err != nil { if err != nil {
return "", err return "", err
} }
if fromRequest, ok := storage.(CanSetUserinfoFromRequest); ok {
err := fromRequest.SetUserinfoFromRequest(ctx, userInfo, request, scopes)
if err != nil {
return "", err
}
}
claims.SetUserInfo(userInfo) claims.SetUserInfo(userInfo)
} }
if code != "" { if code != "" {