diff --git a/pkg/op/config.go b/pkg/op/config.go index c383480..9fec7cc 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -54,7 +54,24 @@ type Configuration interface { type IssuerFromRequest func(r *http.Request) string func IssuerFromHost(path string) func(bool) (IssuerFromRequest, error) { - return issuerFromForwardedOrHost(path, false) + return issuerFromForwardedOrHost(path, new(issuerConfig)) +} + +type IssuerFromOption func(c *issuerConfig) + +// WithIssuerFromCustomHeaders can be used to customize the header names used. +// The same rules apply where the first successful host is returned. +func WithIssuerFromCustomHeaders(headers ...string) IssuerFromOption { + return func(c *issuerConfig) { + for i, h := range headers { + headers[i] = http.CanonicalHeaderKey(h) + } + c.headers = headers + } +} + +type issuerConfig struct { + headers []string } // IssuerFromForwardedOrHost tries to establish the Issuer based @@ -64,11 +81,18 @@ func IssuerFromHost(path string) func(bool) (IssuerFromRequest, error) { // If the Forwarded header is not present, no host field is found, // or there is a parser error the Request Host will be used as a fallback. // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Forwarded -func IssuerFromForwardedOrHost(path string) func(bool) (IssuerFromRequest, error) { - return issuerFromForwardedOrHost(path, true) +func IssuerFromForwardedOrHost(path string, opts ...IssuerFromOption) func(bool) (IssuerFromRequest, error) { + c := &issuerConfig{ + headers: []string{http.CanonicalHeaderKey("forwarded")}, + } + for _, opt := range opts { + opt(c) + } + + return issuerFromForwardedOrHost(path, c) } -func issuerFromForwardedOrHost(path string, parseForwarded bool) func(bool) (IssuerFromRequest, error) { +func issuerFromForwardedOrHost(path string, c *issuerConfig) func(bool) (IssuerFromRequest, error) { return func(allowInsecure bool) (IssuerFromRequest, error) { issuerPath, err := url.Parse(path) if err != nil { @@ -78,26 +102,26 @@ func issuerFromForwardedOrHost(path string, parseForwarded bool) func(bool) (Iss return nil, err } return func(r *http.Request) string { - if parseForwarded { - if host, ok := hostFromForwarded(r); ok { - return dynamicIssuer(host, path, allowInsecure) - } + if host, ok := hostFromForwarded(r, c.headers); ok { + return dynamicIssuer(host, path, allowInsecure) } return dynamicIssuer(r.Host, path, allowInsecure) }, nil } } -func hostFromForwarded(r *http.Request) (host string, ok bool) { - fwd, err := httpforwarded.ParseFromRequest(r) - if err != nil { - log.Printf("Err: issuer from forwarded header: %v", err) // TODO change to slog on next branch - return "", false +func hostFromForwarded(r *http.Request, headers []string) (host string, ok bool) { + for _, header := range headers { + hosts, err := httpforwarded.ParseParameter("host", r.Header[header]) + if err != nil { + log.Printf("Err: issuer from forwarded header: %v", err) // TODO change to slog on next branch + continue + } + if len(hosts) > 0 { + return hosts[0], true + } } - if fwd == nil || len(fwd["host"]) == 0 { - return "", false - } - return fwd["host"][0], true + return "", false } func StaticIssuer(issuer string) func(bool) (IssuerFromRequest, error) { diff --git a/pkg/op/config_test.go b/pkg/op/config_test.go index dcafc3a..d739348 100644 --- a/pkg/op/config_test.go +++ b/pkg/op/config_test.go @@ -1,6 +1,7 @@ package op import ( + "net/http" "net/http/httptest" "net/url" "testing" @@ -264,9 +265,10 @@ func TestIssuerFromHost(t *testing.T) { func TestIssuerFromForwardedOrHost(t *testing.T) { type args struct { - path string - target string - forwarded []string + path string + opts []IssuerFromOption + target string + header map[string][]string } type res struct { issuer string @@ -279,9 +281,9 @@ func TestIssuerFromForwardedOrHost(t *testing.T) { { "header parse error", args{ - path: "/custom/", - target: "https://issuer.com", - forwarded: []string{"~~~"}, + path: "/custom/", + target: "https://issuer.com", + header: map[string][]string{"Forwarded": {"~~~~"}}, }, res{ issuer: "https://issuer.com/custom/", @@ -303,9 +305,9 @@ func TestIssuerFromForwardedOrHost(t *testing.T) { args{ path: "/custom/", target: "https://issuer.com", - forwarded: []string{ + header: map[string][]string{"Forwarded": { `by=identifier;for=identifier;proto=https`, - }, + }}, }, res{ issuer: "https://issuer.com/custom/", @@ -316,9 +318,9 @@ func TestIssuerFromForwardedOrHost(t *testing.T) { args{ path: "/custom/", target: "https://issuer.com", - forwarded: []string{ + header: map[string][]string{"Forwarded": { `by=identifier;for=identifier;host=first.com;proto=https`, - }, + }}, }, res{ issuer: "https://first.com/custom/", @@ -329,9 +331,9 @@ func TestIssuerFromForwardedOrHost(t *testing.T) { args{ path: "/custom/", target: "https://issuer.com", - forwarded: []string{ + header: map[string][]string{"Forwarded": { `by=identifier;for=identifier;host=first.com;proto=https,host=second.com`, - }, + }}, }, res{ issuer: "https://first.com/custom/", @@ -342,23 +344,45 @@ func TestIssuerFromForwardedOrHost(t *testing.T) { args{ path: "/custom/", target: "https://issuer.com", - forwarded: []string{ + header: map[string][]string{"Forwarded": { `by=identifier;for=identifier;host=first.com;proto=https,host=second.com`, `by=identifier;for=identifier;host=third.com;proto=https`, - }, + }}, }, res{ issuer: "https://first.com/custom/", }, }, + { + "custom header first", + args{ + path: "/custom/", + target: "https://issuer.com", + header: map[string][]string{ + "Forwarded": { + `by=identifier;for=identifier;host=first.com;proto=https,host=second.com`, + `by=identifier;for=identifier;host=third.com;proto=https`, + }, + "X-Custom-Forwarded": { + `by=identifier;for=identifier;host=custom.com;proto=https,host=custom2.com`, + }, + }, + opts: []IssuerFromOption{ + WithIssuerFromCustomHeaders("x-custom-forwarded"), + }, + }, + res{ + issuer: "https://custom.com/custom/", + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - issuer, err := IssuerFromForwardedOrHost(tt.args.path)(false) + issuer, err := IssuerFromForwardedOrHost(tt.args.path, tt.args.opts...)(false) require.NoError(t, err) req := httptest.NewRequest("", tt.args.target, nil) - if tt.args.forwarded != nil { - req.Header["Forwarded"] = tt.args.forwarded + for k, v := range tt.args.header { + req.Header[http.CanonicalHeaderKey(k)] = v } assert.Equal(t, tt.res.issuer, issuer(req)) })