From cf8128a60b7e02b2724e21bd0eb5f1ce9af384bd Mon Sep 17 00:00:00 2001 From: Emil Bektimirov Date: Tue, 20 Dec 2022 01:49:55 +0100 Subject: [PATCH] feat: add token exchange to client --- pkg/client/client.go | 19 +++ pkg/client/{rp => }/integration_test.go | 157 ++++++++++++++++------ pkg/client/rs/resource_server.go | 5 + pkg/client/tokenexchange/tokenexchange.go | 127 +++++++++++++++++ 4 files changed, 270 insertions(+), 38 deletions(-) rename pkg/client/{rp => }/integration_test.go (73%) create mode 100644 pkg/client/tokenexchange/tokenexchange.go diff --git a/pkg/client/client.go b/pkg/client/client.go index 344e26b..243fe6d 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -90,6 +90,9 @@ func CallEndSessionEndpoint(request interface{}, authFn interface{}, caller EndS return http.ErrUseLastResponse } resp, err := client.Do(req) + if err != nil { + return nil, err + } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 400 { // TODO: switch to io.ReadAll when go1.15 support is retired @@ -150,6 +153,22 @@ func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCa return nil } +func CallTokenExchangeEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) { + return callTokenExchangeEndpoint(request, authFn, caller) +} + +func callTokenExchangeEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) { + req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn) + if err != nil { + return nil, err + } + tokenRes := new(oidc.TokenExchangeResponse) + if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil { + return nil, err + } + return tokenRes, nil +} + func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error) { privateKey, err := crypto.BytesToPrivateKey(key) if err != nil { diff --git a/pkg/client/rp/integration_test.go b/pkg/client/integration_test.go similarity index 73% rename from pkg/client/rp/integration_test.go rename to pkg/client/integration_test.go index e08e2eb..75f8c7e 100644 --- a/pkg/client/rp/integration_test.go +++ b/pkg/client/integration_test.go @@ -1,4 +1,4 @@ -package rp_test +package client_test import ( "bytes" @@ -21,6 +21,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/zitadel/oidc/pkg/client/rp" + "github.com/zitadel/oidc/pkg/client/rs" + "github.com/zitadel/oidc/pkg/client/tokenexchange" httphelper "github.com/zitadel/oidc/pkg/http" "github.com/zitadel/oidc/pkg/oidc" ) @@ -35,13 +37,119 @@ func TestRelyingPartySession(t *testing.T) { t.Logf("auth server at %s", opServer.URL) dh.Handler = exampleop.SetupServer(ctx, opServer.URL, exampleStorage) + seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) + clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) + + t.Log("------- run authorization code flow ------") + provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, "secret") + + t.Log("------- refresh tokens ------") + + newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "") + require.NoError(t, err, "refresh token") + assert.NotNil(t, newTokens, "access token") + t.Logf("new access token %s", newTokens.AccessToken) + t.Logf("new refresh token %s", newTokens.RefreshToken) + t.Logf("new token type %s", newTokens.TokenType) + t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339)) + require.NotEmpty(t, newTokens.AccessToken, "new accessToken") + + t.Log("------ end session (logout) ------") + + newLoc, err := rp.EndSession(provider, idToken, "", "") + require.NoError(t, err, "logout") + if newLoc != nil { + t.Logf("redirect to %s", newLoc) + } else { + t.Logf("no redirect") + } + + t.Log("------ attempt refresh again (should fail) ------") + t.Log("trying original refresh token", refreshToken) + _, err = rp.RefreshAccessToken(provider, refreshToken, "", "") + assert.Errorf(t, err, "refresh with original") + if newTokens.RefreshToken != "" { + t.Log("trying replacement refresh token", newTokens.RefreshToken) + _, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "") + assert.Errorf(t, err, "refresh with replacement") + } +} + +func TestResourceServerTokenExchange(t *testing.T) { + t.Log("------- start example OP ------") + ctx := context.Background() + exampleStorage := storage.NewStorage(storage.NewUserStore()) + var dh deferredHandler + opServer := httptest.NewServer(&dh) + defer opServer.Close() + t.Logf("auth server at %s", opServer.URL) + dh.Handler = exampleop.SetupServer(ctx, opServer.URL, exampleStorage) + + seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) + clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) + clientSecret := "secret" + + t.Log("------- run authorization code flow ------") + provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret) + + resourceServer, err := rs.NewResourceServerClientCredentials(opServer.URL, clientID, clientSecret) + require.NoError(t, err, "new resource server") + + t.Log("------- exchage refresh tokens (impersonation) ------") + + tokenExchangeResponse, err := tokenexchange.ExchangeToken( + resourceServer, + refreshToken, + oidc.RefreshTokenType, + "", + "", + []string{}, + []string{}, + []string{"profile", "custom_scope:impersonate:id2"}, + oidc.RefreshTokenType, + ) + require.NoError(t, err, "refresh token") + require.NotNil(t, tokenExchangeResponse, "token exchange response") + assert.Equal(t, tokenExchangeResponse.IssuedTokenType, oidc.RefreshTokenType) + assert.NotEmpty(t, tokenExchangeResponse.AccessToken, "access token") + assert.NotEmpty(t, tokenExchangeResponse.RefreshToken, "refresh token") + assert.Equal(t, []string(tokenExchangeResponse.Scopes), []string{"profile", "custom_scope:impersonate:id2"}) + + t.Log("------ end session (logout) ------") + + newLoc, err := rp.EndSession(provider, idToken, "", "") + require.NoError(t, err, "logout") + if newLoc != nil { + t.Logf("redirect to %s", newLoc) + } else { + t.Logf("no redirect") + } + + t.Log("------- attempt exchage again (should fail) ------") + + tokenExchangeResponse, err = tokenexchange.ExchangeToken( + resourceServer, + refreshToken, + oidc.RefreshTokenType, + "", + "", + []string{}, + []string{}, + []string{"profile", "custom_scope:impersonate:id2"}, + oidc.RefreshTokenType, + ) + require.Error(t, err, "refresh token") + assert.Contains(t, err.Error(), "subject_token is invalid") + require.Nil(t, tokenExchangeResponse, "token exchange response") + +} + +func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) { targetURL := "http://local-site" localURL, err := url.Parse(targetURL + "/login?requestID=1234") require.NoError(t, err, "local url") - seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) - clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) - client := storage.WebClient(clientID, "secret", targetURL) + client := storage.WebClient(clientID, clientSecret, targetURL) storage.RegisterClients(client) jar, err := cookiejar.New(nil) @@ -57,10 +165,10 @@ func TestRelyingPartySession(t *testing.T) { t.Log("------- create RP ------") key := []byte("test1234test1234") cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) - provider, err := rp.NewRelyingPartyOIDC( + provider, err = rp.NewRelyingPartyOIDC( opServer.URL, clientID, - "secret", + clientSecret, targetURL, []string{"openid", "email", "profile", "offline_access"}, rp.WithPKCE(cookieHandler), @@ -69,8 +177,10 @@ func TestRelyingPartySession(t *testing.T) { rp.WithSupportedSigningAlgorithms("RS256", "RS384", "RS512", "ES256", "ES384", "ES512"), ), ) + require.NoError(t, err, "new rp") t.Log("------- get redirect from local client (rp) to OP ------") + seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) state := "state-" + strconv.FormatInt(seed.Int63(), 25) capturedW := httptest.NewRecorder() get := httptest.NewRequest("GET", localURL.String(), nil) @@ -124,7 +234,7 @@ func TestRelyingPartySession(t *testing.T) { t.Logf("setting cookie %s", cookie) } - var accessToken, refreshToken, idToken, email string + var email string redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) { require.NotNil(t, tokens, "tokens") require.NotNil(t, info, "info") @@ -137,7 +247,7 @@ func TestRelyingPartySession(t *testing.T) { refreshToken = tokens.RefreshToken idToken = tokens.IDToken email = info.GetEmail() - http.Redirect(w, r, targetURL, 302) + http.Redirect(w, r, targetURL, http.StatusFound) } rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider)(capturedW, get) @@ -162,36 +272,7 @@ func TestRelyingPartySession(t *testing.T) { assert.NotEmpty(t, accessToken, "access token") assert.NotEmpty(t, email, "email") - t.Log("------- refresh tokens ------") - - newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "") - require.NoError(t, err, "refresh token") - assert.NotNil(t, newTokens, "access token") - t.Logf("new access token %s", newTokens.AccessToken) - t.Logf("new refresh token %s", newTokens.RefreshToken) - t.Logf("new token type %s", newTokens.TokenType) - t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339)) - require.NotEmpty(t, newTokens.AccessToken, "new accessToken") - - t.Log("------ end session (logout) ------") - - newLoc, err := rp.EndSession(provider, idToken, "", "") - require.NoError(t, err, "logout") - if newLoc != nil { - t.Logf("redirect to %s", newLoc) - } else { - t.Logf("no redirect") - } - - t.Log("------ attempt refresh again (should fail) ------") - t.Log("trying original refresh token", refreshToken) - _, err = rp.RefreshAccessToken(provider, refreshToken, "", "") - assert.Errorf(t, err, "refresh with original") - if newTokens.RefreshToken != "" { - t.Log("trying replacement refresh token", newTokens.RefreshToken) - _, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "") - assert.Errorf(t, err, "refresh with replacement") - } + return provider, accessToken, refreshToken, idToken } type deferredHandler struct { diff --git a/pkg/client/rs/resource_server.go b/pkg/client/rs/resource_server.go index b1bc47e..6536c7f 100644 --- a/pkg/client/rs/resource_server.go +++ b/pkg/client/rs/resource_server.go @@ -13,6 +13,7 @@ import ( type ResourceServer interface { IntrospectionURL() string + TokenEndpoint() string HttpClient() *http.Client AuthFn() (interface{}, error) } @@ -29,6 +30,10 @@ func (r *resourceServer) IntrospectionURL() string { return r.introspectURL } +func (r *resourceServer) TokenEndpoint() string { + return r.tokenURL +} + func (r *resourceServer) HttpClient() *http.Client { return r.httpClient } diff --git a/pkg/client/tokenexchange/tokenexchange.go b/pkg/client/tokenexchange/tokenexchange.go new file mode 100644 index 0000000..15b5d40 --- /dev/null +++ b/pkg/client/tokenexchange/tokenexchange.go @@ -0,0 +1,127 @@ +package tokenexchange + +import ( + "errors" + "net/http" + + "github.com/zitadel/oidc/pkg/client" + httphelper "github.com/zitadel/oidc/pkg/http" + "github.com/zitadel/oidc/pkg/oidc" +) + +type TokenExchanger interface { + TokenEndpoint() string + HttpClient() *http.Client + AuthFn() (interface{}, error) +} + +type OAuthTokenExchange struct { + httpClient *http.Client + tokenEndpoint string + authFn func() (interface{}, error) +} + +func NewTokenExchanger(issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { + return newOAuthTokenExchange(issuer, nil, options...) +} + +func NewTokenExchangerClientCredentials(issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { + authorizer := func() (interface{}, error) { + return httphelper.AuthorizeBasic(clientID, clientSecret), nil + } + return newOAuthTokenExchange(issuer, authorizer, options...) +} + +func newOAuthTokenExchange(issuer string, authorizer func() (interface{}, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) { + te := &OAuthTokenExchange{ + httpClient: httphelper.DefaultHTTPClient, + } + for _, opt := range options { + opt(te) + } + + if te.tokenEndpoint == "" { + config, err := client.Discover(issuer, te.httpClient) + if err != nil { + return nil, err + } + + te.tokenEndpoint = config.TokenEndpoint + } + + if te.tokenEndpoint == "" { + return nil, errors.New("tokenURL is empty: please provide with either `WithStaticTokenEndpoint` or a discovery url") + } + + te.authFn = authorizer + + return te, nil +} + +func WithHTTPClient(client *http.Client) func(*OAuthTokenExchange) { + return func(source *OAuthTokenExchange) { + source.httpClient = client + } +} + +func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(*OAuthTokenExchange) { + return func(source *OAuthTokenExchange) { + source.tokenEndpoint = tokenEndpoint + } +} + +func (te *OAuthTokenExchange) TokenEndpoint() string { + return te.tokenEndpoint +} + +func (te *OAuthTokenExchange) HttpClient() *http.Client { + return te.httpClient +} + +func (te *OAuthTokenExchange) AuthFn() (interface{}, error) { + if te.authFn != nil { + return te.authFn() + } + + return nil, nil +} + +// ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint. +// SubjectToken and SubjectTokenType are required parameters. +func ExchangeToken( + te TokenExchanger, + SubjectToken string, + SubjectTokenType oidc.TokenType, + ActorToken string, + ActorTokenType oidc.TokenType, + Resource []string, + Audience []string, + Scopes []string, + RequestedTokenType oidc.TokenType, +) (*oidc.TokenExchangeResponse, error) { + if SubjectToken == "" { + return nil, errors.New("empty subject_token") + } + if SubjectTokenType == "" { + return nil, errors.New("empty subject_token_type") + } + + authFn, err := te.AuthFn() + if err != nil { + return nil, err + } + + request := oidc.TokenExchangeRequest{ + GrantType: oidc.GrantTypeTokenExchange, + SubjectToken: SubjectToken, + SubjectTokenType: SubjectTokenType, + ActorToken: ActorToken, + ActorTokenType: ActorTokenType, + Resource: Resource, + Audience: Audience, + Scopes: Scopes, + RequestedTokenType: RequestedTokenType, + } + + return client.CallTokenExchangeEndpoint(request, authFn, te) +}