From 6af94fded0a1d5ddb448799358b029733b77d7a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Thu, 23 Mar 2023 16:31:38 +0200 Subject: [PATCH] feat: add context to all client calls (#345) BREAKING CHANGE closes #309 --- example/client/api/api.go | 3 +- example/client/app/app.go | 3 +- example/client/device/device.go | 4 +- example/client/service/service.go | 4 +- pkg/client/client.go | 30 ++++++------- pkg/client/integration_test.go | 30 ++++++++++--- pkg/client/jwt_profile.go | 5 ++- pkg/client/key.go | 8 ++-- pkg/client/profile/jwt_profile.go | 51 ++++++++++++++++------- pkg/client/rp/device.go | 4 +- pkg/client/rp/relying_party.go | 43 +++++-------------- pkg/client/rs/resource_server.go | 18 ++++---- pkg/client/tokenexchange/tokenexchange.go | 16 +++---- pkg/http/http.go | 4 +- 14 files changed, 124 insertions(+), 99 deletions(-) diff --git a/example/client/api/api.go b/example/client/api/api.go index 95e84e7..83ec2a1 100644 --- a/example/client/api/api.go +++ b/example/client/api/api.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "fmt" "log" @@ -27,7 +28,7 @@ func main() { port := os.Getenv("PORT") issuer := os.Getenv("ISSUER") - provider, err := rs.NewResourceServerFromKeyFile(issuer, keyPath) + provider, err := rs.NewResourceServerFromKeyFile(context.TODO(), issuer, keyPath) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } diff --git a/example/client/app/app.go b/example/client/app/app.go index 446c17b..2cb5dfa 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "fmt" "net/http" @@ -43,7 +44,7 @@ func main() { options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) } - provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, options...) + provider, err := rp.NewRelyingPartyOIDC(context.TODO(), issuer, clientID, clientSecret, redirectURI, scopes, options...) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } diff --git a/example/client/device/device.go b/example/client/device/device.go index 88ecfe9..c186b34 100644 --- a/example/client/device/device.go +++ b/example/client/device/device.go @@ -39,13 +39,13 @@ func main() { options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) } - provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, "", scopes, options...) + provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, "", scopes, options...) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } logrus.Info("starting device authorization flow") - resp, err := rp.DeviceAuthorization(scopes, provider) + resp, err := rp.DeviceAuthorization(ctx, scopes, provider) if err != nil { logrus.Fatal(err) } diff --git a/example/client/service/service.go b/example/client/service/service.go index 4908b09..ffcdccb 100644 --- a/example/client/service/service.go +++ b/example/client/service/service.go @@ -25,7 +25,7 @@ func main() { scopes := strings.Split(os.Getenv("SCOPES"), " ") if keyPath != "" { - ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath, scopes) + ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(context.TODO(), issuer, keyPath, scopes) if err != nil { logrus.Fatalf("error creating token source %s", err.Error()) } @@ -76,7 +76,7 @@ func main() { http.Error(w, err.Error(), http.StatusInternalServerError) return } - ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(issuer, key, scopes) + ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(context.TODO(), issuer, key, scopes) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/pkg/client/client.go b/pkg/client/client.go index e9af8ce..b9580ff 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -23,12 +23,12 @@ var Encoder = httphelper.Encoder(oidc.NewEncoder()) // Discover calls the discovery endpoint of the provided issuer and returns its configuration // It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url -func Discover(issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) { +func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) { wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint if len(wellKnownUrl) == 1 && wellKnownUrl[0] != "" { wellKnown = wellKnownUrl[0] } - req, err := http.NewRequest("GET", wellKnown, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil) if err != nil { return nil, err } @@ -48,12 +48,12 @@ type TokenEndpointCaller interface { HttpClient() *http.Client } -func CallTokenEndpoint(request interface{}, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { - return callTokenEndpoint(request, nil, caller) +func CallTokenEndpoint(ctx context.Context, request interface{}, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { + return callTokenEndpoint(ctx, request, nil, caller) } -func callTokenEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { - req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn) +func callTokenEndpoint(ctx context.Context, request interface{}, authFn interface{}, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { + req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn) if err != nil { return nil, err } @@ -74,8 +74,8 @@ type EndSessionCaller interface { HttpClient() *http.Client } -func CallEndSessionEndpoint(request interface{}, authFn interface{}, caller EndSessionCaller) (*url.URL, error) { - req, err := httphelper.FormRequest(caller.GetEndSessionEndpoint(), request, Encoder, authFn) +func CallEndSessionEndpoint(ctx context.Context, request interface{}, authFn interface{}, caller EndSessionCaller) (*url.URL, error) { + req, err := httphelper.FormRequest(ctx, caller.GetEndSessionEndpoint(), request, Encoder, authFn) if err != nil { return nil, err } @@ -117,8 +117,8 @@ type RevokeRequest struct { ClientSecret string `schema:"client_secret"` } -func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCaller) error { - req, err := httphelper.FormRequest(caller.GetRevokeEndpoint(), request, Encoder, authFn) +func CallRevokeEndpoint(ctx context.Context, request interface{}, authFn interface{}, caller RevokeCaller) error { + req, err := httphelper.FormRequest(ctx, caller.GetRevokeEndpoint(), request, Encoder, authFn) if err != nil { return err } @@ -145,8 +145,8 @@ func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCa return nil } -func CallTokenExchangeEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) { - req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn) +func CallTokenExchangeEndpoint(ctx context.Context, request interface{}, authFn interface{}, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) { + req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn) if err != nil { return nil, err } @@ -186,8 +186,8 @@ type DeviceAuthorizationCaller interface { HttpClient() *http.Client } -func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) { - req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil) +func CallDeviceAuthorizationEndpoint(ctx context.Context, request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) { + req, err := httphelper.FormRequest(ctx, caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil) if err != nil { return nil, err } @@ -208,7 +208,7 @@ type DeviceAccessTokenRequest struct { } func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) { - req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil) + req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, nil) if err != nil { return nil, err } diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index 709d5a1..2c3ef62 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -2,6 +2,7 @@ package client_test import ( "bytes" + "context" "io" "io/ioutil" "math/rand" @@ -10,7 +11,9 @@ import ( "net/http/httptest" "net/url" "os" + "os/signal" "strconv" + "syscall" "testing" "time" @@ -27,6 +30,18 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" ) +var CTX context.Context + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT) + defer cancel() + CTX, cancel = context.WithTimeout(ctx, time.Minute) + defer cancel() + return m.Run() + }()) +} + func TestRelyingPartySession(t *testing.T) { t.Log("------- start example OP ------") targetURL := "http://local-site" @@ -45,7 +60,7 @@ func TestRelyingPartySession(t *testing.T) { t.Log("------- refresh tokens ------") - newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "") + newTokens, err := rp.RefreshAccessToken(CTX, provider, refreshToken, "", "") require.NoError(t, err, "refresh token") assert.NotNil(t, newTokens, "access token") t.Logf("new access token %s", newTokens.AccessToken) @@ -56,7 +71,7 @@ func TestRelyingPartySession(t *testing.T) { t.Log("------ end session (logout) ------") - newLoc, err := rp.EndSession(provider, idToken, "", "") + newLoc, err := rp.EndSession(CTX, provider, idToken, "", "") require.NoError(t, err, "logout") if newLoc != nil { t.Logf("redirect to %s", newLoc) @@ -66,11 +81,11 @@ func TestRelyingPartySession(t *testing.T) { t.Log("------ attempt refresh again (should fail) ------") t.Log("trying original refresh token", refreshToken) - _, err = rp.RefreshAccessToken(provider, refreshToken, "", "") + _, err = rp.RefreshAccessToken(CTX, provider, refreshToken, "", "") assert.Errorf(t, err, "refresh with original") if newTokens.RefreshToken != "" { t.Log("trying replacement refresh token", newTokens.RefreshToken) - _, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "") + _, err = rp.RefreshAccessToken(CTX, provider, newTokens.RefreshToken, "", "") assert.Errorf(t, err, "refresh with replacement") } } @@ -92,12 +107,13 @@ func TestResourceServerTokenExchange(t *testing.T) { t.Log("------- run authorization code flow ------") provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret) - resourceServer, err := rs.NewResourceServerClientCredentials(opServer.URL, clientID, clientSecret) + resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret) require.NoError(t, err, "new resource server") t.Log("------- exchage refresh tokens (impersonation) ------") tokenExchangeResponse, err := tokenexchange.ExchangeToken( + CTX, resourceServer, refreshToken, oidc.RefreshTokenType, @@ -117,7 +133,7 @@ func TestResourceServerTokenExchange(t *testing.T) { t.Log("------ end session (logout) ------") - newLoc, err := rp.EndSession(provider, idToken, "", "") + newLoc, err := rp.EndSession(CTX, provider, idToken, "", "") require.NoError(t, err, "logout") if newLoc != nil { t.Logf("redirect to %s", newLoc) @@ -128,6 +144,7 @@ func TestResourceServerTokenExchange(t *testing.T) { t.Log("------- attempt exchage again (should fail) ------") tokenExchangeResponse, err = tokenexchange.ExchangeToken( + CTX, resourceServer, refreshToken, oidc.RefreshTokenType, @@ -166,6 +183,7 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, key := []byte("test1234test1234") cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) provider, err = rp.NewRelyingPartyOIDC( + CTX, opServer.URL, clientID, clientSecret, diff --git a/pkg/client/jwt_profile.go b/pkg/client/jwt_profile.go index 486d998..0a5d9ec 100644 --- a/pkg/client/jwt_profile.go +++ b/pkg/client/jwt_profile.go @@ -1,6 +1,7 @@ package client import ( + "context" "net/url" "golang.org/x/oauth2" @@ -10,8 +11,8 @@ import ( ) // JWTProfileExchange handles the oauth2 jwt profile exchange -func JWTProfileExchange(jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) { - return CallTokenEndpoint(jwtProfileGrantRequest, caller) +func JWTProfileExchange(ctx context.Context, jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) { + return CallTokenEndpoint(ctx, jwtProfileGrantRequest, caller) } func ClientAssertionCodeOptions(assertion string) []oauth2.AuthCodeOption { diff --git a/pkg/client/key.go b/pkg/client/key.go index 740c6d3..0c01dd2 100644 --- a/pkg/client/key.go +++ b/pkg/client/key.go @@ -10,7 +10,7 @@ const ( applicationKey = "application" ) -type keyFile struct { +type KeyFile struct { Type string `json:"type"` // serviceaccount or application KeyID string `json:"keyId"` Key string `json:"key"` @@ -23,7 +23,7 @@ type keyFile struct { ClientID string `json:"clientId"` } -func ConfigFromKeyFile(path string) (*keyFile, error) { +func ConfigFromKeyFile(path string) (*KeyFile, error) { data, err := ioutil.ReadFile(path) if err != nil { return nil, err @@ -31,8 +31,8 @@ func ConfigFromKeyFile(path string) (*keyFile, error) { return ConfigFromKeyFileData(data) } -func ConfigFromKeyFileData(data []byte) (*keyFile, error) { - var f keyFile +func ConfigFromKeyFileData(data []byte) (*KeyFile, error) { + var f KeyFile if err := json.Unmarshal(data, &f); err != nil { return nil, err } diff --git a/pkg/client/profile/jwt_profile.go b/pkg/client/profile/jwt_profile.go index bb18570..668f749 100644 --- a/pkg/client/profile/jwt_profile.go +++ b/pkg/client/profile/jwt_profile.go @@ -1,6 +1,7 @@ package profile import ( + "context" "net/http" "time" @@ -11,9 +12,12 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" ) -// jwtProfileTokenSource implement the oauth2.TokenSource -// it will request a token using the OAuth2 JWT Profile Grant -// therefore sending an `assertion` by singing a JWT with the provided private key +type TokenSource interface { + oauth2.TokenSource + TokenCtx(context.Context) (*oauth2.Token, error) +} + +// jwtProfileTokenSource implements the TokenSource type jwtProfileTokenSource struct { clientID string audience []string @@ -23,23 +27,38 @@ type jwtProfileTokenSource struct { tokenEndpoint string } -func NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath string, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) { - keyData, err := client.ConfigFromKeyFile(keyPath) +// NewJWTProfileTokenSourceFromKeyFile returns an implementation of TokenSource +// It will request a token using the OAuth2 JWT Profile Grant, +// therefore sending an `assertion` by singing a JWT with the provided private key from jsonFile. +// +// The passed context is only used for the call to the Discover endpoint. +func NewJWTProfileTokenSourceFromKeyFile(ctx context.Context, issuer, jsonFile string, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) { + keyData, err := client.ConfigFromKeyFile(jsonFile) if err != nil { return nil, err } - return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) + return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) } -func NewJWTProfileTokenSourceFromKeyFileData(issuer string, data []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) { - keyData, err := client.ConfigFromKeyFileData(data) +// NewJWTProfileTokenSourceFromKeyFileData returns an implementation of oauth2.TokenSource +// It will request a token using the OAuth2 JWT Profile Grant, +// therefore sending an `assertion` by singing a JWT with the provided private key in jsonData. +// +// The passed context is only used for the call to the Discover endpoint. +func NewJWTProfileTokenSourceFromKeyFileData(ctx context.Context, issuer string, jsonData []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) { + keyData, err := client.ConfigFromKeyFileData(jsonData) if err != nil { return nil, err } - return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) + return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) } -func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) { +// NewJWTProfileSource returns an implementation of oauth2.TokenSource +// It will request a token using the OAuth2 JWT Profile Grant, +// therefore sending an `assertion` by singing a JWT with the provided private key. +// +// The passed context is only used for the call to the Discover endpoint. +func NewJWTProfileTokenSource(ctx context.Context, issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) { signer, err := client.NewSignerFromPrivateKeyByte(key, keyID) if err != nil { return nil, err @@ -55,7 +74,7 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes opt(source) } if source.tokenEndpoint == "" { - config, err := client.Discover(issuer, source.httpClient) + config, err := client.Discover(ctx, issuer, source.httpClient) if err != nil { return nil, err } @@ -64,13 +83,13 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes return source, nil } -func WithHTTPClient(client *http.Client) func(*jwtProfileTokenSource) { +func WithHTTPClient(client *http.Client) func(source *jwtProfileTokenSource) { return func(source *jwtProfileTokenSource) { source.httpClient = client } } -func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(*jwtProfileTokenSource) { +func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(source *jwtProfileTokenSource) { return func(source *jwtProfileTokenSource) { source.tokenEndpoint = tokenEndpoint } @@ -85,9 +104,13 @@ func (j *jwtProfileTokenSource) HttpClient() *http.Client { } func (j *jwtProfileTokenSource) Token() (*oauth2.Token, error) { + return j.TokenCtx(context.Background()) +} + +func (j *jwtProfileTokenSource) TokenCtx(ctx context.Context) (*oauth2.Token, error) { assertion, err := client.SignedJWTProfileAssertion(j.clientID, j.audience, time.Hour, j.signer) if err != nil { return nil, err } - return client.JWTProfileExchange(oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j) + return client.JWTProfileExchange(ctx, oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j) } diff --git a/pkg/client/rp/device.go b/pkg/client/rp/device.go index 9cfc41e..b2c5be6 100644 --- a/pkg/client/rp/device.go +++ b/pkg/client/rp/device.go @@ -33,13 +33,13 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc. // DeviceAuthorization starts a new Device Authorization flow as defined // in RFC 8628, section 3.1 and 3.2: // https://www.rfc-editor.org/rfc/rfc8628#section-3.1 -func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) { +func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) { req, err := newDeviceClientCredentialsRequest(scopes, rp) if err != nil { return nil, err } - return client.CallDeviceAuthorizationEndpoint(req, rp) + return client.CallDeviceAuthorizationEndpoint(ctx, req, rp) } // DeviceAccessToken attempts to obtain tokens from a Device Authorization, diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index bd96e16..820107f 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "net/url" - "strings" "time" "github.com/google/uuid" @@ -177,7 +176,7 @@ func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingPart // NewRelyingPartyOIDC creates an (OIDC) RelyingParty with the given // issuer, clientID, clientSecret, redirectURI, scopes and possible configOptions // it will run discovery on the provided issuer and use the found endpoints -func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) { +func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) { rp := &relyingParty{ issuer: issuer, oauthConfig: &oauth2.Config{ @@ -195,7 +194,7 @@ func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, sco return nil, err } } - discoveryConfiguration, err := client.Discover(rp.issuer, rp.httpClient, rp.DiscoveryEndpoint) + discoveryConfiguration, err := client.Discover(ctx, rp.issuer, rp.httpClient, rp.DiscoveryEndpoint) if err != nil { return nil, err } @@ -310,26 +309,6 @@ func SignerFromKeyAndKeyID(key []byte, keyID string) SignerFromKey { } } -// Discover calls the discovery endpoint of the provided issuer and returns the found endpoints -// -// deprecated: use client.Discover -func Discover(issuer string, httpClient *http.Client) (Endpoints, error) { - wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint - req, err := http.NewRequest("GET", wellKnown, nil) - if err != nil { - return Endpoints{}, err - } - discoveryConfig := new(oidc.DiscoveryConfiguration) - err = httphelper.HttpRequest(httpClient, req, &discoveryConfig) - if err != nil { - return Endpoints{}, err - } - if discoveryConfig.Issuer != issuer { - return Endpoints{}, oidc.ErrIssuerInvalid - } - return GetEndpoints(discoveryConfig), nil -} - // AuthURL returns the auth request url // (wrapping the oauth2 `AuthCodeURL`) func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string { @@ -463,7 +442,7 @@ type CodeExchangeUserinfoCallback[C oidc.IDClaims] func(w http.ResponseWriter, r // on success it will pass the userinfo into its callback function as well func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] { return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) { - info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) + info, err := Userinfo(r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) if err != nil { http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized) return @@ -473,8 +452,8 @@ func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeEx } // Userinfo will call the OIDC Userinfo Endpoint with the provided token -func Userinfo(token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) { - req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil) +func Userinfo(ctx context.Context, token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil) if err != nil { return nil, err } @@ -620,7 +599,7 @@ type RefreshTokenRequest struct { GrantType oidc.GrantType `schema:"grant_type"` } -func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) { +func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) { request := RefreshTokenRequest{ RefreshToken: refreshToken, Scopes: rp.OAuthConfig().Scopes, @@ -630,17 +609,17 @@ func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAs ClientAssertionType: clientAssertionType, GrantType: oidc.GrantTypeRefreshToken, } - return client.CallTokenEndpoint(request, tokenEndpointCaller{RelyingParty: rp}) + return client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp}) } -func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { +func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { request := oidc.EndSessionRequest{ IdTokenHint: idToken, ClientID: rp.OAuthConfig().ClientID, PostLogoutRedirectURI: optionalRedirectURI, State: optionalState, } - return client.CallEndSessionEndpoint(request, nil, rp) + return client.CallEndSessionEndpoint(ctx, request, nil, rp) } // RevokeToken requires a RelyingParty that is also a client.RevokeCaller. The RelyingParty @@ -648,7 +627,7 @@ func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState str // NewRelyingPartyOAuth() does not. // // tokenTypeHint should be either "id_token" or "refresh_token". -func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error { +func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error { request := client.RevokeRequest{ Token: token, TokenTypeHint: tokenTypeHint, @@ -656,7 +635,7 @@ func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error { ClientSecret: rp.OAuthConfig().ClientSecret, } if rc, ok := rp.(client.RevokeCaller); ok && rc.GetRevokeEndpoint() != "" { - return client.CallRevokeEndpoint(request, nil, rc) + return client.CallRevokeEndpoint(ctx, request, nil, rc) } return fmt.Errorf("RelyingParty does not support RevokeCaller") } diff --git a/pkg/client/rs/resource_server.go b/pkg/client/rs/resource_server.go index f0e0e0a..054dfbe 100644 --- a/pkg/client/rs/resource_server.go +++ b/pkg/client/rs/resource_server.go @@ -42,14 +42,14 @@ func (r *resourceServer) AuthFn() (interface{}, error) { return r.authFn() } -func NewResourceServerClientCredentials(issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) { +func NewResourceServerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) { authorizer := func() (interface{}, error) { return httphelper.AuthorizeBasic(clientID, clientSecret), nil } - return newResourceServer(issuer, authorizer, option...) + return newResourceServer(ctx, issuer, authorizer, option...) } -func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) { +func NewResourceServerJWTProfile(ctx context.Context, issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) { signer, err := client.NewSignerFromPrivateKeyByte(key, keyID) if err != nil { return nil, err @@ -61,10 +61,10 @@ func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, opt } return client.ClientAssertionFormAuthorization(assertion), nil } - return newResourceServer(issuer, authorizer, options...) + return newResourceServer(ctx, issuer, authorizer, options...) } -func newResourceServer(issuer string, authorizer func() (interface{}, error), options ...Option) (*resourceServer, error) { +func newResourceServer(ctx context.Context, issuer string, authorizer func() (interface{}, error), options ...Option) (*resourceServer, error) { rs := &resourceServer{ issuer: issuer, httpClient: httphelper.DefaultHTTPClient, @@ -73,7 +73,7 @@ func newResourceServer(issuer string, authorizer func() (interface{}, error), op optFunc(rs) } if rs.introspectURL == "" || rs.tokenURL == "" { - config, err := client.Discover(rs.issuer, rs.httpClient) + config, err := client.Discover(ctx, rs.issuer, rs.httpClient) if err != nil { return nil, err } @@ -87,12 +87,12 @@ func newResourceServer(issuer string, authorizer func() (interface{}, error), op return rs, nil } -func NewResourceServerFromKeyFile(issuer, path string, options ...Option) (ResourceServer, error) { +func NewResourceServerFromKeyFile(ctx context.Context, issuer, path string, options ...Option) (ResourceServer, error) { c, err := client.ConfigFromKeyFile(path) if err != nil { return nil, err } - return NewResourceServerJWTProfile(issuer, c.ClientID, c.KeyID, []byte(c.Key), options...) + return NewResourceServerJWTProfile(ctx, issuer, c.ClientID, c.KeyID, []byte(c.Key), options...) } type Option func(*resourceServer) @@ -117,7 +117,7 @@ func Introspect(ctx context.Context, rp ResourceServer, token string) (*oidc.Int if err != nil { return nil, err } - req, err := httphelper.FormRequest(rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn) + req, err := httphelper.FormRequest(ctx, rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn) if err != nil { return nil, err } diff --git a/pkg/client/tokenexchange/tokenexchange.go b/pkg/client/tokenexchange/tokenexchange.go index ce665cd..1c10df2 100644 --- a/pkg/client/tokenexchange/tokenexchange.go +++ b/pkg/client/tokenexchange/tokenexchange.go @@ -1,6 +1,7 @@ package tokenexchange import ( + "context" "errors" "net/http" @@ -21,18 +22,18 @@ type OAuthTokenExchange struct { authFn func() (interface{}, error) } -func NewTokenExchanger(issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { - return newOAuthTokenExchange(issuer, nil, options...) +func NewTokenExchanger(ctx context.Context, issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { + return newOAuthTokenExchange(ctx, issuer, nil, options...) } -func NewTokenExchangerClientCredentials(issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { +func NewTokenExchangerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { authorizer := func() (interface{}, error) { return httphelper.AuthorizeBasic(clientID, clientSecret), nil } - return newOAuthTokenExchange(issuer, authorizer, options...) + return newOAuthTokenExchange(ctx, issuer, authorizer, options...) } -func newOAuthTokenExchange(issuer string, authorizer func() (interface{}, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) { +func newOAuthTokenExchange(ctx context.Context, issuer string, authorizer func() (interface{}, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) { te := &OAuthTokenExchange{ httpClient: httphelper.DefaultHTTPClient, } @@ -41,7 +42,7 @@ func newOAuthTokenExchange(issuer string, authorizer func() (interface{}, error) } if te.tokenEndpoint == "" { - config, err := client.Discover(issuer, te.httpClient) + config, err := client.Discover(ctx, issuer, te.httpClient) if err != nil { return nil, err } @@ -89,6 +90,7 @@ func (te *OAuthTokenExchange) AuthFn() (interface{}, error) { // ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint. // SubjectToken and SubjectTokenType are required parameters. func ExchangeToken( + ctx context.Context, te TokenExchanger, SubjectToken string, SubjectTokenType oidc.TokenType, @@ -123,5 +125,5 @@ func ExchangeToken( RequestedTokenType: RequestedTokenType, } - return client.CallTokenExchangeEndpoint(request, authFn, te) + return client.CallTokenExchangeEndpoint(ctx, request, authFn, te) } diff --git a/pkg/http/http.go b/pkg/http/http.go index d3c5b4f..9771888 100644 --- a/pkg/http/http.go +++ b/pkg/http/http.go @@ -33,7 +33,7 @@ func AuthorizeBasic(user, password string) RequestAuthorization { } } -func FormRequest(endpoint string, request interface{}, encoder Encoder, authFn interface{}) (*http.Request, error) { +func FormRequest(ctx context.Context, endpoint string, request interface{}, encoder Encoder, authFn interface{}) (*http.Request, error) { form := url.Values{} if err := encoder.Encode(request, form); err != nil { return nil, err @@ -42,7 +42,7 @@ func FormRequest(endpoint string, request interface{}, encoder Encoder, authFn i fn(form) } body := strings.NewReader(form.Encode()) - req, err := http.NewRequest("POST", endpoint, body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body) if err != nil { return nil, err }