Merge branch 'next' into main-next
prepare the merge of next into main by resolving merge conflicts.
This commit is contained in:
commit
0476b5946e
122 changed files with 8195 additions and 2858 deletions
|
@ -4,22 +4,22 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
loginPath = "/login"
|
||||
)
|
||||
|
||||
func CodeFlow(ctx context.Context, relyingParty rp.RelyingParty, callbackPath, port string, stateProvider func() string) *oidc.Tokens {
|
||||
func CodeFlow[C oidc.IDClaims](ctx context.Context, relyingParty rp.RelyingParty, callbackPath, port string, stateProvider func() string) *oidc.Tokens[C] {
|
||||
codeflowCtx, codeflowCancel := context.WithCancel(ctx)
|
||||
defer codeflowCancel()
|
||||
|
||||
tokenChan := make(chan *oidc.Tokens, 1)
|
||||
tokenChan := make(chan *oidc.Tokens[C], 1)
|
||||
|
||||
callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
|
||||
callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp rp.RelyingParty) {
|
||||
tokenChan <- tokens
|
||||
msg := "<p><strong>Success!</strong></p>"
|
||||
msg = msg + "<p>You are authenticated and can now return to the CLI.</p>"
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"github.com/zitadel/oidc/pkg/oidc/grants/tokenexchange"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange"
|
||||
)
|
||||
|
||||
// DelegationTokenRequest is an implementation of TokenExchangeRequest
|
||||
// it exchanges an "urn:ietf:params:oauth:token-type:access_token" with an optional
|
||||
//"urn:ietf:params:oauth:token-type:access_token" actor token for an
|
||||
//"urn:ietf:params:oauth:token-type:access_token" delegation token
|
||||
// "urn:ietf:params:oauth:token-type:access_token" actor token for an
|
||||
// "urn:ietf:params:oauth:token-type:access_token" delegation token
|
||||
func DelegationTokenRequest(subjectToken string, opts ...tokenexchange.TokenExchangeOption) *tokenexchange.TokenExchangeRequest {
|
||||
return tokenexchange.NewTokenExchangeRequest(subjectToken, tokenexchange.AccessTokenType, opts...)
|
||||
}
|
||||
|
|
62
pkg/client/rp/device.go
Normal file
62
pkg/client/rp/device.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) {
|
||||
confg := rp.OAuthConfig()
|
||||
req := &oidc.ClientCredentialsRequest{
|
||||
GrantType: oidc.GrantTypeDeviceCode,
|
||||
Scope: scopes,
|
||||
ClientID: confg.ClientID,
|
||||
ClientSecret: confg.ClientSecret,
|
||||
}
|
||||
|
||||
if signer := rp.Signer(); signer != nil {
|
||||
assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, signer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build assertion: %w", err)
|
||||
}
|
||||
req.ClientAssertion = assertion
|
||||
req.ClientAssertionType = oidc.ClientAssertionTypeJWTAssertion
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// DeviceAuthorization starts a new Device Authorization flow as defined
|
||||
// in RFC 8628, section 3.1 and 3.2:
|
||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
|
||||
func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) {
|
||||
req, err := newDeviceClientCredentialsRequest(scopes, rp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.CallDeviceAuthorizationEndpoint(req, rp)
|
||||
}
|
||||
|
||||
// DeviceAccessToken attempts to obtain tokens from a Device Authorization,
|
||||
// by means of polling as defined in RFC, section 3.3 and 3.4:
|
||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4
|
||||
func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
|
||||
req := &client.DeviceAccessTokenRequest{
|
||||
DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
|
||||
GrantType: oidc.GrantTypeDeviceCode,
|
||||
DeviceCode: deviceCode,
|
||||
},
|
||||
}
|
||||
|
||||
req.ClientCredentialsRequest, err = newDeviceClientCredentialsRequest(nil, rp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.PollDeviceAccessTokenEndpoint(ctx, interval, req, tokenEndpointCaller{rp})
|
||||
}
|
|
@ -1,279 +0,0 @@
|
|||
package rp_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/example/server/exampleop"
|
||||
"github.com/zitadel/oidc/example/server/storage"
|
||||
|
||||
"github.com/jeremija/gosubmit"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
func TestRelyingPartySession(t *testing.T) {
|
||||
t.Log("------- start example OP ------")
|
||||
ctx := context.Background()
|
||||
exampleStorage := storage.NewStorage(storage.NewUserStore())
|
||||
var dh deferredHandler
|
||||
opServer := httptest.NewServer(&dh)
|
||||
defer opServer.Close()
|
||||
t.Logf("auth server at %s", opServer.URL)
|
||||
dh.Handler = exampleop.SetupServer(ctx, opServer.URL, exampleStorage)
|
||||
|
||||
targetURL := "http://local-site"
|
||||
localURL, err := url.Parse(targetURL + "/login?requestID=1234")
|
||||
require.NoError(t, err, "local url")
|
||||
|
||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
client := storage.WebClient(clientID, "secret", targetURL)
|
||||
storage.RegisterClients(client)
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
require.NoError(t, err, "create cookie jar")
|
||||
httpClient := &http.Client{
|
||||
Timeout: time.Second * 5,
|
||||
CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
Jar: jar,
|
||||
}
|
||||
|
||||
t.Log("------- create RP ------")
|
||||
key := []byte("test1234test1234")
|
||||
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
|
||||
provider, err := rp.NewRelyingPartyOIDC(
|
||||
opServer.URL,
|
||||
clientID,
|
||||
"secret",
|
||||
targetURL,
|
||||
[]string{"openid", "email", "profile", "offline_access"},
|
||||
rp.WithPKCE(cookieHandler),
|
||||
rp.WithVerifierOpts(
|
||||
rp.WithIssuedAtOffset(5*time.Second),
|
||||
rp.WithSupportedSigningAlgorithms("RS256", "RS384", "RS512", "ES256", "ES384", "ES512"),
|
||||
),
|
||||
)
|
||||
|
||||
t.Log("------- get redirect from local client (rp) to OP ------")
|
||||
state := "state-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
capturedW := httptest.NewRecorder()
|
||||
get := httptest.NewRequest("GET", localURL.String(), nil)
|
||||
rp.AuthURLHandler(func() string { return state }, provider,
|
||||
rp.WithPromptURLParam("Hello, World!", "Goodbye, World!"),
|
||||
rp.WithURLParam("custom", "param"),
|
||||
)(capturedW, get)
|
||||
|
||||
defer func() {
|
||||
if t.Failed() {
|
||||
t.Log("response body (redirect from RP to OP)", capturedW.Body.String())
|
||||
}
|
||||
}()
|
||||
require.GreaterOrEqual(t, capturedW.Code, 200, "captured response code")
|
||||
require.Less(t, capturedW.Code, 400, "captured response code")
|
||||
require.Contains(t, capturedW.Body.String(), `prompt=Hello%2C+World%21+Goodbye%2C+World%21`)
|
||||
require.Contains(t, capturedW.Body.String(), `custom=param`)
|
||||
|
||||
//nolint:bodyclose
|
||||
resp := capturedW.Result()
|
||||
jar.SetCookies(localURL, resp.Cookies())
|
||||
|
||||
startAuthURL, err := resp.Location()
|
||||
require.NoError(t, err, "get redirect")
|
||||
assert.NotEmpty(t, startAuthURL, "login url")
|
||||
t.Log("Starting auth at", startAuthURL)
|
||||
|
||||
t.Log("------- get redirect to OP to login page ------")
|
||||
loginPageURL := getRedirect(t, "get redirect to login page", httpClient, startAuthURL)
|
||||
t.Log("login page URL", loginPageURL)
|
||||
|
||||
t.Log("------- get login form ------")
|
||||
form := getForm(t, "get login form", httpClient, loginPageURL)
|
||||
t.Log("login form (unfilled)", string(form))
|
||||
defer func() {
|
||||
if t.Failed() {
|
||||
t.Logf("login form (unfilled): %s", string(form))
|
||||
}
|
||||
}()
|
||||
|
||||
t.Log("------- post to login form, get redirect to OP ------")
|
||||
postLoginRedirectURL := fillForm(t, "fill login form", httpClient, form, loginPageURL,
|
||||
gosubmit.Set("username", "test-user"),
|
||||
gosubmit.Set("password", "verysecure"))
|
||||
t.Logf("Get redirect from %s", postLoginRedirectURL)
|
||||
|
||||
t.Log("------- redirect from OP back to RP ------")
|
||||
codeBearingURL := getRedirect(t, "get redirect with code", httpClient, postLoginRedirectURL)
|
||||
t.Logf("Redirect with code %s", codeBearingURL)
|
||||
|
||||
t.Log("------- exchange code for tokens ------")
|
||||
capturedW = httptest.NewRecorder()
|
||||
get = httptest.NewRequest("GET", codeBearingURL.String(), nil)
|
||||
for _, cookie := range jar.Cookies(codeBearingURL) {
|
||||
get.Header["Cookie"] = append(get.Header["Cookie"], cookie.String())
|
||||
t.Logf("setting cookie %s", cookie)
|
||||
}
|
||||
|
||||
var accessToken, refreshToken, idToken, email string
|
||||
redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {
|
||||
require.NotNil(t, tokens, "tokens")
|
||||
require.NotNil(t, info, "info")
|
||||
t.Log("access token", tokens.AccessToken)
|
||||
t.Log("refresh token", tokens.RefreshToken)
|
||||
t.Log("id token", tokens.IDToken)
|
||||
t.Log("email", info.GetEmail())
|
||||
|
||||
accessToken = tokens.AccessToken
|
||||
refreshToken = tokens.RefreshToken
|
||||
idToken = tokens.IDToken
|
||||
email = info.GetEmail()
|
||||
http.Redirect(w, r, targetURL, 302)
|
||||
}
|
||||
rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider, rp.WithURLParam("custom", "param"))(capturedW, get)
|
||||
|
||||
defer func() {
|
||||
if t.Failed() {
|
||||
t.Log("token exchange response body", capturedW.Body.String())
|
||||
require.GreaterOrEqual(t, capturedW.Code, 200, "captured response code")
|
||||
}
|
||||
}()
|
||||
require.Less(t, capturedW.Code, 400, "token exchange response code")
|
||||
require.Less(t, capturedW.Code, 400, "token exchange response code")
|
||||
// TODO: how to check the custom header was sent to the server?
|
||||
|
||||
//nolint:bodyclose
|
||||
resp = capturedW.Result()
|
||||
|
||||
authorizedURL, err := resp.Location()
|
||||
require.NoError(t, err, "get fully-authorizied redirect location")
|
||||
require.Equal(t, targetURL, authorizedURL.String(), "fully-authorizied redirect location")
|
||||
|
||||
require.NotEmpty(t, idToken, "id token")
|
||||
assert.NotEmpty(t, refreshToken, "refresh token")
|
||||
assert.NotEmpty(t, accessToken, "access token")
|
||||
assert.NotEmpty(t, email, "email")
|
||||
|
||||
t.Log("------- refresh tokens ------")
|
||||
|
||||
newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
require.NoError(t, err, "refresh token")
|
||||
assert.NotNil(t, newTokens, "access token")
|
||||
t.Logf("new access token %s", newTokens.AccessToken)
|
||||
t.Logf("new refresh token %s", newTokens.RefreshToken)
|
||||
t.Logf("new token type %s", newTokens.TokenType)
|
||||
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
|
||||
require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
|
||||
|
||||
t.Log("------ end session (logout) ------")
|
||||
|
||||
newLoc, err := rp.EndSession(provider, idToken, "", "")
|
||||
require.NoError(t, err, "logout")
|
||||
if newLoc != nil {
|
||||
t.Logf("redirect to %s", newLoc)
|
||||
} else {
|
||||
t.Logf("no redirect")
|
||||
}
|
||||
|
||||
t.Log("------ attempt refresh again (should fail) ------")
|
||||
t.Log("trying original refresh token", refreshToken)
|
||||
_, err = rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
assert.Errorf(t, err, "refresh with original")
|
||||
if newTokens.RefreshToken != "" {
|
||||
t.Log("trying replacement refresh token", newTokens.RefreshToken)
|
||||
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
|
||||
assert.Errorf(t, err, "refresh with replacement")
|
||||
}
|
||||
|
||||
t.Run("WithPrompt", func(t *testing.T) {
|
||||
opts := rp.WithPrompt("foo", "bar")()
|
||||
url := provider.OAuthConfig().AuthCodeURL("some", opts...)
|
||||
|
||||
require.Contains(t, url, "prompt=foo+bar")
|
||||
})
|
||||
}
|
||||
|
||||
type deferredHandler struct {
|
||||
http.Handler
|
||||
}
|
||||
|
||||
func getRedirect(t *testing.T, desc string, httpClient *http.Client, uri *url.URL) *url.URL {
|
||||
req := &http.Request{
|
||||
Method: "GET",
|
||||
URL: uri,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
require.NoError(t, err, "GET "+uri.String())
|
||||
|
||||
defer func() {
|
||||
if t.Failed() {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Logf("%s: GET %s: body: %s", desc, uri, string(body))
|
||||
}
|
||||
}()
|
||||
|
||||
//nolint:errcheck
|
||||
defer resp.Body.Close()
|
||||
redirect, err := resp.Location()
|
||||
require.NoErrorf(t, err, "%s: get redirect %s", desc, uri)
|
||||
require.NotEmptyf(t, redirect, "%s: get redirect %s", desc, uri)
|
||||
return redirect
|
||||
}
|
||||
|
||||
func getForm(t *testing.T, desc string, httpClient *http.Client, uri *url.URL) []byte {
|
||||
req := &http.Request{
|
||||
Method: "GET",
|
||||
URL: uri,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
require.NoErrorf(t, err, "%s: GET %s", desc, uri)
|
||||
//nolint:errcheck
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err, "%s: read GET %s", desc, uri)
|
||||
return body
|
||||
}
|
||||
|
||||
func fillForm(t *testing.T, desc string, httpClient *http.Client, body []byte, uri *url.URL, opts ...gosubmit.Option) *url.URL {
|
||||
// TODO: switch to io.NopCloser when go1.15 support is dropped
|
||||
req := gosubmit.ParseWithURL(ioutil.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest(
|
||||
append([]gosubmit.Option{gosubmit.AutoFill()}, opts...)...,
|
||||
)
|
||||
if req.URL.Scheme == "" {
|
||||
req.URL = uri
|
||||
t.Log("request lost it's proto..., adding back... request now", req.URL)
|
||||
}
|
||||
req.RequestURI = "" // bug in gosubmit?
|
||||
resp, err := httpClient.Do(req)
|
||||
require.NoErrorf(t, err, "%s: POST %s", desc, uri)
|
||||
|
||||
//nolint:errcheck
|
||||
defer resp.Body.Close()
|
||||
defer func() {
|
||||
if t.Failed() {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Logf("%s: GET %s: body: %s", desc, uri, string(body))
|
||||
}
|
||||
}()
|
||||
|
||||
redirect, err := resp.Location()
|
||||
require.NoErrorf(t, err, "%s: redirect for POST %s", desc, uri)
|
||||
return redirect
|
||||
}
|
|
@ -9,8 +9,8 @@ import (
|
|||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet {
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
package mock
|
||||
|
||||
//go:generate mockgen -package mock -destination ./verifier.mock.go github.com/zitadel/oidc/pkg/rp Verifier
|
|
@ -1,67 +0,0 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/pkg/rp (interfaces: Verifier)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
// MockVerifier is a mock of Verifier interface
|
||||
type MockVerifier struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockVerifierMockRecorder
|
||||
}
|
||||
|
||||
// MockVerifierMockRecorder is the mock recorder for MockVerifier
|
||||
type MockVerifierMockRecorder struct {
|
||||
mock *MockVerifier
|
||||
}
|
||||
|
||||
// NewMockVerifier creates a new mock instance
|
||||
func NewMockVerifier(ctrl *gomock.Controller) *MockVerifier {
|
||||
mock := &MockVerifier{ctrl: ctrl}
|
||||
mock.recorder = &MockVerifierMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockVerifier) EXPECT() *MockVerifierMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Verify mocks base method
|
||||
func (m *MockVerifier) Verify(arg0 context.Context, arg1, arg2 string) (*oidc.IDTokenClaims, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Verify", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(*oidc.IDTokenClaims)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Verify indicates an expected call of Verify
|
||||
func (mr *MockVerifierMockRecorder) Verify(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockVerifier)(nil).Verify), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// VerifyIDToken mocks base method
|
||||
func (m *MockVerifier) VerifyIDToken(arg0 context.Context, arg1 string) (*oidc.IDTokenClaims, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "VerifyIDToken", arg0, arg1)
|
||||
ret0, _ := ret[0].(*oidc.IDTokenClaims)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// VerifyIDToken indicates an expected call of VerifyIDToken
|
||||
func (mr *MockVerifierMockRecorder) VerifyIDToken(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyIDToken", reflect.TypeOf((*MockVerifier)(nil).VerifyIDToken), arg0, arg1)
|
||||
}
|
|
@ -14,9 +14,9 @@ import (
|
|||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -54,11 +54,15 @@ type RelyingParty interface {
|
|||
GetEndSessionEndpoint() string
|
||||
|
||||
// GetRevokeEndpoint returns the endpoint to revoke a specific token
|
||||
// "GetRevokeEndpoint() string" will be added in a future release
|
||||
GetRevokeEndpoint() string
|
||||
|
||||
// UserinfoEndpoint returns the userinfo
|
||||
UserinfoEndpoint() string
|
||||
|
||||
// GetDeviceAuthorizationEndpoint returns the enpoint which can
|
||||
// be used to start a DeviceAuthorization flow.
|
||||
GetDeviceAuthorizationEndpoint() string
|
||||
|
||||
// IDTokenVerifier returns the verifier interface used for oidc id_token verification
|
||||
IDTokenVerifier() IDTokenVerifier
|
||||
// ErrorHandler returns the handler used for callback errors
|
||||
|
@ -121,6 +125,10 @@ func (rp *relyingParty) UserinfoEndpoint() string {
|
|||
return rp.endpoints.UserinfoURL
|
||||
}
|
||||
|
||||
func (rp *relyingParty) GetDeviceAuthorizationEndpoint() string {
|
||||
return rp.endpoints.DeviceAuthorizationURL
|
||||
}
|
||||
|
||||
func (rp *relyingParty) GetEndSessionEndpoint() string {
|
||||
return rp.endpoints.EndSessionURL
|
||||
}
|
||||
|
@ -371,7 +379,7 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri
|
|||
|
||||
// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
|
||||
// returning it parsed together with the oauth2 tokens (access, refresh)
|
||||
func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens, err error) {
|
||||
func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
|
||||
codeOpts := make([]oauth2.AuthCodeOption, 0)
|
||||
for _, opt := range opts {
|
||||
|
@ -384,7 +392,7 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod
|
|||
}
|
||||
|
||||
if rp.IsOAuth2Only() {
|
||||
return &oidc.Tokens{Token: token}, nil
|
||||
return &oidc.Tokens[C]{Token: token}, nil
|
||||
}
|
||||
|
||||
idTokenString, ok := token.Extra(idTokenKey).(string)
|
||||
|
@ -392,21 +400,21 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod
|
|||
return nil, errors.New("id_token missing")
|
||||
}
|
||||
|
||||
idToken, err := VerifyTokens(ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
|
||||
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
|
||||
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
|
||||
}
|
||||
|
||||
type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty)
|
||||
type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty)
|
||||
|
||||
// CodeExchangeHandler extends the `CodeExchange` method with a http handler
|
||||
// including cookie handling for secure `state` transfer
|
||||
// and optional PKCE code verifier checking.
|
||||
// Custom paramaters can optionally be set to the token URL.
|
||||
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
|
||||
func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := tryReadStateCookie(w, r, rp)
|
||||
if err != nil {
|
||||
|
@ -439,7 +447,7 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlPara
|
|||
}
|
||||
codeOpts = append(codeOpts, WithClientAssertionJWT(assertion))
|
||||
}
|
||||
tokens, err := CodeExchange(r.Context(), params.Get("code"), rp, codeOpts...)
|
||||
tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
|
@ -448,13 +456,13 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlPara
|
|||
}
|
||||
}
|
||||
|
||||
type CodeExchangeUserinfoCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, provider RelyingParty, info oidc.UserInfo)
|
||||
type CodeExchangeUserinfoCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info *oidc.UserInfo)
|
||||
|
||||
// UserinfoCallback wraps the callback function of the CodeExchangeHandler
|
||||
// and calls the userinfo endpoint with the access token
|
||||
// on success it will pass the userinfo into its callback function as well
|
||||
func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
|
||||
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty) {
|
||||
func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] {
|
||||
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
|
||||
info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
|
||||
if err != nil {
|
||||
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
|
||||
|
@ -465,17 +473,17 @@ func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
|
|||
}
|
||||
|
||||
// Userinfo will call the OIDC Userinfo Endpoint with the provided token
|
||||
func Userinfo(token, tokenType, subject string, rp RelyingParty) (oidc.UserInfo, error) {
|
||||
func Userinfo(token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) {
|
||||
req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("authorization", tokenType+" "+token)
|
||||
userinfo := oidc.NewUserInfo()
|
||||
userinfo := new(oidc.UserInfo)
|
||||
if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userinfo.GetSubject() != subject {
|
||||
if userinfo.Subject != subject {
|
||||
return nil, ErrUserInfoSubNotMatching
|
||||
}
|
||||
return userinfo, nil
|
||||
|
@ -506,11 +514,12 @@ type OptionFunc func(RelyingParty)
|
|||
|
||||
type Endpoints struct {
|
||||
oauth2.Endpoint
|
||||
IntrospectURL string
|
||||
UserinfoURL string
|
||||
JKWsURL string
|
||||
EndSessionURL string
|
||||
RevokeURL string
|
||||
IntrospectURL string
|
||||
UserinfoURL string
|
||||
JKWsURL string
|
||||
EndSessionURL string
|
||||
RevokeURL string
|
||||
DeviceAuthorizationURL string
|
||||
}
|
||||
|
||||
func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
|
||||
|
@ -520,11 +529,12 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
|
|||
AuthStyle: oauth2.AuthStyleAutoDetect,
|
||||
TokenURL: discoveryConfig.TokenEndpoint,
|
||||
},
|
||||
IntrospectURL: discoveryConfig.IntrospectionEndpoint,
|
||||
UserinfoURL: discoveryConfig.UserinfoEndpoint,
|
||||
JKWsURL: discoveryConfig.JwksURI,
|
||||
EndSessionURL: discoveryConfig.EndSessionEndpoint,
|
||||
RevokeURL: discoveryConfig.RevocationEndpoint,
|
||||
IntrospectURL: discoveryConfig.IntrospectionEndpoint,
|
||||
UserinfoURL: discoveryConfig.UserinfoEndpoint,
|
||||
JKWsURL: discoveryConfig.JwksURI,
|
||||
EndSessionURL: discoveryConfig.EndSessionEndpoint,
|
||||
RevokeURL: discoveryConfig.RevocationEndpoint,
|
||||
DeviceAuthorizationURL: discoveryConfig.DeviceAuthorizationEndpoint,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc/grants/tokenexchange"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange"
|
||||
)
|
||||
|
||||
// TokenExchangeRP extends the `RelyingParty` interface for the *draft* oauth2 `Token Exchange`
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type IDTokenVerifier interface {
|
||||
|
@ -20,76 +20,78 @@ type IDTokenVerifier interface {
|
|||
}
|
||||
|
||||
// VerifyTokens implement the Token Response Validation as defined in OIDC specification
|
||||
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
||||
func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
|
||||
idToken, err := VerifyIDToken(ctx, idTokenString, v)
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
||||
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) {
|
||||
var nilClaims C
|
||||
|
||||
claims, err = VerifyIDToken[C](ctx, idToken, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil {
|
||||
return nil, err
|
||||
if err := VerifyAccessToken(accessToken, claims.GetAccessTokenHash(), claims.GetSignatureAlgorithm()); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
return idToken, nil
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// VerifyIDToken validates the id token according to
|
||||
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
|
||||
claims := oidc.EmptyIDTokenClaims()
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) {
|
||||
var nilClaims C
|
||||
|
||||
decrypted, err := oidc.DecryptToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
payload, err := oidc.ParseToken(decrypted, claims)
|
||||
payload, err := oidc.ParseToken(decrypted, &claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err := oidc.CheckSubject(claims); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAudience(claims, v.ClientID()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// VerifyAccessToken validates the access token according to
|
||||
//https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
|
||||
func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
|
||||
if atHash == "" {
|
||||
return nil
|
||||
|
@ -112,7 +114,7 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...
|
|||
issuer: issuer,
|
||||
clientID: clientID,
|
||||
keySet: keySet,
|
||||
offset: 1 * time.Second,
|
||||
offset: time.Second,
|
||||
nonce: func(_ context.Context) string {
|
||||
return ""
|
||||
},
|
||||
|
@ -139,7 +141,7 @@ func WithIssuedAtOffset(offset time.Duration) func(*idTokenVerifier) {
|
|||
// WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
|
||||
func WithIssuedAtMaxAge(maxAge time.Duration) func(*idTokenVerifier) {
|
||||
return func(v *idTokenVerifier) {
|
||||
v.maxAge = maxAge
|
||||
v.maxAgeIAT = maxAge
|
||||
}
|
||||
}
|
||||
|
||||
|
|
339
pkg/client/rp/verifier_test.go
Normal file
339
pkg/client/rp/verifier_test.go
Normal file
|
@ -0,0 +1,339 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
func TestVerifyTokens(t *testing.T) {
|
||||
verifier := &idTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
maxAgeIAT: 2 * time.Minute,
|
||||
offset: time.Second,
|
||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
keySet: tu.KeySet{},
|
||||
maxAge: 2 * time.Minute,
|
||||
acr: tu.ACRVerify,
|
||||
nonce: func(context.Context) string { return tu.ValidNonce },
|
||||
clientID: tu.ValidClientID,
|
||||
}
|
||||
accessToken, _ := tu.ValidAccessToken()
|
||||
atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
idTokenClaims func() (string, *oidc.IDTokenClaims)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "without access token",
|
||||
idTokenClaims: tu.ValidIDToken,
|
||||
},
|
||||
{
|
||||
name: "with access token",
|
||||
accessToken: accessToken,
|
||||
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash,
|
||||
)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "expired id token",
|
||||
accessToken: accessToken,
|
||||
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash,
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong access token",
|
||||
accessToken: accessToken,
|
||||
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "~~~",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
idToken, want := tt.idTokenClaims()
|
||||
got, err := VerifyTokens[*oidc.IDTokenClaims](context.Background(), tt.accessToken, idToken, verifier)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, got, want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyIDToken(t *testing.T) {
|
||||
verifier := &idTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
maxAgeIAT: 2 * time.Minute,
|
||||
offset: time.Second,
|
||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
keySet: tu.KeySet{},
|
||||
maxAge: 2 * time.Minute,
|
||||
acr: tu.ACRVerify,
|
||||
nonce: func(context.Context) string { return tu.ValidNonce },
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientID string
|
||||
tokenClaims func() (string, *oidc.IDTokenClaims)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: tu.ValidIDToken,
|
||||
},
|
||||
{
|
||||
name: "parse err",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid signature",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty subject",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, "", tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong issuer",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
"foo", tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong clientID",
|
||||
clientID: "foo",
|
||||
tokenClaims: tu.ValidIDToken,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong IAT",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong acr",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
"else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired auth",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime.Add(-time.Hour), tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong nonce",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, "foo",
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, want := tt.tokenClaims()
|
||||
verifier.clientID = tt.clientID
|
||||
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, got, want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAccessToken(t *testing.T) {
|
||||
token, _ := tu.ValidAccessToken()
|
||||
hash, err := oidc.ClaimHash(token, tu.SignatureAlgorithm)
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
accessToken string
|
||||
atHash string
|
||||
sigAlgorithm jose.SignatureAlgorithm
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty hash",
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
args: args{
|
||||
accessToken: token,
|
||||
atHash: hash,
|
||||
sigAlgorithm: tu.SignatureAlgorithm,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid algorithm",
|
||||
args: args{
|
||||
accessToken: token,
|
||||
atHash: hash,
|
||||
sigAlgorithm: "foo",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "mismatch",
|
||||
args: args{
|
||||
accessToken: token,
|
||||
atHash: "~~",
|
||||
sigAlgorithm: tu.SignatureAlgorithm,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := VerifyAccessToken(tt.args.accessToken, tt.args.atHash, tt.args.sigAlgorithm)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewIDTokenVerifier(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
clientID string
|
||||
keySet oidc.KeySet
|
||||
options []VerifierOption
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want IDTokenVerifier
|
||||
}{
|
||||
{
|
||||
name: "nil nonce", // otherwise assert.Equal will fail on the function
|
||||
args: args{
|
||||
issuer: tu.ValidIssuer,
|
||||
clientID: tu.ValidClientID,
|
||||
keySet: tu.KeySet{},
|
||||
options: []VerifierOption{
|
||||
WithIssuedAtOffset(time.Minute),
|
||||
WithIssuedAtMaxAge(time.Hour),
|
||||
WithNonce(nil), // otherwise assert.Equal will fail on the function
|
||||
WithACRVerifier(nil),
|
||||
WithAuthTimeMaxAge(2 * time.Hour),
|
||||
WithSupportedSigningAlgorithms("ABC", "DEF"),
|
||||
},
|
||||
},
|
||||
want: &idTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
offset: time.Minute,
|
||||
maxAgeIAT: time.Hour,
|
||||
clientID: tu.ValidClientID,
|
||||
keySet: tu.KeySet{},
|
||||
nonce: nil,
|
||||
acr: nil,
|
||||
maxAge: 2 * time.Hour,
|
||||
supportedSignAlgs: []string{"ABC", "DEF"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewIDTokenVerifier(tt.args.issuer, tt.args.clientID, tt.args.keySet, tt.args.options...)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
86
pkg/client/rp/verifier_tokens_example_test.go
Normal file
86
pkg/client/rp/verifier_tokens_example_test.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package rp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
// MyCustomClaims extends the TokenClaims base,
|
||||
// so it implmeents the oidc.Claims interface.
|
||||
// Instead of carrying a map, we add needed fields// to the struct for type safe access.
|
||||
type MyCustomClaims struct {
|
||||
oidc.TokenClaims
|
||||
NotBefore oidc.Time `json:"nbf,omitempty"`
|
||||
AccessTokenHash string `json:"at_hash,omitempty"`
|
||||
Foo string `json:"foo,omitempty"`
|
||||
Bar *Nested `json:"bar,omitempty"`
|
||||
}
|
||||
|
||||
// GetAccessTokenHash is required to implement
|
||||
// the oidc.IDClaims interface.
|
||||
func (c *MyCustomClaims) GetAccessTokenHash() string {
|
||||
return c.AccessTokenHash
|
||||
}
|
||||
|
||||
// Nested struct types are also possible.
|
||||
type Nested struct {
|
||||
Count int `json:"count,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
/*
|
||||
idToken carries the following claims. foo and bar are custom claims
|
||||
|
||||
{
|
||||
"acr": "something",
|
||||
"amr": [
|
||||
"foo",
|
||||
"bar"
|
||||
],
|
||||
"at_hash": "2dzbm_vIxy-7eRtqUIGPPw",
|
||||
"aud": [
|
||||
"unit",
|
||||
"test",
|
||||
"555666"
|
||||
],
|
||||
"auth_time": 1678100961,
|
||||
"azp": "555666",
|
||||
"bar": {
|
||||
"count": 22,
|
||||
"tags": [
|
||||
"some",
|
||||
"tags"
|
||||
]
|
||||
},
|
||||
"client_id": "555666",
|
||||
"exp": 4802238682,
|
||||
"foo": "Hello, World!",
|
||||
"iat": 1678101021,
|
||||
"iss": "local.com",
|
||||
"jti": "9876",
|
||||
"nbf": 1678101021,
|
||||
"nonce": "12345",
|
||||
"sub": "tim@local.com"
|
||||
}
|
||||
*/
|
||||
const idToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF0X2hhc2giOiIyZHpibV92SXh5LTdlUnRxVUlHUFB3IiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImF1dGhfdGltZSI6MTY3ODEwMDk2MSwiYXpwIjoiNTU1NjY2IiwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiY2xpZW50X2lkIjoiNTU1NjY2IiwiZXhwIjo0ODAyMjM4NjgyLCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MTAxMDIxLCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MTAxMDIxLCJub25jZSI6IjEyMzQ1Iiwic3ViIjoidGltQGxvY2FsLmNvbSJ9.t3GXSfVNNwiW1Suv9_84v0sdn2_-RWHVxhphhRozDXnsO7SDNOlGnEioemXABESxSzMclM7gB7mYy5Qah2ZUNx7eP5t2njoxEYfavgHwx7UJZ2NCg8NDPQyr-hlxelEcfdXK-I0oTd-FRDvF4rqPkD9Us52IpnplChCxnHFgh4wKwPqZZjv2IXVCtn0ilKW3hff1rMOYKEuLRcN2YP0gkyuqyHvcf2dMmjod0t4sLOTJ82rsCbMBC5CLpqv3nIC9HOGITkt1Kd-Am0n1LrdZvWwTo6RFe8AnzF0gpqjcB5Wg4Qeh58DIjZOz4f_8wnmJ_gCqyRh5vfSW4XHdbum0Tw`
|
||||
const accessToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhdWQiOlsidW5pdCIsInRlc3QiXSwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiZXhwIjo0ODAyMjM4NjgyLCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MTAxMDIxLCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MTAxMDIxLCJzdWIiOiJ0aW1AbG9jYWwuY29tIn0.Zrz3LWSRjCMJZUMaI5dUbW4vGdSmEeJQ3ouhaX0bcW9rdFFLgBI4K2FWJhNivq8JDmCGSxwLu3mI680GWmDaEoAx1M5sCO9lqfIZHGZh-lfAXk27e6FPLlkTDBq8Bx4o4DJ9Fw0hRJGjUTjnYv5cq1vo2-UqldasL6CwTbkzNC_4oQFfRtuodC4Ql7dZ1HRv5LXuYx7KPkOssLZtV9cwtJp5nFzKjcf2zEE_tlbjcpynMwypornRUp1EhCWKRUGkJhJeiP71ECY5pQhShfjBu9Nc5wDpSnZmnk2S4YsPrRK3QkE-iEkas8BfsOCrGoErHjEJexAIDjasGO5PFLWfCA`
|
||||
|
||||
func ExampleVerifyTokens_customClaims() {
|
||||
v := rp.NewIDTokenVerifier("local.com", "555666", tu.KeySet{},
|
||||
rp.WithNonce(func(ctx context.Context) string { return "12345" }),
|
||||
)
|
||||
|
||||
// VerifyAccessToken can be called with the *MyCustomClaims.
|
||||
claims, err := rp.VerifyTokens[*MyCustomClaims](context.TODO(), accessToken, idToken, v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Here we have typesafe access to the custom claims
|
||||
fmt.Println(claims.Foo, claims.Bar.Count, claims.Bar.Tags)
|
||||
// Output: Hello, World! 22 [some tags]
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue