Merge branch 'next' into main-next

prepare the merge of next into main by resolving merge conflicts.
This commit is contained in:
Tim Möhlmann 2023-03-15 16:26:32 +02:00
commit 0476b5946e
122 changed files with 8195 additions and 2858 deletions

View file

@ -1,31 +1,25 @@
package client
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"reflect"
"strings"
"time"
"github.com/gorilla/schema"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/pkg/crypto"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/crypto"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
var Encoder = func() httphelper.Encoder {
e := schema.NewEncoder()
e.RegisterEncoder(oidc.SpaceDelimitedArray{}, func(value reflect.Value) string {
return value.Interface().(oidc.SpaceDelimitedArray).Encode()
})
return e
}()
var Encoder = httphelper.Encoder(oidc.NewEncoder())
// Discover calls the discovery endpoint of the provided issuer and returns its configuration
// It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url
@ -90,6 +84,9 @@ func CallEndSessionEndpoint(request interface{}, authFn interface{}, caller EndS
return http.ErrUseLastResponse
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
body, err := io.ReadAll(resp.Body)
@ -148,6 +145,18 @@ func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCa
return nil
}
func CallTokenExchangeEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) {
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn)
if err != nil {
return nil, err
}
tokenRes := new(oidc.TokenExchangeResponse)
if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil {
return nil, err
}
return tokenRes, nil
}
func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error) {
privateKey, err := crypto.BytesToPrivateKey(key)
if err != nil {
@ -167,7 +176,98 @@ func SignedJWTProfileAssertion(clientID string, audience []string, expiration ti
Issuer: clientID,
Subject: clientID,
Audience: audience,
ExpiresAt: oidc.Time(exp),
IssuedAt: oidc.Time(iat),
ExpiresAt: oidc.FromTime(exp),
IssuedAt: oidc.FromTime(iat),
}, signer)
}
type DeviceAuthorizationCaller interface {
GetDeviceAuthorizationEndpoint() string
HttpClient() *http.Client
}
func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) {
req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil)
if err != nil {
return nil, err
}
if request.ClientSecret != "" {
req.SetBasicAuth(request.ClientID, request.ClientSecret)
}
resp := new(oidc.DeviceAuthorizationResponse)
if err := httphelper.HttpRequest(caller.HttpClient(), req, &resp); err != nil {
return nil, err
}
return resp, nil
}
type DeviceAccessTokenRequest struct {
*oidc.ClientCredentialsRequest
oidc.DeviceAccessTokenRequest
}
func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil)
if err != nil {
return nil, err
}
if request.ClientSecret != "" {
req.SetBasicAuth(request.ClientID, request.ClientSecret)
}
httpResp, err := caller.HttpClient().Do(req)
if err != nil {
return nil, err
}
defer httpResp.Body.Close()
resp := new(struct {
*oidc.AccessTokenResponse
*oidc.Error
})
if err = json.NewDecoder(httpResp.Body).Decode(resp); err != nil {
return nil, err
}
if httpResp.StatusCode == http.StatusOK {
return resp.AccessTokenResponse, nil
}
return nil, resp.Error
}
func PollDeviceAccessTokenEndpoint(ctx context.Context, interval time.Duration, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
for {
timer := time.After(interval)
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timer:
}
ctx, cancel := context.WithTimeout(ctx, interval)
defer cancel()
resp, err := CallDeviceAccessTokenEndpoint(ctx, request, caller)
if err == nil {
return resp, nil
}
if errors.Is(err, context.DeadlineExceeded) {
interval += 5 * time.Second
}
var target *oidc.Error
if !errors.As(err, &target) {
return nil, err
}
switch target.ErrorType {
case oidc.AuthorizationPending:
continue
case oidc.SlowDown:
interval += 5 * time.Second
continue
default:
return nil, err
}
}
}

View file

@ -1,8 +1,7 @@
package rp_test
package client_test
import (
"bytes"
"context"
"io"
"io/ioutil"
"math/rand"
@ -15,34 +14,142 @@ import (
"testing"
"time"
"github.com/zitadel/oidc/example/server/exampleop"
"github.com/zitadel/oidc/example/server/storage"
"github.com/jeremija/gosubmit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/pkg/client/rp"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/example/server/exampleop"
"github.com/zitadel/oidc/v2/example/server/storage"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/client/rs"
"github.com/zitadel/oidc/v2/pkg/client/tokenexchange"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
func TestRelyingPartySession(t *testing.T) {
t.Log("------- start example OP ------")
ctx := context.Background()
exampleStorage := storage.NewStorage(storage.NewUserStore())
targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
var dh deferredHandler
opServer := httptest.NewServer(&dh)
defer opServer.Close()
t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(ctx, opServer.URL, exampleStorage)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
t.Log("------- run authorization code flow ------")
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, "secret")
t.Log("------- refresh tokens ------")
newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "")
require.NoError(t, err, "refresh token")
assert.NotNil(t, newTokens, "access token")
t.Logf("new access token %s", newTokens.AccessToken)
t.Logf("new refresh token %s", newTokens.RefreshToken)
t.Logf("new token type %s", newTokens.TokenType)
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
t.Log("------ end session (logout) ------")
newLoc, err := rp.EndSession(provider, idToken, "", "")
require.NoError(t, err, "logout")
if newLoc != nil {
t.Logf("redirect to %s", newLoc)
} else {
t.Logf("no redirect")
}
t.Log("------ attempt refresh again (should fail) ------")
t.Log("trying original refresh token", refreshToken)
_, err = rp.RefreshAccessToken(provider, refreshToken, "", "")
assert.Errorf(t, err, "refresh with original")
if newTokens.RefreshToken != "" {
t.Log("trying replacement refresh token", newTokens.RefreshToken)
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
assert.Errorf(t, err, "refresh with replacement")
}
}
func TestResourceServerTokenExchange(t *testing.T) {
t.Log("------- start example OP ------")
targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
var dh deferredHandler
opServer := httptest.NewServer(&dh)
defer opServer.Close()
t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
clientSecret := "secret"
t.Log("------- run authorization code flow ------")
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)
resourceServer, err := rs.NewResourceServerClientCredentials(opServer.URL, clientID, clientSecret)
require.NoError(t, err, "new resource server")
t.Log("------- exchage refresh tokens (impersonation) ------")
tokenExchangeResponse, err := tokenexchange.ExchangeToken(
resourceServer,
refreshToken,
oidc.RefreshTokenType,
"",
"",
[]string{},
[]string{},
[]string{"profile", "custom_scope:impersonate:id2"},
oidc.RefreshTokenType,
)
require.NoError(t, err, "refresh token")
require.NotNil(t, tokenExchangeResponse, "token exchange response")
assert.Equal(t, tokenExchangeResponse.IssuedTokenType, oidc.RefreshTokenType)
assert.NotEmpty(t, tokenExchangeResponse.AccessToken, "access token")
assert.NotEmpty(t, tokenExchangeResponse.RefreshToken, "refresh token")
assert.Equal(t, []string(tokenExchangeResponse.Scopes), []string{"profile", "custom_scope:impersonate:id2"})
t.Log("------ end session (logout) ------")
newLoc, err := rp.EndSession(provider, idToken, "", "")
require.NoError(t, err, "logout")
if newLoc != nil {
t.Logf("redirect to %s", newLoc)
} else {
t.Logf("no redirect")
}
t.Log("------- attempt exchage again (should fail) ------")
tokenExchangeResponse, err = tokenexchange.ExchangeToken(
resourceServer,
refreshToken,
oidc.RefreshTokenType,
"",
"",
[]string{},
[]string{},
[]string{"profile", "custom_scope:impersonate:id2"},
oidc.RefreshTokenType,
)
require.Error(t, err, "refresh token")
assert.Contains(t, err.Error(), "subject_token is invalid")
require.Nil(t, tokenExchangeResponse, "token exchange response")
}
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) {
targetURL := "http://local-site"
localURL, err := url.Parse(targetURL + "/login?requestID=1234")
require.NoError(t, err, "local url")
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
client := storage.WebClient(clientID, "secret", targetURL)
client := storage.WebClient(clientID, clientSecret, targetURL)
storage.RegisterClients(client)
jar, err := cookiejar.New(nil)
@ -58,10 +165,10 @@ func TestRelyingPartySession(t *testing.T) {
t.Log("------- create RP ------")
key := []byte("test1234test1234")
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
provider, err := rp.NewRelyingPartyOIDC(
provider, err = rp.NewRelyingPartyOIDC(
opServer.URL,
clientID,
"secret",
clientSecret,
targetURL,
[]string{"openid", "email", "profile", "offline_access"},
rp.WithPKCE(cookieHandler),
@ -70,8 +177,10 @@ func TestRelyingPartySession(t *testing.T) {
rp.WithSupportedSigningAlgorithms("RS256", "RS384", "RS512", "ES256", "ES384", "ES512"),
),
)
require.NoError(t, err, "new rp")
t.Log("------- get redirect from local client (rp) to OP ------")
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
state := "state-" + strconv.FormatInt(seed.Int63(), 25)
capturedW := httptest.NewRecorder()
get := httptest.NewRequest("GET", localURL.String(), nil)
@ -114,7 +223,7 @@ func TestRelyingPartySession(t *testing.T) {
t.Log("------- post to login form, get redirect to OP ------")
postLoginRedirectURL := fillForm(t, "fill login form", httpClient, form, loginPageURL,
gosubmit.Set("username", "test-user"),
gosubmit.Set("username", "test-user@local-site"),
gosubmit.Set("password", "verysecure"))
t.Logf("Get redirect from %s", postLoginRedirectURL)
@ -130,19 +239,19 @@ func TestRelyingPartySession(t *testing.T) {
t.Logf("setting cookie %s", cookie)
}
var accessToken, refreshToken, idToken, email string
redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {
var email string
redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
require.NotNil(t, tokens, "tokens")
require.NotNil(t, info, "info")
t.Log("access token", tokens.AccessToken)
t.Log("refresh token", tokens.RefreshToken)
t.Log("id token", tokens.IDToken)
t.Log("email", info.GetEmail())
t.Log("email", info.Email)
accessToken = tokens.AccessToken
refreshToken = tokens.RefreshToken
idToken = tokens.IDToken
email = info.GetEmail()
email = info.Email
http.Redirect(w, r, targetURL, 302)
}
rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider, rp.WithURLParam("custom", "param"))(capturedW, get)
@ -154,7 +263,6 @@ func TestRelyingPartySession(t *testing.T) {
}
}()
require.Less(t, capturedW.Code, 400, "token exchange response code")
require.Less(t, capturedW.Code, 400, "token exchange response code")
// TODO: how to check the custom header was sent to the server?
//nolint:bodyclose
@ -169,43 +277,7 @@ func TestRelyingPartySession(t *testing.T) {
assert.NotEmpty(t, accessToken, "access token")
assert.NotEmpty(t, email, "email")
t.Log("------- refresh tokens ------")
newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "")
require.NoError(t, err, "refresh token")
assert.NotNil(t, newTokens, "access token")
t.Logf("new access token %s", newTokens.AccessToken)
t.Logf("new refresh token %s", newTokens.RefreshToken)
t.Logf("new token type %s", newTokens.TokenType)
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
t.Log("------ end session (logout) ------")
newLoc, err := rp.EndSession(provider, idToken, "", "")
require.NoError(t, err, "logout")
if newLoc != nil {
t.Logf("redirect to %s", newLoc)
} else {
t.Logf("no redirect")
}
t.Log("------ attempt refresh again (should fail) ------")
t.Log("trying original refresh token", refreshToken)
_, err = rp.RefreshAccessToken(provider, refreshToken, "", "")
assert.Errorf(t, err, "refresh with original")
if newTokens.RefreshToken != "" {
t.Log("trying replacement refresh token", newTokens.RefreshToken)
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
assert.Errorf(t, err, "refresh with replacement")
}
t.Run("WithPrompt", func(t *testing.T) {
opts := rp.WithPrompt("foo", "bar")()
url := provider.OAuthConfig().AuthCodeURL("some", opts...)
require.Contains(t, url, "prompt=foo+bar")
})
return provider, accessToken, refreshToken, idToken
}
type deferredHandler struct {

View file

@ -5,8 +5,8 @@ import (
"golang.org/x/oauth2"
"github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
// JWTProfileExchange handles the oauth2 jwt profile exchange

View file

@ -7,8 +7,8 @@ import (
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/pkg/client"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/client"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
// jwtProfileTokenSource implement the oauth2.TokenSource

View file

@ -4,22 +4,22 @@ import (
"context"
"net/http"
"github.com/zitadel/oidc/pkg/client/rp"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
const (
loginPath = "/login"
)
func CodeFlow(ctx context.Context, relyingParty rp.RelyingParty, callbackPath, port string, stateProvider func() string) *oidc.Tokens {
func CodeFlow[C oidc.IDClaims](ctx context.Context, relyingParty rp.RelyingParty, callbackPath, port string, stateProvider func() string) *oidc.Tokens[C] {
codeflowCtx, codeflowCancel := context.WithCancel(ctx)
defer codeflowCancel()
tokenChan := make(chan *oidc.Tokens, 1)
tokenChan := make(chan *oidc.Tokens[C], 1)
callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp rp.RelyingParty) {
tokenChan <- tokens
msg := "<p><strong>Success!</strong></p>"
msg = msg + "<p>You are authenticated and can now return to the CLI.</p>"

View file

@ -1,13 +1,13 @@
package rp
import (
"github.com/zitadel/oidc/pkg/oidc/grants/tokenexchange"
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange"
)
// DelegationTokenRequest is an implementation of TokenExchangeRequest
// it exchanges an "urn:ietf:params:oauth:token-type:access_token" with an optional
//"urn:ietf:params:oauth:token-type:access_token" actor token for an
//"urn:ietf:params:oauth:token-type:access_token" delegation token
// "urn:ietf:params:oauth:token-type:access_token" actor token for an
// "urn:ietf:params:oauth:token-type:access_token" delegation token
func DelegationTokenRequest(subjectToken string, opts ...tokenexchange.TokenExchangeOption) *tokenexchange.TokenExchangeRequest {
return tokenexchange.NewTokenExchangeRequest(subjectToken, tokenexchange.AccessTokenType, opts...)
}

62
pkg/client/rp/device.go Normal file
View file

@ -0,0 +1,62 @@
package rp
import (
"context"
"fmt"
"time"
"github.com/zitadel/oidc/v2/pkg/client"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) {
confg := rp.OAuthConfig()
req := &oidc.ClientCredentialsRequest{
GrantType: oidc.GrantTypeDeviceCode,
Scope: scopes,
ClientID: confg.ClientID,
ClientSecret: confg.ClientSecret,
}
if signer := rp.Signer(); signer != nil {
assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, signer)
if err != nil {
return nil, fmt.Errorf("failed to build assertion: %w", err)
}
req.ClientAssertion = assertion
req.ClientAssertionType = oidc.ClientAssertionTypeJWTAssertion
}
return req, nil
}
// DeviceAuthorization starts a new Device Authorization flow as defined
// in RFC 8628, section 3.1 and 3.2:
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) {
req, err := newDeviceClientCredentialsRequest(scopes, rp)
if err != nil {
return nil, err
}
return client.CallDeviceAuthorizationEndpoint(req, rp)
}
// DeviceAccessToken attempts to obtain tokens from a Device Authorization,
// by means of polling as defined in RFC, section 3.3 and 3.4:
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4
func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
req := &client.DeviceAccessTokenRequest{
DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
GrantType: oidc.GrantTypeDeviceCode,
DeviceCode: deviceCode,
},
}
req.ClientCredentialsRequest, err = newDeviceClientCredentialsRequest(nil, rp)
if err != nil {
return nil, err
}
return client.PollDeviceAccessTokenEndpoint(ctx, interval, req, tokenEndpointCaller{rp})
}

View file

@ -9,8 +9,8 @@ import (
"gopkg.in/square/go-jose.v2"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet {

View file

@ -1,3 +0,0 @@
package mock
//go:generate mockgen -package mock -destination ./verifier.mock.go github.com/zitadel/oidc/pkg/rp Verifier

View file

@ -1,67 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/pkg/rp (interfaces: Verifier)
// Package mock is a generated GoMock package.
package mock
import (
"context"
"reflect"
"github.com/golang/mock/gomock"
"github.com/zitadel/oidc/pkg/oidc"
)
// MockVerifier is a mock of Verifier interface
type MockVerifier struct {
ctrl *gomock.Controller
recorder *MockVerifierMockRecorder
}
// MockVerifierMockRecorder is the mock recorder for MockVerifier
type MockVerifierMockRecorder struct {
mock *MockVerifier
}
// NewMockVerifier creates a new mock instance
func NewMockVerifier(ctrl *gomock.Controller) *MockVerifier {
mock := &MockVerifier{ctrl: ctrl}
mock.recorder = &MockVerifierMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockVerifier) EXPECT() *MockVerifierMockRecorder {
return m.recorder
}
// Verify mocks base method
func (m *MockVerifier) Verify(arg0 context.Context, arg1, arg2 string) (*oidc.IDTokenClaims, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Verify", arg0, arg1, arg2)
ret0, _ := ret[0].(*oidc.IDTokenClaims)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Verify indicates an expected call of Verify
func (mr *MockVerifierMockRecorder) Verify(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockVerifier)(nil).Verify), arg0, arg1, arg2)
}
// VerifyIDToken mocks base method
func (m *MockVerifier) VerifyIDToken(arg0 context.Context, arg1 string) (*oidc.IDTokenClaims, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "VerifyIDToken", arg0, arg1)
ret0, _ := ret[0].(*oidc.IDTokenClaims)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// VerifyIDToken indicates an expected call of VerifyIDToken
func (mr *MockVerifierMockRecorder) VerifyIDToken(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyIDToken", reflect.TypeOf((*MockVerifier)(nil).VerifyIDToken), arg0, arg1)
}

View file

@ -14,9 +14,9 @@ import (
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/pkg/client"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/client"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
const (
@ -54,11 +54,15 @@ type RelyingParty interface {
GetEndSessionEndpoint() string
// GetRevokeEndpoint returns the endpoint to revoke a specific token
// "GetRevokeEndpoint() string" will be added in a future release
GetRevokeEndpoint() string
// UserinfoEndpoint returns the userinfo
UserinfoEndpoint() string
// GetDeviceAuthorizationEndpoint returns the enpoint which can
// be used to start a DeviceAuthorization flow.
GetDeviceAuthorizationEndpoint() string
// IDTokenVerifier returns the verifier interface used for oidc id_token verification
IDTokenVerifier() IDTokenVerifier
// ErrorHandler returns the handler used for callback errors
@ -121,6 +125,10 @@ func (rp *relyingParty) UserinfoEndpoint() string {
return rp.endpoints.UserinfoURL
}
func (rp *relyingParty) GetDeviceAuthorizationEndpoint() string {
return rp.endpoints.DeviceAuthorizationURL
}
func (rp *relyingParty) GetEndSessionEndpoint() string {
return rp.endpoints.EndSessionURL
}
@ -371,7 +379,7 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri
// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
// returning it parsed together with the oauth2 tokens (access, refresh)
func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens, err error) {
func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
codeOpts := make([]oauth2.AuthCodeOption, 0)
for _, opt := range opts {
@ -384,7 +392,7 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod
}
if rp.IsOAuth2Only() {
return &oidc.Tokens{Token: token}, nil
return &oidc.Tokens[C]{Token: token}, nil
}
idTokenString, ok := token.Extra(idTokenKey).(string)
@ -392,21 +400,21 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod
return nil, errors.New("id_token missing")
}
idToken, err := VerifyTokens(ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
if err != nil {
return nil, err
}
return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
}
type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty)
type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty)
// CodeExchangeHandler extends the `CodeExchange` method with a http handler
// including cookie handling for secure `state` transfer
// and optional PKCE code verifier checking.
// Custom paramaters can optionally be set to the token URL.
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
state, err := tryReadStateCookie(w, r, rp)
if err != nil {
@ -439,7 +447,7 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlPara
}
codeOpts = append(codeOpts, WithClientAssertionJWT(assertion))
}
tokens, err := CodeExchange(r.Context(), params.Get("code"), rp, codeOpts...)
tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...)
if err != nil {
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
return
@ -448,13 +456,13 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlPara
}
}
type CodeExchangeUserinfoCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, provider RelyingParty, info oidc.UserInfo)
type CodeExchangeUserinfoCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info *oidc.UserInfo)
// UserinfoCallback wraps the callback function of the CodeExchangeHandler
// and calls the userinfo endpoint with the access token
// on success it will pass the userinfo into its callback function as well
func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty) {
func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] {
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
if err != nil {
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
@ -465,17 +473,17 @@ func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
}
// Userinfo will call the OIDC Userinfo Endpoint with the provided token
func Userinfo(token, tokenType, subject string, rp RelyingParty) (oidc.UserInfo, error) {
func Userinfo(token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) {
req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil)
if err != nil {
return nil, err
}
req.Header.Set("authorization", tokenType+" "+token)
userinfo := oidc.NewUserInfo()
userinfo := new(oidc.UserInfo)
if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil {
return nil, err
}
if userinfo.GetSubject() != subject {
if userinfo.Subject != subject {
return nil, ErrUserInfoSubNotMatching
}
return userinfo, nil
@ -506,11 +514,12 @@ type OptionFunc func(RelyingParty)
type Endpoints struct {
oauth2.Endpoint
IntrospectURL string
UserinfoURL string
JKWsURL string
EndSessionURL string
RevokeURL string
IntrospectURL string
UserinfoURL string
JKWsURL string
EndSessionURL string
RevokeURL string
DeviceAuthorizationURL string
}
func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
@ -520,11 +529,12 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
AuthStyle: oauth2.AuthStyleAutoDetect,
TokenURL: discoveryConfig.TokenEndpoint,
},
IntrospectURL: discoveryConfig.IntrospectionEndpoint,
UserinfoURL: discoveryConfig.UserinfoEndpoint,
JKWsURL: discoveryConfig.JwksURI,
EndSessionURL: discoveryConfig.EndSessionEndpoint,
RevokeURL: discoveryConfig.RevocationEndpoint,
IntrospectURL: discoveryConfig.IntrospectionEndpoint,
UserinfoURL: discoveryConfig.UserinfoEndpoint,
JKWsURL: discoveryConfig.JwksURI,
EndSessionURL: discoveryConfig.EndSessionEndpoint,
RevokeURL: discoveryConfig.RevocationEndpoint,
DeviceAuthorizationURL: discoveryConfig.DeviceAuthorizationEndpoint,
}
}

View file

@ -5,7 +5,7 @@ import (
"golang.org/x/oauth2"
"github.com/zitadel/oidc/pkg/oidc/grants/tokenexchange"
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange"
)
// TokenExchangeRP extends the `RelyingParty` interface for the *draft* oauth2 `Token Exchange`

View file

