diff --git a/example/client/app/app.go b/example/client/app/app.go index e3ddd15..e997f41 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -59,11 +59,9 @@ func main() { //including state handling with secure cookie and the possibility to use PKCE http.Handle("/login", rp.AuthURLHandler(state, provider)) - //for demonstration purposes the returned tokens (access token, id_token an its parsed claims) - //are written as JSON objects onto response - marshal := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string) { - _ = state - data, err := json.Marshal(tokens) + //for demonstration purposes the returned userinfo response is written as JSON object onto response + marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) { + data, err := json.Marshal(info) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -74,7 +72,9 @@ func main() { //register the CodeExchangeHandler at the callbackPath //the CodeExchangeHandler handles the auth response, creates the token request and calls the callback function //with the returned tokens from the token endpoint - http.Handle(callbackPath, rp.CodeExchangeHandler(marshal, provider)) + //in this example the callback function itself is wrapped by the UserinfoCallback which + //will call the Userinfo endpoint, check the sub and pass the info into the callback function + http.Handle(callbackPath, rp.CodeExchangeHandler(rp.UserinfoCallback(marshalUserinfo), provider)) lis := fmt.Sprintf("127.0.0.1:%s", port) logrus.Infof("listening on http://%s/", lis) diff --git a/pkg/client/rp/relaying_party.go b/pkg/client/rp/relaying_party.go index 9e02e65..4db43eb 100644 --- a/pkg/client/rp/relaying_party.go +++ b/pkg/client/rp/relaying_party.go @@ -327,10 +327,12 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil } +type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty) + //CodeExchangeHandler extends the `CodeExchange` method with a http handler //including cookie handling for secure `state` transfer //and optional PKCE code verifier checking -func CodeExchangeHandler(callback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string), rp RelyingParty) http.HandlerFunc { +func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { state, err := tryReadStateCookie(w, r, rp) if err != nil { @@ -364,21 +366,23 @@ func CodeExchangeHandler(callback func(w http.ResponseWriter, r *http.Request, t http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized) return } - callback(w, r, tokens, state) + callback(w, r, tokens, state, rp) } } +type CodeExchangeUserinfoCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, provider RelyingParty, info oidc.UserInfo) + //UserinfoCallback wraps the callback function of the CodeExchangeHandler //and calls the userinfo endpoint with the access token //on success it will pass the userinfo into its callback function as well -func UserinfoCallback(provider RelyingParty, f func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, info oidc.UserInfo, state string)) func(http.ResponseWriter, *http.Request, *oidc.Tokens, string) { - return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string) { - info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), provider) +func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback { + return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty) { + info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) if err != nil { http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized) return } - f(w, r, tokens, info, state) + f(w, r, tokens, state, rp, info) } }