feat(op): allow Legacy Server extension (#466)

This change splits the constructor and registration of the Legacy Server.
This allows it to be extended by struct embedding.
This commit is contained in:
Tim Möhlmann 2023-10-24 10:20:02 +03:00 committed by GitHub
parent 164c5b28c7
commit bab5399859
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 79 additions and 39 deletions

View file

@ -80,7 +80,7 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer
handler := http.Handler(provider)
if wrapServer {
handler = op.NewLegacyServer(provider, *op.DefaultEndpoints)
handler = op.RegisterLegacyServer(op.NewLegacyServer(provider, *op.DefaultEndpoints))
}
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)

View file

@ -25,11 +25,13 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption)
decoder.IgnoreUnknownKeys(true)
ws := &webServer{
router: chi.NewRouter(),
server: server,
endpoints: endpoints,
decoder: decoder,
logger: slog.Default(),
}
ws.router.Use(cors.New(defaultCORSOptions).Handler)
for _, option := range options {
option(ws)
@ -45,7 +47,14 @@ type ServerOption func(s *webServer)
// the Server's router.
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
return func(s *webServer) {
s.middleware = m
s.router.Use(m...)
}
}
// WithSetRouter allows customization or the Server's router.
func WithSetRouter(set func(chi.Router)) ServerOption {
return func(s *webServer) {
set(s.router)
}
}
@ -67,12 +76,15 @@ func WithFallbackLogger(logger *slog.Logger) ServerOption {
}
type webServer struct {
http.Handler
server Server
middleware []func(http.Handler) http.Handler
endpoints Endpoints
decoder httphelper.Decoder
logger *slog.Logger
server Server
router *chi.Mux
endpoints Endpoints
decoder httphelper.Decoder
logger *slog.Logger
}
func (s *webServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
}
func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
@ -83,27 +95,23 @@ func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
}
func (s *webServer) createRouter() {
router := chi.NewRouter()
router.Use(cors.New(defaultCORSOptions).Handler)
router.Use(s.middleware...)
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
s.router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
s.router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
s.router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler)
s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
s.endpointRoute(router, s.endpoints.Token, s.tokensHandler)
s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler))
s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler)
s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler))
s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler)
s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
s.Handler = router
s.endpointRoute(s.endpoints.Authorization, s.authorizeHandler)
s.endpointRoute(s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
s.endpointRoute(s.endpoints.Token, s.tokensHandler)
s.endpointRoute(s.endpoints.Introspection, s.withClient(s.introspectionHandler))
s.endpointRoute(s.endpoints.Userinfo, s.userInfoHandler)
s.endpointRoute(s.endpoints.Revocation, s.withClient(s.revocationHandler))
s.endpointRoute(s.endpoints.EndSession, s.endSessionHandler)
s.endpointRoute(s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
}
func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) {
func (s *webServer) endpointRoute(e *Endpoint, hf http.HandlerFunc) {
if e != nil {
router.HandleFunc(e.Relative(), hf)
s.router.HandleFunc(e.Relative(), hf)
s.logger.Info("registered route", "endpoint", e.Relative())
}
}

View file

@ -32,7 +32,7 @@ func jwtProfile() (string, error) {
}
func TestServerRoutes(t *testing.T) {
server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints)
server := op.RegisterLegacyServer(op.NewLegacyServer(testProvider, *op.DefaultEndpoints))
storage := testProvider.Storage().(routesTestStorage)
ctx := op.ContextWithIssuer(context.Background(), testIssuer)

View file

@ -10,37 +10,69 @@ import (
"github.com/zitadel/oidc/v3/pkg/oidc"
)
// LegacyServer is an implementation of [Server[] that
// simply wraps a [OpenIDProvider].
// ExtendedLegacyServer allows embedding [LegacyServer] in a struct,
// so that its methods can be individually overridden.
//
// EXPERIMENTAL: may change until v4
type ExtendedLegacyServer interface {
Server
Provider() OpenIDProvider
Endpoints() Endpoints
}
// RegisterLegacyServer registers a [LegacyServer] or an extension thereof.
// It takes care of registering the IssuerFromRequest middleware
// and Authorization Callback Routes.
// Neither are part of the bare [Server] interface.
//
// EXPERIMENTAL: may change until v4
func RegisterLegacyServer(s ExtendedLegacyServer, options ...ServerOption) http.Handler {
provider := s.Provider()
options = append(options,
WithHTTPMiddleware(intercept(provider.IssuerFromRequest)),
WithSetRouter(func(r chi.Router) {
r.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider))
}),
)
return RegisterServer(s, s.Endpoints(), options...)
}
// LegacyServer is an implementation of [Server] that
// simply wraps an [OpenIDProvider].
// It can be used to transition from the former Provider/Storage
// interfaces to the new Server interface.
//
// EXPERIMENTAL: may change until v4
type LegacyServer struct {
UnimplementedServer
provider OpenIDProvider
endpoints Endpoints
}
// NewLegacyServer wraps provider in a `Server` and returns a handler which is
// the Server's router.
// NewLegacyServer wraps provider in a `Server` implementation
//
// Only non-nil endpoints will be registered on the router.
// Nil endpoints are disabled.
//
// The passed endpoints is also set to the provider,
// to be consistent with the discovery config.
// The passed endpoints is also used for the discovery config,
// and endpoints already set to the provider are ignored.
// Any `With*Endpoint()` option used on the provider is
// therefore ineffective.
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler {
server := RegisterServer(&LegacyServer{
//
// EXPERIMENTAL: may change until v4
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) *LegacyServer {
return &LegacyServer{
provider: provider,
endpoints: endpoints,
}, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
}
}
router := chi.NewRouter()
router.Mount("/", server)
router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider))
func (s *LegacyServer) Provider() OpenIDProvider {
return s.provider
}
return router
func (s *LegacyServer) Endpoints() Endpoints {
return s.endpoints
}
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {