diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index 732df21..a42da6a 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -42,18 +42,18 @@ var ( ) type DefaultOP struct { - config *Config - endpoints *endpoints - storage Storage - signer Signer - verifier rp.Verifier - crypto Crypto - http http.Handler - decoder *schema.Decoder - encoder *schema.Encoder - interceptor HttpInterceptor - retry func(int) (bool, int) - timer <-chan time.Time + config *Config + endpoints *endpoints + storage Storage + signer Signer + verifier rp.Verifier + crypto Crypto + http http.Handler + decoder *schema.Decoder + encoder *schema.Encoder + interceptors []HttpInterceptor + retry func(int) (bool, int) + timer <-chan time.Time } type Config struct { @@ -132,9 +132,9 @@ func WithCustomKeysEndpoint(endpoint Endpoint) DefaultOPOpts { } } -func WithHttpInterceptor(h HttpInterceptor) DefaultOPOpts { +func WithHttpInterceptors(interceptors ...HttpInterceptor) DefaultOPOpts { return func(o *DefaultOP) error { - o.interceptor = h + o.interceptors = append(o.interceptors, interceptors...) return nil } } @@ -185,7 +185,7 @@ func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts . p.verifier = rp.NewDefaultVerifier(config.Issuer, "", p, rp.WithIgnoreAudience(), rp.WithIgnoreExpiration()) - p.http = CreateRouter(p, p.interceptor) + p.http = CreateRouter(p, p.interceptors...) p.decoder = schema.NewDecoder() p.decoder.IgnoreUnknownKeys(true) diff --git a/pkg/op/op.go b/pkg/op/op.go index fda9315..624a8a1 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -27,22 +27,14 @@ type OpenIDProvider interface { HttpHandler() http.Handler } -type HttpInterceptor func(http.HandlerFunc) http.HandlerFunc - -var DefaultInterceptor = func(h http.HandlerFunc) http.HandlerFunc { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - h(w, r) - }) -} +type HttpInterceptor func(http.Handler) http.Handler var allowAllOrigins = func(_ string) bool { return true } -func CreateRouter(o OpenIDProvider, h HttpInterceptor) *mux.Router { - if h == nil { - h = DefaultInterceptor - } +func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router { + intercept := buildInterceptor(interceptors...) router := mux.NewRouter() router.Use(handlers.CORS( handlers.AllowCredentials(), @@ -52,11 +44,27 @@ func CreateRouter(o OpenIDProvider, h HttpInterceptor) *mux.Router { router.HandleFunc(healthzEndpoint, Healthz) router.HandleFunc(readinessEndpoint, o.HandleReady) router.HandleFunc(oidc.DiscoveryEndpoint, o.HandleDiscovery) - router.HandleFunc(o.AuthorizationEndpoint().Relative(), h(o.HandleAuthorize)) - router.HandleFunc(o.AuthorizationEndpoint().Relative()+"/{id}", h(o.HandleAuthorizeCallback)) - router.HandleFunc(o.TokenEndpoint().Relative(), h(o.HandleExchange)) + router.Handle(o.AuthorizationEndpoint().Relative(), intercept(o.HandleAuthorize)) + router.Handle(o.AuthorizationEndpoint().Relative()+"/{id}", intercept(o.HandleAuthorizeCallback)) + router.Handle(o.TokenEndpoint().Relative(), intercept(o.HandleExchange)) router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo) - router.HandleFunc(o.EndSessionEndpoint().Relative(), h(o.HandleEndSession)) + router.Handle(o.EndSessionEndpoint().Relative(), intercept(o.HandleEndSession)) router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys) return router } + +func buildInterceptor(interceptors ...HttpInterceptor) func(http.HandlerFunc) http.Handler { + return func(handlerFunc http.HandlerFunc) http.Handler { + handler := handlerFuncToHandler(handlerFunc) + for i := len(interceptors) - 1; i >= 0; i-- { + handler = interceptors[i](handler) + } + return handler + } +} + +func handlerFuncToHandler(handlerFunc http.HandlerFunc) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlerFunc(w, r) + }) +}