This commit is contained in:
Livio Amstutz 2021-06-23 06:43:40 +02:00
parent 65a58039dc
commit b546640b5c
2 changed files with 16 additions and 12 deletions

View file

@ -59,11 +59,9 @@ func main() {
//including state handling with secure cookie and the possibility to use PKCE //including state handling with secure cookie and the possibility to use PKCE
http.Handle("/login", rp.AuthURLHandler(state, provider)) http.Handle("/login", rp.AuthURLHandler(state, provider))
//for demonstration purposes the returned tokens (access token, id_token an its parsed claims) //for demonstration purposes the returned userinfo response is written as JSON object onto response
//are written as JSON objects onto response marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {
marshal := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string) { data, err := json.Marshal(info)
_ = state
data, err := json.Marshal(tokens)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@ -74,7 +72,9 @@ func main() {
//register the CodeExchangeHandler at the callbackPath //register the CodeExchangeHandler at the callbackPath
//the CodeExchangeHandler handles the auth response, creates the token request and calls the callback function //the CodeExchangeHandler handles the auth response, creates the token request and calls the callback function
//with the returned tokens from the token endpoint //with the returned tokens from the token endpoint
http.Handle(callbackPath, rp.CodeExchangeHandler(marshal, provider)) //in this example the callback function itself is wrapped by the UserinfoCallback which
//will call the Userinfo endpoint, check the sub and pass the info into the callback function
http.Handle(callbackPath, rp.CodeExchangeHandler(rp.UserinfoCallback(marshalUserinfo), provider))
lis := fmt.Sprintf("127.0.0.1:%s", port) lis := fmt.Sprintf("127.0.0.1:%s", port)
logrus.Infof("listening on http://%s/", lis) logrus.Infof("listening on http://%s/", lis)

View file

@ -327,10 +327,12 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod
return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
} }
type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty)
//CodeExchangeHandler extends the `CodeExchange` method with a http handler //CodeExchangeHandler extends the `CodeExchange` method with a http handler
//including cookie handling for secure `state` transfer //including cookie handling for secure `state` transfer
//and optional PKCE code verifier checking //and optional PKCE code verifier checking
func CodeExchangeHandler(callback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string), rp RelyingParty) http.HandlerFunc { func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
state, err := tryReadStateCookie(w, r, rp) state, err := tryReadStateCookie(w, r, rp)
if err != nil { if err != nil {
@ -364,21 +366,23 @@ func CodeExchangeHandler(callback func(w http.ResponseWriter, r *http.Request, t
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized) http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
return return
} }
callback(w, r, tokens, state) callback(w, r, tokens, state, rp)
} }
} }
type CodeExchangeUserinfoCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, provider RelyingParty, info oidc.UserInfo)
//UserinfoCallback wraps the callback function of the CodeExchangeHandler //UserinfoCallback wraps the callback function of the CodeExchangeHandler
//and calls the userinfo endpoint with the access token //and calls the userinfo endpoint with the access token
//on success it will pass the userinfo into its callback function as well //on success it will pass the userinfo into its callback function as well
func UserinfoCallback(provider RelyingParty, f func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, info oidc.UserInfo, state string)) func(http.ResponseWriter, *http.Request, *oidc.Tokens, string) { func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string) { return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty) {
info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), provider) info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
if err != nil { if err != nil {
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized) http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
return return
} }
f(w, r, tokens, info, state) f(w, r, tokens, state, rp, info)
} }
} }