feat(op): allow Legacy Server extension (#466)
This change splits the constructor and registration of the Legacy Server. This allows it to be extended by struct embedding.
This commit is contained in:
parent
164c5b28c7
commit
bab5399859
4 changed files with 79 additions and 39 deletions
|
@ -80,7 +80,7 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer
|
||||||
|
|
||||||
handler := http.Handler(provider)
|
handler := http.Handler(provider)
|
||||||
if wrapServer {
|
if wrapServer {
|
||||||
handler = op.NewLegacyServer(provider, *op.DefaultEndpoints)
|
handler = op.RegisterLegacyServer(op.NewLegacyServer(provider, *op.DefaultEndpoints))
|
||||||
}
|
}
|
||||||
|
|
||||||
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
|
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
|
||||||
|
|
|
@ -25,11 +25,13 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption)
|
||||||
decoder.IgnoreUnknownKeys(true)
|
decoder.IgnoreUnknownKeys(true)
|
||||||
|
|
||||||
ws := &webServer{
|
ws := &webServer{
|
||||||
|
router: chi.NewRouter(),
|
||||||
server: server,
|
server: server,
|
||||||
endpoints: endpoints,
|
endpoints: endpoints,
|
||||||
decoder: decoder,
|
decoder: decoder,
|
||||||
logger: slog.Default(),
|
logger: slog.Default(),
|
||||||
}
|
}
|
||||||
|
ws.router.Use(cors.New(defaultCORSOptions).Handler)
|
||||||
|
|
||||||
for _, option := range options {
|
for _, option := range options {
|
||||||
option(ws)
|
option(ws)
|
||||||
|
@ -45,7 +47,14 @@ type ServerOption func(s *webServer)
|
||||||
// the Server's router.
|
// the Server's router.
|
||||||
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
|
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
|
||||||
return func(s *webServer) {
|
return func(s *webServer) {
|
||||||
s.middleware = m
|
s.router.Use(m...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSetRouter allows customization or the Server's router.
|
||||||
|
func WithSetRouter(set func(chi.Router)) ServerOption {
|
||||||
|
return func(s *webServer) {
|
||||||
|
set(s.router)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,12 +76,15 @@ func WithFallbackLogger(logger *slog.Logger) ServerOption {
|
||||||
}
|
}
|
||||||
|
|
||||||
type webServer struct {
|
type webServer struct {
|
||||||
http.Handler
|
server Server
|
||||||
server Server
|
router *chi.Mux
|
||||||
middleware []func(http.Handler) http.Handler
|
endpoints Endpoints
|
||||||
endpoints Endpoints
|
decoder httphelper.Decoder
|
||||||
decoder httphelper.Decoder
|
logger *slog.Logger
|
||||||
logger *slog.Logger
|
}
|
||||||
|
|
||||||
|
func (s *webServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.router.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
|
func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
|
||||||
|
@ -83,27 +95,23 @@ func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *webServer) createRouter() {
|
func (s *webServer) createRouter() {
|
||||||
router := chi.NewRouter()
|
s.router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
|
||||||
router.Use(cors.New(defaultCORSOptions).Handler)
|
s.router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
|
||||||
router.Use(s.middleware...)
|
s.router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
|
||||||
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
|
|
||||||
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
|
|
||||||
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
|
|
||||||
|
|
||||||
s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler)
|
s.endpointRoute(s.endpoints.Authorization, s.authorizeHandler)
|
||||||
s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
|
s.endpointRoute(s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
|
||||||
s.endpointRoute(router, s.endpoints.Token, s.tokensHandler)
|
s.endpointRoute(s.endpoints.Token, s.tokensHandler)
|
||||||
s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler))
|
s.endpointRoute(s.endpoints.Introspection, s.withClient(s.introspectionHandler))
|
||||||
s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler)
|
s.endpointRoute(s.endpoints.Userinfo, s.userInfoHandler)
|
||||||
s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler))
|
s.endpointRoute(s.endpoints.Revocation, s.withClient(s.revocationHandler))
|
||||||
s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler)
|
s.endpointRoute(s.endpoints.EndSession, s.endSessionHandler)
|
||||||
s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
|
s.endpointRoute(s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
|
||||||
s.Handler = router
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) {
|
func (s *webServer) endpointRoute(e *Endpoint, hf http.HandlerFunc) {
|
||||||
if e != nil {
|
if e != nil {
|
||||||
router.HandleFunc(e.Relative(), hf)
|
s.router.HandleFunc(e.Relative(), hf)
|
||||||
s.logger.Info("registered route", "endpoint", e.Relative())
|
s.logger.Info("registered route", "endpoint", e.Relative())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,7 @@ func jwtProfile() (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerRoutes(t *testing.T) {
|
func TestServerRoutes(t *testing.T) {
|
||||||
server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints)
|
server := op.RegisterLegacyServer(op.NewLegacyServer(testProvider, *op.DefaultEndpoints))
|
||||||
|
|
||||||
storage := testProvider.Storage().(routesTestStorage)
|
storage := testProvider.Storage().(routesTestStorage)
|
||||||
ctx := op.ContextWithIssuer(context.Background(), testIssuer)
|
ctx := op.ContextWithIssuer(context.Background(), testIssuer)
|
||||||
|
|
|
@ -10,37 +10,69 @@ import (
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LegacyServer is an implementation of [Server[] that
|
// ExtendedLegacyServer allows embedding [LegacyServer] in a struct,
|
||||||
// simply wraps a [OpenIDProvider].
|
// so that its methods can be individually overridden.
|
||||||
|
//
|
||||||
|
// EXPERIMENTAL: may change until v4
|
||||||
|
type ExtendedLegacyServer interface {
|
||||||
|
Server
|
||||||
|
Provider() OpenIDProvider
|
||||||
|
Endpoints() Endpoints
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterLegacyServer registers a [LegacyServer] or an extension thereof.
|
||||||
|
// It takes care of registering the IssuerFromRequest middleware
|
||||||
|
// and Authorization Callback Routes.
|
||||||
|
// Neither are part of the bare [Server] interface.
|
||||||
|
//
|
||||||
|
// EXPERIMENTAL: may change until v4
|
||||||
|
func RegisterLegacyServer(s ExtendedLegacyServer, options ...ServerOption) http.Handler {
|
||||||
|
provider := s.Provider()
|
||||||
|
options = append(options,
|
||||||
|
WithHTTPMiddleware(intercept(provider.IssuerFromRequest)),
|
||||||
|
WithSetRouter(func(r chi.Router) {
|
||||||
|
r.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
return RegisterServer(s, s.Endpoints(), options...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LegacyServer is an implementation of [Server] that
|
||||||
|
// simply wraps an [OpenIDProvider].
|
||||||
// It can be used to transition from the former Provider/Storage
|
// It can be used to transition from the former Provider/Storage
|
||||||
// interfaces to the new Server interface.
|
// interfaces to the new Server interface.
|
||||||
|
//
|
||||||
|
// EXPERIMENTAL: may change until v4
|
||||||
type LegacyServer struct {
|
type LegacyServer struct {
|
||||||
UnimplementedServer
|
UnimplementedServer
|
||||||
provider OpenIDProvider
|
provider OpenIDProvider
|
||||||
endpoints Endpoints
|
endpoints Endpoints
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewLegacyServer wraps provider in a `Server` and returns a handler which is
|
// NewLegacyServer wraps provider in a `Server` implementation
|
||||||
// the Server's router.
|
|
||||||
//
|
//
|
||||||
// Only non-nil endpoints will be registered on the router.
|
// Only non-nil endpoints will be registered on the router.
|
||||||
// Nil endpoints are disabled.
|
// Nil endpoints are disabled.
|
||||||
//
|
//
|
||||||
// The passed endpoints is also set to the provider,
|
// The passed endpoints is also used for the discovery config,
|
||||||
// to be consistent with the discovery config.
|
// and endpoints already set to the provider are ignored.
|
||||||
// Any `With*Endpoint()` option used on the provider is
|
// Any `With*Endpoint()` option used on the provider is
|
||||||
// therefore ineffective.
|
// therefore ineffective.
|
||||||
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler {
|
//
|
||||||
server := RegisterServer(&LegacyServer{
|
// EXPERIMENTAL: may change until v4
|
||||||
|
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) *LegacyServer {
|
||||||
|
return &LegacyServer{
|
||||||
provider: provider,
|
provider: provider,
|
||||||
endpoints: endpoints,
|
endpoints: endpoints,
|
||||||
}, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
|
}
|
||||||
|
}
|
||||||
|
|
||||||
router := chi.NewRouter()
|
func (s *LegacyServer) Provider() OpenIDProvider {
|
||||||
router.Mount("/", server)
|
return s.provider
|
||||||
router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider))
|
}
|
||||||
|
|
||||||
return router
|
func (s *LegacyServer) Endpoints() Endpoints {
|
||||||
|
return s.endpoints
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {
|
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue