diff --git a/go.mod b/go.mod index c7562ea..bd7a0e5 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,6 @@ require ( github.com/zitadel/logging v0.6.0 github.com/zitadel/schema v1.3.0 go.opentelemetry.io/otel v1.24.0 - go.opentelemetry.io/otel/trace v1.24.0 golang.org/x/oauth2 v0.18.0 golang.org/x/text v0.14.0 ) @@ -32,10 +31,11 @@ require ( github.com/google/go-querystring v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect go.opentelemetry.io/otel/metric v1.24.0 // indirect + go.opentelemetry.io/otel/trace v1.24.0 // indirect golang.org/x/crypto v0.21.0 // indirect golang.org/x/net v0.22.0 // indirect golang.org/x/sys v0.18.0 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/protobuf v1.31.0 // indirect + google.golang.org/protobuf v1.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 2b1eb7e..58029da 100644 --- a/go.sum +++ b/go.sum @@ -136,8 +136,8 @@ google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6 google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/pkg/client/client.go b/pkg/client/client.go index b329b3d..8b60264 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -12,19 +12,25 @@ import ( "time" jose "github.com/go-jose/go-jose/v3" - "golang.org/x/oauth2" - "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/pkg/crypto" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" + "go.opentelemetry.io/otel" + "golang.org/x/oauth2" ) -var Encoder = httphelper.Encoder(oidc.NewEncoder()) +var ( + Encoder = httphelper.Encoder(oidc.NewEncoder()) + Tracer = otel.Tracer("github.com/zitadel/oidc/pkg/client") +) // Discover calls the discovery endpoint of the provided issuer and returns its configuration // It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) { + ctx, span := Tracer.Start(ctx, "Discover") + defer span.End() + wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint if len(wellKnownUrl) == 1 && wellKnownUrl[0] != "" { wellKnown = wellKnownUrl[0] @@ -58,6 +64,9 @@ func CallTokenEndpoint(ctx context.Context, request any, caller TokenEndpointCal } func callTokenEndpoint(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { + ctx, span := Tracer.Start(ctx, "callTokenEndpoint") + defer span.End() + req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn) if err != nil { return nil, err @@ -86,6 +95,9 @@ type EndSessionCaller interface { } func CallEndSessionEndpoint(ctx context.Context, request any, authFn any, caller EndSessionCaller) (*url.URL, error) { + ctx, span := Tracer.Start(ctx, "CallEndSessionEndpoint") + defer span.End() + req, err := httphelper.FormRequest(ctx, caller.GetEndSessionEndpoint(), request, Encoder, authFn) if err != nil { return nil, err @@ -129,6 +141,9 @@ type RevokeRequest struct { } func CallRevokeEndpoint(ctx context.Context, request any, authFn any, caller RevokeCaller) error { + ctx, span := Tracer.Start(ctx, "CallRevokeEndpoint") + defer span.End() + req, err := httphelper.FormRequest(ctx, caller.GetRevokeEndpoint(), request, Encoder, authFn) if err != nil { return err @@ -157,6 +172,9 @@ func CallRevokeEndpoint(ctx context.Context, request any, authFn any, caller Rev } func CallTokenExchangeEndpoint(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) { + ctx, span := Tracer.Start(ctx, "CallTokenExchangeEndpoint") + defer span.End() + req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn) if err != nil { return nil, err @@ -198,6 +216,9 @@ type DeviceAuthorizationCaller interface { } func CallDeviceAuthorizationEndpoint(ctx context.Context, request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller, authFn any) (*oidc.DeviceAuthorizationResponse, error) { + ctx, span := Tracer.Start(ctx, "CallDeviceAuthorizationEndpoint") + defer span.End() + req, err := httphelper.FormRequest(ctx, caller.GetDeviceAuthorizationEndpoint(), request, Encoder, authFn) if err != nil { return nil, err @@ -219,6 +240,9 @@ type DeviceAccessTokenRequest struct { } func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) { + ctx, span := Tracer.Start(ctx, "CallDeviceAccessTokenEndpoint") + defer span.End() + req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, nil) if err != nil { return nil, err @@ -249,6 +273,9 @@ func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTok } func PollDeviceAccessTokenEndpoint(ctx context.Context, interval time.Duration, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) { + ctx, span := Tracer.Start(ctx, "PollDeviceAccessTokenEndpoint") + defer span.End() + for { timer := time.After(interval) select { diff --git a/pkg/client/rp/device.go b/pkg/client/rp/device.go index 02c647e..c2d1f8a 100644 --- a/pkg/client/rp/device.go +++ b/pkg/client/rp/device.go @@ -33,6 +33,9 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc. // in RFC 8628, section 3.1 and 3.2: // https://www.rfc-editor.org/rfc/rfc8628#section-3.1 func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, authFn any) (*oidc.DeviceAuthorizationResponse, error) { + ctx, span := client.Tracer.Start(ctx, "DeviceAuthorization") + defer span.End() + ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAuthorization") req, err := newDeviceClientCredentialsRequest(scopes, rp) if err != nil { @@ -46,6 +49,9 @@ func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, // by means of polling as defined in RFC, section 3.3 and 3.4: // https://www.rfc-editor.org/rfc/rfc8628#section-3.4 func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) { + ctx, span := client.Tracer.Start(ctx, "DeviceAccessToken") + defer span.End() + ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAccessToken") req := &client.DeviceAccessTokenRequest{ DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{ diff --git a/pkg/client/rp/jwks.go b/pkg/client/rp/jwks.go index 28aec9b..a061777 100644 --- a/pkg/client/rp/jwks.go +++ b/pkg/client/rp/jwks.go @@ -9,6 +9,7 @@ import ( jose "github.com/go-jose/go-jose/v3" + "github.com/zitadel/oidc/v3/pkg/client" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" ) @@ -83,6 +84,9 @@ func (i *inflight) result() ([]jose.JSONWebKey, error) { } func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { + ctx, span := client.Tracer.Start(ctx, "VerifySignature") + defer span.End() + keyID, alg := oidc.GetKeyIDAndAlg(jws) if alg == "" { alg = r.defaultAlg @@ -135,6 +139,9 @@ func (r *remoteKeySet) exactMatch(jwkID, jwsID string) bool { } func (r *remoteKeySet) verifySignatureRemote(ctx context.Context, jws *jose.JSONWebSignature, keyID, alg string) ([]byte, error) { + ctx, span := client.Tracer.Start(ctx, "verifySignatureRemote") + defer span.End() + keys, err := r.keysFromRemote(ctx) if err != nil { return nil, fmt.Errorf("unable to fetch key for signature validation: %w", err) @@ -159,6 +166,9 @@ func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) { // keysFromRemote syncs the key set from the remote set, records the values in the // cache, and returns the key set. func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) { + ctx, span := client.Tracer.Start(ctx, "keysFromRemote") + defer span.End() + // Need to lock to inspect the inflight request field. r.mu.Lock() // If there's not a current inflight request, create one. @@ -182,6 +192,9 @@ func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, e } func (r *remoteKeySet) updateKeys(ctx context.Context) { + ctx, span := client.Tracer.Start(ctx, "updateKeys") + defer span.End() + // Sync keys and finish inflight when that's done. keys, err := r.fetchRemoteKeys(ctx) @@ -201,6 +214,9 @@ func (r *remoteKeySet) updateKeys(ctx context.Context) { } func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey, error) { + ctx, span := client.Tracer.Start(ctx, "fetchRemoteKeys") + defer span.End() + req, err := http.NewRequest("GET", r.jwksURL, nil) if err != nil { return nil, fmt.Errorf("oidc: can't create request: %v", err) diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index 74da71e..62c650e 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -11,10 +11,10 @@ import ( "github.com/go-jose/go-jose/v3" "github.com/google/uuid" - "github.com/zitadel/logging" "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" + "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/pkg/client" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" @@ -428,6 +428,9 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri var ErrMissingIDToken = errors.New("id_token missing") func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Token, rp RelyingParty) (*oidc.Tokens[C], error) { + ctx, span := client.Tracer.Start(ctx, "verifyTokenResponse") + defer span.End() + if rp.IsOAuth2Only() { return &oidc.Tokens[C]{Token: token}, nil } @@ -445,6 +448,9 @@ func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Tok // CodeExchange handles the oauth2 code exchange, extracting and validating the id_token // returning it parsed together with the oauth2 tokens (access, refresh) func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) { + ctx, codeExchangeSpan := client.Tracer.Start(ctx, "CodeExchange") + defer codeExchangeSpan.End() + ctx = logCtxWithRPData(ctx, rp, "function", "CodeExchange") ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient()) codeOpts := make([]oauth2.AuthCodeOption, 0) @@ -452,10 +458,12 @@ func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingP codeOpts = append(codeOpts, opt()...) } + ctx, oauthExchangeSpan := client.Tracer.Start(ctx, "OAuthExchange") token, err := rp.OAuthConfig().Exchange(ctx, code, codeOpts...) if err != nil { return nil, err } + oauthExchangeSpan.End() return verifyTokenResponse[C](ctx, token, rp) } @@ -469,6 +477,9 @@ func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingP // [RFC 6749, section 4.4]: https://datatracker.ietf.org/doc/html/rfc6749#section-4.4 func ClientCredentials(ctx context.Context, rp RelyingParty, endpointParams url.Values) (token *oauth2.Token, err error) { ctx = logCtxWithRPData(ctx, rp, "function", "ClientCredentials") + ctx, span := client.Tracer.Start(ctx, "ClientCredentials") + defer span.End() + ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient()) config := clientcredentials.Config{ ClientID: rp.OAuthConfig().ClientID, @@ -489,6 +500,10 @@ type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.R // Custom parameters can optionally be set to the token URL. func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + ctx, span := client.Tracer.Start(r.Context(), "CodeExchangeHandler") + r = r.WithContext(ctx) + defer span.End() + state, err := tryReadStateCookie(w, r, rp) if err != nil { unauthorizedError(w, r, "failed to get state: "+err.Error(), state, rp) @@ -540,6 +555,10 @@ type CodeExchangeUserinfoCallback[C oidc.IDClaims, U SubjectGetter] func(w http. // on success it will pass the userinfo into its callback function as well func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCallback[C, U]) CodeExchangeCallback[C] { return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) { + ctx, span := client.Tracer.Start(r.Context(), "UserinfoCallback") + r = r.WithContext(ctx) + defer span.End() + info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) if err != nil { unauthorizedError(w, r, "userinfo failed: "+err.Error(), state, rp) @@ -558,6 +577,8 @@ func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCa func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject string, rp RelyingParty) (userinfo U, err error) { var nilU U ctx = logCtxWithRPData(ctx, rp, "function", "Userinfo") + ctx, span := client.Tracer.Start(ctx, "Userinfo") + defer span.End() req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil) if err != nil { @@ -716,6 +737,9 @@ type RefreshTokenRequest struct { // the IDToken and AccessToken will be verfied // and the IDToken and IDTokenClaims fields will be populated in the returned object. func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) { + ctx, span := client.Tracer.Start(ctx, "RefreshTokens") + defer span.End() + ctx = logCtxWithRPData(ctx, rp, "function", "RefreshTokens") request := RefreshTokenRequest{ RefreshToken: refreshToken, @@ -741,6 +765,9 @@ func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refres func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { ctx = logCtxWithRPData(ctx, rp, "function", "EndSession") + ctx, span := client.Tracer.Start(ctx, "RefreshTokens") + defer span.End() + request := oidc.EndSessionRequest{ IdTokenHint: idToken, ClientID: rp.OAuthConfig().ClientID, @@ -757,6 +784,8 @@ func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectU // tokenTypeHint should be either "id_token" or "refresh_token". func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error { ctx = logCtxWithRPData(ctx, rp, "function", "RevokeToken") + ctx, span := client.Tracer.Start(ctx, "RefreshTokens") + defer span.End() request := client.RevokeRequest{ Token: token, TokenTypeHint: tokenTypeHint, diff --git a/pkg/client/rp/verifier.go b/pkg/client/rp/verifier.go index adf8872..94be079 100644 --- a/pkg/client/rp/verifier.go +++ b/pkg/client/rp/verifier.go @@ -6,12 +6,16 @@ import ( jose "github.com/go-jose/go-jose/v3" + "github.com/zitadel/oidc/v3/pkg/client" "github.com/zitadel/oidc/v3/pkg/oidc" ) // VerifyTokens implement the Token Response Validation as defined in OIDC specification // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier) (claims C, err error) { + ctx, span := client.Tracer.Start(ctx, "VerifyTokens") + defer span.End() + var nilClaims C claims, err = VerifyIDToken[C](ctx, idToken, v) @@ -27,6 +31,9 @@ func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken str // VerifyIDToken validates the id token according to // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier) (claims C, err error) { + ctx, span := client.Tracer.Start(ctx, "VerifyIDToken") + defer span.End() + var nilClaims C decrypted, err := oidc.DecryptToken(token) diff --git a/pkg/client/rs/resource_server.go b/pkg/client/rs/resource_server.go index 57925d5..962af7e 100644 --- a/pkg/client/rs/resource_server.go +++ b/pkg/client/rs/resource_server.go @@ -123,6 +123,9 @@ func WithStaticEndpoints(tokenURL, introspectURL string) Option { // // [RFC7662]: https://www.rfc-editor.org/rfc/rfc7662 func Introspect[R any](ctx context.Context, rp ResourceServer, token string) (resp R, err error) { + ctx, span := client.Tracer.Start(ctx, "Introspect") + defer span.End() + if rp.IntrospectionURL() == "" { return resp, errors.New("resource server: introspection URL is empty") } diff --git a/pkg/client/rs/resource_server_test.go b/pkg/client/rs/resource_server_test.go index bb17c64..7a5ced9 100644 --- a/pkg/client/rs/resource_server_test.go +++ b/pkg/client/rs/resource_server_test.go @@ -201,7 +201,7 @@ func TestIntrospect(t *testing.T) { { name: "missing-introspect-url", args: args{ - ctx: nil, + ctx: context.Background(), rp: rp, token: "my-token", }, diff --git a/pkg/client/tokenexchange/tokenexchange.go b/pkg/client/tokenexchange/tokenexchange.go index 7bc35a2..a2ea1bb 100644 --- a/pkg/client/tokenexchange/tokenexchange.go +++ b/pkg/client/tokenexchange/tokenexchange.go @@ -114,6 +114,9 @@ func ExchangeToken( Scopes []string, RequestedTokenType oidc.TokenType, ) (*oidc.TokenExchangeResponse, error) { + ctx, span := client.Tracer.Start(ctx, "ExchangeToken") + defer span.End() + if SubjectToken == "" { return nil, errors.New("empty subject_token") } diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 18d8826..4b5837a 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -70,12 +70,15 @@ func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, * // Authorize handles the authorization request, including // parsing, validating, storing and finally redirecting to the login handler func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { + ctx, span := tracer.Start(r.Context(), "Authorize") + r = r.WithContext(ctx) + defer span.End() + authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder()) if err != nil { AuthRequestError(w, r, nil, err, authorizer) return } - ctx := r.Context() if authReq.RequestParam != "" && authorizer.RequestObjectSupported() { err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx)) if err != nil { @@ -210,6 +213,9 @@ func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oi // ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) { + ctx, span := tracer.Start(ctx, "ValidateAuthRequest") + defer span.End() + authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge) if err != nil { return "", err @@ -310,7 +316,7 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res return checkURIAgainstRedirects(client, uri) } if client.ApplicationType() == ApplicationTypeNative { - return validateAuthReqRedirectURINative(client, uri, responseType) + return validateAuthReqRedirectURINative(client, uri) } if err := checkURIAgainstRedirects(client, uri); err != nil { return err @@ -330,7 +336,7 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res } // ValidateAuthReqRedirectURINative validates the passed redirect_uri and response_type to the registered uris and client type -func validateAuthReqRedirectURINative(client Client, uri string, responseType oidc.ResponseType) error { +func validateAuthReqRedirectURINative(client Client, uri string) error { parsedURL, isLoopback := HTTPLoopbackOrLocalhost(uri) isCustomSchema := !strings.HasPrefix(uri, "http://") if err := checkURIAgainstRedirects(client, uri); err == nil { @@ -362,8 +368,8 @@ func equalURI(url1, url2 *url.URL) bool { return url1.Path == url2.Path && url1.RawQuery == url2.RawQuery } -func HTTPLoopbackOrLocalhost(rawurl string) (*url.URL, bool) { - parsedURL, err := url.Parse(rawurl) +func HTTPLoopbackOrLocalhost(rawURL string) (*url.URL, bool) { + parsedURL, err := url.Parse(rawURL) if err != nil { return nil, false } @@ -409,6 +415,10 @@ func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r * // AuthorizeCallback handles the callback after authentication in the Login UI func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { + ctx, span := tracer.Start(r.Context(), "AuthorizeCallback") + r = r.WithContext(ctx) + defer span.End() + id, err := ParseAuthorizeCallbackRequest(r) if err != nil { AuthRequestError(w, r, nil, err, authorizer) @@ -441,6 +451,10 @@ func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) { // AuthResponse creates the successful authentication response (either code or tokens) func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { + ctx, span := tracer.Start(r.Context(), "AuthResponse") + r = r.WithContext(ctx) + defer span.End() + client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID()) if err != nil { AuthRequestError(w, r, authReq, err, authorizer) @@ -455,6 +469,10 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri // AuthResponseCode creates the successful code authentication response func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) { + ctx, span := tracer.Start(r.Context(), "AuthResponseCode") + r = r.WithContext(ctx) + defer span.End() + code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto()) if err != nil { AuthRequestError(w, r, authReq, err, authorizer) @@ -519,6 +537,9 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque // CreateAuthRequestCode creates and stores a code for the auth code response func CreateAuthRequestCode(ctx context.Context, authReq AuthRequest, storage Storage, crypto Crypto) (string, error) { + ctx, span := tracer.Start(ctx, "CreateAuthRequestCode") + defer span.End() + code, err := BuildAuthRequestCode(authReq, crypto) if err != nil { return "", err diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index 76cb00d..45627a5 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -1072,7 +1072,7 @@ func TestAuthResponseCode(t *testing.T) { authorizer: func(t *testing.T) op.Authorizer { ctrl := gomock.NewController(t) storage := mock.NewMockStorage(ctrl) - storage.EXPECT().SaveAuthCode(context.Background(), "id1", "id1") + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") authorizer := mock.NewMockAuthorizer(ctrl) authorizer.EXPECT().Storage().Return(storage) @@ -1097,7 +1097,7 @@ func TestAuthResponseCode(t *testing.T) { authorizer: func(t *testing.T) op.Authorizer { ctrl := gomock.NewController(t) storage := mock.NewMockStorage(ctrl) - storage.EXPECT().SaveAuthCode(context.Background(), "id1", "id1") + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") authorizer := mock.NewMockAuthorizer(ctrl) authorizer.EXPECT().Storage().Return(storage) @@ -1124,7 +1124,7 @@ func TestAuthResponseCode(t *testing.T) { authorizer: func(t *testing.T) op.Authorizer { ctrl := gomock.NewController(t) storage := mock.NewMockStorage(ctrl) - storage.EXPECT().SaveAuthCode(context.Background(), "id1", "id1") + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") authorizer := mock.NewMockAuthorizer(ctrl) authorizer.EXPECT().Storage().Return(storage) diff --git a/pkg/op/client.go b/pkg/op/client.go index 0574afa..913944c 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -92,6 +92,9 @@ type ClientJWTProfile interface { } func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) { + ctx, span := tracer.Start(ctx, "ClientJWTAuth") + defer span.End() + if ca.ClientAssertion == "" { return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials) } @@ -104,6 +107,10 @@ func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier } func ClientBasicAuth(r *http.Request, storage Storage) (clientID string, err error) { + ctx, span := tracer.Start(r.Context(), "ClientBasicAuth") + r = r.WithContext(ctx) + defer span.End() + clientID, clientSecret, ok := r.BasicAuth() if !ok { return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials) @@ -151,6 +158,10 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au return "", false, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) } + ctx, span := tracer.Start(r.Context(), "ClientIDFromRequest") + r = r.WithContext(ctx) + defer span.End() + data := new(clientData) if err = p.Decoder().Decode(data, r.Form); err != nil { return "", false, err @@ -171,7 +182,7 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au } // if the client did not send a Basic Auth Header, ignore the `ErrNoClientCredentials` // but return other errors immediately - if err != nil && !errors.Is(err, ErrNoClientCredentials) { + if !errors.Is(err, ErrNoClientCredentials) { return "", false, err } diff --git a/pkg/op/client_test.go b/pkg/op/client_test.go index 0321f88..b772ba5 100644 --- a/pkg/op/client_test.go +++ b/pkg/op/client_test.go @@ -108,7 +108,7 @@ func TestClientBasicAuth(t *testing.T) { }, storage: func() op.Storage { s := mock.NewMockStorage(gomock.NewController(t)) - s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "wrong").Return(errWrong) + s.EXPECT().AuthorizeClientIDSecret(gomock.Any(), "foo", "wrong").Return(errWrong) return s }(), wantErr: errWrong, @@ -121,7 +121,7 @@ func TestClientBasicAuth(t *testing.T) { }, storage: func() op.Storage { s := mock.NewMockStorage(gomock.NewController(t)) - s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil) + s.EXPECT().AuthorizeClientIDSecret(gomock.Any(), "foo", "bar").Return(nil) return s }(), wantClientID: "foo", @@ -207,7 +207,7 @@ func TestClientIDFromRequest(t *testing.T) { p: testClientProvider{ storage: func() op.Storage { s := mock.NewMockStorage(gomock.NewController(t)) - s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil) + s.EXPECT().AuthorizeClientIDSecret(gomock.Any(), "foo", "bar").Return(nil) return s }(), }, diff --git a/pkg/op/device.go b/pkg/op/device.go index 1b86d04..11638b0 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -64,6 +64,10 @@ func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *htt } func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error { + ctx, span := tracer.Start(r.Context(), "DeviceAuthorization") + r = r.WithContext(ctx) + defer span.End() + req, err := ParseDeviceCodeRequest(r, o) if err != nil { return err @@ -78,6 +82,9 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide } func createDeviceAuthorization(ctx context.Context, req *oidc.DeviceAuthorizationRequest, clientID string, o OpenIDProvider) (*oidc.DeviceAuthorizationResponse, error) { + ctx, span := tracer.Start(ctx, "createDeviceAuthorization") + defer span.End() + storage, err := assertDeviceStorage(o.Storage()) if err != nil { return nil, err @@ -127,6 +134,10 @@ func createDeviceAuthorization(ctx context.Context, req *oidc.DeviceAuthorizatio } func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) { + ctx, span := tracer.Start(r.Context(), "ParseDeviceCodeRequest") + r = r.WithContext(ctx) + defer span.End() + clientID, _, err := ClientIDFromRequest(r, o) if err != nil { return nil, err @@ -288,6 +299,9 @@ func (r *DeviceAuthorizationState) GetSubject() string { } func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) { + ctx, span := tracer.Start(ctx, "CheckDeviceAuthorizationState") + defer span.End() + storage, err := assertDeviceStorage(exchanger.Storage()) if err != nil { return nil, err diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 6af1674..7b5ecbe 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -135,6 +135,9 @@ func SubjectTypes(c Configuration) []string { } func SigAlgorithms(ctx context.Context, storage DiscoverStorage) []string { + ctx, span := tracer.Start(ctx, "SigAlgorithms") + defer span.End() + algorithms, err := storage.SignatureAlgorithms(ctx) if err != nil { return nil diff --git a/pkg/op/keys.go b/pkg/op/keys.go index fe111f0..d55c8d1 100644 --- a/pkg/op/keys.go +++ b/pkg/op/keys.go @@ -20,6 +20,10 @@ func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) { } func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) { + ctx, span := tracer.Start(r.Context(), "Keys") + r = r.WithContext(ctx) + defer span.End() + keySet, err := k.KeySet(r.Context()) if err != nil { httphelper.MarshalJSONWithStatus(w, err, http.StatusInternalServerError) diff --git a/pkg/op/op.go b/pkg/op/op.go index 326737a..3248317 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -12,7 +12,6 @@ import ( "github.com/rs/cors" "github.com/zitadel/schema" "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/trace" "golang.org/x/text/language" httphelper "github.com/zitadel/oidc/v3/pkg/http" @@ -97,11 +96,7 @@ var ( } ) -var tracer trace.Tracer - -func init() { - tracer = otel.Tracer("github.com/zitadel/oidc/pkg/op") -} +var tracer = otel.Tracer("github.com/zitadel/oidc/pkg/op") type OpenIDProvider interface { http.Handler diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index b2a758c..83032d4 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -181,7 +181,7 @@ func TestRoutes(t *testing.T) { }, }, { - // This call will fail. A successfull test is already + // This call will fail. A successful test is already // part of client/integration_test.go name: "code exchange", method: http.MethodGet, @@ -235,7 +235,7 @@ func TestRoutes(t *testing.T) { contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`}, }, { - // This call will fail. A successfull test is already + // This call will fail. A successful test is already // part of device_test.go name: "device token", method: http.MethodPost, diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index 0a5e469..725dd64 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -124,7 +124,13 @@ func (s *webServer) createRouter() { func (s *webServer) endpointRoute(e *Endpoint, hf http.HandlerFunc) { if e != nil { - s.router.HandleFunc(e.Relative(), hf) + traceHandler := func(w http.ResponseWriter, r *http.Request) { + ctx, span := tracer.Start(r.Context(), e.Relative()) + r = r.WithContext(ctx) + hf(w, r) + defer span.End() + } + s.router.HandleFunc(e.Relative(), traceHandler) s.logger.Info("registered route", "endpoint", e.Relative()) } } @@ -133,6 +139,10 @@ type clientHandler func(w http.ResponseWriter, r *http.Request, client Client) func (s *webServer) withClient(handler clientHandler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + ctx, span := tracer.Start(r.Context(), r.URL.Path) + defer span.End() + r = r.WithContext(ctx) + client, err := s.verifyRequestClient(r) if err != nil { WriteError(w, r, err, s.getLogger(r.Context())) diff --git a/pkg/op/server_http_routes_test.go b/pkg/op/server_http_routes_test.go index c50e989..8b3fa02 100644 --- a/pkg/op/server_http_routes_test.go +++ b/pkg/op/server_http_routes_test.go @@ -177,7 +177,7 @@ func TestServerRoutes(t *testing.T) { contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`}, }, { - // This call will fail. A successfull test is already + // This call will fail. A successful test is already // part of device_test.go name: "device token", method: http.MethodPost, diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go index f99d15d..6b6d4b3 100644 --- a/pkg/op/server_legacy.go +++ b/pkg/op/server_legacy.go @@ -79,6 +79,9 @@ func (s *LegacyServer) Endpoints() Endpoints { // AuthCallbackURL builds the url for the redirect (with the requestID) after a successful login func (s *LegacyServer) AuthCallbackURL() func(context.Context, string) string { return func(ctx context.Context, requestID string) string { + ctx, span := tracer.Start(ctx, "LegacyServer.AuthCallbackURL") + defer span.End() + return s.endpoints.Authorization.Absolute(IssuerFromContext(ctx)) + authCallbackPathSuffix + "?id=" + requestID } } @@ -98,12 +101,18 @@ func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Respon } func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.Discovery") + defer span.End() + return NewResponse( createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints), ), nil } func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.Keys") + defer span.End() + keys, err := s.provider.Storage().KeySet(ctx) if err != nil { return nil, AsStatusError(err, http.StatusInternalServerError) @@ -117,6 +126,9 @@ var ( ) func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { + ctx, span := tracer.Start(ctx, "LegacyServer.VerifyAuthRequest") + defer span.End() + if r.Data.RequestParam != "" { if !s.provider.RequestObjectSupported() { return nil, oidc.ErrRequestNotSupported() @@ -141,6 +153,9 @@ func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.Au } func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (_ *Redirect, err error) { + ctx, span := tracer.Start(ctx, "LegacyServer.Authorize") + defer span.End() + userID, err := ValidateAuthReqIDTokenHint(ctx, r.Data.IDTokenHint, s.provider.IDTokenHintVerifier(ctx)) if err != nil { return nil, err @@ -153,6 +168,9 @@ func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.Auth } func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.DeviceAuthorization") + defer span.End() + response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider) if err != nil { return nil, AsStatusError(err, http.StatusInternalServerError) @@ -161,6 +179,9 @@ func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest } func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.VerifyClient") + defer span.End() + if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials { storage, ok := s.provider.Storage().(ClientCredentialsStorage) if !ok { @@ -201,6 +222,9 @@ func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCreden } func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.CodeExchange") + defer span.End() + authReq, err := AuthRequestByCode(ctx, s.provider.Storage(), r.Data.Code) if err != nil { return nil, err @@ -221,6 +245,9 @@ func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.A } func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.RefreshToken") + defer span.End() + if !s.provider.GrantTypeRefreshTokenSupported() { return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken) } @@ -242,6 +269,9 @@ func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.R } func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.JWTProfile") + defer span.End() + exchanger, ok := s.provider.(JWTAuthorizationGrantExchanger) if !ok { return nil, unimplementedGrantError(oidc.GrantTypeBearer) @@ -263,6 +293,9 @@ func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfil } func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.TokenExchange") + defer span.End() + if !s.provider.GrantTypeTokenExchangeSupported() { return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) } @@ -278,6 +311,9 @@ func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc. } func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.ClientCredentialsExchange") + defer span.End() + storage, ok := s.provider.Storage().(ClientCredentialsStorage) if !ok { return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials) @@ -294,6 +330,9 @@ func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientR } func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.DeviceToken") + defer span.End() + if !s.provider.GrantTypeDeviceCodeSupported() { return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode) } @@ -314,6 +353,9 @@ func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.De } func (s *LegacyServer) authenticateResourceClient(ctx context.Context, cc *ClientCredentials) (string, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.authenticateResourceClient") + defer span.End() + if cc.ClientAssertion != "" { if jp, ok := s.provider.(ClientJWTProfile); ok { return ClientJWTAuth(ctx, oidc.ClientAssertionParams{ClientAssertion: cc.ClientAssertion}, jp) @@ -327,6 +369,9 @@ func (s *LegacyServer) authenticateResourceClient(ctx context.Context, cc *Clien } func (s *LegacyServer) Introspect(ctx context.Context, r *Request[IntrospectionRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.Introspect") + defer span.End() + clientID, err := s.authenticateResourceClient(ctx, r.Data.ClientCredentials) if err != nil { return nil, err @@ -345,6 +390,9 @@ func (s *LegacyServer) Introspect(ctx context.Context, r *Request[IntrospectionR } func (s *LegacyServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.UserInfo") + defer span.End() + tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.AccessToken) if !ok { return nil, NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid"), http.StatusUnauthorized) @@ -358,6 +406,9 @@ func (s *LegacyServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoReq } func (s *LegacyServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.Revocation") + defer span.End() + var subject string doDecrypt := true if r.Data.TokenTypeHint != "access_token" { @@ -387,6 +438,9 @@ func (s *LegacyServer) Revocation(ctx context.Context, r *ClientRequest[oidc.Rev } func (s *LegacyServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.EndSession") + defer span.End() + session, err := ValidateEndSessionRequest(ctx, r.Data, s.provider) if err != nil { return nil, err diff --git a/pkg/op/session.go b/pkg/op/session.go index 6af7d7c..8ac530d 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -27,6 +27,10 @@ func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Reque } func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) { + ctx, span := tracer.Start(r.Context(), "EndSession") + defer span.End() + r = r.WithContext(ctx) + req, err := ParseEndSessionRequest(r, ender.Decoder()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -64,6 +68,9 @@ func ParseEndSessionRequest(r *http.Request, decoder httphelper.Decoder) (*oidc. } func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, ender SessionEnder) (*EndSessionRequest, error) { + ctx, span := tracer.Start(ctx, "ValidateEndSessionRequest") + defer span.End() + session := &EndSessionRequest{ RedirectURI: ender.DefaultLogoutRedirectURI(), } diff --git a/pkg/op/token.go b/pkg/op/token.go index a055eb7..b45789b 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -69,6 +69,9 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli } func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage, refreshToken string, client AccessTokenClient) (id, newRefreshToken string, exp time.Time, err error) { + ctx, span := tracer.Start(ctx, "createTokens") + defer span.End() + if needsRefreshToken(tokenRequest, client) { return storage.CreateAccessAndRefreshTokens(ctx, tokenRequest, refreshToken) } diff --git a/pkg/op/token_exchange.go b/pkg/op/token_exchange.go index db3e468..fcb4468 100644 --- a/pkg/op/token_exchange.go +++ b/pkg/op/token_exchange.go @@ -193,6 +193,9 @@ func ValidateTokenExchangeRequest( clientID, clientSecret string, exchanger Exchanger, ) (TokenExchangeRequest, Client, error) { + ctx, span := tracer.Start(ctx, "ValidateTokenExchangeRequest") + defer span.End() + if oidcTokenExchangeRequest.SubjectToken == "" { return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token missing") } @@ -231,6 +234,9 @@ func CreateTokenExchangeRequest( client Client, exchanger Exchanger, ) (TokenExchangeRequest, error) { + ctx, span := tracer.Start(ctx, "CreateTokenExchangeRequest") + defer span.End() + teStorage, ok := exchanger.Storage().(TokenExchangeStorage) if !ok { return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) @@ -294,6 +300,9 @@ func GetTokenIDAndSubjectFromToken( tokenType oidc.TokenType, isActor bool, ) (tokenIDOrToken, subject string, claims map[string]any, ok bool) { + ctx, span := tracer.Start(ctx, "GetTokenIDAndSubjectFromToken") + defer span.End() + switch tokenType { case oidc.AccessTokenType: var accessTokenClaims *oidc.AccessTokenClaims @@ -341,6 +350,9 @@ func GetTokenIDAndSubjectFromToken( // AuthorizeTokenExchangeClient authorizes a client by validating the client_id and client_secret func AuthorizeTokenExchangeClient(ctx context.Context, clientID, clientSecret string, exchanger Exchanger) (client Client, err error) { + ctx, span := tracer.Start(ctx, "AuthorizeTokenExchangeClient") + defer span.End() + if err := AuthorizeClientIDSecret(ctx, clientID, clientSecret, exchanger.Storage()); err != nil { return nil, err } @@ -359,6 +371,8 @@ func CreateTokenExchangeResponse( client Client, creator TokenCreator, ) (_ *oidc.TokenExchangeResponse, err error) { + ctx, span := tracer.Start(ctx, "CreateTokenExchangeResponse") + defer span.End() var ( token, refreshToken, tokenType string diff --git a/pkg/op/token_intospection.go b/pkg/op/token_intospection.go index 9c45ef8..29234e1 100644 --- a/pkg/op/token_intospection.go +++ b/pkg/op/token_intospection.go @@ -28,6 +28,10 @@ func introspectionHandler(introspector Introspector) func(http.ResponseWriter, * } func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspector) { + ctx, span := tracer.Start(r.Context(), "Introspect") + defer span.End() + r = r.WithContext(ctx) + response := new(oidc.IntrospectionResponse) token, clientID, err := ParseTokenIntrospectionRequest(r, introspector) if err != nil { diff --git a/pkg/op/token_refresh.go b/pkg/op/token_refresh.go index afca3bf..92ef476 100644 --- a/pkg/op/token_refresh.go +++ b/pkg/op/token_refresh.go @@ -141,6 +141,9 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ // RefreshTokenRequestByRefreshToken returns the RefreshTokenRequest (data representing the original auth request) // corresponding to the refresh_token from Storage or an error func RefreshTokenRequestByRefreshToken(ctx context.Context, storage Storage, refreshToken string) (RefreshTokenRequest, error) { + ctx, span := tracer.Start(ctx, "RefreshTokenRequestByRefreshToken") + defer span.End() + request, err := storage.TokenRequestByRefreshToken(ctx, refreshToken) if err != nil { return nil, oidc.ErrInvalidGrant().WithParent(err) diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index 2006725..85e2270 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -37,6 +37,10 @@ func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Reque // Exchange performs a token exchange appropriate for the grant type func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { + ctx, span := tracer.Start(r.Context(), "Exchange") + r = r.WithContext(ctx) + defer span.End() + grantType := r.FormValue("grant_type") switch grantType { case string(oidc.GrantTypeCode): @@ -115,6 +119,9 @@ func ParseAuthenticatedTokenRequest(r *http.Request, decoder httphelper.Decoder, // AuthorizeClientIDSecret authorizes a client by validating the client_id and client_secret (Basic Auth and POST) func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage Storage) error { + ctx, span := tracer.Start(ctx, "AuthorizeClientIDSecret") + defer span.End() + err := storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret) if err != nil { return oidc.ErrInvalidClient().WithDescription("invalid client_id / client_secret").WithParent(err) diff --git a/pkg/op/token_revocation.go b/pkg/op/token_revocation.go index d19c7f7..a86a481 100644 --- a/pkg/op/token_revocation.go +++ b/pkg/op/token_revocation.go @@ -32,6 +32,10 @@ func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) } func Revoke(w http.ResponseWriter, r *http.Request, revoker Revoker) { + ctx, span := tracer.Start(r.Context(), "Revoke") + r = r.WithContext(ctx) + defer span.End() + token, tokenTypeHint, clientID, err := ParseTokenRevocationRequest(r, revoker) if err != nil { RevocationRequestError(w, r, err) @@ -68,6 +72,10 @@ func Revoke(w http.ResponseWriter, r *http.Request, revoker Revoker) { } func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, tokenTypeHint, clientID string, err error) { + ctx, span := tracer.Start(r.Context(), "ParseTokenRevocationRequest") + r = r.WithContext(ctx) + defer span.End() + err = r.ParseForm() if err != nil { return "", "", "", oidc.ErrInvalidRequest().WithDescription("unable to parse request").WithParent(err) @@ -148,6 +156,9 @@ func RevocationError(err error) StatusError { } func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) { + ctx, span := tracer.Start(ctx, "getTokenIDAndSubjectForRevocation") + defer span.End() + tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken) if err == nil { splitToken := strings.Split(tokenIDSubject, ":") diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index 86205b5..839b139 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -24,6 +24,10 @@ func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter } func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoProvider) { + ctx, span := tracer.Start(r.Context(), "Userinfo") + r = r.WithContext(ctx) + defer span.End() + accessToken, err := ParseUserinfoRequest(r, userinfoProvider.Decoder()) if err != nil { http.Error(w, "access token missing", http.StatusUnauthorized) @@ -44,6 +48,10 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP } func ParseUserinfoRequest(r *http.Request, decoder httphelper.Decoder) (string, error) { + ctx, span := tracer.Start(r.Context(), "ParseUserinfoRequest") + r = r.WithContext(ctx) + defer span.End() + accessToken, err := getAccessToken(r) if err == nil { return accessToken, nil @@ -61,6 +69,10 @@ func ParseUserinfoRequest(r *http.Request, decoder httphelper.Decoder) (string, } func getAccessToken(r *http.Request) (string, error) { + ctx, span := tracer.Start(r.Context(), "getAccessToken") + r = r.WithContext(ctx) + defer span.End() + authHeader := r.Header.Get("authorization") if authHeader == "" { return "", errors.New("no auth header") @@ -73,6 +85,9 @@ func getAccessToken(r *http.Request) (string, error) { } func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) { + ctx, span := tracer.Start(ctx, "getTokenIDAndSubject") + defer span.End() + tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken) if err == nil { splitToken := strings.Split(tokenIDSubject, ":") diff --git a/pkg/op/verifier_access_token.go b/pkg/op/verifier_access_token.go index 120bfa7..6ac29f2 100644 --- a/pkg/op/verifier_access_token.go +++ b/pkg/op/verifier_access_token.go @@ -30,6 +30,9 @@ func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTok // VerifyAccessToken validates the access token (issuer, signature and expiration). func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v *AccessTokenVerifier) (claims C, err error) { + ctx, span := tracer.Start(ctx, "VerifyAccessToken") + defer span.End() + var nilClaims C decrypted, err := oidc.DecryptToken(token) diff --git a/pkg/op/verifier_id_token_hint.go b/pkg/op/verifier_id_token_hint.go index b5ec72e..331c64c 100644 --- a/pkg/op/verifier_id_token_hint.go +++ b/pkg/op/verifier_id_token_hint.go @@ -46,6 +46,9 @@ func (e IDTokenHintExpiredError) Is(err error) bool { // is returned of type [IDTokenHintExpiredError]. In that case the caller can choose to still // trust the token for cases like logout, as signature and other verifications succeeded. func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v *IDTokenHintVerifier) (claims C, err error) { + ctx, span := tracer.Start(ctx, "VerifyIDTokenHint") + defer span.End() + var nilClaims C decrypted, err := oidc.DecryptToken(token) diff --git a/pkg/op/verifier_jwt_profile.go b/pkg/op/verifier_jwt_profile.go index 38b8ee4..ced99ad 100644 --- a/pkg/op/verifier_jwt_profile.go +++ b/pkg/op/verifier_jwt_profile.go @@ -118,6 +118,9 @@ type jwtProfileKeySet struct { // VerifySignature implements oidc.KeySet by getting the public key from Storage implementation func (k *jwtProfileKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) { + ctx, span := tracer.Start(ctx, "VerifySignature") + defer span.End() + keyID, _ := oidc.GetKeyIDAndAlg(jws) key, err := k.storage.GetKeyByIDAndClientID(ctx, keyID, k.clientID) if err != nil {