fix: improve interceptor handling

This commit is contained in:
Livio Amstutz 2020-08-28 14:40:53 +02:00
parent d02653e75d
commit 87884357fb
2 changed files with 38 additions and 30 deletions

View file

@ -51,7 +51,7 @@ type DefaultOP struct {
http http.Handler http http.Handler
decoder *schema.Decoder decoder *schema.Decoder
encoder *schema.Encoder encoder *schema.Encoder
interceptor HttpInterceptor interceptors []HttpInterceptor
retry func(int) (bool, int) retry func(int) (bool, int)
timer <-chan time.Time timer <-chan time.Time
} }
@ -132,9 +132,9 @@ func WithCustomKeysEndpoint(endpoint Endpoint) DefaultOPOpts {
} }
} }
func WithHttpInterceptor(h HttpInterceptor) DefaultOPOpts { func WithHttpInterceptors(interceptors ...HttpInterceptor) DefaultOPOpts {
return func(o *DefaultOP) error { return func(o *DefaultOP) error {
o.interceptor = h o.interceptors = append(o.interceptors, interceptors...)
return nil return nil
} }
} }
@ -185,7 +185,7 @@ func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts .
p.verifier = rp.NewDefaultVerifier(config.Issuer, "", p, rp.WithIgnoreAudience(), rp.WithIgnoreExpiration()) p.verifier = rp.NewDefaultVerifier(config.Issuer, "", p, rp.WithIgnoreAudience(), rp.WithIgnoreExpiration())
p.http = CreateRouter(p, p.interceptor) p.http = CreateRouter(p, p.interceptors...)
p.decoder = schema.NewDecoder() p.decoder = schema.NewDecoder()
p.decoder.IgnoreUnknownKeys(true) p.decoder.IgnoreUnknownKeys(true)

View file

@ -27,22 +27,14 @@ type OpenIDProvider interface {
HttpHandler() http.Handler HttpHandler() http.Handler
} }
type HttpInterceptor func(http.HandlerFunc) http.HandlerFunc type HttpInterceptor func(http.Handler) http.Handler
var DefaultInterceptor = func(h http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h(w, r)
})
}
var allowAllOrigins = func(_ string) bool { var allowAllOrigins = func(_ string) bool {
return true return true
} }
func CreateRouter(o OpenIDProvider, h HttpInterceptor) *mux.Router { func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router {
if h == nil { intercept := buildInterceptor(interceptors...)
h = DefaultInterceptor
}
router := mux.NewRouter() router := mux.NewRouter()
router.Use(handlers.CORS( router.Use(handlers.CORS(
handlers.AllowCredentials(), handlers.AllowCredentials(),
@ -52,11 +44,27 @@ func CreateRouter(o OpenIDProvider, h HttpInterceptor) *mux.Router {
router.HandleFunc(healthzEndpoint, Healthz) router.HandleFunc(healthzEndpoint, Healthz)
router.HandleFunc(readinessEndpoint, o.HandleReady) router.HandleFunc(readinessEndpoint, o.HandleReady)
router.HandleFunc(oidc.DiscoveryEndpoint, o.HandleDiscovery) router.HandleFunc(oidc.DiscoveryEndpoint, o.HandleDiscovery)
router.HandleFunc(o.AuthorizationEndpoint().Relative(), h(o.HandleAuthorize)) router.Handle(o.AuthorizationEndpoint().Relative(), intercept(o.HandleAuthorize))
router.HandleFunc(o.AuthorizationEndpoint().Relative()+"/{id}", h(o.HandleAuthorizeCallback)) router.Handle(o.AuthorizationEndpoint().Relative()+"/{id}", intercept(o.HandleAuthorizeCallback))
router.HandleFunc(o.TokenEndpoint().Relative(), h(o.HandleExchange)) router.Handle(o.TokenEndpoint().Relative(), intercept(o.HandleExchange))
router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo) router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo)
router.HandleFunc(o.EndSessionEndpoint().Relative(), h(o.HandleEndSession)) router.Handle(o.EndSessionEndpoint().Relative(), intercept(o.HandleEndSession))
router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys) router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys)
return router return router
} }
func buildInterceptor(interceptors ...HttpInterceptor) func(http.HandlerFunc) http.Handler {
return func(handlerFunc http.HandlerFunc) http.Handler {
handler := handlerFuncToHandler(handlerFunc)
for i := len(interceptors) - 1; i >= 0; i-- {
handler = interceptors[i](handler)
}
return handler
}
}
func handlerFuncToHandler(handlerFunc http.HandlerFunc) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerFunc(w, r)
})
}