diff --git a/example/client/app/app.go b/example/client/app/app.go index e33c795..560ac02 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -60,7 +60,7 @@ func main() { http.Handle("/login", rp.AuthURLHandler(state, provider)) // for demonstration purposes the returned userinfo response is written as JSON object onto response - marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info *oidc.UserInfo) { + marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) { data, err := json.Marshal(info) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/example/client/github/github.go b/example/client/github/github.go index 57bb3ae..9cb813c 100644 --- a/example/client/github/github.go +++ b/example/client/github/github.go @@ -13,6 +13,7 @@ import ( "github.com/zitadel/oidc/v2/pkg/client/rp" "github.com/zitadel/oidc/v2/pkg/client/rp/cli" "github.com/zitadel/oidc/v2/pkg/http" + "github.com/zitadel/oidc/v2/pkg/oidc" ) var ( @@ -43,7 +44,7 @@ func main() { state := func() string { return uuid.New().String() } - token := cli.CodeFlow(ctx, relyingParty, callbackPath, port, state) + token := cli.CodeFlow[*oidc.IDTokenClaims](ctx, relyingParty, callbackPath, port, state) client := github.NewClient(relyingParty.OAuthConfig().Client(ctx, token.Token)) diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index 1280454..d401823 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -238,7 +238,7 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, } var email string - redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info *oidc.UserInfo) { + redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) { require.NotNil(t, tokens, "tokens") require.NotNil(t, info, "info") t.Log("access token", tokens.AccessToken) diff --git a/pkg/client/rp/cli/cli.go b/pkg/client/rp/cli/cli.go index 936f319..91b200d 100644 --- a/pkg/client/rp/cli/cli.go +++ b/pkg/client/rp/cli/cli.go @@ -13,13 +13,13 @@ const ( loginPath = "/login" ) -func CodeFlow(ctx context.Context, relyingParty rp.RelyingParty, callbackPath, port string, stateProvider func() string) *oidc.Tokens { +func CodeFlow[C oidc.IDClaims](ctx context.Context, relyingParty rp.RelyingParty, callbackPath, port string, stateProvider func() string) *oidc.Tokens[C] { codeflowCtx, codeflowCancel := context.WithCancel(ctx) defer codeflowCancel() - tokenChan := make(chan *oidc.Tokens, 1) + tokenChan := make(chan *oidc.Tokens[C], 1) - callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) { + callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp rp.RelyingParty) { tokenChan <- tokens msg := "

Success!

" msg = msg + "

You are authenticated and can now return to the CLI.

" diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index dda0869..8aa7b1e 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -373,7 +373,7 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri // CodeExchange handles the oauth2 code exchange, extracting and validating the id_token // returning it parsed together with the oauth2 tokens (access, refresh) -func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens, err error) { +func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) { ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient()) codeOpts := make([]oauth2.AuthCodeOption, 0) for _, opt := range opts { @@ -386,7 +386,7 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod } if rp.IsOAuth2Only() { - return &oidc.Tokens{Token: token}, nil + return &oidc.Tokens[C]{Token: token}, nil } idTokenString, ok := token.Extra(idTokenKey).(string) @@ -394,20 +394,20 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod return nil, errors.New("id_token missing") } - idToken, err := VerifyTokens[*oidc.IDTokenClaims](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier()) + idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier()) if err != nil { return nil, err } - return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil + return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil } -type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty) +type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) // CodeExchangeHandler extends the `CodeExchange` method with a http handler // including cookie handling for secure `state` transfer // and optional PKCE code verifier checking -func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.HandlerFunc { +func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp RelyingParty) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { state, err := tryReadStateCookie(w, r, rp) if err != nil { @@ -436,7 +436,7 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.Ha } codeOpts = append(codeOpts, WithClientAssertionJWT(assertion)) } - tokens, err := CodeExchange(r.Context(), params.Get("code"), rp, codeOpts...) + tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...) if err != nil { http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized) return @@ -445,14 +445,14 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.Ha } } -type CodeExchangeUserinfoCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, provider RelyingParty, info *oidc.UserInfo) +type CodeExchangeUserinfoCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info *oidc.UserInfo) // UserinfoCallback wraps the callback function of the CodeExchangeHandler // and calls the userinfo endpoint with the access token // on success it will pass the userinfo into its callback function as well -func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback { - return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty) { - info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, rp) +func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] { + return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) { + info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) if err != nil { http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized) return diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 62614ca..d79d2b7 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -18,9 +18,9 @@ const ( PrefixBearer = BearerToken + " " ) -type Tokens struct { +type Tokens[C IDClaims] struct { *oauth2.Token - IDTokenClaims *IDTokenClaims + IDTokenClaims C IDToken string }