From dcd3f46f0274a37e43a9c0bee5a1a1a413c3b880 Mon Sep 17 00:00:00 2001 From: Stefan Benz Date: Sat, 25 Nov 2023 17:14:21 +0100 Subject: [PATCH] feat: add storage info to token responses --- pkg/oidc/token.go | 17 +++++++++------- pkg/op/server_http.go | 5 +++-- pkg/op/server_legacy.go | 31 +++++++++++++++++++++++++----- pkg/op/storage.go | 6 +++--- pkg/op/token.go | 14 ++++++++------ pkg/op/token_client_credentials.go | 3 ++- pkg/op/token_exchange.go | 4 +++- pkg/op/token_jwt_profile.go | 3 ++- 8 files changed, 57 insertions(+), 26 deletions(-) diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index b4cb6b6..4fccc9d 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -5,10 +5,11 @@ import ( "os" "time" - jose "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3" "golang.org/x/oauth2" "github.com/muhlemmer/gu" + "github.com/zitadel/oidc/v3/pkg/crypto" ) @@ -205,12 +206,13 @@ func (i *IDTokenClaims) UnmarshalJSON(data []byte) error { } type AccessTokenResponse struct { - AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"` - TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"` - RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"` - ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"` - IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"` - State string `json:"state,omitempty" schema:"state,omitempty"` + AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"` + TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"` + RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"` + ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"` + IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"` + State string `json:"state,omitempty" schema:"state,omitempty"` + StorageInfo map[string]string `json:"storage_info,omitempty" schema:"storage_info,omitempty"` } type JWTProfileAssertionClaims struct { @@ -352,4 +354,5 @@ type TokenExchangeResponse struct { ExpiresIn uint64 `json:"expires_in,omitempty"` Scopes SpaceDelimitedArray `json:"scope,omitempty"` RefreshToken string `json:"refresh_token,omitempty"` + StorageInfo map[string]string `json:"storage_info,omitempty" schema:"storage_info,omitempty"` } diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index 2220e44..879e11e 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -8,10 +8,11 @@ import ( "github.com/go-chi/chi/v5" "github.com/rs/cors" "github.com/zitadel/logging" - httphelper "github.com/zitadel/oidc/v3/pkg/http" - "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/schema" "golang.org/x/exp/slog" + + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" ) // RegisterServer registers an implementation of Server. diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go index deb1abc..df879e4 100644 --- a/pkg/op/server_legacy.go +++ b/pkg/op/server_legacy.go @@ -7,6 +7,7 @@ import ( "time" "github.com/go-chi/chi/v5" + "github.com/zitadel/oidc/v3/pkg/oidc" ) @@ -214,7 +215,11 @@ func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.A if err != nil { return nil, err } - return NewResponse(resp), nil + ret := NewResponse(resp) + for k, v := range resp.StorageInfo { + ret.Header.Add(k, v) + } + return ret, nil } func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) { @@ -256,7 +261,11 @@ func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfil if err != nil { return nil, err } - return NewResponse(resp), nil + ret := NewResponse(resp) + for k, v := range resp.StorageInfo { + ret.Header.Add(k, v) + } + return ret, nil } func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) { @@ -271,7 +280,11 @@ func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc. if err != nil { return nil, err } - return NewResponse(resp), nil + ret := NewResponse(resp) + for k, v := range resp.StorageInfo { + ret.Header.Add(k, v) + } + return ret, nil } func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) { @@ -287,7 +300,11 @@ func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientR if err != nil { return nil, err } - return NewResponse(resp), nil + ret := NewResponse(resp) + for k, v := range resp.StorageInfo { + ret.Header.Add(k, v) + } + return ret, nil } func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) { @@ -312,7 +329,11 @@ func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.De if err != nil { return nil, err } - return NewResponse(resp), nil + ret := NewResponse(resp) + for k, v := range resp.StorageInfo { + ret.Header.Add(k, v) + } + return ret, nil } func (s *LegacyServer) authenticateResourceClient(ctx context.Context, cc *ClientCredentials) (string, error) { diff --git a/pkg/op/storage.go b/pkg/op/storage.go index d083a31..932a8d0 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -5,7 +5,7 @@ import ( "errors" "time" - jose "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3" "github.com/zitadel/oidc/v3/pkg/oidc" ) @@ -27,7 +27,7 @@ type AuthStorage interface { // Grant: https://datatracker.ietf.org/doc/html/rfc7523#section-2.1 // // * TokenExchangeRequest as returned by ValidateTokenExchangeRequest - CreateAccessToken(context.Context, TokenRequest) (accessTokenID string, expiration time.Time, err error) + CreateAccessToken(context.Context, TokenRequest) (accessTokenID string, expiration time.Time, storageInfo map[string]string, err error) // The TokenRequest parameter of CreateAccessAndRefreshTokens can be any of: // @@ -40,7 +40,7 @@ type AuthStorage interface { // registered the refresh_token grant type in advance // // * TokenExchangeRequest as returned by ValidateTokenExchangeRequest - CreateAccessAndRefreshTokens(ctx context.Context, request TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshTokenID string, expiration time.Time, err error) + CreateAccessAndRefreshTokens(ctx context.Context, request TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshTokenID string, expiration time.Time, storageInfo map[string]string, err error) TokenRequestByRefreshToken(ctx context.Context, refreshTokenID string) (RefreshTokenRequest, error) TerminateSession(ctx context.Context, userID string, clientID string) error diff --git a/pkg/op/token.go b/pkg/op/token.go index 63a01a6..1028804 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -33,9 +33,10 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli var accessToken, newRefreshToken string var validity time.Duration + var storageInfo map[string]string if createAccessToken { var err error - accessToken, newRefreshToken, validity, err = CreateAccessToken(ctx, request, client.AccessTokenType(), creator, client, refreshToken) + accessToken, newRefreshToken, validity, storageInfo, err = CreateAccessToken(ctx, request, client.AccessTokenType(), creator, client, refreshToken) if err != nil { return nil, err } @@ -65,14 +66,15 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli TokenType: oidc.BearerToken, ExpiresIn: exp, State: state, + StorageInfo: storageInfo, }, nil } -func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage, refreshToken string, client AccessTokenClient) (id, newRefreshToken string, exp time.Time, err error) { +func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage, refreshToken string, client AccessTokenClient) (id, newRefreshToken string, exp time.Time, storageInfo map[string]string, err error) { if needsRefreshToken(tokenRequest, client) { return storage.CreateAccessAndRefreshTokens(ctx, tokenRequest, refreshToken) } - id, exp, err = storage.CreateAccessToken(ctx, tokenRequest) + id, exp, storageInfo, err = storage.CreateAccessToken(ctx, tokenRequest) return } @@ -89,13 +91,13 @@ func needsRefreshToken(tokenRequest TokenRequest, client AccessTokenClient) bool } } -func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client AccessTokenClient, refreshToken string) (accessToken, newRefreshToken string, validity time.Duration, err error) { +func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client AccessTokenClient, refreshToken string) (accessToken, newRefreshToken string, validity time.Duration, storageInfo map[string]string, err error) { ctx, span := tracer.Start(ctx, "CreateAccessToken") defer span.End() - id, newRefreshToken, exp, err := createTokens(ctx, tokenRequest, creator.Storage(), refreshToken, client) + id, newRefreshToken, exp, storageInfo, err := createTokens(ctx, tokenRequest, creator.Storage(), refreshToken, client) if err != nil { - return "", "", 0, err + return "", "", 0, nil, err } var clockSkew time.Duration if client != nil { diff --git a/pkg/op/token_client_credentials.go b/pkg/op/token_client_credentials.go index 7f1debe..f692197 100644 --- a/pkg/op/token_client_credentials.go +++ b/pkg/op/token_client_credentials.go @@ -111,7 +111,7 @@ func CreateClientCredentialsTokenResponse(ctx context.Context, tokenRequest Toke ctx, span := tracer.Start(ctx, "CreateClientCredentialsTokenResponse") defer span.End() - accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "") + accessToken, _, validity, storageInfo, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "") if err != nil { return nil, err } @@ -120,5 +120,6 @@ func CreateClientCredentialsTokenResponse(ctx context.Context, tokenRequest Toke AccessToken: accessToken, TokenType: oidc.BearerToken, ExpiresIn: uint64(validity.Seconds()), + StorageInfo: storageInfo, }, nil } diff --git a/pkg/op/token_exchange.go b/pkg/op/token_exchange.go index db3e468..42bc1f9 100644 --- a/pkg/op/token_exchange.go +++ b/pkg/op/token_exchange.go @@ -363,11 +363,12 @@ func CreateTokenExchangeResponse( var ( token, refreshToken, tokenType string validity time.Duration + storageInfo map[string]string ) switch tokenExchangeRequest.GetRequestedTokenType() { case oidc.AccessTokenType, oidc.RefreshTokenType: - token, refreshToken, validity, err = CreateAccessToken(ctx, tokenExchangeRequest, client.AccessTokenType(), creator, client, "") + token, refreshToken, validity, storageInfo, err = CreateAccessToken(ctx, tokenExchangeRequest, client.AccessTokenType(), creator, client, "") if err != nil { return nil, err } @@ -396,6 +397,7 @@ func CreateTokenExchangeResponse( ExpiresIn: exp, RefreshToken: refreshToken, Scopes: tokenExchangeRequest.GetScopes(), + StorageInfo: storageInfo, }, nil } diff --git a/pkg/op/token_jwt_profile.go b/pkg/op/token_jwt_profile.go index 96ce1ed..4619dc0 100644 --- a/pkg/op/token_jwt_profile.go +++ b/pkg/op/token_jwt_profile.go @@ -81,7 +81,7 @@ func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, crea } } - accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "") + accessToken, _, validity, storageInfo, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "") if err != nil { return nil, err } @@ -89,6 +89,7 @@ func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, crea AccessToken: accessToken, TokenType: oidc.BearerToken, ExpiresIn: uint64(validity.Seconds()), + StorageInfo: storageInfo, }, nil }