From 37e01449e0cf479d90c85aaa8ccc7fd452130fb1 Mon Sep 17 00:00:00 2001 From: Kory Prince Date: Mon, 13 Nov 2023 18:31:39 -0600 Subject: [PATCH] Allow nil CORS policy to be set to disable CORS middleware --- pkg/op/op.go | 14 ++++++++------ pkg/op/server_http.go | 11 +++++++---- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/pkg/op/op.go b/pkg/op/op.go index 1a7ccdb..939ebf8 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -98,13 +98,15 @@ type OpenIDProvider interface { type HttpInterceptor func(http.Handler) http.Handler type corsOptioner interface { - CORSOptions() cors.Options + CORSOptions() *cors.Options } func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router { router := chi.NewRouter() if co, ok := o.(corsOptioner); ok { - router.Use(cors.New(co.CORSOptions()).Handler) + if opts := co.CORSOptions(); opts != nil { + router.Use(cors.New(*opts).Handler) + } } else { router.Use(cors.New(defaultCORSOptions).Handler) } @@ -232,7 +234,7 @@ func NewProvider(config *Config, storage Storage, issuer func(insecure bool) (Is storage: storage, endpoints: DefaultEndpoints, timer: make(<-chan time.Time), - corsOpts: defaultCORSOptions, + corsOpts: &defaultCORSOptions, logger: slog.Default(), } @@ -277,7 +279,7 @@ type Provider struct { timer <-chan time.Time accessTokenVerifierOpts []AccessTokenVerifierOpt idTokenHintVerifierOpts []IDTokenHintVerifierOpt - corsOpts cors.Options + corsOpts *cors.Options logger *slog.Logger } @@ -437,7 +439,7 @@ func (o *Provider) Probes() []ProbesFn { } } -func (o *Provider) CORSOptions() cors.Options { +func (o *Provider) CORSOptions() *cors.Options { return o.corsOpts } @@ -601,7 +603,7 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option { } } -func WithCORSOptions(opts cors.Options) Option { +func WithCORSOptions(opts *cors.Options) Option { return func(o *Provider) error { o.corsOpts = opts return nil diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index 6fd2a29..34a322f 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -29,7 +29,7 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) server: server, endpoints: endpoints, decoder: decoder, - corsOpts: defaultCORSOptions, + corsOpts: &defaultCORSOptions, logger: slog.Default(), } @@ -38,7 +38,10 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) } ws.createRouter() - return cors.New(ws.corsOpts).Handler(ws) + if ws.corsOpts != nil { + return cors.New(*ws.corsOpts).Handler(ws) + } + return ws } type ServerOption func(s *webServer) @@ -67,7 +70,7 @@ func WithDecoder(decoder httphelper.Decoder) ServerOption { } // WithServerCORSOptions sets the CORS policy for the Server's router. -func WithServerCORSOptions(opts cors.Options) ServerOption { +func WithServerCORSOptions(opts *cors.Options) ServerOption { return func(s *webServer) { s.corsOpts = opts } @@ -87,7 +90,7 @@ type webServer struct { router *chi.Mux endpoints Endpoints decoder httphelper.Decoder - corsOpts cors.Options + corsOpts *cors.Options logger *slog.Logger }