From dd5b1ca3684f960070a9bdeecfdc9d20525de9ea Mon Sep 17 00:00:00 2001 From: Kory Prince Date: Thu, 16 Nov 2023 04:08:25 -0600 Subject: [PATCH] feat: Allow CORS policy to be configured (#485) --- pkg/op/op.go | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/pkg/op/op.go b/pkg/op/op.go index c4be14f..286dcca 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -90,9 +90,19 @@ type OpenIDProvider interface { type HttpInterceptor func(http.Handler) http.Handler +type corsOptioner interface { + CORSOptions() *cors.Options +} + func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router { router := mux.NewRouter() - router.Use(cors.New(defaultCORSOptions).Handler) + if co, ok := o.(corsOptioner); ok { + if opts := co.CORSOptions(); opts != nil { + router.Use(cors.New(*opts).Handler) + } + } else { + router.Use(cors.New(defaultCORSOptions).Handler) + } router.Use(intercept(o.IssuerFromRequest, interceptors...)) router.HandleFunc(healthEndpoint, healthHandler) router.HandleFunc(readinessEndpoint, readyHandler(o.Probes())) @@ -186,6 +196,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR storage: storage, endpoints: DefaultEndpoints, timer: make(<-chan time.Time), + corsOpts: &defaultCORSOptions, } for _, optFunc := range opOpts { @@ -229,6 +240,7 @@ type Provider struct { timer <-chan time.Time accessTokenVerifierOpts []AccessTokenVerifierOpt idTokenHintVerifierOpts []IDTokenHintVerifierOpt + corsOpts *cors.Options } func (o *Provider) IssuerFromRequest(r *http.Request) string { @@ -387,6 +399,10 @@ func (o *Provider) Probes() []ProbesFn { } } +func (o *Provider) CORSOptions() *cors.Options { + return o.corsOpts +} + func (o *Provider) HttpHandler() http.Handler { return o.httpHandler } @@ -534,12 +550,19 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option { } } +func WithCORSOptions(opts *cors.Options) Option { + return func(o *Provider) error { + o.corsOpts = opts + return nil + } +} + func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler { issuerInterceptor := NewIssuerInterceptor(i) return func(handler http.Handler) http.Handler { for i := len(interceptors) - 1; i >= 0; i-- { handler = interceptors[i](handler) } - return cors.New(defaultCORSOptions).Handler(issuerInterceptor.Handler(handler)) + return issuerInterceptor.Handler(handler) } }