make endpoints pointers to enable/disable them
This commit is contained in:
parent
f6cb47fbbb
commit
af22c1a4d8
12 changed files with 229 additions and 93 deletions
|
@ -79,7 +79,7 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer
|
|||
|
||||
handler := http.Handler(provider)
|
||||
if wrapServer {
|
||||
handler = op.NewLegacyServer(provider)
|
||||
handler = op.NewLegacyServer(provider, *op.DefaultEndpoints)
|
||||
}
|
||||
|
||||
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
|
||||
|
|
|
@ -20,14 +20,14 @@ var (
|
|||
type Configuration interface {
|
||||
IssuerFromRequest(r *http.Request) string
|
||||
Insecure() bool
|
||||
AuthorizationEndpoint() Endpoint
|
||||
TokenEndpoint() Endpoint
|
||||
IntrospectionEndpoint() Endpoint
|
||||
UserinfoEndpoint() Endpoint
|
||||
RevocationEndpoint() Endpoint
|
||||
EndSessionEndpoint() Endpoint
|
||||
KeysEndpoint() Endpoint
|
||||
DeviceAuthorizationEndpoint() Endpoint
|
||||
AuthorizationEndpoint() *Endpoint
|
||||
TokenEndpoint() *Endpoint
|
||||
IntrospectionEndpoint() *Endpoint
|
||||
UserinfoEndpoint() *Endpoint
|
||||
RevocationEndpoint() *Endpoint
|
||||
EndSessionEndpoint() *Endpoint
|
||||
KeysEndpoint() *Endpoint
|
||||
DeviceAuthorizationEndpoint() *Endpoint
|
||||
|
||||
AuthMethodPostSupported() bool
|
||||
CodeMethodS256Supported() bool
|
||||
|
|
|
@ -64,6 +64,37 @@ func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage Di
|
|||
}
|
||||
}
|
||||
|
||||
func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage DiscoverStorage, endpoints *Endpoints) *oidc.DiscoveryConfiguration {
|
||||
issuer := IssuerFromContext(ctx)
|
||||
return &oidc.DiscoveryConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: endpoints.Authorization.Absolute(issuer),
|
||||
TokenEndpoint: endpoints.Token.Absolute(issuer),
|
||||
IntrospectionEndpoint: endpoints.Introspection.Absolute(issuer),
|
||||
UserinfoEndpoint: endpoints.Userinfo.Absolute(issuer),
|
||||
RevocationEndpoint: endpoints.Revocation.Absolute(issuer),
|
||||
EndSessionEndpoint: endpoints.EndSession.Absolute(issuer),
|
||||
JwksURI: endpoints.JwksURI.Absolute(issuer),
|
||||
DeviceAuthorizationEndpoint: endpoints.DeviceAuthorization.Absolute(issuer),
|
||||
ScopesSupported: Scopes(config),
|
||||
ResponseTypesSupported: ResponseTypes(config),
|
||||
GrantTypesSupported: GrantTypes(config),
|
||||
SubjectTypesSupported: SubjectTypes(config),
|
||||
IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
|
||||
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
|
||||
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
|
||||
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
|
||||
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config),
|
||||
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config),
|
||||
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config),
|
||||
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config),
|
||||
ClaimsSupported: SupportedClaims(config),
|
||||
CodeChallengeMethodsSupported: CodeChallengeMethods(config),
|
||||
UILocalesSupported: config.SupportedUILocales(),
|
||||
RequestParameterSupported: config.RequestObjectSupported(),
|
||||
}
|
||||
}
|
||||
|
||||
func Scopes(c Configuration) []string {
|
||||
return DefaultSupportedScopes // TODO: config
|
||||
}
|
||||
|
|
|
@ -1,32 +1,46 @@
|
|||
package op
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Endpoint struct {
|
||||
path string
|
||||
url string
|
||||
}
|
||||
|
||||
func NewEndpoint(path string) Endpoint {
|
||||
return Endpoint{path: path}
|
||||
func NewEndpoint(path string) *Endpoint {
|
||||
return &Endpoint{path: path}
|
||||
}
|
||||
|
||||
func NewEndpointWithURL(path, url string) Endpoint {
|
||||
return Endpoint{path: path, url: url}
|
||||
func NewEndpointWithURL(path, url string) *Endpoint {
|
||||
return &Endpoint{path: path, url: url}
|
||||
}
|
||||
|
||||
func (e Endpoint) Relative() string {
|
||||
func (e *Endpoint) Relative() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return relativeEndpoint(e.path)
|
||||
}
|
||||
|
||||
func (e Endpoint) Absolute(host string) string {
|
||||
func (e *Endpoint) Absolute(host string) string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.url != "" {
|
||||
return e.url
|
||||
}
|
||||
return absoluteEndpoint(host, e.path)
|
||||
}
|
||||
|
||||
func (e Endpoint) Validate() error {
|
||||
var ErrNilEndpoint = errors.New("nil endpoint")
|
||||
|
||||
func (e *Endpoint) Validate() error {
|
||||
if e == nil {
|
||||
return ErrNilEndpoint
|
||||
}
|
||||
return nil // TODO:
|
||||
}
|
||||
|
||||
|
|
|
@ -3,13 +3,14 @@ package op_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
func TestEndpoint_Path(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
e op.Endpoint
|
||||
e *op.Endpoint
|
||||
want string
|
||||
}{
|
||||
{
|
||||
|
@ -27,6 +28,11 @@ func TestEndpoint_Path(t *testing.T) {
|
|||
op.NewEndpointWithURL("/test", "http://test.com/test"),
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
"",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -43,7 +49,7 @@ func TestEndpoint_Absolute(t *testing.T) {
|
|||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
e op.Endpoint
|
||||
e *op.Endpoint
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
|
@ -77,6 +83,12 @@ func TestEndpoint_Absolute(t *testing.T) {
|
|||
args{"https://host"},
|
||||
"https://test.com/test",
|
||||
},
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
args{"https://host"},
|
||||
"",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -91,16 +103,19 @@ func TestEndpoint_Absolute(t *testing.T) {
|
|||
func TestEndpoint_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
e op.Endpoint
|
||||
wantErr bool
|
||||
e *op.Endpoint
|
||||
wantErr error
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
op.ErrNilEndpoint,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.e.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
err := tt.e.Validate()
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -65,10 +65,10 @@ func (mr *MockConfigurationMockRecorder) AuthMethodPrivateKeyJWTSupported() *gom
|
|||
}
|
||||
|
||||
// AuthorizationEndpoint mocks base method.
|
||||
func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) AuthorizationEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthorizationEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -107,10 +107,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
|
|||
}
|
||||
|
||||
// DeviceAuthorizationEndpoint mocks base method.
|
||||
func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) DeviceAuthorizationEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -121,10 +121,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.C
|
|||
}
|
||||
|
||||
// EndSessionEndpoint mocks base method.
|
||||
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) EndSessionEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "EndSessionEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -233,10 +233,10 @@ func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSup
|
|||
}
|
||||
|
||||
// IntrospectionEndpoint mocks base method.
|
||||
func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) IntrospectionEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IntrospectionEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -275,10 +275,10 @@ func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *go
|
|||
}
|
||||
|
||||
// KeysEndpoint mocks base method.
|
||||
func (m *MockConfiguration) KeysEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) KeysEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "KeysEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -331,10 +331,10 @@ func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSuppor
|
|||
}
|
||||
|
||||
// RevocationEndpoint mocks base method.
|
||||
func (m *MockConfiguration) RevocationEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) RevocationEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RevocationEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -373,10 +373,10 @@ func (mr *MockConfigurationMockRecorder) SupportedUILocales() *gomock.Call {
|
|||
}
|
||||
|
||||
// TokenEndpoint mocks base method.
|
||||
func (m *MockConfiguration) TokenEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) TokenEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TokenEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -401,10 +401,10 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
|
|||
}
|
||||
|
||||
// UserinfoEndpoint mocks base method.
|
||||
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) UserinfoEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UserinfoEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
|
61
pkg/op/op.go
61
pkg/op/op.go
|
@ -131,16 +131,17 @@ type Config struct {
|
|||
DeviceAuthorization DeviceAuthorizationConfig
|
||||
}
|
||||
|
||||
// Endpoints defines endpoint routes.
|
||||
type Endpoints struct {
|
||||
Authorization Endpoint
|
||||
Token Endpoint
|
||||
Introspection Endpoint
|
||||
Userinfo Endpoint
|
||||
Revocation Endpoint
|
||||
EndSession Endpoint
|
||||
CheckSessionIframe Endpoint
|
||||
JwksURI Endpoint
|
||||
DeviceAuthorization Endpoint
|
||||
Authorization *Endpoint
|
||||
Token *Endpoint
|
||||
Introspection *Endpoint
|
||||
Userinfo *Endpoint
|
||||
Revocation *Endpoint
|
||||
EndSession *Endpoint
|
||||
CheckSessionIframe *Endpoint
|
||||
JwksURI *Endpoint
|
||||
DeviceAuthorization *Endpoint
|
||||
}
|
||||
|
||||
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
|
||||
|
@ -233,35 +234,35 @@ func (o *Provider) Insecure() bool {
|
|||
return o.insecure
|
||||
}
|
||||
|
||||
func (o *Provider) AuthorizationEndpoint() Endpoint {
|
||||
func (o *Provider) AuthorizationEndpoint() *Endpoint {
|
||||
return o.endpoints.Authorization
|
||||
}
|
||||
|
||||
func (o *Provider) TokenEndpoint() Endpoint {
|
||||
func (o *Provider) TokenEndpoint() *Endpoint {
|
||||
return o.endpoints.Token
|
||||
}
|
||||
|
||||
func (o *Provider) IntrospectionEndpoint() Endpoint {
|
||||
func (o *Provider) IntrospectionEndpoint() *Endpoint {
|
||||
return o.endpoints.Introspection
|
||||
}
|
||||
|
||||
func (o *Provider) UserinfoEndpoint() Endpoint {
|
||||
func (o *Provider) UserinfoEndpoint() *Endpoint {
|
||||
return o.endpoints.Userinfo
|
||||
}
|
||||
|
||||
func (o *Provider) RevocationEndpoint() Endpoint {
|
||||
func (o *Provider) RevocationEndpoint() *Endpoint {
|
||||
return o.endpoints.Revocation
|
||||
}
|
||||
|
||||
func (o *Provider) EndSessionEndpoint() Endpoint {
|
||||
func (o *Provider) EndSessionEndpoint() *Endpoint {
|
||||
return o.endpoints.EndSession
|
||||
}
|
||||
|
||||
func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
|
||||
func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint {
|
||||
return o.endpoints.DeviceAuthorization
|
||||
}
|
||||
|
||||
func (o *Provider) KeysEndpoint() Endpoint {
|
||||
func (o *Provider) KeysEndpoint() *Endpoint {
|
||||
return o.endpoints.JwksURI
|
||||
}
|
||||
|
||||
|
@ -420,7 +421,7 @@ func WithAllowInsecure() Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomAuthEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomAuthEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -430,7 +431,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomTokenEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomTokenEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -440,7 +441,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomIntrospectionEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -450,7 +451,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomUserinfoEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -460,7 +461,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomRevocationEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -470,7 +471,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomEndSessionEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -480,7 +481,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomKeysEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomKeysEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -490,7 +491,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomDeviceAuthorizationEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -500,8 +501,16 @@ func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option {
|
||||
// WithCustomEndpoints sets multiple endpoints at once.
|
||||
// Non of the endpoints may be nil, or an error will
|
||||
// be returned when the Option used by the Provider.
|
||||
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
for _, e := range []*Endpoint{auth, token, userInfo, revocation, endSession, keys} {
|
||||
if err := e.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
o.endpoints.Authorization = auth
|
||||
o.endpoints.Token = token
|
||||
o.endpoints.Userinfo = userInfo
|
||||
|
|
|
@ -395,3 +395,54 @@ func TestRoutes(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithCustomEndpoints(t *testing.T) {
|
||||
type args struct {
|
||||
auth *op.Endpoint
|
||||
token *op.Endpoint
|
||||
userInfo *op.Endpoint
|
||||
revocation *op.Endpoint
|
||||
endSession *op.Endpoint
|
||||
keys *op.Endpoint
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "all nil",
|
||||
args: args{},
|
||||
wantErr: op.ErrNilEndpoint,
|
||||
},
|
||||
{
|
||||
name: "all set",
|
||||
args: args{
|
||||
auth: op.NewEndpoint("/authorize"),
|
||||
token: op.NewEndpoint("/oauth/token"),
|
||||
userInfo: op.NewEndpoint("/userinfo"),
|
||||
revocation: op.NewEndpoint("/revoke"),
|
||||
endSession: op.NewEndpoint("/end_session"),
|
||||
keys: op.NewEndpoint("/keys"),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, err := op.NewOpenIDProvider(testIssuer, testConfig,
|
||||
storage.NewStorage(storage.NewUserStore(testIssuer)),
|
||||
op.WithCustomEndpoints(tt.args.auth, tt.args.token, tt.args.userInfo, tt.args.revocation, tt.args.endSession, tt.args.keys),
|
||||
)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
if tt.wantErr != nil {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, tt.args.auth, provider.AuthorizationEndpoint())
|
||||
assert.Equal(t, tt.args.token, provider.TokenEndpoint())
|
||||
assert.Equal(t, tt.args.userInfo, provider.UserinfoEndpoint())
|
||||
assert.Equal(t, tt.args.revocation, provider.RevocationEndpoint())
|
||||
assert.Equal(t, tt.args.endSession, provider.EndSessionEndpoint())
|
||||
assert.Equal(t, tt.args.keys, provider.KeysEndpoint())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,13 +20,13 @@ import (
|
|||
// The routes can be customized with [WithEndpoints].
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
func RegisterServer(server Server, options ...ServerOption) http.Handler {
|
||||
func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler {
|
||||
decoder := schema.NewDecoder()
|
||||
decoder.IgnoreUnknownKeys(true)
|
||||
|
||||
ws := &webServer{
|
||||
server: server,
|
||||
endpoints: *DefaultEndpoints,
|
||||
endpoints: endpoints,
|
||||
decoder: decoder,
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
@ -49,13 +49,6 @@ func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
|
|||
}
|
||||
}
|
||||
|
||||
// WithEndpoints overrides the [DefaultEndpoints]
|
||||
func WithEndpoints(endpoints Endpoints) ServerOption {
|
||||
return func(s *webServer) {
|
||||
s.endpoints = endpoints
|
||||
}
|
||||
}
|
||||
|
||||
// WithDecoder overrides the default decoder,
|
||||
// which is a [schema.Decoder] with IgnoreUnknownKeys set to true.
|
||||
func WithDecoder(decoder httphelper.Decoder) ServerOption {
|
||||
|
@ -96,17 +89,25 @@ func (s *webServer) createRouter() {
|
|||
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
|
||||
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
|
||||
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
|
||||
router.HandleFunc(s.endpoints.Authorization.Relative(), s.authorizeHandler)
|
||||
router.HandleFunc(s.endpoints.DeviceAuthorization.Relative(), s.withClient(s.deviceAuthorizationHandler))
|
||||
router.HandleFunc(s.endpoints.Token.Relative(), s.tokensHandler)
|
||||
router.HandleFunc(s.endpoints.Introspection.Relative(), s.withClient(s.introspectionHandler))
|
||||
router.HandleFunc(s.endpoints.Userinfo.Relative(), s.userInfoHandler)
|
||||
router.HandleFunc(s.endpoints.Revocation.Relative(), s.withClient(s.revocationHandler))
|
||||
router.HandleFunc(s.endpoints.EndSession.Relative(), s.endSessionHandler)
|
||||
router.HandleFunc(s.endpoints.JwksURI.Relative(), simpleHandler(s, s.server.Keys))
|
||||
|
||||
s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler)
|
||||
s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
|
||||
s.endpointRoute(router, s.endpoints.Token, s.tokensHandler)
|
||||
s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler))
|
||||
s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler)
|
||||
s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler))
|
||||
s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler)
|
||||
s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
|
||||
s.Handler = router
|
||||
}
|
||||
|
||||
func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) {
|
||||
if e != nil {
|
||||
router.HandleFunc(e.Relative(), hf)
|
||||
s.logger.Info("registered route", "endpoint", e.Relative())
|
||||
}
|
||||
}
|
||||
|
||||
type clientHandler func(w http.ResponseWriter, r *http.Request, client Client)
|
||||
|
||||
func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {
|
||||
|
|
|
@ -18,7 +18,7 @@ import (
|
|||
)
|
||||
|
||||
func TestServerRoutes(t *testing.T) {
|
||||
server := op.NewLegacyServer(testProvider)
|
||||
server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints)
|
||||
|
||||
storage := testProvider.Storage().(routesTestStorage)
|
||||
ctx := op.ContextWithIssuer(context.Background(), testIssuer)
|
||||
|
|
|
@ -25,15 +25,14 @@ import (
|
|||
func TestRegisterServer(t *testing.T) {
|
||||
server := UnimplementedServer{}
|
||||
endpoints := Endpoints{
|
||||
Authorization: Endpoint{
|
||||
Authorization: &Endpoint{
|
||||
path: "/auth",
|
||||
},
|
||||
}
|
||||
decoder := schema.NewDecoder()
|
||||
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
|
||||
|
||||
h := RegisterServer(server,
|
||||
WithEndpoints(endpoints),
|
||||
h := RegisterServer(server, endpoints,
|
||||
WithDecoder(decoder),
|
||||
WithFallbackLogger(logger),
|
||||
)
|
||||
|
|
|
@ -10,15 +10,31 @@ import (
|
|||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
// LegacyServer is an implementation of [Server[] that
|
||||
// simply wraps a [OpenIDProvider].
|
||||
// It can be used to transition from the former Provider/Storage
|
||||
// interfaces to the new Server interface.
|
||||
type LegacyServer struct {
|
||||
UnimplementedServer
|
||||
provider OpenIDProvider
|
||||
endpoints Endpoints
|
||||
}
|
||||
|
||||
func NewLegacyServer(provider OpenIDProvider) http.Handler {
|
||||
// NewLegacyServer wraps provider in a `Server` and returns a handler which is
|
||||
// the Server's router.
|
||||
//
|
||||
// Only non-nil endpoints will be registered on the router.
|
||||
// Nil endpoints are disabled.
|
||||
//
|
||||
// The passed endpoints is also set to the provider,
|
||||
// to be consistent with the discovery config.
|
||||
// Any `With*Endpoint()` option used on the provider is
|
||||
// therefore ineffective.
|
||||
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler {
|
||||
server := RegisterServer(&LegacyServer{
|
||||
provider: provider,
|
||||
}, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
|
||||
endpoints: endpoints,
|
||||
}, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Mount("/", server)
|
||||
|
@ -43,7 +59,7 @@ func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Respon
|
|||
|
||||
func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
return NewResponse(
|
||||
CreateDiscoveryConfig(ctx, s.provider, s.provider.Storage()),
|
||||
createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints),
|
||||
), nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue