zitadel-oidc/pkg/op/token.go
Marc Alvarez 8e0e489a67 fix: clarify refresh token parameter names and improve code readability
- Rename misleading parameters in Storage interface:
  - refreshTokenID → refreshToken
  - newRefreshTokenID → newRefreshToken

- Make bare returns explicit in token.go for better readability

- Add documentation to clarify token creation flow

The interface parameters were misleadingly named with 'ID' suffix when they
actually contain the full token values. The example implementations already
used the semantically correct names, creating inconsistency. This change
aligns the interface with its implementations and prevents confusion.
2025-06-12 15:52:06 -06:00

283 lines
9.3 KiB
Go

package op
import (
"context"
"slices"
"time"
"github.com/zitadel/oidc/v3/pkg/crypto"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
type TokenCreator interface {
Storage() Storage
Crypto() Crypto
}
type TokenRequest interface {
GetSubject() string
GetAudience() []string
GetScopes() []string
}
type AccessTokenClient interface {
GetID() string
ClockSkew() time.Duration
RestrictAdditionalAccessTokenScopes() func(scopes []string) []string
GrantTypes() []oidc.GrantType
}
func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Client, creator TokenCreator, createAccessToken bool, code, refreshToken string) (*oidc.AccessTokenResponse, error) {
ctx, span := tracer.Start(ctx, "CreateTokenResponse")
defer span.End()
var accessToken, newRefreshToken string
var validity time.Duration
if createAccessToken {
var err error
accessToken, newRefreshToken, validity, err = CreateAccessToken(ctx, request, client.AccessTokenType(), creator, client, refreshToken)
if err != nil {
return nil, err
}
}
idToken, err := CreateIDToken(ctx, IssuerFromContext(ctx), request, client.IDTokenLifetime(), accessToken, code, creator.Storage(), client)
if err != nil {
return nil, err
}
var state string
if authRequest, ok := request.(AuthRequest); ok {
err = creator.Storage().DeleteAuthRequest(ctx, authRequest.GetID())
if err != nil {
return nil, err
}
// only implicit flow requires state to be returned.
if code == "" {
state = authRequest.GetState()
}
}
exp := uint64(validity.Seconds())
return &oidc.AccessTokenResponse{
AccessToken: accessToken,
IDToken: idToken,
RefreshToken: newRefreshToken,
TokenType: oidc.BearerToken,
ExpiresIn: exp,
State: state,
Scope: request.GetScopes(),
}, nil
}
// createTokens delegates token creation to the appropriate storage method based on
// the request type and requirements. It returns an access token ID and expiration
// in all cases, but the refresh token handling varies:
// - When needsRefreshToken() returns true: calls CreateAccessAndRefreshTokens,
// which returns both tokens. The newRefreshToken will contain the actual token value.
// - When needsRefreshToken() returns false: calls CreateAccessToken only.
// The newRefreshToken will be an empty string in this case.
func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage, refreshToken string, client AccessTokenClient) (id, newRefreshToken string, exp time.Time, err error) {
ctx, span := tracer.Start(ctx, "createTokens")
defer span.End()
if needsRefreshToken(tokenRequest, client) {
return storage.CreateAccessAndRefreshTokens(ctx, tokenRequest, refreshToken)
}
id, exp, err = storage.CreateAccessToken(ctx, tokenRequest)
return id, "", exp, err
}
func needsRefreshToken(tokenRequest TokenRequest, client AccessTokenClient) bool {
switch req := tokenRequest.(type) {
case AuthRequest:
return slices.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode && ValidateGrantType(client, oidc.GrantTypeRefreshToken)
case TokenExchangeRequest:
return req.GetRequestedTokenType() == oidc.RefreshTokenType
case RefreshTokenRequest:
return true
case *DeviceAuthorizationState:
return slices.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && ValidateGrantType(client, oidc.GrantTypeRefreshToken)
default:
return false
}
}
// CreateAccessToken creates an access token and may return a refresh token from storage.
// This function always creates the access token using the ID returned from storage.
// The refresh token is obtained from the storage layer and passed through unchanged.
// Whether a refresh token is included depends on the request:
// - Authorization code flow with offline_access scope: returns refresh token
// - Refresh token grant (rotation): returns new refresh token
// - Client credentials, implicit flow: returns empty string
//
// The function returns both tokens to support all flows with a single signature.
func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client AccessTokenClient, refreshToken string) (accessToken, newRefreshToken string, validity time.Duration, err error) {
ctx, span := tracer.Start(ctx, "CreateAccessToken")
defer span.End()
id, newRefreshToken, exp, err := createTokens(ctx, tokenRequest, creator.Storage(), refreshToken, client)
if err != nil {
return "", "", 0, err
}
var clockSkew time.Duration
if client != nil {
clockSkew = client.ClockSkew()
}
validity = exp.Add(clockSkew).Sub(time.Now().UTC())
if accessTokenType == AccessTokenTypeJWT {
accessToken, err = CreateJWT(ctx, IssuerFromContext(ctx), tokenRequest, exp, id, client, creator.Storage())
return accessToken, newRefreshToken, validity, err
}
_, span = tracer.Start(ctx, "CreateBearerToken")
accessToken, err = CreateBearerToken(id, tokenRequest.GetSubject(), creator.Crypto())
span.End()
return accessToken, newRefreshToken, validity, err
}
func CreateBearerToken(tokenID, subject string, crypto Crypto) (string, error) {
return crypto.Encrypt(tokenID + ":" + subject)
}
type TokenActorRequest interface {
GetActor() *oidc.ActorClaims
}
func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, client AccessTokenClient, storage Storage) (string, error) {
ctx, span := tracer.Start(ctx, "CreateJWT")
defer span.End()
claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id, client.GetID(), client.ClockSkew())
if client != nil {
restrictedScopes := client.RestrictAdditionalAccessTokenScopes()(tokenRequest.GetScopes())
var (
privateClaims map[string]any
err error
)
tokenExchangeRequest, okReq := tokenRequest.(TokenExchangeRequest)
teStorage, okStorage := storage.(TokenExchangeStorage)
if okReq && okStorage {
privateClaims, err = teStorage.GetPrivateClaimsFromTokenExchangeRequest(
ctx,
tokenExchangeRequest,
)
} else {
if fromRequest, ok := storage.(CanGetPrivateClaimsFromRequest); ok {
privateClaims, err = fromRequest.GetPrivateClaimsFromRequest(ctx, tokenRequest, removeUserinfoScopes(restrictedScopes))
} else {
privateClaims, err = storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(restrictedScopes))
}
}
if err != nil {
return "", err
}
claims.Claims = privateClaims
}
if actorReq, ok := tokenRequest.(TokenActorRequest); ok {
claims.Actor = actorReq.GetActor()
}
signingKey, err := storage.SigningKey(ctx)
if err != nil {
return "", err
}
signer, err := SignerFromKey(signingKey)
if err != nil {
return "", err
}
return crypto.Sign(claims, signer)
}
type IDTokenRequest interface {
GetAMR() []string
GetAudience() []string
GetAuthTime() time.Time
GetClientID() string
GetScopes() []string
GetSubject() string
}
func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, validity time.Duration, accessToken, code string, storage Storage, client Client) (string, error) {
ctx, span := tracer.Start(ctx, "CreateIDToken")
defer span.End()
exp := time.Now().UTC().Add(client.ClockSkew()).Add(validity)
var acr, nonce string
if authRequest, ok := request.(AuthRequest); ok {
acr = authRequest.GetACR()
nonce = authRequest.GetNonce()
}
claims := oidc.NewIDTokenClaims(issuer, request.GetSubject(), request.GetAudience(), exp, request.GetAuthTime(), nonce, acr, request.GetAMR(), request.GetClientID(), client.ClockSkew())
if actorReq, ok := request.(TokenActorRequest); ok {
claims.Actor = actorReq.GetActor()
}
scopes := client.RestrictAdditionalIdTokenScopes()(request.GetScopes())
signingKey, err := storage.SigningKey(ctx)
if err != nil {
return "", err
}
if accessToken != "" {
atHash, err := oidc.ClaimHash(accessToken, signingKey.SignatureAlgorithm())
if err != nil {
return "", err
}
claims.AccessTokenHash = atHash
if !client.IDTokenUserinfoClaimsAssertion() {
scopes = removeUserinfoScopes(scopes)
}
}
tokenExchangeRequest, okReq := request.(TokenExchangeRequest)
teStorage, okStorage := storage.(TokenExchangeStorage)
if okReq && okStorage {
userInfo := new(oidc.UserInfo)
err := teStorage.SetUserinfoFromTokenExchangeRequest(ctx, userInfo, tokenExchangeRequest)
if err != nil {
return "", err
}
claims.SetUserInfo(userInfo)
} else if len(scopes) > 0 {
userInfo := new(oidc.UserInfo)
err := storage.SetUserinfoFromScopes(ctx, userInfo, request.GetSubject(), request.GetClientID(), scopes)
if err != nil {
return "", err
}
if fromRequest, ok := storage.(CanSetUserinfoFromRequest); ok {
err := fromRequest.SetUserinfoFromRequest(ctx, userInfo, request, scopes)
if err != nil {
return "", err
}
}
claims.SetUserInfo(userInfo)
}
if code != "" {
codeHash, err := oidc.ClaimHash(code, signingKey.SignatureAlgorithm())
if err != nil {
return "", err
}
claims.CodeHash = codeHash
}
signer, err := SignerFromKey(signingKey)
if err != nil {
return "", err
}
return crypto.Sign(claims, signer)
}
func removeUserinfoScopes(scopes []string) []string {
newScopeList := make([]string, 0, len(scopes))
for _, scope := range scopes {
switch scope {
case oidc.ScopeProfile,
oidc.ScopeEmail,
oidc.ScopeAddress,
oidc.ScopePhone:
continue
default:
newScopeList = append(newScopeList, scope)
}
}
return newScopeList
}