@ -6,7 +6,7 @@ import (
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type IDTokenVerifier interface {
@ -20,76 +20,78 @@ type IDTokenVerifier interface {
}
// VerifyTokens implement the Token Response Validation as defined in OIDC specification
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
idToken, err := VerifyIDToken(ctx, idTokenString, v)
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) {
var nilClaims C
claims, err = VerifyIDToken[C](ctx, idToken, v)
if err != nil {
return nil, err
return nilClaims, err
}
if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil {
return nil, err
if err := VerifyAccessToken(accessToken, claims.GetAccessTokenHash(), claims.GetSignatureAlgorithm()); err != nil {
return nilClaims, err
}
return idToken, nil
return claims, nil
}
// VerifyIDToken validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
claims := oidc.EmptyIDTokenClaims()
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) {
var nilClaims C
decrypted, err := oidc.DecryptToken(token)
if err != nil {
return nil, err
return nilClaims, err
}
payload, err := oidc.ParseToken(decrypted, claims)
payload, err := oidc.ParseToken(decrypted, &claims)
if err != nil {
return nil, err
return nilClaims, err
}
if err := oidc.CheckSubject(claims); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckAudience(claims, v.ClientID()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
return nil, err
return nilClaims, err
}
return claims, nil
}
// VerifyAccessToken validates the access token according to
//https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
if atHash == "" {
return nil
@ -112,7 +114,7 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...
issuer: issuer,
clientID: clientID,
keySet: keySet,
offset: 1 * time.Second,
offset: time.Second,
nonce: func(_ context.Context) string {
return ""
},
@ -139,7 +141,7 @@ func WithIssuedAtOffset(offset time.Duration) func(*idTokenVerifier) {
// WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
func WithIssuedAtMaxAge(maxAge time.Duration) func(*idTokenVerifier) {
return func(v *idTokenVerifier) {
v.maxAge = maxAge
v.maxAgeIAT = maxAge
}
}

View file

@ -0,0 +1,339 @@
package rp
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v2/internal/testutil"
"github.com/zitadel/oidc/v2/pkg/oidc"
"gopkg.in/square/go-jose.v2"
)
func TestVerifyTokens(t *testing.T) {
verifier := &idTokenVerifier{
issuer: tu.ValidIssuer,
maxAgeIAT: 2 * time.Minute,
offset: time.Second,
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
keySet: tu.KeySet{},
maxAge: 2 * time.Minute,
acr: tu.ACRVerify,
nonce: func(context.Context) string { return tu.ValidNonce },
clientID: tu.ValidClientID,
}
accessToken, _ := tu.ValidAccessToken()
atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm)
require.NoError(t, err)
tests := []struct {
name string
accessToken string
idTokenClaims func() (string, *oidc.IDTokenClaims)
wantErr bool
}{
{
name: "without access token",
idTokenClaims: tu.ValidIDToken,
},
{
name: "with access token",
accessToken: accessToken,
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash,
)
},
},
{
name: "expired id token",
accessToken: accessToken,
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash,
)
},
wantErr: true,
},
{
name: "wrong access token",
accessToken: accessToken,
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "~~~",
)
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
idToken, want := tt.idTokenClaims()
got, err := VerifyTokens[*oidc.IDTokenClaims](context.Background(), tt.accessToken, idToken, verifier)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
return
}
require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, got, want)
})
}
}
func TestVerifyIDToken(t *testing.T) {
verifier := &idTokenVerifier{
issuer: tu.ValidIssuer,
maxAgeIAT: 2 * time.Minute,
offset: time.Second,
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
keySet: tu.KeySet{},
maxAge: 2 * time.Minute,
acr: tu.ACRVerify,
nonce: func(context.Context) string { return tu.ValidNonce },
}
tests := []struct {
name string
clientID string
tokenClaims func() (string, *oidc.IDTokenClaims)
wantErr bool
}{
{
name: "success",
clientID: tu.ValidClientID,
tokenClaims: tu.ValidIDToken,
},
{
name: "parse err",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
wantErr: true,
},
{
name: "invalid signature",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
wantErr: true,
},
{
name: "empty subject",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, "", tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
{
name: "wrong issuer",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
"foo", tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
{
name: "wrong clientID",
clientID: "foo",
tokenClaims: tu.ValidIDToken,
wantErr: true,
},
{
name: "expired",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
{
name: "wrong IAT",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "",
)
},
wantErr: true,
},
{
name: "wrong acr",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
"else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
{
name: "expired auth",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime.Add(-time.Hour), tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
{
name: "wrong nonce",
clientID: tu.ValidClientID,
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, "foo",
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, want := tt.tokenClaims()
verifier.clientID = tt.clientID
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
return
}
require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, got, want)
})
}
}
func TestVerifyAccessToken(t *testing.T) {
token, _ := tu.ValidAccessToken()
hash, err := oidc.ClaimHash(token, tu.SignatureAlgorithm)
require.NoError(t, err)
type args struct {
accessToken string
atHash string
sigAlgorithm jose.SignatureAlgorithm
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "empty hash",
},
{
name: "success",
args: args{
accessToken: token,
atHash: hash,
sigAlgorithm: tu.SignatureAlgorithm,
},
},
{
name: "invalid algorithm",
args: args{
accessToken: token,
atHash: hash,
sigAlgorithm: "foo",
},
wantErr: true,
},
{
name: "mismatch",
args: args{
accessToken: token,
atHash: "~~",
sigAlgorithm: tu.SignatureAlgorithm,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := VerifyAccessToken(tt.args.accessToken, tt.args.atHash, tt.args.sigAlgorithm)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
func TestNewIDTokenVerifier(t *testing.T) {
type args struct {
issuer string
clientID string
keySet oidc.KeySet
options []VerifierOption
}
tests := []struct {
name string
args args
want IDTokenVerifier
}{
{
name: "nil nonce", // otherwise assert.Equal will fail on the function
args: args{
issuer: tu.ValidIssuer,
clientID: tu.ValidClientID,
keySet: tu.KeySet{},
options: []VerifierOption{
WithIssuedAtOffset(time.Minute),
WithIssuedAtMaxAge(time.Hour),
WithNonce(nil), // otherwise assert.Equal will fail on the function
WithACRVerifier(nil),
WithAuthTimeMaxAge(2 * time.Hour),
WithSupportedSigningAlgorithms("ABC", "DEF"),
},
},
want: &idTokenVerifier{
issuer: tu.ValidIssuer,
offset: time.Minute,
maxAgeIAT: time.Hour,
clientID: tu.ValidClientID,
keySet: tu.KeySet{},
nonce: nil,
acr: nil,
maxAge: 2 * time.Hour,
supportedSignAlgs: []string{"ABC", "DEF"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewIDTokenVerifier(tt.args.issuer, tt.args.clientID, tt.args.keySet, tt.args.options...)
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -0,0 +1,86 @@
package rp_test
import (
"context"
"fmt"
tu "github.com/zitadel/oidc/v2/internal/testutil"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
// MyCustomClaims extends the TokenClaims base,
// so it implmeents the oidc.Claims interface.
// Instead of carrying a map, we add needed fields// to the struct for type safe access.
type MyCustomClaims struct {
oidc.TokenClaims
NotBefore oidc.Time `json:"nbf,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"`
Foo string `json:"foo,omitempty"`
Bar *Nested `json:"bar,omitempty"`
}
// GetAccessTokenHash is required to implement
// the oidc.IDClaims interface.
func (c *MyCustomClaims) GetAccessTokenHash() string {
return c.AccessTokenHash
}
// Nested struct types are also possible.
type Nested struct {
Count int `json:"count,omitempty"`
Tags []string `json:"tags,omitempty"`
}
/*
idToken carries the following claims. foo and bar are custom claims
{
"acr": "something",
"amr": [
"foo",
"bar"
],
"at_hash": "2dzbm_vIxy-7eRtqUIGPPw",
"aud": [
"unit",
"test",
"555666"
],
"auth_time": 1678100961,
"azp": "555666",
"bar": {
"count": 22,
"tags": [
"some",
"tags"
]
},
"client_id": "555666",
"exp": 4802238682,
"foo": "Hello, World!",
"iat": 1678101021,
"iss": "local.com",
"jti": "9876",
"nbf": 1678101021,
"nonce": "12345",
"sub": "tim@local.com"
}
*/
const idToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF0X2hhc2giOiIyZHpibV92SXh5LTdlUnRxVUlHUFB3IiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImF1dGhfdGltZSI6MTY3ODEwMDk2MSwiYXpwIjoiNTU1NjY2IiwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiY2xpZW50X2lkIjoiNTU1NjY2IiwiZXhwIjo0ODAyMjM4NjgyLCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MTAxMDIxLCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MTAxMDIxLCJub25jZSI6IjEyMzQ1Iiwic3ViIjoidGltQGxvY2FsLmNvbSJ9.t3GXSfVNNwiW1Suv9_84v0sdn2_-RWHVxhphhRozDXnsO7SDNOlGnEioemXABESxSzMclM7gB7mYy5Qah2ZUNx7eP5t2njoxEYfavgHwx7UJZ2NCg8NDPQyr-hlxelEcfdXK-I0oTd-FRDvF4rqPkD9Us52IpnplChCxnHFgh4wKwPqZZjv2IXVCtn0ilKW3hff1rMOYKEuLRcN2YP0gkyuqyHvcf2dMmjod0t4sLOTJ82rsCbMBC5CLpqv3nIC9HOGITkt1Kd-Am0n1LrdZvWwTo6RFe8AnzF0gpqjcB5Wg4Qeh58DIjZOz4f_8wnmJ_gCqyRh5vfSW4XHdbum0Tw`
const accessToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhdWQiOlsidW5pdCIsInRlc3QiXSwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiZXhwIjo0ODAyMjM4NjgyLCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MTAxMDIxLCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MTAxMDIxLCJzdWIiOiJ0aW1AbG9jYWwuY29tIn0.Zrz3LWSRjCMJZUMaI5dUbW4vGdSmEeJQ3ouhaX0bcW9rdFFLgBI4K2FWJhNivq8JDmCGSxwLu3mI680GWmDaEoAx1M5sCO9lqfIZHGZh-lfAXk27e6FPLlkTDBq8Bx4o4DJ9Fw0hRJGjUTjnYv5cq1vo2-UqldasL6CwTbkzNC_4oQFfRtuodC4Ql7dZ1HRv5LXuYx7KPkOssLZtV9cwtJp5nFzKjcf2zEE_tlbjcpynMwypornRUp1EhCWKRUGkJhJeiP71ECY5pQhShfjBu9Nc5wDpSnZmnk2S4YsPrRK3QkE-iEkas8BfsOCrGoErHjEJexAIDjasGO5PFLWfCA`
func ExampleVerifyTokens_customClaims() {
v := rp.NewIDTokenVerifier("local.com", "555666", tu.KeySet{},
rp.WithNonce(func(ctx context.Context) string { return "12345" }),
)
// VerifyAccessToken can be called with the *MyCustomClaims.
claims, err := rp.VerifyTokens[*MyCustomClaims](context.TODO(), accessToken, idToken, v)
if err != nil {
panic(err)
}
// Here we have typesafe access to the custom claims
fmt.Println(claims.Foo, claims.Bar.Count, claims.Bar.Tags)
// Output: Hello, World! 22 [some tags]
}

View file

@ -6,13 +6,14 @@ import (
"net/http"
"time"
"github.com/zitadel/oidc/pkg/client"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/client"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type ResourceServer interface {
IntrospectionURL() string
TokenEndpoint() string
HttpClient() *http.Client
AuthFn() (interface{}, error)
}
@ -29,6 +30,10 @@ func (r *resourceServer) IntrospectionURL() string {
return r.introspectURL
}
func (r *resourceServer) TokenEndpoint() string {
return r.tokenURL
}
func (r *resourceServer) HttpClient() *http.Client {
return r.httpClient
}
@ -107,7 +112,7 @@ func WithStaticEndpoints(tokenURL, introspectURL string) Option {
}
}
func Introspect(ctx context.Context, rp ResourceServer, token string) (oidc.IntrospectionResponse, error) {
func Introspect(ctx context.Context, rp ResourceServer, token string) (*oidc.IntrospectionResponse, error) {
authFn, err := rp.AuthFn()
if err != nil {
return nil, err
@ -116,7 +121,7 @@ func Introspect(ctx context.Context, rp ResourceServer, token string) (oidc.Intr
if err != nil {
return nil, err
}
resp := oidc.NewIntrospectionResponse()
resp := new(oidc.IntrospectionResponse)
if err := httphelper.HttpRequest(rp.HttpClient(), req, resp); err != nil {
return nil, err
}

View file

@ -0,0 +1,127 @@
package tokenexchange
import (
"errors"
"net/http"
"github.com/zitadel/oidc/v2/pkg/client"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type TokenExchanger interface {
TokenEndpoint() string
HttpClient() *http.Client
AuthFn() (interface{}, error)
}
type OAuthTokenExchange struct {
httpClient *http.Client
tokenEndpoint string
authFn func() (interface{}, error)
}
func NewTokenExchanger(issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
return newOAuthTokenExchange(issuer, nil, options...)
}
func NewTokenExchangerClientCredentials(issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
authorizer := func() (interface{}, error) {
return httphelper.AuthorizeBasic(clientID, clientSecret), nil
}
return newOAuthTokenExchange(issuer, authorizer, options...)
}
func newOAuthTokenExchange(issuer string, authorizer func() (interface{}, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) {
te := &OAuthTokenExchange{
httpClient: httphelper.DefaultHTTPClient,
}
for _, opt := range options {
opt(te)
}
if te.tokenEndpoint == "" {
config, err := client.Discover(issuer, te.httpClient)
if err != nil {
return nil, err
}
te.tokenEndpoint = config.TokenEndpoint
}
if te.tokenEndpoint == "" {
return nil, errors.New("tokenURL is empty: please provide with either `WithStaticTokenEndpoint` or a discovery url")
}
te.authFn = authorizer
return te, nil
}
func WithHTTPClient(client *http.Client) func(*OAuthTokenExchange) {
return func(source *OAuthTokenExchange) {
source.httpClient = client
}
}
func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(*OAuthTokenExchange) {
return func(source *OAuthTokenExchange) {
source.tokenEndpoint = tokenEndpoint
}
}
func (te *OAuthTokenExchange) TokenEndpoint() string {
return te.tokenEndpoint
}
func (te *OAuthTokenExchange) HttpClient() *http.Client {
return te.httpClient
}
func (te *OAuthTokenExchange) AuthFn() (interface{}, error) {
if te.authFn != nil {
return te.authFn()
}
return nil, nil
}
// ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint.
// SubjectToken and SubjectTokenType are required parameters.
func ExchangeToken(
te TokenExchanger,
SubjectToken string,
SubjectTokenType oidc.TokenType,
ActorToken string,
ActorTokenType oidc.TokenType,
Resource []string,
Audience []string,
Scopes []string,
RequestedTokenType oidc.TokenType,
) (*oidc.TokenExchangeResponse, error) {
if SubjectToken == "" {
return nil, errors.New("empty subject_token")
}
if SubjectTokenType == "" {
return nil, errors.New("empty subject_token_type")
}
authFn, err := te.AuthFn()
if err != nil {
return nil, err
}
request := oidc.TokenExchangeRequest{
GrantType: oidc.GrantTypeTokenExchange,
SubjectToken: SubjectToken,
SubjectTokenType: SubjectTokenType,
ActorToken: ActorToken,
ActorTokenType: ActorTokenType,
Resource: Resource,
Audience: Audience,
Scopes: Scopes,
RequestedTokenType: RequestedTokenType,
}
return client.CallTokenExchangeEndpoint(request, authFn, te)
}

View file

@ -3,7 +3,7 @@ package oidc
import (
"crypto/sha256"
"github.com/zitadel/oidc/pkg/crypto"
"github.com/zitadel/oidc/v2/pkg/crypto"
)
const (

View file

@ -0,0 +1,29 @@
package oidc
// DeviceAuthorizationRequest implements
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1,
// 3.1 Device Authorization Request.
type DeviceAuthorizationRequest struct {
Scopes SpaceDelimitedArray `schema:"scope"`
ClientID string `schema:"client_id"`
}
// DeviceAuthorizationResponse implements
// https://www.rfc-editor.org/rfc/rfc8628#section-3.2
// 3.2. Device Authorization Response.
type DeviceAuthorizationResponse struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval,omitempty"`
}
// DeviceAccessTokenRequest implements
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4,
// Device Access Token Request.
type DeviceAccessTokenRequest struct {
GrantType GrantType `json:"grant_type" schema:"grant_type"`
DeviceCode string `json:"device_code" schema:"device_code"`
}

View file

@ -30,6 +30,8 @@ type DiscoveryConfiguration struct {
// EndSessionEndpoint is a URL where the RP can perform a redirect to request that the End-User be logged out at the OP.
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint,omitempty"`
// CheckSessionIframe is a URL where the OP provides an iframe that support cross-origin communications for session state information with the RP Client.
CheckSessionIframe string `json:"check_session_iframe,omitempty"`

View file

@ -18,6 +18,14 @@ const (
InteractionRequired errorType = "interaction_required"
LoginRequired errorType = "login_required"
RequestNotSupported errorType = "request_not_supported"
// Additional error codes as defined in
// https://www.rfc-editor.org/rfc/rfc8628#section-3.5
// Device Access Token Response
AuthorizationPending errorType = "authorization_pending"
SlowDown errorType = "slow_down"
AccessDenied errorType = "access_denied"
ExpiredToken errorType = "expired_token"
)
var (
@ -77,6 +85,32 @@ var (
ErrorType: RequestNotSupported,
}
}
// Device Access Token errors:
ErrAuthorizationPending = func() *Error {
return &Error{
ErrorType: AuthorizationPending,
Description: "The client SHOULD repeat the access token request to the token endpoint, after interval from device authorization response.",
}
}
ErrSlowDown = func() *Error {
return &Error{
ErrorType: SlowDown,
Description: "Polling should continue, but the interval MUST be increased by 5 seconds for this and all subsequent requests.",
}
}
ErrAccessDenied = func() *Error {
return &Error{
ErrorType: AccessDenied,
Description: "The authorization request was denied.",
}
}
ErrExpiredDeviceCode = func() *Error {
return &Error{
ErrorType: ExpiredToken,
Description: "The \"device_code\" has expired.",
}
}
)
type Error struct {

View file

@ -1,12 +1,6 @@
package oidc
import (
"encoding/json"
"fmt"
"time"
"golang.org/x/text/language"
)
import "github.com/muhlemmer/gu"
type IntrospectionRequest struct {
Token string `schema:"token"`
@ -17,36 +11,11 @@ type ClientAssertionParams struct {
ClientAssertionType string `schema:"client_assertion_type"`
}
type IntrospectionResponse interface {
UserInfoSetter
IsActive() bool
SetActive(bool)
SetScopes(scopes []string)
SetClientID(id string)
SetTokenType(tokenType string)
SetExpiration(exp time.Time)
SetIssuedAt(iat time.Time)
SetNotBefore(nbf time.Time)
SetAudience(audience []string)
SetIssuer(issuer string)
SetJWTID(id string)
GetScope() []string
GetClientID() string
GetTokenType() string
GetExpiration() time.Time
GetIssuedAt() time.Time
GetNotBefore() time.Time
GetSubject() string
GetAudience() []string
GetIssuer() string
GetJWTID() string
}
func NewIntrospectionResponse() IntrospectionResponse {
return &introspectionResponse{}
}
type introspectionResponse struct {
// IntrospectionResponse implements RFC 7662, section 2.2 and
// OpenID Connect Core 1.0, section 5.1 (UserInfo).
// https://www.rfc-editor.org/rfc/rfc7662.html#section-2.2.
// https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims.
type IntrospectionResponse struct {
Active bool `json:"active"`
Scope SpaceDelimitedArray `json:"scope,omitempty"`
ClientID string `json:"client_id,omitempty"`
@ -58,323 +27,50 @@ type introspectionResponse struct {
Audience Audience `json:"aud,omitempty"`
Issuer string `json:"iss,omitempty"`
JWTID string `json:"jti,omitempty"`
userInfoProfile
userInfoEmail
userInfoPhone
Username string `json:"username,omitempty"`
UserInfoProfile
UserInfoEmail
UserInfoPhone
Address UserInfoAddress `json:"address,omitempty"`
claims map[string]interface{}
Address *UserInfoAddress `json:"address,omitempty"`
Claims map[string]any `json:"-"`
}
func (i *introspectionResponse) IsActive() bool {
return i.Active
// SetUserInfo copies all relevant fields from UserInfo
// into the IntroSpectionResponse.
func (i *IntrospectionResponse) SetUserInfo(u *UserInfo) {
i.Subject = u.Subject
i.Username = u.PreferredUsername
i.Address = gu.PtrCopy(u.Address)
i.UserInfoProfile = u.UserInfoProfile
i.UserInfoEmail = u.UserInfoEmail
i.UserInfoPhone = u.UserInfoPhone
if i.Claims == nil {
i.Claims = gu.MapCopy(u.Claims)
} else {
gu.MapMerge(u.Claims, i.Claims)
}
}
func (i *introspectionResponse) GetSubject() string {
return i.Subject
}
func (i *introspectionResponse) GetName() string {
return i.Name
}
func (i *introspectionResponse) GetGivenName() string {
return i.GivenName
}
func (i *introspectionResponse) GetFamilyName() string {
return i.FamilyName
}
func (i *introspectionResponse) GetMiddleName() string {
return i.MiddleName
}
func (i *introspectionResponse) GetNickname() string {
return i.Nickname
}
func (i *introspectionResponse) GetProfile() string {
return i.Profile
}
func (i *introspectionResponse) GetPicture() string {
return i.Picture
}
func (i *introspectionResponse) GetWebsite() string {
return i.Website
}
func (i *introspectionResponse) GetGender() Gender {
return i.Gender
}
func (i *introspectionResponse) GetBirthdate() string {
return i.Birthdate
}
func (i *introspectionResponse) GetZoneinfo() string {
return i.Zoneinfo
}
func (i *introspectionResponse) GetLocale() language.Tag {
return i.Locale
}
func (i *introspectionResponse) GetPreferredUsername() string {
return i.PreferredUsername
}
func (i *introspectionResponse) GetEmail() string {
return i.Email
}
func (i *introspectionResponse) IsEmailVerified() bool {
return bool(i.EmailVerified)
}
func (i *introspectionResponse) GetPhoneNumber() string {
return i.PhoneNumber
}
func (i *introspectionResponse) IsPhoneNumberVerified() bool {
return i.PhoneNumberVerified
}
func (i *introspectionResponse) GetAddress() UserInfoAddress {
// GetAddress is a safe getter that takes
// care of a possible nil value.
func (i *IntrospectionResponse) GetAddress() *UserInfoAddress {
if i.Address == nil {
return new(UserInfoAddress)
}
return i.Address
}
func (i *introspectionResponse) GetClaim(key string) interface{} {
return i.claims[key]
}
// introspectionResponseAlias prevents loops on the JSON methods
type introspectionResponseAlias IntrospectionResponse
func (i *introspectionResponse) GetClaims() map[string]interface{} {
return i.claims
}
func (i *introspectionResponse) GetScope() []string {
return []string(i.Scope)
}
func (i *introspectionResponse) GetClientID() string {
return i.ClientID
}
func (i *introspectionResponse) GetTokenType() string {
return i.TokenType
}
func (i *introspectionResponse) GetExpiration() time.Time {
return time.Time(i.Expiration)
}
func (i *introspectionResponse) GetIssuedAt() time.Time {
return time.Time(i.IssuedAt)
}
func (i *introspectionResponse) GetNotBefore() time.Time {
return time.Time(i.NotBefore)
}
func (i *introspectionResponse) GetAudience() []string {
return []string(i.Audience)
}
func (i *introspectionResponse) GetIssuer() string {
return i.Issuer
}
func (i *introspectionResponse) GetJWTID() string {
return i.JWTID
}
func (i *introspectionResponse) SetActive(active bool) {
i.Active = active
}
func (i *introspectionResponse) SetScopes(scope []string) {
i.Scope = scope
}
func (i *introspectionResponse) SetClientID(id string) {
i.ClientID = id
}
func (i *introspectionResponse) SetTokenType(tokenType string) {
i.TokenType = tokenType
}
func (i *introspectionResponse) SetExpiration(exp time.Time) {
i.Expiration = Time(exp)
}
func (i *introspectionResponse) SetIssuedAt(iat time.Time) {
i.IssuedAt = Time(iat)
}
func (i *introspectionResponse) SetNotBefore(nbf time.Time) {
i.NotBefore = Time(nbf)
}
func (i *introspectionResponse) SetAudience(audience []string) {
i.Audience = audience
}
func (i *introspectionResponse) SetIssuer(issuer string) {
i.Issuer = issuer
}
func (i *introspectionResponse) SetJWTID(id string) {
i.JWTID = id
}
func (i *introspectionResponse) SetSubject(sub string) {
i.Subject = sub
}
func (i *introspectionResponse) SetName(name string) {
i.Name = name
}
func (i *introspectionResponse) SetGivenName(name string) {
i.GivenName = name
}
func (i *introspectionResponse) SetFamilyName(name string) {
i.FamilyName = name
}
func (i *introspectionResponse) SetMiddleName(name string) {
i.MiddleName = name
}
func (i *introspectionResponse) SetNickname(name string) {
i.Nickname = name
}
func (i *introspectionResponse) SetUpdatedAt(date time.Time) {
i.UpdatedAt = Time(date)
}
func (i *introspectionResponse) SetProfile(profile string) {
i.Profile = profile
}
func (i *introspectionResponse) SetPicture(picture string) {
i.Picture = picture
}
func (i *introspectionResponse) SetWebsite(website string) {
i.Website = website
}
func (i *introspectionResponse) SetGender(gender Gender) {
i.Gender = gender
}
func (i *introspectionResponse) SetBirthdate(birthdate string) {
i.Birthdate = birthdate
}
func (i *introspectionResponse) SetZoneinfo(zoneInfo string) {
i.Zoneinfo = zoneInfo
}
func (i *introspectionResponse) SetLocale(locale language.Tag) {
i.Locale = locale
}
func (i *introspectionResponse) SetPreferredUsername(name string) {
i.PreferredUsername = name
}
func (i *introspectionResponse) SetEmail(email string, verified bool) {
i.Email = email
i.EmailVerified = boolString(verified)
}
func (i *introspectionResponse) SetPhone(phone string, verified bool) {
i.PhoneNumber = phone
i.PhoneNumberVerified = verified
}
func (i *introspectionResponse) SetAddress(address UserInfoAddress) {
i.Address = address
}
func (i *introspectionResponse) AppendClaims(key string, value interface{}) {
if i.claims == nil {
i.claims = make(map[string]interface{})
func (i *IntrospectionResponse) MarshalJSON() ([]byte, error) {
if i.Username == "" {
i.Username = i.PreferredUsername
}
i.claims[key] = value
return mergeAndMarshalClaims((*introspectionResponseAlias)(i), i.Claims)
}
func (i *introspectionResponse) MarshalJSON() ([]byte, error) {
type Alias introspectionResponse
a := &struct {
*Alias
Expiration int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
Locale interface{} `json:"locale,omitempty"`
UpdatedAt int64 `json:"updated_at,omitempty"`
Username string `json:"username,omitempty"`
}{
Alias: (*Alias)(i),
}
if !i.Locale.IsRoot() {
a.Locale = i.Locale
}
if !time.Time(i.UpdatedAt).IsZero() {
a.UpdatedAt = time.Time(i.UpdatedAt).Unix()
}
if !time.Time(i.Expiration).IsZero() {
a.Expiration = time.Time(i.Expiration).Unix()
}
if !time.Time(i.IssuedAt).IsZero() {
a.IssuedAt = time.Time(i.IssuedAt).Unix()
}
if !time.Time(i.NotBefore).IsZero() {
a.NotBefore = time.Time(i.NotBefore).Unix()
}
a.Username = i.PreferredUsername
b, err := json.Marshal(a)
if err != nil {
return nil, err
}
if len(i.claims) == 0 {
return b, nil
}
err = json.Unmarshal(b, &i.claims)
if err != nil {
return nil, fmt.Errorf("jws: invalid map of custom claims %v", i.claims)
}
return json.Marshal(i.claims)
}
func (i *introspectionResponse) UnmarshalJSON(data []byte) error {
type Alias introspectionResponse
a := &struct {
*Alias
UpdatedAt int64 `json:"update_at,omitempty"`
}{
Alias: (*Alias)(i),
}
if err := json.Unmarshal(data, &a); err != nil {
return err
}
i.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC())
if err := json.Unmarshal(data, &i.claims); err != nil {
return err
}
return nil
func (i *IntrospectionResponse) UnmarshalJSON(data []byte) error {
return unmarshalJSONMulti(data, (*introspectionResponseAlias)(i), &i.Claims)
}

View file

@ -0,0 +1,78 @@
package oidc
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIntrospectionResponse_SetUserInfo(t *testing.T) {
tests := []struct {
name string
start *IntrospectionResponse
want *IntrospectionResponse
}{
{
name: "nil claims",
start: &IntrospectionResponse{},
want: &IntrospectionResponse{
Subject: userInfoData.Subject,
Username: userInfoData.PreferredUsername,
Address: userInfoData.Address,
UserInfoProfile: userInfoData.UserInfoProfile,
UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone,
Claims: userInfoData.Claims,
},
},
{
name: "merge claims",
start: &IntrospectionResponse{
Claims: map[string]any{
"hello": "world",
},
},
want: &IntrospectionResponse{
Subject: userInfoData.Subject,
Username: userInfoData.PreferredUsername,
Address: userInfoData.Address,
UserInfoProfile: userInfoData.UserInfoProfile,
UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone,
Claims: map[string]any{
"foo": "bar",
"hello": "world",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.start.SetUserInfo(userInfoData)
assert.Equal(t, tt.want, tt.start)
})
}
}
func TestIntrospectionResponse_GetAddress(t *testing.T) {
// nil address
i := new(IntrospectionResponse)
assert.Equal(t, &UserInfoAddress{}, i.GetAddress())
i.Address = &UserInfoAddress{PostalCode: "1234"}
assert.Equal(t, i.Address, i.GetAddress())
}
func TestIntrospectionResponse_MarshalJSON(t *testing.T) {
got, err := json.Marshal(&IntrospectionResponse{
UserInfoProfile: UserInfoProfile{
PreferredUsername: "muhlemmer",
},
})
require.NoError(t, err)
assert.Equal(t, string(got), `{"active":false,"username":"muhlemmer","preferred_username":"muhlemmer"}`)
}

View file

@ -0,0 +1,50 @@
//go:build !create_regression_data
package oidc
import (
"encoding/json"
"io"
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test_assert_regression verifies current output from
// json.Marshal to stored regression data.
// These tests are only ran when the create_regression_data
// tag is NOT set.
func Test_assert_regression(t *testing.T) {
buf := new(strings.Builder)
for _, obj := range regressionData {
name := jsonFilename(obj)
t.Run(name, func(t *testing.T) {
file, err := os.Open(name)
require.NoError(t, err)
defer file.Close()
_, err = io.Copy(buf, file)
require.NoError(t, err)
want := buf.String()
buf.Reset()
encodeJSON(t, buf, obj)
first := buf.String()
buf.Reset()
assert.JSONEq(t, want, first)
require.NoError(t,
json.Unmarshal([]byte(first), obj),
)
second, err := json.Marshal(obj)
require.NoError(t, err)
assert.JSONEq(t, want, string(second))
})
}
}

View file

@ -0,0 +1,24 @@
//go:build create_regression_data
package oidc
import (
"os"
"testing"
"github.com/stretchr/testify/require"
)
// Test_create_regression generates the regression data.
// It is excluded from regular testing, unless
// called with the create_regression_data tag:
// go test -tags="create_regression_data" ./pkg/oidc
func Test_create_regression(t *testing.T) {
for _, obj := range regressionData {
file, err := os.Create(jsonFilename(obj))
require.NoError(t, err)
defer file.Close()
encodeJSON(t, file, obj)
}
}

View file

@ -0,0 +1,23 @@
{
"iss": "zitadel",
"sub": "hello@me.com",
"aud": [
"foo",
"bar"
],
"jti": "900",
"azp": "just@me.com",
"nonce": "6969",
"acr": "something",
"amr": [
"some",
"methods"
],
"scope": "email phone",
"client_id": "777",
"exp": 12345,
"iat": 12000,
"nbf": 12000,
"auth_time": 12000,
"foo": "bar"
}

View file

@ -0,0 +1,51 @@
{
"iss": "zitadel",
"aud": [
"foo",
"bar"
],
"jti": "900",
"azp": "just@me.com",
"nonce": "6969",
"at_hash": "acthashhash",
"c_hash": "hashhash",
"acr": "something",
"amr": [
"some",
"methods"
],
"sid": "666",
"client_id": "777",
"exp": 12345,
"iat": 12000,
"nbf": 12000,
"auth_time": 12000,
"address": {
"country": "Moon",
"formatted": "Sesame street 666\n666-666, Smallvile\nMoon",
"locality": "Smallvile",
"postal_code": "666-666",
"region": "Outer space",
"street_address": "Sesame street 666"
},
"birthdate": "1st of April",
"email": "tim@zitadel.com",
"email_verified": true,
"family_name": "Möhlmann",
"foo": "bar",
"gender": "male",
"given_name": "Tim",
"locale": "nl",
"middle_name": "Danger",
"name": "Tim Möhlmann",
"nickname": "muhlemmer",
"phone_number": "+1234567890",
"phone_number_verified": true,
"picture": "https://avatars.githubusercontent.com/u/5411563?v=4",
"preferred_username": "muhlemmer",
"profile": "https://github.com/muhlemmer",
"sub": "hello@me.com",
"updated_at": 1,
"website": "https://zitadel.com",
"zoneinfo": "Europe/Amsterdam"
}

View file

@ -0,0 +1,44 @@
{
"active": true,
"address": {
"country": "Moon",
"formatted": "Sesame street 666\n666-666, Smallvile\nMoon",
"locality": "Smallvile",
"postal_code": "666-666",
"region": "Outer space",
"street_address": "Sesame street 666"
},
"aud": [
"foo",
"bar"
],
"birthdate": "1st of April",
"client_id": "777",
"email": "tim@zitadel.com",
"email_verified": true,
"exp": 12345,
"family_name": "Möhlmann",
"foo": "bar",
"gender": "male",
"given_name": "Tim",
"iat": 12000,
"iss": "zitadel",
"jti": "900",
"locale": "nl",
"middle_name": "Danger",
"name": "Tim Möhlmann",
"nbf": 12000,
"nickname": "muhlemmer",
"phone_number": "+1234567890",
"phone_number_verified": true,
"picture": "https://avatars.githubusercontent.com/u/5411563?v=4",
"preferred_username": "muhlemmer",
"profile": "https://github.com/muhlemmer",
"scope": "email phone",
"sub": "hello@me.com",
"token_type": "idtoken",
"updated_at": 1,
"username": "muhlemmer",
"website": "https://zitadel.com",
"zoneinfo": "Europe/Amsterdam"
}

View file

@ -0,0 +1,11 @@
{
"aud": [
"foo",
"bar"
],
"exp": 12345,
"foo": "bar",
"iat": 12000,
"iss": "zitadel",
"sub": "hello@me.com"
}

View file

@ -0,0 +1,30 @@
{
"address": {
"country": "Moon",
"formatted": "Sesame street 666\n666-666, Smallvile\nMoon",
"locality": "Smallvile",
"postal_code": "666-666",
"region": "Outer space",
"street_address": "Sesame street 666"
},
"birthdate": "1st of April",
"email": "tim@zitadel.com",
"email_verified": true,
"family_name": "Möhlmann",
"foo": "bar",
"gender": "male",
"given_name": "Tim",
"locale": "nl",
"middle_name": "Danger",
"name": "Tim Möhlmann",
"nickname": "muhlemmer",
"phone_number": "+1234567890",
"phone_number_verified": true,
"picture": "https://avatars.githubusercontent.com/u/5411563?v=4",
"preferred_username": "muhlemmer",
"profile": "https://github.com/muhlemmer",
"sub": "hello@me.com",
"updated_at": 1,
"website": "https://zitadel.com",
"zoneinfo": "Europe/Amsterdam"
}

View file

@ -0,0 +1,40 @@
package oidc
// This file contains common functions and data for regression testing
import (
"encoding/json"
"fmt"
"io"
"path"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
const dataDir = "regression_data"
// jsonFilename builds a filename for the regression testdata.
// dataDir/<type_name>.json
func jsonFilename(obj interface{}) string {
name := fmt.Sprintf("%T.json", obj)
return path.Join(
dataDir,
strings.TrimPrefix(name, "*"),
)
}
func encodeJSON(t *testing.T, w io.Writer, obj interface{}) {
enc := json.NewEncoder(w)
enc.SetIndent("", "\t")
require.NoError(t, enc.Encode(obj))
}
var regressionData = []interface{}{
accessTokenData,
idTokenData,
introspectionResponseData,
userInfoData,
jwtProfileAssertionData,
}

View file

@ -2,15 +2,13 @@ package oidc
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"time"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/pkg/crypto"
"github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/v2/pkg/crypto"
)
const (
@ -20,374 +18,174 @@ const (
PrefixBearer = BearerToken + " "
)
type Tokens struct {
type Tokens[C IDClaims] struct {
*oauth2.Token
IDTokenClaims IDTokenClaims
IDTokenClaims C
IDToken string
}
type AccessTokenClaims interface {
Claims
GetSubject() string
GetTokenID() string
SetPrivateClaims(map[string]interface{})
// TokenClaims contains the base Claims used all tokens.
// It implements OpenID Connect Core 1.0, section 2.
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
// And RFC 9068: JSON Web Token (JWT) Profile for OAuth 2.0 Access Tokens,
// section 2.2. https://datatracker.ietf.org/doc/html/rfc9068#name-data-structure
//
// TokenClaims implements the Claims interface,
// and can be used to extend larger claim types by embedding.
type TokenClaims struct {
Issuer string `json:"iss,omitempty"`
Subject string `json:"sub,omitempty"`
Audience Audience `json:"aud,omitempty"`
Expiration Time `json:"exp,omitempty"`
IssuedAt Time `json:"iat,omitempty"`
AuthTime Time `json:"auth_time,omitempty"`
NotBefore Time `json:"nbf,omitempty"`
Nonce string `json:"nonce,omitempty"`
AuthenticationContextClassReference string `json:"acr,omitempty"`
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
AuthorizedParty string `json:"azp,omitempty"`
ClientID string `json:"client_id,omitempty"`
JWTID string `json:"jti,omitempty"`
// Additional information set by this framework
SignatureAlg jose.SignatureAlgorithm `json:"-"`
}
type IDTokenClaims interface {
Claims
GetNotBefore() time.Time
GetJWTID() string
GetAccessTokenHash() string
GetCodeHash() string
GetAuthenticationMethodsReferences() []string
GetClientID() string
GetSignatureAlgorithm() jose.SignatureAlgorithm
SetAccessTokenHash(hash string)
SetUserinfo(userinfo UserInfo)
SetCodeHash(hash string)
UserInfo
func (c *TokenClaims) GetIssuer() string {
return c.Issuer
}
func EmptyAccessTokenClaims() AccessTokenClaims {
return new(accessTokenClaims)
func (c *TokenClaims) GetSubject() string {
return c.Subject
}
func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, id, clientID string, skew time.Duration) AccessTokenClaims {
func (c *TokenClaims) GetAudience() []string {
return c.Audience
}
func (c *TokenClaims) GetExpiration() time.Time {
return c.Expiration.AsTime()
}
func (c *TokenClaims) GetIssuedAt() time.Time {
return c.IssuedAt.AsTime()
}
func (c *TokenClaims) GetNonce() string {
return c.Nonce
}
func (c *TokenClaims) GetAuthTime() time.Time {
return c.AuthTime.AsTime()
}
func (c *TokenClaims) GetAuthorizedParty() string {
return c.AuthorizedParty
}
func (c *TokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm {
return c.SignatureAlg
}
func (c *TokenClaims) GetAuthenticationContextClassReference() string {
return c.AuthenticationContextClassReference
}
func (c *TokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
c.SignatureAlg = algorithm
}
type AccessTokenClaims struct {
TokenClaims
Scopes SpaceDelimitedArray `json:"scope,omitempty"`
Claims map[string]any `json:"-"`
}
func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, jwtid, clientID string, skew time.Duration) *AccessTokenClaims {
now := time.Now().UTC().Add(-skew)
if len(audience) == 0 {
audience = append(audience, clientID)
}
return &accessTokenClaims{
Issuer: issuer,
Subject: subject,
Audience: audience,
Expiration: Time(expiration),
IssuedAt: Time(now),
NotBefore: Time(now),
JWTID: id,
return &AccessTokenClaims{
TokenClaims: TokenClaims{
Issuer: issuer,
Subject: subject,
Audience: audience,
Expiration: FromTime(expiration),
IssuedAt: FromTime(now),
NotBefore: FromTime(now),
JWTID: jwtid,
},
}
}
type accessTokenClaims struct {
Issuer string `json:"iss,omitempty"`
Subject string `json:"sub,omitempty"`
Audience Audience `json:"aud,omitempty"`
Expiration Time `json:"exp,omitempty"`
IssuedAt Time `json:"iat,omitempty"`
NotBefore Time `json:"nbf,omitempty"`
JWTID string `json:"jti,omitempty"`
AuthorizedParty string `json:"azp,omitempty"`
Nonce string `json:"nonce,omitempty"`
AuthTime Time `json:"auth_time,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
AuthenticationContextClassReference string `json:"acr,omitempty"`
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
SessionID string `json:"sid,omitempty"`
Scopes SpaceDelimitedArray `json:"scope,omitempty"`
ClientID string `json:"client_id,omitempty"`
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
type atcAlias AccessTokenClaims
claims map[string]interface{} `json:"-"`
signatureAlg jose.SignatureAlgorithm `json:"-"`
func (a *AccessTokenClaims) MarshalJSON() ([]byte, error) {
return mergeAndMarshalClaims((*atcAlias)(a), a.Claims)
}
// GetIssuer implements the Claims interface
func (a *accessTokenClaims) GetIssuer() string {
return a.Issuer
func (a *AccessTokenClaims) UnmarshalJSON(data []byte) error {
return unmarshalJSONMulti(data, (*atcAlias)(a), &a.Claims)
}
// GetAudience implements the Claims interface
func (a *accessTokenClaims) GetAudience() []string {
return a.Audience
}
// GetExpiration implements the Claims interface
func (a *accessTokenClaims) GetExpiration() time.Time {
return time.Time(a.Expiration)
}
// GetIssuedAt implements the Claims interface
func (a *accessTokenClaims) GetIssuedAt() time.Time {
return time.Time(a.IssuedAt)
}
// GetNonce implements the Claims interface
func (a *accessTokenClaims) GetNonce() string {
return a.Nonce
}
// GetAuthenticationContextClassReference implements the Claims interface
func (a *accessTokenClaims) GetAuthenticationContextClassReference() string {
return a.AuthenticationContextClassReference
}
// GetAuthTime implements the Claims interface
func (a *accessTokenClaims) GetAuthTime() time.Time {
return time.Time(a.AuthTime)
}
// GetAuthorizedParty implements the Claims interface
func (a *accessTokenClaims) GetAuthorizedParty() string {
return a.AuthorizedParty
}
// SetSignatureAlgorithm implements the Claims interface
func (a *accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
a.signatureAlg = algorithm
}
// GetSubject implements the AccessTokenClaims interface
func (a *accessTokenClaims) GetSubject() string {
return a.Subject
}
// GetTokenID implements the AccessTokenClaims interface
func (a *accessTokenClaims) GetTokenID() string {
return a.JWTID
}
// SetPrivateClaims implements the AccessTokenClaims interface
func (a *accessTokenClaims) SetPrivateClaims(claims map[string]interface{}) {
a.claims = claims
}
func (a *accessTokenClaims) MarshalJSON() ([]byte, error) {
type Alias accessTokenClaims
s := &struct {
*Alias
Expiration int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
AuthTime int64 `json:"auth_time,omitempty"`
}{
Alias: (*Alias)(a),
}
if !time.Time(a.Expiration).IsZero() {
s.Expiration = time.Time(a.Expiration).Unix()
}
if !time.Time(a.IssuedAt).IsZero() {
s.IssuedAt = time.Time(a.IssuedAt).Unix()
}
if !time.Time(a.NotBefore).IsZero() {
s.NotBefore = time.Time(a.NotBefore).Unix()
}
if !time.Time(a.AuthTime).IsZero() {
s.AuthTime = time.Time(a.AuthTime).Unix()
}
b, err := json.Marshal(s)
if err != nil {
return nil, err
}
if a.claims == nil {
return b, nil
}
info, err := json.Marshal(a.claims)
if err != nil {
return nil, err
}
return http.ConcatenateJSON(b, info)
}
func (a *accessTokenClaims) UnmarshalJSON(data []byte) error {
type Alias accessTokenClaims
if err := json.Unmarshal(data, (*Alias)(a)); err != nil {
return err
}
claims := make(map[string]interface{})
if err := json.Unmarshal(data, &claims); err != nil {
return err
}
a.claims = claims
return nil
}
func EmptyIDTokenClaims() IDTokenClaims {
return new(idTokenClaims)
}
func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration) IDTokenClaims {
audience = AppendClientIDToAudience(clientID, audience)
return &idTokenClaims{
Issuer: issuer,
Audience: audience,
Expiration: Time(expiration),
IssuedAt: Time(time.Now().UTC().Add(-skew)),
AuthTime: Time(authTime.Add(-skew)),
Nonce: nonce,
AuthenticationContextClassReference: acr,
AuthenticationMethodsReferences: amr,
AuthorizedParty: clientID,
UserInfo: &userinfo{Subject: subject},
}
}
type idTokenClaims struct {
Issuer string `json:"iss,omitempty"`
Audience Audience `json:"aud,omitempty"`
Expiration Time `json:"exp,omitempty"`
NotBefore Time `json:"nbf,omitempty"`
IssuedAt Time `json:"iat,omitempty"`
JWTID string `json:"jti,omitempty"`
AuthorizedParty string `json:"azp,omitempty"`
Nonce string `json:"nonce,omitempty"`
AuthTime Time `json:"auth_time,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
AuthenticationContextClassReference string `json:"acr,omitempty"`
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
ClientID string `json:"client_id,omitempty"`
UserInfo `json:"-"`
signatureAlg jose.SignatureAlgorithm
}
// GetIssuer implements the Claims interface
func (t *idTokenClaims) GetIssuer() string {
return t.Issuer
}
// GetAudience implements the Claims interface
func (t *idTokenClaims) GetAudience() []string {
return t.Audience
}
// GetExpiration implements the Claims interface
func (t *idTokenClaims) GetExpiration() time.Time {
return time.Time(t.Expiration)
}
// GetIssuedAt implements the Claims interface
func (t *idTokenClaims) GetIssuedAt() time.Time {
return time.Time(t.IssuedAt)
}
// GetNonce implements the Claims interface
func (t *idTokenClaims) GetNonce() string {
return t.Nonce
}
// GetAuthenticationContextClassReference implements the Claims interface
func (t *idTokenClaims) GetAuthenticationContextClassReference() string {
return t.AuthenticationContextClassReference
}
// GetAuthTime implements the Claims interface
func (t *idTokenClaims) GetAuthTime() time.Time {
return time.Time(t.AuthTime)
}
// GetAuthorizedParty implements the Claims interface
func (t *idTokenClaims) GetAuthorizedParty() string {
return t.AuthorizedParty
}
// SetSignatureAlgorithm implements the Claims interface
func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) {
t.signatureAlg = alg
}
// GetNotBefore implements the IDTokenClaims interface
func (t *idTokenClaims) GetNotBefore() time.Time {
return time.Time(t.NotBefore)
}
// GetJWTID implements the IDTokenClaims interface
func (t *idTokenClaims) GetJWTID() string {
return t.JWTID
// IDTokenClaims extends TokenClaims by further implementing
// OpenID Connect Core 1.0, sections 3.1.3.6 (Code flow),
// 3.2.2.10 (implicit), 3.3.2.11 (Hybrid) and 5.1 (UserInfo).
// https://openid.net/specs/openid-connect-core-1_0.html#toc
type IDTokenClaims struct {
TokenClaims
NotBefore Time `json:"nbf,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
SessionID string `json:"sid,omitempty"`
UserInfoProfile
UserInfoEmail
UserInfoPhone
Address *UserInfoAddress `json:"address,omitempty"`
Claims map[string]any `json:"-"`
}
// GetAccessTokenHash implements the IDTokenClaims interface
func (t *idTokenClaims) GetAccessTokenHash() string {
func (t *IDTokenClaims) GetAccessTokenHash() string {
return t.AccessTokenHash
}
// GetCodeHash implements the IDTokenClaims interface
func (t *idTokenClaims) GetCodeHash() string {
return t.CodeHash
func (t *IDTokenClaims) SetUserInfo(i *UserInfo) {
t.Subject = i.Subject
t.UserInfoProfile = i.UserInfoProfile
t.UserInfoEmail = i.UserInfoEmail
t.UserInfoPhone = i.UserInfoPhone
t.Address = i.Address
}
// GetAuthenticationMethodsReferences implements the IDTokenClaims interface
func (t *idTokenClaims) GetAuthenticationMethodsReferences() []string {
return t.AuthenticationMethodsReferences
func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration) *IDTokenClaims {
audience = AppendClientIDToAudience(clientID, audience)
return &IDTokenClaims{
TokenClaims: TokenClaims{
Issuer: issuer,
Subject: subject,
Audience: audience,
Expiration: FromTime(expiration),
IssuedAt: FromTime(time.Now().Add(-skew)),
AuthTime: FromTime(authTime.Add(-skew)),
Nonce: nonce,
AuthenticationContextClassReference: acr,
AuthenticationMethodsReferences: amr,
AuthorizedParty: clientID,
ClientID: clientID,
},
}
}
// GetClientID implements the IDTokenClaims interface
func (t *idTokenClaims) GetClientID() string {
return t.ClientID
type itcAlias IDTokenClaims
func (i *IDTokenClaims) MarshalJSON() ([]byte, error) {
return mergeAndMarshalClaims((*itcAlias)(i), i.Claims)
}
// GetSignatureAlgorithm implements the IDTokenClaims interface
func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm {
return t.signatureAlg
}
// SetAccessTokenHash implements the IDTokenClaims interface
func (t *idTokenClaims) SetAccessTokenHash(hash string) {
t.AccessTokenHash = hash
}
// SetUserinfo implements the IDTokenClaims interface
func (t *idTokenClaims) SetUserinfo(info UserInfo) {
t.UserInfo = info
}
// SetCodeHash implements the IDTokenClaims interface
func (t *idTokenClaims) SetCodeHash(hash string) {
t.CodeHash = hash
}
func (t *idTokenClaims) MarshalJSON() ([]byte, error) {
type Alias idTokenClaims
a := &struct {
*Alias
Expiration int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
AuthTime int64 `json:"auth_time,omitempty"`
}{
Alias: (*Alias)(t),
}
if !time.Time(t.Expiration).IsZero() {
a.Expiration = time.Time(t.Expiration).Unix()
}
if !time.Time(t.IssuedAt).IsZero() {
a.IssuedAt = time.Time(t.IssuedAt).Unix()
}
if !time.Time(t.NotBefore).IsZero() {
a.NotBefore = time.Time(t.NotBefore).Unix()
}
if !time.Time(t.AuthTime).IsZero() {
a.AuthTime = time.Time(t.AuthTime).Unix()
}
b, err := json.Marshal(a)
if err != nil {
return nil, err
}
if t.UserInfo == nil {
return b, nil
}
info, err := json.Marshal(t.UserInfo)
if err != nil {
return nil, err
}
return http.ConcatenateJSON(b, info)
}
func (t *idTokenClaims) UnmarshalJSON(data []byte) error {
type Alias idTokenClaims
if err := json.Unmarshal(data, (*Alias)(t)); err != nil {
return err
}
userinfo := new(userinfo)
if err := json.Unmarshal(data, userinfo); err != nil {
return err
}
t.UserInfo = userinfo
return nil
func (i *IDTokenClaims) UnmarshalJSON(data []byte) error {
return unmarshalJSONMulti(data, (*itcAlias)(i), &i.Claims)
}
type AccessTokenResponse struct {
@ -399,19 +197,7 @@ type AccessTokenResponse struct {
State string `json:"state,omitempty" schema:"state,omitempty"`
}
type JWTProfileAssertionClaims interface {
GetKeyID() string
GetPrivateKey() []byte
GetIssuer() string
GetSubject() string
GetAudience() []string
GetExpiration() time.Time
GetIssuedAt() time.Time
SetCustomClaim(key string, value interface{})
GetCustomClaim(key string) interface{}
}
type jwtProfileAssertion struct {
type JWTProfileAssertionClaims struct {
PrivateKeyID string `json:"-"`
PrivateKey []byte `json:"-"`
Issuer string `json:"iss"`
@ -420,91 +206,21 @@ type jwtProfileAssertion struct {
Expiration Time `json:"exp"`
IssuedAt Time `json:"iat"`
customClaims map[string]interface{}
Claims map[string]interface{} `json:"-"`
}
func (j *jwtProfileAssertion) MarshalJSON() ([]byte, error) {
type Alias jwtProfileAssertion
a := (*Alias)(j)
type jpaAlias JWTProfileAssertionClaims
b, err := json.Marshal(a)
if err != nil {
return nil, err
}
if len(j.customClaims) == 0 {
return b, nil
}
err = json.Unmarshal(b, &j.customClaims)
if err != nil {
return nil, fmt.Errorf("jws: invalid map of custom claims %v", j.customClaims)
}
return json.Marshal(j.customClaims)
func (j *JWTProfileAssertionClaims) MarshalJSON() ([]byte, error) {
return mergeAndMarshalClaims((*jpaAlias)(j), j.Claims)
}
func (j *jwtProfileAssertion) UnmarshalJSON(data []byte) error {
type Alias jwtProfileAssertion
a := (*Alias)(j)
err := json.Unmarshal(data, a)
if err != nil {
return err
}
err = json.Unmarshal(data, &j.customClaims)
if err != nil {
return err
}
return nil
func (j *JWTProfileAssertionClaims) UnmarshalJSON(data []byte) error {
return unmarshalJSONMulti(data, (*jpaAlias)(j), &j.Claims)
}
func (j *jwtProfileAssertion) GetKeyID() string {
return j.PrivateKeyID
}
func (j *jwtProfileAssertion) GetPrivateKey() []byte {
return j.PrivateKey
}
func (j *jwtProfileAssertion) SetCustomClaim(key string, value interface{}) {
if j.customClaims == nil {
j.customClaims = make(map[string]interface{})
}
j.customClaims[key] = value
}
func (j *jwtProfileAssertion) GetCustomClaim(key string) interface{} {
if j.customClaims == nil {
return nil
}
return j.customClaims[key]
}
func (j *jwtProfileAssertion) GetIssuer() string {
return j.Issuer
}
func (j *jwtProfileAssertion) GetSubject() string {
return j.Subject
}
func (j *jwtProfileAssertion) GetAudience() []string {
return j.Audience
}
func (j *jwtProfileAssertion) GetExpiration() time.Time {
return time.Time(j.Expiration)
}
func (j *jwtProfileAssertion) GetIssuedAt() time.Time {
return time.Time(j.IssuedAt)
}
func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string, opts ...AssertionOption) (JWTProfileAssertionClaims, error) {
data, err := ioutil.ReadFile(filename)
func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string, opts ...AssertionOption) (*JWTProfileAssertionClaims, error) {
data, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
@ -524,19 +240,19 @@ func NewJWTProfileAssertionStringFromFileData(data []byte, audience []string, op
return GenerateJWTProfileToken(NewJWTProfileAssertion(keyData.UserID, keyData.KeyID, audience, []byte(keyData.Key), opts...))
}
func JWTProfileDelegatedSubject(sub string) func(*jwtProfileAssertion) {
return func(j *jwtProfileAssertion) {
func JWTProfileDelegatedSubject(sub string) func(*JWTProfileAssertionClaims) {
return func(j *JWTProfileAssertionClaims) {
j.Subject = sub
}
}
func JWTProfileCustomClaim(key string, value interface{}) func(*jwtProfileAssertion) {
return func(j *jwtProfileAssertion) {
j.customClaims[key] = value
func JWTProfileCustomClaim(key string, value interface{}) func(*JWTProfileAssertionClaims) {
return func(j *JWTProfileAssertionClaims) {
j.Claims[key] = value
}
}
func NewJWTProfileAssertionFromFileData(data []byte, audience []string, opts ...AssertionOption) (JWTProfileAssertionClaims, error) {
func NewJWTProfileAssertionFromFileData(data []byte, audience []string, opts ...AssertionOption) (*JWTProfileAssertionClaims, error) {
keyData := new(struct {
KeyID string `json:"keyId"`
Key string `json:"key"`
@ -549,18 +265,18 @@ func NewJWTProfileAssertionFromFileData(data []byte, audience []string, opts ...
return NewJWTProfileAssertion(keyData.UserID, keyData.KeyID, audience, []byte(keyData.Key), opts...), nil
}
type AssertionOption func(*jwtProfileAssertion)
type AssertionOption func(*JWTProfileAssertionClaims)
func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte, opts ...AssertionOption) JWTProfileAssertionClaims {
j := &jwtProfileAssertion{
func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte, opts ...AssertionOption) *JWTProfileAssertionClaims {
j := &JWTProfileAssertionClaims{
PrivateKey: key,
PrivateKeyID: keyID,
Issuer: userID,
Subject: userID,
IssuedAt: Time(time.Now().UTC()),
Expiration: Time(time.Now().Add(1 * time.Hour).UTC()),
IssuedAt: FromTime(time.Now().UTC()),
Expiration: FromTime(time.Now().Add(1 * time.Hour).UTC()),
Audience: audience,
customClaims: make(map[string]interface{}),
Claims: make(map[string]interface{}),
}
for _, opt := range opts {
@ -588,14 +304,14 @@ func AppendClientIDToAudience(clientID string, audience []string) []string {
return append(audience, clientID)
}
func GenerateJWTProfileToken(assertion JWTProfileAssertionClaims) (string, error) {
privateKey, err := crypto.BytesToPrivateKey(assertion.GetPrivateKey())
func GenerateJWTProfileToken(assertion *JWTProfileAssertionClaims) (string, error) {
privateKey, err := crypto.BytesToPrivateKey(assertion.PrivateKey)
if err != nil {
return "", err
}
key := jose.SigningKey{
Algorithm: jose.RS256,
Key: &jose.JSONWebKey{Key: privateKey, KeyID: assertion.GetKeyID()},
Key: &jose.JSONWebKey{Key: privateKey, KeyID: assertion.PrivateKeyID},
}
signer, err := jose.NewSigner(key, &jose.SignerOptions{})
if err != nil {
@ -612,3 +328,12 @@ func GenerateJWTProfileToken(assertion JWTProfileAssertionClaims) (string, error
}
return signedAssertion.CompactSerialize()
}
type TokenExchangeResponse struct {
AccessToken string `json:"access_token"` // Can be access token or ID token
IssuedTokenType TokenType `json:"issued_token_type"`
TokenType string `json:"token_type"`
ExpiresIn uint64 `json:"expires_in,omitempty"`
Scopes SpaceDelimitedArray `json:"scope,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
}

View file

@ -27,6 +27,9 @@ const (
// GrantTypeImplicit defines the grant type `implicit` used for implicit flows that skip the generation and exchange of an Authorization Code
GrantTypeImplicit GrantType = "implicit"
// GrantTypeDeviceCode
GrantTypeDeviceCode GrantType = "urn:ietf:params:oauth:grant-type:device_code"
// ClientAssertionTypeJWTAssertion defines the client_assertion_type `urn:ietf:params:oauth:client-assertion-type:jwt-bearer`
// used for the OAuth JWT Profile Client Authentication
ClientAssertionTypeJWTAssertion = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
@ -35,11 +38,34 @@ const (
var AllGrantTypes = []GrantType{
GrantTypeCode, GrantTypeRefreshToken, GrantTypeClientCredentials,
GrantTypeBearer, GrantTypeTokenExchange, GrantTypeImplicit,
ClientAssertionTypeJWTAssertion,
GrantTypeDeviceCode, ClientAssertionTypeJWTAssertion,
}
type GrantType string
const (
AccessTokenType TokenType = "urn:ietf:params:oauth:token-type:access_token"
RefreshTokenType TokenType = "urn:ietf:params:oauth:token-type:refresh_token"
IDTokenType TokenType = "urn:ietf:params:oauth:token-type:id_token"
JWTTokenType TokenType = "urn:ietf:params:oauth:token-type:jwt"
)
var AllTokenTypes = []TokenType{
AccessTokenType, RefreshTokenType, IDTokenType, JWTTokenType,
}
type TokenType string
func (t TokenType) IsSupported() bool {
for _, tt := range AllTokenTypes {
if t == tt {
return true
}
}
return false
}
type TokenRequest interface {
// GrantType GrantType `schema:"grant_type"`
GrantType() GrantType
@ -161,12 +187,12 @@ func (j *JWTTokenRequest) GetAudience() []string {
// GetExpiration implements the Claims interface
func (j *JWTTokenRequest) GetExpiration() time.Time {
return time.Time(j.ExpiresAt)
return j.ExpiresAt.AsTime()
}
// GetIssuedAt implements the Claims interface
func (j *JWTTokenRequest) GetIssuedAt() time.Time {
return time.Time(j.IssuedAt)
return j.ExpiresAt.AsTime()
}
// GetNonce implements the Claims interface
@ -203,19 +229,22 @@ func (j *JWTTokenRequest) GetScopes() []string {
}
type TokenExchangeRequest struct {
subjectToken string `schema:"subject_token"`
subjectTokenType string `schema:"subject_token_type"`
actorToken string `schema:"actor_token"`
actorTokenType string `schema:"actor_token_type"`
resource []string `schema:"resource"`
audience Audience `schema:"audience"`
Scope SpaceDelimitedArray `schema:"scope"`
requestedTokenType string `schema:"requested_token_type"`
GrantType GrantType `schema:"grant_type"`
SubjectToken string `schema:"subject_token"`
SubjectTokenType TokenType `schema:"subject_token_type"`
ActorToken string `schema:"actor_token"`
ActorTokenType TokenType `schema:"actor_token_type"`
Resource []string `schema:"resource"`
Audience Audience `schema:"audience"`
Scopes SpaceDelimitedArray `schema:"scope"`
RequestedTokenType TokenType `schema:"requested_token_type"`
}
type ClientCredentialsRequest struct {
GrantType GrantType `schema:"grant_type"`
Scope SpaceDelimitedArray `schema:"scope"`
ClientID string `schema:"client_id"`
ClientSecret string `schema:"client_secret"`
GrantType GrantType `schema:"grant_type"`
Scope SpaceDelimitedArray `schema:"scope"`
ClientID string `schema:"client_id"`
ClientSecret string `schema:"client_secret"`
ClientAssertion string `schema:"client_assertion"`
ClientAssertionType string `schema:"client_assertion_type"`
}

227
pkg/oidc/token_test.go Normal file
View file

@ -0,0 +1,227 @@
package oidc
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
)
var (
tokenClaimsData = TokenClaims{
Issuer: "zitadel",
Subject: "hello@me.com",
Audience: Audience{"foo", "bar"},
Expiration: 12345,
IssuedAt: 12000,
JWTID: "900",
AuthorizedParty: "just@me.com",
Nonce: "6969",
AuthTime: 12000,
NotBefore: 12000,
AuthenticationContextClassReference: "something",
AuthenticationMethodsReferences: []string{"some", "methods"},
ClientID: "777",
SignatureAlg: jose.ES256,
}
accessTokenData = &AccessTokenClaims{
TokenClaims: tokenClaimsData,
Scopes: []string{"email", "phone"},
Claims: map[string]interface{}{
"foo": "bar",
},
}
idTokenData = &IDTokenClaims{
TokenClaims: tokenClaimsData,
NotBefore: 12000,
AccessTokenHash: "acthashhash",
CodeHash: "hashhash",
SessionID: "666",
UserInfoProfile: userInfoData.UserInfoProfile,
UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone,
Address: userInfoData.Address,
Claims: map[string]interface{}{
"foo": "bar",
},
}
introspectionResponseData = &IntrospectionResponse{
Active: true,
Scope: SpaceDelimitedArray{"email", "phone"},
ClientID: "777",
TokenType: "idtoken",
Expiration: 12345,
IssuedAt: 12000,
NotBefore: 12000,
Subject: "hello@me.com",
Audience: Audience{"foo", "bar"},
Issuer: "zitadel",
JWTID: "900",
Username: "muhlemmer",
UserInfoProfile: userInfoData.UserInfoProfile,
UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone,
Address: userInfoData.Address,
Claims: map[string]interface{}{
"foo": "bar",
},
}
userInfoData = &UserInfo{
Subject: "hello@me.com",
UserInfoProfile: UserInfoProfile{
Name: "Tim Möhlmann",
GivenName: "Tim",
FamilyName: "Möhlmann",
MiddleName: "Danger",
Nickname: "muhlemmer",
Profile: "https://github.com/muhlemmer",
Picture: "https://avatars.githubusercontent.com/u/5411563?v=4",
Website: "https://zitadel.com",
Gender: "male",
Birthdate: "1st of April",
Zoneinfo: "Europe/Amsterdam",
Locale: NewLocale(language.Dutch),
UpdatedAt: 1,
PreferredUsername: "muhlemmer",
},
UserInfoEmail: UserInfoEmail{
Email: "tim@zitadel.com",
EmailVerified: true,
},
UserInfoPhone: UserInfoPhone{
PhoneNumber: "+1234567890",
PhoneNumberVerified: true,
},
Address: &UserInfoAddress{
Formatted: "Sesame street 666\n666-666, Smallvile\nMoon",
StreetAddress: "Sesame street 666",
Locality: "Smallvile",
Region: "Outer space",
PostalCode: "666-666",
Country: "Moon",
},
Claims: map[string]interface{}{
"foo": "bar",
},
}
jwtProfileAssertionData = &JWTProfileAssertionClaims{
PrivateKeyID: "8888",
PrivateKey: []byte("qwerty"),
Issuer: "zitadel",
Subject: "hello@me.com",
Audience: Audience{"foo", "bar"},
Expiration: 12345,
IssuedAt: 12000,
Claims: map[string]interface{}{
"foo": "bar",
},
}
)
func TestTokenClaims(t *testing.T) {
claims := tokenClaimsData
assert.Equal(t, claims.Issuer, tokenClaimsData.GetIssuer())
assert.Equal(t, claims.Subject, tokenClaimsData.GetSubject())
assert.Equal(t, []string(claims.Audience), tokenClaimsData.GetAudience())
assert.Equal(t, claims.Expiration.AsTime(), tokenClaimsData.GetExpiration())
assert.Equal(t, claims.IssuedAt.AsTime(), tokenClaimsData.GetIssuedAt())
assert.Equal(t, claims.Nonce, tokenClaimsData.GetNonce())
assert.Equal(t, claims.AuthTime.AsTime(), tokenClaimsData.GetAuthTime())
assert.Equal(t, claims.AuthorizedParty, tokenClaimsData.GetAuthorizedParty())
assert.Equal(t, claims.SignatureAlg, tokenClaimsData.GetSignatureAlgorithm())
assert.Equal(t, claims.AuthenticationContextClassReference, tokenClaimsData.GetAuthenticationContextClassReference())
claims.SetSignatureAlgorithm(jose.ES384)
assert.Equal(t, jose.ES384, claims.SignatureAlg)
}
func TestNewAccessTokenClaims(t *testing.T) {
want := &AccessTokenClaims{
TokenClaims: TokenClaims{
Issuer: "zitadel",
Subject: "hello@me.com",
Audience: Audience{"foo"},
Expiration: 12345,
JWTID: "900",
},
}
got := NewAccessTokenClaims(
want.Issuer, want.Subject, nil,
want.Expiration.AsTime(), want.JWTID, "foo", time.Second,
)
// test if the dynamic timestamps are around now,
// allowing for a delta of 1, just in case we flip on
// either side of a second boundry.
nowMinusSkew := NowTime() - 1
assert.InDelta(t, int64(nowMinusSkew), int64(got.IssuedAt), 1)
assert.InDelta(t, int64(nowMinusSkew), int64(got.NotBefore), 1)
// Make equal not fail on dynamic timestamp
got.IssuedAt = 0
got.NotBefore = 0
assert.Equal(t, want, got)
}
func TestIDTokenClaims_GetAccessTokenHash(t *testing.T) {
assert.Equal(t, idTokenData.AccessTokenHash, idTokenData.GetAccessTokenHash())
}
func TestIDTokenClaims_SetUserInfo(t *testing.T) {
want := IDTokenClaims{
TokenClaims: TokenClaims{
Subject: userInfoData.Subject,
},
UserInfoProfile: userInfoData.UserInfoProfile,
UserInfoEmail: userInfoData.UserInfoEmail,
UserInfoPhone: userInfoData.UserInfoPhone,
Address: userInfoData.Address,
}
var got IDTokenClaims
got.SetUserInfo(userInfoData)
assert.Equal(t, want, got)
}
func TestNewIDTokenClaims(t *testing.T) {
want := &IDTokenClaims{
TokenClaims: TokenClaims{
Issuer: "zitadel",
Subject: "hello@me.com",
Audience: Audience{"foo", "just@me.com"},
Expiration: 12345,
AuthTime: 12000,
Nonce: "6969",
AuthenticationContextClassReference: "something",
AuthenticationMethodsReferences: []string{"some", "methods"},
AuthorizedParty: "just@me.com",
ClientID: "just@me.com",
},
}
got := NewIDTokenClaims(
want.Issuer, want.Subject, want.Audience,
want.Expiration.AsTime(),
want.AuthTime.AsTime().Add(time.Second),
want.Nonce, want.AuthenticationContextClassReference,
want.AuthenticationMethodsReferences, want.AuthorizedParty,
time.Second,
)
// test if the dynamic timestamp is around now,
// allowing for a delta of 1, just in case we flip on
// either side of a second boundry.
nowMinusSkew := NowTime() - 1
assert.InDelta(t, int64(nowMinusSkew), int64(got.IssuedAt), 1)
// Make equal not fail on dynamic timestamp
got.IssuedAt = 0
assert.Equal(t, want, got)
}

View file

@ -4,9 +4,11 @@ import (
"database/sql/driver"
"encoding/json"
"fmt"
"reflect"
"strings"
"time"
"github.com/gorilla/schema"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
)
@ -44,6 +46,39 @@ func (d *Display) UnmarshalText(text []byte) error {
type Gender string
type Locale struct {
tag language.Tag
}
func NewLocale(tag language.Tag) *Locale {
return &Locale{tag: tag}
}
func (l *Locale) Tag() language.Tag {
if l == nil {
return language.Und
}
return l.tag
}
func (l *Locale) String() string {
return l.Tag().String()
}
func (l *Locale) MarshalJSON() ([]byte, error) {
tag := l.Tag()
if tag.IsRoot() {
return []byte("null"), nil
}
return json.Marshal(tag)
}
func (l *Locale) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &l.tag)
}
type Locales []language.Tag
func (l *Locales) UnmarshalText(text []byte) error {
@ -125,19 +160,52 @@ func (s SpaceDelimitedArray) Value() (driver.Value, error) {
return strings.Join(s, " "), nil
}
type Time time.Time
func (t *Time) UnmarshalJSON(data []byte) error {
var i int64
if err := json.Unmarshal(data, &i); err != nil {
return err
}
*t = Time(time.Unix(i, 0).UTC())
return nil
// NewEncoder returns a schema Encoder with
// a registered encoder for SpaceDelimitedArray.
func NewEncoder() *schema.Encoder {
e := schema.NewEncoder()
e.RegisterEncoder(SpaceDelimitedArray{}, func(value reflect.Value) string {
return value.Interface().(SpaceDelimitedArray).Encode()
})
return e
}
func (t *Time) MarshalJSON() ([]byte, error) {
return json.Marshal(time.Time(*t).UTC().Unix())
type Time int64
func (ts Time) AsTime() time.Time {
return time.Unix(int64(ts), 0)
}
func FromTime(tt time.Time) Time {
return Time(tt.Unix())
}
func NowTime() Time {
return FromTime(time.Now())
}
func (ts *Time) UnmarshalJSON(data []byte) error {
var v any
if err := json.Unmarshal(data, &v); err != nil {
return fmt.Errorf("oidc.Time: %w", err)
}
switch x := v.(type) {
case float64:
*ts = Time(x)
case string:
// Compatibility with Auth0:
// https://github.com/zitadel/oidc/issues/292
tt, err := time.Parse(time.RFC3339, x)
if err != nil {
return fmt.Errorf("oidc.Time: %w", err)
}
*ts = FromTime(tt)
case nil:
*ts = 0
default:
return fmt.Errorf("oidc.Time: unable to parse type %T with value %v", x, x)
}
return nil
}
type RequestObject struct {
@ -150,5 +218,4 @@ func (r *RequestObject) GetIssuer() string {
return r.Issuer
}
func (r *RequestObject) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
}
func (*RequestObject) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {}

View file

@ -3,11 +3,14 @@ package oidc
import (
"bytes"
"encoding/json"
"net/url"
"strconv"
"strings"
"testing"
"github.com/gorilla/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/language"
)
@ -109,6 +112,117 @@ func TestDisplay_UnmarshalText(t *testing.T) {
}
}
func TestLocale_Tag(t *testing.T) {
tests := []struct {
name string
l *Locale
want language.Tag
}{
{
name: "nil",
l: nil,
want: language.Und,
},
{
name: "Und",
l: NewLocale(language.Und),
want: language.Und,
},
{
name: "language",
l: NewLocale(language.Afrikaans),
want: language.Afrikaans,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, tt.l.Tag())
})
}
}
func TestLocale_String(t *testing.T) {
tests := []struct {
name string
l *Locale
want language.Tag
}{
{
name: "nil",
l: nil,
want: language.Und,
},
{
name: "Und",
l: NewLocale(language.Und),
want: language.Und,
},
{
name: "language",
l: NewLocale(language.Afrikaans),
want: language.Afrikaans,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want.String(), tt.l.String())
})
}
}
func TestLocale_MarshalJSON(t *testing.T) {
tests := []struct {
name string
l *Locale
want string
wantErr bool
}{
{
name: "nil",
l: nil,
want: "null",
},
{
name: "und",
l: NewLocale(language.Und),
want: "null",
},
{
name: "language",
l: NewLocale(language.Afrikaans),
want: `"af"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := json.Marshal(tt.l)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.want, string(got))
})
}
}
func TestLocale_UnmarshalJSON(t *testing.T) {
type a struct {
Locale *Locale `json:"locale,omitempty"`
}
want := a{
Locale: NewLocale(language.Afrikaans),
}
const input = `{"locale": "af"}`
var got a
require.NoError(t,
json.Unmarshal([]byte(input), &got),
)
assert.Equal(t, want, got)
}
func TestLocales_UnmarshalText(t *testing.T) {
type args struct {
text []byte
@ -335,3 +449,74 @@ func TestSpaceDelimitatedArray_ValuerNil(t *testing.T) {
assert.Equal(t, SpaceDelimitedArray(nil), reversed, "scan nil")
}
}
func TestNewEncoder(t *testing.T) {
type request struct {
Scopes SpaceDelimitedArray `schema:"scope"`
}
a := request{
Scopes: SpaceDelimitedArray{"foo", "bar"},
}
values := make(url.Values)
NewEncoder().Encode(a, values)
assert.Equal(t, url.Values{"scope": []string{"foo bar"}}, values)
var b request
schema.NewDecoder().Decode(&b, values)
assert.Equal(t, a, b)
}
func TestTime_UnmarshalJSON(t *testing.T) {
type dst struct {
UpdatedAt Time `json:"updated_at"`
}
tests := []struct {
name string
json string
want dst
wantErr bool
}{
{
name: "RFC3339", // https://github.com/zitadel/oidc/issues/292
json: `{"updated_at": "2021-05-11T21:13:25.566Z"}`,
want: dst{UpdatedAt: 1620767605},
},
{
name: "int",
json: `{"updated_at":1620767605}`,
want: dst{UpdatedAt: 1620767605},
},
{
name: "time parse error",
json: `{"updated_at":"foo"}`,
wantErr: true,
},
{
name: "null",
json: `{"updated_at":null}`,
},
{
name: "invalid type",
json: `{"updated_at":["foo","bar"]}`,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got dst
err := json.Unmarshal([]byte(tt.json), &got)
if tt.wantErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.want, got)
})
}
t.Run("syntax error", func(t *testing.T) {
var ts Time
err := ts.UnmarshalJSON([]byte{'~'})
assert.Error(t, err)
})
}

View file

@ -1,320 +1,73 @@
package oidc
import (
"encoding/json"
"fmt"
"time"
"golang.org/x/text/language"
)
type UserInfo interface {
GetSubject() string
// UserInfo implements OpenID Connect Core 1.0, section 5.1.
// https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims.
type UserInfo struct {
Subject string `json:"sub,omitempty"`
UserInfoProfile
UserInfoEmail
UserInfoPhone
GetAddress() UserInfoAddress
GetClaim(key string) interface{}
GetClaims() map[string]interface{}
Address *UserInfoAddress `json:"address,omitempty"`
Claims map[string]any `json:"-"`
}
type UserInfoProfile interface {
GetName() string
GetGivenName() string
GetFamilyName() string
GetMiddleName() string
GetNickname() string
GetProfile() string
GetPicture() string
GetWebsite() string
GetGender() Gender
GetBirthdate() string
GetZoneinfo() string
GetLocale() language.Tag
GetPreferredUsername() string
func (u *UserInfo) AppendClaims(k string, v any) {
if u.Claims == nil {
u.Claims = make(map[string]any)
}
u.Claims[k] = v
}
type UserInfoEmail interface {
GetEmail() string
IsEmailVerified() bool
}
type UserInfoPhone interface {
GetPhoneNumber() string
IsPhoneNumberVerified() bool
}
type UserInfoAddress interface {
GetFormatted() string
GetStreetAddress() string
GetLocality() string
GetRegion() string
GetPostalCode() string
GetCountry() string
}
type UserInfoSetter interface {
UserInfo
SetSubject(sub string)
UserInfoProfileSetter
SetEmail(email string, verified bool)
SetPhone(phone string, verified bool)
SetAddress(address UserInfoAddress)
AppendClaims(key string, values interface{})
}
type UserInfoProfileSetter interface {
SetName(name string)
SetGivenName(name string)
SetFamilyName(name string)
SetMiddleName(name string)
SetNickname(name string)
SetUpdatedAt(date time.Time)
SetProfile(profile string)
SetPicture(profile string)
SetWebsite(website string)
SetGender(gender Gender)
SetBirthdate(birthdate string)
SetZoneinfo(zoneInfo string)
SetLocale(locale language.Tag)
SetPreferredUsername(name string)
}
func NewUserInfo() UserInfoSetter {
return &userinfo{}
}
type userinfo struct {
Subject string `json:"sub,omitempty"`
userInfoProfile
userInfoEmail
userInfoPhone
Address UserInfoAddress `json:"address,omitempty"`
claims map[string]interface{}
}
func (u *userinfo) GetSubject() string {
return u.Subject
}
func (u *userinfo) GetName() string {
return u.Name
}
func (u *userinfo) GetGivenName() string {
return u.GivenName
}
func (u *userinfo) GetFamilyName() string {
return u.FamilyName
}
func (u *userinfo) GetMiddleName() string {
return u.MiddleName
}
func (u *userinfo) GetNickname() string {
return u.Nickname
}
func (u *userinfo) GetProfile() string {
return u.Profile
}
func (u *userinfo) GetPicture() string {
return u.Picture
}
func (u *userinfo) GetWebsite() string {
return u.Website
}
func (u *userinfo) GetGender() Gender {
return u.Gender
}
func (u *userinfo) GetBirthdate() string {
return u.Birthdate
}
func (u *userinfo) GetZoneinfo() string {
return u.Zoneinfo
}
func (u *userinfo) GetLocale() language.Tag {
return u.Locale
}
func (u *userinfo) GetPreferredUsername() string {
return u.PreferredUsername
}
func (u *userinfo) GetEmail() string {
return u.Email
}
func (u *userinfo) IsEmailVerified() bool {
return bool(u.EmailVerified)
}
func (u *userinfo) GetPhoneNumber() string {
return u.PhoneNumber
}
func (u *userinfo) IsPhoneNumberVerified() bool {
return u.PhoneNumberVerified
}
func (u *userinfo) GetAddress() UserInfoAddress {
// GetAddress is a safe getter that takes
// care of a possible nil value.
func (u *UserInfo) GetAddress() *UserInfoAddress {
if u.Address == nil {
return &userInfoAddress{}
return new(UserInfoAddress)
}
return u.Address
}
func (u *userinfo) GetClaim(key string) interface{} {
return u.claims[key]
type uiAlias UserInfo
func (u *UserInfo) MarshalJSON() ([]byte, error) {
return mergeAndMarshalClaims((*uiAlias)(u), u.Claims)
}
func (u *userinfo) GetClaims() map[string]interface{} {
return u.claims
func (u *UserInfo) UnmarshalJSON(data []byte) error {
return unmarshalJSONMulti(data, (*uiAlias)(u), &u.Claims)
}
func (u *userinfo) SetSubject(sub string) {
u.Subject = sub
type UserInfoProfile struct {
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
MiddleName string `json:"middle_name,omitempty"`
Nickname string `json:"nickname,omitempty"`
Profile string `json:"profile,omitempty"`
Picture string `json:"picture,omitempty"`
Website string `json:"website,omitempty"`
Gender Gender `json:"gender,omitempty"`
Birthdate string `json:"birthdate,omitempty"`
Zoneinfo string `json:"zoneinfo,omitempty"`
Locale *Locale `json:"locale,omitempty"`
UpdatedAt Time `json:"updated_at,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
}
func (u *userinfo) SetName(name string) {
u.Name = name
}
func (u *userinfo) SetGivenName(name string) {
u.GivenName = name
}
func (u *userinfo) SetFamilyName(name string) {
u.FamilyName = name
}
func (u *userinfo) SetMiddleName(name string) {
u.MiddleName = name
}
func (u *userinfo) SetNickname(name string) {
u.Nickname = name
}
func (u *userinfo) SetUpdatedAt(date time.Time) {
u.UpdatedAt = Time(date)
}
func (u *userinfo) SetProfile(profile string) {
u.Profile = profile
}
func (u *userinfo) SetPicture(picture string) {
u.Picture = picture
}
func (u *userinfo) SetWebsite(website string) {
u.Website = website
}
func (u *userinfo) SetGender(gender Gender) {
u.Gender = gender
}
func (u *userinfo) SetBirthdate(birthdate string) {
u.Birthdate = birthdate
}
func (u *userinfo) SetZoneinfo(zoneInfo string) {
u.Zoneinfo = zoneInfo
}
func (u *userinfo) SetLocale(locale language.Tag) {
u.Locale = locale
}
func (u *userinfo) SetPreferredUsername(name string) {
u.PreferredUsername = name
}
func (u *userinfo) SetEmail(email string, verified bool) {
u.Email = email
u.EmailVerified = boolString(verified)
}
func (u *userinfo) SetPhone(phone string, verified bool) {
u.PhoneNumber = phone
u.PhoneNumberVerified = verified
}
func (u *userinfo) SetAddress(address UserInfoAddress) {
u.Address = address
}
func (u *userinfo) AppendClaims(key string, value interface{}) {
if u.claims == nil {
u.claims = make(map[string]interface{})
}
u.claims[key] = value
}
func (u *userInfoAddress) GetFormatted() string {
return u.Formatted
}
func (u *userInfoAddress) GetStreetAddress() string {
return u.StreetAddress
}
func (u *userInfoAddress) GetLocality() string {
return u.Locality
}
func (u *userInfoAddress) GetRegion() string {
return u.Region
}
func (u *userInfoAddress) GetPostalCode() string {
return u.PostalCode
}
func (u *userInfoAddress) GetCountry() string {
return u.Country
}
type userInfoProfile struct {
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
MiddleName string `json:"middle_name,omitempty"`
Nickname string `json:"nickname,omitempty"`
Profile string `json:"profile,omitempty"`
Picture string `json:"picture,omitempty"`
Website string `json:"website,omitempty"`
Gender Gender `json:"gender,omitempty"`
Birthdate string `json:"birthdate,omitempty"`
Zoneinfo string `json:"zoneinfo,omitempty"`
Locale language.Tag `json:"locale,omitempty"`
UpdatedAt Time `json:"updated_at,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
}
type userInfoEmail struct {
type UserInfoEmail struct {
Email string `json:"email,omitempty"`
// Handle providers that return email_verified as a string
// https://forums.aws.amazon.com/thread.jspa?messageID=949441&#949441
// https://discuss.elastic.co/t/openid-error-after-authenticating-against-aws-cognito/206018/11
EmailVerified boolString `json:"email_verified,omitempty"`
EmailVerified Bool `json:"email_verified,omitempty"`
}
type boolString bool
type Bool bool
func (bs *boolString) UnmarshalJSON(data []byte) error {
func (bs *Bool) UnmarshalJSON(data []byte) error {
if string(data) == "true" || string(data) == `"true"` {
*bs = true
}
@ -322,12 +75,12 @@ func (bs *boolString) UnmarshalJSON(data []byte) error {
return nil
}
type userInfoPhone struct {
type UserInfoPhone struct {
PhoneNumber string `json:"phone_number,omitempty"`
PhoneNumberVerified bool `json:"phone_number_verified,omitempty"`
}
type userInfoAddress struct {
type UserInfoAddress struct {
Formatted string `json:"formatted,omitempty"`
StreetAddress string `json:"street_address,omitempty"`
Locality string `json:"locality,omitempty"`
@ -336,76 +89,6 @@ type userInfoAddress struct {
Country string `json:"country,omitempty"`
}
func NewUserInfoAddress(streetAddress, locality, region, postalCode, country, formatted string) UserInfoAddress {
return &userInfoAddress{
StreetAddress: streetAddress,
Locality: locality,
Region: region,
PostalCode: postalCode,
Country: country,
Formatted: formatted,
}
}
func (u *userinfo) MarshalJSON() ([]byte, error) {
type Alias userinfo
a := &struct {
*Alias
Locale interface{} `json:"locale,omitempty"`
UpdatedAt int64 `json:"updated_at,omitempty"`
}{
Alias: (*Alias)(u),
}
if !u.Locale.IsRoot() {
a.Locale = u.Locale
}
if !time.Time(u.UpdatedAt).IsZero() {
a.UpdatedAt = time.Time(u.UpdatedAt).Unix()
}
b, err := json.Marshal(a)
if err != nil {
return nil, err
}
if len(u.claims) == 0 {
return b, nil
}
err = json.Unmarshal(b, &u.claims)
if err != nil {
return nil, fmt.Errorf("jws: invalid map of custom claims %v", u.claims)
}
return json.Marshal(u.claims)
}
func (u *userinfo) UnmarshalJSON(data []byte) error {
type Alias userinfo
a := &struct {
Address *userInfoAddress `json:"address,omitempty"`
*Alias
UpdatedAt int64 `json:"update_at,omitempty"`
}{
Alias: (*Alias)(u),
}
if err := json.Unmarshal(data, &a); err != nil {
return err
}
if a.Address != nil {
u.Address = a.Address
}
u.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC())
if err := json.Unmarshal(data, &u.claims); err != nil {
return err
}
return nil
}
type UserInfoRequest struct {
AccessToken string `schema:"access_token"`
}

View file

@ -7,21 +7,54 @@ import (
"github.com/stretchr/testify/assert"
)
func TestUserInfo_AppendClaims(t *testing.T) {
u := new(UserInfo)
u.AppendClaims("a", "b")
want := map[string]any{"a": "b"}
assert.Equal(t, want, u.Claims)
u.AppendClaims("d", "e")
want["d"] = "e"
assert.Equal(t, want, u.Claims)
}
func TestUserInfo_GetAddress(t *testing.T) {
// nil address
u := new(UserInfo)
assert.Equal(t, &UserInfoAddress{}, u.GetAddress())
u.Address = &UserInfoAddress{PostalCode: "1234"}
assert.Equal(t, u.Address, u.GetAddress())
}
func TestUserInfoMarshal(t *testing.T) {
userinfo := NewUserInfo()
userinfo.SetSubject("test")
userinfo.SetAddress(NewUserInfoAddress("Test 789\nPostfach 2", "", "", "", "", ""))
userinfo.SetEmail("test", true)
userinfo.SetPhone("0791234567", true)
userinfo.SetName("Test")
userinfo.AppendClaims("private_claim", "test")
userinfo := &UserInfo{
Subject: "test",
Address: &UserInfoAddress{
StreetAddress: "Test 789\nPostfach 2",
},
UserInfoEmail: UserInfoEmail{
Email: "test",
EmailVerified: true,
},
UserInfoPhone: UserInfoPhone{
PhoneNumber: "0791234567",
PhoneNumberVerified: true,
},
UserInfoProfile: UserInfoProfile{
Name: "Test",
},
Claims: map[string]any{"private_claim": "test"},
}
marshal, err := json.Marshal(userinfo)
out := NewUserInfo()
assert.NoError(t, err)
out := new(UserInfo)
assert.NoError(t, json.Unmarshal(marshal, out))
assert.Equal(t, userinfo.GetAddress(), out.GetAddress())
assert.Equal(t, userinfo, out)
expected, err := json.Marshal(out)
assert.NoError(t, err)
assert.Equal(t, expected, marshal)
}
@ -29,91 +62,55 @@ func TestUserInfoMarshal(t *testing.T) {
func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) {
t.Parallel()
t.Run("unmarsha email_verified from json bool true", func(t *testing.T) {
t.Run("unmarshal email_verified from json bool true", func(t *testing.T) {
jsonBool := []byte(`{"email": "my@email.com", "email_verified": true}`)
var uie userInfoEmail
var uie UserInfoEmail
err := json.Unmarshal(jsonBool, &uie)
assert.NoError(t, err)
assert.Equal(t, userInfoEmail{
assert.Equal(t, UserInfoEmail{
Email: "my@email.com",
EmailVerified: true,
}, uie)
})
t.Run("unmarsha email_verified from json string true", func(t *testing.T) {
t.Run("unmarshal email_verified from json string true", func(t *testing.T) {
jsonBool := []byte(`{"email": "my@email.com", "email_verified": "true"}`)
var uie userInfoEmail
var uie UserInfoEmail
err := json.Unmarshal(jsonBool, &uie)
assert.NoError(t, err)
assert.Equal(t, userInfoEmail{
assert.Equal(t, UserInfoEmail{
Email: "my@email.com",
EmailVerified: true,
}, uie)
})
t.Run("unmarsha email_verified from json bool false", func(t *testing.T) {
t.Run("unmarshal email_verified from json bool false", func(t *testing.T) {
jsonBool := []byte(`{"email": "my@email.com", "email_verified": false}`)
var uie userInfoEmail
var uie UserInfoEmail
err := json.Unmarshal(jsonBool, &uie)
assert.NoError(t, err)
assert.Equal(t, userInfoEmail{
assert.Equal(t, UserInfoEmail{
Email: "my@email.com",
EmailVerified: false,
}, uie)
})
t.Run("unmarsha email_verified from json string false", func(t *testing.T) {
t.Run("unmarshal email_verified from json string false", func(t *testing.T) {
jsonBool := []byte(`{"email": "my@email.com", "email_verified": "false"}`)
var uie userInfoEmail
var uie UserInfoEmail
err := json.Unmarshal(jsonBool, &uie)
assert.NoError(t, err)
assert.Equal(t, userInfoEmail{
assert.Equal(t, UserInfoEmail{
Email: "my@email.com",
EmailVerified: false,
}, uie)
})
}
// issue 203 test case.
func Test_userinfo_GetAddress_issue_203(t *testing.T) {
tests := []struct {
name string
data string
}{
{
name: "with address",
data: `{"address":{"street_address":"Test 789\nPostfach 2"},"email":"test","email_verified":true,"name":"Test","phone_number":"0791234567","phone_number_verified":true,"private_claim":"test","sub":"test"}`,
},
{
name: "without address",
data: `{"email":"test","email_verified":true,"name":"Test","phone_number":"0791234567","phone_number_verified":true,"private_claim":"test","sub":"test"}`,
},
{
name: "null address",
data: `{"address":null,"email":"test","email_verified":true,"name":"Test","phone_number":"0791234567","phone_number_verified":true,"private_claim":"test","sub":"test"}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
info := &userinfo{}
err := json.Unmarshal([]byte(tt.data), info)
assert.NoError(t, err)
info.GetAddress().GetCountry() //<- used to panic
// now shortly assure that a marshalling still produces the same as was parsed into the struct
marshal, err := json.Marshal(info)
assert.NoError(t, err)
assert.Equal(t, tt.data, string(marshal))
})
}
}

49
pkg/oidc/util.go Normal file
View file

@ -0,0 +1,49 @@
package oidc
import (
"bytes"
"encoding/json"
"fmt"
)
// mergeAndMarshalClaims merges registered and the custom
// claims map into a single JSON object.
// Registered fields overwrite custom claims.
func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error) {
// Use a buffer for memory re-use, instead off letting
// json allocate a new []byte for every step.
buf := new(bytes.Buffer)
// Marshal the registered claims into JSON
if err := json.NewEncoder(buf).Encode(registered); err != nil {
return nil, fmt.Errorf("oidc registered claims: %w", err)
}
if len(claims) > 0 {
// Merge JSON data into custom claims.
// The full-read action by the decoder resets the buffer
// to zero len, while retaining underlaying cap.
if err := json.NewDecoder(buf).Decode(&claims); err != nil {
return nil, fmt.Errorf("oidc registered claims: %w", err)
}
// Marshal the final result.
if err := json.NewEncoder(buf).Encode(claims); err != nil {
return nil, fmt.Errorf("oidc custom claims: %w", err)
}
}
return buf.Bytes(), nil
}
// unmarshalJSONMulti unmarshals the same JSON data into multiple destinations.
// Each destination must be a pointer, as per json.Unmarshal rules.
// Returns on the first error and destinations may be partly filled with data.
func unmarshalJSONMulti(data []byte, destinations ...any) error {
for _, dst := range destinations {
if err := json.Unmarshal(data, dst); err != nil {
return fmt.Errorf("oidc: %w into %T", err, dst)
}
}
return nil
}

147
pkg/oidc/util_test.go Normal file
View file

@ -0,0 +1,147 @@
package oidc
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type jsonErrorTest struct{}
func (jsonErrorTest) MarshalJSON() ([]byte, error) {
return nil, errors.New("test")
}
func Test_mergeAndMarshalClaims(t *testing.T) {
type args struct {
registered any
claims map[string]any
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "encoder error",
args: args{
registered: jsonErrorTest{},
},
wantErr: true,
},
{
name: "no claims",
args: args{
registered: struct {
Foo string `json:"foo,omitempty"`
}{
Foo: "bar",
},
},
want: "{\"foo\":\"bar\"}\n",
},
{
name: "with claims",
args: args{
registered: struct {
Foo string `json:"foo,omitempty"`
}{
Foo: "bar",
},
claims: map[string]any{
"bar": "foo",
},
},
want: "{\"bar\":\"foo\",\"foo\":\"bar\"}\n",
},
{
name: "registered overwrites custom",
args: args{
registered: struct {
Foo string `json:"foo,omitempty"`
}{
Foo: "bar",
},
claims: map[string]any{
"foo": "Hello, World!",
},
},
want: "{\"foo\":\"bar\"}\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := mergeAndMarshalClaims(tt.args.registered, tt.args.claims)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.want, string(got))
})
}
}
func Test_unmarshalJSONMulti(t *testing.T) {
type dst struct {
Foo string `json:"foo,omitempty"`
}
type args struct {
data string
destinations []any
}
tests := []struct {
name string
args args
want []any
wantErr bool
}{
{
name: "error",
args: args{
data: "~!~~",
destinations: []any{
&dst{},
&map[string]any{},
},
},
want: []any{
&dst{},
&map[string]any{},
},
wantErr: true,
},
{
name: "success",
args: args{
data: "{\"bar\":\"foo\",\"foo\":\"bar\"}\n",
destinations: []any{
&dst{},
&map[string]any{},
},
},
want: []any{
&dst{Foo: "bar"},
&map[string]any{
"foo": "bar",
"bar": "foo",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := unmarshalJSONMulti([]byte(tt.args.data), tt.args.destinations...)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.want, tt.args.destinations)
})
}
}

View file

@ -12,7 +12,7 @@ import (
"gopkg.in/square/go-jose.v2"
str "github.com/zitadel/oidc/pkg/strings"
str "github.com/zitadel/oidc/v2/pkg/strings"
)
type Claims interface {
@ -32,6 +32,12 @@ type ClaimsSignature interface {
SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm)
}
type IDClaims interface {
Claims
GetSignatureAlgorithm() jose.SignatureAlgorithm
GetAccessTokenHash() string
}
var (
ErrParse = errors.New("parsing of request failed")
ErrIssuerInvalid = errors.New("issuer does not match")

View file

@ -12,9 +12,9 @@ import (
"github.com/gorilla/mux"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
str "github.com/zitadel/oidc/pkg/strings"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
str "github.com/zitadel/oidc/v2/pkg/strings"
)
type AuthRequest interface {
@ -39,10 +39,8 @@ type Authorizer interface {
Storage() Storage
Decoder() httphelper.Decoder
Encoder() httphelper.Encoder
Signer() Signer
IDTokenHintVerifier() IDTokenHintVerifier
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
Crypto() Crypto
Issuer() string
RequestObjectSupported() bool
}
@ -73,8 +71,9 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
ctx := r.Context()
if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
authReq, err = ParseRequestObject(r.Context(), authReq, authorizer.Storage(), authorizer.Issuer())
authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
@ -92,7 +91,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
if validater, ok := authorizer.(AuthorizeValidator); ok {
validation = validater.ValidateAuthRequest
}
userID, err := validation(r.Context(), authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier())
userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx))
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
@ -101,12 +100,12 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder())
return
}
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq, userID)
req, err := authorizer.Storage().CreateAuthRequest(ctx, authReq, userID)
if err != nil {
AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder())
return
}
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
if err != nil {
AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder())
return
@ -390,7 +389,7 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
if idTokenHint == "" {
return "", nil
}
claims, err := VerifyIDTokenHint(ctx, idTokenHint, verifier)
claims, err := VerifyIDTokenHint[*oidc.TokenClaims](ctx, idTokenHint, verifier)
if err != nil {
return "", oidc.ErrLoginRequired().WithDescription("The id_token_hint is invalid. " +
"If you have any questions, you may contact the administrator of the application.")

View file

@ -13,10 +13,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/pkg/op"
"github.com/zitadel/oidc/pkg/op/mock"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v2/pkg/op/mock"
)
//

View file

@ -1,13 +1,19 @@
package op
import (
"context"
"errors"
"net/http"
"net/url"
"time"
"github.com/zitadel/oidc/pkg/oidc"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
//go:generate go get github.com/dmarkham/enumer
//go:generate go run github.com/dmarkham/enumer -linecomment -sql -json -text -yaml -gqlgen -type=ApplicationType,AccessTokenType
//go:generate go mod tidy
const (
ApplicationTypeWeb ApplicationType = iota // web
@ -67,3 +73,95 @@ func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseT
func IsConfidentialType(c Client) bool {
return c.ApplicationType() == ApplicationTypeWeb
}
var (
ErrInvalidAuthHeader = errors.New("invalid basic auth header")
ErrNoClientCredentials = errors.New("no client credentials provided")
ErrMissingClientID = errors.New("client_id missing from request")
)
type ClientJWTProfile interface {
JWTProfileVerifier(context.Context) JWTProfileVerifier
}
func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) {
if ca.ClientAssertion == "" {
return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
}
profile, err := VerifyJWTAssertion(ctx, ca.ClientAssertion, verifier.JWTProfileVerifier(ctx))
if err != nil {
return "", oidc.ErrUnauthorizedClient().WithParent(err).WithDescription("JWT assertion failed")
}
return profile.Issuer, nil
}
func ClientBasicAuth(r *http.Request, storage Storage) (clientID string, err error) {
clientID, clientSecret, ok := r.BasicAuth()
if !ok {
return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
}
clientID, err = url.QueryUnescape(clientID)
if err != nil {
return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
}
clientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
}
if err := storage.AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil {
return "", oidc.ErrUnauthorizedClient().WithParent(err)
}
return clientID, nil
}
type ClientProvider interface {
Decoder() httphelper.Decoder
Storage() Storage
}
type clientData struct {
ClientID string `schema:"client_id"`
oidc.ClientAssertionParams
}
// ClientIDFromRequest parses the request form and tries to obtain the client ID
// and reports if it is authenticated, using a JWT or static client secrets over
// http basic auth.
//
// If the Provider implements IntrospectorJWTProfile and "client_assertion" is
// present in the form data, JWT assertion will be verified and the
// client ID is taken from there.
// If any of them is absent, basic auth is attempted.
// In absence of basic auth data, the unauthenticated client id from the form
// data is returned.
//
// If no client id can be obtained by any method, oidc.ErrInvalidClient
// is returned with ErrMissingClientID wrapped in it.
func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, authenticated bool, err error) {
err = r.ParseForm()
if err != nil {
return "", false, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
}
data := new(clientData)
if err = p.Decoder().Decode(data, r.PostForm); err != nil {
return "", false, err
}
JWTProfile, ok := p.(ClientJWTProfile)
if ok {
clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile)
}
if !ok || errors.Is(err, ErrNoClientCredentials) {
clientID, err = ClientBasicAuth(r, p.Storage())
}
if err == nil {
return clientID, true, nil
}
if data.ClientID == "" {
return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID)
}
return data.ClientID, false, nil
}

253
pkg/op/client_test.go Normal file
View file

@ -0,0 +1,253 @@
package op_test
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v2/pkg/op/mock"
)
type testClientJWTProfile struct{}
func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil }
func TestClientJWTAuth(t *testing.T) {
type args struct {
ctx context.Context
ca oidc.ClientAssertionParams
verifier op.ClientJWTProfile
}
tests := []struct {
name string
args args
wantClientID string
wantErr error
}{
{
name: "empty assertion",
args: args{
context.Background(),
oidc.ClientAssertionParams{},
testClientJWTProfile{},
},
wantErr: op.ErrNoClientCredentials,
},
{
name: "verification error",
args: args{
context.Background(),
oidc.ClientAssertionParams{
ClientAssertion: "foo",
},
testClientJWTProfile{},
},
wantErr: oidc.ErrParse,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotClientID, err := op.ClientJWTAuth(tt.args.ctx, tt.args.ca, tt.args.verifier)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantClientID, gotClientID)
})
}
}
func TestClientBasicAuth(t *testing.T) {
errWrong := errors.New("wrong secret")
type args struct {
username string
password string
}
tests := []struct {
name string
args *args
storage op.Storage
wantClientID string
wantErr error
}{
{
name: "no args",
wantErr: op.ErrNoClientCredentials,
},
{
name: "username unescape err",
args: &args{
username: "%",
password: "bar",
},
wantErr: op.ErrInvalidAuthHeader,
},
{
name: "password unescape err",
args: &args{
username: "foo",
password: "%",
},
wantErr: op.ErrInvalidAuthHeader,
},
{
name: "auth error",
args: &args{
username: "foo",
password: "wrong",
},
storage: func() op.Storage {
s := mock.NewMockStorage(gomock.NewController(t))
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "wrong").Return(errWrong)
return s
}(),
wantErr: errWrong,
},
{
name: "auth error",
args: &args{
username: "foo",
password: "bar",
},
storage: func() op.Storage {
s := mock.NewMockStorage(gomock.NewController(t))
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
return s
}(),
wantClientID: "foo",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/foo", nil)
if tt.args != nil {
r.SetBasicAuth(tt.args.username, tt.args.password)
}
gotClientID, err := op.ClientBasicAuth(r, tt.storage)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantClientID, gotClientID)
})
}
}
type errReader struct{}
func (errReader) Read([]byte) (int, error) {
return 0, io.ErrNoProgress
}
type testClientProvider struct {
storage op.Storage
}
func (testClientProvider) Decoder() httphelper.Decoder {
return schema.NewDecoder()
}
func (p testClientProvider) Storage() op.Storage {
return p.storage
}
func TestClientIDFromRequest(t *testing.T) {
type args struct {
body io.Reader
p op.ClientProvider
}
type basicAuth struct {
username string
password string
}
tests := []struct {
name string
args args
basicAuth *basicAuth
wantClientID string
wantAuthenticated bool
wantErr bool
}{
{
name: "parse error",
args: args{
body: errReader{},
},
wantErr: true,
},
{
name: "unauthenticated",
args: args{
body: strings.NewReader(
url.Values{
"client_id": []string{"foo"},
}.Encode(),
),
p: testClientProvider{
storage: mock.NewStorage(t),
},
},
wantClientID: "foo",
wantAuthenticated: false,
},
{
name: "authenticated",
args: args{
body: strings.NewReader(
url.Values{}.Encode(),
),
p: testClientProvider{
storage: func() op.Storage {
s := mock.NewMockStorage(gomock.NewController(t))
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
return s
}(),
},
},
basicAuth: &basicAuth{
username: "foo",
password: "bar",
},
wantClientID: "foo",
wantAuthenticated: true,
},
{
name: "missing client id",
args: args{
body: strings.NewReader(
url.Values{}.Encode(),
),
p: testClientProvider{
storage: mock.NewStorage(t),
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodPost, "/foo", tt.args.body)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if tt.basicAuth != nil {
r.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
}
gotClientID, gotAuthenticated, err := op.ClientIDFromRequest(r, tt.args.p)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.wantClientID, gotClientID)
assert.Equal(t, tt.wantAuthenticated, gotAuthenticated)
})
}
}

View file

@ -2,20 +2,24 @@ package op
import (
"errors"
"net/http"
"net/url"
"os"
"strings"
"golang.org/x/text/language"
)
const (
OidcDevMode = "ZITADEL_OIDC_DEV"
// deprecated: use OidcDevMode (ZITADEL_OIDC_DEV=true)
devMode = "CAOS_OIDC_DEV"
var (
ErrInvalidIssuerPath = errors.New("no fragments or query allowed for issuer")
ErrInvalidIssuerNoIssuer = errors.New("missing issuer")
ErrInvalidIssuerURL = errors.New("invalid url for issuer")
ErrInvalidIssuerMissingHost = errors.New("host for issuer missing")
ErrInvalidIssuerHTTPS = errors.New("scheme for issuer must be `https`")
)
type Configuration interface {
Issuer() string
IssuerFromRequest(r *http.Request) string
Insecure() bool
AuthorizationEndpoint() Endpoint
TokenEndpoint() Endpoint
IntrospectionEndpoint() Endpoint
@ -23,6 +27,7 @@ type Configuration interface {
RevocationEndpoint() Endpoint
EndSessionEndpoint() Endpoint
KeysEndpoint() Endpoint
DeviceAuthorizationEndpoint() Endpoint
AuthMethodPostSupported() bool
CodeMethodS256Supported() bool
@ -32,6 +37,7 @@ type Configuration interface {
GrantTypeTokenExchangeSupported() bool
GrantTypeJWTAuthorizationSupported() bool
GrantTypeClientCredentialsSupported() bool
GrantTypeDeviceCodeSupported() bool
IntrospectionAuthMethodPrivateKeyJWTSupported() bool
IntrospectionEndpointSigningAlgorithmsSupported() []string
RevocationAuthMethodPrivateKeyJWTSupported() bool
@ -40,38 +46,77 @@ type Configuration interface {
RequestObjectSigningAlgorithmsSupported() []string
SupportedUILocales() []language.Tag
DeviceAuthorization() DeviceAuthorizationConfig
}
func ValidateIssuer(issuer string) error {
type IssuerFromRequest func(r *http.Request) string
func IssuerFromHost(path string) func(bool) (IssuerFromRequest, error) {
return func(allowInsecure bool) (IssuerFromRequest, error) {
issuerPath, err := url.Parse(path)
if err != nil {
return nil, ErrInvalidIssuerURL
}
if err := ValidateIssuerPath(issuerPath); err != nil {
return nil, err
}
return func(r *http.Request) string {
return dynamicIssuer(r.Host, path, allowInsecure)
}, nil
}
}
func StaticIssuer(issuer string) func(bool) (IssuerFromRequest, error) {
return func(allowInsecure bool) (IssuerFromRequest, error) {
if err := ValidateIssuer(issuer, allowInsecure); err != nil {
return nil, err
}
return func(_ *http.Request) string {
return issuer
}, nil
}
}
func ValidateIssuer(issuer string, allowInsecure bool) error {
if issuer == "" {
return errors.New("missing issuer")
return ErrInvalidIssuerNoIssuer
}
u, err := url.Parse(issuer)
if err != nil {
return errors.New("invalid url for issuer")
return ErrInvalidIssuerURL
}
if u.Host == "" {
return errors.New("host for issuer missing")
return ErrInvalidIssuerMissingHost
}
if u.Scheme != "https" {
if !devLocalAllowed(u) {
return errors.New("scheme for issuer must be `https`")
if !devLocalAllowed(u, allowInsecure) {
return ErrInvalidIssuerHTTPS
}
}
if u.Fragment != "" || len(u.Query()) > 0 {
return errors.New("no fragments or query allowed for issuer")
return ValidateIssuerPath(u)
}
func ValidateIssuerPath(issuer *url.URL) error {
if issuer.Fragment != "" || len(issuer.Query()) > 0 {
return ErrInvalidIssuerPath
}
return nil
}
func devLocalAllowed(url *url.URL) bool {
_, b := os.LookupEnv(OidcDevMode)
if !b {
// check the old / current env var as well
_, b = os.LookupEnv(devMode)
if !b {
return b
}
func devLocalAllowed(url *url.URL, allowInsecure bool) bool {
if !allowInsecure {
return false
}
return url.Scheme == "http"
}
func dynamicIssuer(issuer, path string, allowInsecure bool) string {
schema := "https"
if allowInsecure {
schema = "http"
}
if len(path) > 0 && !strings.HasPrefix(path, "/") {
path = "/" + path
}
return schema + "://" + issuer + path
}

View file

@ -1,13 +1,17 @@
package op
import (
"os"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
)
func TestValidateIssuer(t *testing.T) {
type args struct {
issuer string
issuer string
allowInsecure bool
}
tests := []struct {
name string
@ -16,65 +20,97 @@ func TestValidateIssuer(t *testing.T) {
}{
{
"missing issuer fails",
args{""},
args{
issuer: "",
},
true,
},
{
"invalid url for issuer fails",
args{":issuer"},
true,
},
{
"invalid url for issuer fails",
args{":issuer"},
args{
issuer: ":issuer",
},
true,
},
{
"host for issuer missing fails",
args{"https:///issuer"},
true,
},
{
"host for not https fails",
args{"http://issuer.com"},
args{
issuer: "https:///issuer",
},
true,
},
{
"host with fragment fails",
args{"https://issuer.com/#issuer"},
args{
issuer: "https://issuer.com/#issuer",
},
true,
},
{
"host with query fails",
args{"https://issuer.com?issuer=me"},
args{
issuer: "https://issuer.com?issuer=me",
},
true,
},
{
"host with http fails",
args{
issuer: "http://issuer.com",
},
true,
},
{
"host with https ok",
args{"https://issuer.com"},
args{
issuer: "https://issuer.com",
},
false,
},
{
"localhost with http fails",
args{"http://localhost:9999"},
"custom scheme fails",
args{
issuer: "custom://localhost:9999",
},
true,
},
{
"http with allowInsecure ok",
args{
issuer: "http://localhost:9999",
allowInsecure: true,
},
false,
},
{
"https with allowInsecure ok",
args{
issuer: "https://localhost:9999",
allowInsecure: true,
},
false,
},
{
"custom scheme with allowInsecure fails",
args{
issuer: "custom://localhost:9999",
allowInsecure: true,
},
true,
},
}
// ensure env is not set
//nolint:errcheck
os.Unsetenv(OidcDevMode)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
if err := ValidateIssuer(tt.args.issuer, tt.args.allowInsecure); (err != nil) != tt.wantErr {
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestValidateIssuerDevLocalAllowed(t *testing.T) {
func TestValidateIssuerPath(t *testing.T) {
type args struct {
issuer string
issuerPath *url.URL
}
tests := []struct {
name string
@ -82,17 +118,217 @@ func TestValidateIssuerDevLocalAllowed(t *testing.T) {
wantErr bool
}{
{
"localhost with http with dev ok",
args{"http://localhost:9999"},
"empty ok",
args{func() *url.URL {
u, _ := url.Parse("")
return u
}()},
false,
},
{
"custom ok",
args{func() *url.URL {
u, _ := url.Parse("/custom")
return u
}()},
false,
},
{
"fragment fails",
args{func() *url.URL {
u, _ := url.Parse("#fragment")
return u
}()},
true,
},
{
"query fails",
args{func() *url.URL {
u, _ := url.Parse("?query=value")
return u
}()},
true,
},
}
//nolint:errcheck
os.Setenv(OidcDevMode, "true")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
if err := ValidateIssuerPath(tt.args.issuerPath); (err != nil) != tt.wantErr {
t.Errorf("ValidateIssuerPath() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestIssuerFromHost(t *testing.T) {
type args struct {
path string
allowInsecure bool
target string
}
type res struct {
issuer string
err error
}
tests := []struct {
name string
args args
res res
}{
{
"invalid issuer path",
args{
path: "/#fragment",
allowInsecure: false,
},
res{
issuer: "",
err: ErrInvalidIssuerPath,
},
},
{
"empty path secure",
args{
path: "",
allowInsecure: false,
target: "https://issuer.com",
},
res{
issuer: "https://issuer.com",
err: nil,
},
},
{
"custom path secure",
args{
path: "/custom/",
allowInsecure: false,
target: "https://issuer.com",
},
res{
issuer: "https://issuer.com/custom/",
err: nil,
},
},
{
"custom path no leading slash",
args{
path: "custom/",
allowInsecure: false,
target: "https://issuer.com",
},
res{
issuer: "https://issuer.com/custom/",
err: nil,
},
},
{
"empty path unsecure",
args{
path: "",
allowInsecure: true,
target: "http://issuer.com",
},
res{
issuer: "http://issuer.com",
err: nil,
},
},
{
"custom path unsecure",
args{
path: "/custom/",
allowInsecure: true,
target: "http://issuer.com",
},
res{
issuer: "http://issuer.com/custom/",
err: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
issuer, err := IssuerFromHost(tt.args.path)(tt.args.allowInsecure)
if tt.res.err == nil {
assert.NoError(t, err)
req := httptest.NewRequest("", tt.args.target, nil)
assert.Equal(t, tt.res.issuer, issuer(req))
}
if tt.res.err != nil {
assert.ErrorIs(t, err, tt.res.err)
}
})
}
}
func TestStaticIssuer(t *testing.T) {
type args struct {
issuer string
allowInsecure bool
}
type res struct {
issuer string
err error
}
tests := []struct {
name string
args args
res res
}{
{
"invalid issuer",
args{
issuer: "",
allowInsecure: false,
},
res{
issuer: "",
err: ErrInvalidIssuerNoIssuer,
},
},
{
"empty path secure",
args{
issuer: "https://issuer.com",
allowInsecure: false,
},
res{
issuer: "https://issuer.com",
err: nil,
},
},
{
"custom path secure",
args{
issuer: "https://issuer.com/custom/",
allowInsecure: false,
},
res{
issuer: "https://issuer.com/custom/",
err: nil,
},
},
{
"unsecure",
args{
issuer: "http://issuer.com",
allowInsecure: true,
},
res{
issuer: "http://issuer.com",
err: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
issuer, err := StaticIssuer(tt.args.issuer)(tt.args.allowInsecure)
if tt.res.err == nil {
assert.NoError(t, err)
assert.Equal(t, tt.res.issuer, issuer(nil))
}
if tt.res.err != nil {
assert.ErrorIs(t, err, tt.res.err)
}
})
}

53
pkg/op/context.go Normal file
View file

@ -0,0 +1,53 @@
package op
import (
"context"
"net/http"
)
type key int
const (
issuerKey key = 0
)
type IssuerInterceptor struct {
issuerFromRequest IssuerFromRequest
}
// NewIssuerInterceptor will set the issuer into the context
// by the provided IssuerFromRequest (e.g. returned from StaticIssuer or IssuerFromHost)
func NewIssuerInterceptor(issuerFromRequest IssuerFromRequest) *IssuerInterceptor {
return &IssuerInterceptor{
issuerFromRequest: issuerFromRequest,
}
}
func (i *IssuerInterceptor) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
i.setIssuerCtx(w, r, next)
})
}
func (i *IssuerInterceptor) HandlerFunc(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
i.setIssuerCtx(w, r, next)
}
}
// IssuerFromContext reads the issuer from the context (set by an IssuerInterceptor)
// it will return an empty string if not found
func IssuerFromContext(ctx context.Context) string {
ctxIssuer, _ := ctx.Value(issuerKey).(string)
return ctxIssuer
}
// ContextWithIssuer returns a new context with issuer set to it.
func ContextWithIssuer(ctx context.Context, issuer string) context.Context {
return context.WithValue(ctx, issuerKey, issuer)
}
func (i *IssuerInterceptor) setIssuerCtx(w http.ResponseWriter, r *http.Request, next http.Handler) {
r = r.WithContext(ContextWithIssuer(r.Context(), i.issuerFromRequest(r)))
next.ServeHTTP(w, r)
}

76
pkg/op/context_test.go Normal file
View file

@ -0,0 +1,76 @@
package op
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIssuerInterceptor(t *testing.T) {
type fields struct {
issuerFromRequest IssuerFromRequest
}
type args struct {
r *http.Request
next http.Handler
}
type res struct {
issuer string
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"empty",
fields{
func(r *http.Request) string {
return ""
},
},
args{},
res{
issuer: "",
},
},
{
"static",
fields{
func(r *http.Request) string {
return "static"
},
},
args{},
res{
issuer: "static",
},
},
{
"host",
fields{
func(r *http.Request) string {
return r.Host
},
},
args{},
res{
issuer: "issuer.com",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
i := NewIssuerInterceptor(tt.fields.issuerFromRequest)
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
assert.Equal(t, tt.res.issuer, IssuerFromContext(r.Context()))
})
req := httptest.NewRequest("", "https://issuer.com", nil)
i.Handler(next).ServeHTTP(nil, req)
i.HandlerFunc(next).ServeHTTP(nil, req)
})
}
}

View file

@ -1,7 +1,7 @@
package op
import (
"github.com/zitadel/oidc/pkg/crypto"
"github.com/zitadel/oidc/v2/pkg/crypto"
)
type Crypto interface {

265
pkg/op/device.go Normal file
View file

@ -0,0 +1,265 @@
package op
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"math/big"
"net/http"
"strings"
"time"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type DeviceAuthorizationConfig struct {
Lifetime time.Duration
PollInterval time.Duration
UserFormURL string // the URL where the user must go to authorize the device
UserCode UserCodeConfig
}
type UserCodeConfig struct {
CharSet string
CharAmount int
DashInterval int
}
const (
CharSetBase20 = "BCDFGHJKLMNPQRSTVWXZ"
CharSetDigits = "0123456789"
)
var (
UserCodeBase20 = UserCodeConfig{
CharSet: CharSetBase20,
CharAmount: 8,
DashInterval: 4,
}
UserCodeDigits = UserCodeConfig{
CharSet: CharSetDigits,
CharAmount: 9,
DashInterval: 3,
}
)
func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if err := DeviceAuthorization(w, r, o); err != nil {
RequestError(w, r, err)
}
}
}
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error {
storage, err := assertDeviceStorage(o.Storage())
if err != nil {
return err
}
req, err := ParseDeviceCodeRequest(r, o)
if err != nil {
return err
}
config := o.DeviceAuthorization()
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
if err != nil {
return err
}
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval)
if err != nil {
return err
}
expires := time.Now().Add(config.Lifetime)
err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes)
if err != nil {
return err
}
response := &oidc.DeviceAuthorizationResponse{
DeviceCode: deviceCode,
UserCode: userCode,
VerificationURI: config.UserFormURL,
ExpiresIn: int(config.Lifetime / time.Second),
Interval: int(config.PollInterval / time.Second),
}
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode)
httphelper.MarshalJSON(w, response)
return nil
}
func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) {
clientID, _, err := ClientIDFromRequest(r, o)
if err != nil {
return nil, err
}
req := new(oidc.DeviceAuthorizationRequest)
if err := o.Decoder().Decode(req, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err)
}
req.ClientID = clientID
return req, nil
}
// 16 bytes gives 128 bit of entropy.
// results in a 22 character base64 encoded string.
const RecommendedDeviceCodeBytes = 16
func NewDeviceCode(nBytes int) (string, error) {
bytes := make([]byte, nBytes)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("%w getting entropy for device code", err)
}
return base64.RawURLEncoding.EncodeToString(bytes), nil
}
func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) {
var buf strings.Builder
if dashInterval > 0 {
buf.Grow(charAmount + charAmount/dashInterval - 1)
} else {
buf.Grow(charAmount)
}
max := big.NewInt(int64(len(charSet)))
for i := 0; i < charAmount; i++ {
if dashInterval != 0 && i != 0 && i%dashInterval == 0 {
buf.WriteByte('-')
}
bi, err := rand.Int(rand.Reader, max)
if err != nil {
return "", fmt.Errorf("%w getting entropy for user code", err)
}
buf.WriteRune(charSet[int(bi.Int64())])
}
return buf.String(), nil
}
type deviceAccessTokenRequest struct {
subject string
audience []string
scopes []string
}
func (r *deviceAccessTokenRequest) GetSubject() string {
return r.subject
}
func (r *deviceAccessTokenRequest) GetAudience() []string {
return r.audience
}
func (r *deviceAccessTokenRequest) GetScopes() []string {
return r.scopes
}
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
if err := deviceAccessToken(w, r, exchanger); err != nil {
RequestError(w, r, err)
}
}
func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) error {
// use a limited context timeout shorter as the default
// poll interval of 5 seconds.
ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second)
defer cancel()
r = r.WithContext(ctx)
clientID, clientAuthenticated, err := ClientIDFromRequest(r, exchanger)
if err != nil {
return err
}
req, err := ParseDeviceAccessTokenRequest(r, exchanger)
if err != nil {
return err
}
state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
if err != nil {
return err
}
client, err := exchanger.Storage().GetClientByClientID(ctx, clientID)
if err != nil {
return err
}
if clientAuthenticated != IsConfidentialType(client) {
return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
WithDescription("confidential client requires authentication")
}
tokenRequest := &deviceAccessTokenRequest{
subject: state.Subject,
audience: []string{clientID},
scopes: state.Scopes,
}
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
if err != nil {
return err
}
httphelper.MarshalJSON(w, resp)
return nil
}
func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) {
req := new(oidc.DeviceAccessTokenRequest)
if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
return nil, err
}
return req, nil
}
func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) {
storage, err := assertDeviceStorage(exchanger.Storage())
if err != nil {
return nil, err
}
state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode)
if errors.Is(err, context.DeadlineExceeded) {
return nil, oidc.ErrSlowDown().WithParent(err)
}
if err != nil {
return nil, oidc.ErrAccessDenied().WithParent(err)
}
if state.Denied {
return state, oidc.ErrAccessDenied()
}
if state.Done {
return state, nil
}
if time.Now().After(state.Expires) {
return state, oidc.ErrExpiredDeviceCode()
}
return state, oidc.ErrAuthorizationPending()
}
func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client AccessTokenClient) (*oidc.AccessTokenResponse, error) {
accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator, client, "")
if err != nil {
return nil, err
}
return &oidc.AccessTokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: oidc.BearerToken,
ExpiresIn: uint64(validity.Seconds()),
}, nil
}

407
pkg/op/device_test.go Normal file
View file

@ -0,0 +1,407 @@
package op_test
import (
"context"
"crypto/rand"
"encoding/base64"
"io"
mr "math/rand"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
)
func Test_deviceAuthorizationHandler(t *testing.T) {
req := &oidc.DeviceAuthorizationRequest{
Scopes: []string{"foo", "bar"},
ClientID: "web",
}
values := make(url.Values)
testProvider.Encoder().Encode(req, values)
body := strings.NewReader(values.Encode())
r := httptest.NewRequest(http.MethodPost, "/", body)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w := httptest.NewRecorder()
runWithRandReader(mr.New(mr.NewSource(1)), func() {
op.DeviceAuthorizationHandler(testProvider)(w, r)
})
result := w.Result()
assert.Less(t, result.StatusCode, 300)
got, _ := io.ReadAll(result.Body)
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
}
func TestParseDeviceCodeRequest(t *testing.T) {
tests := []struct {
name string
req *oidc.DeviceAuthorizationRequest
wantErr bool
}{
{
name: "empty request",
wantErr: true,
},
{
name: "success",
req: &oidc.DeviceAuthorizationRequest{
Scopes: oidc.SpaceDelimitedArray{"foo", "bar"},
ClientID: "web",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var body io.Reader
if tt.req != nil {
values := make(url.Values)
testProvider.Encoder().Encode(tt.req, values)
body = strings.NewReader(values.Encode())
}
r := httptest.NewRequest(http.MethodPost, "/", body)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
got, err := op.ParseDeviceCodeRequest(r, testProvider)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.req, got)
})
}
}
func runWithRandReader(r io.Reader, f func()) {
originalReader := rand.Reader
rand.Reader = r
defer func() {
rand.Reader = originalReader
}()
f()
}
func TestNewDeviceCode(t *testing.T) {
t.Run("reader error", func(t *testing.T) {
runWithRandReader(errReader{}, func() {
_, err := op.NewDeviceCode(16)
require.Error(t, err)
})
})
t.Run("different lengths, rand reader", func(t *testing.T) {
for i := 1; i <= 32; i++ {
got, err := op.NewDeviceCode(i)
require.NoError(t, err)
assert.Len(t, got, base64.RawURLEncoding.EncodedLen(i))
}
})
}
func TestNewUserCode(t *testing.T) {
type args struct {
charset []rune
charAmount int
dashInterval int
}
tests := []struct {
name string
args args
reader io.Reader
want string
wantErr bool
}{
{
name: "reader error",
args: args{
charset: []rune(op.CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: errReader{},
wantErr: true,
},
{
name: "base20",
args: args{
charset: []rune(op.CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: mr.New(mr.NewSource(1)),
want: "XKCD-HTTD",
},
{
name: "digits",
args: args{
charset: []rune(op.CharSetDigits),
charAmount: 9,
dashInterval: 3,
},
reader: mr.New(mr.NewSource(1)),
want: "271-256-225",
},
{
name: "no dashes",
args: args{
charset: []rune(op.CharSetDigits),
charAmount: 9,
},
reader: mr.New(mr.NewSource(1)),
want: "271256225",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
runWithRandReader(tt.reader, func() {
got, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
if tt.wantErr {
require.ErrorIs(t, err, io.ErrNoProgress)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.want, got)
})
})
}
t.Run("crypto/rand", func(t *testing.T) {
const testN = 100000
for _, c := range []op.UserCodeConfig{op.UserCodeBase20, op.UserCodeDigits} {
t.Run(c.CharSet, func(t *testing.T) {
results := make(map[string]int)
for i := 0; i < testN; i++ {
code, err := op.NewUserCode([]rune(c.CharSet), c.CharAmount, c.DashInterval)
require.NoError(t, err)
results[code]++
}
t.Log(results)
var duplicates int
for code, count := range results {
assert.Less(t, count, 3, code)
if count == 2 {
duplicates++
}
}
})
}
})
}
func BenchmarkNewUserCode(b *testing.B) {
type args struct {
charset []rune
charAmount int
dashInterval int
}
tests := []struct {
name string
args args
reader io.Reader
}{
{
name: "math rand, base20",
args: args{
charset: []rune(op.CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: mr.New(mr.NewSource(1)),
},
{
name: "math rand, digits",
args: args{
charset: []rune(op.CharSetDigits),
charAmount: 9,
dashInterval: 3,
},
reader: mr.New(mr.NewSource(1)),
},
{
name: "crypto rand, base20",
args: args{
charset: []rune(op.CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: rand.Reader,
},
{
name: "crypto rand, digits",
args: args{
charset: []rune(op.CharSetDigits),
charAmount: 9,
dashInterval: 3,
},
reader: rand.Reader,
},
}
for _, tt := range tests {
runWithRandReader(tt.reader, func() {
b.Run(tt.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
require.NoError(b, err)
}
})
})
}
}
func TestDeviceAccessToken(t *testing.T) {
storage := testProvider.Storage().(op.DeviceAuthorizationStorage)
storage.StoreDeviceAuthorization(context.Background(), "native", "qwerty", "yuiop", time.Now().Add(time.Minute), []string{"foo"})
storage.CompleteDeviceAuthorization(context.Background(), "yuiop", "tim")
values := make(url.Values)
values.Set("client_id", "native")
values.Set("grant_type", string(oidc.GrantTypeDeviceCode))
values.Set("device_code", "qwerty")
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(values.Encode()))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w := httptest.NewRecorder()
op.DeviceAccessToken(w, r, testProvider)
result := w.Result()
got, _ := io.ReadAll(result.Body)
t.Log(string(got))
assert.Less(t, result.StatusCode, 300)
assert.NotEmpty(t, string(got))
}
func TestCheckDeviceAuthorizationState(t *testing.T) {
now := time.Now()
storage := testProvider.Storage().(op.DeviceAuthorizationStorage)
storage.StoreDeviceAuthorization(context.Background(), "native", "pending", "pending", now.Add(time.Minute), []string{"foo"})
storage.StoreDeviceAuthorization(context.Background(), "native", "denied", "denied", now.Add(time.Minute), []string{"foo"})
storage.StoreDeviceAuthorization(context.Background(), "native", "completed", "completed", now.Add(time.Minute), []string{"foo"})
storage.StoreDeviceAuthorization(context.Background(), "native", "expired", "expired", now.Add(-time.Minute), []string{"foo"})
storage.DenyDeviceAuthorization(context.Background(), "denied")
storage.CompleteDeviceAuthorization(context.Background(), "completed", "tim")
exceededCtx, cancel := context.WithTimeout(context.Background(), -time.Second)
defer cancel()
type args struct {
ctx context.Context
clientID string
deviceCode string
}
tests := []struct {
name string
args args
want *op.DeviceAuthorizationState
wantErr error
}{
{
name: "pending",
args: args{
ctx: context.Background(),
clientID: "native",
deviceCode: "pending",
},
want: &op.DeviceAuthorizationState{
ClientID: "native",
Scopes: []string{"foo"},
Expires: now.Add(time.Minute),
},
wantErr: oidc.ErrAuthorizationPending(),
},
{
name: "slow down",
args: args{
ctx: exceededCtx,
clientID: "native",
deviceCode: "ok",
},
wantErr: oidc.ErrSlowDown(),
},
{
name: "wrong client",
args: args{
ctx: context.Background(),
clientID: "foo",
deviceCode: "ok",
},
wantErr: oidc.ErrAccessDenied(),
},
{
name: "denied",
args: args{
ctx: context.Background(),
clientID: "native",
deviceCode: "denied",
},
want: &op.DeviceAuthorizationState{
ClientID: "native",
Scopes: []string{"foo"},
Expires: now.Add(time.Minute),
Denied: true,
},
wantErr: oidc.ErrAccessDenied(),
},
{
name: "completed",
args: args{
ctx: context.Background(),
clientID: "native",
deviceCode: "completed",
},
want: &op.DeviceAuthorizationState{
ClientID: "native",
Scopes: []string{"foo"},
Expires: now.Add(time.Minute),
Subject: "tim",
Done: true,
},
},
{
name: "expired",
args: args{
ctx: context.Background(),
clientID: "native",
deviceCode: "expired",
},
want: &op.DeviceAuthorizationState{
ClientID: "native",
Scopes: []string{"foo"},
Expires: now.Add(-time.Minute),
},
wantErr: oidc.ErrExpiredDeviceCode(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := op.CheckDeviceAuthorizationState(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, testProvider)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -1,49 +1,17 @@
package op
import (
"context"
"net/http"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"gopkg.in/square/go-jose.v2"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
func discoveryHandler(c Configuration, s Signer) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
Discover(w, CreateDiscoveryConfig(c, s))
}
}
func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
httphelper.MarshalJSON(w, config)
}
func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfiguration {
return &oidc.DiscoveryConfiguration{
Issuer: c.Issuer(),
AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()),
TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()),
IntrospectionEndpoint: c.IntrospectionEndpoint().Absolute(c.Issuer()),
UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()),
RevocationEndpoint: c.RevocationEndpoint().Absolute(c.Issuer()),
EndSessionEndpoint: c.EndSessionEndpoint().Absolute(c.Issuer()),
JwksURI: c.KeysEndpoint().Absolute(c.Issuer()),
ScopesSupported: Scopes(c),
ResponseTypesSupported: ResponseTypes(c),
GrantTypesSupported: GrantTypes(c),
SubjectTypesSupported: SubjectTypes(c),
IDTokenSigningAlgValuesSupported: SigAlgorithms(s),
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(c),
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(c),
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(c),
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(c),
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(c),
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(c),
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(c),
ClaimsSupported: SupportedClaims(c),
CodeChallengeMethodsSupported: CodeChallengeMethods(c),
UILocalesSupported: c.SupportedUILocales(),
RequestParameterSupported: c.RequestObjectSupported(),
}
type DiscoverStorage interface {
SignatureAlgorithms(context.Context) ([]jose.SignatureAlgorithm, error)
}
var DefaultSupportedScopes = []string{
@ -55,6 +23,47 @@ var DefaultSupportedScopes = []string{
oidc.ScopeOfflineAccess,
}
func discoveryHandler(c Configuration, s DiscoverStorage) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
Discover(w, CreateDiscoveryConfig(r, c, s))
}
}
func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
httphelper.MarshalJSON(w, config)
}
func CreateDiscoveryConfig(r *http.Request, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration {
issuer := config.IssuerFromRequest(r)
return &oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: config.AuthorizationEndpoint().Absolute(issuer),
TokenEndpoint: config.TokenEndpoint().Absolute(issuer),
IntrospectionEndpoint: config.IntrospectionEndpoint().Absolute(issuer),
UserinfoEndpoint: config.UserinfoEndpoint().Absolute(issuer),
RevocationEndpoint: config.RevocationEndpoint().Absolute(issuer),
EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer),
JwksURI: config.KeysEndpoint().Absolute(issuer),
DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer),
ScopesSupported: Scopes(config),
ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config),
SubjectTypesSupported: SubjectTypes(config),
IDTokenSigningAlgValuesSupported: SigAlgorithms(r.Context(), storage),
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config),
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config),
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config),
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config),
ClaimsSupported: SupportedClaims(config),
CodeChallengeMethodsSupported: CodeChallengeMethods(config),
UILocalesSupported: config.SupportedUILocales(),
RequestParameterSupported: config.RequestObjectSupported(),
}
}
func Scopes(c Configuration) []string {
return DefaultSupportedScopes // TODO: config
}
@ -84,9 +93,94 @@ func GrantTypes(c Configuration) []oidc.GrantType {
if c.GrantTypeJWTAuthorizationSupported() {
grantTypes = append(grantTypes, oidc.GrantTypeBearer)
}
if c.GrantTypeDeviceCodeSupported() {
grantTypes = append(grantTypes, oidc.GrantTypeDeviceCode)
}
return grantTypes
}
func SubjectTypes(c Configuration) []string {
return []string{"public"} //TODO: config
}
func SigAlgorithms(ctx context.Context, storage DiscoverStorage) []string {
algorithms, err := storage.SignatureAlgorithms(ctx)
if err != nil {
return nil
}
algs := make([]string, len(algorithms))
for i, algorithm := range algorithms {
algs[i] = string(algorithm)
}
return algs
}
func RequestObjectSigAlgorithms(c Configuration) []string {
if !c.RequestObjectSupported() {
return nil
}
return c.RequestObjectSigningAlgorithmsSupported()
}
func AuthMethodsTokenEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{
oidc.AuthMethodNone,
oidc.AuthMethodBasic,
}
if c.AuthMethodPostSupported() {
authMethods = append(authMethods, oidc.AuthMethodPost)
}
if c.AuthMethodPrivateKeyJWTSupported() {
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
}
return authMethods
}
func TokenSigAlgorithms(c Configuration) []string {
if !c.AuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.TokenEndpointSigningAlgorithmsSupported()
}
func IntrospectionSigAlgorithms(c Configuration) []string {
if !c.IntrospectionAuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.IntrospectionEndpointSigningAlgorithmsSupported()
}
func AuthMethodsIntrospectionEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{
oidc.AuthMethodBasic,
}
if c.AuthMethodPrivateKeyJWTSupported() {
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
}
return authMethods
}
func RevocationSigAlgorithms(c Configuration) []string {
if !c.RevocationAuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.RevocationEndpointSigningAlgorithmsSupported()
}
func AuthMethodsRevocationEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{
oidc.AuthMethodNone,
oidc.AuthMethodBasic,
}
if c.AuthMethodPostSupported() {
authMethods = append(authMethods, oidc.AuthMethodPost)
}
if c.AuthMethodPrivateKeyJWTSupported() {
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
}
return authMethods
}
func SupportedClaims(c Configuration) []string {
return []string{ // TODO: config
"sub",
@ -116,59 +210,6 @@ func SupportedClaims(c Configuration) []string {
}
}
func SigAlgorithms(s Signer) []string {
return []string{string(s.SignatureAlgorithm())}
}
func SubjectTypes(c Configuration) []string {
return []string{"public"} // TODO: config
}
func AuthMethodsTokenEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{
oidc.AuthMethodNone,
oidc.AuthMethodBasic,
}
if c.AuthMethodPostSupported() {
authMethods = append(authMethods, oidc.AuthMethodPost)
}
if c.AuthMethodPrivateKeyJWTSupported() {
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
}
return authMethods
}
func TokenSigAlgorithms(c Configuration) []string {
if !c.AuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.TokenEndpointSigningAlgorithmsSupported()
}
func AuthMethodsIntrospectionEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{
oidc.AuthMethodBasic,
}
if c.AuthMethodPrivateKeyJWTSupported() {
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
}
return authMethods
}
func AuthMethodsRevocationEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{
oidc.AuthMethodNone,
oidc.AuthMethodBasic,
}
if c.AuthMethodPostSupported() {
authMethods = append(authMethods, oidc.AuthMethodPost)
}
if c.AuthMethodPrivateKeyJWTSupported() {
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
}
return authMethods
}
func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod {
codeMethods := make([]oidc.CodeChallengeMethod, 0, 1)
if c.CodeMethodS256Supported() {
@ -176,24 +217,3 @@ func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod {
}
return codeMethods
}
func IntrospectionSigAlgorithms(c Configuration) []string {
if !c.IntrospectionAuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.IntrospectionEndpointSigningAlgorithmsSupported()
}
func RevocationSigAlgorithms(c Configuration) []string {
if !c.RevocationAuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.RevocationEndpointSigningAlgorithmsSupported()
}
func RequestObjectSigAlgorithms(c Configuration) []string {
if !c.RequestObjectSupported() {
return nil
}
return c.RequestObjectSigningAlgorithmsSupported()
}

View file

@ -1,18 +1,19 @@
package op_test
import (
"context"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/pkg/op"
"github.com/zitadel/oidc/pkg/op/mock"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v2/pkg/op/mock"
)
func TestDiscover(t *testing.T) {
@ -47,8 +48,9 @@ func TestDiscover(t *testing.T) {
func TestCreateDiscoveryConfig(t *testing.T) {
type args struct {
c op.Configuration
s op.Signer
request *http.Request
c op.Configuration
s op.DiscoverStorage
}
tests := []struct {
name string
@ -59,9 +61,8 @@ func TestCreateDiscoveryConfig(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) {
t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want)
}
got := op.CreateDiscoveryConfig(tt.args.request, tt.args.c, tt.args.s)
assert.Equal(t, tt.want, got)
})
}
}
@ -83,9 +84,8 @@ func Test_scopes(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.Scopes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("scopes() = %v, want %v", got, tt.want)
}
got := op.Scopes(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
@ -99,13 +99,16 @@ func Test_ResponseTypes(t *testing.T) {
args args
want []string
}{
// TODO: Add test cases.
{
"code and implicit flow",
args{},
[]string{"code", "id_token", "id_token token"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.ResponseTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("responseTypes() = %v, want %v", got, tt.want)
}
got := op.ResponseTypes(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
@ -117,63 +120,53 @@ func Test_GrantTypes(t *testing.T) {
tests := []struct {
name string
args args
want []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.GrantTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("grantTypes() = %v, want %v", got, tt.want)
}
})
}
}
func TestSupportedClaims(t *testing.T) {
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.SupportedClaims(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("SupportedClaims() = %v, want %v", got, tt.want)
}
})
}
}
func Test_SigAlgorithms(t *testing.T) {
m := mock.NewMockSigner(gomock.NewController(t))
type args struct {
s op.Signer
}
tests := []struct {
name string
args args
want []string
want []oidc.GrantType
}{
{
"",
args{func() op.Signer {
m.EXPECT().SignatureAlgorithm().Return(jose.RS256)
return m
}()},
[]string{"RS256"},
"code and implicit flow",
args{
func() op.Configuration {
c := mock.NewMockConfiguration(gomock.NewController(t))
c.EXPECT().GrantTypeRefreshTokenSupported().Return(false)
c.EXPECT().GrantTypeTokenExchangeSupported().Return(false)
c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(false)
c.EXPECT().GrantTypeClientCredentialsSupported().Return(false)
c.EXPECT().GrantTypeDeviceCodeSupported().Return(false)
return c
}(),
},
[]oidc.GrantType{
oidc.GrantTypeCode,
oidc.GrantTypeImplicit,
},
},
{
"code, implicit flow, refresh token, token exchange, jwt profile, client_credentials",
args{
func() op.Configuration {
c := mock.NewMockConfiguration(gomock.NewController(t))
c.EXPECT().GrantTypeRefreshTokenSupported().Return(true)
c.EXPECT().GrantTypeTokenExchangeSupported().Return(true)
c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(true)
c.EXPECT().GrantTypeClientCredentialsSupported().Return(true)
c.EXPECT().GrantTypeDeviceCodeSupported().Return(false)
return c
}(),
},
[]oidc.GrantType{
oidc.GrantTypeCode,
oidc.GrantTypeImplicit,
oidc.GrantTypeRefreshToken,
oidc.GrantTypeClientCredentials,
oidc.GrantTypeTokenExchange,
oidc.GrantTypeBearer,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.SigAlgorithms(tt.args.s); !reflect.DeepEqual(got, tt.want) {
t.Errorf("sigAlgorithms() = %v, want %v", got, tt.want)
}
got := op.GrantTypes(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
@ -195,9 +188,80 @@ func Test_SubjectTypes(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.SubjectTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("subjectTypes() = %v, want %v", got, tt.want)
}
got := op.SubjectTypes(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
func Test_SigAlgorithms(t *testing.T) {
m := mock.NewMockDiscoverStorage(gomock.NewController(t))
type args struct {
s op.DiscoverStorage
}
tests := []struct {
name string
args args
want []string
}{
{
"",
args{func() op.DiscoverStorage {
m.EXPECT().SignatureAlgorithms(gomock.Any()).Return([]jose.SignatureAlgorithm{jose.RS256}, nil)
return m
}()},
[]string{"RS256"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.SigAlgorithms(context.Background(), tt.args.s)
assert.Equal(t, tt.want, got)
})
}
}
func Test_RequestObjectSigAlgorithms(t *testing.T) {
m := mock.NewMockConfiguration(gomock.NewController(t))
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
{
"not supported, empty",
args{func() op.Configuration {
m.EXPECT().RequestObjectSupported().Return(false)
return m
}()},
nil,
},
{
"supported, empty",
args{func() op.Configuration {
m.EXPECT().RequestObjectSupported().Return(true)
m.EXPECT().RequestObjectSigningAlgorithmsSupported().Return(nil)
return m
}()},
nil,
},
{
"supported, list",
args{func() op.Configuration {
m.EXPECT().RequestObjectSupported().Return(true)
m.EXPECT().RequestObjectSigningAlgorithmsSupported().Return([]string{"RS256"})
return m
}()},
[]string{"RS256"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.RequestObjectSigAlgorithms(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
@ -244,9 +308,311 @@ func Test_AuthMethodsTokenEndpoint(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.AuthMethodsTokenEndpoint(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("authMethods() = %v, want %v", got, tt.want)
}
got := op.AuthMethodsTokenEndpoint(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
func Test_TokenSigAlgorithms(t *testing.T) {
m := mock.NewMockConfiguration(gomock.NewController(t))
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
{
"not supported, empty",
args{func() op.Configuration {
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
return m
}()},
nil,
},
{
"supported, empty",
args{func() op.Configuration {
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
m.EXPECT().TokenEndpointSigningAlgorithmsSupported().Return(nil)
return m
}()},
nil,
},
{
"supported, list",
args{func() op.Configuration {
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
m.EXPECT().TokenEndpointSigningAlgorithmsSupported().Return([]string{"RS256"})
return m
}()},
[]string{"RS256"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.TokenSigAlgorithms(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
func Test_IntrospectionSigAlgorithms(t *testing.T) {
m := mock.NewMockConfiguration(gomock.NewController(t))
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
{
"not supported, empty",
args{func() op.Configuration {
m.EXPECT().IntrospectionAuthMethodPrivateKeyJWTSupported().Return(false)
return m
}()},
nil,
},
{
"supported, empty",
args{func() op.Configuration {
m.EXPECT().IntrospectionAuthMethodPrivateKeyJWTSupported().Return(true)
m.EXPECT().IntrospectionEndpointSigningAlgorithmsSupported().Return(nil)
return m
}()},
nil,
},
{
"supported, list",
args{func() op.Configuration {
m.EXPECT().IntrospectionAuthMethodPrivateKeyJWTSupported().Return(true)
m.EXPECT().IntrospectionEndpointSigningAlgorithmsSupported().Return([]string{"RS256"})
return m
}()},
[]string{"RS256"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.IntrospectionSigAlgorithms(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
func Test_AuthMethodsIntrospectionEndpoint(t *testing.T) {
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []oidc.AuthMethod
}{
{
"basic only",
args{func() op.Configuration {
m := mock.NewMockConfiguration(gomock.NewController(t))
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
return m
}()},
[]oidc.AuthMethod{oidc.AuthMethodBasic},
},
{
"basic and private_key_jwt",
args{func() op.Configuration {
m := mock.NewMockConfiguration(gomock.NewController(t))
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
return m
}()},
[]oidc.AuthMethod{oidc.AuthMethodBasic, oidc.AuthMethodPrivateKeyJWT},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.AuthMethodsIntrospectionEndpoint(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
func Test_RevocationSigAlgorithms(t *testing.T) {
m := mock.NewMockConfiguration(gomock.NewController(t))
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
{
"not supported, empty",
args{func() op.Configuration {
m.EXPECT().RevocationAuthMethodPrivateKeyJWTSupported().Return(false)
return m
}()},
nil,
},
{
"supported, empty",
args{func() op.Configuration {
m.EXPECT().RevocationAuthMethodPrivateKeyJWTSupported().Return(true)
m.EXPECT().RevocationEndpointSigningAlgorithmsSupported().Return(nil)
return m
}()},
nil,
},
{
"supported, list",
args{func() op.Configuration {
m.EXPECT().RevocationAuthMethodPrivateKeyJWTSupported().Return(true)
m.EXPECT().RevocationEndpointSigningAlgorithmsSupported().Return([]string{"RS256"})
return m
}()},
[]string{"RS256"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.RevocationSigAlgorithms(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
func Test_AuthMethodsRevocationEndpoint(t *testing.T) {
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []oidc.AuthMethod
}{
{
"none and basic",
args{func() op.Configuration {
m := mock.NewMockConfiguration(gomock.NewController(t))
m.EXPECT().AuthMethodPostSupported().Return(false)
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
return m
}()},
[]oidc.AuthMethod{oidc.AuthMethodNone, oidc.AuthMethodBasic},
},
{
"none, basic and post",
args{func() op.Configuration {
m := mock.NewMockConfiguration(gomock.NewController(t))
m.EXPECT().AuthMethodPostSupported().Return(true)
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
return m
}()},
[]oidc.AuthMethod{oidc.AuthMethodNone, oidc.AuthMethodBasic, oidc.AuthMethodPost},
},
{
"none, basic, post and private_key_jwt",
args{func() op.Configuration {
m := mock.NewMockConfiguration(gomock.NewController(t))
m.EXPECT().AuthMethodPostSupported().Return(true)
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
return m
}()},
[]oidc.AuthMethod{oidc.AuthMethodNone, oidc.AuthMethodBasic, oidc.AuthMethodPost, oidc.AuthMethodPrivateKeyJWT},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.AuthMethodsRevocationEndpoint(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
func TestSupportedClaims(t *testing.T) {
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
{
"scopes",
args{},
[]string{
"sub",
"aud",
"exp",
"iat",
"iss",
"auth_time",
"nonce",
"acr",
"amr",
"c_hash",
"at_hash",
"act",
"scopes",
"client_id",
"azp",
"preferred_username",
"name",
"family_name",
"given_name",
"locale",
"email",
"email_verified",
"phone_number",
"phone_number_verified",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.SupportedClaims(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
func Test_CodeChallengeMethods(t *testing.T) {
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []oidc.CodeChallengeMethod
}{
{
"not supported",
args{func() op.Configuration {
m := mock.NewMockConfiguration(gomock.NewController(t))
m.EXPECT().CodeMethodS256Supported().Return(false)
return m
}()},
[]oidc.CodeChallengeMethod{},
},
{
"S256",
args{func() op.Configuration {
m := mock.NewMockConfiguration(gomock.NewController(t))
m.EXPECT().CodeMethodS256Supported().Return(true)
return m
}()},
[]oidc.CodeChallengeMethod{oidc.CodeChallengeMethodS256},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.CodeChallengeMethods(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -3,7 +3,7 @@ package op_test
import (
"testing"
"github.com/zitadel/oidc/pkg/op"
"github.com/zitadel/oidc/v2/pkg/op"
)
func TestEndpoint_Path(t *testing.T) {

View file

@ -3,8 +3,8 @@ package op
import (
"net/http"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type ErrAuthRequest interface {

View file

@ -6,11 +6,11 @@ import (
"gopkg.in/square/go-jose.v2"
httphelper "github.com/zitadel/oidc/pkg/http"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
)
type KeyProvider interface {
GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
KeySet(context.Context) ([]Key, error)
}
func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) {
@ -20,10 +20,23 @@ func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) {
}
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
keySet, err := k.GetKeySet(r.Context())
keySet, err := k.KeySet(r.Context())
if err != nil {
httphelper.MarshalJSONWithStatus(w, err, http.StatusInternalServerError)
return
}
httphelper.MarshalJSON(w, keySet)
httphelper.MarshalJSON(w, jsonWebKeySet(keySet))
}
func jsonWebKeySet(keys []Key) *jose.JSONWebKeySet {
webKeys := make([]jose.JSONWebKey, len(keys))
for i, key := range keys {
webKeys[i] = jose.JSONWebKey{
KeyID: key.ID(),
Algorithm: string(key.Algorithm()),
Use: key.Use(),
Key: key.Key(),
}
}
return &jose.JSONWebKeySet{Keys: webKeys}
}

View file

@ -11,9 +11,9 @@ import (
"github.com/stretchr/testify/assert"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/pkg/op"
"github.com/zitadel/oidc/pkg/op/mock"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v2/pkg/op/mock"
)
func TestKeys(t *testing.T) {
@ -35,7 +35,7 @@ func TestKeys(t *testing.T) {
args: args{
k: func() op.KeyProvider {
m := mock.NewMockKeyProvider(gomock.NewController(t))
m.EXPECT().GetKeySet(gomock.Any()).Return(nil, oidc.ErrServerError())
m.EXPECT().KeySet(gomock.Any()).Return(nil, oidc.ErrServerError())
return m
}(),
},
@ -51,39 +51,39 @@ func TestKeys(t *testing.T) {
args: args{
k: func() op.KeyProvider {
m := mock.NewMockKeyProvider(gomock.NewController(t))
m.EXPECT().GetKeySet(gomock.Any()).Return(nil, nil)
m.EXPECT().KeySet(gomock.Any()).Return(nil, nil)
return m
}(),
},
res: res{
statusCode: http.StatusOK,
contentType: "application/json",
body: `{"keys":[]}
`,
},
},
{
name: "list",
args: args{
k: func() op.KeyProvider {
m := mock.NewMockKeyProvider(gomock.NewController(t))
m.EXPECT().GetKeySet(gomock.Any()).Return(
&jose.JSONWebKeySet{Keys: []jose.JSONWebKey{
{
Key: &rsa.PublicKey{
N: big.NewInt(1),
E: 1,
},
KeyID: "id",
},
}},
nil,
)
ctrl := gomock.NewController(t)
m := mock.NewMockKeyProvider(ctrl)
k := mock.NewMockKey(ctrl)
k.EXPECT().Key().Return(&rsa.PublicKey{
N: big.NewInt(1),
E: 1,
})
k.EXPECT().ID().Return("id")
k.EXPECT().Algorithm().Return(jose.RS256)
k.EXPECT().Use().Return("sig")
m.EXPECT().KeySet(gomock.Any()).Return([]op.Key{k}, nil)
return m
}(),
},
res: res{
statusCode: http.StatusOK,
contentType: "application/json",
body: `{"keys":[{"kty":"RSA","kid":"id","n":"AQ","e":"AQ"}]}
body: `{"keys":[{"use":"sig","kty":"RSA","kid":"id","alg":"RS256","n":"AQ","e":"AQ"}]}
`,
},
},

View file

@ -1,15 +1,16 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/pkg/op (interfaces: Authorizer)
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Authorizer)
// Package mock is a generated GoMock package.
package mock
import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
http "github.com/zitadel/oidc/pkg/http"
op "github.com/zitadel/oidc/pkg/op"
http "github.com/zitadel/oidc/v2/pkg/http"
op "github.com/zitadel/oidc/v2/pkg/op"
)
// MockAuthorizer is a mock of Authorizer interface.
@ -78,31 +79,17 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
}
// IDTokenHintVerifier mocks base method.
func (m *MockAuthorizer) IDTokenHintVerifier() op.IDTokenHintVerifier {
func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) op.IDTokenHintVerifier {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IDTokenHintVerifier")
ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0)
ret0, _ := ret[0].(op.IDTokenHintVerifier)
return ret0
}
// IDTokenHintVerifier indicates an expected call of IDTokenHintVerifier.
func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier() *gomock.Call {
func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier))
}
// Issuer mocks base method.
func (m *MockAuthorizer) Issuer() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Issuer")
ret0, _ := ret[0].(string)
return ret0
}
// Issuer indicates an expected call of Issuer.
func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier), arg0)
}
// RequestObjectSupported mocks base method.
@ -119,20 +106,6 @@ func (mr *MockAuthorizerMockRecorder) RequestObjectSupported() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSupported", reflect.TypeOf((*MockAuthorizer)(nil).RequestObjectSupported))
}
// Signer mocks base method.
func (m *MockAuthorizer) Signer() op.Signer {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Signer")
ret0, _ := ret[0].(op.Signer)
return ret0
}
// Signer indicates an expected call of Signer.
func (mr *MockAuthorizerMockRecorder) Signer() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockAuthorizer)(nil).Signer))
}
// Storage mocks base method.
func (m *MockAuthorizer) Storage() op.Storage {
m.ctrl.T.Helper()

View file

@ -8,8 +8,8 @@ import (
"github.com/gorilla/schema"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/pkg/op"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
)
func NewAuthorizer(t *testing.T) op.Authorizer {
@ -20,23 +20,13 @@ func NewAuthorizerExpectValid(t *testing.T, wantErr bool) op.Authorizer {
m := NewAuthorizer(t)
ExpectDecoder(m)
ExpectEncoder(m)
ExpectSigner(m, t)
//ExpectSigner(m, t)
ExpectStorage(m, t)
ExpectVerifier(m, t)
// ExpectErrorHandler(m, t, wantErr)
return m
}
// func NewAuthorizerExpectDecoderFails(t *testing.T) op.Authorizer {
// m := NewAuthorizer(t)
// ExpectDecoderFails(m)
// ExpectEncoder(m)
// ExpectSigner(m, t)
// ExpectStorage(m, t)
// ExpectErrorHandler(m, t)
// return m
// }
func ExpectDecoder(a op.Authorizer) {
mockA := a.(*MockAuthorizer)
mockA.EXPECT().Decoder().AnyTimes().Return(schema.NewDecoder())
@ -47,17 +37,18 @@ func ExpectEncoder(a op.Authorizer) {
mockA.EXPECT().Encoder().AnyTimes().Return(schema.NewEncoder())
}
func ExpectSigner(a op.Authorizer, t *testing.T) {
mockA := a.(*MockAuthorizer)
mockA.EXPECT().Signer().DoAndReturn(
func() op.Signer {
return &Sig{}
})
}
//
//func ExpectSigner(a op.Authorizer, t *testing.T) {
// mockA := a.(*MockAuthorizer)
// mockA.EXPECT().Signer().DoAndReturn(
// func() op.Signer {
// return &Sig{}
// })
//}
func ExpectVerifier(a op.Authorizer, t *testing.T) {
mockA := a.(*MockAuthorizer)
mockA.EXPECT().IDTokenHintVerifier().DoAndReturn(
mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn(
func() op.IDTokenHintVerifier {
return op.NewIDTokenHintVerifier("", nil)
})

View file

@ -5,8 +5,8 @@ import (
"github.com/golang/mock/gomock"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/pkg/op"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
)
func NewClient(t *testing.T) op.Client {

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/pkg/op (interfaces: Client)
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Client)
// Package mock is a generated GoMock package.
package mock
@ -9,8 +9,8 @@ import (
time "time"
gomock "github.com/golang/mock/gomock"
oidc "github.com/zitadel/oidc/pkg/oidc"
op "github.com/zitadel/oidc/pkg/op"
oidc "github.com/zitadel/oidc/v2/pkg/oidc"
op "github.com/zitadel/oidc/v2/pkg/op"
)
// MockClient is a mock of Client interface.

View file

@ -1,14 +1,15 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/pkg/op (interfaces: Configuration)
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Configuration)
// Package mock is a generated GoMock package.
package mock
import (
http "net/http"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
op "github.com/zitadel/oidc/pkg/op"
op "github.com/zitadel/oidc/v2/pkg/op"
language "golang.org/x/text/language"
)
@ -91,6 +92,34 @@ func (mr *MockConfigurationMockRecorder) CodeMethodS256Supported() *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CodeMethodS256Supported", reflect.TypeOf((*MockConfiguration)(nil).CodeMethodS256Supported))
}
// DeviceAuthorization mocks base method.
func (m *MockConfiguration) DeviceAuthorization() op.DeviceAuthorizationConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeviceAuthorization")
ret0, _ := ret[0].(op.DeviceAuthorizationConfig)
return ret0
}
// DeviceAuthorization indicates an expected call of DeviceAuthorization.
func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorization", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorization))
}
// DeviceAuthorizationEndpoint mocks base method.
func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
return ret0
}
// DeviceAuthorizationEndpoint indicates an expected call of DeviceAuthorizationEndpoint.
func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorizationEndpoint))
}
// EndSessionEndpoint mocks base method.
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
m.ctrl.T.Helper()
@ -119,6 +148,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeClientCredentialsSupported() *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeClientCredentialsSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeClientCredentialsSupported))
}
// GrantTypeDeviceCodeSupported mocks base method.
func (m *MockConfiguration) GrantTypeDeviceCodeSupported() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GrantTypeDeviceCodeSupported")
ret0, _ := ret[0].(bool)
return ret0
}
// GrantTypeDeviceCodeSupported indicates an expected call of GrantTypeDeviceCodeSupported.
func (mr *MockConfigurationMockRecorder) GrantTypeDeviceCodeSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeDeviceCodeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeDeviceCodeSupported))
}
// GrantTypeJWTAuthorizationSupported mocks base method.
func (m *MockConfiguration) GrantTypeJWTAuthorizationSupported() bool {
m.ctrl.T.Helper()
@ -161,6 +204,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeTokenExchangeSupported() *gomo
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeTokenExchangeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeTokenExchangeSupported))
}
// Insecure mocks base method.
func (m *MockConfiguration) Insecure() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Insecure")
ret0, _ := ret[0].(bool)
return ret0
}
// Insecure indicates an expected call of Insecure.
func (mr *MockConfigurationMockRecorder) Insecure() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insecure", reflect.TypeOf((*MockConfiguration)(nil).Insecure))
}
// IntrospectionAuthMethodPrivateKeyJWTSupported mocks base method.
func (m *MockConfiguration) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
m.ctrl.T.Helper()
@ -203,18 +260,18 @@ func (mr *MockConfigurationMockRecorder) IntrospectionEndpointSigningAlgorithmsS
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionEndpointSigningAlgorithmsSupported))
}
// Issuer mocks base method.
func (m *MockConfiguration) Issuer() string {
// IssuerFromRequest mocks base method.
func (m *MockConfiguration) IssuerFromRequest(arg0 *http.Request) string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Issuer")
ret := m.ctrl.Call(m, "IssuerFromRequest", arg0)
ret0, _ := ret[0].(string)
return ret0
}
// Issuer indicates an expected call of Issuer.
func (mr *MockConfigurationMockRecorder) Issuer() *gomock.Call {
// IssuerFromRequest indicates an expected call of IssuerFromRequest.
func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockConfiguration)(nil).Issuer))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IssuerFromRequest", reflect.TypeOf((*MockConfiguration)(nil).IssuerFromRequest), arg0)
}
// KeysEndpoint mocks base method.

View file

@ -0,0 +1,51 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: DiscoverStorage)
// Package mock is a generated GoMock package.
package mock
import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
jose "gopkg.in/square/go-jose.v2"
)
// MockDiscoverStorage is a mock of DiscoverStorage interface.
type MockDiscoverStorage struct {
ctrl *gomock.Controller
recorder *MockDiscoverStorageMockRecorder
}
// MockDiscoverStorageMockRecorder is the mock recorder for MockDiscoverStorage.
type MockDiscoverStorageMockRecorder struct {
mock *MockDiscoverStorage
}
// NewMockDiscoverStorage creates a new mock instance.
func NewMockDiscoverStorage(ctrl *gomock.Controller) *MockDiscoverStorage {
mock := &MockDiscoverStorage{ctrl: ctrl}
mock.recorder = &MockDiscoverStorageMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDiscoverStorage) EXPECT() *MockDiscoverStorageMockRecorder {
return m.recorder
}
// SignatureAlgorithms mocks base method.
func (m *MockDiscoverStorage) SignatureAlgorithms(arg0 context.Context) ([]jose.SignatureAlgorithm, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignatureAlgorithms", arg0)
ret0, _ := ret[0].([]jose.SignatureAlgorithm)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SignatureAlgorithms indicates an expected call of SignatureAlgorithms.
func (mr *MockDiscoverStorageMockRecorder) SignatureAlgorithms(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithms", reflect.TypeOf((*MockDiscoverStorage)(nil).SignatureAlgorithms), arg0)
}

View file

@ -1,8 +1,10 @@
package mock
//go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/pkg/op Storage
//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/pkg/op Authorizer
//go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/pkg/op Client
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/pkg/op Configuration
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/pkg/op Signer
//go:generate mockgen -package mock -destination ./key.mock.go github.com/zitadel/oidc/pkg/op KeyProvider
//go:generate go install github.com/golang/mock/mockgen@v1.6.0
//go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/v2/pkg/op Storage
//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/v2/pkg/op Authorizer
//go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/v2/pkg/op Client
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/v2/pkg/op Configuration
//go:generate mockgen -package mock -destination ./discovery.mock.go github.com/zitadel/oidc/v2/pkg/op DiscoverStorage
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/v2/pkg/op SigningKey,Key
//go:generate mockgen -package mock -destination ./key.mock.go github.com/zitadel/oidc/v2/pkg/op KeyProvider

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/pkg/op (interfaces: KeyProvider)
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: KeyProvider)
// Package mock is a generated GoMock package.
package mock
@ -9,7 +9,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
jose "gopkg.in/square/go-jose.v2"
op "github.com/zitadel/oidc/v2/pkg/op"
)
// MockKeyProvider is a mock of KeyProvider interface.
@ -35,17 +35,17 @@ func (m *MockKeyProvider) EXPECT() *MockKeyProviderMockRecorder {
return m.recorder
}
// GetKeySet mocks base method.
func (m *MockKeyProvider) GetKeySet(arg0 context.Context) (*jose.JSONWebKeySet, error) {
// KeySet mocks base method.
func (m *MockKeyProvider) KeySet(arg0 context.Context) ([]op.Key, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKeySet", arg0)
ret0, _ := ret[0].(*jose.JSONWebKeySet)
ret := m.ctrl.Call(m, "KeySet", arg0)
ret0, _ := ret[0].([]op.Key)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetKeySet indicates an expected call of GetKeySet.
func (mr *MockKeyProviderMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
// KeySet indicates an expected call of KeySet.
func (mr *MockKeyProviderMockRecorder) KeySet(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockKeyProvider)(nil).GetKeySet), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeySet", reflect.TypeOf((*MockKeyProvider)(nil).KeySet), arg0)
}

View file

@ -1,56 +1,69 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/pkg/op (interfaces: Signer)
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: SigningKey,Key)
// Package mock is a generated GoMock package.
package mock
import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
jose "gopkg.in/square/go-jose.v2"
)
// MockSigner is a mock of Signer interface.
type MockSigner struct {
// MockSigningKey is a mock of SigningKey interface.
type MockSigningKey struct {
ctrl *gomock.Controller
recorder *MockSignerMockRecorder
recorder *MockSigningKeyMockRecorder
}
// MockSignerMockRecorder is the mock recorder for MockSigner.
type MockSignerMockRecorder struct {
mock *MockSigner
// MockSigningKeyMockRecorder is the mock recorder for MockSigningKey.
type MockSigningKeyMockRecorder struct {
mock *MockSigningKey
}
// NewMockSigner creates a new mock instance.
func NewMockSigner(ctrl *gomock.Controller) *MockSigner {
mock := &MockSigner{ctrl: ctrl}
mock.recorder = &MockSignerMockRecorder{mock}
// NewMockSigningKey creates a new mock instance.
func NewMockSigningKey(ctrl *gomock.Controller) *MockSigningKey {
mock := &MockSigningKey{ctrl: ctrl}
mock.recorder = &MockSigningKeyMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSigner) EXPECT() *MockSignerMockRecorder {
func (m *MockSigningKey) EXPECT() *MockSigningKeyMockRecorder {
return m.recorder
}
// Health mocks base method.
func (m *MockSigner) Health(arg0 context.Context) error {
// ID mocks base method.
func (m *MockSigningKey) ID() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Health", arg0)
ret0, _ := ret[0].(error)
ret := m.ctrl.Call(m, "ID")
ret0, _ := ret[0].(string)
return ret0
}
// Health indicates an expected call of Health.
func (mr *MockSignerMockRecorder) Health(arg0 interface{}) *gomock.Call {
// ID indicates an expected call of ID.
func (mr *MockSigningKeyMockRecorder) ID() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockSigner)(nil).Health), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockSigningKey)(nil).ID))
}
// Key mocks base method.
func (m *MockSigningKey) Key() interface{} {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Key")
ret0, _ := ret[0].(interface{})
return ret0
}
// Key indicates an expected call of Key.
func (mr *MockSigningKeyMockRecorder) Key() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Key", reflect.TypeOf((*MockSigningKey)(nil).Key))
}
// SignatureAlgorithm mocks base method.
func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
func (m *MockSigningKey) SignatureAlgorithm() jose.SignatureAlgorithm {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignatureAlgorithm")
ret0, _ := ret[0].(jose.SignatureAlgorithm)
@ -58,21 +71,86 @@ func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
}
// SignatureAlgorithm indicates an expected call of SignatureAlgorithm.
func (mr *MockSignerMockRecorder) SignatureAlgorithm() *gomock.Call {
func (mr *MockSigningKeyMockRecorder) SignatureAlgorithm() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigner)(nil).SignatureAlgorithm))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigningKey)(nil).SignatureAlgorithm))
}
// Signer mocks base method.
func (m *MockSigner) Signer() jose.Signer {
// MockKey is a mock of Key interface.
type MockKey struct {
ctrl *gomock.Controller
recorder *MockKeyMockRecorder
}
// MockKeyMockRecorder is the mock recorder for MockKey.
type MockKeyMockRecorder struct {
mock *MockKey
}
// NewMockKey creates a new mock instance.
func NewMockKey(ctrl *gomock.Controller) *MockKey {
mock := &MockKey{ctrl: ctrl}
mock.recorder = &MockKeyMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockKey) EXPECT() *MockKeyMockRecorder {
return m.recorder
}
// Algorithm mocks base method.
func (m *MockKey) Algorithm() jose.SignatureAlgorithm {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Signer")
ret0, _ := ret[0].(jose.Signer)
ret := m.ctrl.Call(m, "Algorithm")
ret0, _ := ret[0].(jose.SignatureAlgorithm)
return ret0
}
// Signer indicates an expected call of Signer.
func (mr *MockSignerMockRecorder) Signer() *gomock.Call {
// Algorithm indicates an expected call of Algorithm.
func (mr *MockKeyMockRecorder) Algorithm() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockSigner)(nil).Signer))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Algorithm", reflect.TypeOf((*MockKey)(nil).Algorithm))
}
// ID mocks base method.
func (m *MockKey) ID() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ID")
ret0, _ := ret[0].(string)
return ret0
}
// ID indicates an expected call of ID.
func (mr *MockKeyMockRecorder) ID() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockKey)(nil).ID))
}
// Key mocks base method.
func (m *MockKey) Key() interface{} {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Key")
ret0, _ := ret[0].(interface{})
return ret0
}
// Key indicates an expected call of Key.
func (mr *MockKeyMockRecorder) Key() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Key", reflect.TypeOf((*MockKey)(nil).Key))
}
// Use mocks base method.
func (m *MockKey) Use() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Use")
ret0, _ := ret[0].(string)
return ret0
}
// Use indicates an expected call of Use.
func (mr *MockKeyMockRecorder) Use() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Use", reflect.TypeOf((*MockKey)(nil).Use))
}

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/pkg/op (interfaces: Storage)
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Storage)
// Package mock is a generated GoMock package.
package mock
@ -10,8 +10,8 @@ import (
time "time"
gomock "github.com/golang/mock/gomock"
oidc "github.com/zitadel/oidc/pkg/oidc"
op "github.com/zitadel/oidc/pkg/op"
oidc "github.com/zitadel/oidc/v2/pkg/oidc"
op "github.com/zitadel/oidc/v2/pkg/op"
jose "gopkg.in/square/go-jose.v2"
)
@ -159,34 +159,19 @@ func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0, arg1)
}
// GetKeyByIDAndUserID mocks base method.
func (m *MockStorage) GetKeyByIDAndUserID(arg0 context.Context, arg1, arg2 string) (*jose.JSONWebKey, error) {
// GetKeyByIDAndClientID mocks base method.
func (m *MockStorage) GetKeyByIDAndClientID(arg0 context.Context, arg1, arg2 string) (*jose.JSONWebKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKeyByIDAndUserID", arg0, arg1, arg2)
ret := m.ctrl.Call(m, "GetKeyByIDAndClientID", arg0, arg1, arg2)
ret0, _ := ret[0].(*jose.JSONWebKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetKeyByIDAndUserID indicates an expected call of GetKeyByIDAndUserID.
func (mr *MockStorageMockRecorder) GetKeyByIDAndUserID(arg0, arg1, arg2 interface{}) *gomock.Call {
// GetKeyByIDAndClientID indicates an expected call of GetKeyByIDAndClientID.
func (mr *MockStorageMockRecorder) GetKeyByIDAndClientID(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeyByIDAndUserID", reflect.TypeOf((*MockStorage)(nil).GetKeyByIDAndUserID), arg0, arg1, arg2)
}
// GetKeySet mocks base method.
func (m *MockStorage) GetKeySet(arg0 context.Context) (*jose.JSONWebKeySet, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKeySet", arg0)
ret0, _ := ret[0].(*jose.JSONWebKeySet)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetKeySet indicates an expected call of GetKeySet.
func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeyByIDAndClientID", reflect.TypeOf((*MockStorage)(nil).GetKeyByIDAndClientID), arg0, arg1, arg2)
}
// GetPrivateClaimsFromScopes mocks base method.
@ -204,16 +189,20 @@ func (mr *MockStorageMockRecorder) GetPrivateClaimsFromScopes(arg0, arg1, arg2,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivateClaimsFromScopes", reflect.TypeOf((*MockStorage)(nil).GetPrivateClaimsFromScopes), arg0, arg1, arg2, arg3)
}
// GetSigningKey mocks base method.
func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey) {
// GetRefreshTokenInfo mocks base method.
func (m *MockStorage) GetRefreshTokenInfo(arg0 context.Context, arg1, arg2 string) (string, string, error) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "GetSigningKey", arg0, arg1)
ret := m.ctrl.Call(m, "GetRefreshTokenInfo", arg0, arg1, arg2)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(string)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetSigningKey indicates an expected call of GetSigningKey.
func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1 interface{}) *gomock.Call {
// GetRefreshTokenInfo indicates an expected call of GetRefreshTokenInfo.
func (mr *MockStorageMockRecorder) GetRefreshTokenInfo(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRefreshTokenInfo", reflect.TypeOf((*MockStorage)(nil).GetRefreshTokenInfo), arg0, arg1, arg2)
}
// Health mocks base method.
@ -230,6 +219,21 @@ func (mr *MockStorageMockRecorder) Health(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0)
}
// KeySet mocks base method.
func (m *MockStorage) KeySet(arg0 context.Context) ([]op.Key, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeySet", arg0)
ret0, _ := ret[0].([]op.Key)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// KeySet indicates an expected call of KeySet.
func (mr *MockStorageMockRecorder) KeySet(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeySet", reflect.TypeOf((*MockStorage)(nil).KeySet), arg0)
}
// RevokeToken mocks base method.
func (m *MockStorage) RevokeToken(arg0 context.Context, arg1, arg2, arg3 string) *oidc.Error {
m.ctrl.T.Helper()
@ -259,7 +263,7 @@ func (mr *MockStorageMockRecorder) SaveAuthCode(arg0, arg1, arg2 interface{}) *g
}
// SetIntrospectionFromToken mocks base method.
func (m *MockStorage) SetIntrospectionFromToken(arg0 context.Context, arg1 oidc.IntrospectionResponse, arg2, arg3, arg4 string) error {
func (m *MockStorage) SetIntrospectionFromToken(arg0 context.Context, arg1 *oidc.IntrospectionResponse, arg2, arg3, arg4 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetIntrospectionFromToken", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(error)
@ -273,7 +277,7 @@ func (mr *MockStorageMockRecorder) SetIntrospectionFromToken(arg0, arg1, arg2, a
}
// SetUserinfoFromScopes mocks base method.
func (m *MockStorage) SetUserinfoFromScopes(arg0 context.Context, arg1 oidc.UserInfoSetter, arg2, arg3 string, arg4 []string) error {
func (m *MockStorage) SetUserinfoFromScopes(arg0 context.Context, arg1 *oidc.UserInfo, arg2, arg3 string, arg4 []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetUserinfoFromScopes", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(error)
@ -287,7 +291,7 @@ func (mr *MockStorageMockRecorder) SetUserinfoFromScopes(arg0, arg1, arg2, arg3,
}
// SetUserinfoFromToken mocks base method.
func (m *MockStorage) SetUserinfoFromToken(arg0 context.Context, arg1 oidc.UserInfoSetter, arg2, arg3, arg4 string) error {
func (m *MockStorage) SetUserinfoFromToken(arg0 context.Context, arg1 *oidc.UserInfo, arg2, arg3, arg4 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetUserinfoFromToken", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(error)
@ -300,6 +304,36 @@ func (mr *MockStorageMockRecorder) SetUserinfoFromToken(arg0, arg1, arg2, arg3,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUserinfoFromToken", reflect.TypeOf((*MockStorage)(nil).SetUserinfoFromToken), arg0, arg1, arg2, arg3, arg4)
}
// SignatureAlgorithms mocks base method.
func (m *MockStorage) SignatureAlgorithms(arg0 context.Context) ([]jose.SignatureAlgorithm, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignatureAlgorithms", arg0)
ret0, _ := ret[0].([]jose.SignatureAlgorithm)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SignatureAlgorithms indicates an expected call of SignatureAlgorithms.
func (mr *MockStorageMockRecorder) SignatureAlgorithms(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithms", reflect.TypeOf((*MockStorage)(nil).SignatureAlgorithms), arg0)
}
// SigningKey mocks base method.
func (m *MockStorage) SigningKey(arg0 context.Context) (op.SigningKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SigningKey", arg0)
ret0, _ := ret[0].(op.SigningKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SigningKey indicates an expected call of SigningKey.
func (mr *MockStorageMockRecorder) SigningKey(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SigningKey", reflect.TypeOf((*MockStorage)(nil).SigningKey), arg0)
}
// TerminateSession mocks base method.
func (m *MockStorage) TerminateSession(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()

View file

@ -6,13 +6,10 @@ import (
"testing"
"time"
"github.com/zitadel/oidc/pkg/oidc"
"gopkg.in/square/go-jose.v2"
"github.com/golang/mock/gomock"
"github.com/zitadel/oidc/pkg/op"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
)
func NewStorage(t *testing.T) op.Storage {
@ -41,13 +38,13 @@ func NewMockStorageAny(t *testing.T) op.Storage {
func NewMockStorageSigningKeyInvalid(t *testing.T) op.Storage {
m := NewStorage(t)
ExpectSigningKeyInvalid(m)
//ExpectSigningKeyInvalid(m)
return m
}
func NewMockStorageSigningKey(t *testing.T) op.Storage {
m := NewStorage(t)
ExpectSigningKey(m)
//ExpectSigningKey(m)
return m
}
@ -85,24 +82,6 @@ func ExpectValidClientID(s op.Storage) {
})
}
func ExpectSigningKeyInvalid(s op.Storage) {
mockS := s.(*MockStorage)
mockS.EXPECT().GetSigningKey(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, keyCh chan<- jose.SigningKey) {
keyCh <- jose.SigningKey{}
},
)
}
func ExpectSigningKey(s op.Storage) {
mockS := s.(*MockStorage)
mockS.EXPECT().GetSigningKey(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, keyCh chan<- jose.SigningKey) {
keyCh <- jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}
},
)
}
type ConfClient struct {
id string
appType op.ApplicationType

View file

@ -12,8 +12,8 @@ import (
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
const (
@ -27,80 +27,84 @@ const (
defaultRevocationEndpoint = "revoke"
defaultEndSessionEndpoint = "end_session"
defaultKeysEndpoint = "keys"
defaultDeviceAuthzEndpoint = "/device_authorization"
)
var DefaultEndpoints = &endpoints{
Authorization: NewEndpoint(defaultAuthorizationEndpoint),
Token: NewEndpoint(defaultTokenEndpoint),
Introspection: NewEndpoint(defaultIntrospectEndpoint),
Userinfo: NewEndpoint(defaultUserinfoEndpoint),
Revocation: NewEndpoint(defaultRevocationEndpoint),
EndSession: NewEndpoint(defaultEndSessionEndpoint),
JwksURI: NewEndpoint(defaultKeysEndpoint),
}
var (
DefaultEndpoints = &endpoints{
Authorization: NewEndpoint(defaultAuthorizationEndpoint),
Token: NewEndpoint(defaultTokenEndpoint),
Introspection: NewEndpoint(defaultIntrospectEndpoint),
Userinfo: NewEndpoint(defaultUserinfoEndpoint),
Revocation: NewEndpoint(defaultRevocationEndpoint),
EndSession: NewEndpoint(defaultEndSessionEndpoint),
JwksURI: NewEndpoint(defaultKeysEndpoint),
DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint),
}
defaultCORSOptions = cors.Options{
AllowCredentials: true,
AllowedHeaders: []string{
"Origin",
"Accept",
"Accept-Language",
"Authorization",
"Content-Type",
"X-Requested-With",
},
AllowedMethods: []string{
http.MethodGet,
http.MethodHead,
http.MethodPost,
},
ExposedHeaders: []string{
"Location",
"Content-Length",
},
AllowOriginFunc: func(_ string) bool {
return true
},
}
)
type OpenIDProvider interface {
Configuration
Storage() Storage
Decoder() httphelper.Decoder
Encoder() httphelper.Encoder
IDTokenHintVerifier() IDTokenHintVerifier
AccessTokenVerifier() AccessTokenVerifier
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
AccessTokenVerifier(context.Context) AccessTokenVerifier
Crypto() Crypto
DefaultLogoutRedirectURI() string
Signer() Signer
Probes() []ProbesFn
HttpHandler() http.Handler
}
type HttpInterceptor func(http.Handler) http.Handler
var defaultCORSOptions = cors.Options{
AllowCredentials: true,
AllowedHeaders: []string{
"Origin",
"Accept",
"Accept-Language",
"Authorization",
"Content-Type",
"X-Requested-With",
},
AllowedMethods: []string{
http.MethodGet,
http.MethodHead,
http.MethodPost,
},
ExposedHeaders: []string{
"Location",
"Content-Length",
},
AllowOriginFunc: func(_ string) bool {
return true
},
}
func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router {
intercept := buildInterceptor(interceptors...)
router := mux.NewRouter()
router.Use(cors.New(defaultCORSOptions).Handler)
router.Use(intercept(o.IssuerFromRequest, interceptors...))
router.HandleFunc(healthEndpoint, healthHandler)
router.HandleFunc(readinessEndpoint, readyHandler(o.Probes()))
router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Signer()))
router.Handle(o.AuthorizationEndpoint().Relative(), intercept(authorizeHandler(o)))
router.NewRoute().Path(authCallbackPath(o)).Queries("id", "{id}").Handler(intercept(authorizeCallbackHandler(o)))
router.Handle(o.TokenEndpoint().Relative(), intercept(tokenHandler(o)))
router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage()))
router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o))
router.NewRoute().Path(authCallbackPath(o)).Queries("id", "{id}").HandlerFunc(authorizeCallbackHandler(o))
router.HandleFunc(o.TokenEndpoint().Relative(), tokenHandler(o))
router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o))
router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o))
router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o))
router.Handle(o.EndSessionEndpoint().Relative(), intercept(endSessionHandler(o)))
router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o))
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), DeviceAuthorizationHandler(o))
return router
}
// AuthCallbackURL builds the url for the redirect (with the requestID) after a successful login
func AuthCallbackURL(o OpenIDProvider) func(string) string {
return func(requestID string) string {
return o.AuthorizationEndpoint().Absolute(o.Issuer()) + authCallbackPathSuffix + "?id=" + requestID
func AuthCallbackURL(o OpenIDProvider) func(context.Context, string) string {
return func(ctx context.Context, requestID string) string {
return o.AuthorizationEndpoint().Absolute(IssuerFromContext(ctx)) + authCallbackPathSuffix + "?id=" + requestID
}
}
@ -109,7 +113,6 @@ func authCallbackPath(o OpenIDProvider) string {
}
type Config struct {
Issuer string
CryptoKey [32]byte
DefaultLogoutRedirectURI string
CodeMethodS256 bool
@ -118,44 +121,52 @@ type Config struct {
GrantTypeRefreshToken bool
RequestObjectSupported bool
SupportedUILocales []language.Tag
DeviceAuthorization DeviceAuthorizationConfig
}
type endpoints struct {
Authorization Endpoint
Token Endpoint
Introspection Endpoint
Userinfo Endpoint
Revocation Endpoint
EndSession Endpoint
CheckSessionIframe Endpoint
JwksURI Endpoint
Authorization Endpoint
Token Endpoint
Introspection Endpoint
Userinfo Endpoint
Revocation Endpoint
EndSession Endpoint
CheckSessionIframe Endpoint
JwksURI Endpoint
DeviceAuthorization Endpoint
}
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
// a http.Router that handles a suite of endpoints (some paths can be overridden):
// /healthz
// /ready
// /.well-known/openid-configuration
// /oauth/token
// /oauth/introspect
// /callback
// /authorize
// /userinfo
// /revoke
// /end_session
// /keys
//
// /healthz
// /ready
// /.well-known/openid-configuration
// /oauth/token
// /oauth/introspect
// /callback
// /authorize
// /userinfo
// /revoke
// /end_session
// /keys
// /device_authorization
//
// This does not include login. Login is handled with a redirect that includes the
// request ID. The redirect for logins is specified per-client by Client.LoginURL().
// Successful logins should mark the request as authorized and redirect back to to
// op.AuthCallbackURL(provider) which is probably /callback. On the redirect back
// to the AuthCallbackURL, the request id should be passed as the "id" parameter.
func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opOpts ...Option) (OpenIDProvider, error) {
err := ValidateIssuer(config.Issuer)
if err != nil {
return nil, err
}
func NewOpenIDProvider(issuer string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) {
return newProvider(config, storage, StaticIssuer(issuer), opOpts...)
}
o := &openidProvider{
func NewDynamicOpenIDProvider(path string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) {
return newProvider(config, storage, IssuerFromHost(path), opOpts...)
}
func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromRequest, error), opOpts ...Option) (_ *Provider, err error) {
o := &Provider{
config: config,
storage: storage,
endpoints: DefaultEndpoints,
@ -168,36 +179,32 @@ func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opO
}
}
keyCh := make(chan jose.SigningKey)
go storage.GetSigningKey(ctx, keyCh)
o.signer = NewSigner(ctx, storage, keyCh)
o.issuer, err = issuer(o.insecure)
if err != nil {
return nil, err
}
o.httpHandler = CreateRouter(o, o.interceptors...)
o.decoder = schema.NewDecoder()
o.decoder.IgnoreUnknownKeys(true)
o.encoder = schema.NewEncoder()
o.encoder = oidc.NewEncoder()
o.crypto = NewAESCrypto(config.CryptoKey)
// Avoid potential race conditions by calling these early
_ = o.AccessTokenVerifier() // sets accessTokenVerifier
_ = o.IDTokenHintVerifier() // sets idTokenHintVerifier
_ = o.JWTProfileVerifier() // sets jwtProfileVerifier
_ = o.openIDKeySet() // sets keySet
_ = o.openIDKeySet() // sets keySet
return o, nil
}
type openidProvider struct {
type Provider struct {
config *Config
issuer IssuerFromRequest
insecure bool
endpoints *endpoints
storage Storage
signer Signer
idTokenHintVerifier IDTokenHintVerifier
jwtProfileVerifier JWTProfileVerifier
accessTokenVerifier AccessTokenVerifier
keySet *openIDKeySet
crypto Crypto
httpHandler http.Handler
@ -209,159 +216,163 @@ type openidProvider struct {
idTokenHintVerifierOpts []IDTokenHintVerifierOpt
}
func (o *openidProvider) Issuer() string {
return o.config.Issuer
func (o *Provider) IssuerFromRequest(r *http.Request) string {
return o.issuer(r)
}
func (o *openidProvider) AuthorizationEndpoint() Endpoint {
func (o *Provider) Insecure() bool {
return o.insecure
}
func (o *Provider) AuthorizationEndpoint() Endpoint {
return o.endpoints.Authorization
}
func (o *openidProvider) TokenEndpoint() Endpoint {
func (o *Provider) TokenEndpoint() Endpoint {
return o.endpoints.Token
}
func (o *openidProvider) IntrospectionEndpoint() Endpoint {
func (o *Provider) IntrospectionEndpoint() Endpoint {
return o.endpoints.Introspection
}
func (o *openidProvider) UserinfoEndpoint() Endpoint {
func (o *Provider) UserinfoEndpoint() Endpoint {
return o.endpoints.Userinfo
}
func (o *openidProvider) RevocationEndpoint() Endpoint {
func (o *Provider) RevocationEndpoint() Endpoint {
return o.endpoints.Revocation
}
func (o *openidProvider) EndSessionEndpoint() Endpoint {
func (o *Provider) EndSessionEndpoint() Endpoint {
return o.endpoints.EndSession
}
func (o *openidProvider) KeysEndpoint() Endpoint {
func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
return o.endpoints.DeviceAuthorization
}
func (o *Provider) KeysEndpoint() Endpoint {
return o.endpoints.JwksURI
}
func (o *openidProvider) AuthMethodPostSupported() bool {
func (o *Provider) AuthMethodPostSupported() bool {
return o.config.AuthMethodPost
}
func (o *openidProvider) CodeMethodS256Supported() bool {
func (o *Provider) CodeMethodS256Supported() bool {
return o.config.CodeMethodS256
}
func (o *openidProvider) AuthMethodPrivateKeyJWTSupported() bool {
func (o *Provider) AuthMethodPrivateKeyJWTSupported() bool {
return o.config.AuthMethodPrivateKeyJWT
}
func (o *openidProvider) TokenEndpointSigningAlgorithmsSupported() []string {
func (o *Provider) TokenEndpointSigningAlgorithmsSupported() []string {
return []string{"RS256"}
}
func (o *openidProvider) GrantTypeRefreshTokenSupported() bool {
func (o *Provider) GrantTypeRefreshTokenSupported() bool {
return o.config.GrantTypeRefreshToken
}
func (o *openidProvider) GrantTypeTokenExchangeSupported() bool {
return false
func (o *Provider) GrantTypeTokenExchangeSupported() bool {
_, ok := o.storage.(TokenExchangeStorage)
return ok
}
func (o *openidProvider) GrantTypeJWTAuthorizationSupported() bool {
func (o *Provider) GrantTypeJWTAuthorizationSupported() bool {
return true
}
func (o *openidProvider) GrantTypeClientCredentialsSupported() bool {
func (o *Provider) GrantTypeDeviceCodeSupported() bool {
_, ok := o.storage.(DeviceAuthorizationStorage)
return ok
}
func (o *Provider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
return true
}
func (o *Provider) IntrospectionEndpointSigningAlgorithmsSupported() []string {
return []string{"RS256"}
}
func (o *Provider) GrantTypeClientCredentialsSupported() bool {
_, ok := o.storage.(ClientCredentialsStorage)
return ok
}
func (o *openidProvider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
func (o *Provider) RevocationAuthMethodPrivateKeyJWTSupported() bool {
return true
}
func (o *openidProvider) IntrospectionEndpointSigningAlgorithmsSupported() []string {
func (o *Provider) RevocationEndpointSigningAlgorithmsSupported() []string {
return []string{"RS256"}
}
func (o *openidProvider) RevocationAuthMethodPrivateKeyJWTSupported() bool {
return true
}
func (o *openidProvider) RevocationEndpointSigningAlgorithmsSupported() []string {
return []string{"RS256"}
}
func (o *openidProvider) RequestObjectSupported() bool {
func (o *Provider) RequestObjectSupported() bool {
return o.config.RequestObjectSupported
}
func (o *openidProvider) RequestObjectSigningAlgorithmsSupported() []string {
func (o *Provider) RequestObjectSigningAlgorithmsSupported() []string {
return []string{"RS256"}
}
func (o *openidProvider) SupportedUILocales() []language.Tag {
func (o *Provider) SupportedUILocales() []language.Tag {
return o.config.SupportedUILocales
}
func (o *openidProvider) Storage() Storage {
func (o *Provider) DeviceAuthorization() DeviceAuthorizationConfig {
return o.config.DeviceAuthorization
}
func (o *Provider) Storage() Storage {
return o.storage
}
func (o *openidProvider) Decoder() httphelper.Decoder {
func (o *Provider) Decoder() httphelper.Decoder {
return o.decoder
}
func (o *openidProvider) Encoder() httphelper.Encoder {
func (o *Provider) Encoder() httphelper.Encoder {
return o.encoder
}
func (o *openidProvider) IDTokenHintVerifier() IDTokenHintVerifier {
if o.idTokenHintVerifier == nil {
o.idTokenHintVerifier = NewIDTokenHintVerifier(o.Issuer(), o.openIDKeySet(), o.idTokenHintVerifierOpts...)
}
return o.idTokenHintVerifier
func (o *Provider) IDTokenHintVerifier(ctx context.Context) IDTokenHintVerifier {
return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...)
}
func (o *openidProvider) JWTProfileVerifier() JWTProfileVerifier {
if o.jwtProfileVerifier == nil {
o.jwtProfileVerifier = NewJWTProfileVerifier(o.Storage(), o.Issuer(), 1*time.Hour, time.Second)
}
return o.jwtProfileVerifier
func (o *Provider) JWTProfileVerifier(ctx context.Context) JWTProfileVerifier {
return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second)
}
func (o *openidProvider) AccessTokenVerifier() AccessTokenVerifier {
if o.accessTokenVerifier == nil {
o.accessTokenVerifier = NewAccessTokenVerifier(o.Issuer(), o.openIDKeySet(), o.accessTokenVerifierOpts...)
}
return o.accessTokenVerifier
func (o *Provider) AccessTokenVerifier(ctx context.Context) AccessTokenVerifier {
return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...)
}
func (o *openidProvider) openIDKeySet() oidc.KeySet {
func (o *Provider) openIDKeySet() oidc.KeySet {
if o.keySet == nil {
o.keySet = &openIDKeySet{o.Storage()}
}
return o.keySet
}
func (o *openidProvider) Crypto() Crypto {
func (o *Provider) Crypto() Crypto {
return o.crypto
}
func (o *openidProvider) DefaultLogoutRedirectURI() string {
func (o *Provider) DefaultLogoutRedirectURI() string {
return o.config.DefaultLogoutRedirectURI
}
func (o *openidProvider) Signer() Signer {
return o.signer
}
func (o *openidProvider) Probes() []ProbesFn {
func (o *Provider) Probes() []ProbesFn {
return []ProbesFn{
ReadySigner(o.Signer()),
ReadyStorage(o.Storage()),
}
}
func (o *openidProvider) HttpHandler() http.Handler {
func (o *Provider) HttpHandler() http.Handler {
return o.httpHandler
}
@ -372,22 +383,31 @@ type openIDKeySet struct {
// VerifySignature implements the oidc.KeySet interface
// providing an implementation for the keys stored in the OP Storage interface
func (o *openIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
keySet, err := o.Storage.GetKeySet(ctx)
keySet, err := o.Storage.KeySet(ctx)
if err != nil {
return nil, fmt.Errorf("error fetching keys: %w", err)
}
keyID, alg := oidc.GetKeyIDAndAlg(jws)
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keySet.Keys...)
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, jsonWebKeySet(keySet).Keys...)
if err != nil {
return nil, fmt.Errorf("invalid signature: %w", err)
}
return jws.Verify(&key)
}
type Option func(o *openidProvider) error
type Option func(o *Provider) error
// WithAllowInsecure allows the use of http (instead of https) for issuers
// this is not recommended for production use and violates the OIDC specification
func WithAllowInsecure() Option {
return func(o *Provider) error {
o.insecure = true
return nil
}
}
func WithCustomAuthEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
}
@ -397,7 +417,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option {
}
func WithCustomTokenEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
}
@ -407,7 +427,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option {
}
func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
}
@ -417,7 +437,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
}
func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
}
@ -427,7 +447,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
}
func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
}
@ -437,7 +457,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
}
func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
}
@ -447,7 +467,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
}
func WithCustomKeysEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
}
@ -457,7 +477,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
}
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option {
return func(o *openidProvider) error {
return func(o *Provider) error {
o.endpoints.Authorization = auth
o.endpoints.Token = token
o.endpoints.Userinfo = userInfo
@ -469,38 +489,32 @@ func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys End
}
func WithHttpInterceptors(interceptors ...HttpInterceptor) Option {
return func(o *openidProvider) error {
return func(o *Provider) error {
o.interceptors = append(o.interceptors, interceptors...)
return nil
}
}
func WithAccessTokenVerifierOpts(opts ...AccessTokenVerifierOpt) Option {
return func(o *openidProvider) error {
return func(o *Provider) error {
o.accessTokenVerifierOpts = opts
return nil
}
}
func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option {
return func(o *openidProvider) error {
return func(o *Provider) error {
o.idTokenHintVerifierOpts = opts
return nil
}
}
func buildInterceptor(interceptors ...HttpInterceptor) func(http.HandlerFunc) http.Handler {
return func(handlerFunc http.HandlerFunc) http.Handler {
handler := handlerFuncToHandler(handlerFunc)
func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler {
issuerInterceptor := NewIssuerInterceptor(i)
return func(handler http.Handler) http.Handler {
for i := len(interceptors) - 1; i >= 0; i-- {
handler = interceptors[i](handler)
}
return handler
return cors.New(defaultCORSOptions).Handler(issuerInterceptor.Handler(handler))
}
}
func handlerFuncToHandler(handlerFunc http.HandlerFunc) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerFunc(w, r)
})
}

392
pkg/op/op_test.go Normal file
View file

@ -0,0 +1,392 @@
package op_test
import (
"context"
"crypto/sha256"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/example/server/storage"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"golang.org/x/text/language"
)
var testProvider op.OpenIDProvider
const (
testIssuer = "https://localhost:9998/"
pathLoggedOut = "/logged-out"
)
func init() {
config := &op.Config{
CryptoKey: sha256.Sum256([]byte("test")),
DefaultLogoutRedirectURI: pathLoggedOut,
CodeMethodS256: true,
AuthMethodPost: true,
AuthMethodPrivateKeyJWT: true,
GrantTypeRefreshToken: true,
RequestObjectSupported: true,
SupportedUILocales: []language.Tag{language.English},
DeviceAuthorization: op.DeviceAuthorizationConfig{
Lifetime: 5 * time.Minute,
PollInterval: 5 * time.Second,
UserFormURL: testIssuer + "device",
UserCode: op.UserCodeBase20,
},
}
storage.RegisterClients(
storage.NativeClient("native"),
storage.WebClient("web", "secret", "https://example.com"),
storage.WebClient("api", "secret"),
)
var err error
testProvider, err = op.NewOpenIDProvider(testIssuer, config,
storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
)
if err != nil {
panic(err)
}
}
type routesTestStorage interface {
op.Storage
AuthRequestDone(id string) error
}
func mapAsValues(m map[string]string) string {
values := make(url.Values, len(m))
for k, v := range m {
values.Set(k, v)
}
return values.Encode()
}
func TestRoutes(t *testing.T) {
storage := testProvider.Storage().(routesTestStorage)
ctx := op.ContextWithIssuer(context.Background(), testIssuer)
client, err := storage.GetClientByClientID(ctx, "web")
require.NoError(t, err)
oidcAuthReq := &oidc.AuthRequest{
ClientID: client.GetID(),
RedirectURI: "https://example.com",
MaxAge: gu.Ptr[uint](300),
Scopes: oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, oidc.ScopeEmail, oidc.ScopeProfile, oidc.ScopePhone},
ResponseType: oidc.ResponseTypeCode,
}
authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1")
require.NoError(t, err)
storage.AuthRequestDone(authReq.GetID())
accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
require.NoError(t, err)
accessTokenRevoke, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
require.NoError(t, err)
idToken, err := op.CreateIDToken(ctx, testIssuer, authReq, time.Hour, accessToken, "123", storage, client)
require.NoError(t, err)
jwtToken, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeJWT, testProvider, client, "")
require.NoError(t, err)
oidcAuthReq.IDTokenHint = idToken
serverURL, err := url.Parse(testIssuer)
require.NoError(t, err)
type basicAuth struct {
username, password string
}
tests := []struct {
name string
method string
path string
basicAuth *basicAuth
header map[string]string
values map[string]string
body map[string]string
wantCode int
headerContains map[string]string
json string // test for exact json output
contains []string // when the body output is not constant, we just check for snippets to be present in the response
}{
{
name: "health",
method: http.MethodGet,
path: "/healthz",
wantCode: http.StatusOK,
json: `{"status":"ok"}`,
},
{
name: "ready",
method: http.MethodGet,
path: "/ready",
wantCode: http.StatusOK,
json: `{"status":"ok"}`,
},
{
name: "discovery",
method: http.MethodGet,
path: oidc.DiscoveryEndpoint,
wantCode: http.StatusOK,
json: `{"issuer":"https://localhost:9998/","authorization_endpoint":"https://localhost:9998/authorize","token_endpoint":"https://localhost:9998/oauth/token","introspection_endpoint":"https://localhost:9998/oauth/introspect","userinfo_endpoint":"https://localhost:9998/userinfo","revocation_endpoint":"https://localhost:9998/revoke","end_session_endpoint":"https://localhost:9998/end_session","device_authorization_endpoint":"https://localhost:9998/device_authorization","jwks_uri":"https://localhost:9998/keys","scopes_supported":["openid","profile","email","phone","address","offline_access"],"response_types_supported":["code","id_token","id_token token"],"grant_types_supported":["authorization_code","implicit","refresh_token","client_credentials","urn:ietf:params:oauth:grant-type:token-exchange","urn:ietf:params:oauth:grant-type:jwt-bearer","urn:ietf:params:oauth:grant-type:device_code"],"subject_types_supported":["public"],"id_token_signing_alg_values_supported":["RS256"],"request_object_signing_alg_values_supported":["RS256"],"token_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"token_endpoint_auth_signing_alg_values_supported":["RS256"],"revocation_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["RS256"],"introspection_endpoint_auth_methods_supported":["client_secret_basic","private_key_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["RS256"],"claims_supported":["sub","aud","exp","iat","iss","auth_time","nonce","acr","amr","c_hash","at_hash","act","scopes","client_id","azp","preferred_username","name","family_name","given_name","locale","email","email_verified","phone_number","phone_number_verified"],"code_challenge_methods_supported":["S256"],"ui_locales_supported":["en"],"request_parameter_supported":true,"request_uri_parameter_supported":false}`,
},
{
name: "authorization",
method: http.MethodGet,
path: testProvider.AuthorizationEndpoint().Relative(),
values: map[string]string{
"client_id": client.GetID(),
"redirect_uri": "https://example.com",
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"response_type": string(oidc.ResponseTypeCode),
},
wantCode: http.StatusFound,
headerContains: map[string]string{"Location": "/login/username?authRequestID="},
},
{
name: "authorization callback",
method: http.MethodGet,
path: testProvider.AuthorizationEndpoint().Relative() + "/callback",
values: map[string]string{"id": authReq.GetID()},
wantCode: http.StatusFound,
headerContains: map[string]string{"Location": "https://example.com?code="},
contains: []string{
`<a href="https://example.com?code=`,
">Found</a>.",
},
},
{
// This call will fail. A successfull test is already
// part of client/integration_test.go
name: "code exchange",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeCode),
"code": "123",
},
wantCode: http.StatusUnauthorized,
json: `{"error":"invalid_client"}`,
},
{
name: "JWT authorization",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeBearer),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"assertion": jwtToken,
},
wantCode: http.StatusBadRequest,
json: "{\"error\":\"server_error\",\"error_description\":\"audience is not valid: Audience must contain client_id \\\"https://localhost:9998/\\\"\"}",
},
{
name: "Token exchange",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"grant_type": string(oidc.GrantTypeTokenExchange),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"subject_token": jwtToken,
"subject_token_type": string(oidc.AccessTokenType),
},
wantCode: http.StatusOK,
contains: []string{
`{"access_token":"`,
`","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`,
},
},
{
name: "Client credentials exchange",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
basicAuth: &basicAuth{"sid1", "verysecret"},
values: map[string]string{
"grant_type": string(oidc.GrantTypeClientCredentials),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
},
wantCode: http.StatusOK,
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`},
},
{
// This call will fail. A successfull test is already
// part of device_test.go
name: "device token",
method: http.MethodPost,
path: testProvider.TokenEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
header: map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
},
body: map[string]string{
"grant_type": string(oidc.GrantTypeDeviceCode),
"device_code": "123",
},
wantCode: http.StatusBadRequest,
json: `{"error":"access_denied","error_description":"The authorization request was denied."}`,
},
{
name: "missing grant type",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
wantCode: http.StatusBadRequest,
json: `{"error":"invalid_request","error_description":"grant_type missing"}`,
},
{
name: "unsupported grant type",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": "foo",
},
wantCode: http.StatusBadRequest,
json: `{"error":"unsupported_grant_type","error_description":"foo not supported"}`,
},
{
name: "introspection",
method: http.MethodGet,
path: testProvider.IntrospectionEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"token": accessToken,
},
wantCode: http.StatusOK,
json: `{"active":true,"scope":"openid offline_access email profile phone","client_id":"web","sub":"id1","username":"test-user@localhost","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
},
{
name: "user info",
method: http.MethodGet,
path: testProvider.UserinfoEndpoint().Relative(),
header: map[string]string{
"authorization": "Bearer " + accessToken,
},
wantCode: http.StatusOK,
json: `{"sub":"id1","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
},
{
name: "refresh token",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeRefreshToken),
"refresh_token": refreshToken,
"client_id": client.GetID(),
"client_secret": "secret",
},
wantCode: http.StatusOK,
contains: []string{
`{"access_token":"`,
`","token_type":"Bearer","refresh_token":"`,
`","expires_in":299,"id_token":"`,
},
},
{
name: "revoke",
method: http.MethodGet,
path: testProvider.RevocationEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"token": accessTokenRevoke,
},
wantCode: http.StatusOK,
},
{
name: "end session",
method: http.MethodGet,
path: testProvider.EndSessionEndpoint().Relative(),
values: map[string]string{
"id_token_hint": idToken,
"client_id": "web",
},
wantCode: http.StatusFound,
headerContains: map[string]string{"Location": "/logged-out"},
contains: []string{`<a href="/logged-out">Found</a>.`},
},
{
name: "keys",
method: http.MethodGet,
path: testProvider.KeysEndpoint().Relative(),
wantCode: http.StatusOK,
contains: []string{
`{"keys":[{"use":"sig","kty":"RSA","kid":"`,
`","alg":"RS256","n":"`, `","e":"AQAB"}]}`,
},
},
{
name: "device authorization",
method: http.MethodGet,
path: testProvider.DeviceAuthorizationEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
},
wantCode: http.StatusOK,
contains: []string{
`{"device_code":"`, `","user_code":"`,
`","verification_uri":"https://localhost:9998/device"`,
`"verification_uri_complete":"https://localhost:9998/device?user_code=`,
`","expires_in":300,"interval":5}`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u := gu.PtrCopy(serverURL)
u.Path = tt.path
if tt.values != nil {
u.RawQuery = mapAsValues(tt.values)
}
var body io.Reader
if tt.body != nil {
body = strings.NewReader(mapAsValues(tt.body))
}
req := httptest.NewRequest(tt.method, u.String(), body)
for k, v := range tt.header {
req.Header.Set(k, v)
}
if tt.basicAuth != nil {
req.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
}
rec := httptest.NewRecorder()
testProvider.HttpHandler().ServeHTTP(rec, req)
resp := rec.Result()
require.NoError(t, err)
assert.Equal(t, tt.wantCode, resp.StatusCode)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
respBodyString := string(respBody)
t.Log(respBodyString)
t.Log(resp.Header)
if tt.json != "" {
assert.JSONEq(t, tt.json, respBodyString)
}
for _, c := range tt.contains {
assert.Contains(t, respBodyString, c)
}
for k, v := range tt.headerContains {
assert.Contains(t, resp.Header.Get(k), v)
}
})
}
}

View file

@ -5,7 +5,7 @@ import (
"errors"
"net/http"
httphelper "github.com/zitadel/oidc/pkg/http"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
)
type ProbesFn func(context.Context) error
@ -31,15 +31,6 @@ func Readiness(w http.ResponseWriter, r *http.Request, probes ...ProbesFn) {
ok(w)
}
func ReadySigner(s Signer) ProbesFn {
return func(ctx context.Context) error {
if s == nil {
return errors.New("no signer")
}
return s.Health(ctx)
}
}
func ReadyStorage(s Storage) ProbesFn {
return func(ctx context.Context) error {
if s == nil {

View file

@ -6,14 +6,14 @@ import (
"net/url"
"path"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type SessionEnder interface {
Decoder() httphelper.Decoder
Storage() Storage
IDTokenHintVerifier() IDTokenHintVerifier
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
DefaultLogoutRedirectURI() string
}
@ -60,7 +60,7 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
RedirectURI: ender.DefaultLogoutRedirectURI(),
}
if req.IdTokenHint != "" {
claims, err := VerifyIDTokenHint(ctx, req.IdTokenHint, ender.IDTokenHintVerifier())
claims, err := VerifyIDTokenHint[*oidc.TokenClaims](ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx))
if err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err)
}

View file

@ -1,88 +1,38 @@
package op
import (
"context"
"errors"
"sync"
"github.com/zitadel/logging"
"gopkg.in/square/go-jose.v2"
)
type Signer interface {
Health(ctx context.Context) error
Signer() jose.Signer
var (
ErrSignerCreationFailed = errors.New("signer creation failed")
)
type SigningKey interface {
SignatureAlgorithm() jose.SignatureAlgorithm
Key() interface{}
ID() string
}
type tokenSigner struct {
signer jose.Signer
storage AuthStorage
alg jose.SignatureAlgorithm
lock sync.RWMutex
}
func NewSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer {
s := &tokenSigner{
storage: storage,
}
select {
case <-ctx.Done():
return nil
case key := <-keyCh:
s.exchangeSigningKey(key)
}
go s.refreshSigningKey(ctx, keyCh)
return s
}
func (s *tokenSigner) Health(_ context.Context) error {
if s.signer == nil {
return errors.New("no signer")
}
if string(s.alg) == "" {
return errors.New("no signing algorithm")
}
return nil
}
func (s *tokenSigner) Signer() jose.Signer {
s.lock.RLock()
defer s.lock.RUnlock()
return s.signer
}
func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) {
for {
select {
case <-ctx.Done():
return
case key := <-keyCh:
s.exchangeSigningKey(key)
}
}
}
func (s *tokenSigner) exchangeSigningKey(key jose.SigningKey) {
s.lock.Lock()
defer s.lock.Unlock()
s.alg = key.Algorithm
if key.Algorithm == "" || key.Key == nil {
s.signer = nil
logging.Warn("signer has no key")
return
}
var err error
s.signer, err = jose.NewSigner(key, &jose.SignerOptions{})
func SignerFromKey(key SigningKey) (jose.Signer, error) {
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: key.SignatureAlgorithm(),
Key: &jose.JSONWebKey{
Key: key.Key(),
KeyID: key.ID(),
},
}, &jose.SignerOptions{})
if err != nil {
logging.New().WithError(err).Error("error creating signer")
return
return nil, ErrSignerCreationFailed //TODO: log / wrap error?
}
logging.Info("signer exchanged signing key")
return signer, nil
}
func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
return s.alg
type Key interface {
ID() string
Algorithm() jose.SignatureAlgorithm
Use() string
Key() interface{}
}

View file

@ -7,7 +7,7 @@ import (
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type AuthStorage interface {
@ -25,6 +25,8 @@ type AuthStorage interface {
//
// * *oidc.JWTTokenRequest from a JWT that is the assertion value of a JWT Profile
// Grant: https://datatracker.ietf.org/doc/html/rfc7523#section-2.1
//
// * TokenExchangeRequest as returned by ValidateTokenExchangeRequest
CreateAccessToken(context.Context, TokenRequest) (accessTokenID string, expiration time.Time, err error)
// The TokenRequest parameter of CreateAccessAndRefreshTokens can be any of:
@ -36,6 +38,8 @@ type AuthStorage interface {
// * AuthRequest as by returned by the AuthRequestByID or AuthRequestByCode (above).
// Used for the authorization code flow which requested offline_access scope and
// registered the refresh_token grant type in advance
//
// * TokenExchangeRequest as returned by ValidateTokenExchangeRequest
CreateAccessAndRefreshTokens(ctx context.Context, request TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshTokenID string, expiration time.Time, err error)
TokenRequestByRefreshToken(ctx context.Context, refreshTokenID string) (RefreshTokenRequest, error)
@ -44,44 +48,85 @@ type AuthStorage interface {
// RevokeToken should revoke a token. In the situation that the original request was to
// revoke an access token, then tokenOrTokenID will be a tokenID and userID will be set
// but if the original request was for a refresh token, then userID will be empty and
// tokenOrTokenID will be the refresh token, not its ID.
// tokenOrTokenID will be the refresh token, not its ID. RevokeToken depends upon GetRefreshTokenInfo
// to get information from refresh tokens that are not either "<tokenID>:<userID>" strings
// nor JWTs.
RevokeToken(ctx context.Context, tokenOrTokenID string, userID string, clientID string) *oidc.Error
GetSigningKey(context.Context, chan<- jose.SigningKey)
GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
}
// CanRefreshTokenInfo is an optional additional interface that Storage can support.
// Supporting CanRefreshTokenInfo is required to be able to (revoke) a refresh token that
// is neither an encrypted string of <tokenID>:<userID> nor a JWT.
type CanRefreshTokenInfo interface {
// GetRefreshTokenInfo must return ErrInvalidRefreshToken when presented
// with a token that is not a refresh token.
GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error)
SigningKey(context.Context) (SigningKey, error)
SignatureAlgorithms(context.Context) ([]jose.SignatureAlgorithm, error)
KeySet(context.Context) ([]Key, error)
}
type ClientCredentialsStorage interface {
ClientCredentials(ctx context.Context, clientID, clientSecret string) (Client, error)
ClientCredentialsTokenRequest(ctx context.Context, clientID string, scopes []string) (TokenRequest, error)
}
type TokenExchangeStorage interface {
// ValidateTokenExchangeRequest will be called to validate parsed (including tokens) Token Exchange Grant request.
//
// Important validations can include:
// - permissions
// - set requested token type to some default value if it is empty (rfc 8693 allows it) using SetRequestedTokenType method.
// Depending on RequestedTokenType - the following tokens will be issued:
// - RefreshTokenType - both access and refresh tokens
// - AccessTokenType - only access token
// - IDTokenType - only id token
// - validation of subject's token type on possibility to be exchanged to the requested token type (according to your requirements)
// - scopes (and update them using SetCurrentScopes method)
// - set new subject if it differs from exchange subject (impersonation flow)
//
// Request will include subject's and/or actor's token claims if correspinding tokens are access/id_token issued by op
// or third party tokens parsed by TokenExchangeTokensVerifierStorage interface methods.
ValidateTokenExchangeRequest(ctx context.Context, request TokenExchangeRequest) error
// CreateTokenExchangeRequest will be called after parsing and validating token exchange request.
// Stored request is not accessed later by op - so it is up to implementer to decide
// should this method actually store the request or not (common use case - store for it for audit purposes)
CreateTokenExchangeRequest(ctx context.Context, request TokenExchangeRequest) error
// GetPrivateClaimsFromTokenExchangeRequest will be called during access token creation.
// Claims evaluation can be based on all validated request data available, including: scopes, resource, audience, etc.
GetPrivateClaimsFromTokenExchangeRequest(ctx context.Context, request TokenExchangeRequest) (claims map[string]interface{}, err error)
// SetUserinfoFromTokenExchangeRequest will be called during id token creation.
// Claims evaluation can be based on all validated request data available, including: scopes, resource, audience, etc.
SetUserinfoFromTokenExchangeRequest(ctx context.Context, userinfo *oidc.UserInfo, request TokenExchangeRequest) error
}
// TokenExchangeTokensVerifierStorage is an optional interface used in token exchange process to verify tokens
// issued by third-party applications. If interface is not implemented - only tokens issued by op will be exchanged.
type TokenExchangeTokensVerifierStorage interface {
VerifyExchangeSubjectToken(ctx context.Context, token string, tokenType oidc.TokenType) (tokenIDOrToken string, subject string, tokenClaims map[string]interface{}, err error)
VerifyExchangeActorToken(ctx context.Context, token string, tokenType oidc.TokenType) (tokenIDOrToken string, actor string, tokenClaims map[string]interface{}, err error)
}
var ErrInvalidRefreshToken = errors.New("invalid_refresh_token")
type ClientCredentialsStorage interface {
ClientCredentialsTokenRequest(ctx context.Context, clientID string, scopes []string) (TokenRequest, error)
}
type OPStorage interface {
// GetClientByClientID loads a Client. The returned Client is never cached and is only used to
// handle the current request.
GetClientByClientID(ctx context.Context, clientID string) (Client, error)
AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error
SetUserinfoFromScopes(ctx context.Context, userinfo oidc.UserInfoSetter, userID, clientID string, scopes []string) error
SetUserinfoFromToken(ctx context.Context, userinfo oidc.UserInfoSetter, tokenID, subject, origin string) error
SetIntrospectionFromToken(ctx context.Context, userinfo oidc.IntrospectionResponse, tokenID, subject, clientID string) error
SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error
SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error
SetIntrospectionFromToken(ctx context.Context, userinfo *oidc.IntrospectionResponse, tokenID, subject, clientID string) error
GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]interface{}, error)
// GetKeyByIDAndUserID is mis-named. It does not pass userID. Instead
// it passes the clientID.
GetKeyByIDAndUserID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error)
GetKeyByIDAndClientID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error)
ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error)
}
// JWTProfileTokenStorage is an additional, optional storage to implement
// implementing it, allows specifying the [AccessTokenType] of the access_token returned form the JWT Profile TokenRequest
type JWTProfileTokenStorage interface {
JWTProfileTokenType(ctx context.Context, request TokenRequest) (AccessTokenType, error)
}
// Storage is a required parameter for NewOpenIDProvider(). In addition to the
// embedded interfaces below, if the passed Storage implements ClientCredentialsStorage
// then the grant type "client_credentials" will be supported. In that case, the access
@ -102,3 +147,50 @@ type EndSessionRequest struct {
ClientID string
RedirectURI string
}
var ErrDuplicateUserCode = errors.New("user code already exists")
type DeviceAuthorizationState struct {
ClientID string
Scopes []string
Expires time.Time
Done bool
Subject string
Denied bool
}
type DeviceAuthorizationStorage interface {
// StoreDeviceAuthorizationRequest stores a new device authorization request in the database.
// User code will be used by the user to complete the login flow and must be unique.
// ErrDuplicateUserCode signals the caller should try again with a new code.
//
// Note that user codes are low entropy keys and when many exist in the
// database, the change for collisions increases. Therefore implementers
// of this interface must make sure that user codes of expired authentication flows are purged,
// after some time.
StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error
// GetDeviceAuthorizatonState returns the current state of the device authorization flow in the database.
// The method is polled untill the the authorization is eighter Completed, Expired or Denied.
GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error)
// GetDeviceAuthorizationByUserCode resturn the current state of the device authorization flow,
// identified by the user code.
GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*DeviceAuthorizationState, error)
// CompleteDeviceAuthorization marks a device authorization entry as Completed,
// identified by userCode. The Subject is added to the state, so that
// GetDeviceAuthorizatonState can use it to create a new Access Token.
CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error
// DenyDeviceAuthorization marks a device authorization entry as Denied.
DenyDeviceAuthorization(ctx context.Context, userCode string) error
}
func assertDeviceStorage(s Storage) (DeviceAuthorizationStorage, error) {
storage, ok := s.(DeviceAuthorizationStorage)
if !ok {
return nil, oidc.ErrUnsupportedGrantType().WithDescription("device_code grant not supported")
}
return storage, nil
}

View file

@ -4,14 +4,12 @@ import (
"context"
"time"
"github.com/zitadel/oidc/pkg/crypto"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/pkg/strings"
"github.com/zitadel/oidc/v2/pkg/crypto"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/strings"
)
type TokenCreator interface {
Issuer() string
Signer() Signer
Storage() Storage
Crypto() Crypto
}
@ -22,6 +20,13 @@ type TokenRequest interface {
GetScopes() []string
}
type AccessTokenClient interface {
GetID() string
ClockSkew() time.Duration
RestrictAdditionalAccessTokenScopes() func(scopes []string) []string
GrantTypes() []oidc.GrantType
}
func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Client, creator TokenCreator, createAccessToken bool, code, refreshToken string) (*oidc.AccessTokenResponse, error) {
var accessToken, newRefreshToken string
var validity time.Duration
@ -32,7 +37,7 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli
return nil, err
}
}
idToken, err := CreateIDToken(ctx, creator.Issuer(), request, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer(), client)
idToken, err := CreateIDToken(ctx, IssuerFromContext(ctx), request, client.IDTokenLifetime(), accessToken, code, creator.Storage(), client)
if err != nil {
return nil, err
}
@ -57,7 +62,7 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli
}, nil
}
func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage, refreshToken string, client Client) (id, newRefreshToken string, exp time.Time, err error) {
func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage, refreshToken string, client AccessTokenClient) (id, newRefreshToken string, exp time.Time, err error) {
if needsRefreshToken(tokenRequest, client) {
return storage.CreateAccessAndRefreshTokens(ctx, tokenRequest, refreshToken)
}
@ -65,10 +70,12 @@ func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storag
return
}
func needsRefreshToken(tokenRequest TokenRequest, client Client) bool {
func needsRefreshToken(tokenRequest TokenRequest, client AccessTokenClient) bool {
switch req := tokenRequest.(type) {
case AuthRequest:
return strings.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode && ValidateGrantType(client, oidc.GrantTypeRefreshToken)
case TokenExchangeRequest:
return req.GetRequestedTokenType() == oidc.RefreshTokenType
case RefreshTokenRequest:
return true
default:
@ -76,7 +83,7 @@ func needsRefreshToken(tokenRequest TokenRequest, client Client) bool {
}
}
func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client Client, refreshToken string) (accessToken, newRefreshToken string, validity time.Duration, err error) {
func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client AccessTokenClient, refreshToken string) (accessToken, newRefreshToken string, validity time.Duration, err error) {
id, newRefreshToken, exp, err := createTokens(ctx, tokenRequest, creator.Storage(), refreshToken, client)
if err != nil {
return "", "", 0, err
@ -87,7 +94,7 @@ func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTok
}
validity = exp.Add(clockSkew).Sub(time.Now().UTC())
if accessTokenType == AccessTokenTypeJWT {
accessToken, err = CreateJWT(ctx, creator.Issuer(), tokenRequest, exp, id, creator.Signer(), client, creator.Storage())
accessToken, err = CreateJWT(ctx, IssuerFromContext(ctx), tokenRequest, exp, id, client, creator.Storage())
return
}
accessToken, err = CreateBearerToken(id, tokenRequest.GetSubject(), creator.Crypto())
@ -98,17 +105,41 @@ func CreateBearerToken(tokenID, subject string, crypto Crypto) (string, error) {
return crypto.Encrypt(tokenID + ":" + subject)
}
func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, signer Signer, client Client, storage Storage) (string, error) {
func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, client AccessTokenClient, storage Storage) (string, error) {
claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id, client.GetID(), client.ClockSkew())
if client != nil {
restrictedScopes := client.RestrictAdditionalAccessTokenScopes()(tokenRequest.GetScopes())
privateClaims, err := storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(restrictedScopes))
var (
privateClaims map[string]interface{}
err error
)
tokenExchangeRequest, okReq := tokenRequest.(TokenExchangeRequest)
teStorage, okStorage := storage.(TokenExchangeStorage)
if okReq && okStorage {
privateClaims, err = teStorage.GetPrivateClaimsFromTokenExchangeRequest(
ctx,
tokenExchangeRequest,
)
} else {
privateClaims, err = storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(restrictedScopes))
}
if err != nil {
return "", err
}
claims.SetPrivateClaims(privateClaims)
claims.Claims = privateClaims
}
return crypto.Sign(claims, signer.Signer())
signingKey, err := storage.SigningKey(ctx)
if err != nil {
return "", err
}
signer, err := SignerFromKey(signingKey)
if err != nil {
return "", err
}
return crypto.Sign(claims, signer)
}
type IDTokenRequest interface {
@ -120,7 +151,7 @@ type IDTokenRequest interface {
GetSubject() string
}
func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer, client Client) (string, error) {
func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, validity time.Duration, accessToken, code string, storage Storage, client Client) (string, error) {
exp := time.Now().UTC().Add(client.ClockSkew()).Add(validity)
var acr, nonce string
if authRequest, ok := request.(AuthRequest); ok {
@ -129,33 +160,50 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v
}
claims := oidc.NewIDTokenClaims(issuer, request.GetSubject(), request.GetAudience(), exp, request.GetAuthTime(), nonce, acr, request.GetAMR(), request.GetClientID(), client.ClockSkew())
scopes := client.RestrictAdditionalIdTokenScopes()(request.GetScopes())
signingKey, err := storage.SigningKey(ctx)
if err != nil {
return "", err
}
if accessToken != "" {
atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
atHash, err := oidc.ClaimHash(accessToken, signingKey.SignatureAlgorithm())
if err != nil {
return "", err
}
claims.SetAccessTokenHash(atHash)
claims.AccessTokenHash = atHash
if !client.IDTokenUserinfoClaimsAssertion() {
scopes = removeUserinfoScopes(scopes)
}
}
if len(scopes) > 0 {
userInfo := oidc.NewUserInfo()
tokenExchangeRequest, okReq := request.(TokenExchangeRequest)
teStorage, okStorage := storage.(TokenExchangeStorage)
if okReq && okStorage {
userInfo := new(oidc.UserInfo)
err := teStorage.SetUserinfoFromTokenExchangeRequest(ctx, userInfo, tokenExchangeRequest)
if err != nil {
return "", err
}
claims.SetUserInfo(userInfo)
} else if len(scopes) > 0 {
userInfo := new(oidc.UserInfo)
err := storage.SetUserinfoFromScopes(ctx, userInfo, request.GetSubject(), request.GetClientID(), scopes)
if err != nil {
return "", err
}
claims.SetUserinfo(userInfo)
claims.SetUserInfo(userInfo)
}
if code != "" {
codeHash, err := oidc.ClaimHash(code, signer.SignatureAlgorithm())
codeHash, err := oidc.ClaimHash(code, signingKey.SignatureAlgorithm())
if err != nil {
return "", err
}
claims.SetCodeHash(codeHash)
claims.CodeHash = codeHash
}
return crypto.Sign(claims, signer.Signer())
signer, err := SignerFromKey(signingKey)
if err != nil {
return "", err
}
return crypto.Sign(claims, signer)
}
func removeUserinfoScopes(scopes []string) []string {

View file

@ -5,8 +5,8 @@ import (
"net/http"
"net/url"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
// ClientCredentialsExchange handles the OAuth 2.0 client_credentials grant, including
@ -63,15 +63,15 @@ func ParseClientCredentialsRequest(r *http.Request, decoder httphelper.Decoder)
return request, nil
}
// ValidateClientCredentialsRequest validates the refresh_token request parameters including authorization check of the client
// and returns the data representing the original auth request corresponding to the refresh_token
// ValidateClientCredentialsRequest validates the client_credentials request parameters including authorization check of the client
// and returns a TokenRequest and Client implementation to be used in the client_credentials response, resp. creation of the corresponding access_token.
func ValidateClientCredentialsRequest(ctx context.Context, request *oidc.ClientCredentialsRequest, exchanger Exchanger) (TokenRequest, Client, error) {
storage, ok := exchanger.Storage().(ClientCredentialsStorage)
if !ok {
return nil, nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported")
}
client, err := AuthorizeClientCredentialsClient(ctx, request, exchanger)
client, err := AuthorizeClientCredentialsClient(ctx, request, storage)
if err != nil {
return nil, nil, err
}
@ -84,12 +84,8 @@ func ValidateClientCredentialsRequest(ctx context.Context, request *oidc.ClientC
return tokenRequest, client, nil
}
func AuthorizeClientCredentialsClient(ctx context.Context, request *oidc.ClientCredentialsRequest, exchanger Exchanger) (Client, error) {
if err := AuthorizeClientIDSecret(ctx, request.ClientID, request.ClientSecret, exchanger.Storage()); err != nil {
return nil, err
}
client, err := exchanger.Storage().GetClientByClientID(ctx, request.ClientID)
func AuthorizeClientCredentialsClient(ctx context.Context, request *oidc.ClientCredentialsRequest, storage ClientCredentialsStorage) (Client, error) {
client, err := storage.ClientCredentials(ctx, request.ClientID, request.ClientSecret)
if err != nil {
return nil, oidc.ErrInvalidClient().WithParent(err)
}
@ -102,7 +98,7 @@ func AuthorizeClientCredentialsClient(ctx context.Context, request *oidc.ClientC
}
func CreateClientCredentialsTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) {
accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeJWT, creator, client, "")
accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "")
if err != nil {
return nil, err
}

View file

@ -4,8 +4,8 @@ import (
"context"
"net/http"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
// CodeExchange handles the OAuth 2.0 authorization_code grant, including

View file

@ -1,11 +1,399 @@
package op
import (
"errors"
"context"
"net/http"
"net/url"
"strings"
"time"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
// TokenExchange will handle the OAuth 2.0 token exchange grant ("urn:ietf:params:oauth:grant-type:token-exchange")
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
RequestError(w, r, errors.New("unimplemented"))
type TokenExchangeRequest interface {
GetAMR() []string
GetAudience() []string
GetResourses() []string
GetAuthTime() time.Time
GetClientID() string
GetScopes() []string
GetSubject() string
GetRequestedTokenType() oidc.TokenType
GetExchangeSubject() string
GetExchangeSubjectTokenType() oidc.TokenType
GetExchangeSubjectTokenIDOrToken() string
GetExchangeSubjectTokenClaims() map[string]interface{}
GetExchangeActor() string
GetExchangeActorTokenType() oidc.TokenType
GetExchangeActorTokenIDOrToken() string
GetExchangeActorTokenClaims() map[string]interface{}
SetCurrentScopes(scopes []string)
SetRequestedTokenType(tt oidc.TokenType)
SetSubject(subject string)
}
type tokenExchangeRequest struct {
exchangeSubjectTokenIDOrToken string
exchangeSubjectTokenType oidc.TokenType
exchangeSubject string
exchangeSubjectTokenClaims map[string]interface{}
exchangeActorTokenIDOrToken string
exchangeActorTokenType oidc.TokenType
exchangeActor string
exchangeActorTokenClaims map[string]interface{}
resource []string
audience oidc.Audience
scopes oidc.SpaceDelimitedArray
requestedTokenType oidc.TokenType
clientID string
authTime time.Time
subject string
}
func (r *tokenExchangeRequest) GetAMR() []string {
return []string{}
}
func (r *tokenExchangeRequest) GetAudience() []string {
return r.audience
}
func (r *tokenExchangeRequest) GetResourses() []string {
return r.resource
}
func (r *tokenExchangeRequest) GetAuthTime() time.Time {
return r.authTime
}
func (r *tokenExchangeRequest) GetClientID() string {
return r.clientID
}
func (r *tokenExchangeRequest) GetScopes() []string {
return r.scopes
}
func (r *tokenExchangeRequest) GetRequestedTokenType() oidc.TokenType {
return r.requestedTokenType
}
func (r *tokenExchangeRequest) GetExchangeSubject() string {
return r.exchangeSubject
}
func (r *tokenExchangeRequest) GetExchangeSubjectTokenType() oidc.TokenType {
return r.exchangeSubjectTokenType
}
func (r *tokenExchangeRequest) GetExchangeSubjectTokenIDOrToken() string {
return r.exchangeSubjectTokenIDOrToken
}
func (r *tokenExchangeRequest) GetExchangeSubjectTokenClaims() map[string]interface{} {
return r.exchangeSubjectTokenClaims
}
func (r *tokenExchangeRequest) GetExchangeActor() string {
return r.exchangeActor
}
func (r *tokenExchangeRequest) GetExchangeActorTokenType() oidc.TokenType {
return r.exchangeActorTokenType
}
func (r *tokenExchangeRequest) GetExchangeActorTokenIDOrToken() string {
return r.exchangeActorTokenIDOrToken
}
func (r *tokenExchangeRequest) GetExchangeActorTokenClaims() map[string]interface{} {
return r.exchangeActorTokenClaims
}
func (r *tokenExchangeRequest) GetSubject() string {
return r.subject
}
func (r *tokenExchangeRequest) SetCurrentScopes(scopes []string) {
r.scopes = scopes
}
func (r *tokenExchangeRequest) SetRequestedTokenType(tt oidc.TokenType) {
r.requestedTokenType = tt
}
func (r *tokenExchangeRequest) SetSubject(subject string) {
r.subject = subject
}
// TokenExchange handles the OAuth 2.0 token exchange grant ("urn:ietf:params:oauth:grant-type:token-exchange")
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
tokenExchangeReq, clientID, clientSecret, err := ParseTokenExchangeRequest(r, exchanger.Decoder())
if err != nil {
RequestError(w, r, err)
}
tokenExchangeRequest, client, err := ValidateTokenExchangeRequest(r.Context(), tokenExchangeReq, clientID, clientSecret, exchanger)
if err != nil {
RequestError(w, r, err)
return
}
resp, err := CreateTokenExchangeResponse(r.Context(), tokenExchangeRequest, client, exchanger)
if err != nil {
RequestError(w, r, err)
return
}
httphelper.MarshalJSON(w, resp)
}
// ParseTokenExchangeRequest parses the http request into oidc.TokenExchangeRequest
func ParseTokenExchangeRequest(r *http.Request, decoder httphelper.Decoder) (_ *oidc.TokenExchangeRequest, clientID, clientSecret string, err error) {
err = r.ParseForm()
if err != nil {
return nil, "", "", oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
}
request := new(oidc.TokenExchangeRequest)
err = decoder.Decode(request, r.Form)
if err != nil {
return nil, "", "", oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
var ok bool
if clientID, clientSecret, ok = r.BasicAuth(); ok {
clientID, err = url.QueryUnescape(clientID)
if err != nil {
return nil, "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
clientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return nil, "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
}
return request, clientID, clientSecret, nil
}
// ValidateTokenExchangeRequest validates the token exchange request parameters including authorization check of the client,
// subject_token and actor_token
func ValidateTokenExchangeRequest(
ctx context.Context,
oidcTokenExchangeRequest *oidc.TokenExchangeRequest,
clientID, clientSecret string,
exchanger Exchanger,
) (TokenExchangeRequest, Client, error) {
if oidcTokenExchangeRequest.SubjectToken == "" {
return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token missing")
}
if oidcTokenExchangeRequest.SubjectTokenType == "" {
return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing")
}
storage := exchanger.Storage()
teStorage, ok := storage.(TokenExchangeStorage)
if !ok {
return nil, nil, oidc.ErrUnsupportedGrantType().WithDescription("token_exchange grant not supported")
}
client, err := AuthorizeTokenExchangeClient(ctx, clientID, clientSecret, exchanger)
if err != nil {
return nil, nil, err
}
if oidcTokenExchangeRequest.RequestedTokenType != "" && !oidcTokenExchangeRequest.RequestedTokenType.IsSupported() {
return nil, nil, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported")
}
if !oidcTokenExchangeRequest.SubjectTokenType.IsSupported() {
return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported")
}
if oidcTokenExchangeRequest.ActorTokenType != "" && !oidcTokenExchangeRequest.ActorTokenType.IsSupported() {
return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported")
}
exchangeSubjectTokenIDOrToken, exchangeSubject, exchangeSubjectTokenClaims, ok := GetTokenIDAndSubjectFromToken(ctx, exchanger,
oidcTokenExchangeRequest.SubjectToken, oidcTokenExchangeRequest.SubjectTokenType, false)
if !ok {
return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid")
}
var (
exchangeActorTokenIDOrToken, exchangeActor string
exchangeActorTokenClaims map[string]interface{}
)
if oidcTokenExchangeRequest.ActorToken != "" {
exchangeActorTokenIDOrToken, exchangeActor, exchangeActorTokenClaims, ok = GetTokenIDAndSubjectFromToken(ctx, exchanger,
oidcTokenExchangeRequest.ActorToken, oidcTokenExchangeRequest.ActorTokenType, true)
if !ok {
return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid")
}
}
req := &tokenExchangeRequest{
exchangeSubjectTokenIDOrToken: exchangeSubjectTokenIDOrToken,
exchangeSubjectTokenType: oidcTokenExchangeRequest.SubjectTokenType,
exchangeSubject: exchangeSubject,
exchangeSubjectTokenClaims: exchangeSubjectTokenClaims,
exchangeActorTokenIDOrToken: exchangeActorTokenIDOrToken,
exchangeActorTokenType: oidcTokenExchangeRequest.ActorTokenType,
exchangeActor: exchangeActor,
exchangeActorTokenClaims: exchangeActorTokenClaims,
subject: exchangeSubject,
resource: oidcTokenExchangeRequest.Resource,
audience: oidcTokenExchangeRequest.Audience,
scopes: oidcTokenExchangeRequest.Scopes,
requestedTokenType: oidcTokenExchangeRequest.RequestedTokenType,
clientID: client.GetID(),
authTime: time.Now(),
}
err = teStorage.ValidateTokenExchangeRequest(ctx, req)
if err != nil {
return nil, nil, err
}
err = teStorage.CreateTokenExchangeRequest(ctx, req)
if err != nil {
return nil, nil, err
}
return req, client, nil
}
func GetTokenIDAndSubjectFromToken(
ctx context.Context,
exchanger Exchanger,
token string,
tokenType oidc.TokenType,
isActor bool,
) (tokenIDOrToken, subject string, claims map[string]interface{}, ok bool) {
switch tokenType {
case oidc.AccessTokenType:
var accessTokenClaims *oidc.AccessTokenClaims
tokenIDOrToken, subject, accessTokenClaims, ok = getTokenIDAndClaims(ctx, exchanger, token)
claims = accessTokenClaims.Claims
case oidc.RefreshTokenType:
refreshTokenRequest, err := exchanger.Storage().TokenRequestByRefreshToken(ctx, token)
if err != nil {
break
}
tokenIDOrToken, subject, ok = token, refreshTokenRequest.GetSubject(), true
case oidc.IDTokenType:
idTokenClaims, err := VerifyIDTokenHint[*oidc.IDTokenClaims](ctx, token, exchanger.IDTokenHintVerifier(ctx))
if err != nil {
break
}
tokenIDOrToken, subject, claims, ok = token, idTokenClaims.Subject, idTokenClaims.Claims, true
}
if !ok {
if verifier, ok := exchanger.Storage().(TokenExchangeTokensVerifierStorage); ok {
var err error
if isActor {
tokenIDOrToken, subject, claims, err = verifier.VerifyExchangeActorToken(ctx, token, tokenType)
} else {
tokenIDOrToken, subject, claims, err = verifier.VerifyExchangeSubjectToken(ctx, token, tokenType)
}
if err != nil {
return "", "", nil, false
}
return tokenIDOrToken, subject, claims, true
}
return "", "", nil, false
}
return tokenIDOrToken, subject, claims, true
}
// AuthorizeTokenExchangeClient authorizes a client by validating the client_id and client_secret
func AuthorizeTokenExchangeClient(ctx context.Context, clientID, clientSecret string, exchanger Exchanger) (client Client, err error) {
if err := AuthorizeClientIDSecret(ctx, clientID, clientSecret, exchanger.Storage()); err != nil {
return nil, err
}
client, err = exchanger.Storage().GetClientByClientID(ctx, clientID)
if err != nil {
return nil, oidc.ErrInvalidClient().WithParent(err)
}
return client, nil
}
func CreateTokenExchangeResponse(
ctx context.Context,
tokenExchangeRequest TokenExchangeRequest,
client Client,
creator TokenCreator,
) (_ *oidc.TokenExchangeResponse, err error) {
var (
token, refreshToken, tokenType string
validity time.Duration
)
switch tokenExchangeRequest.GetRequestedTokenType() {
case oidc.AccessTokenType, oidc.RefreshTokenType:
token, refreshToken, validity, err = CreateAccessToken(ctx, tokenExchangeRequest, client.AccessTokenType(), creator, client, "")
if err != nil {
return nil, err
}
tokenType = oidc.BearerToken
case oidc.IDTokenType:
token, err = CreateIDToken(ctx, IssuerFromContext(ctx), tokenExchangeRequest, client.IDTokenLifetime(), "", "", creator.Storage(), client)
if err != nil {
return nil, err
}
// not applicable (see https://datatracker.ietf.org/doc/html/rfc8693#section-2-2-1-2-6)
tokenType = "N_A"
default:
// oidc.JWTTokenType and other custom token types are not supported for issuing.
// In the future it can be considered to have custom tokens generation logic injected via op configuration
// or via expanding Storage interface
oidc.ErrInvalidRequest().WithDescription("requested_token_type is invalid")
}
exp := uint64(validity.Seconds())
return &oidc.TokenExchangeResponse{
AccessToken: token,
IssuedTokenType: tokenExchangeRequest.GetRequestedTokenType(),
TokenType: tokenType,
ExpiresIn: exp,
RefreshToken: refreshToken,
Scopes: tokenExchangeRequest.GetScopes(),
}, nil
}
func getTokenIDAndClaims(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, *oidc.AccessTokenClaims, bool) {
tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
if err == nil {
splitToken := strings.Split(tokenIDSubject, ":")
if len(splitToken) != 2 {
return "", "", nil, false
}
return splitToken[0], splitToken[1], nil, true
}
accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
if err != nil {
return "", "", nil, false
}
return accessTokenClaims.JWTID, accessTokenClaims.Subject, accessTokenClaims, true
}

View file

@ -1,24 +1,24 @@
package op
import (
"context"
"errors"
"net/http"
"net/url"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type Introspector interface {
Decoder() httphelper.Decoder
Crypto() Crypto
Storage() Storage
AccessTokenVerifier() AccessTokenVerifier
AccessTokenVerifier(context.Context) AccessTokenVerifier
}
type IntrospectorJWTProfile interface {
Introspector
JWTProfileVerifier() JWTProfileVerifier
JWTProfileVerifier(context.Context) JWTProfileVerifier
}
func introspectionHandler(introspector Introspector) func(http.ResponseWriter, *http.Request) {
@ -28,7 +28,7 @@ func introspectionHandler(introspector Introspector) func(http.ResponseWriter, *
}
func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspector) {
response := oidc.NewIntrospectionResponse()
response := new(oidc.IntrospectionResponse)
token, clientID, err := ParseTokenIntrospectionRequest(r, introspector)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
@ -44,43 +44,24 @@ func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspecto
httphelper.MarshalJSON(w, response)
return
}
response.SetActive(true)
response.Active = true
httphelper.MarshalJSON(w, response)
}
func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) {
err = r.ParseForm()
clientID, authenticated, err := ClientIDFromRequest(r, introspector)
if err != nil {
return "", "", errors.New("unable to parse request")
return "", "", err
}
req := new(struct {
oidc.IntrospectionRequest
oidc.ClientAssertionParams
})
if !authenticated {
return "", "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
}
req := new(oidc.IntrospectionRequest)
err = introspector.Decoder().Decode(req, r.Form)
if err != nil {
return "", "", errors.New("unable to parse request")
}
if introspectorJWTProfile, ok := introspector.(IntrospectorJWTProfile); ok && req.ClientAssertion != "" {
profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, introspectorJWTProfile.JWTProfileVerifier())
if err == nil {
return req.Token, profile.Issuer, nil
}
}
clientID, clientSecret, ok := r.BasicAuth()
if ok {
clientID, err = url.QueryUnescape(clientID)
if err != nil {
return "", "", errors.New("invalid basic auth header")
}
clientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return "", "", errors.New("invalid basic auth header")
}
if err := introspector.Storage().AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil {
return "", "", err
}
return req.Token, clientID, nil
}
return "", "", errors.New("invalid authorization")
return req.Token, clientID, nil
}

View file

@ -5,13 +5,13 @@ import (
"net/http"
"time"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type JWTAuthorizationGrantExchanger interface {
Exchanger
JWTProfileVerifier() JWTProfileVerifier
JWTProfileVerifier(context.Context) JWTProfileVerifier
}
// JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant https://tools.ietf.org/html/rfc7523#section-2.1
@ -21,7 +21,7 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizati
RequestError(w, r, err)
}
tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, exchanger.JWTProfileVerifier())
tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, exchanger.JWTProfileVerifier(r.Context()))
if err != nil {
RequestError(w, r, err)
return
@ -53,27 +53,65 @@ func ParseJWTProfileGrantRequest(r *http.Request, decoder httphelper.Decoder) (*
return tokenReq, nil
}
// CreateJWTTokenResponse creates
// CreateJWTTokenResponse creates an access_token response for a JWT Profile Grant request
// by default the access_token is an opaque string, but can be specified by implementing the JWTProfileTokenStorage interface
func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator) (*oidc.AccessTokenResponse, error) {
id, exp, err := creator.Storage().CreateAccessToken(ctx, tokenRequest)
if err != nil {
return nil, err
}
accessToken, err := CreateBearerToken(id, tokenRequest.GetSubject(), creator.Crypto())
if err != nil {
return nil, err
// return an opaque token as default to not break current implementations
tokenType := AccessTokenTypeBearer
// the current CreateAccessToken function, esp. CreateJWT requires an implementation of an AccessTokenClient
client := &jwtProfileClient{
id: tokenRequest.GetSubject(),
}
// by implementing the JWTProfileTokenStorage the storage can specify the AccessTokenType to be returned
tokenStorage, ok := creator.Storage().(JWTProfileTokenStorage)
if ok {
var err error
tokenType, err = tokenStorage.JWTProfileTokenType(ctx, tokenRequest)
if err != nil {
return nil, err
}
}
accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "")
if err != nil {
return nil, err
}
return &oidc.AccessTokenResponse{
AccessToken: accessToken,
TokenType: oidc.BearerToken,
ExpiresIn: uint64(exp.Sub(time.Now().UTC()).Seconds()),
ExpiresIn: uint64(validity.Seconds()),
}, nil
}
// ParseJWTProfileRequest has been renamed to ParseJWTProfileGrantRequest
//
//deprecated: use ParseJWTProfileGrantRequest
// deprecated: use ParseJWTProfileGrantRequest
func ParseJWTProfileRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.JWTProfileGrantRequest, error) {
return ParseJWTProfileGrantRequest(r, decoder)
}
type jwtProfileClient struct {
id string
}
func (j *jwtProfileClient) GetID() string {
return j.id
}
func (j *jwtProfileClient) ClockSkew() time.Duration {
return 0
}
func (j *jwtProfileClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string {
return func(scopes []string) []string {
return scopes
}
}
func (j *jwtProfileClient) GrantTypes() []oidc.GrantType {
return []oidc.GrantType{
oidc.GrantTypeBearer,
}
}

View file

@ -6,9 +6,9 @@ import (
"net/http"
"time"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/pkg/strings"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/strings"
)
type RefreshTokenRequest interface {

View file

@ -5,15 +5,13 @@ import (
"net/http"
"net/url"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type Exchanger interface {
Issuer() string
Storage() Storage
Decoder() httphelper.Decoder
Signer() Signer
Crypto() Crypto
AuthMethodPostSupported() bool
AuthMethodPrivateKeyJWTSupported() bool
@ -21,6 +19,9 @@ type Exchanger interface {
GrantTypeTokenExchangeSupported() bool
GrantTypeJWTAuthorizationSupported() bool
GrantTypeClientCredentialsSupported() bool
GrantTypeDeviceCodeSupported() bool
AccessTokenVerifier(context.Context) AccessTokenVerifier
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
}
func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) {
@ -56,6 +57,11 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
ClientCredentialsExchange(w, r, exchanger)
return
}
case string(oidc.GrantTypeDeviceCode):
if exchanger.GrantTypeDeviceCodeSupported() {
DeviceAccessToken(w, r, exchanger)
return
}
case "":
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"))
return
@ -122,7 +128,7 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.C
// AuthorizePrivateJWTKey authorizes a client by validating the client_assertion's signature with a previously
// registered public key (JWT Profile)
func AuthorizePrivateJWTKey(ctx context.Context, clientAssertion string, exchanger JWTAuthorizationGrantExchanger) (Client, error) {
jwtReq, err := VerifyJWTAssertion(ctx, clientAssertion, exchanger.JWTProfileVerifier())
jwtReq, err := VerifyJWTAssertion(ctx, clientAssertion, exchanger.JWTProfileVerifier(ctx))
if err != nil {
return nil, err
}
@ -136,8 +142,8 @@ func AuthorizePrivateJWTKey(ctx context.Context, clientAssertion string, exchang
return client, nil
}
// ValidateGrantType ensures that the requested grant_type is allowed by the Client
func ValidateGrantType(client Client, grantType oidc.GrantType) bool {
// ValidateGrantType ensures that the requested grant_type is allowed by the client
func ValidateGrantType(client interface{ GrantTypes() []oidc.GrantType }, grantType oidc.GrantType) bool {
if client == nil {
return false
}

View file

@ -7,22 +7,22 @@ import (
"net/url"
"strings"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type Revoker interface {
Decoder() httphelper.Decoder
Crypto() Crypto
Storage() Storage
AccessTokenVerifier() AccessTokenVerifier
AccessTokenVerifier(context.Context) AccessTokenVerifier
AuthMethodPrivateKeyJWTSupported() bool
AuthMethodPostSupported() bool
}
type RevokerJWTProfile interface {
Revoker
JWTProfileVerifier() JWTProfileVerifier
JWTProfileVerifier(context.Context) JWTProfileVerifier
}
func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) {
@ -39,8 +39,8 @@ func Revoke(w http.ResponseWriter, r *http.Request, revoker Revoker) {
}
var subject string
doDecrypt := true
if canRefreshInfo, ok := revoker.Storage().(CanRefreshTokenInfo); ok && tokenTypeHint != "access_token" {
userID, tokenID, err := canRefreshInfo.GetRefreshTokenInfo(r.Context(), clientID, token)
if tokenTypeHint != "access_token" {
userID, tokenID, err := revoker.Storage().GetRefreshTokenInfo(r.Context(), clientID, token)
if err != nil {
// An invalid refresh token means that we'll try other things (leaving doDecrypt==true)
if !errors.Is(err, ErrInvalidRefreshToken) {
@ -87,7 +87,7 @@ func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, token
if !ok || !revoker.AuthMethodPrivateKeyJWTSupported() {
return "", "", "", oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
}
profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, revokerJWTProfile.JWTProfileVerifier())
profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, revokerJWTProfile.JWTProfileVerifier(r.Context()))
if err == nil {
return req.Token, req.TokenTypeHint, profile.Issuer, nil
}
@ -151,9 +151,9 @@ func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider Use
}
return splitToken[0], splitToken[1], true
}
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier())
accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
if err != nil {
return "", "", false
}
return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true
return accessTokenClaims.JWTID, accessTokenClaims.Subject, true
}

View file

@ -6,15 +6,15 @@ import (
"net/http"
"strings"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type UserinfoProvider interface {
Decoder() httphelper.Decoder
Crypto() Crypto
Storage() Storage
AccessTokenVerifier() AccessTokenVerifier
AccessTokenVerifier(context.Context) AccessTokenVerifier
}
func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) {
@ -34,7 +34,7 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP
http.Error(w, "access token invalid", http.StatusUnauthorized)
return
}
info := oidc.NewUserInfo()
info := new(oidc.UserInfo)
err = userinfoProvider.Storage().SetUserinfoFromToken(r.Context(), info, tokenID, subject, r.Header.Get("origin"))
if err != nil {
httphelper.MarshalJSONWithStatus(w, err, http.StatusForbidden)
@ -81,9 +81,9 @@ func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider
}
return splitToken[0], splitToken[1], true
}
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier())
accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
if err != nil {
return "", "", false
}
return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true
return accessTokenClaims.JWTID, accessTokenClaims.Subject, true
}

View file

@ -4,7 +4,7 @@ import (
"context"
"time"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type AccessTokenVerifier interface {
@ -18,8 +18,6 @@ type accessTokenVerifier struct {
maxAgeIAT time.Duration
offset time.Duration
supportedSignAlgs []string
maxAge time.Duration
acr oidc.ACRVerifier
keySet oidc.KeySet
}
@ -67,29 +65,29 @@ func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTok
return verifier
}
// VerifyAccessToken validates the access token (issuer, signature and expiration)
func VerifyAccessToken(ctx context.Context, token string, v AccessTokenVerifier) (oidc.AccessTokenClaims, error) {
claims := oidc.EmptyAccessTokenClaims()
// VerifyAccessToken validates the access token (issuer, signature and expiration).
func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v AccessTokenVerifier) (claims C, err error) {
var nilClaims C
decrypted, err := oidc.DecryptToken(token)
if err != nil {
return nil, err
return nilClaims, err
}
payload, err := oidc.ParseToken(decrypted, claims)
payload, err := oidc.ParseToken(decrypted, &claims)
if err != nil {
return nil, err
return nilClaims, err
}
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
return nil, err
return nilClaims, err
}
return claims, nil

View file

@ -0,0 +1,70 @@
package op_test
import (
"context"
"fmt"
tu "github.com/zitadel/oidc/v2/internal/testutil"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
)
// MyCustomClaims extends the TokenClaims base,
// so it implements the oidc.Claims interface.
// Instead of carrying a map, we add needed fields// to the struct for type safe access.
type MyCustomClaims struct {
oidc.TokenClaims
NotBefore oidc.Time `json:"nbf,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
SessionID string `json:"sid,omitempty"`
Scopes []string `json:"scope,omitempty"`
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
Foo string `json:"foo,omitempty"`
Bar *Nested `json:"bar,omitempty"`
}
// Nested struct types are also possible.
type Nested struct {
Count int `json:"count,omitempty"`
Tags []string `json:"tags,omitempty"`
}
/*
accessToken carries the following claims. foo and bar are custom claims
{
"aud": [
"unit",
"test"
],
"bar": {
"count": 22,
"tags": [
"some",
"tags"
]
},
"exp": 4802234675,
"foo": "Hello, World!",
"iat": 1678097014,
"iss": "local.com",
"jti": "9876",
"nbf": 1678097014,
"sub": "tim@local.com"
}
*/
const accessToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhdWQiOlsidW5pdCIsInRlc3QiXSwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiZXhwIjo0ODAyMjM0Njc1LCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MDk3MDE0LCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MDk3MDE0LCJzdWIiOiJ0aW1AbG9jYWwuY29tIn0.OUgk-B7OXjYlYFj-nogqSDJiQE19tPrbzqUHEAjcEiJkaWo6-IpGVfDiGKm-TxjXQsNScxpaY0Pg3XIh1xK6TgtfYtoLQm-5RYw_mXgb9xqZB2VgPs6nNEYFUDM513MOU0EBc0QMyqAEGzW-HiSPAb4ugCvkLtM1yo11Xyy6vksAdZNs_mJDT4X3vFXnr0jk0ugnAW6fTN3_voC0F_9HQUAkmd750OIxkAHxAMvEPQcpbLHenVvX_Q0QMrzClVrxehn5TVMfmkYYg7ocr876Bq9xQGPNHAcrwvVIJqdg5uMUA38L3HC2BEueG6furZGvc7-qDWAT1VR9liM5ieKpPg`
func ExampleVerifyAccessToken_customClaims() {
v := op.NewAccessTokenVerifier("local.com", tu.KeySet{})
// VerifyAccessToken can be called with the *MyCustomClaims.
claims, err := op.VerifyAccessToken[*MyCustomClaims](context.TODO(), accessToken, v)
if err != nil {
panic(err)
}
// Here we have typesafe access to the custom claims
fmt.Println(claims.Foo, claims.Bar.Count, claims.Bar.Tags)
// Output: Hello, World! 22 [some tags]
}

View file

@ -0,0 +1,126 @@
package op
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v2/internal/testutil"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
func TestNewAccessTokenVerifier(t *testing.T) {
type args struct {
issuer string
keySet oidc.KeySet
opts []AccessTokenVerifierOpt
}
tests := []struct {
name string
args args
want AccessTokenVerifier
}{
{
name: "simple",
args: args{
issuer: tu.ValidIssuer,
keySet: tu.KeySet{},
},
want: &accessTokenVerifier{
issuer: tu.ValidIssuer,
keySet: tu.KeySet{},
},
},
{
name: "with signature algorithm",
args: args{
issuer: tu.ValidIssuer,
keySet: tu.KeySet{},
opts: []AccessTokenVerifierOpt{
WithSupportedAccessTokenSigningAlgorithms("ABC", "DEF"),
},
},
want: &accessTokenVerifier{
issuer: tu.ValidIssuer,
keySet: tu.KeySet{},
supportedSignAlgs: []string{"ABC", "DEF"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewAccessTokenVerifier(tt.args.issuer, tt.args.keySet, tt.args.opts...)
assert.Equal(t, tt.want, got)
})
}
}
func TestVerifyAccessToken(t *testing.T) {
verifier := &accessTokenVerifier{
issuer: tu.ValidIssuer,
maxAgeIAT: 2 * time.Minute,
offset: time.Second,
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
keySet: tu.KeySet{},
}
tests := []struct {
name string
tokenClaims func() (string, *oidc.AccessTokenClaims)
wantErr bool
}{
{
name: "success",
tokenClaims: tu.ValidAccessToken,
},
{
name: "parse err",
tokenClaims: func() (string, *oidc.AccessTokenClaims) { return "~~~~", nil },
wantErr: true,
},
{
name: "invalid signature",
tokenClaims: func() (string, *oidc.AccessTokenClaims) { return tu.InvalidSignatureToken, nil },
wantErr: true,
},
{
name: "wrong issuer",
tokenClaims: func() (string, *oidc.AccessTokenClaims) {
return tu.NewAccessToken(
"foo", tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidJWTID, tu.ValidClientID,
tu.ValidSkew,
)
},
wantErr: true,
},
{
name: "expired",
tokenClaims: func() (string, *oidc.AccessTokenClaims) {
return tu.NewAccessToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration.Add(-time.Hour), tu.ValidJWTID, tu.ValidClientID,
tu.ValidSkew,
)
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, want := tt.tokenClaims()
got, err := VerifyAccessToken[*oidc.AccessTokenClaims](context.Background(), token, verifier)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
return
}
require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, got, want)
})
}
}

View file

@ -4,7 +4,7 @@ import (
"context"
"time"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type IDTokenHintVerifier interface {
@ -73,41 +73,41 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHi
}
// VerifyIDTokenHint validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (oidc.IDTokenClaims, error) {
claims := oidc.EmptyIDTokenClaims()
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTokenHintVerifier) (claims C, err error) {
var nilClaims C
decrypted, err := oidc.DecryptToken(token)
if err != nil {
return nil, err
return nilClaims, err
}
payload, err := oidc.ParseToken(decrypted, claims)
payload, err := oidc.ParseToken(decrypted, &claims)
if err != nil {
return nil, err
return nilClaims, err
}
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
return nil, err
return nilClaims, err
}
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
return nil, err
return nilClaims, err
}
return claims, nil
}

View file

@ -0,0 +1,161 @@
package op
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v2/internal/testutil"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
func TestNewIDTokenHintVerifier(t *testing.T) {
type args struct {
issuer string
keySet oidc.KeySet
opts []IDTokenHintVerifierOpt
}
tests := []struct {
name string
args args
want IDTokenHintVerifier
}{
{
name: "simple",
args: args{
issuer: tu.ValidIssuer,
keySet: tu.KeySet{},
},
want: &idTokenHintVerifier{
issuer: tu.ValidIssuer,
keySet: tu.KeySet{},
},
},
{
name: "with signature algorithm",
args: args{
issuer: tu.ValidIssuer,
keySet: tu.KeySet{},
opts: []IDTokenHintVerifierOpt{
WithSupportedIDTokenHintSigningAlgorithms("ABC", "DEF"),
},
},
want: &idTokenHintVerifier{
issuer: tu.ValidIssuer,
keySet: tu.KeySet{},
supportedSignAlgs: []string{"ABC", "DEF"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewIDTokenHintVerifier(tt.args.issuer, tt.args.keySet, tt.args.opts...)
assert.Equal(t, tt.want, got)
})
}
}
func TestVerifyIDTokenHint(t *testing.T) {
verifier := &idTokenHintVerifier{
issuer: tu.ValidIssuer,
maxAgeIAT: 2 * time.Minute,
offset: time.Second,
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
maxAge: 2 * time.Minute,
acr: tu.ACRVerify,
keySet: tu.KeySet{},
}
tests := []struct {
name string
tokenClaims func() (string, *oidc.IDTokenClaims)
wantErr bool
}{
{
name: "success",
tokenClaims: tu.ValidIDToken,
},
{
name: "parse err",
tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
wantErr: true,
},
{
name: "invalid signature",
tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
wantErr: true,
},
{
name: "wrong issuer",
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
"foo", tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
{
name: "expired",
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
{
name: "wrong IAT",
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "",
)
},
wantErr: true,
},
{
name: "wrong acr",
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
"else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
{
name: "expired auth",
tokenClaims: func() (string, *oidc.IDTokenClaims) {
return tu.NewIDToken(
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
tu.ValidExpiration, tu.ValidAuthTime.Add(-time.Hour), tu.ValidNonce,
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
)
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, want := tt.tokenClaims()
got, err := VerifyIDTokenHint[*oidc.IDTokenClaims](context.Background(), token, verifier)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
return
}
require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, got, want)
})
}
}

View file

@ -8,7 +8,7 @@ import (
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type JWTProfileVerifier interface {
@ -104,7 +104,7 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif
}
type jwtProfileKeyStorage interface {
GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
}
func SubjectIsIssuer(request *oidc.JWTTokenRequest) error {
@ -122,7 +122,7 @@ type jwtProfileKeySet struct {
// VerifySignature implements oidc.KeySet by getting the public key from Storage implementation
func (k *jwtProfileKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
keyID, _ := oidc.GetKeyIDAndAlg(jws)
key, err := k.storage.GetKeyByIDAndUserID(ctx, keyID, k.clientID)
key, err := k.storage.GetKeyByIDAndClientID(ctx, keyID, k.clientID)
if err != nil {
return nil, fmt.Errorf("error fetching keys: %w", err)
}