From bab53998599bfe56f1dc3c6a1c3f52b8a97b3a6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Tue, 24 Oct 2023 10:20:02 +0300 Subject: [PATCH] feat(op): allow Legacy Server extension (#466) This change splits the constructor and registration of the Legacy Server. This allows it to be extended by struct embedding. --- example/server/exampleop/op.go | 2 +- pkg/op/server_http.go | 56 ++++++++++++++++------------- pkg/op/server_http_routes_test.go | 2 +- pkg/op/server_legacy.go | 58 ++++++++++++++++++++++++------- 4 files changed, 79 insertions(+), 39 deletions(-) diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index 74018da..baa2662 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -80,7 +80,7 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer handler := http.Handler(provider) if wrapServer { - handler = op.NewLegacyServer(provider, *op.DefaultEndpoints) + handler = op.RegisterLegacyServer(op.NewLegacyServer(provider, *op.DefaultEndpoints)) } // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index 96ee7a5..750f7a9 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -25,11 +25,13 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) decoder.IgnoreUnknownKeys(true) ws := &webServer{ + router: chi.NewRouter(), server: server, endpoints: endpoints, decoder: decoder, logger: slog.Default(), } + ws.router.Use(cors.New(defaultCORSOptions).Handler) for _, option := range options { option(ws) @@ -45,7 +47,14 @@ type ServerOption func(s *webServer) // the Server's router. func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption { return func(s *webServer) { - s.middleware = m + s.router.Use(m...) + } +} + +// WithSetRouter allows customization or the Server's router. +func WithSetRouter(set func(chi.Router)) ServerOption { + return func(s *webServer) { + set(s.router) } } @@ -67,12 +76,15 @@ func WithFallbackLogger(logger *slog.Logger) ServerOption { } type webServer struct { - http.Handler - server Server - middleware []func(http.Handler) http.Handler - endpoints Endpoints - decoder httphelper.Decoder - logger *slog.Logger + server Server + router *chi.Mux + endpoints Endpoints + decoder httphelper.Decoder + logger *slog.Logger +} + +func (s *webServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.router.ServeHTTP(w, r) } func (s *webServer) getLogger(ctx context.Context) *slog.Logger { @@ -83,27 +95,23 @@ func (s *webServer) getLogger(ctx context.Context) *slog.Logger { } func (s *webServer) createRouter() { - router := chi.NewRouter() - router.Use(cors.New(defaultCORSOptions).Handler) - router.Use(s.middleware...) - router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health)) - router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready)) - router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery)) + s.router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health)) + s.router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready)) + s.router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery)) - s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler) - s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler)) - s.endpointRoute(router, s.endpoints.Token, s.tokensHandler) - s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler)) - s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler) - s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler)) - s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler) - s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys)) - s.Handler = router + s.endpointRoute(s.endpoints.Authorization, s.authorizeHandler) + s.endpointRoute(s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler)) + s.endpointRoute(s.endpoints.Token, s.tokensHandler) + s.endpointRoute(s.endpoints.Introspection, s.withClient(s.introspectionHandler)) + s.endpointRoute(s.endpoints.Userinfo, s.userInfoHandler) + s.endpointRoute(s.endpoints.Revocation, s.withClient(s.revocationHandler)) + s.endpointRoute(s.endpoints.EndSession, s.endSessionHandler) + s.endpointRoute(s.endpoints.JwksURI, simpleHandler(s, s.server.Keys)) } -func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) { +func (s *webServer) endpointRoute(e *Endpoint, hf http.HandlerFunc) { if e != nil { - router.HandleFunc(e.Relative(), hf) + s.router.HandleFunc(e.Relative(), hf) s.logger.Info("registered route", "endpoint", e.Relative()) } } diff --git a/pkg/op/server_http_routes_test.go b/pkg/op/server_http_routes_test.go index c7767d2..c50e989 100644 --- a/pkg/op/server_http_routes_test.go +++ b/pkg/op/server_http_routes_test.go @@ -32,7 +32,7 @@ func jwtProfile() (string, error) { } func TestServerRoutes(t *testing.T) { - server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints) + server := op.RegisterLegacyServer(op.NewLegacyServer(testProvider, *op.DefaultEndpoints)) storage := testProvider.Storage().(routesTestStorage) ctx := op.ContextWithIssuer(context.Background(), testIssuer) diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go index f373b9d..2006e90 100644 --- a/pkg/op/server_legacy.go +++ b/pkg/op/server_legacy.go @@ -10,37 +10,69 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" ) -// LegacyServer is an implementation of [Server[] that -// simply wraps a [OpenIDProvider]. +// ExtendedLegacyServer allows embedding [LegacyServer] in a struct, +// so that its methods can be individually overridden. +// +// EXPERIMENTAL: may change until v4 +type ExtendedLegacyServer interface { + Server + Provider() OpenIDProvider + Endpoints() Endpoints +} + +// RegisterLegacyServer registers a [LegacyServer] or an extension thereof. +// It takes care of registering the IssuerFromRequest middleware +// and Authorization Callback Routes. +// Neither are part of the bare [Server] interface. +// +// EXPERIMENTAL: may change until v4 +func RegisterLegacyServer(s ExtendedLegacyServer, options ...ServerOption) http.Handler { + provider := s.Provider() + options = append(options, + WithHTTPMiddleware(intercept(provider.IssuerFromRequest)), + WithSetRouter(func(r chi.Router) { + r.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider)) + }), + ) + return RegisterServer(s, s.Endpoints(), options...) +} + +// LegacyServer is an implementation of [Server] that +// simply wraps an [OpenIDProvider]. // It can be used to transition from the former Provider/Storage // interfaces to the new Server interface. +// +// EXPERIMENTAL: may change until v4 type LegacyServer struct { UnimplementedServer provider OpenIDProvider endpoints Endpoints } -// NewLegacyServer wraps provider in a `Server` and returns a handler which is -// the Server's router. +// NewLegacyServer wraps provider in a `Server` implementation // // Only non-nil endpoints will be registered on the router. // Nil endpoints are disabled. // -// The passed endpoints is also set to the provider, -// to be consistent with the discovery config. +// The passed endpoints is also used for the discovery config, +// and endpoints already set to the provider are ignored. // Any `With*Endpoint()` option used on the provider is // therefore ineffective. -func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler { - server := RegisterServer(&LegacyServer{ +// +// EXPERIMENTAL: may change until v4 +func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) *LegacyServer { + return &LegacyServer{ provider: provider, endpoints: endpoints, - }, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest))) + } +} - router := chi.NewRouter() - router.Mount("/", server) - router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider)) +func (s *LegacyServer) Provider() OpenIDProvider { + return s.provider +} - return router +func (s *LegacyServer) Endpoints() Endpoints { + return s.endpoints } func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {