diff --git a/pkg/op/op.go b/pkg/op/op.go index ba36c61..939ebf8 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -97,9 +97,19 @@ type OpenIDProvider interface { type HttpInterceptor func(http.Handler) http.Handler +type corsOptioner interface { + CORSOptions() *cors.Options +} + func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router { router := chi.NewRouter() - router.Use(cors.New(defaultCORSOptions).Handler) + if co, ok := o.(corsOptioner); ok { + if opts := co.CORSOptions(); opts != nil { + router.Use(cors.New(*opts).Handler) + } + } else { + router.Use(cors.New(defaultCORSOptions).Handler) + } router.Use(intercept(o.IssuerFromRequest, interceptors...)) router.HandleFunc(healthEndpoint, healthHandler) router.HandleFunc(readinessEndpoint, readyHandler(o.Probes())) @@ -224,6 +234,7 @@ func NewProvider(config *Config, storage Storage, issuer func(insecure bool) (Is storage: storage, endpoints: DefaultEndpoints, timer: make(<-chan time.Time), + corsOpts: &defaultCORSOptions, logger: slog.Default(), } @@ -268,6 +279,7 @@ type Provider struct { timer <-chan time.Time accessTokenVerifierOpts []AccessTokenVerifierOpt idTokenHintVerifierOpts []IDTokenHintVerifierOpt + corsOpts *cors.Options logger *slog.Logger } @@ -427,6 +439,10 @@ func (o *Provider) Probes() []ProbesFn { } } +func (o *Provider) CORSOptions() *cors.Options { + return o.corsOpts +} + func (o *Provider) Logger() *slog.Logger { return o.logger } @@ -587,6 +603,13 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option { } } +func WithCORSOptions(opts *cors.Options) Option { + return func(o *Provider) error { + o.corsOpts = opts + return nil + } +} + // WithLogger lets a logger other than slog.Default(). // // EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20 @@ -603,6 +626,6 @@ func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handle for i := len(interceptors) - 1; i >= 0; i-- { handler = interceptors[i](handler) } - return cors.New(defaultCORSOptions).Handler(issuerInterceptor.Handler(handler)) + return issuerInterceptor.Handler(handler) } } diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index 6d379c6..2220e44 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -29,15 +29,19 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) server: server, endpoints: endpoints, decoder: decoder, + corsOpts: &defaultCORSOptions, logger: slog.Default(), } - ws.router.Use(cors.New(defaultCORSOptions).Handler) for _, option := range options { option(ws) } ws.createRouter() + ws.handler = ws.router + if ws.corsOpts != nil { + ws.handler = cors.New(*ws.corsOpts).Handler(ws.router) + } return ws } @@ -66,6 +70,13 @@ func WithDecoder(decoder httphelper.Decoder) ServerOption { } } +// WithServerCORSOptions sets the CORS policy for the Server's router. +func WithServerCORSOptions(opts *cors.Options) ServerOption { + return func(s *webServer) { + s.corsOpts = opts + } +} + // WithFallbackLogger overrides the fallback logger, which // is used when no logger was found in the context. // Defaults to [slog.Default]. @@ -78,13 +89,15 @@ func WithFallbackLogger(logger *slog.Logger) ServerOption { type webServer struct { server Server router *chi.Mux + handler http.Handler endpoints Endpoints decoder httphelper.Decoder + corsOpts *cors.Options logger *slog.Logger } func (s *webServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - s.router.ServeHTTP(w, r) + s.handler.ServeHTTP(w, r) } func (s *webServer) getLogger(ctx context.Context) *slog.Logger {