feat(rp): return oidc.Tokens on token refresh (#423)

BREAKING CHANGE:
- rename RefreshAccessToken to RefreshToken
- RefreshToken returns *oidc.Tokens instead of *oauth2.Token

This change allows the return of the id_token in an explicit manner,
as part of the oidc.Tokens struct.
The return type is now consistent with the CodeExchange function.

When an id_token is returned, it is verified.
In case no id_token was received,
RefreshTokens will not return an error.

As per specifictation:
https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse

Upon successful validation of the Refresh Token,
the response body is the Token Response of Section 3.1.3.3
except that it might not contain an id_token.

Closes #364
This commit is contained in:
Tim Möhlmann 2023-08-18 15:36:39 +03:00 committed by GitHub
parent e8262cbf1f
commit 6708ef4c24
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 166 additions and 43 deletions

View file

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"io" "io"
"io/ioutil"
"math/rand" "math/rand"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
@ -56,11 +55,11 @@ func TestRelyingPartySession(t *testing.T) {
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
t.Log("------- run authorization code flow ------") t.Log("------- run authorization code flow ------")
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, "secret") provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret")
t.Log("------- refresh tokens ------") t.Log("------- refresh tokens ------")
newTokens, err := rp.RefreshAccessToken(CTX, provider, refreshToken, "", "") newTokens, err := rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
require.NoError(t, err, "refresh token") require.NoError(t, err, "refresh token")
assert.NotNil(t, newTokens, "access token") assert.NotNil(t, newTokens, "access token")
t.Logf("new access token %s", newTokens.AccessToken) t.Logf("new access token %s", newTokens.AccessToken)
@ -68,11 +67,13 @@ func TestRelyingPartySession(t *testing.T) {
t.Logf("new token type %s", newTokens.TokenType) t.Logf("new token type %s", newTokens.TokenType)
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339)) t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
require.NotEmpty(t, newTokens.AccessToken, "new accessToken") require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
assert.NotEmpty(t, newTokens.Extra("id_token"), "new idToken") assert.NotEmpty(t, newTokens.IDToken, "new idToken")
assert.NotNil(t, newTokens.IDTokenClaims)
assert.Equal(t, newTokens.IDTokenClaims.Subject, tokens.IDTokenClaims.Subject)
t.Log("------ end session (logout) ------") t.Log("------ end session (logout) ------")
newLoc, err := rp.EndSession(CTX, provider, idToken, "", "") newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "")
require.NoError(t, err, "logout") require.NoError(t, err, "logout")
if newLoc != nil { if newLoc != nil {
t.Logf("redirect to %s", newLoc) t.Logf("redirect to %s", newLoc)
@ -81,12 +82,12 @@ func TestRelyingPartySession(t *testing.T) {
} }
t.Log("------ attempt refresh again (should fail) ------") t.Log("------ attempt refresh again (should fail) ------")
t.Log("trying original refresh token", refreshToken) t.Log("trying original refresh token", tokens.RefreshToken)
_, err = rp.RefreshAccessToken(CTX, provider, refreshToken, "", "") _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
assert.Errorf(t, err, "refresh with original") assert.Errorf(t, err, "refresh with original")
if newTokens.RefreshToken != "" { if newTokens.RefreshToken != "" {
t.Log("trying replacement refresh token", newTokens.RefreshToken) t.Log("trying replacement refresh token", newTokens.RefreshToken)
_, err = rp.RefreshAccessToken(CTX, provider, newTokens.RefreshToken, "", "") _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, newTokens.RefreshToken, "", "")
assert.Errorf(t, err, "refresh with replacement") assert.Errorf(t, err, "refresh with replacement")
} }
} }
@ -106,7 +107,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
clientSecret := "secret" clientSecret := "secret"
t.Log("------- run authorization code flow ------") t.Log("------- run authorization code flow ------")
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret) provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)
resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret) resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret)
require.NoError(t, err, "new resource server") require.NoError(t, err, "new resource server")
@ -116,7 +117,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
tokenExchangeResponse, err := tokenexchange.ExchangeToken( tokenExchangeResponse, err := tokenexchange.ExchangeToken(
CTX, CTX,
resourceServer, resourceServer,
refreshToken, tokens.RefreshToken,
oidc.RefreshTokenType, oidc.RefreshTokenType,
"", "",
"", "",
@ -134,7 +135,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
t.Log("------ end session (logout) ------") t.Log("------ end session (logout) ------")
newLoc, err := rp.EndSession(CTX, provider, idToken, "", "") newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "")
require.NoError(t, err, "logout") require.NoError(t, err, "logout")
if newLoc != nil { if newLoc != nil {
t.Logf("redirect to %s", newLoc) t.Logf("redirect to %s", newLoc)
@ -147,7 +148,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
tokenExchangeResponse, err = tokenexchange.ExchangeToken( tokenExchangeResponse, err = tokenexchange.ExchangeToken(
CTX, CTX,
resourceServer, resourceServer,
refreshToken, tokens.RefreshToken,
oidc.RefreshTokenType, oidc.RefreshTokenType,
"", "",
"", "",
@ -161,7 +162,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
require.Nil(t, tokenExchangeResponse, "token exchange response") require.Nil(t, tokenExchangeResponse, "token exchange response")
} }
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) { func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) {
targetURL := "http://local-site" targetURL := "http://local-site"
localURL, err := url.Parse(targetURL + "/login?requestID=1234") localURL, err := url.Parse(targetURL + "/login?requestID=1234")
require.NoError(t, err, "local url") require.NoError(t, err, "local url")
@ -258,7 +259,8 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
} }
var email string var email string
redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) { redirect := func(w http.ResponseWriter, r *http.Request, newTokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
tokens = newTokens
require.NotNil(t, tokens, "tokens") require.NotNil(t, tokens, "tokens")
require.NotNil(t, info, "info") require.NotNil(t, info, "info")
t.Log("access token", tokens.AccessToken) t.Log("access token", tokens.AccessToken)
@ -266,9 +268,6 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
t.Log("id token", tokens.IDToken) t.Log("id token", tokens.IDToken)
t.Log("email", info.Email) t.Log("email", info.Email)
accessToken = tokens.AccessToken
refreshToken = tokens.RefreshToken
idToken = tokens.IDToken
email = info.Email email = info.Email
http.Redirect(w, r, targetURL, 302) http.Redirect(w, r, targetURL, 302)
} }
@ -290,12 +289,12 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
require.NoError(t, err, "get fully-authorizied redirect location") require.NoError(t, err, "get fully-authorizied redirect location")
require.Equal(t, targetURL, authorizedURL.String(), "fully-authorizied redirect location") require.Equal(t, targetURL, authorizedURL.String(), "fully-authorizied redirect location")
require.NotEmpty(t, idToken, "id token") require.NotEmpty(t, tokens.IDToken, "id token")
assert.NotEmpty(t, refreshToken, "refresh token") assert.NotEmpty(t, tokens.RefreshToken, "refresh token")
assert.NotEmpty(t, accessToken, "access token") assert.NotEmpty(t, tokens.AccessToken, "access token")
assert.NotEmpty(t, email, "email") assert.NotEmpty(t, email, "email")
return provider, accessToken, refreshToken, idToken return provider, tokens
} }
type deferredHandler struct { type deferredHandler struct {
@ -343,7 +342,7 @@ func getForm(t *testing.T, desc string, httpClient *http.Client, uri *url.URL) [
func fillForm(t *testing.T, desc string, httpClient *http.Client, body []byte, uri *url.URL, opts ...gosubmit.Option) *url.URL { func fillForm(t *testing.T, desc string, httpClient *http.Client, body []byte, uri *url.URL, opts ...gosubmit.Option) *url.URL {
// TODO: switch to io.NopCloser when go1.15 support is dropped // TODO: switch to io.NopCloser when go1.15 support is dropped
req := gosubmit.ParseWithURL(ioutil.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest( req := gosubmit.ParseWithURL(io.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest(
append([]gosubmit.Option{gosubmit.AutoFill()}, opts...)..., append([]gosubmit.Option{gosubmit.AutoFill()}, opts...)...,
) )
if req.URL.Scheme == "" { if req.URL.Scheme == "" {

View file

@ -356,6 +356,25 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri
return oidc.NewSHACodeChallenge(codeVerifier), nil return oidc.NewSHACodeChallenge(codeVerifier), nil
} }
// ErrMissingIDToken is returned when an id_token was expected,
// but not received in the token response.
var ErrMissingIDToken = errors.New("id_token missing")
func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Token, rp RelyingParty) (*oidc.Tokens[C], error) {
if rp.IsOAuth2Only() {
return &oidc.Tokens[C]{Token: token}, nil
}
idTokenString, ok := token.Extra(idTokenKey).(string)
if !ok {
return &oidc.Tokens[C]{Token: token}, ErrMissingIDToken
}
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
if err != nil {
return nil, err
}
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
}
// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token // CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
// returning it parsed together with the oauth2 tokens (access, refresh) // returning it parsed together with the oauth2 tokens (access, refresh)
func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) { func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
@ -369,22 +388,7 @@ func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingP
if err != nil { if err != nil {
return nil, err return nil, err
} }
return verifyTokenResponse[C](ctx, token, rp)
if rp.IsOAuth2Only() {
return &oidc.Tokens[C]{Token: token}, nil
}
idTokenString, ok := token.Extra(idTokenKey).(string)
if !ok {
return nil, errors.New("id_token missing")
}
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
if err != nil {
return nil, err
}
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
} }
type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty)
@ -609,11 +613,14 @@ type RefreshTokenRequest struct {
GrantType oidc.GrantType `schema:"grant_type"` GrantType oidc.GrantType `schema:"grant_type"`
} }
// RefreshAccessToken performs a token refresh. If it doesn't error, it will always // RefreshTokens performs a token refresh. If it doesn't error, it will always
// provide a new AccessToken. It may provide a new RefreshToken, and if it does, then // provide a new AccessToken. It may provide a new RefreshToken, and if it does, then
// the old one should be considered invalid. It may also provide a new IDToken. The // the old one should be considered invalid.
// new IDToken can be retrieved with token.Extra("id_token"). //
func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) { // In case the RP is not OAuth2 only and an IDToken was part of the response,
// the IDToken and AccessToken will be verfied
// and the IDToken and IDTokenClaims fields will be populated in the returned object.
func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) {
request := RefreshTokenRequest{ request := RefreshTokenRequest{
RefreshToken: refreshToken, RefreshToken: refreshToken,
Scopes: rp.OAuthConfig().Scopes, Scopes: rp.OAuthConfig().Scopes,
@ -623,7 +630,17 @@ func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clie
ClientAssertionType: clientAssertionType, ClientAssertionType: clientAssertionType,
GrantType: oidc.GrantTypeRefreshToken, GrantType: oidc.GrantTypeRefreshToken,
} }
return client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp}) newToken, err := client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp})
if err != nil {
return nil, err
}
tokens, err := verifyTokenResponse[C](ctx, newToken, rp)
if err == nil || errors.Is(err, ErrMissingIDToken) {
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse
// ...except that it might not contain an id_token.
return tokens, nil
}
return nil, err
} }
func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) {

View file

@ -0,0 +1,107 @@
package rp
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
)
func Test_verifyTokenResponse(t *testing.T) {
verifier := &IDTokenVerifier{
Issuer: tu.ValidIssuer,
MaxAgeIAT: 2 * time.Minute,
ClientID: tu.ValidClientID,
Offset: time.Second,
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
KeySet: tu.KeySet{},
MaxAge: 2 * time.Minute,
ACR: tu.ACRVerify,
Nonce: func(context.Context) string { return tu.ValidNonce },
}
tests := []struct {
name string
oauth2Only bool
tokens func() (token *oauth2.Token, want *oidc.Tokens[*oidc.IDTokenClaims])
wantErr error
}{
{
name: "succes, oauth2 only",
oauth2Only: true,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
Token: token,
}
},
},
{
name: "id_token missing error",
oauth2Only: false,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
Token: token,
}
},
wantErr: ErrMissingIDToken,
},
{
name: "verify tokens error",
oauth2Only: false,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
token = token.WithExtra(map[string]any{
"id_token": "foobar",
})
return token, nil
},
wantErr: oidc.ErrParse,
},
{
name: "success, with id_token",
oauth2Only: false,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
idToken, claims := tu.ValidIDToken()
token = token.WithExtra(map[string]any{
"id_token": idToken,
})
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
Token: token,
IDTokenClaims: claims,
IDToken: idToken,
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := &relyingParty{
oauth2Only: tt.oauth2Only,
idTokenVerifier: verifier,
}
token, want := tt.tokens()
got, err := verifyTokenResponse[*oidc.IDTokenClaims](context.Background(), token, rp)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, want, got)
})
}
}