feat: add storage info to token responses

This commit is contained in:
Stefan Benz 2023-11-25 17:14:21 +01:00
parent a8ef8de87b
commit dcd3f46f02
No known key found for this signature in database
GPG key ID: 9D2FE4EA50BEFE68
8 changed files with 57 additions and 26 deletions

View file

@ -5,10 +5,11 @@ import (
"os"
"time"
jose "github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3"
"golang.org/x/oauth2"
"github.com/muhlemmer/gu"
"github.com/zitadel/oidc/v3/pkg/crypto"
)
@ -205,12 +206,13 @@ func (i *IDTokenClaims) UnmarshalJSON(data []byte) error {
}
type AccessTokenResponse struct {
AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
State string `json:"state,omitempty" schema:"state,omitempty"`
AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
State string `json:"state,omitempty" schema:"state,omitempty"`
StorageInfo map[string]string `json:"storage_info,omitempty" schema:"storage_info,omitempty"`
}
type JWTProfileAssertionClaims struct {
@ -352,4 +354,5 @@ type TokenExchangeResponse struct {
ExpiresIn uint64 `json:"expires_in,omitempty"`
Scopes SpaceDelimitedArray `json:"scope,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
StorageInfo map[string]string `json:"storage_info,omitempty" schema:"storage_info,omitempty"`
}

View file

@ -8,10 +8,11 @@ import (
"github.com/go-chi/chi/v5"
"github.com/rs/cors"
"github.com/zitadel/logging"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/schema"
"golang.org/x/exp/slog"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
// RegisterServer registers an implementation of Server.

View file

@ -7,6 +7,7 @@ import (
"time"
"github.com/go-chi/chi/v5"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
@ -214,7 +215,11 @@ func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.A
if err != nil {
return nil, err
}
return NewResponse(resp), nil
ret := NewResponse(resp)
for k, v := range resp.StorageInfo {
ret.Header.Add(k, v)
}
return ret, nil
}
func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) {
@ -256,7 +261,11 @@ func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfil
if err != nil {
return nil, err
}
return NewResponse(resp), nil
ret := NewResponse(resp)
for k, v := range resp.StorageInfo {
ret.Header.Add(k, v)
}
return ret, nil
}
func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) {
@ -271,7 +280,11 @@ func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.
if err != nil {
return nil, err
}
return NewResponse(resp), nil
ret := NewResponse(resp)
for k, v := range resp.StorageInfo {
ret.Header.Add(k, v)
}
return ret, nil
}
func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) {
@ -287,7 +300,11 @@ func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientR
if err != nil {
return nil, err
}
return NewResponse(resp), nil
ret := NewResponse(resp)
for k, v := range resp.StorageInfo {
ret.Header.Add(k, v)
}
return ret, nil
}
func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
@ -312,7 +329,11 @@ func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.De
if err != nil {
return nil, err
}
return NewResponse(resp), nil
ret := NewResponse(resp)
for k, v := range resp.StorageInfo {
ret.Header.Add(k, v)
}
return ret, nil
}
func (s *LegacyServer) authenticateResourceClient(ctx context.Context, cc *ClientCredentials) (string, error) {

View file

@ -5,7 +5,7 @@ import (
"errors"
"time"
jose "github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
@ -27,7 +27,7 @@ type AuthStorage interface {
// Grant: https://datatracker.ietf.org/doc/html/rfc7523#section-2.1
//
// * TokenExchangeRequest as returned by ValidateTokenExchangeRequest
CreateAccessToken(context.Context, TokenRequest) (accessTokenID string, expiration time.Time, err error)
CreateAccessToken(context.Context, TokenRequest) (accessTokenID string, expiration time.Time, storageInfo map[string]string, err error)
// The TokenRequest parameter of CreateAccessAndRefreshTokens can be any of:
//
@ -40,7 +40,7 @@ type AuthStorage interface {
// registered the refresh_token grant type in advance
//
// * TokenExchangeRequest as returned by ValidateTokenExchangeRequest
CreateAccessAndRefreshTokens(ctx context.Context, request TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshTokenID string, expiration time.Time, err error)
CreateAccessAndRefreshTokens(ctx context.Context, request TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshTokenID string, expiration time.Time, storageInfo map[string]string, err error)
TokenRequestByRefreshToken(ctx context.Context, refreshTokenID string) (RefreshTokenRequest, error)
TerminateSession(ctx context.Context, userID string, clientID string) error

View file

@ -33,9 +33,10 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli
var accessToken, newRefreshToken string
var validity time.Duration
var storageInfo map[string]string
if createAccessToken {
var err error
accessToken, newRefreshToken, validity, err = CreateAccessToken(ctx, request, client.AccessTokenType(), creator, client, refreshToken)
accessToken, newRefreshToken, validity, storageInfo, err = CreateAccessToken(ctx, request, client.AccessTokenType(), creator, client, refreshToken)
if err != nil {
return nil, err
}
@ -65,14 +66,15 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli
TokenType: oidc.BearerToken,
ExpiresIn: exp,
State: state,
StorageInfo: storageInfo,
}, nil
}
func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage, refreshToken string, client AccessTokenClient) (id, newRefreshToken string, exp time.Time, err error) {
func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage, refreshToken string, client AccessTokenClient) (id, newRefreshToken string, exp time.Time, storageInfo map[string]string, err error) {
if needsRefreshToken(tokenRequest, client) {
return storage.CreateAccessAndRefreshTokens(ctx, tokenRequest, refreshToken)
}
id, exp, err = storage.CreateAccessToken(ctx, tokenRequest)
id, exp, storageInfo, err = storage.CreateAccessToken(ctx, tokenRequest)
return
}
@ -89,13 +91,13 @@ func needsRefreshToken(tokenRequest TokenRequest, client AccessTokenClient) bool
}
}
func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client AccessTokenClient, refreshToken string) (accessToken, newRefreshToken string, validity time.Duration, err error) {
func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client AccessTokenClient, refreshToken string) (accessToken, newRefreshToken string, validity time.Duration, storageInfo map[string]string, err error) {
ctx, span := tracer.Start(ctx, "CreateAccessToken")
defer span.End()
id, newRefreshToken, exp, err := createTokens(ctx, tokenRequest, creator.Storage(), refreshToken, client)
id, newRefreshToken, exp, storageInfo, err := createTokens(ctx, tokenRequest, creator.Storage(), refreshToken, client)
if err != nil {
return "", "", 0, err
return "", "", 0, nil, err
}
var clockSkew time.Duration
if client != nil {

View file

@ -111,7 +111,7 @@ func CreateClientCredentialsTokenResponse(ctx context.Context, tokenRequest Toke
ctx, span := tracer.Start(ctx, "CreateClientCredentialsTokenResponse")
defer span.End()
accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "")
accessToken, _, validity, storageInfo, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "")
if err != nil {
return nil, err
}
@ -120,5 +120,6 @@ func CreateClientCredentialsTokenResponse(ctx context.Context, tokenRequest Toke
AccessToken: accessToken,
TokenType: oidc.BearerToken,
ExpiresIn: uint64(validity.Seconds()),
StorageInfo: storageInfo,
}, nil
}

View file

@ -363,11 +363,12 @@ func CreateTokenExchangeResponse(
var (
token, refreshToken, tokenType string
validity time.Duration
storageInfo map[string]string
)
switch tokenExchangeRequest.GetRequestedTokenType() {
case oidc.AccessTokenType, oidc.RefreshTokenType:
token, refreshToken, validity, err = CreateAccessToken(ctx, tokenExchangeRequest, client.AccessTokenType(), creator, client, "")
token, refreshToken, validity, storageInfo, err = CreateAccessToken(ctx, tokenExchangeRequest, client.AccessTokenType(), creator, client, "")
if err != nil {
return nil, err
}
@ -396,6 +397,7 @@ func CreateTokenExchangeResponse(
ExpiresIn: exp,
RefreshToken: refreshToken,
Scopes: tokenExchangeRequest.GetScopes(),
StorageInfo: storageInfo,
}, nil
}

View file

@ -81,7 +81,7 @@ func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, crea
}
}
accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "")
accessToken, _, validity, storageInfo, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "")
if err != nil {
return nil, err
}
@ -89,6 +89,7 @@ func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, crea
AccessToken: accessToken,
TokenType: oidc.BearerToken,
ExpiresIn: uint64(validity.Seconds()),
StorageInfo: storageInfo,
}, nil
}