This commit is contained in:
Livio Amstutz 2021-06-23 06:43:40 +02:00
parent 65a58039dc
commit b546640b5c
2 changed files with 16 additions and 12 deletions

View file

@ -327,10 +327,12 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod
return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
}
type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty)
//CodeExchangeHandler extends the `CodeExchange` method with a http handler
//including cookie handling for secure `state` transfer
//and optional PKCE code verifier checking
func CodeExchangeHandler(callback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string), rp RelyingParty) http.HandlerFunc {
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
state, err := tryReadStateCookie(w, r, rp)
if err != nil {
@ -364,21 +366,23 @@ func CodeExchangeHandler(callback func(w http.ResponseWriter, r *http.Request, t
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
return
}
callback(w, r, tokens, state)
callback(w, r, tokens, state, rp)
}
}
type CodeExchangeUserinfoCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, provider RelyingParty, info oidc.UserInfo)
//UserinfoCallback wraps the callback function of the CodeExchangeHandler
//and calls the userinfo endpoint with the access token
//on success it will pass the userinfo into its callback function as well
func UserinfoCallback(provider RelyingParty, f func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, info oidc.UserInfo, state string)) func(http.ResponseWriter, *http.Request, *oidc.Tokens, string) {
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string) {
info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), provider)
func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty) {
info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
if err != nil {
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
return
}
f(w, r, tokens, info, state)
f(w, r, tokens, state, rp, info)
}
}