rp: allow to set custom URL parameters (#273)
* rp: allow to set prompts in AuthURLHandler Fixes #241 * rp: configuration for handlers with URL options to call RS Fixes #265
This commit is contained in:
parent
ff2729cb23
commit
c8d61c0858
3 changed files with 73 additions and 20 deletions
|
@ -54,10 +54,12 @@ func main() {
|
||||||
return uuid.New().String()
|
return uuid.New().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// register the AuthURLHandler at your preferred path
|
// register the AuthURLHandler at your preferred path.
|
||||||
// the AuthURLHandler creates the auth request and redirects the user to the auth server
|
// the AuthURLHandler creates the auth request and redirects the user to the auth server.
|
||||||
// including state handling with secure cookie and the possibility to use PKCE
|
// including state handling with secure cookie and the possibility to use PKCE.
|
||||||
http.Handle("/login", rp.AuthURLHandler(state, provider))
|
// Prompts can optionally be set to inform the server of
|
||||||
|
// any messages that need to be prompted back to the user.
|
||||||
|
http.Handle("/login", rp.AuthURLHandler(state, provider, rp.WithPromptURLParam("Welcome back!")))
|
||||||
|
|
||||||
// for demonstration purposes the returned userinfo response is written as JSON object onto response
|
// for demonstration purposes the returned userinfo response is written as JSON object onto response
|
||||||
marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {
|
marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {
|
||||||
|
|
|
@ -75,7 +75,10 @@ func TestRelyingPartySession(t *testing.T) {
|
||||||
state := "state-" + strconv.FormatInt(seed.Int63(), 25)
|
state := "state-" + strconv.FormatInt(seed.Int63(), 25)
|
||||||
capturedW := httptest.NewRecorder()
|
capturedW := httptest.NewRecorder()
|
||||||
get := httptest.NewRequest("GET", localURL.String(), nil)
|
get := httptest.NewRequest("GET", localURL.String(), nil)
|
||||||
rp.AuthURLHandler(func() string { return state }, provider)(capturedW, get)
|
rp.AuthURLHandler(func() string { return state }, provider,
|
||||||
|
rp.WithPromptURLParam("Hello, World!", "Goodbye, World!"),
|
||||||
|
rp.WithURLParam("custom", "param"),
|
||||||
|
)(capturedW, get)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if t.Failed() {
|
if t.Failed() {
|
||||||
|
@ -84,6 +87,8 @@ func TestRelyingPartySession(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
require.GreaterOrEqual(t, capturedW.Code, 200, "captured response code")
|
require.GreaterOrEqual(t, capturedW.Code, 200, "captured response code")
|
||||||
require.Less(t, capturedW.Code, 400, "captured response code")
|
require.Less(t, capturedW.Code, 400, "captured response code")
|
||||||
|
require.Contains(t, capturedW.Body.String(), `prompt=Hello%2C+World%21+Goodbye%2C+World%21`)
|
||||||
|
require.Contains(t, capturedW.Body.String(), `custom=param`)
|
||||||
|
|
||||||
//nolint:bodyclose
|
//nolint:bodyclose
|
||||||
resp := capturedW.Result()
|
resp := capturedW.Result()
|
||||||
|
@ -140,7 +145,7 @@ func TestRelyingPartySession(t *testing.T) {
|
||||||
email = info.GetEmail()
|
email = info.GetEmail()
|
||||||
http.Redirect(w, r, targetURL, 302)
|
http.Redirect(w, r, targetURL, 302)
|
||||||
}
|
}
|
||||||
rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider)(capturedW, get)
|
rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider, rp.WithURLParam("custom", "param"))(capturedW, get)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if t.Failed() {
|
if t.Failed() {
|
||||||
|
@ -150,6 +155,7 @@ func TestRelyingPartySession(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
require.Less(t, capturedW.Code, 400, "token exchange response code")
|
require.Less(t, capturedW.Code, 400, "token exchange response code")
|
||||||
require.Less(t, capturedW.Code, 400, "token exchange response code")
|
require.Less(t, capturedW.Code, 400, "token exchange response code")
|
||||||
|
// TODO: how to check the custom header was sent to the server?
|
||||||
|
|
||||||
//nolint:bodyclose
|
//nolint:bodyclose
|
||||||
resp = capturedW.Result()
|
resp = capturedW.Result()
|
||||||
|
@ -193,6 +199,13 @@ func TestRelyingPartySession(t *testing.T) {
|
||||||
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
|
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
|
||||||
assert.Errorf(t, err, "refresh with replacement")
|
assert.Errorf(t, err, "refresh with replacement")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Run("WithPrompt", func(t *testing.T) {
|
||||||
|
opts := rp.WithPrompt("foo", "bar")()
|
||||||
|
url := provider.OAuthConfig().AuthCodeURL("some", opts...)
|
||||||
|
|
||||||
|
require.Contains(t, url, "prompt=foo+bar")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type deferredHandler struct {
|
type deferredHandler struct {
|
||||||
|
|
|
@ -255,7 +255,7 @@ func WithVerifierOpts(opts ...VerifierOption) Option {
|
||||||
|
|
||||||
// WithClientKey specifies the path to the key.json to be used for the JWT Profile Client Authentication on the token endpoint
|
// WithClientKey specifies the path to the key.json to be used for the JWT Profile Client Authentication on the token endpoint
|
||||||
//
|
//
|
||||||
//deprecated: use WithJWTProfile(SignerFromKeyPath(path)) instead
|
// deprecated: use WithJWTProfile(SignerFromKeyPath(path)) instead
|
||||||
func WithClientKey(path string) Option {
|
func WithClientKey(path string) Option {
|
||||||
return WithJWTProfile(SignerFromKeyPath(path))
|
return WithJWTProfile(SignerFromKeyPath(path))
|
||||||
}
|
}
|
||||||
|
@ -304,7 +304,7 @@ func SignerFromKeyAndKeyID(key []byte, keyID string) SignerFromKey {
|
||||||
|
|
||||||
// Discover calls the discovery endpoint of the provided issuer and returns the found endpoints
|
// Discover calls the discovery endpoint of the provided issuer and returns the found endpoints
|
||||||
//
|
//
|
||||||
//deprecated: use client.Discover
|
// deprecated: use client.Discover
|
||||||
func Discover(issuer string, httpClient *http.Client) (Endpoints, error) {
|
func Discover(issuer string, httpClient *http.Client) (Endpoints, error) {
|
||||||
wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
|
wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
|
||||||
req, err := http.NewRequest("GET", wellKnown, nil)
|
req, err := http.NewRequest("GET", wellKnown, nil)
|
||||||
|
@ -323,7 +323,7 @@ func Discover(issuer string, httpClient *http.Client) (Endpoints, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthURL returns the auth request url
|
// AuthURL returns the auth request url
|
||||||
//(wrapping the oauth2 `AuthCodeURL`)
|
// (wrapping the oauth2 `AuthCodeURL`)
|
||||||
func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string {
|
func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string {
|
||||||
authOpts := make([]oauth2.AuthCodeOption, 0)
|
authOpts := make([]oauth2.AuthCodeOption, 0)
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
|
@ -333,10 +333,15 @@ func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthURLHandler extends the `AuthURL` method with a http redirect handler
|
// AuthURLHandler extends the `AuthURL` method with a http redirect handler
|
||||||
// including handling setting cookie for secure `state` transfer
|
// including handling setting cookie for secure `state` transfer.
|
||||||
func AuthURLHandler(stateFn func() string, rp RelyingParty) http.HandlerFunc {
|
// Custom paramaters can optionally be set to the redirect URL.
|
||||||
|
func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
opts := make([]AuthURLOpt, 0)
|
opts := make([]AuthURLOpt, len(urlParam))
|
||||||
|
for i, p := range urlParam {
|
||||||
|
opts[i] = AuthURLOpt(p)
|
||||||
|
}
|
||||||
|
|
||||||
state := stateFn()
|
state := stateFn()
|
||||||
if err := trySetStateCookie(w, state, rp); err != nil {
|
if err := trySetStateCookie(w, state, rp); err != nil {
|
||||||
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
|
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
|
||||||
|
@ -350,6 +355,7 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty) http.HandlerFunc {
|
||||||
}
|
}
|
||||||
opts = append(opts, WithCodeChallenge(codeChallenge))
|
opts = append(opts, WithCodeChallenge(codeChallenge))
|
||||||
}
|
}
|
||||||
|
|
||||||
http.Redirect(w, r, AuthURL(state, rp, opts...), http.StatusFound)
|
http.Redirect(w, r, AuthURL(state, rp, opts...), http.StatusFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -398,8 +404,9 @@ type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *o
|
||||||
|
|
||||||
// CodeExchangeHandler extends the `CodeExchange` method with a http handler
|
// CodeExchangeHandler extends the `CodeExchange` method with a http handler
|
||||||
// including cookie handling for secure `state` transfer
|
// including cookie handling for secure `state` transfer
|
||||||
// and optional PKCE code verifier checking
|
// and optional PKCE code verifier checking.
|
||||||
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.HandlerFunc {
|
// Custom paramaters can optionally be set to the token URL.
|
||||||
|
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
state, err := tryReadStateCookie(w, r, rp)
|
state, err := tryReadStateCookie(w, r, rp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -411,7 +418,11 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.Ha
|
||||||
rp.ErrorHandler()(w, r, params.Get("error"), params.Get("error_description"), state)
|
rp.ErrorHandler()(w, r, params.Get("error"), params.Get("error_description"), state)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
codeOpts := make([]CodeExchangeOpt, 0)
|
codeOpts := make([]CodeExchangeOpt, len(urlParam))
|
||||||
|
for i, p := range urlParam {
|
||||||
|
codeOpts[i] = CodeExchangeOpt(p)
|
||||||
|
}
|
||||||
|
|
||||||
if rp.IsPKCE() {
|
if rp.IsPKCE() {
|
||||||
codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode)
|
codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -517,6 +528,37 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// withURLParam sets custom url paramaters.
|
||||||
|
// This is the generalized, unexported, function used by both
|
||||||
|
// URLParamOpt and AuthURLOpt.
|
||||||
|
func withURLParam(key, value string) func() []oauth2.AuthCodeOption {
|
||||||
|
return func() []oauth2.AuthCodeOption {
|
||||||
|
return []oauth2.AuthCodeOption{
|
||||||
|
oauth2.SetAuthURLParam(key, value),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// withPrompt sets the `prompt` params in the auth request
|
||||||
|
// This is the generalized, unexported, function used by both
|
||||||
|
// URLParamOpt and AuthURLOpt.
|
||||||
|
func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption {
|
||||||
|
return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode())
|
||||||
|
}
|
||||||
|
|
||||||
|
type URLParamOpt func() []oauth2.AuthCodeOption
|
||||||
|
|
||||||
|
// WithURLParam allows setting custom key-vale pairs
|
||||||
|
// to an OAuth2 URL.
|
||||||
|
func WithURLParam(key, value string) URLParamOpt {
|
||||||
|
return withURLParam(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPromptURLParam sets the `prompt` parameter in a URL.
|
||||||
|
func WithPromptURLParam(prompt ...string) URLParamOpt {
|
||||||
|
return withPrompt(prompt...)
|
||||||
|
}
|
||||||
|
|
||||||
type AuthURLOpt func() []oauth2.AuthCodeOption
|
type AuthURLOpt func() []oauth2.AuthCodeOption
|
||||||
|
|
||||||
// WithCodeChallenge sets the `code_challenge` params in the auth request
|
// WithCodeChallenge sets the `code_challenge` params in the auth request
|
||||||
|
@ -531,11 +573,7 @@ func WithCodeChallenge(codeChallenge string) AuthURLOpt {
|
||||||
|
|
||||||
// WithPrompt sets the `prompt` params in the auth request
|
// WithPrompt sets the `prompt` params in the auth request
|
||||||
func WithPrompt(prompt ...string) AuthURLOpt {
|
func WithPrompt(prompt ...string) AuthURLOpt {
|
||||||
return func() []oauth2.AuthCodeOption {
|
return withPrompt(prompt...)
|
||||||
return []oauth2.AuthCodeOption{
|
|
||||||
oauth2.SetAuthURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CodeExchangeOpt func() []oauth2.AuthCodeOption
|
type CodeExchangeOpt func() []oauth2.AuthCodeOption
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue