feat(client): tracing in rp
This commit is contained in:
parent
d18aba8cb3
commit
bdcccc3303
5 changed files with 65 additions and 0 deletions
|
@ -33,6 +33,9 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.
|
||||||
// in RFC 8628, section 3.1 and 3.2:
|
// in RFC 8628, section 3.1 and 3.2:
|
||||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
|
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
|
||||||
func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, authFn any) (*oidc.DeviceAuthorizationResponse, error) {
|
func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, authFn any) (*oidc.DeviceAuthorizationResponse, error) {
|
||||||
|
ctx, span := tracer.Start(ctx, "DeviceAuthorization")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAuthorization")
|
ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAuthorization")
|
||||||
req, err := newDeviceClientCredentialsRequest(scopes, rp)
|
req, err := newDeviceClientCredentialsRequest(scopes, rp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -46,6 +49,9 @@ func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty,
|
||||||
// by means of polling as defined in RFC, section 3.3 and 3.4:
|
// by means of polling as defined in RFC, section 3.3 and 3.4:
|
||||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4
|
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4
|
||||||
func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
|
func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
|
||||||
|
ctx, span := tracer.Start(ctx, "DeviceAccessToken")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAccessToken")
|
ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAccessToken")
|
||||||
req := &client.DeviceAccessTokenRequest{
|
req := &client.DeviceAccessTokenRequest{
|
||||||
DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
|
DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
|
||||||
|
|
|
@ -83,6 +83,9 @@ func (i *inflight) result() ([]jose.JSONWebKey, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||||
|
ctx, span := tracer.Start(ctx, "VerifySignature")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
keyID, alg := oidc.GetKeyIDAndAlg(jws)
|
keyID, alg := oidc.GetKeyIDAndAlg(jws)
|
||||||
if alg == "" {
|
if alg == "" {
|
||||||
alg = r.defaultAlg
|
alg = r.defaultAlg
|
||||||
|
@ -135,6 +138,9 @@ func (r *remoteKeySet) exactMatch(jwkID, jwsID string) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *remoteKeySet) verifySignatureRemote(ctx context.Context, jws *jose.JSONWebSignature, keyID, alg string) ([]byte, error) {
|
func (r *remoteKeySet) verifySignatureRemote(ctx context.Context, jws *jose.JSONWebSignature, keyID, alg string) ([]byte, error) {
|
||||||
|
ctx, span := tracer.Start(ctx, "verifySignatureRemote")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
keys, err := r.keysFromRemote(ctx)
|
keys, err := r.keysFromRemote(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to fetch key for signature validation: %w", err)
|
return nil, fmt.Errorf("unable to fetch key for signature validation: %w", err)
|
||||||
|
@ -159,6 +165,9 @@ func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {
|
||||||
// keysFromRemote syncs the key set from the remote set, records the values in the
|
// keysFromRemote syncs the key set from the remote set, records the values in the
|
||||||
// cache, and returns the key set.
|
// cache, and returns the key set.
|
||||||
func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
|
func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
|
||||||
|
ctx, span := tracer.Start(ctx, "keysFromRemote")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
// Need to lock to inspect the inflight request field.
|
// Need to lock to inspect the inflight request field.
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
// If there's not a current inflight request, create one.
|
// If there's not a current inflight request, create one.
|
||||||
|
@ -182,6 +191,9 @@ func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, e
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *remoteKeySet) updateKeys(ctx context.Context) {
|
func (r *remoteKeySet) updateKeys(ctx context.Context) {
|
||||||
|
ctx, span := tracer.Start(ctx, "updateKeys")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
// Sync keys and finish inflight when that's done.
|
// Sync keys and finish inflight when that's done.
|
||||||
keys, err := r.fetchRemoteKeys(ctx)
|
keys, err := r.fetchRemoteKeys(ctx)
|
||||||
|
|
||||||
|
@ -201,6 +213,9 @@ func (r *remoteKeySet) updateKeys(ctx context.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey, error) {
|
func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey, error) {
|
||||||
|
ctx, span := tracer.Start(ctx, "fetchRemoteKeys")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", r.jwksURL, nil)
|
req, err := http.NewRequest("GET", r.jwksURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oidc: can't create request: %v", err)
|
return nil, fmt.Errorf("oidc: can't create request: %v", err)
|
||||||
|
|
|
@ -12,6 +12,8 @@ import (
|
||||||
"github.com/go-jose/go-jose/v3"
|
"github.com/go-jose/go-jose/v3"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/zitadel/logging"
|
"github.com/zitadel/logging"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
|
"go.opentelemetry.io/otel/trace"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/clientcredentials"
|
"golang.org/x/oauth2/clientcredentials"
|
||||||
|
|
||||||
|
@ -28,6 +30,12 @@ const (
|
||||||
|
|
||||||
var ErrUserInfoSubNotMatching = errors.New("sub from userinfo does not match the sub from the id_token")
|
var ErrUserInfoSubNotMatching = errors.New("sub from userinfo does not match the sub from the id_token")
|
||||||
|
|
||||||
|
var tracer trace.Tracer
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
tracer = otel.Tracer("github.com/zitadel/oidc/pkg/client/rp")
|
||||||
|
}
|
||||||
|
|
||||||
// RelyingParty declares the minimal interface for oidc clients
|
// RelyingParty declares the minimal interface for oidc clients
|
||||||
type RelyingParty interface {
|
type RelyingParty interface {
|
||||||
// OAuthConfig returns the oauth2 Config
|
// OAuthConfig returns the oauth2 Config
|
||||||
|
@ -428,6 +436,9 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri
|
||||||
var ErrMissingIDToken = errors.New("id_token missing")
|
var ErrMissingIDToken = errors.New("id_token missing")
|
||||||
|
|
||||||
func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Token, rp RelyingParty) (*oidc.Tokens[C], error) {
|
func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Token, rp RelyingParty) (*oidc.Tokens[C], error) {
|
||||||
|
ctx, span := tracer.Start(ctx, "verifyTokenResponse")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
if rp.IsOAuth2Only() {
|
if rp.IsOAuth2Only() {
|
||||||
return &oidc.Tokens[C]{Token: token}, nil
|
return &oidc.Tokens[C]{Token: token}, nil
|
||||||
}
|
}
|
||||||
|
@ -445,6 +456,9 @@ func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Tok
|
||||||
// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
|
// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
|
||||||
// returning it parsed together with the oauth2 tokens (access, refresh)
|
// returning it parsed together with the oauth2 tokens (access, refresh)
|
||||||
func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
|
func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
|
||||||
|
ctx, codeExchangeSpan := tracer.Start(ctx, "CodeExchange")
|
||||||
|
defer codeExchangeSpan.End()
|
||||||
|
|
||||||
ctx = logCtxWithRPData(ctx, rp, "function", "CodeExchange")
|
ctx = logCtxWithRPData(ctx, rp, "function", "CodeExchange")
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
|
||||||
codeOpts := make([]oauth2.AuthCodeOption, 0)
|
codeOpts := make([]oauth2.AuthCodeOption, 0)
|
||||||
|
@ -452,10 +466,12 @@ func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingP
|
||||||
codeOpts = append(codeOpts, opt()...)
|
codeOpts = append(codeOpts, opt()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, oauthExchangeSpan := tracer.Start(ctx, "OAuthExchange")
|
||||||
token, err := rp.OAuthConfig().Exchange(ctx, code, codeOpts...)
|
token, err := rp.OAuthConfig().Exchange(ctx, code, codeOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
oauthExchangeSpan.End()
|
||||||
return verifyTokenResponse[C](ctx, token, rp)
|
return verifyTokenResponse[C](ctx, token, rp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -469,6 +485,9 @@ func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingP
|
||||||
// [RFC 6749, section 4.4]: https://datatracker.ietf.org/doc/html/rfc6749#section-4.4
|
// [RFC 6749, section 4.4]: https://datatracker.ietf.org/doc/html/rfc6749#section-4.4
|
||||||
func ClientCredentials(ctx context.Context, rp RelyingParty, endpointParams url.Values) (token *oauth2.Token, err error) {
|
func ClientCredentials(ctx context.Context, rp RelyingParty, endpointParams url.Values) (token *oauth2.Token, err error) {
|
||||||
ctx = logCtxWithRPData(ctx, rp, "function", "ClientCredentials")
|
ctx = logCtxWithRPData(ctx, rp, "function", "ClientCredentials")
|
||||||
|
ctx, span := tracer.Start(ctx, "ClientCredentials")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
|
||||||
config := clientcredentials.Config{
|
config := clientcredentials.Config{
|
||||||
ClientID: rp.OAuthConfig().ClientID,
|
ClientID: rp.OAuthConfig().ClientID,
|
||||||
|
@ -489,6 +508,10 @@ type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.R
|
||||||
// Custom parameters can optionally be set to the token URL.
|
// Custom parameters can optionally be set to the token URL.
|
||||||
func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
|
func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, span := tracer.Start(r.Context(), "CodeExchangeHandler")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
state, err := tryReadStateCookie(w, r, rp)
|
state, err := tryReadStateCookie(w, r, rp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
unauthorizedError(w, r, "failed to get state: "+err.Error(), state, rp)
|
unauthorizedError(w, r, "failed to get state: "+err.Error(), state, rp)
|
||||||
|
@ -540,6 +563,10 @@ type CodeExchangeUserinfoCallback[C oidc.IDClaims, U SubjectGetter] func(w http.
|
||||||
// on success it will pass the userinfo into its callback function as well
|
// on success it will pass the userinfo into its callback function as well
|
||||||
func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCallback[C, U]) CodeExchangeCallback[C] {
|
func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCallback[C, U]) CodeExchangeCallback[C] {
|
||||||
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
|
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
|
||||||
|
ctx, span := tracer.Start(r.Context(), "UserinfoCallback")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
|
info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
unauthorizedError(w, r, "userinfo failed: "+err.Error(), state, rp)
|
unauthorizedError(w, r, "userinfo failed: "+err.Error(), state, rp)
|
||||||
|
@ -558,6 +585,8 @@ func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCa
|
||||||
func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject string, rp RelyingParty) (userinfo U, err error) {
|
func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject string, rp RelyingParty) (userinfo U, err error) {
|
||||||
var nilU U
|
var nilU U
|
||||||
ctx = logCtxWithRPData(ctx, rp, "function", "Userinfo")
|
ctx = logCtxWithRPData(ctx, rp, "function", "Userinfo")
|
||||||
|
ctx, span := tracer.Start(ctx, "Userinfo")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -716,6 +745,9 @@ type RefreshTokenRequest struct {
|
||||||
// the IDToken and AccessToken will be verfied
|
// the IDToken and AccessToken will be verfied
|
||||||
// and the IDToken and IDTokenClaims fields will be populated in the returned object.
|
// and the IDToken and IDTokenClaims fields will be populated in the returned object.
|
||||||
func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) {
|
func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) {
|
||||||
|
ctx, span := tracer.Start(ctx, "RefreshTokens")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
ctx = logCtxWithRPData(ctx, rp, "function", "RefreshTokens")
|
ctx = logCtxWithRPData(ctx, rp, "function", "RefreshTokens")
|
||||||
request := RefreshTokenRequest{
|
request := RefreshTokenRequest{
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
|
@ -741,6 +773,9 @@ func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refres
|
||||||
|
|
||||||
func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) {
|
func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) {
|
||||||
ctx = logCtxWithRPData(ctx, rp, "function", "EndSession")
|
ctx = logCtxWithRPData(ctx, rp, "function", "EndSession")
|
||||||
|
ctx, span := tracer.Start(ctx, "RefreshTokens")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
request := oidc.EndSessionRequest{
|
request := oidc.EndSessionRequest{
|
||||||
IdTokenHint: idToken,
|
IdTokenHint: idToken,
|
||||||
ClientID: rp.OAuthConfig().ClientID,
|
ClientID: rp.OAuthConfig().ClientID,
|
||||||
|
@ -757,6 +792,8 @@ func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectU
|
||||||
// tokenTypeHint should be either "id_token" or "refresh_token".
|
// tokenTypeHint should be either "id_token" or "refresh_token".
|
||||||
func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error {
|
func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error {
|
||||||
ctx = logCtxWithRPData(ctx, rp, "function", "RevokeToken")
|
ctx = logCtxWithRPData(ctx, rp, "function", "RevokeToken")
|
||||||
|
ctx, span := tracer.Start(ctx, "RefreshTokens")
|
||||||
|
defer span.End()
|
||||||
request := client.RevokeRequest{
|
request := client.RevokeRequest{
|
||||||
Token: token,
|
Token: token,
|
||||||
TokenTypeHint: tokenTypeHint,
|
TokenTypeHint: tokenTypeHint,
|
||||||
|
|
|
@ -12,6 +12,9 @@ import (
|
||||||
// VerifyTokens implement the Token Response Validation as defined in OIDC specification
|
// VerifyTokens implement the Token Response Validation as defined in OIDC specification
|
||||||
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
||||||
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier) (claims C, err error) {
|
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier) (claims C, err error) {
|
||||||
|
ctx, span := tracer.Start(ctx, "VerifyTokens")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
var nilClaims C
|
var nilClaims C
|
||||||
|
|
||||||
claims, err = VerifyIDToken[C](ctx, idToken, v)
|
claims, err = VerifyIDToken[C](ctx, idToken, v)
|
||||||
|
@ -27,6 +30,9 @@ func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken str
|
||||||
// VerifyIDToken validates the id token according to
|
// VerifyIDToken validates the id token according to
|
||||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||||
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier) (claims C, err error) {
|
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier) (claims C, err error) {
|
||||||
|
ctx, span := tracer.Start(ctx, "VerifyIDToken")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
var nilClaims C
|
var nilClaims C
|
||||||
|
|
||||||
decrypted, err := oidc.DecryptToken(token)
|
decrypted, err := oidc.DecryptToken(token)
|
||||||
|
|
|
@ -135,6 +135,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router
|
||||||
} else {
|
} else {
|
||||||
router.Use(cors.New(defaultCORSOptions).Handler)
|
router.Use(cors.New(defaultCORSOptions).Handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
router.Use(intercept(o.IssuerFromRequest, interceptors...))
|
router.Use(intercept(o.IssuerFromRequest, interceptors...))
|
||||||
router.HandleFunc(healthEndpoint, healthHandler)
|
router.HandleFunc(healthEndpoint, healthHandler)
|
||||||
router.HandleFunc(readinessEndpoint, readyHandler(o.Probes()))
|
router.HandleFunc(readinessEndpoint, readyHandler(o.Probes()))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue