rp: allow to set custom URL parameters (#273)

* rp: allow to set prompts in AuthURLHandler

Fixes #241

* rp: configuration for handlers with URL options to call RS

Fixes #265
This commit is contained in:
Tim Möhlmann 2023-02-13 11:28:46 +02:00 committed by GitHub
parent ff2729cb23
commit c8d61c0858
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 20 deletions

View file

@ -54,10 +54,12 @@ func main() {
return uuid.New().String() return uuid.New().String()
} }
// register the AuthURLHandler at your preferred path // register the AuthURLHandler at your preferred path.
// the AuthURLHandler creates the auth request and redirects the user to the auth server // the AuthURLHandler creates the auth request and redirects the user to the auth server.
// including state handling with secure cookie and the possibility to use PKCE // including state handling with secure cookie and the possibility to use PKCE.
http.Handle("/login", rp.AuthURLHandler(state, provider)) // Prompts can optionally be set to inform the server of
// any messages that need to be prompted back to the user.
http.Handle("/login", rp.AuthURLHandler(state, provider, rp.WithPromptURLParam("Welcome back!")))
// for demonstration purposes the returned userinfo response is written as JSON object onto response // for demonstration purposes the returned userinfo response is written as JSON object onto response
marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) { marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {

View file

@ -75,7 +75,10 @@ func TestRelyingPartySession(t *testing.T) {
state := "state-" + strconv.FormatInt(seed.Int63(), 25) state := "state-" + strconv.FormatInt(seed.Int63(), 25)
capturedW := httptest.NewRecorder() capturedW := httptest.NewRecorder()
get := httptest.NewRequest("GET", localURL.String(), nil) get := httptest.NewRequest("GET", localURL.String(), nil)
rp.AuthURLHandler(func() string { return state }, provider)(capturedW, get) rp.AuthURLHandler(func() string { return state }, provider,
rp.WithPromptURLParam("Hello, World!", "Goodbye, World!"),
rp.WithURLParam("custom", "param"),
)(capturedW, get)
defer func() { defer func() {
if t.Failed() { if t.Failed() {
@ -84,6 +87,8 @@ func TestRelyingPartySession(t *testing.T) {
}() }()
require.GreaterOrEqual(t, capturedW.Code, 200, "captured response code") require.GreaterOrEqual(t, capturedW.Code, 200, "captured response code")
require.Less(t, capturedW.Code, 400, "captured response code") require.Less(t, capturedW.Code, 400, "captured response code")
require.Contains(t, capturedW.Body.String(), `prompt=Hello%2C+World%21+Goodbye%2C+World%21`)
require.Contains(t, capturedW.Body.String(), `custom=param`)
//nolint:bodyclose //nolint:bodyclose
resp := capturedW.Result() resp := capturedW.Result()
@ -140,7 +145,7 @@ func TestRelyingPartySession(t *testing.T) {
email = info.GetEmail() email = info.GetEmail()
http.Redirect(w, r, targetURL, 302) http.Redirect(w, r, targetURL, 302)
} }
rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider)(capturedW, get) rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider, rp.WithURLParam("custom", "param"))(capturedW, get)
defer func() { defer func() {
if t.Failed() { if t.Failed() {
@ -150,6 +155,7 @@ func TestRelyingPartySession(t *testing.T) {
}() }()
require.Less(t, capturedW.Code, 400, "token exchange response code") require.Less(t, capturedW.Code, 400, "token exchange response code")
require.Less(t, capturedW.Code, 400, "token exchange response code") require.Less(t, capturedW.Code, 400, "token exchange response code")
// TODO: how to check the custom header was sent to the server?
//nolint:bodyclose //nolint:bodyclose
resp = capturedW.Result() resp = capturedW.Result()
@ -193,6 +199,13 @@ func TestRelyingPartySession(t *testing.T) {
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "") _, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
assert.Errorf(t, err, "refresh with replacement") assert.Errorf(t, err, "refresh with replacement")
} }
t.Run("WithPrompt", func(t *testing.T) {
opts := rp.WithPrompt("foo", "bar")()
url := provider.OAuthConfig().AuthCodeURL("some", opts...)
require.Contains(t, url, "prompt=foo+bar")
})
} }
type deferredHandler struct { type deferredHandler struct {

View file

@ -333,10 +333,15 @@ func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string {
} }
// AuthURLHandler extends the `AuthURL` method with a http redirect handler // AuthURLHandler extends the `AuthURL` method with a http redirect handler
// including handling setting cookie for secure `state` transfer // including handling setting cookie for secure `state` transfer.
func AuthURLHandler(stateFn func() string, rp RelyingParty) http.HandlerFunc { // Custom paramaters can optionally be set to the redirect URL.
func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
opts := make([]AuthURLOpt, 0) opts := make([]AuthURLOpt, len(urlParam))
for i, p := range urlParam {
opts[i] = AuthURLOpt(p)
}
state := stateFn() state := stateFn()
if err := trySetStateCookie(w, state, rp); err != nil { if err := trySetStateCookie(w, state, rp); err != nil {
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized) http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
@ -350,6 +355,7 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty) http.HandlerFunc {
} }
opts = append(opts, WithCodeChallenge(codeChallenge)) opts = append(opts, WithCodeChallenge(codeChallenge))
} }
http.Redirect(w, r, AuthURL(state, rp, opts...), http.StatusFound) http.Redirect(w, r, AuthURL(state, rp, opts...), http.StatusFound)
} }
} }
@ -398,8 +404,9 @@ type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *o
// CodeExchangeHandler extends the `CodeExchange` method with a http handler // CodeExchangeHandler extends the `CodeExchange` method with a http handler
// including cookie handling for secure `state` transfer // including cookie handling for secure `state` transfer
// and optional PKCE code verifier checking // and optional PKCE code verifier checking.
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.HandlerFunc { // Custom paramaters can optionally be set to the token URL.
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
state, err := tryReadStateCookie(w, r, rp) state, err := tryReadStateCookie(w, r, rp)
if err != nil { if err != nil {
@ -411,7 +418,11 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.Ha
rp.ErrorHandler()(w, r, params.Get("error"), params.Get("error_description"), state) rp.ErrorHandler()(w, r, params.Get("error"), params.Get("error_description"), state)
return return
} }
codeOpts := make([]CodeExchangeOpt, 0) codeOpts := make([]CodeExchangeOpt, len(urlParam))
for i, p := range urlParam {
codeOpts[i] = CodeExchangeOpt(p)
}
if rp.IsPKCE() { if rp.IsPKCE() {
codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode) codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode)
if err != nil { if err != nil {
@ -517,6 +528,37 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
} }
} }
// withURLParam sets custom url paramaters.
// This is the generalized, unexported, function used by both
// URLParamOpt and AuthURLOpt.
func withURLParam(key, value string) func() []oauth2.AuthCodeOption {
return func() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam(key, value),
}
}
}
// withPrompt sets the `prompt` params in the auth request
// This is the generalized, unexported, function used by both
// URLParamOpt and AuthURLOpt.
func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption {
return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode())
}
type URLParamOpt func() []oauth2.AuthCodeOption
// WithURLParam allows setting custom key-vale pairs
// to an OAuth2 URL.
func WithURLParam(key, value string) URLParamOpt {
return withURLParam(key, value)
}
// WithPromptURLParam sets the `prompt` parameter in a URL.
func WithPromptURLParam(prompt ...string) URLParamOpt {
return withPrompt(prompt...)
}
type AuthURLOpt func() []oauth2.AuthCodeOption type AuthURLOpt func() []oauth2.AuthCodeOption
// WithCodeChallenge sets the `code_challenge` params in the auth request // WithCodeChallenge sets the `code_challenge` params in the auth request
@ -531,11 +573,7 @@ func WithCodeChallenge(codeChallenge string) AuthURLOpt {
// WithPrompt sets the `prompt` params in the auth request // WithPrompt sets the `prompt` params in the auth request
func WithPrompt(prompt ...string) AuthURLOpt { func WithPrompt(prompt ...string) AuthURLOpt {
return func() []oauth2.AuthCodeOption { return withPrompt(prompt...)
return []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode()),
}
}
} }
type CodeExchangeOpt func() []oauth2.AuthCodeOption type CodeExchangeOpt func() []oauth2.AuthCodeOption