feat: Allow CORS policy to be configured (#484)
* Add configurable CORS policy in OpenIDProvider * Add configurable CORS policy to Server * remove duplicated CORS middleware * Allow nil CORS policy to be set to disable CORS middleware * create a separate handler on webServer so type assertion works in tests
This commit is contained in:
parent
ce55068aa9
commit
7b64687990
2 changed files with 40 additions and 4 deletions
27
pkg/op/op.go
27
pkg/op/op.go
|
@ -97,9 +97,19 @@ type OpenIDProvider interface {
|
|||
|
||||
type HttpInterceptor func(http.Handler) http.Handler
|
||||
|
||||
type corsOptioner interface {
|
||||
CORSOptions() *cors.Options
|
||||
}
|
||||
|
||||
func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router {
|
||||
router := chi.NewRouter()
|
||||
router.Use(cors.New(defaultCORSOptions).Handler)
|
||||
if co, ok := o.(corsOptioner); ok {
|
||||
if opts := co.CORSOptions(); opts != nil {
|
||||
router.Use(cors.New(*opts).Handler)
|
||||
}
|
||||
} else {
|
||||
router.Use(cors.New(defaultCORSOptions).Handler)
|
||||
}
|
||||
router.Use(intercept(o.IssuerFromRequest, interceptors...))
|
||||
router.HandleFunc(healthEndpoint, healthHandler)
|
||||
router.HandleFunc(readinessEndpoint, readyHandler(o.Probes()))
|
||||
|
@ -224,6 +234,7 @@ func NewProvider(config *Config, storage Storage, issuer func(insecure bool) (Is
|
|||
storage: storage,
|
||||
endpoints: DefaultEndpoints,
|
||||
timer: make(<-chan time.Time),
|
||||
corsOpts: &defaultCORSOptions,
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
|
@ -268,6 +279,7 @@ type Provider struct {
|
|||
timer <-chan time.Time
|
||||
accessTokenVerifierOpts []AccessTokenVerifierOpt
|
||||
idTokenHintVerifierOpts []IDTokenHintVerifierOpt
|
||||
corsOpts *cors.Options
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
|
@ -427,6 +439,10 @@ func (o *Provider) Probes() []ProbesFn {
|
|||
}
|
||||
}
|
||||
|
||||
func (o *Provider) CORSOptions() *cors.Options {
|
||||
return o.corsOpts
|
||||
}
|
||||
|
||||
func (o *Provider) Logger() *slog.Logger {
|
||||
return o.logger
|
||||
}
|
||||
|
@ -587,6 +603,13 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCORSOptions(opts *cors.Options) Option {
|
||||
return func(o *Provider) error {
|
||||
o.corsOpts = opts
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger lets a logger other than slog.Default().
|
||||
//
|
||||
// EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20
|
||||
|
@ -603,6 +626,6 @@ func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handle
|
|||
for i := len(interceptors) - 1; i >= 0; i-- {
|
||||
handler = interceptors[i](handler)
|
||||
}
|
||||
return cors.New(defaultCORSOptions).Handler(issuerInterceptor.Handler(handler))
|
||||
return issuerInterceptor.Handler(handler)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,15 +29,19 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption)
|
|||
server: server,
|
||||
endpoints: endpoints,
|
||||
decoder: decoder,
|
||||
corsOpts: &defaultCORSOptions,
|
||||
logger: slog.Default(),
|
||||
}
|
||||
ws.router.Use(cors.New(defaultCORSOptions).Handler)
|
||||
|
||||
for _, option := range options {
|
||||
option(ws)
|
||||
}
|
||||
|
||||
ws.createRouter()
|
||||
ws.handler = ws.router
|
||||
if ws.corsOpts != nil {
|
||||
ws.handler = cors.New(*ws.corsOpts).Handler(ws.router)
|
||||
}
|
||||
return ws
|
||||
}
|
||||
|
||||
|
@ -66,6 +70,13 @@ func WithDecoder(decoder httphelper.Decoder) ServerOption {
|
|||
}
|
||||
}
|
||||
|
||||
// WithServerCORSOptions sets the CORS policy for the Server's router.
|
||||
func WithServerCORSOptions(opts *cors.Options) ServerOption {
|
||||
return func(s *webServer) {
|
||||
s.corsOpts = opts
|
||||
}
|
||||
}
|
||||
|
||||
// WithFallbackLogger overrides the fallback logger, which
|
||||
// is used when no logger was found in the context.
|
||||
// Defaults to [slog.Default].
|
||||
|
@ -78,13 +89,15 @@ func WithFallbackLogger(logger *slog.Logger) ServerOption {
|
|||
type webServer struct {
|
||||
server Server
|
||||
router *chi.Mux
|
||||
handler http.Handler
|
||||
endpoints Endpoints
|
||||
decoder httphelper.Decoder
|
||||
corsOpts *cors.Options
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *webServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.router.ServeHTTP(w, r)
|
||||
s.handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue