feat(rp): return oidc.Tokens on token refresh (#423)
BREAKING CHANGE: - rename RefreshAccessToken to RefreshToken - RefreshToken returns *oidc.Tokens instead of *oauth2.Token This change allows the return of the id_token in an explicit manner, as part of the oidc.Tokens struct. The return type is now consistent with the CodeExchange function. When an id_token is returned, it is verified. In case no id_token was received, RefreshTokens will not return an error. As per specifictation: https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse Upon successful validation of the Refresh Token, the response body is the Token Response of Section 3.1.3.3 except that it might not contain an id_token. Closes #364
This commit is contained in:
parent
e8262cbf1f
commit
6708ef4c24
3 changed files with 166 additions and 43 deletions
|
@ -356,6 +356,25 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri
|
|||
return oidc.NewSHACodeChallenge(codeVerifier), nil
|
||||
}
|
||||
|
||||
// ErrMissingIDToken is returned when an id_token was expected,
|
||||
// but not received in the token response.
|
||||
var ErrMissingIDToken = errors.New("id_token missing")
|
||||
|
||||
func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Token, rp RelyingParty) (*oidc.Tokens[C], error) {
|
||||
if rp.IsOAuth2Only() {
|
||||
return &oidc.Tokens[C]{Token: token}, nil
|
||||
}
|
||||
idTokenString, ok := token.Extra(idTokenKey).(string)
|
||||
if !ok {
|
||||
return &oidc.Tokens[C]{Token: token}, ErrMissingIDToken
|
||||
}
|
||||
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
|
||||
}
|
||||
|
||||
// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
|
||||
// returning it parsed together with the oauth2 tokens (access, refresh)
|
||||
func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
|
||||
|
@ -369,22 +388,7 @@ func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingP
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rp.IsOAuth2Only() {
|
||||
return &oidc.Tokens[C]{Token: token}, nil
|
||||
}
|
||||
|
||||
idTokenString, ok := token.Extra(idTokenKey).(string)
|
||||
if !ok {
|
||||
return nil, errors.New("id_token missing")
|
||||
}
|
||||
|
||||
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
|
||||
return verifyTokenResponse[C](ctx, token, rp)
|
||||
}
|
||||
|
||||
type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty)
|
||||
|
@ -609,11 +613,14 @@ type RefreshTokenRequest struct {
|
|||
GrantType oidc.GrantType `schema:"grant_type"`
|
||||
}
|
||||
|
||||
// RefreshAccessToken performs a token refresh. If it doesn't error, it will always
|
||||
// RefreshTokens performs a token refresh. If it doesn't error, it will always
|
||||
// provide a new AccessToken. It may provide a new RefreshToken, and if it does, then
|
||||
// the old one should be considered invalid. It may also provide a new IDToken. The
|
||||
// new IDToken can be retrieved with token.Extra("id_token").
|
||||
func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) {
|
||||
// the old one should be considered invalid.
|
||||
//
|
||||
// In case the RP is not OAuth2 only and an IDToken was part of the response,
|
||||
// the IDToken and AccessToken will be verfied
|
||||
// and the IDToken and IDTokenClaims fields will be populated in the returned object.
|
||||
func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) {
|
||||
request := RefreshTokenRequest{
|
||||
RefreshToken: refreshToken,
|
||||
Scopes: rp.OAuthConfig().Scopes,
|
||||
|
@ -623,7 +630,17 @@ func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clie
|
|||
ClientAssertionType: clientAssertionType,
|
||||
GrantType: oidc.GrantTypeRefreshToken,
|
||||
}
|
||||
return client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp})
|
||||
newToken, err := client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokens, err := verifyTokenResponse[C](ctx, newToken, rp)
|
||||
if err == nil || errors.Is(err, ErrMissingIDToken) {
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse
|
||||
// ...except that it might not contain an id_token.
|
||||
return tokens, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) {
|
||||
|
|
107
pkg/client/rp/relying_party_test.go
Normal file
107
pkg/client/rp/relying_party_test.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
tu "github.com/zitadel/oidc/v3/internal/testutil"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func Test_verifyTokenResponse(t *testing.T) {
|
||||
verifier := &IDTokenVerifier{
|
||||
Issuer: tu.ValidIssuer,
|
||||
MaxAgeIAT: 2 * time.Minute,
|
||||
ClientID: tu.ValidClientID,
|
||||
Offset: time.Second,
|
||||
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
KeySet: tu.KeySet{},
|
||||
MaxAge: 2 * time.Minute,
|
||||
ACR: tu.ACRVerify,
|
||||
Nonce: func(context.Context) string { return tu.ValidNonce },
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
oauth2Only bool
|
||||
tokens func() (token *oauth2.Token, want *oidc.Tokens[*oidc.IDTokenClaims])
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "succes, oauth2 only",
|
||||
oauth2Only: true,
|
||||
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
|
||||
accesToken, _ := tu.ValidAccessToken()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: accesToken,
|
||||
}
|
||||
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: token,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "id_token missing error",
|
||||
oauth2Only: false,
|
||||
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
|
||||
accesToken, _ := tu.ValidAccessToken()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: accesToken,
|
||||
}
|
||||
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: token,
|
||||
}
|
||||
},
|
||||
wantErr: ErrMissingIDToken,
|
||||
},
|
||||
{
|
||||
name: "verify tokens error",
|
||||
oauth2Only: false,
|
||||
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
|
||||
accesToken, _ := tu.ValidAccessToken()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: accesToken,
|
||||
}
|
||||
token = token.WithExtra(map[string]any{
|
||||
"id_token": "foobar",
|
||||
})
|
||||
return token, nil
|
||||
},
|
||||
wantErr: oidc.ErrParse,
|
||||
},
|
||||
{
|
||||
name: "success, with id_token",
|
||||
oauth2Only: false,
|
||||
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
|
||||
accesToken, _ := tu.ValidAccessToken()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: accesToken,
|
||||
}
|
||||
idToken, claims := tu.ValidIDToken()
|
||||
token = token.WithExtra(map[string]any{
|
||||
"id_token": idToken,
|
||||
})
|
||||
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: token,
|
||||
IDTokenClaims: claims,
|
||||
IDToken: idToken,
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rp := &relyingParty{
|
||||
oauth2Only: tt.oauth2Only,
|
||||
idTokenVerifier: verifier,
|
||||
}
|
||||
token, want := tt.tokens()
|
||||
got, err := verifyTokenResponse[*oidc.IDTokenClaims](context.Background(), token, rp)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, want, got)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue