diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index b807382..f1906ba 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -79,7 +79,7 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer handler := http.Handler(provider) if wrapServer { - handler = op.NewLegacyServer(provider) + handler = op.NewLegacyServer(provider, *op.DefaultEndpoints) } // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) diff --git a/pkg/op/config.go b/pkg/op/config.go index c40ed39..f61412a 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -20,14 +20,14 @@ var ( type Configuration interface { IssuerFromRequest(r *http.Request) string Insecure() bool - AuthorizationEndpoint() Endpoint - TokenEndpoint() Endpoint - IntrospectionEndpoint() Endpoint - UserinfoEndpoint() Endpoint - RevocationEndpoint() Endpoint - EndSessionEndpoint() Endpoint - KeysEndpoint() Endpoint - DeviceAuthorizationEndpoint() Endpoint + AuthorizationEndpoint() *Endpoint + TokenEndpoint() *Endpoint + IntrospectionEndpoint() *Endpoint + UserinfoEndpoint() *Endpoint + RevocationEndpoint() *Endpoint + EndSessionEndpoint() *Endpoint + KeysEndpoint() *Endpoint + DeviceAuthorizationEndpoint() *Endpoint AuthMethodPostSupported() bool CodeMethodS256Supported() bool diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index d376032..8251261 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -64,6 +64,37 @@ func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage Di } } +func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage DiscoverStorage, endpoints *Endpoints) *oidc.DiscoveryConfiguration { + issuer := IssuerFromContext(ctx) + return &oidc.DiscoveryConfiguration{ + Issuer: issuer, + AuthorizationEndpoint: endpoints.Authorization.Absolute(issuer), + TokenEndpoint: endpoints.Token.Absolute(issuer), + IntrospectionEndpoint: endpoints.Introspection.Absolute(issuer), + UserinfoEndpoint: endpoints.Userinfo.Absolute(issuer), + RevocationEndpoint: endpoints.Revocation.Absolute(issuer), + EndSessionEndpoint: endpoints.EndSession.Absolute(issuer), + JwksURI: endpoints.JwksURI.Absolute(issuer), + DeviceAuthorizationEndpoint: endpoints.DeviceAuthorization.Absolute(issuer), + ScopesSupported: Scopes(config), + ResponseTypesSupported: ResponseTypes(config), + GrantTypesSupported: GrantTypes(config), + SubjectTypesSupported: SubjectTypes(config), + IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage), + RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config), + TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config), + TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config), + IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config), + IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config), + RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config), + RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config), + ClaimsSupported: SupportedClaims(config), + CodeChallengeMethodsSupported: CodeChallengeMethods(config), + UILocalesSupported: config.SupportedUILocales(), + RequestParameterSupported: config.RequestObjectSupported(), + } +} + func Scopes(c Configuration) []string { return DefaultSupportedScopes // TODO: config } diff --git a/pkg/op/endpoint.go b/pkg/op/endpoint.go index b1e1507..1ac1cad 100644 --- a/pkg/op/endpoint.go +++ b/pkg/op/endpoint.go @@ -1,32 +1,46 @@ package op -import "strings" +import ( + "errors" + "strings" +) type Endpoint struct { path string url string } -func NewEndpoint(path string) Endpoint { - return Endpoint{path: path} +func NewEndpoint(path string) *Endpoint { + return &Endpoint{path: path} } -func NewEndpointWithURL(path, url string) Endpoint { - return Endpoint{path: path, url: url} +func NewEndpointWithURL(path, url string) *Endpoint { + return &Endpoint{path: path, url: url} } -func (e Endpoint) Relative() string { +func (e *Endpoint) Relative() string { + if e == nil { + return "" + } return relativeEndpoint(e.path) } -func (e Endpoint) Absolute(host string) string { +func (e *Endpoint) Absolute(host string) string { + if e == nil { + return "" + } if e.url != "" { return e.url } return absoluteEndpoint(host, e.path) } -func (e Endpoint) Validate() error { +var ErrNilEndpoint = errors.New("nil endpoint") + +func (e *Endpoint) Validate() error { + if e == nil { + return ErrNilEndpoint + } return nil // TODO: } diff --git a/pkg/op/endpoint_test.go b/pkg/op/endpoint_test.go index 46e5d47..bf112ef 100644 --- a/pkg/op/endpoint_test.go +++ b/pkg/op/endpoint_test.go @@ -3,13 +3,14 @@ package op_test import ( "testing" + "github.com/stretchr/testify/require" "github.com/zitadel/oidc/v3/pkg/op" ) func TestEndpoint_Path(t *testing.T) { tests := []struct { name string - e op.Endpoint + e *op.Endpoint want string }{ { @@ -27,6 +28,11 @@ func TestEndpoint_Path(t *testing.T) { op.NewEndpointWithURL("/test", "http://test.com/test"), "/test", }, + { + "nil", + nil, + "", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -43,7 +49,7 @@ func TestEndpoint_Absolute(t *testing.T) { } tests := []struct { name string - e op.Endpoint + e *op.Endpoint args args want string }{ @@ -77,6 +83,12 @@ func TestEndpoint_Absolute(t *testing.T) { args{"https://host"}, "https://test.com/test", }, + { + "nil", + nil, + args{"https://host"}, + "", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -91,16 +103,19 @@ func TestEndpoint_Absolute(t *testing.T) { func TestEndpoint_Validate(t *testing.T) { tests := []struct { name string - e op.Endpoint - wantErr bool + e *op.Endpoint + wantErr error }{ - // TODO: Add test cases. + { + "nil", + nil, + op.ErrNilEndpoint, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.e.Validate(); (err != nil) != tt.wantErr { - t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr) - } + err := tt.e.Validate() + require.ErrorIs(t, err, tt.wantErr) }) } } diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go index 96429dd..f392a45 100644 --- a/pkg/op/mock/configuration.mock.go +++ b/pkg/op/mock/configuration.mock.go @@ -65,10 +65,10 @@ func (mr *MockConfigurationMockRecorder) AuthMethodPrivateKeyJWTSupported() *gom } // AuthorizationEndpoint mocks base method. -func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint { +func (m *MockConfiguration) AuthorizationEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AuthorizationEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -107,10 +107,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call { } // DeviceAuthorizationEndpoint mocks base method. -func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint { +func (m *MockConfiguration) DeviceAuthorizationEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -121,10 +121,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.C } // EndSessionEndpoint mocks base method. -func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint { +func (m *MockConfiguration) EndSessionEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "EndSessionEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -233,10 +233,10 @@ func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSup } // IntrospectionEndpoint mocks base method. -func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint { +func (m *MockConfiguration) IntrospectionEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IntrospectionEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -275,10 +275,10 @@ func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *go } // KeysEndpoint mocks base method. -func (m *MockConfiguration) KeysEndpoint() op.Endpoint { +func (m *MockConfiguration) KeysEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "KeysEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -331,10 +331,10 @@ func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSuppor } // RevocationEndpoint mocks base method. -func (m *MockConfiguration) RevocationEndpoint() op.Endpoint { +func (m *MockConfiguration) RevocationEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RevocationEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -373,10 +373,10 @@ func (mr *MockConfigurationMockRecorder) SupportedUILocales() *gomock.Call { } // TokenEndpoint mocks base method. -func (m *MockConfiguration) TokenEndpoint() op.Endpoint { +func (m *MockConfiguration) TokenEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "TokenEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -401,10 +401,10 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported } // UserinfoEndpoint mocks base method. -func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint { +func (m *MockConfiguration) UserinfoEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UserinfoEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } diff --git a/pkg/op/op.go b/pkg/op/op.go index 5b318e3..55ee986 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -131,16 +131,17 @@ type Config struct { DeviceAuthorization DeviceAuthorizationConfig } +// Endpoints defines endpoint routes. type Endpoints struct { - Authorization Endpoint - Token Endpoint - Introspection Endpoint - Userinfo Endpoint - Revocation Endpoint - EndSession Endpoint - CheckSessionIframe Endpoint - JwksURI Endpoint - DeviceAuthorization Endpoint + Authorization *Endpoint + Token *Endpoint + Introspection *Endpoint + Userinfo *Endpoint + Revocation *Endpoint + EndSession *Endpoint + CheckSessionIframe *Endpoint + JwksURI *Endpoint + DeviceAuthorization *Endpoint } // NewOpenIDProvider creates a provider. The provider provides (with HttpHandler()) @@ -233,35 +234,35 @@ func (o *Provider) Insecure() bool { return o.insecure } -func (o *Provider) AuthorizationEndpoint() Endpoint { +func (o *Provider) AuthorizationEndpoint() *Endpoint { return o.endpoints.Authorization } -func (o *Provider) TokenEndpoint() Endpoint { +func (o *Provider) TokenEndpoint() *Endpoint { return o.endpoints.Token } -func (o *Provider) IntrospectionEndpoint() Endpoint { +func (o *Provider) IntrospectionEndpoint() *Endpoint { return o.endpoints.Introspection } -func (o *Provider) UserinfoEndpoint() Endpoint { +func (o *Provider) UserinfoEndpoint() *Endpoint { return o.endpoints.Userinfo } -func (o *Provider) RevocationEndpoint() Endpoint { +func (o *Provider) RevocationEndpoint() *Endpoint { return o.endpoints.Revocation } -func (o *Provider) EndSessionEndpoint() Endpoint { +func (o *Provider) EndSessionEndpoint() *Endpoint { return o.endpoints.EndSession } -func (o *Provider) DeviceAuthorizationEndpoint() Endpoint { +func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint { return o.endpoints.DeviceAuthorization } -func (o *Provider) KeysEndpoint() Endpoint { +func (o *Provider) KeysEndpoint() *Endpoint { return o.endpoints.JwksURI } @@ -420,7 +421,7 @@ func WithAllowInsecure() Option { } } -func WithCustomAuthEndpoint(endpoint Endpoint) Option { +func WithCustomAuthEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -430,7 +431,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option { } } -func WithCustomTokenEndpoint(endpoint Endpoint) Option { +func WithCustomTokenEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -440,7 +441,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option { } } -func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option { +func WithCustomIntrospectionEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -450,7 +451,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option { } } -func WithCustomUserinfoEndpoint(endpoint Endpoint) Option { +func WithCustomUserinfoEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -460,7 +461,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option { } } -func WithCustomRevocationEndpoint(endpoint Endpoint) Option { +func WithCustomRevocationEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -470,7 +471,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option { } } -func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { +func WithCustomEndSessionEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -480,7 +481,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { } } -func WithCustomKeysEndpoint(endpoint Endpoint) Option { +func WithCustomKeysEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -490,7 +491,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option { } } -func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option { +func WithCustomDeviceAuthorizationEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -500,8 +501,16 @@ func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option { } } -func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option { +// WithCustomEndpoints sets multiple endpoints at once. +// Non of the endpoints may be nil, or an error will +// be returned when the Option used by the Provider. +func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys *Endpoint) Option { return func(o *Provider) error { + for _, e := range []*Endpoint{auth, token, userInfo, revocation, endSession, keys} { + if err := e.Validate(); err != nil { + return err + } + } o.endpoints.Authorization = auth o.endpoints.Token = token o.endpoints.Userinfo = userInfo diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index d33b39d..abe53bc 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -395,3 +395,54 @@ func TestRoutes(t *testing.T) { }) } } + +func TestWithCustomEndpoints(t *testing.T) { + type args struct { + auth *op.Endpoint + token *op.Endpoint + userInfo *op.Endpoint + revocation *op.Endpoint + endSession *op.Endpoint + keys *op.Endpoint + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "all nil", + args: args{}, + wantErr: op.ErrNilEndpoint, + }, + { + name: "all set", + args: args{ + auth: op.NewEndpoint("/authorize"), + token: op.NewEndpoint("/oauth/token"), + userInfo: op.NewEndpoint("/userinfo"), + revocation: op.NewEndpoint("/revoke"), + endSession: op.NewEndpoint("/end_session"), + keys: op.NewEndpoint("/keys"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := op.NewOpenIDProvider(testIssuer, testConfig, + storage.NewStorage(storage.NewUserStore(testIssuer)), + op.WithCustomEndpoints(tt.args.auth, tt.args.token, tt.args.userInfo, tt.args.revocation, tt.args.endSession, tt.args.keys), + ) + require.ErrorIs(t, err, tt.wantErr) + if tt.wantErr != nil { + return + } + assert.Equal(t, tt.args.auth, provider.AuthorizationEndpoint()) + assert.Equal(t, tt.args.token, provider.TokenEndpoint()) + assert.Equal(t, tt.args.userInfo, provider.UserinfoEndpoint()) + assert.Equal(t, tt.args.revocation, provider.RevocationEndpoint()) + assert.Equal(t, tt.args.endSession, provider.EndSessionEndpoint()) + assert.Equal(t, tt.args.keys, provider.KeysEndpoint()) + }) + } +} diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index 8956436..3fb481d 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -20,13 +20,13 @@ import ( // The routes can be customized with [WithEndpoints]. // // EXPERIMENTAL: may change until v4 -func RegisterServer(server Server, options ...ServerOption) http.Handler { +func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler { decoder := schema.NewDecoder() decoder.IgnoreUnknownKeys(true) ws := &webServer{ server: server, - endpoints: *DefaultEndpoints, + endpoints: endpoints, decoder: decoder, logger: slog.Default(), } @@ -49,13 +49,6 @@ func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption { } } -// WithEndpoints overrides the [DefaultEndpoints] -func WithEndpoints(endpoints Endpoints) ServerOption { - return func(s *webServer) { - s.endpoints = endpoints - } -} - // WithDecoder overrides the default decoder, // which is a [schema.Decoder] with IgnoreUnknownKeys set to true. func WithDecoder(decoder httphelper.Decoder) ServerOption { @@ -96,17 +89,25 @@ func (s *webServer) createRouter() { router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health)) router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready)) router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery)) - router.HandleFunc(s.endpoints.Authorization.Relative(), s.authorizeHandler) - router.HandleFunc(s.endpoints.DeviceAuthorization.Relative(), s.withClient(s.deviceAuthorizationHandler)) - router.HandleFunc(s.endpoints.Token.Relative(), s.tokensHandler) - router.HandleFunc(s.endpoints.Introspection.Relative(), s.withClient(s.introspectionHandler)) - router.HandleFunc(s.endpoints.Userinfo.Relative(), s.userInfoHandler) - router.HandleFunc(s.endpoints.Revocation.Relative(), s.withClient(s.revocationHandler)) - router.HandleFunc(s.endpoints.EndSession.Relative(), s.endSessionHandler) - router.HandleFunc(s.endpoints.JwksURI.Relative(), simpleHandler(s, s.server.Keys)) + + s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler) + s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler)) + s.endpointRoute(router, s.endpoints.Token, s.tokensHandler) + s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler)) + s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler) + s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler)) + s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler) + s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys)) s.Handler = router } +func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) { + if e != nil { + router.HandleFunc(e.Relative(), hf) + s.logger.Info("registered route", "endpoint", e.Relative()) + } +} + type clientHandler func(w http.ResponseWriter, r *http.Request, client Client) func (s *webServer) withClient(handler clientHandler) http.HandlerFunc { diff --git a/pkg/op/server_http_routes_test.go b/pkg/op/server_http_routes_test.go index 3addea2..730e745 100644 --- a/pkg/op/server_http_routes_test.go +++ b/pkg/op/server_http_routes_test.go @@ -18,7 +18,7 @@ import ( ) func TestServerRoutes(t *testing.T) { - server := op.NewLegacyServer(testProvider) + server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints) storage := testProvider.Storage().(routesTestStorage) ctx := op.ContextWithIssuer(context.Background(), testIssuer) diff --git a/pkg/op/server_http_test.go b/pkg/op/server_http_test.go index 5ee14e3..86fe7ed 100644 --- a/pkg/op/server_http_test.go +++ b/pkg/op/server_http_test.go @@ -25,15 +25,14 @@ import ( func TestRegisterServer(t *testing.T) { server := UnimplementedServer{} endpoints := Endpoints{ - Authorization: Endpoint{ + Authorization: &Endpoint{ path: "/auth", }, } decoder := schema.NewDecoder() logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) - h := RegisterServer(server, - WithEndpoints(endpoints), + h := RegisterServer(server, endpoints, WithDecoder(decoder), WithFallbackLogger(logger), ) diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go index 8f6ef17..0a7de85 100644 --- a/pkg/op/server_legacy.go +++ b/pkg/op/server_legacy.go @@ -10,15 +10,31 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" ) +// LegacyServer is an implementation of [Server[] that +// simply wraps a [OpenIDProvider]. +// It can be used to transition from the former Provider/Storage +// interfaces to the new Server interface. type LegacyServer struct { UnimplementedServer - provider OpenIDProvider + provider OpenIDProvider + endpoints Endpoints } -func NewLegacyServer(provider OpenIDProvider) http.Handler { +// NewLegacyServer wraps provider in a `Server` and returns a handler which is +// the Server's router. +// +// Only non-nil endpoints will be registered on the router. +// Nil endpoints are disabled. +// +// The passed endpoints is also set to the provider, +// to be consistent with the discovery config. +// Any `With*Endpoint()` option used on the provider is +// therefore ineffective. +func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler { server := RegisterServer(&LegacyServer{ - provider: provider, - }, WithHTTPMiddleware(intercept(provider.IssuerFromRequest))) + provider: provider, + endpoints: endpoints, + }, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest))) router := chi.NewRouter() router.Mount("/", server) @@ -43,7 +59,7 @@ func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Respon func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) { return NewResponse( - CreateDiscoveryConfig(ctx, s.provider, s.provider.Storage()), + createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints), ), nil }