feat(op): allow Legacy Server extension (#466)

This change splits the constructor and registration of the Legacy Server.
This allows it to be extended by struct embedding.
This commit is contained in:
Tim Möhlmann 2023-10-24 10:20:02 +03:00 committed by GitHub
parent 164c5b28c7
commit bab5399859
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 79 additions and 39 deletions

View file

@ -80,7 +80,7 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer
handler := http.Handler(provider) handler := http.Handler(provider)
if wrapServer { if wrapServer {
handler = op.NewLegacyServer(provider, *op.DefaultEndpoints) handler = op.RegisterLegacyServer(op.NewLegacyServer(provider, *op.DefaultEndpoints))
} }
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)

View file

@ -25,11 +25,13 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption)
decoder.IgnoreUnknownKeys(true) decoder.IgnoreUnknownKeys(true)
ws := &webServer{ ws := &webServer{
router: chi.NewRouter(),
server: server, server: server,
endpoints: endpoints, endpoints: endpoints,
decoder: decoder, decoder: decoder,
logger: slog.Default(), logger: slog.Default(),
} }
ws.router.Use(cors.New(defaultCORSOptions).Handler)
for _, option := range options { for _, option := range options {
option(ws) option(ws)
@ -45,7 +47,14 @@ type ServerOption func(s *webServer)
// the Server's router. // the Server's router.
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption { func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
return func(s *webServer) { return func(s *webServer) {
s.middleware = m s.router.Use(m...)
}
}
// WithSetRouter allows customization or the Server's router.
func WithSetRouter(set func(chi.Router)) ServerOption {
return func(s *webServer) {
set(s.router)
} }
} }
@ -67,14 +76,17 @@ func WithFallbackLogger(logger *slog.Logger) ServerOption {
} }
type webServer struct { type webServer struct {
http.Handler
server Server server Server
middleware []func(http.Handler) http.Handler router *chi.Mux
endpoints Endpoints endpoints Endpoints
decoder httphelper.Decoder decoder httphelper.Decoder
logger *slog.Logger logger *slog.Logger
} }
func (s *webServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
}
func (s *webServer) getLogger(ctx context.Context) *slog.Logger { func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
if logger, ok := logging.FromContext(ctx); ok { if logger, ok := logging.FromContext(ctx); ok {
return logger return logger
@ -83,27 +95,23 @@ func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
} }
func (s *webServer) createRouter() { func (s *webServer) createRouter() {
router := chi.NewRouter() s.router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
router.Use(cors.New(defaultCORSOptions).Handler) s.router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
router.Use(s.middleware...) s.router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler) s.endpointRoute(s.endpoints.Authorization, s.authorizeHandler)
s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler)) s.endpointRoute(s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
s.endpointRoute(router, s.endpoints.Token, s.tokensHandler) s.endpointRoute(s.endpoints.Token, s.tokensHandler)
s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler)) s.endpointRoute(s.endpoints.Introspection, s.withClient(s.introspectionHandler))
s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler) s.endpointRoute(s.endpoints.Userinfo, s.userInfoHandler)
s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler)) s.endpointRoute(s.endpoints.Revocation, s.withClient(s.revocationHandler))
s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler) s.endpointRoute(s.endpoints.EndSession, s.endSessionHandler)
s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys)) s.endpointRoute(s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
s.Handler = router
} }
func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) { func (s *webServer) endpointRoute(e *Endpoint, hf http.HandlerFunc) {
if e != nil { if e != nil {
router.HandleFunc(e.Relative(), hf) s.router.HandleFunc(e.Relative(), hf)
s.logger.Info("registered route", "endpoint", e.Relative()) s.logger.Info("registered route", "endpoint", e.Relative())
} }
} }

View file

@ -32,7 +32,7 @@ func jwtProfile() (string, error) {
} }
func TestServerRoutes(t *testing.T) { func TestServerRoutes(t *testing.T) {
server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints) server := op.RegisterLegacyServer(op.NewLegacyServer(testProvider, *op.DefaultEndpoints))
storage := testProvider.Storage().(routesTestStorage) storage := testProvider.Storage().(routesTestStorage)
ctx := op.ContextWithIssuer(context.Background(), testIssuer) ctx := op.ContextWithIssuer(context.Background(), testIssuer)

View file

@ -10,37 +10,69 @@ import (
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
// LegacyServer is an implementation of [Server[] that // ExtendedLegacyServer allows embedding [LegacyServer] in a struct,
// simply wraps a [OpenIDProvider]. // so that its methods can be individually overridden.
//
// EXPERIMENTAL: may change until v4
type ExtendedLegacyServer interface {
Server
Provider() OpenIDProvider
Endpoints() Endpoints
}
// RegisterLegacyServer registers a [LegacyServer] or an extension thereof.
// It takes care of registering the IssuerFromRequest middleware
// and Authorization Callback Routes.
// Neither are part of the bare [Server] interface.
//
// EXPERIMENTAL: may change until v4
func RegisterLegacyServer(s ExtendedLegacyServer, options ...ServerOption) http.Handler {
provider := s.Provider()
options = append(options,
WithHTTPMiddleware(intercept(provider.IssuerFromRequest)),
WithSetRouter(func(r chi.Router) {
r.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider))
}),
)
return RegisterServer(s, s.Endpoints(), options...)
}
// LegacyServer is an implementation of [Server] that
// simply wraps an [OpenIDProvider].
// It can be used to transition from the former Provider/Storage // It can be used to transition from the former Provider/Storage
// interfaces to the new Server interface. // interfaces to the new Server interface.
//
// EXPERIMENTAL: may change until v4
type LegacyServer struct { type LegacyServer struct {
UnimplementedServer UnimplementedServer
provider OpenIDProvider provider OpenIDProvider
endpoints Endpoints endpoints Endpoints
} }
// NewLegacyServer wraps provider in a `Server` and returns a handler which is // NewLegacyServer wraps provider in a `Server` implementation
// the Server's router.
// //
// Only non-nil endpoints will be registered on the router. // Only non-nil endpoints will be registered on the router.
// Nil endpoints are disabled. // Nil endpoints are disabled.
// //
// The passed endpoints is also set to the provider, // The passed endpoints is also used for the discovery config,
// to be consistent with the discovery config. // and endpoints already set to the provider are ignored.
// Any `With*Endpoint()` option used on the provider is // Any `With*Endpoint()` option used on the provider is
// therefore ineffective. // therefore ineffective.
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler { //
server := RegisterServer(&LegacyServer{ // EXPERIMENTAL: may change until v4
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) *LegacyServer {
return &LegacyServer{
provider: provider, provider: provider,
endpoints: endpoints, endpoints: endpoints,
}, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest))) }
}
router := chi.NewRouter() func (s *LegacyServer) Provider() OpenIDProvider {
router.Mount("/", server) return s.provider
router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider)) }
return router func (s *LegacyServer) Endpoints() Endpoints {
return s.endpoints
} }
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) { func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {