rp: allow to set custom URL parameters (#273)
* rp: allow to set prompts in AuthURLHandler Fixes #241 * rp: configuration for handlers with URL options to call RS Fixes #265
This commit is contained in:
parent
ff2729cb23
commit
c8d61c0858
3 changed files with 73 additions and 20 deletions
|
@ -54,10 +54,12 @@ func main() {
|
|||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// register the AuthURLHandler at your preferred path
|
||||
// the AuthURLHandler creates the auth request and redirects the user to the auth server
|
||||
// including state handling with secure cookie and the possibility to use PKCE
|
||||
http.Handle("/login", rp.AuthURLHandler(state, provider))
|
||||
// register the AuthURLHandler at your preferred path.
|
||||
// the AuthURLHandler creates the auth request and redirects the user to the auth server.
|
||||
// including state handling with secure cookie and the possibility to use PKCE.
|
||||
// Prompts can optionally be set to inform the server of
|
||||
// any messages that need to be prompted back to the user.
|
||||
http.Handle("/login", rp.AuthURLHandler(state, provider, rp.WithPromptURLParam("Welcome back!")))
|
||||
|
||||
// for demonstration purposes the returned userinfo response is written as JSON object onto response
|
||||
marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {
|
||||
|
|
|
@ -75,7 +75,10 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
state := "state-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
capturedW := httptest.NewRecorder()
|
||||
get := httptest.NewRequest("GET", localURL.String(), nil)
|
||||
rp.AuthURLHandler(func() string { return state }, provider)(capturedW, get)
|
||||
rp.AuthURLHandler(func() string { return state }, provider,
|
||||
rp.WithPromptURLParam("Hello, World!", "Goodbye, World!"),
|
||||
rp.WithURLParam("custom", "param"),
|
||||
)(capturedW, get)
|
||||
|
||||
defer func() {
|
||||
if t.Failed() {
|
||||
|
@ -84,6 +87,8 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
}()
|
||||
require.GreaterOrEqual(t, capturedW.Code, 200, "captured response code")
|
||||
require.Less(t, capturedW.Code, 400, "captured response code")
|
||||
require.Contains(t, capturedW.Body.String(), `prompt=Hello%2C+World%21+Goodbye%2C+World%21`)
|
||||
require.Contains(t, capturedW.Body.String(), `custom=param`)
|
||||
|
||||
//nolint:bodyclose
|
||||
resp := capturedW.Result()
|
||||
|
@ -140,7 +145,7 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
email = info.GetEmail()
|
||||
http.Redirect(w, r, targetURL, 302)
|
||||
}
|
||||
rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider)(capturedW, get)
|
||||
rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider, rp.WithURLParam("custom", "param"))(capturedW, get)
|
||||
|
||||
defer func() {
|
||||
if t.Failed() {
|
||||
|
@ -150,6 +155,7 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
}()
|
||||
require.Less(t, capturedW.Code, 400, "token exchange response code")
|
||||
require.Less(t, capturedW.Code, 400, "token exchange response code")
|
||||
// TODO: how to check the custom header was sent to the server?
|
||||
|
||||
//nolint:bodyclose
|
||||
resp = capturedW.Result()
|
||||
|
@ -193,6 +199,13 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
|
||||
assert.Errorf(t, err, "refresh with replacement")
|
||||
}
|
||||
|
||||
t.Run("WithPrompt", func(t *testing.T) {
|
||||
opts := rp.WithPrompt("foo", "bar")()
|
||||
url := provider.OAuthConfig().AuthCodeURL("some", opts...)
|
||||
|
||||
require.Contains(t, url, "prompt=foo+bar")
|
||||
})
|
||||
}
|
||||
|
||||
type deferredHandler struct {
|
||||
|
|
|
@ -255,7 +255,7 @@ func WithVerifierOpts(opts ...VerifierOption) Option {
|
|||
|
||||
// WithClientKey specifies the path to the key.json to be used for the JWT Profile Client Authentication on the token endpoint
|
||||
//
|
||||
//deprecated: use WithJWTProfile(SignerFromKeyPath(path)) instead
|
||||
// deprecated: use WithJWTProfile(SignerFromKeyPath(path)) instead
|
||||
func WithClientKey(path string) Option {
|
||||
return WithJWTProfile(SignerFromKeyPath(path))
|
||||
}
|
||||
|
@ -304,7 +304,7 @@ func SignerFromKeyAndKeyID(key []byte, keyID string) SignerFromKey {
|
|||
|
||||
// Discover calls the discovery endpoint of the provided issuer and returns the found endpoints
|
||||
//
|
||||
//deprecated: use client.Discover
|
||||
// deprecated: use client.Discover
|
||||
func Discover(issuer string, httpClient *http.Client) (Endpoints, error) {
|
||||
wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
|
||||
req, err := http.NewRequest("GET", wellKnown, nil)
|
||||
|
@ -323,7 +323,7 @@ func Discover(issuer string, httpClient *http.Client) (Endpoints, error) {
|
|||
}
|
||||
|
||||
// AuthURL returns the auth request url
|
||||
//(wrapping the oauth2 `AuthCodeURL`)
|
||||
// (wrapping the oauth2 `AuthCodeURL`)
|
||||
func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string {
|
||||
authOpts := make([]oauth2.AuthCodeOption, 0)
|
||||
for _, opt := range opts {
|
||||
|
@ -333,10 +333,15 @@ func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string {
|
|||
}
|
||||
|
||||
// AuthURLHandler extends the `AuthURL` method with a http redirect handler
|
||||
// including handling setting cookie for secure `state` transfer
|
||||
func AuthURLHandler(stateFn func() string, rp RelyingParty) http.HandlerFunc {
|
||||
// including handling setting cookie for secure `state` transfer.
|
||||
// Custom paramaters can optionally be set to the redirect URL.
|
||||
func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
opts := make([]AuthURLOpt, 0)
|
||||
opts := make([]AuthURLOpt, len(urlParam))
|
||||
for i, p := range urlParam {
|
||||
opts[i] = AuthURLOpt(p)
|
||||
}
|
||||
|
||||
state := stateFn()
|
||||
if err := trySetStateCookie(w, state, rp); err != nil {
|
||||
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
|
||||
|
@ -350,6 +355,7 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty) http.HandlerFunc {
|
|||
}
|
||||
opts = append(opts, WithCodeChallenge(codeChallenge))
|
||||
}
|
||||
|
||||
http.Redirect(w, r, AuthURL(state, rp, opts...), http.StatusFound)
|
||||
}
|
||||
}
|
||||
|
@ -398,8 +404,9 @@ type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *o
|
|||
|
||||
// CodeExchangeHandler extends the `CodeExchange` method with a http handler
|
||||
// including cookie handling for secure `state` transfer
|
||||
// and optional PKCE code verifier checking
|
||||
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.HandlerFunc {
|
||||
// and optional PKCE code verifier checking.
|
||||
// Custom paramaters can optionally be set to the token URL.
|
||||
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := tryReadStateCookie(w, r, rp)
|
||||
if err != nil {
|
||||
|
@ -411,7 +418,11 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.Ha
|
|||
rp.ErrorHandler()(w, r, params.Get("error"), params.Get("error_description"), state)
|
||||
return
|
||||
}
|
||||
codeOpts := make([]CodeExchangeOpt, 0)
|
||||
codeOpts := make([]CodeExchangeOpt, len(urlParam))
|
||||
for i, p := range urlParam {
|
||||
codeOpts[i] = CodeExchangeOpt(p)
|
||||
}
|
||||
|
||||
if rp.IsPKCE() {
|
||||
codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode)
|
||||
if err != nil {
|
||||
|
@ -517,6 +528,37 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
|
|||
}
|
||||
}
|
||||
|
||||
// withURLParam sets custom url paramaters.
|
||||
// This is the generalized, unexported, function used by both
|
||||
// URLParamOpt and AuthURLOpt.
|
||||
func withURLParam(key, value string) func() []oauth2.AuthCodeOption {
|
||||
return func() []oauth2.AuthCodeOption {
|
||||
return []oauth2.AuthCodeOption{
|
||||
oauth2.SetAuthURLParam(key, value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// withPrompt sets the `prompt` params in the auth request
|
||||
// This is the generalized, unexported, function used by both
|
||||
// URLParamOpt and AuthURLOpt.
|
||||
func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption {
|
||||
return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode())
|
||||
}
|
||||
|
||||
type URLParamOpt func() []oauth2.AuthCodeOption
|
||||
|
||||
// WithURLParam allows setting custom key-vale pairs
|
||||
// to an OAuth2 URL.
|
||||
func WithURLParam(key, value string) URLParamOpt {
|
||||
return withURLParam(key, value)
|
||||
}
|
||||
|
||||
// WithPromptURLParam sets the `prompt` parameter in a URL.
|
||||
func WithPromptURLParam(prompt ...string) URLParamOpt {
|
||||
return withPrompt(prompt...)
|
||||
}
|
||||
|
||||
type AuthURLOpt func() []oauth2.AuthCodeOption
|
||||
|
||||
// WithCodeChallenge sets the `code_challenge` params in the auth request
|
||||
|
@ -531,11 +573,7 @@ func WithCodeChallenge(codeChallenge string) AuthURLOpt {
|
|||
|
||||
// WithPrompt sets the `prompt` params in the auth request
|
||||
func WithPrompt(prompt ...string) AuthURLOpt {
|
||||
return func() []oauth2.AuthCodeOption {
|
||||
return []oauth2.AuthCodeOption{
|
||||
oauth2.SetAuthURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode()),
|
||||
}
|
||||
}
|
||||
return withPrompt(prompt...)
|
||||
}
|
||||
|
||||
type CodeExchangeOpt func() []oauth2.AuthCodeOption
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue