feat: add context to all client calls (#345)
BREAKING CHANGE closes #309
This commit is contained in:
parent
33c716ddcf
commit
6af94fded0
14 changed files with 124 additions and 99 deletions
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
@ -27,7 +28,7 @@ func main() {
|
||||||
port := os.Getenv("PORT")
|
port := os.Getenv("PORT")
|
||||||
issuer := os.Getenv("ISSUER")
|
issuer := os.Getenv("ISSUER")
|
||||||
|
|
||||||
provider, err := rs.NewResourceServerFromKeyFile(issuer, keyPath)
|
provider, err := rs.NewResourceServerFromKeyFile(context.TODO(), issuer, keyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Fatalf("error creating provider %s", err.Error())
|
logrus.Fatalf("error creating provider %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -43,7 +44,7 @@ func main() {
|
||||||
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
|
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
|
||||||
}
|
}
|
||||||
|
|
||||||
provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, options...)
|
provider, err := rp.NewRelyingPartyOIDC(context.TODO(), issuer, clientID, clientSecret, redirectURI, scopes, options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Fatalf("error creating provider %s", err.Error())
|
logrus.Fatalf("error creating provider %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,13 +39,13 @@ func main() {
|
||||||
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
|
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
|
||||||
}
|
}
|
||||||
|
|
||||||
provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, "", scopes, options...)
|
provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, "", scopes, options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Fatalf("error creating provider %s", err.Error())
|
logrus.Fatalf("error creating provider %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
logrus.Info("starting device authorization flow")
|
logrus.Info("starting device authorization flow")
|
||||||
resp, err := rp.DeviceAuthorization(scopes, provider)
|
resp, err := rp.DeviceAuthorization(ctx, scopes, provider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Fatal(err)
|
logrus.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ func main() {
|
||||||
scopes := strings.Split(os.Getenv("SCOPES"), " ")
|
scopes := strings.Split(os.Getenv("SCOPES"), " ")
|
||||||
|
|
||||||
if keyPath != "" {
|
if keyPath != "" {
|
||||||
ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath, scopes)
|
ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(context.TODO(), issuer, keyPath, scopes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Fatalf("error creating token source %s", err.Error())
|
logrus.Fatalf("error creating token source %s", err.Error())
|
||||||
}
|
}
|
||||||
|
@ -76,7 +76,7 @@ func main() {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(issuer, key, scopes)
|
ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(context.TODO(), issuer, key, scopes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
|
|
@ -23,12 +23,12 @@ var Encoder = httphelper.Encoder(oidc.NewEncoder())
|
||||||
|
|
||||||
// Discover calls the discovery endpoint of the provided issuer and returns its configuration
|
// Discover calls the discovery endpoint of the provided issuer and returns its configuration
|
||||||
// It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url
|
// It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url
|
||||||
func Discover(issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) {
|
func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) {
|
||||||
wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
|
wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
|
||||||
if len(wellKnownUrl) == 1 && wellKnownUrl[0] != "" {
|
if len(wellKnownUrl) == 1 && wellKnownUrl[0] != "" {
|
||||||
wellKnown = wellKnownUrl[0]
|
wellKnown = wellKnownUrl[0]
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("GET", wellKnown, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -48,12 +48,12 @@ type TokenEndpointCaller interface {
|
||||||
HttpClient() *http.Client
|
HttpClient() *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func CallTokenEndpoint(request interface{}, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
|
func CallTokenEndpoint(ctx context.Context, request interface{}, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
|
||||||
return callTokenEndpoint(request, nil, caller)
|
return callTokenEndpoint(ctx, request, nil, caller)
|
||||||
}
|
}
|
||||||
|
|
||||||
func callTokenEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
|
func callTokenEndpoint(ctx context.Context, request interface{}, authFn interface{}, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
|
||||||
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn)
|
req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -74,8 +74,8 @@ type EndSessionCaller interface {
|
||||||
HttpClient() *http.Client
|
HttpClient() *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func CallEndSessionEndpoint(request interface{}, authFn interface{}, caller EndSessionCaller) (*url.URL, error) {
|
func CallEndSessionEndpoint(ctx context.Context, request interface{}, authFn interface{}, caller EndSessionCaller) (*url.URL, error) {
|
||||||
req, err := httphelper.FormRequest(caller.GetEndSessionEndpoint(), request, Encoder, authFn)
|
req, err := httphelper.FormRequest(ctx, caller.GetEndSessionEndpoint(), request, Encoder, authFn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -117,8 +117,8 @@ type RevokeRequest struct {
|
||||||
ClientSecret string `schema:"client_secret"`
|
ClientSecret string `schema:"client_secret"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCaller) error {
|
func CallRevokeEndpoint(ctx context.Context, request interface{}, authFn interface{}, caller RevokeCaller) error {
|
||||||
req, err := httphelper.FormRequest(caller.GetRevokeEndpoint(), request, Encoder, authFn)
|
req, err := httphelper.FormRequest(ctx, caller.GetRevokeEndpoint(), request, Encoder, authFn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -145,8 +145,8 @@ func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCa
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CallTokenExchangeEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) {
|
func CallTokenExchangeEndpoint(ctx context.Context, request interface{}, authFn interface{}, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) {
|
||||||
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn)
|
req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -186,8 +186,8 @@ type DeviceAuthorizationCaller interface {
|
||||||
HttpClient() *http.Client
|
HttpClient() *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) {
|
func CallDeviceAuthorizationEndpoint(ctx context.Context, request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) {
|
||||||
req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil)
|
req, err := httphelper.FormRequest(ctx, caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -208,7 +208,7 @@ type DeviceAccessTokenRequest struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
|
func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
|
||||||
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil)
|
req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package client_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
@ -10,7 +11,9 @@ import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -27,6 +30,18 @@ import (
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var CTX context.Context
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
os.Exit(func() int {
|
||||||
|
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT)
|
||||||
|
defer cancel()
|
||||||
|
CTX, cancel = context.WithTimeout(ctx, time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
return m.Run()
|
||||||
|
}())
|
||||||
|
}
|
||||||
|
|
||||||
func TestRelyingPartySession(t *testing.T) {
|
func TestRelyingPartySession(t *testing.T) {
|
||||||
t.Log("------- start example OP ------")
|
t.Log("------- start example OP ------")
|
||||||
targetURL := "http://local-site"
|
targetURL := "http://local-site"
|
||||||
|
@ -45,7 +60,7 @@ func TestRelyingPartySession(t *testing.T) {
|
||||||
|
|
||||||
t.Log("------- refresh tokens ------")
|
t.Log("------- refresh tokens ------")
|
||||||
|
|
||||||
newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "")
|
newTokens, err := rp.RefreshAccessToken(CTX, provider, refreshToken, "", "")
|
||||||
require.NoError(t, err, "refresh token")
|
require.NoError(t, err, "refresh token")
|
||||||
assert.NotNil(t, newTokens, "access token")
|
assert.NotNil(t, newTokens, "access token")
|
||||||
t.Logf("new access token %s", newTokens.AccessToken)
|
t.Logf("new access token %s", newTokens.AccessToken)
|
||||||
|
@ -56,7 +71,7 @@ func TestRelyingPartySession(t *testing.T) {
|
||||||
|
|
||||||
t.Log("------ end session (logout) ------")
|
t.Log("------ end session (logout) ------")
|
||||||
|
|
||||||
newLoc, err := rp.EndSession(provider, idToken, "", "")
|
newLoc, err := rp.EndSession(CTX, provider, idToken, "", "")
|
||||||
require.NoError(t, err, "logout")
|
require.NoError(t, err, "logout")
|
||||||
if newLoc != nil {
|
if newLoc != nil {
|
||||||
t.Logf("redirect to %s", newLoc)
|
t.Logf("redirect to %s", newLoc)
|
||||||
|
@ -66,11 +81,11 @@ func TestRelyingPartySession(t *testing.T) {
|
||||||
|
|
||||||
t.Log("------ attempt refresh again (should fail) ------")
|
t.Log("------ attempt refresh again (should fail) ------")
|
||||||
t.Log("trying original refresh token", refreshToken)
|
t.Log("trying original refresh token", refreshToken)
|
||||||
_, err = rp.RefreshAccessToken(provider, refreshToken, "", "")
|
_, err = rp.RefreshAccessToken(CTX, provider, refreshToken, "", "")
|
||||||
assert.Errorf(t, err, "refresh with original")
|
assert.Errorf(t, err, "refresh with original")
|
||||||
if newTokens.RefreshToken != "" {
|
if newTokens.RefreshToken != "" {
|
||||||
t.Log("trying replacement refresh token", newTokens.RefreshToken)
|
t.Log("trying replacement refresh token", newTokens.RefreshToken)
|
||||||
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
|
_, err = rp.RefreshAccessToken(CTX, provider, newTokens.RefreshToken, "", "")
|
||||||
assert.Errorf(t, err, "refresh with replacement")
|
assert.Errorf(t, err, "refresh with replacement")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -92,12 +107,13 @@ func TestResourceServerTokenExchange(t *testing.T) {
|
||||||
t.Log("------- run authorization code flow ------")
|
t.Log("------- run authorization code flow ------")
|
||||||
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)
|
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)
|
||||||
|
|
||||||
resourceServer, err := rs.NewResourceServerClientCredentials(opServer.URL, clientID, clientSecret)
|
resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret)
|
||||||
require.NoError(t, err, "new resource server")
|
require.NoError(t, err, "new resource server")
|
||||||
|
|
||||||
t.Log("------- exchage refresh tokens (impersonation) ------")
|
t.Log("------- exchage refresh tokens (impersonation) ------")
|
||||||
|
|
||||||
tokenExchangeResponse, err := tokenexchange.ExchangeToken(
|
tokenExchangeResponse, err := tokenexchange.ExchangeToken(
|
||||||
|
CTX,
|
||||||
resourceServer,
|
resourceServer,
|
||||||
refreshToken,
|
refreshToken,
|
||||||
oidc.RefreshTokenType,
|
oidc.RefreshTokenType,
|
||||||
|
@ -117,7 +133,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
|
||||||
|
|
||||||
t.Log("------ end session (logout) ------")
|
t.Log("------ end session (logout) ------")
|
||||||
|
|
||||||
newLoc, err := rp.EndSession(provider, idToken, "", "")
|
newLoc, err := rp.EndSession(CTX, provider, idToken, "", "")
|
||||||
require.NoError(t, err, "logout")
|
require.NoError(t, err, "logout")
|
||||||
if newLoc != nil {
|
if newLoc != nil {
|
||||||
t.Logf("redirect to %s", newLoc)
|
t.Logf("redirect to %s", newLoc)
|
||||||
|
@ -128,6 +144,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
|
||||||
t.Log("------- attempt exchage again (should fail) ------")
|
t.Log("------- attempt exchage again (should fail) ------")
|
||||||
|
|
||||||
tokenExchangeResponse, err = tokenexchange.ExchangeToken(
|
tokenExchangeResponse, err = tokenexchange.ExchangeToken(
|
||||||
|
CTX,
|
||||||
resourceServer,
|
resourceServer,
|
||||||
refreshToken,
|
refreshToken,
|
||||||
oidc.RefreshTokenType,
|
oidc.RefreshTokenType,
|
||||||
|
@ -166,6 +183,7 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
|
||||||
key := []byte("test1234test1234")
|
key := []byte("test1234test1234")
|
||||||
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
|
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
|
||||||
provider, err = rp.NewRelyingPartyOIDC(
|
provider, err = rp.NewRelyingPartyOIDC(
|
||||||
|
CTX,
|
||||||
opServer.URL,
|
opServer.URL,
|
||||||
clientID,
|
clientID,
|
||||||
clientSecret,
|
clientSecret,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
@ -10,8 +11,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// JWTProfileExchange handles the oauth2 jwt profile exchange
|
// JWTProfileExchange handles the oauth2 jwt profile exchange
|
||||||
func JWTProfileExchange(jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) {
|
func JWTProfileExchange(ctx context.Context, jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) {
|
||||||
return CallTokenEndpoint(jwtProfileGrantRequest, caller)
|
return CallTokenEndpoint(ctx, jwtProfileGrantRequest, caller)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClientAssertionCodeOptions(assertion string) []oauth2.AuthCodeOption {
|
func ClientAssertionCodeOptions(assertion string) []oauth2.AuthCodeOption {
|
||||||
|
|
|
@ -10,7 +10,7 @@ const (
|
||||||
applicationKey = "application"
|
applicationKey = "application"
|
||||||
)
|
)
|
||||||
|
|
||||||
type keyFile struct {
|
type KeyFile struct {
|
||||||
Type string `json:"type"` // serviceaccount or application
|
Type string `json:"type"` // serviceaccount or application
|
||||||
KeyID string `json:"keyId"`
|
KeyID string `json:"keyId"`
|
||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
|
@ -23,7 +23,7 @@ type keyFile struct {
|
||||||
ClientID string `json:"clientId"`
|
ClientID string `json:"clientId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConfigFromKeyFile(path string) (*keyFile, error) {
|
func ConfigFromKeyFile(path string) (*KeyFile, error) {
|
||||||
data, err := ioutil.ReadFile(path)
|
data, err := ioutil.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -31,8 +31,8 @@ func ConfigFromKeyFile(path string) (*keyFile, error) {
|
||||||
return ConfigFromKeyFileData(data)
|
return ConfigFromKeyFileData(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConfigFromKeyFileData(data []byte) (*keyFile, error) {
|
func ConfigFromKeyFileData(data []byte) (*KeyFile, error) {
|
||||||
var f keyFile
|
var f KeyFile
|
||||||
if err := json.Unmarshal(data, &f); err != nil {
|
if err := json.Unmarshal(data, &f); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package profile
|
package profile
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -11,9 +12,12 @@ import (
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
// jwtProfileTokenSource implement the oauth2.TokenSource
|
type TokenSource interface {
|
||||||
// it will request a token using the OAuth2 JWT Profile Grant
|
oauth2.TokenSource
|
||||||
// therefore sending an `assertion` by singing a JWT with the provided private key
|
TokenCtx(context.Context) (*oauth2.Token, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// jwtProfileTokenSource implements the TokenSource
|
||||||
type jwtProfileTokenSource struct {
|
type jwtProfileTokenSource struct {
|
||||||
clientID string
|
clientID string
|
||||||
audience []string
|
audience []string
|
||||||
|
@ -23,23 +27,38 @@ type jwtProfileTokenSource struct {
|
||||||
tokenEndpoint string
|
tokenEndpoint string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath string, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) {
|
// NewJWTProfileTokenSourceFromKeyFile returns an implementation of TokenSource
|
||||||
keyData, err := client.ConfigFromKeyFile(keyPath)
|
// It will request a token using the OAuth2 JWT Profile Grant,
|
||||||
|
// therefore sending an `assertion` by singing a JWT with the provided private key from jsonFile.
|
||||||
|
//
|
||||||
|
// The passed context is only used for the call to the Discover endpoint.
|
||||||
|
func NewJWTProfileTokenSourceFromKeyFile(ctx context.Context, issuer, jsonFile string, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) {
|
||||||
|
keyData, err := client.ConfigFromKeyFile(jsonFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
|
return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewJWTProfileTokenSourceFromKeyFileData(issuer string, data []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) {
|
// NewJWTProfileTokenSourceFromKeyFileData returns an implementation of oauth2.TokenSource
|
||||||
keyData, err := client.ConfigFromKeyFileData(data)
|
// It will request a token using the OAuth2 JWT Profile Grant,
|
||||||
|
// therefore sending an `assertion` by singing a JWT with the provided private key in jsonData.
|
||||||
|
//
|
||||||
|
// The passed context is only used for the call to the Discover endpoint.
|
||||||
|
func NewJWTProfileTokenSourceFromKeyFileData(ctx context.Context, issuer string, jsonData []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) {
|
||||||
|
keyData, err := client.ConfigFromKeyFileData(jsonData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
|
return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) {
|
// NewJWTProfileSource returns an implementation of oauth2.TokenSource
|
||||||
|
// It will request a token using the OAuth2 JWT Profile Grant,
|
||||||
|
// therefore sending an `assertion` by singing a JWT with the provided private key.
|
||||||
|
//
|
||||||
|
// The passed context is only used for the call to the Discover endpoint.
|
||||||
|
func NewJWTProfileTokenSource(ctx context.Context, issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) {
|
||||||
signer, err := client.NewSignerFromPrivateKeyByte(key, keyID)
|
signer, err := client.NewSignerFromPrivateKeyByte(key, keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -55,7 +74,7 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes
|
||||||
opt(source)
|
opt(source)
|
||||||
}
|
}
|
||||||
if source.tokenEndpoint == "" {
|
if source.tokenEndpoint == "" {
|
||||||
config, err := client.Discover(issuer, source.httpClient)
|
config, err := client.Discover(ctx, issuer, source.httpClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -64,13 +83,13 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes
|
||||||
return source, nil
|
return source, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithHTTPClient(client *http.Client) func(*jwtProfileTokenSource) {
|
func WithHTTPClient(client *http.Client) func(source *jwtProfileTokenSource) {
|
||||||
return func(source *jwtProfileTokenSource) {
|
return func(source *jwtProfileTokenSource) {
|
||||||
source.httpClient = client
|
source.httpClient = client
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(*jwtProfileTokenSource) {
|
func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(source *jwtProfileTokenSource) {
|
||||||
return func(source *jwtProfileTokenSource) {
|
return func(source *jwtProfileTokenSource) {
|
||||||
source.tokenEndpoint = tokenEndpoint
|
source.tokenEndpoint = tokenEndpoint
|
||||||
}
|
}
|
||||||
|
@ -85,9 +104,13 @@ func (j *jwtProfileTokenSource) HttpClient() *http.Client {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *jwtProfileTokenSource) Token() (*oauth2.Token, error) {
|
func (j *jwtProfileTokenSource) Token() (*oauth2.Token, error) {
|
||||||
|
return j.TokenCtx(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j *jwtProfileTokenSource) TokenCtx(ctx context.Context) (*oauth2.Token, error) {
|
||||||
assertion, err := client.SignedJWTProfileAssertion(j.clientID, j.audience, time.Hour, j.signer)
|
assertion, err := client.SignedJWTProfileAssertion(j.clientID, j.audience, time.Hour, j.signer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return client.JWTProfileExchange(oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j)
|
return client.JWTProfileExchange(ctx, oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j)
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,13 +33,13 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.
|
||||||
// DeviceAuthorization starts a new Device Authorization flow as defined
|
// DeviceAuthorization starts a new Device Authorization flow as defined
|
||||||
// in RFC 8628, section 3.1 and 3.2:
|
// in RFC 8628, section 3.1 and 3.2:
|
||||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
|
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
|
||||||
func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) {
|
func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) {
|
||||||
req, err := newDeviceClientCredentialsRequest(scopes, rp)
|
req, err := newDeviceClientCredentialsRequest(scopes, rp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return client.CallDeviceAuthorizationEndpoint(req, rp)
|
return client.CallDeviceAuthorizationEndpoint(ctx, req, rp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceAccessToken attempts to obtain tokens from a Device Authorization,
|
// DeviceAccessToken attempts to obtain tokens from a Device Authorization,
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
@ -177,7 +176,7 @@ func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingPart
|
||||||
// NewRelyingPartyOIDC creates an (OIDC) RelyingParty with the given
|
// NewRelyingPartyOIDC creates an (OIDC) RelyingParty with the given
|
||||||
// issuer, clientID, clientSecret, redirectURI, scopes and possible configOptions
|
// issuer, clientID, clientSecret, redirectURI, scopes and possible configOptions
|
||||||
// it will run discovery on the provided issuer and use the found endpoints
|
// it will run discovery on the provided issuer and use the found endpoints
|
||||||
func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) {
|
func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) {
|
||||||
rp := &relyingParty{
|
rp := &relyingParty{
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
oauthConfig: &oauth2.Config{
|
oauthConfig: &oauth2.Config{
|
||||||
|
@ -195,7 +194,7 @@ func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, sco
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
discoveryConfiguration, err := client.Discover(rp.issuer, rp.httpClient, rp.DiscoveryEndpoint)
|
discoveryConfiguration, err := client.Discover(ctx, rp.issuer, rp.httpClient, rp.DiscoveryEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -310,26 +309,6 @@ func SignerFromKeyAndKeyID(key []byte, keyID string) SignerFromKey {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Discover calls the discovery endpoint of the provided issuer and returns the found endpoints
|
|
||||||
//
|
|
||||||
// deprecated: use client.Discover
|
|
||||||
func Discover(issuer string, httpClient *http.Client) (Endpoints, error) {
|
|
||||||
wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
|
|
||||||
req, err := http.NewRequest("GET", wellKnown, nil)
|
|
||||||
if err != nil {
|
|
||||||
return Endpoints{}, err
|
|
||||||
}
|
|
||||||
discoveryConfig := new(oidc.DiscoveryConfiguration)
|
|
||||||
err = httphelper.HttpRequest(httpClient, req, &discoveryConfig)
|
|
||||||
if err != nil {
|
|
||||||
return Endpoints{}, err
|
|
||||||
}
|
|
||||||
if discoveryConfig.Issuer != issuer {
|
|
||||||
return Endpoints{}, oidc.ErrIssuerInvalid
|
|
||||||
}
|
|
||||||
return GetEndpoints(discoveryConfig), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthURL returns the auth request url
|
// AuthURL returns the auth request url
|
||||||
// (wrapping the oauth2 `AuthCodeURL`)
|
// (wrapping the oauth2 `AuthCodeURL`)
|
||||||
func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string {
|
func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string {
|
||||||
|
@ -463,7 +442,7 @@ type CodeExchangeUserinfoCallback[C oidc.IDClaims] func(w http.ResponseWriter, r
|
||||||
// on success it will pass the userinfo into its callback function as well
|
// on success it will pass the userinfo into its callback function as well
|
||||||
func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] {
|
func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] {
|
||||||
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
|
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
|
||||||
info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
|
info, err := Userinfo(r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
|
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
|
@ -473,8 +452,8 @@ func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeEx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Userinfo will call the OIDC Userinfo Endpoint with the provided token
|
// Userinfo will call the OIDC Userinfo Endpoint with the provided token
|
||||||
func Userinfo(token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) {
|
func Userinfo(ctx context.Context, token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) {
|
||||||
req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -620,7 +599,7 @@ type RefreshTokenRequest struct {
|
||||||
GrantType oidc.GrantType `schema:"grant_type"`
|
GrantType oidc.GrantType `schema:"grant_type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) {
|
func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) {
|
||||||
request := RefreshTokenRequest{
|
request := RefreshTokenRequest{
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
Scopes: rp.OAuthConfig().Scopes,
|
Scopes: rp.OAuthConfig().Scopes,
|
||||||
|
@ -630,17 +609,17 @@ func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAs
|
||||||
ClientAssertionType: clientAssertionType,
|
ClientAssertionType: clientAssertionType,
|
||||||
GrantType: oidc.GrantTypeRefreshToken,
|
GrantType: oidc.GrantTypeRefreshToken,
|
||||||
}
|
}
|
||||||
return client.CallTokenEndpoint(request, tokenEndpointCaller{RelyingParty: rp})
|
return client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp})
|
||||||
}
|
}
|
||||||
|
|
||||||
func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) {
|
func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) {
|
||||||
request := oidc.EndSessionRequest{
|
request := oidc.EndSessionRequest{
|
||||||
IdTokenHint: idToken,
|
IdTokenHint: idToken,
|
||||||
ClientID: rp.OAuthConfig().ClientID,
|
ClientID: rp.OAuthConfig().ClientID,
|
||||||
PostLogoutRedirectURI: optionalRedirectURI,
|
PostLogoutRedirectURI: optionalRedirectURI,
|
||||||
State: optionalState,
|
State: optionalState,
|
||||||
}
|
}
|
||||||
return client.CallEndSessionEndpoint(request, nil, rp)
|
return client.CallEndSessionEndpoint(ctx, request, nil, rp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RevokeToken requires a RelyingParty that is also a client.RevokeCaller. The RelyingParty
|
// RevokeToken requires a RelyingParty that is also a client.RevokeCaller. The RelyingParty
|
||||||
|
@ -648,7 +627,7 @@ func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState str
|
||||||
// NewRelyingPartyOAuth() does not.
|
// NewRelyingPartyOAuth() does not.
|
||||||
//
|
//
|
||||||
// tokenTypeHint should be either "id_token" or "refresh_token".
|
// tokenTypeHint should be either "id_token" or "refresh_token".
|
||||||
func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error {
|
func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error {
|
||||||
request := client.RevokeRequest{
|
request := client.RevokeRequest{
|
||||||
Token: token,
|
Token: token,
|
||||||
TokenTypeHint: tokenTypeHint,
|
TokenTypeHint: tokenTypeHint,
|
||||||
|
@ -656,7 +635,7 @@ func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error {
|
||||||
ClientSecret: rp.OAuthConfig().ClientSecret,
|
ClientSecret: rp.OAuthConfig().ClientSecret,
|
||||||
}
|
}
|
||||||
if rc, ok := rp.(client.RevokeCaller); ok && rc.GetRevokeEndpoint() != "" {
|
if rc, ok := rp.(client.RevokeCaller); ok && rc.GetRevokeEndpoint() != "" {
|
||||||
return client.CallRevokeEndpoint(request, nil, rc)
|
return client.CallRevokeEndpoint(ctx, request, nil, rc)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("RelyingParty does not support RevokeCaller")
|
return fmt.Errorf("RelyingParty does not support RevokeCaller")
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,14 +42,14 @@ func (r *resourceServer) AuthFn() (interface{}, error) {
|
||||||
return r.authFn()
|
return r.authFn()
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewResourceServerClientCredentials(issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) {
|
func NewResourceServerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) {
|
||||||
authorizer := func() (interface{}, error) {
|
authorizer := func() (interface{}, error) {
|
||||||
return httphelper.AuthorizeBasic(clientID, clientSecret), nil
|
return httphelper.AuthorizeBasic(clientID, clientSecret), nil
|
||||||
}
|
}
|
||||||
return newResourceServer(issuer, authorizer, option...)
|
return newResourceServer(ctx, issuer, authorizer, option...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) {
|
func NewResourceServerJWTProfile(ctx context.Context, issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) {
|
||||||
signer, err := client.NewSignerFromPrivateKeyByte(key, keyID)
|
signer, err := client.NewSignerFromPrivateKeyByte(key, keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -61,10 +61,10 @@ func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, opt
|
||||||
}
|
}
|
||||||
return client.ClientAssertionFormAuthorization(assertion), nil
|
return client.ClientAssertionFormAuthorization(assertion), nil
|
||||||
}
|
}
|
||||||
return newResourceServer(issuer, authorizer, options...)
|
return newResourceServer(ctx, issuer, authorizer, options...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newResourceServer(issuer string, authorizer func() (interface{}, error), options ...Option) (*resourceServer, error) {
|
func newResourceServer(ctx context.Context, issuer string, authorizer func() (interface{}, error), options ...Option) (*resourceServer, error) {
|
||||||
rs := &resourceServer{
|
rs := &resourceServer{
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
httpClient: httphelper.DefaultHTTPClient,
|
httpClient: httphelper.DefaultHTTPClient,
|
||||||
|
@ -73,7 +73,7 @@ func newResourceServer(issuer string, authorizer func() (interface{}, error), op
|
||||||
optFunc(rs)
|
optFunc(rs)
|
||||||
}
|
}
|
||||||
if rs.introspectURL == "" || rs.tokenURL == "" {
|
if rs.introspectURL == "" || rs.tokenURL == "" {
|
||||||
config, err := client.Discover(rs.issuer, rs.httpClient)
|
config, err := client.Discover(ctx, rs.issuer, rs.httpClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -87,12 +87,12 @@ func newResourceServer(issuer string, authorizer func() (interface{}, error), op
|
||||||
return rs, nil
|
return rs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewResourceServerFromKeyFile(issuer, path string, options ...Option) (ResourceServer, error) {
|
func NewResourceServerFromKeyFile(ctx context.Context, issuer, path string, options ...Option) (ResourceServer, error) {
|
||||||
c, err := client.ConfigFromKeyFile(path)
|
c, err := client.ConfigFromKeyFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return NewResourceServerJWTProfile(issuer, c.ClientID, c.KeyID, []byte(c.Key), options...)
|
return NewResourceServerJWTProfile(ctx, issuer, c.ClientID, c.KeyID, []byte(c.Key), options...)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Option func(*resourceServer)
|
type Option func(*resourceServer)
|
||||||
|
@ -117,7 +117,7 @@ func Introspect(ctx context.Context, rp ResourceServer, token string) (*oidc.Int
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req, err := httphelper.FormRequest(rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn)
|
req, err := httphelper.FormRequest(ctx, rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package tokenexchange
|
package tokenexchange
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
@ -21,18 +22,18 @@ type OAuthTokenExchange struct {
|
||||||
authFn func() (interface{}, error)
|
authFn func() (interface{}, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTokenExchanger(issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
|
func NewTokenExchanger(ctx context.Context, issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
|
||||||
return newOAuthTokenExchange(issuer, nil, options...)
|
return newOAuthTokenExchange(ctx, issuer, nil, options...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTokenExchangerClientCredentials(issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
|
func NewTokenExchangerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
|
||||||
authorizer := func() (interface{}, error) {
|
authorizer := func() (interface{}, error) {
|
||||||
return httphelper.AuthorizeBasic(clientID, clientSecret), nil
|
return httphelper.AuthorizeBasic(clientID, clientSecret), nil
|
||||||
}
|
}
|
||||||
return newOAuthTokenExchange(issuer, authorizer, options...)
|
return newOAuthTokenExchange(ctx, issuer, authorizer, options...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newOAuthTokenExchange(issuer string, authorizer func() (interface{}, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) {
|
func newOAuthTokenExchange(ctx context.Context, issuer string, authorizer func() (interface{}, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) {
|
||||||
te := &OAuthTokenExchange{
|
te := &OAuthTokenExchange{
|
||||||
httpClient: httphelper.DefaultHTTPClient,
|
httpClient: httphelper.DefaultHTTPClient,
|
||||||
}
|
}
|
||||||
|
@ -41,7 +42,7 @@ func newOAuthTokenExchange(issuer string, authorizer func() (interface{}, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if te.tokenEndpoint == "" {
|
if te.tokenEndpoint == "" {
|
||||||
config, err := client.Discover(issuer, te.httpClient)
|
config, err := client.Discover(ctx, issuer, te.httpClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -89,6 +90,7 @@ func (te *OAuthTokenExchange) AuthFn() (interface{}, error) {
|
||||||
// ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint.
|
// ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint.
|
||||||
// SubjectToken and SubjectTokenType are required parameters.
|
// SubjectToken and SubjectTokenType are required parameters.
|
||||||
func ExchangeToken(
|
func ExchangeToken(
|
||||||
|
ctx context.Context,
|
||||||
te TokenExchanger,
|
te TokenExchanger,
|
||||||
SubjectToken string,
|
SubjectToken string,
|
||||||
SubjectTokenType oidc.TokenType,
|
SubjectTokenType oidc.TokenType,
|
||||||
|
@ -123,5 +125,5 @@ func ExchangeToken(
|
||||||
RequestedTokenType: RequestedTokenType,
|
RequestedTokenType: RequestedTokenType,
|
||||||
}
|
}
|
||||||
|
|
||||||
return client.CallTokenExchangeEndpoint(request, authFn, te)
|
return client.CallTokenExchangeEndpoint(ctx, request, authFn, te)
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ func AuthorizeBasic(user, password string) RequestAuthorization {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func FormRequest(endpoint string, request interface{}, encoder Encoder, authFn interface{}) (*http.Request, error) {
|
func FormRequest(ctx context.Context, endpoint string, request interface{}, encoder Encoder, authFn interface{}) (*http.Request, error) {
|
||||||
form := url.Values{}
|
form := url.Values{}
|
||||||
if err := encoder.Encode(request, form); err != nil {
|
if err := encoder.Encode(request, form); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -42,7 +42,7 @@ func FormRequest(endpoint string, request interface{}, encoder Encoder, authFn i
|
||||||
fn(form)
|
fn(form)
|
||||||
}
|
}
|
||||||
body := strings.NewReader(form.Encode())
|
body := strings.NewReader(form.Encode())
|
||||||
req, err := http.NewRequest("POST", endpoint, body)
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue