Merge branch 'next' into main-next
prepare the merge of next into main by resolving merge conflicts.
This commit is contained in:
commit
0476b5946e
122 changed files with 8195 additions and 2858 deletions
|
@ -1,31 +1,25 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/crypto"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/crypto"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
var Encoder = func() httphelper.Encoder {
|
||||
e := schema.NewEncoder()
|
||||
e.RegisterEncoder(oidc.SpaceDelimitedArray{}, func(value reflect.Value) string {
|
||||
return value.Interface().(oidc.SpaceDelimitedArray).Encode()
|
||||
})
|
||||
return e
|
||||
}()
|
||||
var Encoder = httphelper.Encoder(oidc.NewEncoder())
|
||||
|
||||
// Discover calls the discovery endpoint of the provided issuer and returns its configuration
|
||||
// It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url
|
||||
|
@ -90,6 +84,9 @@ func CallEndSessionEndpoint(request interface{}, authFn interface{}, caller EndS
|
|||
return http.ErrUseLastResponse
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
|
@ -148,6 +145,18 @@ func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCa
|
|||
return nil
|
||||
}
|
||||
|
||||
func CallTokenExchangeEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) {
|
||||
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenRes := new(oidc.TokenExchangeResponse)
|
||||
if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tokenRes, nil
|
||||
}
|
||||
|
||||
func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error) {
|
||||
privateKey, err := crypto.BytesToPrivateKey(key)
|
||||
if err != nil {
|
||||
|
@ -167,7 +176,98 @@ func SignedJWTProfileAssertion(clientID string, audience []string, expiration ti
|
|||
Issuer: clientID,
|
||||
Subject: clientID,
|
||||
Audience: audience,
|
||||
ExpiresAt: oidc.Time(exp),
|
||||
IssuedAt: oidc.Time(iat),
|
||||
ExpiresAt: oidc.FromTime(exp),
|
||||
IssuedAt: oidc.FromTime(iat),
|
||||
}, signer)
|
||||
}
|
||||
|
||||
type DeviceAuthorizationCaller interface {
|
||||
GetDeviceAuthorizationEndpoint() string
|
||||
HttpClient() *http.Client
|
||||
}
|
||||
|
||||
func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) {
|
||||
req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if request.ClientSecret != "" {
|
||||
req.SetBasicAuth(request.ClientID, request.ClientSecret)
|
||||
}
|
||||
|
||||
resp := new(oidc.DeviceAuthorizationResponse)
|
||||
if err := httphelper.HttpRequest(caller.HttpClient(), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type DeviceAccessTokenRequest struct {
|
||||
*oidc.ClientCredentialsRequest
|
||||
oidc.DeviceAccessTokenRequest
|
||||
}
|
||||
|
||||
func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
|
||||
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if request.ClientSecret != "" {
|
||||
req.SetBasicAuth(request.ClientID, request.ClientSecret)
|
||||
}
|
||||
|
||||
httpResp, err := caller.HttpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
resp := new(struct {
|
||||
*oidc.AccessTokenResponse
|
||||
*oidc.Error
|
||||
})
|
||||
if err = json.NewDecoder(httpResp.Body).Decode(resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if httpResp.StatusCode == http.StatusOK {
|
||||
return resp.AccessTokenResponse, nil
|
||||
}
|
||||
|
||||
return nil, resp.Error
|
||||
}
|
||||
|
||||
func PollDeviceAccessTokenEndpoint(ctx context.Context, interval time.Duration, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
|
||||
for {
|
||||
timer := time.After(interval)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-timer:
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, interval)
|
||||
defer cancel()
|
||||
|
||||
resp, err := CallDeviceAccessTokenEndpoint(ctx, request, caller)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
interval += 5 * time.Second
|
||||
}
|
||||
var target *oidc.Error
|
||||
if !errors.As(err, &target) {
|
||||
return nil, err
|
||||
}
|
||||
switch target.ErrorType {
|
||||
case oidc.AuthorizationPending:
|
||||
continue
|
||||
case oidc.SlowDown:
|
||||
interval += 5 * time.Second
|
||||
continue
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
package rp_test
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
|
@ -15,34 +14,142 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/example/server/exampleop"
|
||||
"github.com/zitadel/oidc/example/server/storage"
|
||||
|
||||
"github.com/jeremija/gosubmit"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/oidc/v2/example/server/exampleop"
|
||||
"github.com/zitadel/oidc/v2/example/server/storage"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rs"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/tokenexchange"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
func TestRelyingPartySession(t *testing.T) {
|
||||
t.Log("------- start example OP ------")
|
||||
ctx := context.Background()
|
||||
exampleStorage := storage.NewStorage(storage.NewUserStore())
|
||||
targetURL := "http://local-site"
|
||||
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
|
||||
var dh deferredHandler
|
||||
opServer := httptest.NewServer(&dh)
|
||||
defer opServer.Close()
|
||||
t.Logf("auth server at %s", opServer.URL)
|
||||
dh.Handler = exampleop.SetupServer(ctx, opServer.URL, exampleStorage)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage)
|
||||
|
||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
|
||||
t.Log("------- run authorization code flow ------")
|
||||
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, "secret")
|
||||
|
||||
t.Log("------- refresh tokens ------")
|
||||
|
||||
newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
require.NoError(t, err, "refresh token")
|
||||
assert.NotNil(t, newTokens, "access token")
|
||||
t.Logf("new access token %s", newTokens.AccessToken)
|
||||
t.Logf("new refresh token %s", newTokens.RefreshToken)
|
||||
t.Logf("new token type %s", newTokens.TokenType)
|
||||
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
|
||||
require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
|
||||
|
||||
t.Log("------ end session (logout) ------")
|
||||
|
||||
newLoc, err := rp.EndSession(provider, idToken, "", "")
|
||||
require.NoError(t, err, "logout")
|
||||
if newLoc != nil {
|
||||
t.Logf("redirect to %s", newLoc)
|
||||
} else {
|
||||
t.Logf("no redirect")
|
||||
}
|
||||
|
||||
t.Log("------ attempt refresh again (should fail) ------")
|
||||
t.Log("trying original refresh token", refreshToken)
|
||||
_, err = rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
assert.Errorf(t, err, "refresh with original")
|
||||
if newTokens.RefreshToken != "" {
|
||||
t.Log("trying replacement refresh token", newTokens.RefreshToken)
|
||||
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
|
||||
assert.Errorf(t, err, "refresh with replacement")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceServerTokenExchange(t *testing.T) {
|
||||
t.Log("------- start example OP ------")
|
||||
targetURL := "http://local-site"
|
||||
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
|
||||
var dh deferredHandler
|
||||
opServer := httptest.NewServer(&dh)
|
||||
defer opServer.Close()
|
||||
t.Logf("auth server at %s", opServer.URL)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage)
|
||||
|
||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
clientSecret := "secret"
|
||||
|
||||
t.Log("------- run authorization code flow ------")
|
||||
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)
|
||||
|
||||
resourceServer, err := rs.NewResourceServerClientCredentials(opServer.URL, clientID, clientSecret)
|
||||
require.NoError(t, err, "new resource server")
|
||||
|
||||
t.Log("------- exchage refresh tokens (impersonation) ------")
|
||||
|
||||
tokenExchangeResponse, err := tokenexchange.ExchangeToken(
|
||||
resourceServer,
|
||||
refreshToken,
|
||||
oidc.RefreshTokenType,
|
||||
"",
|
||||
"",
|
||||
[]string{},
|
||||
[]string{},
|
||||
[]string{"profile", "custom_scope:impersonate:id2"},
|
||||
oidc.RefreshTokenType,
|
||||
)
|
||||
require.NoError(t, err, "refresh token")
|
||||
require.NotNil(t, tokenExchangeResponse, "token exchange response")
|
||||
assert.Equal(t, tokenExchangeResponse.IssuedTokenType, oidc.RefreshTokenType)
|
||||
assert.NotEmpty(t, tokenExchangeResponse.AccessToken, "access token")
|
||||
assert.NotEmpty(t, tokenExchangeResponse.RefreshToken, "refresh token")
|
||||
assert.Equal(t, []string(tokenExchangeResponse.Scopes), []string{"profile", "custom_scope:impersonate:id2"})
|
||||
|
||||
t.Log("------ end session (logout) ------")
|
||||
|
||||
newLoc, err := rp.EndSession(provider, idToken, "", "")
|
||||
require.NoError(t, err, "logout")
|
||||
if newLoc != nil {
|
||||
t.Logf("redirect to %s", newLoc)
|
||||
} else {
|
||||
t.Logf("no redirect")
|
||||
}
|
||||
|
||||
t.Log("------- attempt exchage again (should fail) ------")
|
||||
|
||||
tokenExchangeResponse, err = tokenexchange.ExchangeToken(
|
||||
resourceServer,
|
||||
refreshToken,
|
||||
oidc.RefreshTokenType,
|
||||
"",
|
||||
"",
|
||||
[]string{},
|
||||
[]string{},
|
||||
[]string{"profile", "custom_scope:impersonate:id2"},
|
||||
oidc.RefreshTokenType,
|
||||
)
|
||||
require.Error(t, err, "refresh token")
|
||||
assert.Contains(t, err.Error(), "subject_token is invalid")
|
||||
require.Nil(t, tokenExchangeResponse, "token exchange response")
|
||||
|
||||
}
|
||||
|
||||
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) {
|
||||
targetURL := "http://local-site"
|
||||
localURL, err := url.Parse(targetURL + "/login?requestID=1234")
|
||||
require.NoError(t, err, "local url")
|
||||
|
||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
client := storage.WebClient(clientID, "secret", targetURL)
|
||||
client := storage.WebClient(clientID, clientSecret, targetURL)
|
||||
storage.RegisterClients(client)
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
|
@ -58,10 +165,10 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
t.Log("------- create RP ------")
|
||||
key := []byte("test1234test1234")
|
||||
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
|
||||
provider, err := rp.NewRelyingPartyOIDC(
|
||||
provider, err = rp.NewRelyingPartyOIDC(
|
||||
opServer.URL,
|
||||
clientID,
|
||||
"secret",
|
||||
clientSecret,
|
||||
targetURL,
|
||||
[]string{"openid", "email", "profile", "offline_access"},
|
||||
rp.WithPKCE(cookieHandler),
|
||||
|
@ -70,8 +177,10 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
rp.WithSupportedSigningAlgorithms("RS256", "RS384", "RS512", "ES256", "ES384", "ES512"),
|
||||
),
|
||||
)
|
||||
require.NoError(t, err, "new rp")
|
||||
|
||||
t.Log("------- get redirect from local client (rp) to OP ------")
|
||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||
state := "state-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
capturedW := httptest.NewRecorder()
|
||||
get := httptest.NewRequest("GET", localURL.String(), nil)
|
||||
|
@ -114,7 +223,7 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
|
||||
t.Log("------- post to login form, get redirect to OP ------")
|
||||
postLoginRedirectURL := fillForm(t, "fill login form", httpClient, form, loginPageURL,
|
||||
gosubmit.Set("username", "test-user"),
|
||||
gosubmit.Set("username", "test-user@local-site"),
|
||||
gosubmit.Set("password", "verysecure"))
|
||||
t.Logf("Get redirect from %s", postLoginRedirectURL)
|
||||
|
||||
|
@ -130,19 +239,19 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
t.Logf("setting cookie %s", cookie)
|
||||
}
|
||||
|
||||
var accessToken, refreshToken, idToken, email string
|
||||
redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {
|
||||
var email string
|
||||
redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
|
||||
require.NotNil(t, tokens, "tokens")
|
||||
require.NotNil(t, info, "info")
|
||||
t.Log("access token", tokens.AccessToken)
|
||||
t.Log("refresh token", tokens.RefreshToken)
|
||||
t.Log("id token", tokens.IDToken)
|
||||
t.Log("email", info.GetEmail())
|
||||
t.Log("email", info.Email)
|
||||
|
||||
accessToken = tokens.AccessToken
|
||||
refreshToken = tokens.RefreshToken
|
||||
idToken = tokens.IDToken
|
||||
email = info.GetEmail()
|
||||
email = info.Email
|
||||
http.Redirect(w, r, targetURL, 302)
|
||||
}
|
||||
rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider, rp.WithURLParam("custom", "param"))(capturedW, get)
|
||||
|
@ -154,7 +263,6 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
require.Less(t, capturedW.Code, 400, "token exchange response code")
|
||||
require.Less(t, capturedW.Code, 400, "token exchange response code")
|
||||
// TODO: how to check the custom header was sent to the server?
|
||||
|
||||
//nolint:bodyclose
|
||||
|
@ -169,43 +277,7 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
assert.NotEmpty(t, accessToken, "access token")
|
||||
assert.NotEmpty(t, email, "email")
|
||||
|
||||
t.Log("------- refresh tokens ------")
|
||||
|
||||
newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
require.NoError(t, err, "refresh token")
|
||||
assert.NotNil(t, newTokens, "access token")
|
||||
t.Logf("new access token %s", newTokens.AccessToken)
|
||||
t.Logf("new refresh token %s", newTokens.RefreshToken)
|
||||
t.Logf("new token type %s", newTokens.TokenType)
|
||||
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
|
||||
require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
|
||||
|
||||
t.Log("------ end session (logout) ------")
|
||||
|
||||
newLoc, err := rp.EndSession(provider, idToken, "", "")
|
||||
require.NoError(t, err, "logout")
|
||||
if newLoc != nil {
|
||||
t.Logf("redirect to %s", newLoc)
|
||||
} else {
|
||||
t.Logf("no redirect")
|
||||
}
|
||||
|
||||
t.Log("------ attempt refresh again (should fail) ------")
|
||||
t.Log("trying original refresh token", refreshToken)
|
||||
_, err = rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
assert.Errorf(t, err, "refresh with original")
|
||||
if newTokens.RefreshToken != "" {
|
||||
t.Log("trying replacement refresh token", newTokens.RefreshToken)
|
||||
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
|
||||
assert.Errorf(t, err, "refresh with replacement")
|
||||
}
|
||||
|
||||
t.Run("WithPrompt", func(t *testing.T) {
|
||||
opts := rp.WithPrompt("foo", "bar")()
|
||||
url := provider.OAuthConfig().AuthCodeURL("some", opts...)
|
||||
|
||||
require.Contains(t, url, "prompt=foo+bar")
|
||||
})
|
||||
return provider, accessToken, refreshToken, idToken
|
||||
}
|
||||
|
||||
type deferredHandler struct {
|
|
@ -5,8 +5,8 @@ import (
|
|||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
// JWTProfileExchange handles the oauth2 jwt profile exchange
|
||||
|
|
|
@ -7,8 +7,8 @@ import (
|
|||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/client"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
// jwtProfileTokenSource implement the oauth2.TokenSource
|
||||
|
|
|
@ -4,22 +4,22 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
loginPath = "/login"
|
||||
)
|
||||
|
||||
func CodeFlow(ctx context.Context, relyingParty rp.RelyingParty, callbackPath, port string, stateProvider func() string) *oidc.Tokens {
|
||||
func CodeFlow[C oidc.IDClaims](ctx context.Context, relyingParty rp.RelyingParty, callbackPath, port string, stateProvider func() string) *oidc.Tokens[C] {
|
||||
codeflowCtx, codeflowCancel := context.WithCancel(ctx)
|
||||
defer codeflowCancel()
|
||||
|
||||
tokenChan := make(chan *oidc.Tokens, 1)
|
||||
tokenChan := make(chan *oidc.Tokens[C], 1)
|
||||
|
||||
callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
|
||||
callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp rp.RelyingParty) {
|
||||
tokenChan <- tokens
|
||||
msg := "<p><strong>Success!</strong></p>"
|
||||
msg = msg + "<p>You are authenticated and can now return to the CLI.</p>"
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"github.com/zitadel/oidc/pkg/oidc/grants/tokenexchange"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange"
|
||||
)
|
||||
|
||||
// DelegationTokenRequest is an implementation of TokenExchangeRequest
|
||||
// it exchanges an "urn:ietf:params:oauth:token-type:access_token" with an optional
|
||||
//"urn:ietf:params:oauth:token-type:access_token" actor token for an
|
||||
//"urn:ietf:params:oauth:token-type:access_token" delegation token
|
||||
// "urn:ietf:params:oauth:token-type:access_token" actor token for an
|
||||
// "urn:ietf:params:oauth:token-type:access_token" delegation token
|
||||
func DelegationTokenRequest(subjectToken string, opts ...tokenexchange.TokenExchangeOption) *tokenexchange.TokenExchangeRequest {
|
||||
return tokenexchange.NewTokenExchangeRequest(subjectToken, tokenexchange.AccessTokenType, opts...)
|
||||
}
|
||||
|
|
62
pkg/client/rp/device.go
Normal file
62
pkg/client/rp/device.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) {
|
||||
confg := rp.OAuthConfig()
|
||||
req := &oidc.ClientCredentialsRequest{
|
||||
GrantType: oidc.GrantTypeDeviceCode,
|
||||
Scope: scopes,
|
||||
ClientID: confg.ClientID,
|
||||
ClientSecret: confg.ClientSecret,
|
||||
}
|
||||
|
||||
if signer := rp.Signer(); signer != nil {
|
||||
assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, signer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build assertion: %w", err)
|
||||
}
|
||||
req.ClientAssertion = assertion
|
||||
req.ClientAssertionType = oidc.ClientAssertionTypeJWTAssertion
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// DeviceAuthorization starts a new Device Authorization flow as defined
|
||||
// in RFC 8628, section 3.1 and 3.2:
|
||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
|
||||
func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) {
|
||||
req, err := newDeviceClientCredentialsRequest(scopes, rp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.CallDeviceAuthorizationEndpoint(req, rp)
|
||||
}
|
||||
|
||||
// DeviceAccessToken attempts to obtain tokens from a Device Authorization,
|
||||
// by means of polling as defined in RFC, section 3.3 and 3.4:
|
||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4
|
||||
func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
|
||||
req := &client.DeviceAccessTokenRequest{
|
||||
DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
|
||||
GrantType: oidc.GrantTypeDeviceCode,
|
||||
DeviceCode: deviceCode,
|
||||
},
|
||||
}
|
||||
|
||||
req.ClientCredentialsRequest, err = newDeviceClientCredentialsRequest(nil, rp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.PollDeviceAccessTokenEndpoint(ctx, interval, req, tokenEndpointCaller{rp})
|
||||
}
|
|
@ -9,8 +9,8 @@ import (
|
|||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet {
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
package mock
|
||||
|
||||
//go:generate mockgen -package mock -destination ./verifier.mock.go github.com/zitadel/oidc/pkg/rp Verifier
|
|
@ -1,67 +0,0 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/pkg/rp (interfaces: Verifier)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
// MockVerifier is a mock of Verifier interface
|
||||
type MockVerifier struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockVerifierMockRecorder
|
||||
}
|
||||
|
||||
// MockVerifierMockRecorder is the mock recorder for MockVerifier
|
||||
type MockVerifierMockRecorder struct {
|
||||
mock *MockVerifier
|
||||
}
|
||||
|
||||
// NewMockVerifier creates a new mock instance
|
||||
func NewMockVerifier(ctrl *gomock.Controller) *MockVerifier {
|
||||
mock := &MockVerifier{ctrl: ctrl}
|
||||
mock.recorder = &MockVerifierMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockVerifier) EXPECT() *MockVerifierMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Verify mocks base method
|
||||
func (m *MockVerifier) Verify(arg0 context.Context, arg1, arg2 string) (*oidc.IDTokenClaims, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Verify", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(*oidc.IDTokenClaims)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Verify indicates an expected call of Verify
|
||||
func (mr *MockVerifierMockRecorder) Verify(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockVerifier)(nil).Verify), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// VerifyIDToken mocks base method
|
||||
func (m *MockVerifier) VerifyIDToken(arg0 context.Context, arg1 string) (*oidc.IDTokenClaims, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "VerifyIDToken", arg0, arg1)
|
||||
ret0, _ := ret[0].(*oidc.IDTokenClaims)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// VerifyIDToken indicates an expected call of VerifyIDToken
|
||||
func (mr *MockVerifierMockRecorder) VerifyIDToken(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyIDToken", reflect.TypeOf((*MockVerifier)(nil).VerifyIDToken), arg0, arg1)
|
||||
}
|
|
@ -14,9 +14,9 @@ import (
|
|||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -54,11 +54,15 @@ type RelyingParty interface {
|
|||
GetEndSessionEndpoint() string
|
||||
|
||||
// GetRevokeEndpoint returns the endpoint to revoke a specific token
|
||||
// "GetRevokeEndpoint() string" will be added in a future release
|
||||
GetRevokeEndpoint() string
|
||||
|
||||
// UserinfoEndpoint returns the userinfo
|
||||
UserinfoEndpoint() string
|
||||
|
||||
// GetDeviceAuthorizationEndpoint returns the enpoint which can
|
||||
// be used to start a DeviceAuthorization flow.
|
||||
GetDeviceAuthorizationEndpoint() string
|
||||
|
||||
// IDTokenVerifier returns the verifier interface used for oidc id_token verification
|
||||
IDTokenVerifier() IDTokenVerifier
|
||||
// ErrorHandler returns the handler used for callback errors
|
||||
|
@ -121,6 +125,10 @@ func (rp *relyingParty) UserinfoEndpoint() string {
|
|||
return rp.endpoints.UserinfoURL
|
||||
}
|
||||
|
||||
func (rp *relyingParty) GetDeviceAuthorizationEndpoint() string {
|
||||
return rp.endpoints.DeviceAuthorizationURL
|
||||
}
|
||||
|
||||
func (rp *relyingParty) GetEndSessionEndpoint() string {
|
||||
return rp.endpoints.EndSessionURL
|
||||
}
|
||||
|
@ -371,7 +379,7 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri
|
|||
|
||||
// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
|
||||
// returning it parsed together with the oauth2 tokens (access, refresh)
|
||||
func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens, err error) {
|
||||
func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
|
||||
codeOpts := make([]oauth2.AuthCodeOption, 0)
|
||||
for _, opt := range opts {
|
||||
|
@ -384,7 +392,7 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod
|
|||
}
|
||||
|
||||
if rp.IsOAuth2Only() {
|
||||
return &oidc.Tokens{Token: token}, nil
|
||||
return &oidc.Tokens[C]{Token: token}, nil
|
||||
}
|
||||
|
||||
idTokenString, ok := token.Extra(idTokenKey).(string)
|
||||
|
@ -392,21 +400,21 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod
|
|||
return nil, errors.New("id_token missing")
|
||||
}
|
||||
|
||||
idToken, err := VerifyTokens(ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
|
||||
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
|
||||
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
|
||||
}
|
||||
|
||||
type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty)
|
||||
type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty)
|
||||
|
||||
// CodeExchangeHandler extends the `CodeExchange` method with a http handler
|
||||
// including cookie handling for secure `state` transfer
|
||||
// and optional PKCE code verifier checking.
|
||||
// Custom paramaters can optionally be set to the token URL.
|
||||
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
|
||||
func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := tryReadStateCookie(w, r, rp)
|
||||
if err != nil {
|
||||
|
@ -439,7 +447,7 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlPara
|
|||
}
|
||||
codeOpts = append(codeOpts, WithClientAssertionJWT(assertion))
|
||||
}
|
||||
tokens, err := CodeExchange(r.Context(), params.Get("code"), rp, codeOpts...)
|
||||
tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
|
@ -448,13 +456,13 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlPara
|
|||
}
|
||||
}
|
||||
|
||||
type CodeExchangeUserinfoCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, provider RelyingParty, info oidc.UserInfo)
|
||||
type CodeExchangeUserinfoCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info *oidc.UserInfo)
|
||||
|
||||
// UserinfoCallback wraps the callback function of the CodeExchangeHandler
|
||||
// and calls the userinfo endpoint with the access token
|
||||
// on success it will pass the userinfo into its callback function as well
|
||||
func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
|
||||
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty) {
|
||||
func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] {
|
||||
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
|
||||
info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
|
||||
if err != nil {
|
||||
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
|
||||
|
@ -465,17 +473,17 @@ func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
|
|||
}
|
||||
|
||||
// Userinfo will call the OIDC Userinfo Endpoint with the provided token
|
||||
func Userinfo(token, tokenType, subject string, rp RelyingParty) (oidc.UserInfo, error) {
|
||||
func Userinfo(token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) {
|
||||
req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("authorization", tokenType+" "+token)
|
||||
userinfo := oidc.NewUserInfo()
|
||||
userinfo := new(oidc.UserInfo)
|
||||
if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userinfo.GetSubject() != subject {
|
||||
if userinfo.Subject != subject {
|
||||
return nil, ErrUserInfoSubNotMatching
|
||||
}
|
||||
return userinfo, nil
|
||||
|
@ -506,11 +514,12 @@ type OptionFunc func(RelyingParty)
|
|||
|
||||
type Endpoints struct {
|
||||
oauth2.Endpoint
|
||||
IntrospectURL string
|
||||
UserinfoURL string
|
||||
JKWsURL string
|
||||
EndSessionURL string
|
||||
RevokeURL string
|
||||
IntrospectURL string
|
||||
UserinfoURL string
|
||||
JKWsURL string
|
||||
EndSessionURL string
|
||||
RevokeURL string
|
||||
DeviceAuthorizationURL string
|
||||
}
|
||||
|
||||
func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
|
||||
|
@ -520,11 +529,12 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
|
|||
AuthStyle: oauth2.AuthStyleAutoDetect,
|
||||
TokenURL: discoveryConfig.TokenEndpoint,
|
||||
},
|
||||
IntrospectURL: discoveryConfig.IntrospectionEndpoint,
|
||||
UserinfoURL: discoveryConfig.UserinfoEndpoint,
|
||||
JKWsURL: discoveryConfig.JwksURI,
|
||||
EndSessionURL: discoveryConfig.EndSessionEndpoint,
|
||||
RevokeURL: discoveryConfig.RevocationEndpoint,
|
||||
IntrospectURL: discoveryConfig.IntrospectionEndpoint,
|
||||
UserinfoURL: discoveryConfig.UserinfoEndpoint,
|
||||
JKWsURL: discoveryConfig.JwksURI,
|
||||
EndSessionURL: discoveryConfig.EndSessionEndpoint,
|
||||
RevokeURL: discoveryConfig.RevocationEndpoint,
|
||||
DeviceAuthorizationURL: discoveryConfig.DeviceAuthorizationEndpoint,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc/grants/tokenexchange"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange"
|
||||
)
|
||||
|
||||
// TokenExchangeRP extends the `RelyingParty` interface for the *draft* oauth2 `Token Exchange`
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type IDTokenVerifier interface {
|
||||
|
@ -20,76 +20,78 @@ type IDTokenVerifier interface {
|
|||
}
|
||||
|
||||
// VerifyTokens implement the Token Response Validation as defined in OIDC specification
|
||||
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
||||
func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
|
||||
idToken, err := VerifyIDToken(ctx, idTokenString, v)
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
||||
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) {
|
||||
var nilClaims C
|
||||
|
||||
claims, err = VerifyIDToken[C](ctx, idToken, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil {
|
||||
return nil, err
|
||||
if err := VerifyAccessToken(accessToken, claims.GetAccessTokenHash(), claims.GetSignatureAlgorithm()); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
return idToken, nil
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// VerifyIDToken validates the id token according to
|
||||
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
|
||||
claims := oidc.EmptyIDTokenClaims()
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) {
|
||||
var nilClaims C
|
||||
|
||||
decrypted, err := oidc.DecryptToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
payload, err := oidc.ParseToken(decrypted, claims)
|
||||
payload, err := oidc.ParseToken(decrypted, &claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err := oidc.CheckSubject(claims); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAudience(claims, v.ClientID()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// VerifyAccessToken validates the access token according to
|
||||
//https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
|
||||
func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
|
||||
if atHash == "" {
|
||||
return nil
|
||||
|
@ -112,7 +114,7 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...
|
|||
issuer: issuer,
|
||||
clientID: clientID,
|
||||
keySet: keySet,
|
||||
offset: 1 * time.Second,
|
||||
offset: time.Second,
|
||||
nonce: func(_ context.Context) string {
|
||||
return ""
|
||||
},
|
||||
|
@ -139,7 +141,7 @@ func WithIssuedAtOffset(offset time.Duration) func(*idTokenVerifier) {
|
|||
// WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
|
||||
func WithIssuedAtMaxAge(maxAge time.Duration) func(*idTokenVerifier) {
|
||||
return func(v *idTokenVerifier) {
|
||||
v.maxAge = maxAge
|
||||
v.maxAgeIAT = maxAge
|
||||
}
|
||||
}
|
||||
|
||||
|
|
339
pkg/client/rp/verifier_test.go
Normal file
339
pkg/client/rp/verifier_test.go
Normal file
|
@ -0,0 +1,339 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
func TestVerifyTokens(t *testing.T) {
|
||||
verifier := &idTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
maxAgeIAT: 2 * time.Minute,
|
||||
offset: time.Second,
|
||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
keySet: tu.KeySet{},
|
||||
maxAge: 2 * time.Minute,
|
||||
acr: tu.ACRVerify,
|
||||
nonce: func(context.Context) string { return tu.ValidNonce },
|
||||
clientID: tu.ValidClientID,
|
||||
}
|
||||
accessToken, _ := tu.ValidAccessToken()
|
||||
atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
idTokenClaims func() (string, *oidc.IDTokenClaims)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "without access token",
|
||||
idTokenClaims: tu.ValidIDToken,
|
||||
},
|
||||
{
|
||||
name: "with access token",
|
||||
accessToken: accessToken,
|
||||
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash,
|
||||
)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "expired id token",
|
||||
accessToken: accessToken,
|
||||
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash,
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong access token",
|
||||
accessToken: accessToken,
|
||||
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "~~~",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
idToken, want := tt.idTokenClaims()
|
||||
got, err := VerifyTokens[*oidc.IDTokenClaims](context.Background(), tt.accessToken, idToken, verifier)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, got, want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyIDToken(t *testing.T) {
|
||||
verifier := &idTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
maxAgeIAT: 2 * time.Minute,
|
||||
offset: time.Second,
|
||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
keySet: tu.KeySet{},
|
||||
maxAge: 2 * time.Minute,
|
||||
acr: tu.ACRVerify,
|
||||
nonce: func(context.Context) string { return tu.ValidNonce },
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientID string
|
||||
tokenClaims func() (string, *oidc.IDTokenClaims)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: tu.ValidIDToken,
|
||||
},
|
||||
{
|
||||
name: "parse err",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid signature",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty subject",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, "", tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong issuer",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
"foo", tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong clientID",
|
||||
clientID: "foo",
|
||||
tokenClaims: tu.ValidIDToken,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong IAT",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong acr",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
"else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired auth",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime.Add(-time.Hour), tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong nonce",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, "foo",
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, want := tt.tokenClaims()
|
||||
verifier.clientID = tt.clientID
|
||||
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, got, want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAccessToken(t *testing.T) {
|
||||
token, _ := tu.ValidAccessToken()
|
||||
hash, err := oidc.ClaimHash(token, tu.SignatureAlgorithm)
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
accessToken string
|
||||
atHash string
|
||||
sigAlgorithm jose.SignatureAlgorithm
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty hash",
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
args: args{
|
||||
accessToken: token,
|
||||
atHash: hash,
|
||||
sigAlgorithm: tu.SignatureAlgorithm,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid algorithm",
|
||||
args: args{
|
||||
accessToken: token,
|
||||
atHash: hash,
|
||||
sigAlgorithm: "foo",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "mismatch",
|
||||
args: args{
|
||||
accessToken: token,
|
||||
atHash: "~~",
|
||||
sigAlgorithm: tu.SignatureAlgorithm,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := VerifyAccessToken(tt.args.accessToken, tt.args.atHash, tt.args.sigAlgorithm)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewIDTokenVerifier(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
clientID string
|
||||
keySet oidc.KeySet
|
||||
options []VerifierOption
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want IDTokenVerifier
|
||||
}{
|
||||
{
|
||||
name: "nil nonce", // otherwise assert.Equal will fail on the function
|
||||
args: args{
|
||||
issuer: tu.ValidIssuer,
|
||||
clientID: tu.ValidClientID,
|
||||
keySet: tu.KeySet{},
|
||||
options: []VerifierOption{
|
||||
WithIssuedAtOffset(time.Minute),
|
||||
WithIssuedAtMaxAge(time.Hour),
|
||||
WithNonce(nil), // otherwise assert.Equal will fail on the function
|
||||
WithACRVerifier(nil),
|
||||
WithAuthTimeMaxAge(2 * time.Hour),
|
||||
WithSupportedSigningAlgorithms("ABC", "DEF"),
|
||||
},
|
||||
},
|
||||
want: &idTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
offset: time.Minute,
|
||||
maxAgeIAT: time.Hour,
|
||||
clientID: tu.ValidClientID,
|
||||
keySet: tu.KeySet{},
|
||||
nonce: nil,
|
||||
acr: nil,
|
||||
maxAge: 2 * time.Hour,
|
||||
supportedSignAlgs: []string{"ABC", "DEF"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewIDTokenVerifier(tt.args.issuer, tt.args.clientID, tt.args.keySet, tt.args.options...)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
86
pkg/client/rp/verifier_tokens_example_test.go
Normal file
86
pkg/client/rp/verifier_tokens_example_test.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package rp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
// MyCustomClaims extends the TokenClaims base,
|
||||
// so it implmeents the oidc.Claims interface.
|
||||
// Instead of carrying a map, we add needed fields// to the struct for type safe access.
|
||||
type MyCustomClaims struct {
|
||||
oidc.TokenClaims
|
||||
NotBefore oidc.Time `json:"nbf,omitempty"`
|
||||
AccessTokenHash string `json:"at_hash,omitempty"`
|
||||
Foo string `json:"foo,omitempty"`
|
||||
Bar *Nested `json:"bar,omitempty"`
|
||||
}
|
||||
|
||||
// GetAccessTokenHash is required to implement
|
||||
// the oidc.IDClaims interface.
|
||||
func (c *MyCustomClaims) GetAccessTokenHash() string {
|
||||
return c.AccessTokenHash
|
||||
}
|
||||
|
||||
// Nested struct types are also possible.
|
||||
type Nested struct {
|
||||
Count int `json:"count,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
/*
|
||||
idToken carries the following claims. foo and bar are custom claims
|
||||
|
||||
{
|
||||
"acr": "something",
|
||||
"amr": [
|
||||
"foo",
|
||||
"bar"
|
||||
],
|
||||
"at_hash": "2dzbm_vIxy-7eRtqUIGPPw",
|
||||
"aud": [
|
||||
"unit",
|
||||
"test",
|
||||
"555666"
|
||||
],
|
||||
"auth_time": 1678100961,
|
||||
"azp": "555666",
|
||||
"bar": {
|
||||
"count": 22,
|
||||
"tags": [
|
||||
"some",
|
||||
"tags"
|
||||
]
|
||||
},
|
||||
"client_id": "555666",
|
||||
"exp": 4802238682,
|
||||
"foo": "Hello, World!",
|
||||
"iat": 1678101021,
|
||||
"iss": "local.com",
|
||||
"jti": "9876",
|
||||
"nbf": 1678101021,
|
||||
"nonce": "12345",
|
||||
"sub": "tim@local.com"
|
||||
}
|
||||
*/
|
||||
const idToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF0X2hhc2giOiIyZHpibV92SXh5LTdlUnRxVUlHUFB3IiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImF1dGhfdGltZSI6MTY3ODEwMDk2MSwiYXpwIjoiNTU1NjY2IiwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiY2xpZW50X2lkIjoiNTU1NjY2IiwiZXhwIjo0ODAyMjM4NjgyLCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MTAxMDIxLCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MTAxMDIxLCJub25jZSI6IjEyMzQ1Iiwic3ViIjoidGltQGxvY2FsLmNvbSJ9.t3GXSfVNNwiW1Suv9_84v0sdn2_-RWHVxhphhRozDXnsO7SDNOlGnEioemXABESxSzMclM7gB7mYy5Qah2ZUNx7eP5t2njoxEYfavgHwx7UJZ2NCg8NDPQyr-hlxelEcfdXK-I0oTd-FRDvF4rqPkD9Us52IpnplChCxnHFgh4wKwPqZZjv2IXVCtn0ilKW3hff1rMOYKEuLRcN2YP0gkyuqyHvcf2dMmjod0t4sLOTJ82rsCbMBC5CLpqv3nIC9HOGITkt1Kd-Am0n1LrdZvWwTo6RFe8AnzF0gpqjcB5Wg4Qeh58DIjZOz4f_8wnmJ_gCqyRh5vfSW4XHdbum0Tw`
|
||||
const accessToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhdWQiOlsidW5pdCIsInRlc3QiXSwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiZXhwIjo0ODAyMjM4NjgyLCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MTAxMDIxLCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MTAxMDIxLCJzdWIiOiJ0aW1AbG9jYWwuY29tIn0.Zrz3LWSRjCMJZUMaI5dUbW4vGdSmEeJQ3ouhaX0bcW9rdFFLgBI4K2FWJhNivq8JDmCGSxwLu3mI680GWmDaEoAx1M5sCO9lqfIZHGZh-lfAXk27e6FPLlkTDBq8Bx4o4DJ9Fw0hRJGjUTjnYv5cq1vo2-UqldasL6CwTbkzNC_4oQFfRtuodC4Ql7dZ1HRv5LXuYx7KPkOssLZtV9cwtJp5nFzKjcf2zEE_tlbjcpynMwypornRUp1EhCWKRUGkJhJeiP71ECY5pQhShfjBu9Nc5wDpSnZmnk2S4YsPrRK3QkE-iEkas8BfsOCrGoErHjEJexAIDjasGO5PFLWfCA`
|
||||
|
||||
func ExampleVerifyTokens_customClaims() {
|
||||
v := rp.NewIDTokenVerifier("local.com", "555666", tu.KeySet{},
|
||||
rp.WithNonce(func(ctx context.Context) string { return "12345" }),
|
||||
)
|
||||
|
||||
// VerifyAccessToken can be called with the *MyCustomClaims.
|
||||
claims, err := rp.VerifyTokens[*MyCustomClaims](context.TODO(), accessToken, idToken, v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Here we have typesafe access to the custom claims
|
||||
fmt.Println(claims.Foo, claims.Bar.Count, claims.Bar.Tags)
|
||||
// Output: Hello, World! 22 [some tags]
|
||||
}
|
|
@ -6,13 +6,14 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type ResourceServer interface {
|
||||
IntrospectionURL() string
|
||||
TokenEndpoint() string
|
||||
HttpClient() *http.Client
|
||||
AuthFn() (interface{}, error)
|
||||
}
|
||||
|
@ -29,6 +30,10 @@ func (r *resourceServer) IntrospectionURL() string {
|
|||
return r.introspectURL
|
||||
}
|
||||
|
||||
func (r *resourceServer) TokenEndpoint() string {
|
||||
return r.tokenURL
|
||||
}
|
||||
|
||||
func (r *resourceServer) HttpClient() *http.Client {
|
||||
return r.httpClient
|
||||
}
|
||||
|
@ -107,7 +112,7 @@ func WithStaticEndpoints(tokenURL, introspectURL string) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func Introspect(ctx context.Context, rp ResourceServer, token string) (oidc.IntrospectionResponse, error) {
|
||||
func Introspect(ctx context.Context, rp ResourceServer, token string) (*oidc.IntrospectionResponse, error) {
|
||||
authFn, err := rp.AuthFn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -116,7 +121,7 @@ func Introspect(ctx context.Context, rp ResourceServer, token string) (oidc.Intr
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp := oidc.NewIntrospectionResponse()
|
||||
resp := new(oidc.IntrospectionResponse)
|
||||
if err := httphelper.HttpRequest(rp.HttpClient(), req, resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
127
pkg/client/tokenexchange/tokenexchange.go
Normal file
127
pkg/client/tokenexchange/tokenexchange.go
Normal file
|
@ -0,0 +1,127 @@
|
|||
package tokenexchange
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type TokenExchanger interface {
|
||||
TokenEndpoint() string
|
||||
HttpClient() *http.Client
|
||||
AuthFn() (interface{}, error)
|
||||
}
|
||||
|
||||
type OAuthTokenExchange struct {
|
||||
httpClient *http.Client
|
||||
tokenEndpoint string
|
||||
authFn func() (interface{}, error)
|
||||
}
|
||||
|
||||
func NewTokenExchanger(issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
|
||||
return newOAuthTokenExchange(issuer, nil, options...)
|
||||
}
|
||||
|
||||
func NewTokenExchangerClientCredentials(issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
|
||||
authorizer := func() (interface{}, error) {
|
||||
return httphelper.AuthorizeBasic(clientID, clientSecret), nil
|
||||
}
|
||||
return newOAuthTokenExchange(issuer, authorizer, options...)
|
||||
}
|
||||
|
||||
func newOAuthTokenExchange(issuer string, authorizer func() (interface{}, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) {
|
||||
te := &OAuthTokenExchange{
|
||||
httpClient: httphelper.DefaultHTTPClient,
|
||||
}
|
||||
for _, opt := range options {
|
||||
opt(te)
|
||||
}
|
||||
|
||||
if te.tokenEndpoint == "" {
|
||||
config, err := client.Discover(issuer, te.httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
te.tokenEndpoint = config.TokenEndpoint
|
||||
}
|
||||
|
||||
if te.tokenEndpoint == "" {
|
||||
return nil, errors.New("tokenURL is empty: please provide with either `WithStaticTokenEndpoint` or a discovery url")
|
||||
}
|
||||
|
||||
te.authFn = authorizer
|
||||
|
||||
return te, nil
|
||||
}
|
||||
|
||||
func WithHTTPClient(client *http.Client) func(*OAuthTokenExchange) {
|
||||
return func(source *OAuthTokenExchange) {
|
||||
source.httpClient = client
|
||||
}
|
||||
}
|
||||
|
||||
func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(*OAuthTokenExchange) {
|
||||
return func(source *OAuthTokenExchange) {
|
||||
source.tokenEndpoint = tokenEndpoint
|
||||
}
|
||||
}
|
||||
|
||||
func (te *OAuthTokenExchange) TokenEndpoint() string {
|
||||
return te.tokenEndpoint
|
||||
}
|
||||
|
||||
func (te *OAuthTokenExchange) HttpClient() *http.Client {
|
||||
return te.httpClient
|
||||
}
|
||||
|
||||
func (te *OAuthTokenExchange) AuthFn() (interface{}, error) {
|
||||
if te.authFn != nil {
|
||||
return te.authFn()
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint.
|
||||
// SubjectToken and SubjectTokenType are required parameters.
|
||||
func ExchangeToken(
|
||||
te TokenExchanger,
|
||||
SubjectToken string,
|
||||
SubjectTokenType oidc.TokenType,
|
||||
ActorToken string,
|
||||
ActorTokenType oidc.TokenType,
|
||||
Resource []string,
|
||||
Audience []string,
|
||||
Scopes []string,
|
||||
RequestedTokenType oidc.TokenType,
|
||||
) (*oidc.TokenExchangeResponse, error) {
|
||||
if SubjectToken == "" {
|
||||
return nil, errors.New("empty subject_token")
|
||||
}
|
||||
if SubjectTokenType == "" {
|
||||
return nil, errors.New("empty subject_token_type")
|
||||
}
|
||||
|
||||
authFn, err := te.AuthFn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
request := oidc.TokenExchangeRequest{
|
||||
GrantType: oidc.GrantTypeTokenExchange,
|
||||
SubjectToken: SubjectToken,
|
||||
SubjectTokenType: SubjectTokenType,
|
||||
ActorToken: ActorToken,
|
||||
ActorTokenType: ActorTokenType,
|
||||
Resource: Resource,
|
||||
Audience: Audience,
|
||||
Scopes: Scopes,
|
||||
RequestedTokenType: RequestedTokenType,
|
||||
}
|
||||
|
||||
return client.CallTokenExchangeEndpoint(request, authFn, te)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue