Merge branch 'next' into main-next

prepare the merge of next into main by resolving merge conflicts.
This commit is contained in:
Tim Möhlmann 2023-03-15 16:26:32 +02:00
commit 0476b5946e
122 changed files with 8195 additions and 2858 deletions

View file

@ -1,31 +1,25 @@
package client
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"reflect"
"strings"
"time"
"github.com/gorilla/schema"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/pkg/crypto"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/crypto"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
var Encoder = func() httphelper.Encoder {
e := schema.NewEncoder()
e.RegisterEncoder(oidc.SpaceDelimitedArray{}, func(value reflect.Value) string {
return value.Interface().(oidc.SpaceDelimitedArray).Encode()
})
return e
}()
var Encoder = httphelper.Encoder(oidc.NewEncoder())
// Discover calls the discovery endpoint of the provided issuer and returns its configuration
// It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url
@ -90,6 +84,9 @@ func CallEndSessionEndpoint(request interface{}, authFn interface{}, caller EndS
return http.ErrUseLastResponse
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
body, err := io.ReadAll(resp.Body)
@ -148,6 +145,18 @@ func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCa
return nil
}
func CallTokenExchangeEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) {
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn)
if err != nil {
return nil, err
}
tokenRes := new(oidc.TokenExchangeResponse)
if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil {
return nil, err
}
return tokenRes, nil
}
func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error) {
privateKey, err := crypto.BytesToPrivateKey(key)
if err != nil {
@ -167,7 +176,98 @@ func SignedJWTProfileAssertion(clientID string, audience []string, expiration ti
Issuer: clientID,
Subject: clientID,
Audience: audience,
ExpiresAt: oidc.Time(exp),
IssuedAt: oidc.Time(iat),
ExpiresAt: oidc.FromTime(exp),
IssuedAt: oidc.FromTime(iat),
}, signer)
}
type DeviceAuthorizationCaller interface {
GetDeviceAuthorizationEndpoint() string
HttpClient() *http.Client
}
func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) {
req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil)
if err != nil {
return nil, err
}
if request.ClientSecret != "" {
req.SetBasicAuth(request.ClientID, request.ClientSecret)
}
resp := new(oidc.DeviceAuthorizationResponse)
if err := httphelper.HttpRequest(caller.HttpClient(), req, &resp); err != nil {
return nil, err
}
return resp, nil
}
type DeviceAccessTokenRequest struct {
*oidc.ClientCredentialsRequest
oidc.DeviceAccessTokenRequest
}
func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil)
if err != nil {
return nil, err
}
if request.ClientSecret != "" {
req.SetBasicAuth(request.ClientID, request.ClientSecret)
}
httpResp, err := caller.HttpClient().Do(req)
if err != nil {
return nil, err
}
defer httpResp.Body.Close()
resp := new(struct {
*oidc.AccessTokenResponse
*oidc.Error
})
if err = json.NewDecoder(httpResp.Body).Decode(resp); err != nil {
return nil, err
}
if httpResp.StatusCode == http.StatusOK {
return resp.AccessTokenResponse, nil
}
return nil, resp.Error
}
func PollDeviceAccessTokenEndpoint(ctx context.Context, interval time.Duration, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
for {
timer := time.After(interval)
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timer:
}
ctx, cancel := context.WithTimeout(ctx, interval)
defer cancel()
resp, err := CallDeviceAccessTokenEndpoint(ctx, request, caller)
if err == nil {
return resp, nil
}
if errors.Is(err, context.DeadlineExceeded) {
interval += 5 * time.Second
}
var target *oidc.Error
if !errors.As(err, &target) {
return nil, err
}
switch target.ErrorType {
case oidc.AuthorizationPending:
continue
case oidc.SlowDown:
interval += 5 * time.Second
continue
default:
return nil, err
}
}
}

View file

@ -1,8 +1,7 @@
package rp_test
package client_test
import (
"bytes"
"context"
"io"
"io/ioutil"
"math/rand"
@ -15,34 +14,142 @@ import (
"testing"
"time"
"github.com/zitadel/oidc/example/server/exampleop"
"github.com/zitadel/oidc/example/server/storage"
"github.com/jeremija/gosubmit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/pkg/client/rp"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/example/server/exampleop"
"github.com/zitadel/oidc/v2/example/server/storage"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/client/rs"
"github.com/zitadel/oidc/v2/pkg/client/tokenexchange"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
func TestRelyingPartySession(t *testing.T) {
t.Log("------- start example OP ------")
ctx := context.Background()
exampleStorage := storage.NewStorage(storage.NewUserStore())
targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
var dh deferredHandler
opServer := httptest.NewServer(&dh)
defer opServer.Close()
t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(ctx, opServer.URL, exampleStorage)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
t.Log("------- run authorization code flow ------")
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, "secret")
t.Log("------- refresh tokens ------")
newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "")
require.NoError(t, err, "refresh token")
assert.NotNil(t, newTokens, "access token")
t.Logf("new access token %s", newTokens.AccessToken)
t.Logf("new refresh token %s", newTokens.RefreshToken)
t.Logf("new token type %s", newTokens.TokenType)
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
t.Log("------ end session (logout) ------")
newLoc, err := rp.EndSession(provider, idToken, "", "")
require.NoError(t, err, "logout")
if newLoc != nil {
t.Logf("redirect to %s", newLoc)
} else {
t.Logf("no redirect")
}
t.Log("------ attempt refresh again (should fail) ------")
t.Log("trying original refresh token", refreshToken)
_, err = rp.RefreshAccessToken(provider, refreshToken, "", "")
assert.Errorf(t, err, "refresh with original")
if newTokens.RefreshToken != "" {
t.Log("trying replacement refresh token", newTokens.RefreshToken)
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
assert.Errorf(t, err, "refresh with replacement")
}
}
func TestResourceServerTokenExchange(t *testing.T) {
t.Log("------- start example OP ------")
targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
var dh deferredHandler
opServer := httptest.NewServer(&dh)
defer opServer.Close()
t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
clientSecret := "secret"
t.Log("------- run authorization code flow ------")
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)
resourceServer, err := rs.NewResourceServerClientCredentials(opServer.URL, clientID, clientSecret)
require.NoError(t, err, "new resource server")
t.Log("------- exchage refresh tokens (impersonation) ------")
tokenExchangeResponse, err := tokenexchange.ExchangeToken(
resourceServer,
refreshToken,
oidc.RefreshTokenType,
"",
"",
[]string{},
[]string{},
[]string{"profile", "custom_scope:impersonate:id2"},
oidc.RefreshTokenType,
)
require.NoError(t, err, "refresh token")
require.NotNil(t, tokenExchangeResponse, "token exchange response")
assert.Equal(t, tokenExchangeResponse.IssuedTokenType, oidc.RefreshTokenType)
assert.NotEmpty(t, tokenExchangeResponse.AccessToken, "access token")
assert.NotEmpty(t, tokenExchangeResponse.RefreshToken, "refresh token")
assert.Equal(t, []string(tokenExchangeResponse.Scopes), []string{"profile", "custom_scope:impersonate:id2"})
t.Log("------ end session (logout) ------")
newLoc, err := rp.EndSession(provider, idToken, "", "")
require.NoError(t, err, "logout")
if newLoc != nil {
t.Logf("redirect to %s", newLoc)
} else {
t.Logf("no redirect")
}
t.Log("------- attempt exchage again (should fail) ------")
tokenExchangeResponse, err = tokenexchange.ExchangeToken(
resourceServer,
refreshToken,
oidc.RefreshTokenType,
"",
"",
[]string{},
[]string{},
[]string{"profile", "custom_scope:impersonate:id2"},
oidc.RefreshTokenType,
)
require.Error(t, err, "refresh token")
assert.Contains(t, err.Error(), "subject_token is invalid")
require.Nil(t, tokenExchangeResponse, "token exchange response")
}
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) {
targetURL := "http://local-site"
localURL, err := url.Parse(targetURL + "/login?requestID=1234")
require.NoError(t, err, "local url")
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
client := storage.WebClient(clientID, "secret", targetURL)
client := storage.WebClient(clientID, clientSecret, targetURL)
storage.RegisterClients(client)
jar, err := cookiejar.New(nil)
@ -58,10 +165,10 @@ func TestRelyingPartySession(t *testing.T) {
t.Log("------- create RP ------")
key := []byte("test1234test1234")
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
provider, err := rp.NewRelyingPartyOIDC(
provider, err = rp.NewRelyingPartyOIDC(
opServer.URL,
clientID,
"secret",
clientSecret,
targetURL,
[]string{"openid", "email", "profile", "offline_access"},
rp.WithPKCE(cookieHandler),
@ -70,8 +177,10 @@ func TestRelyingPartySession(t *testing.T) {
rp.WithSupportedSigningAlgorithms("RS256", "RS384", "RS512", "ES256", "ES384", "ES512"),
),
)
require.NoError(t, err, "new rp")
t.Log("------- get redirect from local client (rp) to OP ------")
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
state := "state-" + strconv.FormatInt(seed.Int63(), 25)
capturedW := httptest.NewRecorder()
get := httptest.NewRequest("GET", localURL.String(), nil)
@ -114,7 +223,7 @@ func TestRelyingPartySession(t *testing.T) {
t.Log("------- post to login form, get redirect to OP ------")
postLoginRedirectURL := fillForm(t, "fill login form", httpClient, form, loginPageURL,
gosubmit.Set("username", "test-user"),
gosubmit.Set("username", "test-user@local-site"),
gosubmit.Set("password", "verysecure"))
t.Logf("Get redirect from %s", postLoginRedirectURL)
@ -130,19 +239,19 @@ func TestRelyingPartySession(t *testing.T) {
t.Logf("setting cookie %s", cookie)
}
var accessToken, refreshToken, idToken, email string
redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {
var email string
redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
require.NotNil(t, tokens, "tokens")
require.NotNil(t, info, "info")
t.Log("access token", tokens.AccessToken)
t.Log("refresh token", tokens.RefreshToken)
t.Log("id token", tokens.IDToken)
t.Log("email", info.GetEmail())
t.Log("email", info.Email)
accessToken = tokens.AccessToken
refreshToken = tokens.RefreshToken
idToken = tokens.IDToken
email = info.GetEmail()
email = info.Email
http.Redirect(w, r, targetURL, 302)
}
rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider, rp.WithURLParam("custom", "param"))(capturedW, get)
@ -154,7 +263,6 @@ func TestRelyingPartySession(t *testing.T) {
}
}()
require.Less(t, capturedW.Code, 400, "token exchange response code")
require.Less(t, capturedW.Code, 400, "token exchange response code")
// TODO: how to check the custom header was sent to the server?
//nolint:bodyclose
@ -169,43 +277,7 @@ func TestRelyingPartySession(t *testing.T) {
assert.NotEmpty(t, accessToken, "access token")
assert.NotEmpty(t, email, "email")
t.Log("------- refresh tokens ------")
newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "")
require.NoError(t, err, "refresh token")
assert.NotNil(t, newTokens, "access token")
t.Logf("new access token %s", newTokens.AccessToken)
t.Logf("new refresh token %s", newTokens.RefreshToken)
t.Logf("new token type %s", newTokens.TokenType)
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
t.Log("------ end session (logout) ------")
newLoc, err := rp.EndSession(provider, idToken, "", "")
require.NoError(t, err, "logout")
if newLoc != nil {
t.Logf("redirect to %s", newLoc)
} else {
t.Logf("no redirect")
}
t.Log("------ attempt refresh again (should fail) ------")
t.Log("trying original refresh token", refreshToken)
_, err = rp.RefreshAccessToken(provider, refreshToken, "", "")
assert.Errorf(t, err, "refresh with original")
if newTokens.RefreshToken != "" {
t.Log("trying replacement refresh token", newTokens.RefreshToken)
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
assert.Errorf(t, err, "refresh with replacement")
}
t.Run("WithPrompt", func(t *testing.T) {
opts := rp.WithPrompt("foo", "bar")()
url := provider.OAuthConfig().AuthCodeURL("some", opts...)
require.Contains(t, url, "prompt=foo+bar")
})
return provider, accessToken, refreshToken, idToken
}
type deferredHandler struct {

View file

@ -5,8 +5,8 @@ import (
"golang.org/x/oauth2"
"github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
// JWTProfileExchange handles the oauth2 jwt profile exchange

View file

@ -7,8 +7,8 @@ import (
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/pkg/client"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/client"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
// jwtProfileTokenSource implement the oauth2.TokenSource

View file

@ -4,22 +4,22 @@ import (
"context"
"net/http"
"github.com/zitadel/oidc/pkg/client/rp"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
const (
loginPath = "/login"
)
func CodeFlow(ctx context.Context, relyingParty rp.RelyingParty, callbackPath, port string, stateProvider func() string) *oidc.Tokens {
func CodeFlow[C oidc.IDClaims](ctx context.Context, relyingParty rp.RelyingParty, callbackPath, port string, stateProvider func() string) *oidc.Tokens[C] {
codeflowCtx, codeflowCancel := context.WithCancel(ctx)
defer codeflowCancel()
tokenChan := make(chan *oidc.Tokens, 1)
tokenChan := make(chan *oidc.Tokens[C], 1)
callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp rp.RelyingParty) {
tokenChan <- tokens
msg := "<p><strong>Success!</strong></p>"
msg = msg + "<p>You are authenticated and can now return to the CLI.</p>"

View file

@ -1,13 +1,13 @@
package rp
import (
"github.com/zitadel/oidc/pkg/oidc/grants/tokenexchange"
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange"
)
// DelegationTokenRequest is an implementation of TokenExchangeRequest
// it exchanges an "urn:ietf:params:oauth:token-type:access_token" with an optional
//"urn:ietf:params:oauth:token-type:access_token" actor token for an
//"urn:ietf:params:oauth:token-type:access_token" delegation token
// "urn:ietf:params:oauth:token-type:access_token" actor token for an
// "urn:ietf:params:oauth:token-type:access_token" delegation token
func DelegationTokenRequest(subjectToken string, opts ...tokenexchange.TokenExchangeOption) *tokenexchange.TokenExchangeRequest {
return tokenexchange.NewTokenExchangeRequest(subjectToken, tokenexchange.AccessTokenType, opts...)
}

62
pkg/client/rp/device.go Normal file
View file

@ -0,0 +1,62 @@
package rp
import (
"context"
"fmt"
"time"
"github.com/zitadel/oidc/v2/pkg/client"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) {
confg := rp.OAuthConfig()
req := &oidc.ClientCredentialsRequest{
GrantType: oidc.GrantTypeDeviceCode,
Scope: scopes,
ClientID: confg.ClientID,
ClientSecret: confg.ClientSecret,
}
if signer := rp.Signer(); signer != nil {
assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, signer)
if err != nil {
return nil, fmt.Errorf("failed to build assertion: %w", err)
}
req.ClientAssertion = assertion
req.ClientAssertionType = oidc.ClientAssertionTypeJWTAssertion
}
return req, nil
}
// DeviceAuthorization starts a new Device Authorization flow as defined
// in RFC 8628, section 3.1 and 3.2:
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) {
req, err := newDeviceClientCredentialsRequest(scopes, rp)
if err != nil {
return nil, err
}
return client.CallDeviceAuthorizationEndpoint(req, rp)
}
// DeviceAccessToken attempts to obtain tokens from a Device Authorization,
// by means of polling as defined in RFC, section 3.3 and 3.4:
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4
func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
req := &client.DeviceAccessTokenRequest{
DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
GrantType: oidc.GrantTypeDeviceCode,
DeviceCode: deviceCode,
},
}
req.ClientCredentialsRequest, err = newDeviceClientCredentialsRequest(nil, rp)
if err != nil {
return nil, err
}
return client.PollDeviceAccessTokenEndpoint(ctx, interval, req, tokenEndpointCaller{rp})
}

View file

@ -9,8 +9,8 @@ import (
"gopkg.in/square/go-jose.v2"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet {

View file

@ -1,3 +0,0 @@
package mock
//go:generate mockgen -package mock -destination ./verifier.mock.go github.com/zitadel/oidc/pkg/rp Verifier

View file

@ -1,67 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/pkg/rp (interfaces: Verifier)
// Package mock is a generated GoMock package.
package mock
import (
"context"
"reflect"
"github.com/golang/mock/gomock"
"github.com/zitadel/oidc/pkg/oidc"
)
// MockVerifier is a mock of Verifier interface
type MockVerifier struct {
ctrl *gomock.Controller
recorder *MockVerifierMockRecorder
}
// MockVerifierMockRecorder is the mock recorder for MockVerifier
type MockVerifierMockRecorder struct {
mock *MockVerifier
}
// NewMockVerifier creates a new mock instance
func NewMockVerifier(ctrl *gomock.Controller) *MockVerifier {
mock := &MockVerifier{ctrl: ctrl}
mock.recorder = &MockVerifierMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockVerifier) EXPECT() *MockVerifierMockRecorder {
return m.recorder
}
// Verify mocks base method
func (m *MockVerifier) Verify(arg0 context.Context, arg1, arg2 string) (*oidc.IDTokenClaims, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Verify", arg0, arg1, arg2)
ret0, _ := ret[0].(*oidc.IDTokenClaims)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Verify indicates an expected call of Verify
func (mr *MockVerifierMockRecorder) Verify(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockVerifier)(nil).Verify), arg0, arg1, arg2)
}
// VerifyIDToken mocks base method
func (m *MockVerifier) VerifyIDToken(arg0 context.Context, arg1 string) (*oidc.IDTokenClaims, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "VerifyIDToken", arg0, arg1)
ret0, _ := ret[0].(*oidc.IDTokenClaims)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// VerifyIDToken indicates an expected call of VerifyIDToken
func (mr *MockVerifierMockRecorder) VerifyIDToken(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyIDToken", reflect.TypeOf((*MockVerifier)(nil).VerifyIDToken), arg0, arg1)
}

View file

@ -14,9 +14,9 @@ import (
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/pkg/client"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/client"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
const (
@ -54,11 +54,15 @@ type RelyingParty interface {
GetEndSessionEndpoint() string
// GetRevokeEndpoint returns the endpoint to revoke a specific token
// "GetRevokeEndpoint() string" will be added in a future release
GetRevokeEndpoint() string
// UserinfoEndpoint returns the userinfo
UserinfoEndpoint() string
// GetDeviceAuthorizationEndpoint returns the enpoint which can
// be used to start a DeviceAuthorization flow.
GetDeviceAuthorizationEndpoint() string
// IDTokenVerifier returns the verifier interface used for oidc id_token verification
IDTokenVerifier() IDTokenVerifier
// ErrorHandler returns the handler used for callback errors
@ -121,6 +125,10 @@ func (rp *relyingParty) UserinfoEndpoint() string {
return rp.endpoints.UserinfoURL
}
func (rp *relyingParty) GetDeviceAuthorizationEndpoint() string {
return rp.endpoints.DeviceAuthorizationURL
}
func (rp *relyingParty) GetEndSessionEndpoint() string {
return rp.endpoints.EndSessionURL
}
@ -371,7 +379,7 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri
// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
// returning it parsed together with the oauth2 tokens (access, refresh)
func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens, err error) {
func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
codeOpts := make([]oauth2.AuthCodeOption, 0)
for _, opt := range opts {
@ -384,7 +392,7 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod
}
if rp.IsOAuth2Only() {
return &oidc.Tokens{Token: token}, nil
return &oidc.Tokens[C]{Token: token}, nil
}
idTokenString, ok := token.Extra(idTokenKey).(string)
@ -392,21 +400,21 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod
return nil, errors.New("id_token missing")
}
idToken, err := VerifyTokens(ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
if err != nil {
return nil, err
}
return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
}
type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty)
type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty)
// CodeExchangeHandler extends the `CodeExchange` method with a http handler
// including cookie handling for secure `state` transfer
// and optional PKCE code verifier checking.
// Custom paramaters can optionally be set to the token URL.
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
state, err := tryReadStateCookie(w, r, rp)
if err != nil {
@ -439,7 +447,7 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlPara
}
codeOpts = append(codeOpts, WithClientAssertionJWT(assertion))
}
tokens, err := CodeExchange(r.Context(), params.Get("code"), rp, codeOpts...)
tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...)
if err != nil {
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
return
@ -448,13 +456,13 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlPara
}
}
type CodeExchangeUserinfoCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, provider RelyingParty, info oidc.UserInfo)
type CodeExchangeUserinfoCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info *oidc.UserInfo)
// UserinfoCallback wraps the callback function of the CodeExchangeHandler
// and calls the userinfo endpoint with the access token
// on success it will pass the userinfo into its callback function as well
func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty) {
func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] {
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
if err != nil {
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
@ -465,17 +473,17 @@ func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
}
// Userinfo will call the OIDC Userinfo Endpoint with the provided token
func Userinfo(token, tokenType, subject string, rp RelyingParty) (oidc.UserInfo, error) {
func Userinfo(token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) {
req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil)
if err != nil {
return nil, err
}
req.Header.Set("authorization", tokenType+" "+token)
userinfo := oidc.NewUserInfo()
userinfo := new(oidc.UserInfo)
if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil {
return nil, err
}
if userinfo.GetSubject() != subject {
if userinfo.Subject != subject {
return nil, ErrUserInfoSubNotMatching
}
return userinfo, nil
@ -506,11 +514,12 @@ type OptionFunc func(RelyingParty)
type Endpoints struct {
oauth2.Endpoint
IntrospectURL string
UserinfoURL string
JKWsURL string
EndSessionURL string
RevokeURL string
IntrospectURL string
UserinfoURL string
JKWsURL string
EndSessionURL string
RevokeURL string
DeviceAuthorizationURL string
}
func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
@ -520,11 +529,12 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
AuthStyle: oauth2.AuthStyleAutoDetect,
TokenURL: discoveryConfig.TokenEndpoint,
},
IntrospectURL: discoveryConfig.IntrospectionEndpoint,
UserinfoURL: discoveryConfig.UserinfoEndpoint,
JKWsURL: discoveryConfig.JwksURI,
EndSessionURL: discoveryConfig.EndSessionEndpoint,
RevokeURL: discoveryConfig.RevocationEndpoint,
IntrospectURL: discoveryConfig.IntrospectionEndpoint,
UserinfoURL: discoveryConfig.UserinfoEndpoint,
JKWsURL: discoveryConfig.JwksURI,
EndSessionURL: discoveryConfig.EndSessionEndpoint,
RevokeURL: discoveryConfig.RevocationEndpoint,
DeviceAuthorizationURL: discoveryConfig.DeviceAuthorizationEndpoint,
}
}

View file

@ -5,7 +5,7 @@ import (
"golang.org/x/oauth2"
"github.com/zitadel/oidc/pkg/oidc/grants/tokenexchange"
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange"
)
// TokenExchangeRP extends the `RelyingParty` interface for the *draft* oauth2 `Token Exchange`

View file

@ -6,7 +6,7 @@ import (
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type IDTokenVerifier interface {
@ -20,76 +20,78 @@ type IDTokenVerifier interface {
}
// VerifyTokens implement the Token Response Validation as defined in OIDC specification
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
idToken, err := VerifyIDToken(ctx, idTokenString, v)
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) {
var nilClaims C
claims, err = VerifyIDToken[C](ctx, idToken, v)
if err != nil {
return nil, err
return nilClaims, err
}
if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil {
return nil, err
if err := VerifyAccessToken(accessToken, claims.GetAccessTokenHash(), claims.GetSignatureAlgorithm()); err != nil {
return nilClaims, err
}
return idToken, nil
return claims, nil
}
// VerifyIDToken validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
claims := oidc.EmptyIDTokenClaims()
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) {
var nilClaims C
decrypted, err := oidc.DecryptToken(token)
if err != nil {
return nil, err
return nilClaims, err
}
payload, err := oidc.ParseToken(decrypted, claims)
payload, err := oidc.ParseToken(decrypted, &claims)
if err != nil {
return nil, err
return nilClaims, err
}
if err := oidc.CheckSubject(claims); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckAudience(claims, v.ClientID()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
return nil, err
return nilClaims, err
}
return claims, nil
}
// VerifyAccessToken validates the access token according to
//https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
if atHash == "" {
return nil
@ -112,7 +114,7 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...
issuer: issuer,
clientID: clientID,
keySet: keySet,
offset: 1 * time.Second,
offset: time.Second,
nonce: func(_ context.Context) string {
return ""
},
@ -139,7 +141,7 @@ func WithIssuedAtOffset(offset time.Duration) func(*idTokenVerifier) {
// WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
func WithIssuedAtMaxAge(maxAge time.Duration) func(*idTokenVerifier) {
return func(v *idTokenVerifier) {
v.maxAge = maxAge
v.maxAgeIAT = maxAge
}
}

View file

@ -0,0 +1,339 @@
package rp
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v2/internal/testutil"
"github.com/zitadel/oidc/v2/pkg/oidc"
"gopkg.in/square/go-jose.v2"
)
func TestVerifyTokens(t *testing.T) {
verifier := &idTokenVerifier{
issuer: tu.ValidIssuer,
maxAgeIAT: 2 * time.Minute,
offset: time.Second,
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
keySet: tu.KeySet{},
maxAge: 2 * time.Minute,
acr: tu.ACRVerify,
nonce: func(context.Context) string { return tu.ValidNonce },
clientID: tu.ValidClientID,
}
accessToken, _ := tu.ValidAccessToken()
atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm)
require.NoError(t, err)
tests := []struct {
name string
accessToken string
idTokenClaims func() (string, *oidc.IDTokenClaims)
wantErr bool
}{
{
name: "without access token",
idTokenClaims: tu.ValidIDToken,
},
{
name: "with access token",
accessToken: accessToken,
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash,
)
},
},
{
name: "expired id token",
accessToken: accessToken,
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash,
)
},
wantErr: true,
},
{
name: "wrong access token",
accessToken: accessToken,
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "~~~",
)
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
idToken, want := tt.idTokenClaims()
got, err := VerifyTokens[*oidc.IDTokenClaims](context.Background(), tt.accessToken, idToken, verifier)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
return
}
require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, got, want)
})
}
}
func TestVerifyIDToken(t *testing.T) {
verifier := &idTokenVerifier{
issuer: tu.ValidIssuer,
maxAgeIAT: 2 * time.Minute,
offset: time.Second,
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
keySet: tu.KeySet{},
maxAge: 2 * time.Minute,
acr: tu.ACRVerify,
nonce: func(context.Context) string { return tu.ValidNonce },
}
tests := []struct {
name string
clientID string
tokenClaims func() (string, *oidc.IDTokenClaims)
wantErr bool
}{
{
name: "success",
clientID: tu.ValidClientID,
tokenClaims: tu.ValidIDToken,
},
{
name: "parse err",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
wantErr: true,
},
{
name: "invalid signature",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
wantErr: true,
},
{
name: "empty subject",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, "", tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
{
name: "wrong issuer",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
"foo", tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
{
name: "wrong clientID",
clientID: "foo",
tokenClaims: tu.ValidIDToken,
wantErr: true,
},
{
name: "expired",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
{
name: "wrong IAT",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "",
)
},
wantErr: true,
},
{
name: "wrong acr",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
"else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
{
name: "expired auth",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime.Add(-time.Hour), tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
{
name: "wrong nonce",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, "foo",
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, want := tt.tokenClaims()
verifier.clientID = tt.clientID
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
return
}
require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, got, want)
})
}
}
func TestVerifyAccessToken(t *testing.T) {
token, _ := tu.ValidAccessToken()
hash, err := oidc.ClaimHash(token, tu.SignatureAlgorithm)
require.NoError(t, err)
type args struct {
accessToken string
atHash string
sigAlgorithm jose.SignatureAlgorithm
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "empty hash",
},
{
name: "success",
args: args{
accessToken: token,
atHash: hash,
sigAlgorithm: tu.SignatureAlgorithm,
},
},
{
name: "invalid algorithm",
args: args{
accessToken: token,
atHash: hash,
sigAlgorithm: "foo",
},
wantErr: true,
},
{
name: "mismatch",
args: args{
accessToken: token,
atHash: "~~",
sigAlgorithm: tu.SignatureAlgorithm,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := VerifyAccessToken(tt.args.accessToken, tt.args.atHash, tt.args.sigAlgorithm)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
func TestNewIDTokenVerifier(t *testing.T) {
type args struct {
issuer string
clientID string
keySet oidc.KeySet
options []VerifierOption
}
tests := []struct {
name string
args args
want IDTokenVerifier
}{
{
name: "nil nonce", // otherwise assert.Equal will fail on the function
args: args{
issuer: tu.ValidIssuer,
clientID: tu.ValidClientID,
keySet: tu.KeySet{},
options: []VerifierOption{
WithIssuedAtOffset(time.Minute),
WithIssuedAtMaxAge(time.Hour),
WithNonce(nil), // otherwise assert.Equal will fail on the function
WithACRVerifier(nil),
WithAuthTimeMaxAge(2 * time.Hour),
WithSupportedSigningAlgorithms("ABC", "DEF"),
},
},
want: &idTokenVerifier{
issuer: tu.ValidIssuer,
offset: time.Minute,
maxAgeIAT: time.Hour,
clientID: tu.ValidClientID,
keySet: tu.KeySet{},
nonce: nil,
acr: nil,
maxAge: 2 * time.Hour,
supportedSignAlgs: []string{"ABC", "DEF"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewIDTokenVerifier(tt.args.issuer, tt.args.clientID, tt.args.keySet, tt.args.options...)
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -0,0 +1,86 @@
package rp_test
import (
"context"
"fmt"
tu "github.com/zitadel/oidc/v2/internal/testutil"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
// MyCustomClaims extends the TokenClaims base,
// so it implmeents the oidc.Claims interface.
// Instead of carrying a map, we add needed fields// to the struct for type safe access.
type MyCustomClaims struct {
oidc.TokenClaims
NotBefore oidc.Time `json:"nbf,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"`
Foo string `json:"foo,omitempty"`
Bar *Nested `json:"bar,omitempty"`
}
// GetAccessTokenHash is required to implement
// the oidc.IDClaims interface.
func (c *MyCustomClaims) GetAccessTokenHash() string {
return c.AccessTokenHash
}
// Nested struct types are also possible.
type Nested struct {
Count int `json:"count,omitempty"`
Tags []string `json:"tags,omitempty"`
}
/*
idToken carries the following claims. foo and bar are custom claims
{
"acr": "something",
"amr": [
"foo",
"bar"
],
"at_hash": "2dzbm_vIxy-7eRtqUIGPPw",
"aud": [
"unit",
"test",
"555666"
],
"auth_time": 1678100961,
"azp": "555666",
"bar": {
"count": 22,
"tags": [
"some",
"tags"
]
},
"client_id": "555666",
"exp": 4802238682,
"foo": "Hello, World!",
"iat": 1678101021,
"iss": "local.com",
"jti": "9876",
"nbf": 1678101021,
"nonce": "12345",
"sub": "tim@local.com"
}
*/
const idToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF0X2hhc2giOiIyZHpibV92SXh5LTdlUnRxVUlHUFB3IiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImF1dGhfdGltZSI6MTY3ODEwMDk2MSwiYXpwIjoiNTU1NjY2IiwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiY2xpZW50X2lkIjoiNTU1NjY2IiwiZXhwIjo0ODAyMjM4NjgyLCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MTAxMDIxLCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MTAxMDIxLCJub25jZSI6IjEyMzQ1Iiwic3ViIjoidGltQGxvY2FsLmNvbSJ9.t3GXSfVNNwiW1Suv9_84v0sdn2_-RWHVxhphhRozDXnsO7SDNOlGnEioemXABESxSzMclM7gB7mYy5Qah2ZUNx7eP5t2njoxEYfavgHwx7UJZ2NCg8NDPQyr-hlxelEcfdXK-I0oTd-FRDvF4rqPkD9Us52IpnplChCxnHFgh4wKwPqZZjv2IXVCtn0ilKW3hff1rMOYKEuLRcN2YP0gkyuqyHvcf2dMmjod0t4sLOTJ82rsCbMBC5CLpqv3nIC9HOGITkt1Kd-Am0n1LrdZvWwTo6RFe8AnzF0gpqjcB5Wg4Qeh58DIjZOz4f_8wnmJ_gCqyRh5vfSW4XHdbum0Tw`
const accessToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhdWQiOlsidW5pdCIsInRlc3QiXSwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiZXhwIjo0ODAyMjM4NjgyLCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MTAxMDIxLCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MTAxMDIxLCJzdWIiOiJ0aW1AbG9jYWwuY29tIn0.Zrz3LWSRjCMJZUMaI5dUbW4vGdSmEeJQ3ouhaX0bcW9rdFFLgBI4K2FWJhNivq8JDmCGSxwLu3mI680GWmDaEoAx1M5sCO9lqfIZHGZh-lfAXk27e6FPLlkTDBq8Bx4o4DJ9Fw0hRJGjUTjnYv5cq1vo2-UqldasL6CwTbkzNC_4oQFfRtuodC4Ql7dZ1HRv5LXuYx7KPkOssLZtV9cwtJp5nFzKjcf2zEE_tlbjcpynMwypornRUp1EhCWKRUGkJhJeiP71ECY5pQhShfjBu9Nc5wDpSnZmnk2S4YsPrRK3QkE-iEkas8BfsOCrGoErHjEJexAIDjasGO5PFLWfCA`
func ExampleVerifyTokens_customClaims() {
v := rp.NewIDTokenVerifier("local.com", "555666", tu.KeySet{},
rp.WithNonce(func(ctx context.Context) string { return "12345" }),
)
// VerifyAccessToken can be called with the *MyCustomClaims.
claims, err := rp.VerifyTokens[*MyCustomClaims](context.TODO(), accessToken, idToken, v)
if err != nil {
panic(err)
}
// Here we have typesafe access to the custom claims
fmt.Println(claims.Foo, claims.Bar.Count, claims.Bar.Tags)
// Output: Hello, World! 22 [some tags]
}

View file

@ -6,13 +6,14 @@ import (
"net/http"
"time"
"github.com/zitadel/oidc/pkg/client"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/client"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type ResourceServer interface {
IntrospectionURL() string
TokenEndpoint() string
HttpClient() *http.Client
AuthFn() (interface{}, error)
}
@ -29,6 +30,10 @@ func (r *resourceServer) IntrospectionURL() string {
return r.introspectURL
}
func (r *resourceServer) TokenEndpoint() string {
return r.tokenURL
}
func (r *resourceServer) HttpClient() *http.Client {
return r.httpClient
}
@ -107,7 +112,7 @@ func WithStaticEndpoints(tokenURL, introspectURL string) Option {
}
}
func Introspect(ctx context.Context, rp ResourceServer, token string) (oidc.IntrospectionResponse, error) {
func Introspect(ctx context.Context, rp ResourceServer, token string) (*oidc.IntrospectionResponse, error) {
authFn, err := rp.AuthFn()
if err != nil {
return nil, err
@ -116,7 +121,7 @@ func Introspect(ctx context.Context, rp ResourceServer, token string) (oidc.Intr
if err != nil {
return nil, err
}
resp := oidc.NewIntrospectionResponse()
resp := new(oidc.IntrospectionResponse)
if err := httphelper.HttpRequest(rp.HttpClient(), req, resp); err != nil {
return nil, err
}

View file

@ -0,0 +1,127 @@
package tokenexchange
import (
"errors"
"net/http"
"github.com/zitadel/oidc/v2/pkg/client"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type TokenExchanger interface {
TokenEndpoint() string
HttpClient() *http.Client
AuthFn() (interface{}, error)
}
type OAuthTokenExchange struct {
httpClient *http.Client
tokenEndpoint string
authFn func() (interface{}, error)
}
func NewTokenExchanger(issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
return newOAuthTokenExchange(issuer, nil, options...)
}
func NewTokenExchangerClientCredentials(issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
authorizer := func() (interface{}, error) {
return httphelper.AuthorizeBasic(clientID, clientSecret), nil
}
return newOAuthTokenExchange(issuer, authorizer, options...)
}
func newOAuthTokenExchange(issuer string, authorizer func() (interface{}, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) {
te := &OAuthTokenExchange{
httpClient: httphelper.DefaultHTTPClient,
}
for _, opt := range options {
opt(te)
}
if te.tokenEndpoint == "" {
config, err := client.Discover(issuer, te.httpClient)
if err != nil {
return nil, err
}
te.tokenEndpoint = config.TokenEndpoint
}
if te.tokenEndpoint == "" {
return nil, errors.New("tokenURL is empty: please provide with either `WithStaticTokenEndpoint` or a discovery url")
}
te.authFn = authorizer
return te, nil
}
func WithHTTPClient(client *http.Client) func(*OAuthTokenExchange) {
return func(source *OAuthTokenExchange) {
source.httpClient = client
}
}
func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(*OAuthTokenExchange) {
return func(source *OAuthTokenExchange) {
source.tokenEndpoint = tokenEndpoint
}
}
func (te *OAuthTokenExchange) TokenEndpoint() string {
return te.tokenEndpoint
}
func (te *OAuthTokenExchange) HttpClient() *http.Client {
return te.httpClient
}
func (te *OAuthTokenExchange) AuthFn() (interface{}, error) {
if te.authFn != nil {
return te.authFn()
}
return nil, nil
}
// ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint.
// SubjectToken and SubjectTokenType are required parameters.
func ExchangeToken(
te TokenExchanger,
SubjectToken string,
SubjectTokenType oidc.TokenType,
ActorToken string,
ActorTokenType oidc.TokenType,
Resource []string,
Audience []string,
Scopes []string,
RequestedTokenType oidc.TokenType,
) (*oidc.TokenExchangeResponse, error) {
if SubjectToken == "" {
return nil, errors.New("empty subject_token")
}
if SubjectTokenType == "" {
return nil, errors.New("empty subject_token_type")
}
authFn, err := te.AuthFn()
if err != nil {
return nil, err
}
request := oidc.TokenExchangeRequest{
GrantType: oidc.GrantTypeTokenExchange,
SubjectToken: SubjectToken,
SubjectTokenType: SubjectTokenType,
ActorToken: ActorToken,
ActorTokenType: ActorTokenType,
Resource: Resource,
Audience: Audience,
Scopes: Scopes,
RequestedTokenType: RequestedTokenType,
}
return client.CallTokenExchangeEndpoint(request, authFn, te)
}