Merge branch 'next' into main-next
prepare the merge of next into main by resolving merge conflicts.
This commit is contained in:
commit
0476b5946e
122 changed files with 8195 additions and 2858 deletions
|
@ -1,31 +1,25 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/crypto"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/crypto"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
var Encoder = func() httphelper.Encoder {
|
||||
e := schema.NewEncoder()
|
||||
e.RegisterEncoder(oidc.SpaceDelimitedArray{}, func(value reflect.Value) string {
|
||||
return value.Interface().(oidc.SpaceDelimitedArray).Encode()
|
||||
})
|
||||
return e
|
||||
}()
|
||||
var Encoder = httphelper.Encoder(oidc.NewEncoder())
|
||||
|
||||
// Discover calls the discovery endpoint of the provided issuer and returns its configuration
|
||||
// It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url
|
||||
|
@ -90,6 +84,9 @@ func CallEndSessionEndpoint(request interface{}, authFn interface{}, caller EndS
|
|||
return http.ErrUseLastResponse
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
|
@ -148,6 +145,18 @@ func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCa
|
|||
return nil
|
||||
}
|
||||
|
||||
func CallTokenExchangeEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) {
|
||||
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenRes := new(oidc.TokenExchangeResponse)
|
||||
if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tokenRes, nil
|
||||
}
|
||||
|
||||
func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error) {
|
||||
privateKey, err := crypto.BytesToPrivateKey(key)
|
||||
if err != nil {
|
||||
|
@ -167,7 +176,98 @@ func SignedJWTProfileAssertion(clientID string, audience []string, expiration ti
|
|||
Issuer: clientID,
|
||||
Subject: clientID,
|
||||
Audience: audience,
|
||||
ExpiresAt: oidc.Time(exp),
|
||||
IssuedAt: oidc.Time(iat),
|
||||
ExpiresAt: oidc.FromTime(exp),
|
||||
IssuedAt: oidc.FromTime(iat),
|
||||
}, signer)
|
||||
}
|
||||
|
||||
type DeviceAuthorizationCaller interface {
|
||||
GetDeviceAuthorizationEndpoint() string
|
||||
HttpClient() *http.Client
|
||||
}
|
||||
|
||||
func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) {
|
||||
req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if request.ClientSecret != "" {
|
||||
req.SetBasicAuth(request.ClientID, request.ClientSecret)
|
||||
}
|
||||
|
||||
resp := new(oidc.DeviceAuthorizationResponse)
|
||||
if err := httphelper.HttpRequest(caller.HttpClient(), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type DeviceAccessTokenRequest struct {
|
||||
*oidc.ClientCredentialsRequest
|
||||
oidc.DeviceAccessTokenRequest
|
||||
}
|
||||
|
||||
func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
|
||||
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if request.ClientSecret != "" {
|
||||
req.SetBasicAuth(request.ClientID, request.ClientSecret)
|
||||
}
|
||||
|
||||
httpResp, err := caller.HttpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
resp := new(struct {
|
||||
*oidc.AccessTokenResponse
|
||||
*oidc.Error
|
||||
})
|
||||
if err = json.NewDecoder(httpResp.Body).Decode(resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if httpResp.StatusCode == http.StatusOK {
|
||||
return resp.AccessTokenResponse, nil
|
||||
}
|
||||
|
||||
return nil, resp.Error
|
||||
}
|
||||
|
||||
func PollDeviceAccessTokenEndpoint(ctx context.Context, interval time.Duration, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
|
||||
for {
|
||||
timer := time.After(interval)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-timer:
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, interval)
|
||||
defer cancel()
|
||||
|
||||
resp, err := CallDeviceAccessTokenEndpoint(ctx, request, caller)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
interval += 5 * time.Second
|
||||
}
|
||||
var target *oidc.Error
|
||||
if !errors.As(err, &target) {
|
||||
return nil, err
|
||||
}
|
||||
switch target.ErrorType {
|
||||
case oidc.AuthorizationPending:
|
||||
continue
|
||||
case oidc.SlowDown:
|
||||
interval += 5 * time.Second
|
||||
continue
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
package rp_test
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
|
@ -15,34 +14,142 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/example/server/exampleop"
|
||||
"github.com/zitadel/oidc/example/server/storage"
|
||||
|
||||
"github.com/jeremija/gosubmit"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/oidc/v2/example/server/exampleop"
|
||||
"github.com/zitadel/oidc/v2/example/server/storage"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rs"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/tokenexchange"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
func TestRelyingPartySession(t *testing.T) {
|
||||
t.Log("------- start example OP ------")
|
||||
ctx := context.Background()
|
||||
exampleStorage := storage.NewStorage(storage.NewUserStore())
|
||||
targetURL := "http://local-site"
|
||||
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
|
||||
var dh deferredHandler
|
||||
opServer := httptest.NewServer(&dh)
|
||||
defer opServer.Close()
|
||||
t.Logf("auth server at %s", opServer.URL)
|
||||
dh.Handler = exampleop.SetupServer(ctx, opServer.URL, exampleStorage)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage)
|
||||
|
||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
|
||||
t.Log("------- run authorization code flow ------")
|
||||
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, "secret")
|
||||
|
||||
t.Log("------- refresh tokens ------")
|
||||
|
||||
newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
require.NoError(t, err, "refresh token")
|
||||
assert.NotNil(t, newTokens, "access token")
|
||||
t.Logf("new access token %s", newTokens.AccessToken)
|
||||
t.Logf("new refresh token %s", newTokens.RefreshToken)
|
||||
t.Logf("new token type %s", newTokens.TokenType)
|
||||
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
|
||||
require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
|
||||
|
||||
t.Log("------ end session (logout) ------")
|
||||
|
||||
newLoc, err := rp.EndSession(provider, idToken, "", "")
|
||||
require.NoError(t, err, "logout")
|
||||
if newLoc != nil {
|
||||
t.Logf("redirect to %s", newLoc)
|
||||
} else {
|
||||
t.Logf("no redirect")
|
||||
}
|
||||
|
||||
t.Log("------ attempt refresh again (should fail) ------")
|
||||
t.Log("trying original refresh token", refreshToken)
|
||||
_, err = rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
assert.Errorf(t, err, "refresh with original")
|
||||
if newTokens.RefreshToken != "" {
|
||||
t.Log("trying replacement refresh token", newTokens.RefreshToken)
|
||||
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
|
||||
assert.Errorf(t, err, "refresh with replacement")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceServerTokenExchange(t *testing.T) {
|
||||
t.Log("------- start example OP ------")
|
||||
targetURL := "http://local-site"
|
||||
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
|
||||
var dh deferredHandler
|
||||
opServer := httptest.NewServer(&dh)
|
||||
defer opServer.Close()
|
||||
t.Logf("auth server at %s", opServer.URL)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage)
|
||||
|
||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
clientSecret := "secret"
|
||||
|
||||
t.Log("------- run authorization code flow ------")
|
||||
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)
|
||||
|
||||
resourceServer, err := rs.NewResourceServerClientCredentials(opServer.URL, clientID, clientSecret)
|
||||
require.NoError(t, err, "new resource server")
|
||||
|
||||
t.Log("------- exchage refresh tokens (impersonation) ------")
|
||||
|
||||
tokenExchangeResponse, err := tokenexchange.ExchangeToken(
|
||||
resourceServer,
|
||||
refreshToken,
|
||||
oidc.RefreshTokenType,
|
||||
"",
|
||||
"",
|
||||
[]string{},
|
||||
[]string{},
|
||||
[]string{"profile", "custom_scope:impersonate:id2"},
|
||||
oidc.RefreshTokenType,
|
||||
)
|
||||
require.NoError(t, err, "refresh token")
|
||||
require.NotNil(t, tokenExchangeResponse, "token exchange response")
|
||||
assert.Equal(t, tokenExchangeResponse.IssuedTokenType, oidc.RefreshTokenType)
|
||||
assert.NotEmpty(t, tokenExchangeResponse.AccessToken, "access token")
|
||||
assert.NotEmpty(t, tokenExchangeResponse.RefreshToken, "refresh token")
|
||||
assert.Equal(t, []string(tokenExchangeResponse.Scopes), []string{"profile", "custom_scope:impersonate:id2"})
|
||||
|
||||
t.Log("------ end session (logout) ------")
|
||||
|
||||
newLoc, err := rp.EndSession(provider, idToken, "", "")
|
||||
require.NoError(t, err, "logout")
|
||||
if newLoc != nil {
|
||||
t.Logf("redirect to %s", newLoc)
|
||||
} else {
|
||||
t.Logf("no redirect")
|
||||
}
|
||||
|
||||
t.Log("------- attempt exchage again (should fail) ------")
|
||||
|
||||
tokenExchangeResponse, err = tokenexchange.ExchangeToken(
|
||||
resourceServer,
|
||||
refreshToken,
|
||||
oidc.RefreshTokenType,
|
||||
"",
|
||||
"",
|
||||
[]string{},
|
||||
[]string{},
|
||||
[]string{"profile", "custom_scope:impersonate:id2"},
|
||||
oidc.RefreshTokenType,
|
||||
)
|
||||
require.Error(t, err, "refresh token")
|
||||
assert.Contains(t, err.Error(), "subject_token is invalid")
|
||||
require.Nil(t, tokenExchangeResponse, "token exchange response")
|
||||
|
||||
}
|
||||
|
||||
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) {
|
||||
targetURL := "http://local-site"
|
||||
localURL, err := url.Parse(targetURL + "/login?requestID=1234")
|
||||
require.NoError(t, err, "local url")
|
||||
|
||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
client := storage.WebClient(clientID, "secret", targetURL)
|
||||
client := storage.WebClient(clientID, clientSecret, targetURL)
|
||||
storage.RegisterClients(client)
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
|
@ -58,10 +165,10 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
t.Log("------- create RP ------")
|
||||
key := []byte("test1234test1234")
|
||||
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
|
||||
provider, err := rp.NewRelyingPartyOIDC(
|
||||
provider, err = rp.NewRelyingPartyOIDC(
|
||||
opServer.URL,
|
||||
clientID,
|
||||
"secret",
|
||||
clientSecret,
|
||||
targetURL,
|
||||
[]string{"openid", "email", "profile", "offline_access"},
|
||||
rp.WithPKCE(cookieHandler),
|
||||
|
@ -70,8 +177,10 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
rp.WithSupportedSigningAlgorithms("RS256", "RS384", "RS512", "ES256", "ES384", "ES512"),
|
||||
),
|
||||
)
|
||||
require.NoError(t, err, "new rp")
|
||||
|
||||
t.Log("------- get redirect from local client (rp) to OP ------")
|
||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||
state := "state-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
capturedW := httptest.NewRecorder()
|
||||
get := httptest.NewRequest("GET", localURL.String(), nil)
|
||||
|
@ -114,7 +223,7 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
|
||||
t.Log("------- post to login form, get redirect to OP ------")
|
||||
postLoginRedirectURL := fillForm(t, "fill login form", httpClient, form, loginPageURL,
|
||||
gosubmit.Set("username", "test-user"),
|
||||
gosubmit.Set("username", "test-user@local-site"),
|
||||
gosubmit.Set("password", "verysecure"))
|
||||
t.Logf("Get redirect from %s", postLoginRedirectURL)
|
||||
|
||||
|
@ -130,19 +239,19 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
t.Logf("setting cookie %s", cookie)
|
||||
}
|
||||
|
||||
var accessToken, refreshToken, idToken, email string
|
||||
redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {
|
||||
var email string
|
||||
redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
|
||||
require.NotNil(t, tokens, "tokens")
|
||||
require.NotNil(t, info, "info")
|
||||
t.Log("access token", tokens.AccessToken)
|
||||
t.Log("refresh token", tokens.RefreshToken)
|
||||
t.Log("id token", tokens.IDToken)
|
||||
t.Log("email", info.GetEmail())
|
||||
t.Log("email", info.Email)
|
||||
|
||||
accessToken = tokens.AccessToken
|
||||
refreshToken = tokens.RefreshToken
|
||||
idToken = tokens.IDToken
|
||||
email = info.GetEmail()
|
||||
email = info.Email
|
||||
http.Redirect(w, r, targetURL, 302)
|
||||
}
|
||||
rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider, rp.WithURLParam("custom", "param"))(capturedW, get)
|
||||
|
@ -154,7 +263,6 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
require.Less(t, capturedW.Code, 400, "token exchange response code")
|
||||
require.Less(t, capturedW.Code, 400, "token exchange response code")
|
||||
// TODO: how to check the custom header was sent to the server?
|
||||
|
||||
//nolint:bodyclose
|
||||
|
@ -169,43 +277,7 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
assert.NotEmpty(t, accessToken, "access token")
|
||||
assert.NotEmpty(t, email, "email")
|
||||
|
||||
t.Log("------- refresh tokens ------")
|
||||
|
||||
newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
require.NoError(t, err, "refresh token")
|
||||
assert.NotNil(t, newTokens, "access token")
|
||||
t.Logf("new access token %s", newTokens.AccessToken)
|
||||
t.Logf("new refresh token %s", newTokens.RefreshToken)
|
||||
t.Logf("new token type %s", newTokens.TokenType)
|
||||
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
|
||||
require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
|
||||
|
||||
t.Log("------ end session (logout) ------")
|
||||
|
||||
newLoc, err := rp.EndSession(provider, idToken, "", "")
|
||||
require.NoError(t, err, "logout")
|
||||
if newLoc != nil {
|
||||
t.Logf("redirect to %s", newLoc)
|
||||
} else {
|
||||
t.Logf("no redirect")
|
||||
}
|
||||
|
||||
t.Log("------ attempt refresh again (should fail) ------")
|
||||
t.Log("trying original refresh token", refreshToken)
|
||||
_, err = rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
assert.Errorf(t, err, "refresh with original")
|
||||
if newTokens.RefreshToken != "" {
|
||||
t.Log("trying replacement refresh token", newTokens.RefreshToken)
|
||||
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
|
||||
assert.Errorf(t, err, "refresh with replacement")
|
||||
}
|
||||
|
||||
t.Run("WithPrompt", func(t *testing.T) {
|
||||
opts := rp.WithPrompt("foo", "bar")()
|
||||
url := provider.OAuthConfig().AuthCodeURL("some", opts...)
|
||||
|
||||
require.Contains(t, url, "prompt=foo+bar")
|
||||
})
|
||||
return provider, accessToken, refreshToken, idToken
|
||||
}
|
||||
|
||||
type deferredHandler struct {
|
|
@ -5,8 +5,8 @@ import (
|
|||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
// JWTProfileExchange handles the oauth2 jwt profile exchange
|
||||
|
|
|
@ -7,8 +7,8 @@ import (
|
|||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/client"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
// jwtProfileTokenSource implement the oauth2.TokenSource
|
||||
|
|
|
@ -4,22 +4,22 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
loginPath = "/login"
|
||||
)
|
||||
|
||||
func CodeFlow(ctx context.Context, relyingParty rp.RelyingParty, callbackPath, port string, stateProvider func() string) *oidc.Tokens {
|
||||
func CodeFlow[C oidc.IDClaims](ctx context.Context, relyingParty rp.RelyingParty, callbackPath, port string, stateProvider func() string) *oidc.Tokens[C] {
|
||||
codeflowCtx, codeflowCancel := context.WithCancel(ctx)
|
||||
defer codeflowCancel()
|
||||
|
||||
tokenChan := make(chan *oidc.Tokens, 1)
|
||||
tokenChan := make(chan *oidc.Tokens[C], 1)
|
||||
|
||||
callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
|
||||
callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp rp.RelyingParty) {
|
||||
tokenChan <- tokens
|
||||
msg := "<p><strong>Success!</strong></p>"
|
||||
msg = msg + "<p>You are authenticated and can now return to the CLI.</p>"
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"github.com/zitadel/oidc/pkg/oidc/grants/tokenexchange"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange"
|
||||
)
|
||||
|
||||
// DelegationTokenRequest is an implementation of TokenExchangeRequest
|
||||
// it exchanges an "urn:ietf:params:oauth:token-type:access_token" with an optional
|
||||
//"urn:ietf:params:oauth:token-type:access_token" actor token for an
|
||||
//"urn:ietf:params:oauth:token-type:access_token" delegation token
|
||||
// "urn:ietf:params:oauth:token-type:access_token" actor token for an
|
||||
// "urn:ietf:params:oauth:token-type:access_token" delegation token
|
||||
func DelegationTokenRequest(subjectToken string, opts ...tokenexchange.TokenExchangeOption) *tokenexchange.TokenExchangeRequest {
|
||||
return tokenexchange.NewTokenExchangeRequest(subjectToken, tokenexchange.AccessTokenType, opts...)
|
||||
}
|
||||
|
|
62
pkg/client/rp/device.go
Normal file
62
pkg/client/rp/device.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) {
|
||||
confg := rp.OAuthConfig()
|
||||
req := &oidc.ClientCredentialsRequest{
|
||||
GrantType: oidc.GrantTypeDeviceCode,
|
||||
Scope: scopes,
|
||||
ClientID: confg.ClientID,
|
||||
ClientSecret: confg.ClientSecret,
|
||||
}
|
||||
|
||||
if signer := rp.Signer(); signer != nil {
|
||||
assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, signer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build assertion: %w", err)
|
||||
}
|
||||
req.ClientAssertion = assertion
|
||||
req.ClientAssertionType = oidc.ClientAssertionTypeJWTAssertion
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// DeviceAuthorization starts a new Device Authorization flow as defined
|
||||
// in RFC 8628, section 3.1 and 3.2:
|
||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
|
||||
func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) {
|
||||
req, err := newDeviceClientCredentialsRequest(scopes, rp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.CallDeviceAuthorizationEndpoint(req, rp)
|
||||
}
|
||||
|
||||
// DeviceAccessToken attempts to obtain tokens from a Device Authorization,
|
||||
// by means of polling as defined in RFC, section 3.3 and 3.4:
|
||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4
|
||||
func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
|
||||
req := &client.DeviceAccessTokenRequest{
|
||||
DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
|
||||
GrantType: oidc.GrantTypeDeviceCode,
|
||||
DeviceCode: deviceCode,
|
||||
},
|
||||
}
|
||||
|
||||
req.ClientCredentialsRequest, err = newDeviceClientCredentialsRequest(nil, rp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.PollDeviceAccessTokenEndpoint(ctx, interval, req, tokenEndpointCaller{rp})
|
||||
}
|
|
@ -9,8 +9,8 @@ import (
|
|||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet {
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
package mock
|
||||
|
||||
//go:generate mockgen -package mock -destination ./verifier.mock.go github.com/zitadel/oidc/pkg/rp Verifier
|
|
@ -1,67 +0,0 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/pkg/rp (interfaces: Verifier)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
// MockVerifier is a mock of Verifier interface
|
||||
type MockVerifier struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockVerifierMockRecorder
|
||||
}
|
||||
|
||||
// MockVerifierMockRecorder is the mock recorder for MockVerifier
|
||||
type MockVerifierMockRecorder struct {
|
||||
mock *MockVerifier
|
||||
}
|
||||
|
||||
// NewMockVerifier creates a new mock instance
|
||||
func NewMockVerifier(ctrl *gomock.Controller) *MockVerifier {
|
||||
mock := &MockVerifier{ctrl: ctrl}
|
||||
mock.recorder = &MockVerifierMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockVerifier) EXPECT() *MockVerifierMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Verify mocks base method
|
||||
func (m *MockVerifier) Verify(arg0 context.Context, arg1, arg2 string) (*oidc.IDTokenClaims, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Verify", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(*oidc.IDTokenClaims)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Verify indicates an expected call of Verify
|
||||
func (mr *MockVerifierMockRecorder) Verify(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockVerifier)(nil).Verify), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// VerifyIDToken mocks base method
|
||||
func (m *MockVerifier) VerifyIDToken(arg0 context.Context, arg1 string) (*oidc.IDTokenClaims, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "VerifyIDToken", arg0, arg1)
|
||||
ret0, _ := ret[0].(*oidc.IDTokenClaims)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// VerifyIDToken indicates an expected call of VerifyIDToken
|
||||
func (mr *MockVerifierMockRecorder) VerifyIDToken(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyIDToken", reflect.TypeOf((*MockVerifier)(nil).VerifyIDToken), arg0, arg1)
|
||||
}
|
|
@ -14,9 +14,9 @@ import (
|
|||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -54,11 +54,15 @@ type RelyingParty interface {
|
|||
GetEndSessionEndpoint() string
|
||||
|
||||
// GetRevokeEndpoint returns the endpoint to revoke a specific token
|
||||
// "GetRevokeEndpoint() string" will be added in a future release
|
||||
GetRevokeEndpoint() string
|
||||
|
||||
// UserinfoEndpoint returns the userinfo
|
||||
UserinfoEndpoint() string
|
||||
|
||||
// GetDeviceAuthorizationEndpoint returns the enpoint which can
|
||||
// be used to start a DeviceAuthorization flow.
|
||||
GetDeviceAuthorizationEndpoint() string
|
||||
|
||||
// IDTokenVerifier returns the verifier interface used for oidc id_token verification
|
||||
IDTokenVerifier() IDTokenVerifier
|
||||
// ErrorHandler returns the handler used for callback errors
|
||||
|
@ -121,6 +125,10 @@ func (rp *relyingParty) UserinfoEndpoint() string {
|
|||
return rp.endpoints.UserinfoURL
|
||||
}
|
||||
|
||||
func (rp *relyingParty) GetDeviceAuthorizationEndpoint() string {
|
||||
return rp.endpoints.DeviceAuthorizationURL
|
||||
}
|
||||
|
||||
func (rp *relyingParty) GetEndSessionEndpoint() string {
|
||||
return rp.endpoints.EndSessionURL
|
||||
}
|
||||
|
@ -371,7 +379,7 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri
|
|||
|
||||
// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
|
||||
// returning it parsed together with the oauth2 tokens (access, refresh)
|
||||
func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens, err error) {
|
||||
func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
|
||||
codeOpts := make([]oauth2.AuthCodeOption, 0)
|
||||
for _, opt := range opts {
|
||||
|
@ -384,7 +392,7 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod
|
|||
}
|
||||
|
||||
if rp.IsOAuth2Only() {
|
||||
return &oidc.Tokens{Token: token}, nil
|
||||
return &oidc.Tokens[C]{Token: token}, nil
|
||||
}
|
||||
|
||||
idTokenString, ok := token.Extra(idTokenKey).(string)
|
||||
|
@ -392,21 +400,21 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod
|
|||
return nil, errors.New("id_token missing")
|
||||
}
|
||||
|
||||
idToken, err := VerifyTokens(ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
|
||||
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
|
||||
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
|
||||
}
|
||||
|
||||
type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty)
|
||||
type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty)
|
||||
|
||||
// CodeExchangeHandler extends the `CodeExchange` method with a http handler
|
||||
// including cookie handling for secure `state` transfer
|
||||
// and optional PKCE code verifier checking.
|
||||
// Custom paramaters can optionally be set to the token URL.
|
||||
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
|
||||
func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := tryReadStateCookie(w, r, rp)
|
||||
if err != nil {
|
||||
|
@ -439,7 +447,7 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlPara
|
|||
}
|
||||
codeOpts = append(codeOpts, WithClientAssertionJWT(assertion))
|
||||
}
|
||||
tokens, err := CodeExchange(r.Context(), params.Get("code"), rp, codeOpts...)
|
||||
tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
|
@ -448,13 +456,13 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlPara
|
|||
}
|
||||
}
|
||||
|
||||
type CodeExchangeUserinfoCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, provider RelyingParty, info oidc.UserInfo)
|
||||
type CodeExchangeUserinfoCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info *oidc.UserInfo)
|
||||
|
||||
// UserinfoCallback wraps the callback function of the CodeExchangeHandler
|
||||
// and calls the userinfo endpoint with the access token
|
||||
// on success it will pass the userinfo into its callback function as well
|
||||
func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
|
||||
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty) {
|
||||
func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] {
|
||||
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
|
||||
info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
|
||||
if err != nil {
|
||||
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
|
||||
|
@ -465,17 +473,17 @@ func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
|
|||
}
|
||||
|
||||
// Userinfo will call the OIDC Userinfo Endpoint with the provided token
|
||||
func Userinfo(token, tokenType, subject string, rp RelyingParty) (oidc.UserInfo, error) {
|
||||
func Userinfo(token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) {
|
||||
req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("authorization", tokenType+" "+token)
|
||||
userinfo := oidc.NewUserInfo()
|
||||
userinfo := new(oidc.UserInfo)
|
||||
if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userinfo.GetSubject() != subject {
|
||||
if userinfo.Subject != subject {
|
||||
return nil, ErrUserInfoSubNotMatching
|
||||
}
|
||||
return userinfo, nil
|
||||
|
@ -506,11 +514,12 @@ type OptionFunc func(RelyingParty)
|
|||
|
||||
type Endpoints struct {
|
||||
oauth2.Endpoint
|
||||
IntrospectURL string
|
||||
UserinfoURL string
|
||||
JKWsURL string
|
||||
EndSessionURL string
|
||||
RevokeURL string
|
||||
IntrospectURL string
|
||||
UserinfoURL string
|
||||
JKWsURL string
|
||||
EndSessionURL string
|
||||
RevokeURL string
|
||||
DeviceAuthorizationURL string
|
||||
}
|
||||
|
||||
func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
|
||||
|
@ -520,11 +529,12 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
|
|||
AuthStyle: oauth2.AuthStyleAutoDetect,
|
||||
TokenURL: discoveryConfig.TokenEndpoint,
|
||||
},
|
||||
IntrospectURL: discoveryConfig.IntrospectionEndpoint,
|
||||
UserinfoURL: discoveryConfig.UserinfoEndpoint,
|
||||
JKWsURL: discoveryConfig.JwksURI,
|
||||
EndSessionURL: discoveryConfig.EndSessionEndpoint,
|
||||
RevokeURL: discoveryConfig.RevocationEndpoint,
|
||||
IntrospectURL: discoveryConfig.IntrospectionEndpoint,
|
||||
UserinfoURL: discoveryConfig.UserinfoEndpoint,
|
||||
JKWsURL: discoveryConfig.JwksURI,
|
||||
EndSessionURL: discoveryConfig.EndSessionEndpoint,
|
||||
RevokeURL: discoveryConfig.RevocationEndpoint,
|
||||
DeviceAuthorizationURL: discoveryConfig.DeviceAuthorizationEndpoint,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc/grants/tokenexchange"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange"
|
||||
)
|
||||
|
||||
// TokenExchangeRP extends the `RelyingParty` interface for the *draft* oauth2 `Token Exchange`
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type IDTokenVerifier interface {
|
||||
|
@ -20,76 +20,78 @@ type IDTokenVerifier interface {
|
|||
}
|
||||
|
||||
// VerifyTokens implement the Token Response Validation as defined in OIDC specification
|
||||
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
||||
func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
|
||||
idToken, err := VerifyIDToken(ctx, idTokenString, v)
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
||||
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) {
|
||||
var nilClaims C
|
||||
|
||||
claims, err = VerifyIDToken[C](ctx, idToken, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil {
|
||||
return nil, err
|
||||
if err := VerifyAccessToken(accessToken, claims.GetAccessTokenHash(), claims.GetSignatureAlgorithm()); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
return idToken, nil
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// VerifyIDToken validates the id token according to
|
||||
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
|
||||
claims := oidc.EmptyIDTokenClaims()
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) {
|
||||
var nilClaims C
|
||||
|
||||
decrypted, err := oidc.DecryptToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
payload, err := oidc.ParseToken(decrypted, claims)
|
||||
payload, err := oidc.ParseToken(decrypted, &claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err := oidc.CheckSubject(claims); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAudience(claims, v.ClientID()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// VerifyAccessToken validates the access token according to
|
||||
//https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
|
||||
func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
|
||||
if atHash == "" {
|
||||
return nil
|
||||
|
@ -112,7 +114,7 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...
|
|||
issuer: issuer,
|
||||
clientID: clientID,
|
||||
keySet: keySet,
|
||||
offset: 1 * time.Second,
|
||||
offset: time.Second,
|
||||
nonce: func(_ context.Context) string {
|
||||
return ""
|
||||
},
|
||||
|
@ -139,7 +141,7 @@ func WithIssuedAtOffset(offset time.Duration) func(*idTokenVerifier) {
|
|||
// WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
|
||||
func WithIssuedAtMaxAge(maxAge time.Duration) func(*idTokenVerifier) {
|
||||
return func(v *idTokenVerifier) {
|
||||
v.maxAge = maxAge
|
||||
v.maxAgeIAT = maxAge
|
||||
}
|
||||
}
|
||||
|
||||
|
|
339
pkg/client/rp/verifier_test.go
Normal file
339
pkg/client/rp/verifier_test.go
Normal file
|
@ -0,0 +1,339 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
func TestVerifyTokens(t *testing.T) {
|
||||
verifier := &idTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
maxAgeIAT: 2 * time.Minute,
|
||||
offset: time.Second,
|
||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
keySet: tu.KeySet{},
|
||||
maxAge: 2 * time.Minute,
|
||||
acr: tu.ACRVerify,
|
||||
nonce: func(context.Context) string { return tu.ValidNonce },
|
||||
clientID: tu.ValidClientID,
|
||||
}
|
||||
accessToken, _ := tu.ValidAccessToken()
|
||||
atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
idTokenClaims func() (string, *oidc.IDTokenClaims)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "without access token",
|
||||
idTokenClaims: tu.ValidIDToken,
|
||||
},
|
||||
{
|
||||
name: "with access token",
|
||||
accessToken: accessToken,
|
||||
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash,
|
||||
)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "expired id token",
|
||||
accessToken: accessToken,
|
||||
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash,
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong access token",
|
||||
accessToken: accessToken,
|
||||
idTokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "~~~",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
idToken, want := tt.idTokenClaims()
|
||||
got, err := VerifyTokens[*oidc.IDTokenClaims](context.Background(), tt.accessToken, idToken, verifier)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, got, want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyIDToken(t *testing.T) {
|
||||
verifier := &idTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
maxAgeIAT: 2 * time.Minute,
|
||||
offset: time.Second,
|
||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
keySet: tu.KeySet{},
|
||||
maxAge: 2 * time.Minute,
|
||||
acr: tu.ACRVerify,
|
||||
nonce: func(context.Context) string { return tu.ValidNonce },
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientID string
|
||||
tokenClaims func() (string, *oidc.IDTokenClaims)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: tu.ValidIDToken,
|
||||
},
|
||||
{
|
||||
name: "parse err",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid signature",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty subject",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, "", tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong issuer",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
"foo", tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong clientID",
|
||||
clientID: "foo",
|
||||
tokenClaims: tu.ValidIDToken,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong IAT",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong acr",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
"else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired auth",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime.Add(-time.Hour), tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong nonce",
|
||||
clientID: tu.ValidClientID,
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, "foo",
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, want := tt.tokenClaims()
|
||||
verifier.clientID = tt.clientID
|
||||
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, got, want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAccessToken(t *testing.T) {
|
||||
token, _ := tu.ValidAccessToken()
|
||||
hash, err := oidc.ClaimHash(token, tu.SignatureAlgorithm)
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
accessToken string
|
||||
atHash string
|
||||
sigAlgorithm jose.SignatureAlgorithm
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty hash",
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
args: args{
|
||||
accessToken: token,
|
||||
atHash: hash,
|
||||
sigAlgorithm: tu.SignatureAlgorithm,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid algorithm",
|
||||
args: args{
|
||||
accessToken: token,
|
||||
atHash: hash,
|
||||
sigAlgorithm: "foo",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "mismatch",
|
||||
args: args{
|
||||
accessToken: token,
|
||||
atHash: "~~",
|
||||
sigAlgorithm: tu.SignatureAlgorithm,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := VerifyAccessToken(tt.args.accessToken, tt.args.atHash, tt.args.sigAlgorithm)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewIDTokenVerifier(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
clientID string
|
||||
keySet oidc.KeySet
|
||||
options []VerifierOption
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want IDTokenVerifier
|
||||
}{
|
||||
{
|
||||
name: "nil nonce", // otherwise assert.Equal will fail on the function
|
||||
args: args{
|
||||
issuer: tu.ValidIssuer,
|
||||
clientID: tu.ValidClientID,
|
||||
keySet: tu.KeySet{},
|
||||
options: []VerifierOption{
|
||||
WithIssuedAtOffset(time.Minute),
|
||||
WithIssuedAtMaxAge(time.Hour),
|
||||
WithNonce(nil), // otherwise assert.Equal will fail on the function
|
||||
WithACRVerifier(nil),
|
||||
WithAuthTimeMaxAge(2 * time.Hour),
|
||||
WithSupportedSigningAlgorithms("ABC", "DEF"),
|
||||
},
|
||||
},
|
||||
want: &idTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
offset: time.Minute,
|
||||
maxAgeIAT: time.Hour,
|
||||
clientID: tu.ValidClientID,
|
||||
keySet: tu.KeySet{},
|
||||
nonce: nil,
|
||||
acr: nil,
|
||||
maxAge: 2 * time.Hour,
|
||||
supportedSignAlgs: []string{"ABC", "DEF"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewIDTokenVerifier(tt.args.issuer, tt.args.clientID, tt.args.keySet, tt.args.options...)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
86
pkg/client/rp/verifier_tokens_example_test.go
Normal file
86
pkg/client/rp/verifier_tokens_example_test.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package rp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
// MyCustomClaims extends the TokenClaims base,
|
||||
// so it implmeents the oidc.Claims interface.
|
||||
// Instead of carrying a map, we add needed fields// to the struct for type safe access.
|
||||
type MyCustomClaims struct {
|
||||
oidc.TokenClaims
|
||||
NotBefore oidc.Time `json:"nbf,omitempty"`
|
||||
AccessTokenHash string `json:"at_hash,omitempty"`
|
||||
Foo string `json:"foo,omitempty"`
|
||||
Bar *Nested `json:"bar,omitempty"`
|
||||
}
|
||||
|
||||
// GetAccessTokenHash is required to implement
|
||||
// the oidc.IDClaims interface.
|
||||
func (c *MyCustomClaims) GetAccessTokenHash() string {
|
||||
return c.AccessTokenHash
|
||||
}
|
||||
|
||||
// Nested struct types are also possible.
|
||||
type Nested struct {
|
||||
Count int `json:"count,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
/*
|
||||
idToken carries the following claims. foo and bar are custom claims
|
||||
|
||||
{
|
||||
"acr": "something",
|
||||
"amr": [
|
||||
"foo",
|
||||
"bar"
|
||||
],
|
||||
"at_hash": "2dzbm_vIxy-7eRtqUIGPPw",
|
||||
"aud": [
|
||||
"unit",
|
||||
"test",
|
||||
"555666"
|
||||
],
|
||||
"auth_time": 1678100961,
|
||||
"azp": "555666",
|
||||
"bar": {
|
||||
"count": 22,
|
||||
"tags": [
|
||||
"some",
|
||||
"tags"
|
||||
]
|
||||
},
|
||||
"client_id": "555666",
|
||||
"exp": 4802238682,
|
||||
"foo": "Hello, World!",
|
||||
"iat": 1678101021,
|
||||
"iss": "local.com",
|
||||
"jti": "9876",
|
||||
"nbf": 1678101021,
|
||||
"nonce": "12345",
|
||||
"sub": "tim@local.com"
|
||||
}
|
||||
*/
|
||||
const idToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF0X2hhc2giOiIyZHpibV92SXh5LTdlUnRxVUlHUFB3IiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImF1dGhfdGltZSI6MTY3ODEwMDk2MSwiYXpwIjoiNTU1NjY2IiwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiY2xpZW50X2lkIjoiNTU1NjY2IiwiZXhwIjo0ODAyMjM4NjgyLCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MTAxMDIxLCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MTAxMDIxLCJub25jZSI6IjEyMzQ1Iiwic3ViIjoidGltQGxvY2FsLmNvbSJ9.t3GXSfVNNwiW1Suv9_84v0sdn2_-RWHVxhphhRozDXnsO7SDNOlGnEioemXABESxSzMclM7gB7mYy5Qah2ZUNx7eP5t2njoxEYfavgHwx7UJZ2NCg8NDPQyr-hlxelEcfdXK-I0oTd-FRDvF4rqPkD9Us52IpnplChCxnHFgh4wKwPqZZjv2IXVCtn0ilKW3hff1rMOYKEuLRcN2YP0gkyuqyHvcf2dMmjod0t4sLOTJ82rsCbMBC5CLpqv3nIC9HOGITkt1Kd-Am0n1LrdZvWwTo6RFe8AnzF0gpqjcB5Wg4Qeh58DIjZOz4f_8wnmJ_gCqyRh5vfSW4XHdbum0Tw`
|
||||
const accessToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhdWQiOlsidW5pdCIsInRlc3QiXSwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiZXhwIjo0ODAyMjM4NjgyLCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MTAxMDIxLCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MTAxMDIxLCJzdWIiOiJ0aW1AbG9jYWwuY29tIn0.Zrz3LWSRjCMJZUMaI5dUbW4vGdSmEeJQ3ouhaX0bcW9rdFFLgBI4K2FWJhNivq8JDmCGSxwLu3mI680GWmDaEoAx1M5sCO9lqfIZHGZh-lfAXk27e6FPLlkTDBq8Bx4o4DJ9Fw0hRJGjUTjnYv5cq1vo2-UqldasL6CwTbkzNC_4oQFfRtuodC4Ql7dZ1HRv5LXuYx7KPkOssLZtV9cwtJp5nFzKjcf2zEE_tlbjcpynMwypornRUp1EhCWKRUGkJhJeiP71ECY5pQhShfjBu9Nc5wDpSnZmnk2S4YsPrRK3QkE-iEkas8BfsOCrGoErHjEJexAIDjasGO5PFLWfCA`
|
||||
|
||||
func ExampleVerifyTokens_customClaims() {
|
||||
v := rp.NewIDTokenVerifier("local.com", "555666", tu.KeySet{},
|
||||
rp.WithNonce(func(ctx context.Context) string { return "12345" }),
|
||||
)
|
||||
|
||||
// VerifyAccessToken can be called with the *MyCustomClaims.
|
||||
claims, err := rp.VerifyTokens[*MyCustomClaims](context.TODO(), accessToken, idToken, v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Here we have typesafe access to the custom claims
|
||||
fmt.Println(claims.Foo, claims.Bar.Count, claims.Bar.Tags)
|
||||
// Output: Hello, World! 22 [some tags]
|
||||
}
|
|
@ -6,13 +6,14 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type ResourceServer interface {
|
||||
IntrospectionURL() string
|
||||
TokenEndpoint() string
|
||||
HttpClient() *http.Client
|
||||
AuthFn() (interface{}, error)
|
||||
}
|
||||
|
@ -29,6 +30,10 @@ func (r *resourceServer) IntrospectionURL() string {
|
|||
return r.introspectURL
|
||||
}
|
||||
|
||||
func (r *resourceServer) TokenEndpoint() string {
|
||||
return r.tokenURL
|
||||
}
|
||||
|
||||
func (r *resourceServer) HttpClient() *http.Client {
|
||||
return r.httpClient
|
||||
}
|
||||
|
@ -107,7 +112,7 @@ func WithStaticEndpoints(tokenURL, introspectURL string) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func Introspect(ctx context.Context, rp ResourceServer, token string) (oidc.IntrospectionResponse, error) {
|
||||
func Introspect(ctx context.Context, rp ResourceServer, token string) (*oidc.IntrospectionResponse, error) {
|
||||
authFn, err := rp.AuthFn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -116,7 +121,7 @@ func Introspect(ctx context.Context, rp ResourceServer, token string) (oidc.Intr
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp := oidc.NewIntrospectionResponse()
|
||||
resp := new(oidc.IntrospectionResponse)
|
||||
if err := httphelper.HttpRequest(rp.HttpClient(), req, resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
127
pkg/client/tokenexchange/tokenexchange.go
Normal file
127
pkg/client/tokenexchange/tokenexchange.go
Normal file
|
@ -0,0 +1,127 @@
|
|||
package tokenexchange
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type TokenExchanger interface {
|
||||
TokenEndpoint() string
|
||||
HttpClient() *http.Client
|
||||
AuthFn() (interface{}, error)
|
||||
}
|
||||
|
||||
type OAuthTokenExchange struct {
|
||||
httpClient *http.Client
|
||||
tokenEndpoint string
|
||||
authFn func() (interface{}, error)
|
||||
}
|
||||
|
||||
func NewTokenExchanger(issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
|
||||
return newOAuthTokenExchange(issuer, nil, options...)
|
||||
}
|
||||
|
||||
func NewTokenExchangerClientCredentials(issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
|
||||
authorizer := func() (interface{}, error) {
|
||||
return httphelper.AuthorizeBasic(clientID, clientSecret), nil
|
||||
}
|
||||
return newOAuthTokenExchange(issuer, authorizer, options...)
|
||||
}
|
||||
|
||||
func newOAuthTokenExchange(issuer string, authorizer func() (interface{}, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) {
|
||||
te := &OAuthTokenExchange{
|
||||
httpClient: httphelper.DefaultHTTPClient,
|
||||
}
|
||||
for _, opt := range options {
|
||||
opt(te)
|
||||
}
|
||||
|
||||
if te.tokenEndpoint == "" {
|
||||
config, err := client.Discover(issuer, te.httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
te.tokenEndpoint = config.TokenEndpoint
|
||||
}
|
||||
|
||||
if te.tokenEndpoint == "" {
|
||||
return nil, errors.New("tokenURL is empty: please provide with either `WithStaticTokenEndpoint` or a discovery url")
|
||||
}
|
||||
|
||||
te.authFn = authorizer
|
||||
|
||||
return te, nil
|
||||
}
|
||||
|
||||
func WithHTTPClient(client *http.Client) func(*OAuthTokenExchange) {
|
||||
return func(source *OAuthTokenExchange) {
|
||||
source.httpClient = client
|
||||
}
|
||||
}
|
||||
|
||||
func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(*OAuthTokenExchange) {
|
||||
return func(source *OAuthTokenExchange) {
|
||||
source.tokenEndpoint = tokenEndpoint
|
||||
}
|
||||
}
|
||||
|
||||
func (te *OAuthTokenExchange) TokenEndpoint() string {
|
||||
return te.tokenEndpoint
|
||||
}
|
||||
|
||||
func (te *OAuthTokenExchange) HttpClient() *http.Client {
|
||||
return te.httpClient
|
||||
}
|
||||
|
||||
func (te *OAuthTokenExchange) AuthFn() (interface{}, error) {
|
||||
if te.authFn != nil {
|
||||
return te.authFn()
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint.
|
||||
// SubjectToken and SubjectTokenType are required parameters.
|
||||
func ExchangeToken(
|
||||
te TokenExchanger,
|
||||
SubjectToken string,
|
||||
SubjectTokenType oidc.TokenType,
|
||||
ActorToken string,
|
||||
ActorTokenType oidc.TokenType,
|
||||
Resource []string,
|
||||
Audience []string,
|
||||
Scopes []string,
|
||||
RequestedTokenType oidc.TokenType,
|
||||
) (*oidc.TokenExchangeResponse, error) {
|
||||
if SubjectToken == "" {
|
||||
return nil, errors.New("empty subject_token")
|
||||
}
|
||||
if SubjectTokenType == "" {
|
||||
return nil, errors.New("empty subject_token_type")
|
||||
}
|
||||
|
||||
authFn, err := te.AuthFn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
request := oidc.TokenExchangeRequest{
|
||||
GrantType: oidc.GrantTypeTokenExchange,
|
||||
SubjectToken: SubjectToken,
|
||||
SubjectTokenType: SubjectTokenType,
|
||||
ActorToken: ActorToken,
|
||||
ActorTokenType: ActorTokenType,
|
||||
Resource: Resource,
|
||||
Audience: Audience,
|
||||
Scopes: Scopes,
|
||||
RequestedTokenType: RequestedTokenType,
|
||||
}
|
||||
|
||||
return client.CallTokenExchangeEndpoint(request, authFn, te)
|
||||
}
|
|
@ -3,7 +3,7 @@ package oidc
|
|||
import (
|
||||
"crypto/sha256"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/crypto"
|
||||
"github.com/zitadel/oidc/v2/pkg/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
29
pkg/oidc/device_authorization.go
Normal file
29
pkg/oidc/device_authorization.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package oidc
|
||||
|
||||
// DeviceAuthorizationRequest implements
|
||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1,
|
||||
// 3.1 Device Authorization Request.
|
||||
type DeviceAuthorizationRequest struct {
|
||||
Scopes SpaceDelimitedArray `schema:"scope"`
|
||||
ClientID string `schema:"client_id"`
|
||||
}
|
||||
|
||||
// DeviceAuthorizationResponse implements
|
||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.2
|
||||
// 3.2. Device Authorization Response.
|
||||
type DeviceAuthorizationResponse struct {
|
||||
DeviceCode string `json:"device_code"`
|
||||
UserCode string `json:"user_code"`
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Interval int `json:"interval,omitempty"`
|
||||
}
|
||||
|
||||
// DeviceAccessTokenRequest implements
|
||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4,
|
||||
// Device Access Token Request.
|
||||
type DeviceAccessTokenRequest struct {
|
||||
GrantType GrantType `json:"grant_type" schema:"grant_type"`
|
||||
DeviceCode string `json:"device_code" schema:"device_code"`
|
||||
}
|
|
@ -30,6 +30,8 @@ type DiscoveryConfiguration struct {
|
|||
// EndSessionEndpoint is a URL where the RP can perform a redirect to request that the End-User be logged out at the OP.
|
||||
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
|
||||
|
||||
DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint,omitempty"`
|
||||
|
||||
// CheckSessionIframe is a URL where the OP provides an iframe that support cross-origin communications for session state information with the RP Client.
|
||||
CheckSessionIframe string `json:"check_session_iframe,omitempty"`
|
||||
|
||||
|
|
|
@ -18,6 +18,14 @@ const (
|
|||
InteractionRequired errorType = "interaction_required"
|
||||
LoginRequired errorType = "login_required"
|
||||
RequestNotSupported errorType = "request_not_supported"
|
||||
|
||||
// Additional error codes as defined in
|
||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.5
|
||||
// Device Access Token Response
|
||||
AuthorizationPending errorType = "authorization_pending"
|
||||
SlowDown errorType = "slow_down"
|
||||
AccessDenied errorType = "access_denied"
|
||||
ExpiredToken errorType = "expired_token"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -77,6 +85,32 @@ var (
|
|||
ErrorType: RequestNotSupported,
|
||||
}
|
||||
}
|
||||
|
||||
// Device Access Token errors:
|
||||
ErrAuthorizationPending = func() *Error {
|
||||
return &Error{
|
||||
ErrorType: AuthorizationPending,
|
||||
Description: "The client SHOULD repeat the access token request to the token endpoint, after interval from device authorization response.",
|
||||
}
|
||||
}
|
||||
ErrSlowDown = func() *Error {
|
||||
return &Error{
|
||||
ErrorType: SlowDown,
|
||||
Description: "Polling should continue, but the interval MUST be increased by 5 seconds for this and all subsequent requests.",
|
||||
}
|
||||
}
|
||||
ErrAccessDenied = func() *Error {
|
||||
return &Error{
|
||||
ErrorType: AccessDenied,
|
||||
Description: "The authorization request was denied.",
|
||||
}
|
||||
}
|
||||
ErrExpiredDeviceCode = func() *Error {
|
||||
return &Error{
|
||||
ErrorType: ExpiredToken,
|
||||
Description: "The \"device_code\" has expired.",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
|
|
|
@ -1,12 +1,6 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
import "github.com/muhlemmer/gu"
|
||||
|
||||
type IntrospectionRequest struct {
|
||||
Token string `schema:"token"`
|
||||
|
@ -17,36 +11,11 @@ type ClientAssertionParams struct {
|
|||
ClientAssertionType string `schema:"client_assertion_type"`
|
||||
}
|
||||
|
||||
type IntrospectionResponse interface {
|
||||
UserInfoSetter
|
||||
IsActive() bool
|
||||
SetActive(bool)
|
||||
SetScopes(scopes []string)
|
||||
SetClientID(id string)
|
||||
SetTokenType(tokenType string)
|
||||
SetExpiration(exp time.Time)
|
||||
SetIssuedAt(iat time.Time)
|
||||
SetNotBefore(nbf time.Time)
|
||||
SetAudience(audience []string)
|
||||
SetIssuer(issuer string)
|
||||
SetJWTID(id string)
|
||||
GetScope() []string
|
||||
GetClientID() string
|
||||
GetTokenType() string
|
||||
GetExpiration() time.Time
|
||||
GetIssuedAt() time.Time
|
||||
GetNotBefore() time.Time
|
||||
GetSubject() string
|
||||
GetAudience() []string
|
||||
GetIssuer() string
|
||||
GetJWTID() string
|
||||
}
|
||||
|
||||
func NewIntrospectionResponse() IntrospectionResponse {
|
||||
return &introspectionResponse{}
|
||||
}
|
||||
|
||||
type introspectionResponse struct {
|
||||
// IntrospectionResponse implements RFC 7662, section 2.2 and
|
||||
// OpenID Connect Core 1.0, section 5.1 (UserInfo).
|
||||
// https://www.rfc-editor.org/rfc/rfc7662.html#section-2.2.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims.
|
||||
type IntrospectionResponse struct {
|
||||
Active bool `json:"active"`
|
||||
Scope SpaceDelimitedArray `json:"scope,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
|
@ -58,323 +27,50 @@ type introspectionResponse struct {
|
|||
Audience Audience `json:"aud,omitempty"`
|
||||
Issuer string `json:"iss,omitempty"`
|
||||
JWTID string `json:"jti,omitempty"`
|
||||
userInfoProfile
|
||||
userInfoEmail
|
||||
userInfoPhone
|
||||
Username string `json:"username,omitempty"`
|
||||
UserInfoProfile
|
||||
UserInfoEmail
|
||||
UserInfoPhone
|
||||
|
||||
Address UserInfoAddress `json:"address,omitempty"`
|
||||
claims map[string]interface{}
|
||||
Address *UserInfoAddress `json:"address,omitempty"`
|
||||
Claims map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) IsActive() bool {
|
||||
return i.Active
|
||||
// SetUserInfo copies all relevant fields from UserInfo
|
||||
// into the IntroSpectionResponse.
|
||||
func (i *IntrospectionResponse) SetUserInfo(u *UserInfo) {
|
||||
i.Subject = u.Subject
|
||||
i.Username = u.PreferredUsername
|
||||
i.Address = gu.PtrCopy(u.Address)
|
||||
i.UserInfoProfile = u.UserInfoProfile
|
||||
i.UserInfoEmail = u.UserInfoEmail
|
||||
i.UserInfoPhone = u.UserInfoPhone
|
||||
if i.Claims == nil {
|
||||
i.Claims = gu.MapCopy(u.Claims)
|
||||
} else {
|
||||
gu.MapMerge(u.Claims, i.Claims)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetSubject() string {
|
||||
return i.Subject
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetName() string {
|
||||
return i.Name
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetGivenName() string {
|
||||
return i.GivenName
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetFamilyName() string {
|
||||
return i.FamilyName
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetMiddleName() string {
|
||||
return i.MiddleName
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetNickname() string {
|
||||
return i.Nickname
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetProfile() string {
|
||||
return i.Profile
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetPicture() string {
|
||||
return i.Picture
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetWebsite() string {
|
||||
return i.Website
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetGender() Gender {
|
||||
return i.Gender
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetBirthdate() string {
|
||||
return i.Birthdate
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetZoneinfo() string {
|
||||
return i.Zoneinfo
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetLocale() language.Tag {
|
||||
return i.Locale
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetPreferredUsername() string {
|
||||
return i.PreferredUsername
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetEmail() string {
|
||||
return i.Email
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) IsEmailVerified() bool {
|
||||
return bool(i.EmailVerified)
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetPhoneNumber() string {
|
||||
return i.PhoneNumber
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) IsPhoneNumberVerified() bool {
|
||||
return i.PhoneNumberVerified
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetAddress() UserInfoAddress {
|
||||
// GetAddress is a safe getter that takes
|
||||
// care of a possible nil value.
|
||||
func (i *IntrospectionResponse) GetAddress() *UserInfoAddress {
|
||||
if i.Address == nil {
|
||||
return new(UserInfoAddress)
|
||||
}
|
||||
return i.Address
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetClaim(key string) interface{} {
|
||||
return i.claims[key]
|
||||
}
|
||||
// introspectionResponseAlias prevents loops on the JSON methods
|
||||
type introspectionResponseAlias IntrospectionResponse
|
||||
|
||||
func (i *introspectionResponse) GetClaims() map[string]interface{} {
|
||||
return i.claims
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetScope() []string {
|
||||
return []string(i.Scope)
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetClientID() string {
|
||||
return i.ClientID
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetTokenType() string {
|
||||
return i.TokenType
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetExpiration() time.Time {
|
||||
return time.Time(i.Expiration)
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetIssuedAt() time.Time {
|
||||
return time.Time(i.IssuedAt)
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetNotBefore() time.Time {
|
||||
return time.Time(i.NotBefore)
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetAudience() []string {
|
||||
return []string(i.Audience)
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetIssuer() string {
|
||||
return i.Issuer
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) GetJWTID() string {
|
||||
return i.JWTID
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetActive(active bool) {
|
||||
i.Active = active
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetScopes(scope []string) {
|
||||
i.Scope = scope
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetClientID(id string) {
|
||||
i.ClientID = id
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetTokenType(tokenType string) {
|
||||
i.TokenType = tokenType
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetExpiration(exp time.Time) {
|
||||
i.Expiration = Time(exp)
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetIssuedAt(iat time.Time) {
|
||||
i.IssuedAt = Time(iat)
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetNotBefore(nbf time.Time) {
|
||||
i.NotBefore = Time(nbf)
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetAudience(audience []string) {
|
||||
i.Audience = audience
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetIssuer(issuer string) {
|
||||
i.Issuer = issuer
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetJWTID(id string) {
|
||||
i.JWTID = id
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetSubject(sub string) {
|
||||
i.Subject = sub
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetName(name string) {
|
||||
i.Name = name
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetGivenName(name string) {
|
||||
i.GivenName = name
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetFamilyName(name string) {
|
||||
i.FamilyName = name
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetMiddleName(name string) {
|
||||
i.MiddleName = name
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetNickname(name string) {
|
||||
i.Nickname = name
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetUpdatedAt(date time.Time) {
|
||||
i.UpdatedAt = Time(date)
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetProfile(profile string) {
|
||||
i.Profile = profile
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetPicture(picture string) {
|
||||
i.Picture = picture
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetWebsite(website string) {
|
||||
i.Website = website
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetGender(gender Gender) {
|
||||
i.Gender = gender
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetBirthdate(birthdate string) {
|
||||
i.Birthdate = birthdate
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetZoneinfo(zoneInfo string) {
|
||||
i.Zoneinfo = zoneInfo
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetLocale(locale language.Tag) {
|
||||
i.Locale = locale
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetPreferredUsername(name string) {
|
||||
i.PreferredUsername = name
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetEmail(email string, verified bool) {
|
||||
i.Email = email
|
||||
i.EmailVerified = boolString(verified)
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetPhone(phone string, verified bool) {
|
||||
i.PhoneNumber = phone
|
||||
i.PhoneNumberVerified = verified
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) SetAddress(address UserInfoAddress) {
|
||||
i.Address = address
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) AppendClaims(key string, value interface{}) {
|
||||
if i.claims == nil {
|
||||
i.claims = make(map[string]interface{})
|
||||
func (i *IntrospectionResponse) MarshalJSON() ([]byte, error) {
|
||||
if i.Username == "" {
|
||||
i.Username = i.PreferredUsername
|
||||
}
|
||||
i.claims[key] = value
|
||||
return mergeAndMarshalClaims((*introspectionResponseAlias)(i), i.Claims)
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) MarshalJSON() ([]byte, error) {
|
||||
type Alias introspectionResponse
|
||||
a := &struct {
|
||||
*Alias
|
||||
Expiration int64 `json:"exp,omitempty"`
|
||||
IssuedAt int64 `json:"iat,omitempty"`
|
||||
NotBefore int64 `json:"nbf,omitempty"`
|
||||
Locale interface{} `json:"locale,omitempty"`
|
||||
UpdatedAt int64 `json:"updated_at,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
}{
|
||||
Alias: (*Alias)(i),
|
||||
}
|
||||
if !i.Locale.IsRoot() {
|
||||
a.Locale = i.Locale
|
||||
}
|
||||
if !time.Time(i.UpdatedAt).IsZero() {
|
||||
a.UpdatedAt = time.Time(i.UpdatedAt).Unix()
|
||||
}
|
||||
if !time.Time(i.Expiration).IsZero() {
|
||||
a.Expiration = time.Time(i.Expiration).Unix()
|
||||
}
|
||||
if !time.Time(i.IssuedAt).IsZero() {
|
||||
a.IssuedAt = time.Time(i.IssuedAt).Unix()
|
||||
}
|
||||
if !time.Time(i.NotBefore).IsZero() {
|
||||
a.NotBefore = time.Time(i.NotBefore).Unix()
|
||||
}
|
||||
a.Username = i.PreferredUsername
|
||||
|
||||
b, err := json.Marshal(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(i.claims) == 0 {
|
||||
return b, nil
|
||||
}
|
||||
|
||||
err = json.Unmarshal(b, &i.claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("jws: invalid map of custom claims %v", i.claims)
|
||||
}
|
||||
|
||||
return json.Marshal(i.claims)
|
||||
}
|
||||
|
||||
func (i *introspectionResponse) UnmarshalJSON(data []byte) error {
|
||||
type Alias introspectionResponse
|
||||
a := &struct {
|
||||
*Alias
|
||||
UpdatedAt int64 `json:"update_at,omitempty"`
|
||||
}{
|
||||
Alias: (*Alias)(i),
|
||||
}
|
||||
if err := json.Unmarshal(data, &a); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
i.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC())
|
||||
|
||||
if err := json.Unmarshal(data, &i.claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
func (i *IntrospectionResponse) UnmarshalJSON(data []byte) error {
|
||||
return unmarshalJSONMulti(data, (*introspectionResponseAlias)(i), &i.Claims)
|
||||
}
|
||||
|
|
78
pkg/oidc/introspection_test.go
Normal file
78
pkg/oidc/introspection_test.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIntrospectionResponse_SetUserInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
start *IntrospectionResponse
|
||||
want *IntrospectionResponse
|
||||
}{
|
||||
{
|
||||
|
||||
name: "nil claims",
|
||||
start: &IntrospectionResponse{},
|
||||
want: &IntrospectionResponse{
|
||||
Subject: userInfoData.Subject,
|
||||
Username: userInfoData.PreferredUsername,
|
||||
Address: userInfoData.Address,
|
||||
UserInfoProfile: userInfoData.UserInfoProfile,
|
||||
UserInfoEmail: userInfoData.UserInfoEmail,
|
||||
UserInfoPhone: userInfoData.UserInfoPhone,
|
||||
Claims: userInfoData.Claims,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
name: "merge claims",
|
||||
start: &IntrospectionResponse{
|
||||
Claims: map[string]any{
|
||||
"hello": "world",
|
||||
},
|
||||
},
|
||||
want: &IntrospectionResponse{
|
||||
Subject: userInfoData.Subject,
|
||||
Username: userInfoData.PreferredUsername,
|
||||
Address: userInfoData.Address,
|
||||
UserInfoProfile: userInfoData.UserInfoProfile,
|
||||
UserInfoEmail: userInfoData.UserInfoEmail,
|
||||
UserInfoPhone: userInfoData.UserInfoPhone,
|
||||
Claims: map[string]any{
|
||||
"foo": "bar",
|
||||
"hello": "world",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.start.SetUserInfo(userInfoData)
|
||||
assert.Equal(t, tt.want, tt.start)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntrospectionResponse_GetAddress(t *testing.T) {
|
||||
// nil address
|
||||
i := new(IntrospectionResponse)
|
||||
assert.Equal(t, &UserInfoAddress{}, i.GetAddress())
|
||||
|
||||
i.Address = &UserInfoAddress{PostalCode: "1234"}
|
||||
assert.Equal(t, i.Address, i.GetAddress())
|
||||
}
|
||||
|
||||
func TestIntrospectionResponse_MarshalJSON(t *testing.T) {
|
||||
got, err := json.Marshal(&IntrospectionResponse{
|
||||
UserInfoProfile: UserInfoProfile{
|
||||
PreferredUsername: "muhlemmer",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(got), `{"active":false,"username":"muhlemmer","preferred_username":"muhlemmer"}`)
|
||||
}
|
50
pkg/oidc/regression_assert_test.go
Normal file
50
pkg/oidc/regression_assert_test.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
//go:build !create_regression_data
|
||||
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test_assert_regression verifies current output from
|
||||
// json.Marshal to stored regression data.
|
||||
// These tests are only ran when the create_regression_data
|
||||
// tag is NOT set.
|
||||
func Test_assert_regression(t *testing.T) {
|
||||
buf := new(strings.Builder)
|
||||
|
||||
for _, obj := range regressionData {
|
||||
name := jsonFilename(obj)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
file, err := os.Open(name)
|
||||
require.NoError(t, err)
|
||||
defer file.Close()
|
||||
|
||||
_, err = io.Copy(buf, file)
|
||||
require.NoError(t, err)
|
||||
want := buf.String()
|
||||
buf.Reset()
|
||||
|
||||
encodeJSON(t, buf, obj)
|
||||
first := buf.String()
|
||||
buf.Reset()
|
||||
|
||||
assert.JSONEq(t, want, first)
|
||||
|
||||
require.NoError(t,
|
||||
json.Unmarshal([]byte(first), obj),
|
||||
)
|
||||
second, err := json.Marshal(obj)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.JSONEq(t, want, string(second))
|
||||
})
|
||||
}
|
||||
}
|
24
pkg/oidc/regression_create_test.go
Normal file
24
pkg/oidc/regression_create_test.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
//go:build create_regression_data
|
||||
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test_create_regression generates the regression data.
|
||||
// It is excluded from regular testing, unless
|
||||
// called with the create_regression_data tag:
|
||||
// go test -tags="create_regression_data" ./pkg/oidc
|
||||
func Test_create_regression(t *testing.T) {
|
||||
for _, obj := range regressionData {
|
||||
file, err := os.Create(jsonFilename(obj))
|
||||
require.NoError(t, err)
|
||||
defer file.Close()
|
||||
|
||||
encodeJSON(t, file, obj)
|
||||
}
|
||||
}
|
23
pkg/oidc/regression_data/oidc.AccessTokenClaims.json
Normal file
23
pkg/oidc/regression_data/oidc.AccessTokenClaims.json
Normal file
|
@ -0,0 +1,23 @@
|
|||
{
|
||||
"iss": "zitadel",
|
||||
"sub": "hello@me.com",
|
||||
"aud": [
|
||||
"foo",
|
||||
"bar"
|
||||
],
|
||||
"jti": "900",
|
||||
"azp": "just@me.com",
|
||||
"nonce": "6969",
|
||||
"acr": "something",
|
||||
"amr": [
|
||||
"some",
|
||||
"methods"
|
||||
],
|
||||
"scope": "email phone",
|
||||
"client_id": "777",
|
||||
"exp": 12345,
|
||||
"iat": 12000,
|
||||
"nbf": 12000,
|
||||
"auth_time": 12000,
|
||||
"foo": "bar"
|
||||
}
|
51
pkg/oidc/regression_data/oidc.IDTokenClaims.json
Normal file
51
pkg/oidc/regression_data/oidc.IDTokenClaims.json
Normal file
|
@ -0,0 +1,51 @@
|
|||
{
|
||||
"iss": "zitadel",
|
||||
"aud": [
|
||||
"foo",
|
||||
"bar"
|
||||
],
|
||||
"jti": "900",
|
||||
"azp": "just@me.com",
|
||||
"nonce": "6969",
|
||||
"at_hash": "acthashhash",
|
||||
"c_hash": "hashhash",
|
||||
"acr": "something",
|
||||
"amr": [
|
||||
"some",
|
||||
"methods"
|
||||
],
|
||||
"sid": "666",
|
||||
"client_id": "777",
|
||||
"exp": 12345,
|
||||
"iat": 12000,
|
||||
"nbf": 12000,
|
||||
"auth_time": 12000,
|
||||
"address": {
|
||||
"country": "Moon",
|
||||
"formatted": "Sesame street 666\n666-666, Smallvile\nMoon",
|
||||
"locality": "Smallvile",
|
||||
"postal_code": "666-666",
|
||||
"region": "Outer space",
|
||||
"street_address": "Sesame street 666"
|
||||
},
|
||||
"birthdate": "1st of April",
|
||||
"email": "tim@zitadel.com",
|
||||
"email_verified": true,
|
||||
"family_name": "Möhlmann",
|
||||
"foo": "bar",
|
||||
"gender": "male",
|
||||
"given_name": "Tim",
|
||||
"locale": "nl",
|
||||
"middle_name": "Danger",
|
||||
"name": "Tim Möhlmann",
|
||||
"nickname": "muhlemmer",
|
||||
"phone_number": "+1234567890",
|
||||
"phone_number_verified": true,
|
||||
"picture": "https://avatars.githubusercontent.com/u/5411563?v=4",
|
||||
"preferred_username": "muhlemmer",
|
||||
"profile": "https://github.com/muhlemmer",
|
||||
"sub": "hello@me.com",
|
||||
"updated_at": 1,
|
||||
"website": "https://zitadel.com",
|
||||
"zoneinfo": "Europe/Amsterdam"
|
||||
}
|
44
pkg/oidc/regression_data/oidc.IntrospectionResponse.json
Normal file
44
pkg/oidc/regression_data/oidc.IntrospectionResponse.json
Normal file
|
@ -0,0 +1,44 @@
|
|||
{
|
||||
"active": true,
|
||||
"address": {
|
||||
"country": "Moon",
|
||||
"formatted": "Sesame street 666\n666-666, Smallvile\nMoon",
|
||||
"locality": "Smallvile",
|
||||
"postal_code": "666-666",
|
||||
"region": "Outer space",
|
||||
"street_address": "Sesame street 666"
|
||||
},
|
||||
"aud": [
|
||||
"foo",
|
||||
"bar"
|
||||
],
|
||||
"birthdate": "1st of April",
|
||||
"client_id": "777",
|
||||
"email": "tim@zitadel.com",
|
||||
"email_verified": true,
|
||||
"exp": 12345,
|
||||
"family_name": "Möhlmann",
|
||||
"foo": "bar",
|
||||
"gender": "male",
|
||||
"given_name": "Tim",
|
||||
"iat": 12000,
|
||||
"iss": "zitadel",
|
||||
"jti": "900",
|
||||
"locale": "nl",
|
||||
"middle_name": "Danger",
|
||||
"name": "Tim Möhlmann",
|
||||
"nbf": 12000,
|
||||
"nickname": "muhlemmer",
|
||||
"phone_number": "+1234567890",
|
||||
"phone_number_verified": true,
|
||||
"picture": "https://avatars.githubusercontent.com/u/5411563?v=4",
|
||||
"preferred_username": "muhlemmer",
|
||||
"profile": "https://github.com/muhlemmer",
|
||||
"scope": "email phone",
|
||||
"sub": "hello@me.com",
|
||||
"token_type": "idtoken",
|
||||
"updated_at": 1,
|
||||
"username": "muhlemmer",
|
||||
"website": "https://zitadel.com",
|
||||
"zoneinfo": "Europe/Amsterdam"
|
||||
}
|
11
pkg/oidc/regression_data/oidc.JWTProfileAssertionClaims.json
Normal file
11
pkg/oidc/regression_data/oidc.JWTProfileAssertionClaims.json
Normal file
|
@ -0,0 +1,11 @@
|
|||
{
|
||||
"aud": [
|
||||
"foo",
|
||||
"bar"
|
||||
],
|
||||
"exp": 12345,
|
||||
"foo": "bar",
|
||||
"iat": 12000,
|
||||
"iss": "zitadel",
|
||||
"sub": "hello@me.com"
|
||||
}
|
30
pkg/oidc/regression_data/oidc.UserInfo.json
Normal file
30
pkg/oidc/regression_data/oidc.UserInfo.json
Normal file
|
@ -0,0 +1,30 @@
|
|||
{
|
||||
"address": {
|
||||
"country": "Moon",
|
||||
"formatted": "Sesame street 666\n666-666, Smallvile\nMoon",
|
||||
"locality": "Smallvile",
|
||||
"postal_code": "666-666",
|
||||
"region": "Outer space",
|
||||
"street_address": "Sesame street 666"
|
||||
},
|
||||
"birthdate": "1st of April",
|
||||
"email": "tim@zitadel.com",
|
||||
"email_verified": true,
|
||||
"family_name": "Möhlmann",
|
||||
"foo": "bar",
|
||||
"gender": "male",
|
||||
"given_name": "Tim",
|
||||
"locale": "nl",
|
||||
"middle_name": "Danger",
|
||||
"name": "Tim Möhlmann",
|
||||
"nickname": "muhlemmer",
|
||||
"phone_number": "+1234567890",
|
||||
"phone_number_verified": true,
|
||||
"picture": "https://avatars.githubusercontent.com/u/5411563?v=4",
|
||||
"preferred_username": "muhlemmer",
|
||||
"profile": "https://github.com/muhlemmer",
|
||||
"sub": "hello@me.com",
|
||||
"updated_at": 1,
|
||||
"website": "https://zitadel.com",
|
||||
"zoneinfo": "Europe/Amsterdam"
|
||||
}
|
40
pkg/oidc/regression_test.go
Normal file
40
pkg/oidc/regression_test.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package oidc
|
||||
|
||||
// This file contains common functions and data for regression testing
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const dataDir = "regression_data"
|
||||
|
||||
// jsonFilename builds a filename for the regression testdata.
|
||||
// dataDir/<type_name>.json
|
||||
func jsonFilename(obj interface{}) string {
|
||||
name := fmt.Sprintf("%T.json", obj)
|
||||
return path.Join(
|
||||
dataDir,
|
||||
strings.TrimPrefix(name, "*"),
|
||||
)
|
||||
}
|
||||
|
||||
func encodeJSON(t *testing.T, w io.Writer, obj interface{}) {
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", "\t")
|
||||
require.NoError(t, enc.Encode(obj))
|
||||
}
|
||||
|
||||
var regressionData = []interface{}{
|
||||
accessTokenData,
|
||||
idTokenData,
|
||||
introspectionResponseData,
|
||||
userInfoData,
|
||||
jwtProfileAssertionData,
|
||||
}
|
|
@ -2,15 +2,13 @@ package oidc
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/crypto"
|
||||
"github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -20,374 +18,174 @@ const (
|
|||
PrefixBearer = BearerToken + " "
|
||||
)
|
||||
|
||||
type Tokens struct {
|
||||
type Tokens[C IDClaims] struct {
|
||||
*oauth2.Token
|
||||
IDTokenClaims IDTokenClaims
|
||||
IDTokenClaims C
|
||||
IDToken string
|
||||
}
|
||||
|
||||
type AccessTokenClaims interface {
|
||||
Claims
|
||||
GetSubject() string
|
||||
GetTokenID() string
|
||||
SetPrivateClaims(map[string]interface{})
|
||||
// TokenClaims contains the base Claims used all tokens.
|
||||
// It implements OpenID Connect Core 1.0, section 2.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
|
||||
// And RFC 9068: JSON Web Token (JWT) Profile for OAuth 2.0 Access Tokens,
|
||||
// section 2.2. https://datatracker.ietf.org/doc/html/rfc9068#name-data-structure
|
||||
//
|
||||
// TokenClaims implements the Claims interface,
|
||||
// and can be used to extend larger claim types by embedding.
|
||||
type TokenClaims struct {
|
||||
Issuer string `json:"iss,omitempty"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
Audience Audience `json:"aud,omitempty"`
|
||||
Expiration Time `json:"exp,omitempty"`
|
||||
IssuedAt Time `json:"iat,omitempty"`
|
||||
AuthTime Time `json:"auth_time,omitempty"`
|
||||
NotBefore Time `json:"nbf,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
AuthenticationContextClassReference string `json:"acr,omitempty"`
|
||||
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
|
||||
AuthorizedParty string `json:"azp,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
JWTID string `json:"jti,omitempty"`
|
||||
|
||||
// Additional information set by this framework
|
||||
SignatureAlg jose.SignatureAlgorithm `json:"-"`
|
||||
}
|
||||
|
||||
type IDTokenClaims interface {
|
||||
Claims
|
||||
GetNotBefore() time.Time
|
||||
GetJWTID() string
|
||||
GetAccessTokenHash() string
|
||||
GetCodeHash() string
|
||||
GetAuthenticationMethodsReferences() []string
|
||||
GetClientID() string
|
||||
GetSignatureAlgorithm() jose.SignatureAlgorithm
|
||||
SetAccessTokenHash(hash string)
|
||||
SetUserinfo(userinfo UserInfo)
|
||||
SetCodeHash(hash string)
|
||||
UserInfo
|
||||
func (c *TokenClaims) GetIssuer() string {
|
||||
return c.Issuer
|
||||
}
|
||||
|
||||
func EmptyAccessTokenClaims() AccessTokenClaims {
|
||||
return new(accessTokenClaims)
|
||||
func (c *TokenClaims) GetSubject() string {
|
||||
return c.Subject
|
||||
}
|
||||
|
||||
func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, id, clientID string, skew time.Duration) AccessTokenClaims {
|
||||
func (c *TokenClaims) GetAudience() []string {
|
||||
return c.Audience
|
||||
}
|
||||
|
||||
func (c *TokenClaims) GetExpiration() time.Time {
|
||||
return c.Expiration.AsTime()
|
||||
}
|
||||
|
||||
func (c *TokenClaims) GetIssuedAt() time.Time {
|
||||
return c.IssuedAt.AsTime()
|
||||
}
|
||||
|
||||
func (c *TokenClaims) GetNonce() string {
|
||||
return c.Nonce
|
||||
}
|
||||
|
||||
func (c *TokenClaims) GetAuthTime() time.Time {
|
||||
return c.AuthTime.AsTime()
|
||||
}
|
||||
|
||||
func (c *TokenClaims) GetAuthorizedParty() string {
|
||||
return c.AuthorizedParty
|
||||
}
|
||||
|
||||
func (c *TokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm {
|
||||
return c.SignatureAlg
|
||||
}
|
||||
|
||||
func (c *TokenClaims) GetAuthenticationContextClassReference() string {
|
||||
return c.AuthenticationContextClassReference
|
||||
}
|
||||
|
||||
func (c *TokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
|
||||
c.SignatureAlg = algorithm
|
||||
}
|
||||
|
||||
type AccessTokenClaims struct {
|
||||
TokenClaims
|
||||
Scopes SpaceDelimitedArray `json:"scope,omitempty"`
|
||||
Claims map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, jwtid, clientID string, skew time.Duration) *AccessTokenClaims {
|
||||
now := time.Now().UTC().Add(-skew)
|
||||
if len(audience) == 0 {
|
||||
audience = append(audience, clientID)
|
||||
}
|
||||
return &accessTokenClaims{
|
||||
Issuer: issuer,
|
||||
Subject: subject,
|
||||
Audience: audience,
|
||||
Expiration: Time(expiration),
|
||||
IssuedAt: Time(now),
|
||||
NotBefore: Time(now),
|
||||
JWTID: id,
|
||||
return &AccessTokenClaims{
|
||||
TokenClaims: TokenClaims{
|
||||
Issuer: issuer,
|
||||
Subject: subject,
|
||||
Audience: audience,
|
||||
Expiration: FromTime(expiration),
|
||||
IssuedAt: FromTime(now),
|
||||
NotBefore: FromTime(now),
|
||||
JWTID: jwtid,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type accessTokenClaims struct {
|
||||
Issuer string `json:"iss,omitempty"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
Audience Audience `json:"aud,omitempty"`
|
||||
Expiration Time `json:"exp,omitempty"`
|
||||
IssuedAt Time `json:"iat,omitempty"`
|
||||
NotBefore Time `json:"nbf,omitempty"`
|
||||
JWTID string `json:"jti,omitempty"`
|
||||
AuthorizedParty string `json:"azp,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
AuthTime Time `json:"auth_time,omitempty"`
|
||||
CodeHash string `json:"c_hash,omitempty"`
|
||||
AuthenticationContextClassReference string `json:"acr,omitempty"`
|
||||
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
|
||||
SessionID string `json:"sid,omitempty"`
|
||||
Scopes SpaceDelimitedArray `json:"scope,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
|
||||
type atcAlias AccessTokenClaims
|
||||
|
||||
claims map[string]interface{} `json:"-"`
|
||||
signatureAlg jose.SignatureAlgorithm `json:"-"`
|
||||
func (a *AccessTokenClaims) MarshalJSON() ([]byte, error) {
|
||||
return mergeAndMarshalClaims((*atcAlias)(a), a.Claims)
|
||||
}
|
||||
|
||||
// GetIssuer implements the Claims interface
|
||||
func (a *accessTokenClaims) GetIssuer() string {
|
||||
return a.Issuer
|
||||
func (a *AccessTokenClaims) UnmarshalJSON(data []byte) error {
|
||||
return unmarshalJSONMulti(data, (*atcAlias)(a), &a.Claims)
|
||||
}
|
||||
|
||||
// GetAudience implements the Claims interface
|
||||
func (a *accessTokenClaims) GetAudience() []string {
|
||||
return a.Audience
|
||||
}
|
||||
|
||||
// GetExpiration implements the Claims interface
|
||||
func (a *accessTokenClaims) GetExpiration() time.Time {
|
||||
return time.Time(a.Expiration)
|
||||
}
|
||||
|
||||
// GetIssuedAt implements the Claims interface
|
||||
func (a *accessTokenClaims) GetIssuedAt() time.Time {
|
||||
return time.Time(a.IssuedAt)
|
||||
}
|
||||
|
||||
// GetNonce implements the Claims interface
|
||||
func (a *accessTokenClaims) GetNonce() string {
|
||||
return a.Nonce
|
||||
}
|
||||
|
||||
// GetAuthenticationContextClassReference implements the Claims interface
|
||||
func (a *accessTokenClaims) GetAuthenticationContextClassReference() string {
|
||||
return a.AuthenticationContextClassReference
|
||||
}
|
||||
|
||||
// GetAuthTime implements the Claims interface
|
||||
func (a *accessTokenClaims) GetAuthTime() time.Time {
|
||||
return time.Time(a.AuthTime)
|
||||
}
|
||||
|
||||
// GetAuthorizedParty implements the Claims interface
|
||||
func (a *accessTokenClaims) GetAuthorizedParty() string {
|
||||
return a.AuthorizedParty
|
||||
}
|
||||
|
||||
// SetSignatureAlgorithm implements the Claims interface
|
||||
func (a *accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
|
||||
a.signatureAlg = algorithm
|
||||
}
|
||||
|
||||
// GetSubject implements the AccessTokenClaims interface
|
||||
func (a *accessTokenClaims) GetSubject() string {
|
||||
return a.Subject
|
||||
}
|
||||
|
||||
// GetTokenID implements the AccessTokenClaims interface
|
||||
func (a *accessTokenClaims) GetTokenID() string {
|
||||
return a.JWTID
|
||||
}
|
||||
|
||||
// SetPrivateClaims implements the AccessTokenClaims interface
|
||||
func (a *accessTokenClaims) SetPrivateClaims(claims map[string]interface{}) {
|
||||
a.claims = claims
|
||||
}
|
||||
|
||||
func (a *accessTokenClaims) MarshalJSON() ([]byte, error) {
|
||||
type Alias accessTokenClaims
|
||||
s := &struct {
|
||||
*Alias
|
||||
Expiration int64 `json:"exp,omitempty"`
|
||||
IssuedAt int64 `json:"iat,omitempty"`
|
||||
NotBefore int64 `json:"nbf,omitempty"`
|
||||
AuthTime int64 `json:"auth_time,omitempty"`
|
||||
}{
|
||||
Alias: (*Alias)(a),
|
||||
}
|
||||
if !time.Time(a.Expiration).IsZero() {
|
||||
s.Expiration = time.Time(a.Expiration).Unix()
|
||||
}
|
||||
if !time.Time(a.IssuedAt).IsZero() {
|
||||
s.IssuedAt = time.Time(a.IssuedAt).Unix()
|
||||
}
|
||||
if !time.Time(a.NotBefore).IsZero() {
|
||||
s.NotBefore = time.Time(a.NotBefore).Unix()
|
||||
}
|
||||
if !time.Time(a.AuthTime).IsZero() {
|
||||
s.AuthTime = time.Time(a.AuthTime).Unix()
|
||||
}
|
||||
b, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if a.claims == nil {
|
||||
return b, nil
|
||||
}
|
||||
info, err := json.Marshal(a.claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return http.ConcatenateJSON(b, info)
|
||||
}
|
||||
|
||||
func (a *accessTokenClaims) UnmarshalJSON(data []byte) error {
|
||||
type Alias accessTokenClaims
|
||||
if err := json.Unmarshal(data, (*Alias)(a)); err != nil {
|
||||
return err
|
||||
}
|
||||
claims := make(map[string]interface{})
|
||||
if err := json.Unmarshal(data, &claims); err != nil {
|
||||
return err
|
||||
}
|
||||
a.claims = claims
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func EmptyIDTokenClaims() IDTokenClaims {
|
||||
return new(idTokenClaims)
|
||||
}
|
||||
|
||||
func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration) IDTokenClaims {
|
||||
audience = AppendClientIDToAudience(clientID, audience)
|
||||
return &idTokenClaims{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
Expiration: Time(expiration),
|
||||
IssuedAt: Time(time.Now().UTC().Add(-skew)),
|
||||
AuthTime: Time(authTime.Add(-skew)),
|
||||
Nonce: nonce,
|
||||
AuthenticationContextClassReference: acr,
|
||||
AuthenticationMethodsReferences: amr,
|
||||
AuthorizedParty: clientID,
|
||||
UserInfo: &userinfo{Subject: subject},
|
||||
}
|
||||
}
|
||||
|
||||
type idTokenClaims struct {
|
||||
Issuer string `json:"iss,omitempty"`
|
||||
Audience Audience `json:"aud,omitempty"`
|
||||
Expiration Time `json:"exp,omitempty"`
|
||||
NotBefore Time `json:"nbf,omitempty"`
|
||||
IssuedAt Time `json:"iat,omitempty"`
|
||||
JWTID string `json:"jti,omitempty"`
|
||||
AuthorizedParty string `json:"azp,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
AuthTime Time `json:"auth_time,omitempty"`
|
||||
AccessTokenHash string `json:"at_hash,omitempty"`
|
||||
CodeHash string `json:"c_hash,omitempty"`
|
||||
AuthenticationContextClassReference string `json:"acr,omitempty"`
|
||||
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
UserInfo `json:"-"`
|
||||
|
||||
signatureAlg jose.SignatureAlgorithm
|
||||
}
|
||||
|
||||
// GetIssuer implements the Claims interface
|
||||
func (t *idTokenClaims) GetIssuer() string {
|
||||
return t.Issuer
|
||||
}
|
||||
|
||||
// GetAudience implements the Claims interface
|
||||
func (t *idTokenClaims) GetAudience() []string {
|
||||
return t.Audience
|
||||
}
|
||||
|
||||
// GetExpiration implements the Claims interface
|
||||
func (t *idTokenClaims) GetExpiration() time.Time {
|
||||
return time.Time(t.Expiration)
|
||||
}
|
||||
|
||||
// GetIssuedAt implements the Claims interface
|
||||
func (t *idTokenClaims) GetIssuedAt() time.Time {
|
||||
return time.Time(t.IssuedAt)
|
||||
}
|
||||
|
||||
// GetNonce implements the Claims interface
|
||||
func (t *idTokenClaims) GetNonce() string {
|
||||
return t.Nonce
|
||||
}
|
||||
|
||||
// GetAuthenticationContextClassReference implements the Claims interface
|
||||
func (t *idTokenClaims) GetAuthenticationContextClassReference() string {
|
||||
return t.AuthenticationContextClassReference
|
||||
}
|
||||
|
||||
// GetAuthTime implements the Claims interface
|
||||
func (t *idTokenClaims) GetAuthTime() time.Time {
|
||||
return time.Time(t.AuthTime)
|
||||
}
|
||||
|
||||
// GetAuthorizedParty implements the Claims interface
|
||||
func (t *idTokenClaims) GetAuthorizedParty() string {
|
||||
return t.AuthorizedParty
|
||||
}
|
||||
|
||||
// SetSignatureAlgorithm implements the Claims interface
|
||||
func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) {
|
||||
t.signatureAlg = alg
|
||||
}
|
||||
|
||||
// GetNotBefore implements the IDTokenClaims interface
|
||||
func (t *idTokenClaims) GetNotBefore() time.Time {
|
||||
return time.Time(t.NotBefore)
|
||||
}
|
||||
|
||||
// GetJWTID implements the IDTokenClaims interface
|
||||
func (t *idTokenClaims) GetJWTID() string {
|
||||
return t.JWTID
|
||||
// IDTokenClaims extends TokenClaims by further implementing
|
||||
// OpenID Connect Core 1.0, sections 3.1.3.6 (Code flow),
|
||||
// 3.2.2.10 (implicit), 3.3.2.11 (Hybrid) and 5.1 (UserInfo).
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#toc
|
||||
type IDTokenClaims struct {
|
||||
TokenClaims
|
||||
NotBefore Time `json:"nbf,omitempty"`
|
||||
AccessTokenHash string `json:"at_hash,omitempty"`
|
||||
CodeHash string `json:"c_hash,omitempty"`
|
||||
SessionID string `json:"sid,omitempty"`
|
||||
UserInfoProfile
|
||||
UserInfoEmail
|
||||
UserInfoPhone
|
||||
Address *UserInfoAddress `json:"address,omitempty"`
|
||||
Claims map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// GetAccessTokenHash implements the IDTokenClaims interface
|
||||
func (t *idTokenClaims) GetAccessTokenHash() string {
|
||||
func (t *IDTokenClaims) GetAccessTokenHash() string {
|
||||
return t.AccessTokenHash
|
||||
}
|
||||
|
||||
// GetCodeHash implements the IDTokenClaims interface
|
||||
func (t *idTokenClaims) GetCodeHash() string {
|
||||
return t.CodeHash
|
||||
func (t *IDTokenClaims) SetUserInfo(i *UserInfo) {
|
||||
t.Subject = i.Subject
|
||||
t.UserInfoProfile = i.UserInfoProfile
|
||||
t.UserInfoEmail = i.UserInfoEmail
|
||||
t.UserInfoPhone = i.UserInfoPhone
|
||||
t.Address = i.Address
|
||||
}
|
||||
|
||||
// GetAuthenticationMethodsReferences implements the IDTokenClaims interface
|
||||
func (t *idTokenClaims) GetAuthenticationMethodsReferences() []string {
|
||||
return t.AuthenticationMethodsReferences
|
||||
func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration) *IDTokenClaims {
|
||||
audience = AppendClientIDToAudience(clientID, audience)
|
||||
return &IDTokenClaims{
|
||||
TokenClaims: TokenClaims{
|
||||
Issuer: issuer,
|
||||
Subject: subject,
|
||||
Audience: audience,
|
||||
Expiration: FromTime(expiration),
|
||||
IssuedAt: FromTime(time.Now().Add(-skew)),
|
||||
AuthTime: FromTime(authTime.Add(-skew)),
|
||||
Nonce: nonce,
|
||||
AuthenticationContextClassReference: acr,
|
||||
AuthenticationMethodsReferences: amr,
|
||||
AuthorizedParty: clientID,
|
||||
ClientID: clientID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetClientID implements the IDTokenClaims interface
|
||||
func (t *idTokenClaims) GetClientID() string {
|
||||
return t.ClientID
|
||||
type itcAlias IDTokenClaims
|
||||
|
||||
func (i *IDTokenClaims) MarshalJSON() ([]byte, error) {
|
||||
return mergeAndMarshalClaims((*itcAlias)(i), i.Claims)
|
||||
}
|
||||
|
||||
// GetSignatureAlgorithm implements the IDTokenClaims interface
|
||||
func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm {
|
||||
return t.signatureAlg
|
||||
}
|
||||
|
||||
// SetAccessTokenHash implements the IDTokenClaims interface
|
||||
func (t *idTokenClaims) SetAccessTokenHash(hash string) {
|
||||
t.AccessTokenHash = hash
|
||||
}
|
||||
|
||||
// SetUserinfo implements the IDTokenClaims interface
|
||||
func (t *idTokenClaims) SetUserinfo(info UserInfo) {
|
||||
t.UserInfo = info
|
||||
}
|
||||
|
||||
// SetCodeHash implements the IDTokenClaims interface
|
||||
func (t *idTokenClaims) SetCodeHash(hash string) {
|
||||
t.CodeHash = hash
|
||||
}
|
||||
|
||||
func (t *idTokenClaims) MarshalJSON() ([]byte, error) {
|
||||
type Alias idTokenClaims
|
||||
a := &struct {
|
||||
*Alias
|
||||
Expiration int64 `json:"exp,omitempty"`
|
||||
IssuedAt int64 `json:"iat,omitempty"`
|
||||
NotBefore int64 `json:"nbf,omitempty"`
|
||||
AuthTime int64 `json:"auth_time,omitempty"`
|
||||
}{
|
||||
Alias: (*Alias)(t),
|
||||
}
|
||||
if !time.Time(t.Expiration).IsZero() {
|
||||
a.Expiration = time.Time(t.Expiration).Unix()
|
||||
}
|
||||
if !time.Time(t.IssuedAt).IsZero() {
|
||||
a.IssuedAt = time.Time(t.IssuedAt).Unix()
|
||||
}
|
||||
if !time.Time(t.NotBefore).IsZero() {
|
||||
a.NotBefore = time.Time(t.NotBefore).Unix()
|
||||
}
|
||||
if !time.Time(t.AuthTime).IsZero() {
|
||||
a.AuthTime = time.Time(t.AuthTime).Unix()
|
||||
}
|
||||
b, err := json.Marshal(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if t.UserInfo == nil {
|
||||
return b, nil
|
||||
}
|
||||
info, err := json.Marshal(t.UserInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return http.ConcatenateJSON(b, info)
|
||||
}
|
||||
|
||||
func (t *idTokenClaims) UnmarshalJSON(data []byte) error {
|
||||
type Alias idTokenClaims
|
||||
if err := json.Unmarshal(data, (*Alias)(t)); err != nil {
|
||||
return err
|
||||
}
|
||||
userinfo := new(userinfo)
|
||||
if err := json.Unmarshal(data, userinfo); err != nil {
|
||||
return err
|
||||
}
|
||||
t.UserInfo = userinfo
|
||||
|
||||
return nil
|
||||
func (i *IDTokenClaims) UnmarshalJSON(data []byte) error {
|
||||
return unmarshalJSONMulti(data, (*itcAlias)(i), &i.Claims)
|
||||
}
|
||||
|
||||
type AccessTokenResponse struct {
|
||||
|
@ -399,19 +197,7 @@ type AccessTokenResponse struct {
|
|||
State string `json:"state,omitempty" schema:"state,omitempty"`
|
||||
}
|
||||
|
||||
type JWTProfileAssertionClaims interface {
|
||||
GetKeyID() string
|
||||
GetPrivateKey() []byte
|
||||
GetIssuer() string
|
||||
GetSubject() string
|
||||
GetAudience() []string
|
||||
GetExpiration() time.Time
|
||||
GetIssuedAt() time.Time
|
||||
SetCustomClaim(key string, value interface{})
|
||||
GetCustomClaim(key string) interface{}
|
||||
}
|
||||
|
||||
type jwtProfileAssertion struct {
|
||||
type JWTProfileAssertionClaims struct {
|
||||
PrivateKeyID string `json:"-"`
|
||||
PrivateKey []byte `json:"-"`
|
||||
Issuer string `json:"iss"`
|
||||
|
@ -420,91 +206,21 @@ type jwtProfileAssertion struct {
|
|||
Expiration Time `json:"exp"`
|
||||
IssuedAt Time `json:"iat"`
|
||||
|
||||
customClaims map[string]interface{}
|
||||
Claims map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
func (j *jwtProfileAssertion) MarshalJSON() ([]byte, error) {
|
||||
type Alias jwtProfileAssertion
|
||||
a := (*Alias)(j)
|
||||
type jpaAlias JWTProfileAssertionClaims
|
||||
|
||||
b, err := json.Marshal(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(j.customClaims) == 0 {
|
||||
return b, nil
|
||||
}
|
||||
|
||||
err = json.Unmarshal(b, &j.customClaims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("jws: invalid map of custom claims %v", j.customClaims)
|
||||
}
|
||||
|
||||
return json.Marshal(j.customClaims)
|
||||
func (j *JWTProfileAssertionClaims) MarshalJSON() ([]byte, error) {
|
||||
return mergeAndMarshalClaims((*jpaAlias)(j), j.Claims)
|
||||
}
|
||||
|
||||
func (j *jwtProfileAssertion) UnmarshalJSON(data []byte) error {
|
||||
type Alias jwtProfileAssertion
|
||||
a := (*Alias)(j)
|
||||
|
||||
err := json.Unmarshal(data, a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(data, &j.customClaims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
func (j *JWTProfileAssertionClaims) UnmarshalJSON(data []byte) error {
|
||||
return unmarshalJSONMulti(data, (*jpaAlias)(j), &j.Claims)
|
||||
}
|
||||
|
||||
func (j *jwtProfileAssertion) GetKeyID() string {
|
||||
return j.PrivateKeyID
|
||||
}
|
||||
|
||||
func (j *jwtProfileAssertion) GetPrivateKey() []byte {
|
||||
return j.PrivateKey
|
||||
}
|
||||
|
||||
func (j *jwtProfileAssertion) SetCustomClaim(key string, value interface{}) {
|
||||
if j.customClaims == nil {
|
||||
j.customClaims = make(map[string]interface{})
|
||||
}
|
||||
j.customClaims[key] = value
|
||||
}
|
||||
|
||||
func (j *jwtProfileAssertion) GetCustomClaim(key string) interface{} {
|
||||
if j.customClaims == nil {
|
||||
return nil
|
||||
}
|
||||
return j.customClaims[key]
|
||||
}
|
||||
|
||||
func (j *jwtProfileAssertion) GetIssuer() string {
|
||||
return j.Issuer
|
||||
}
|
||||
|
||||
func (j *jwtProfileAssertion) GetSubject() string {
|
||||
return j.Subject
|
||||
}
|
||||
|
||||
func (j *jwtProfileAssertion) GetAudience() []string {
|
||||
return j.Audience
|
||||
}
|
||||
|
||||
func (j *jwtProfileAssertion) GetExpiration() time.Time {
|
||||
return time.Time(j.Expiration)
|
||||
}
|
||||
|
||||
func (j *jwtProfileAssertion) GetIssuedAt() time.Time {
|
||||
return time.Time(j.IssuedAt)
|
||||
}
|
||||
|
||||
func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string, opts ...AssertionOption) (JWTProfileAssertionClaims, error) {
|
||||
data, err := ioutil.ReadFile(filename)
|
||||
func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string, opts ...AssertionOption) (*JWTProfileAssertionClaims, error) {
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -524,19 +240,19 @@ func NewJWTProfileAssertionStringFromFileData(data []byte, audience []string, op
|
|||
return GenerateJWTProfileToken(NewJWTProfileAssertion(keyData.UserID, keyData.KeyID, audience, []byte(keyData.Key), opts...))
|
||||
}
|
||||
|
||||
func JWTProfileDelegatedSubject(sub string) func(*jwtProfileAssertion) {
|
||||
return func(j *jwtProfileAssertion) {
|
||||
func JWTProfileDelegatedSubject(sub string) func(*JWTProfileAssertionClaims) {
|
||||
return func(j *JWTProfileAssertionClaims) {
|
||||
j.Subject = sub
|
||||
}
|
||||
}
|
||||
|
||||
func JWTProfileCustomClaim(key string, value interface{}) func(*jwtProfileAssertion) {
|
||||
return func(j *jwtProfileAssertion) {
|
||||
j.customClaims[key] = value
|
||||
func JWTProfileCustomClaim(key string, value interface{}) func(*JWTProfileAssertionClaims) {
|
||||
return func(j *JWTProfileAssertionClaims) {
|
||||
j.Claims[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func NewJWTProfileAssertionFromFileData(data []byte, audience []string, opts ...AssertionOption) (JWTProfileAssertionClaims, error) {
|
||||
func NewJWTProfileAssertionFromFileData(data []byte, audience []string, opts ...AssertionOption) (*JWTProfileAssertionClaims, error) {
|
||||
keyData := new(struct {
|
||||
KeyID string `json:"keyId"`
|
||||
Key string `json:"key"`
|
||||
|
@ -549,18 +265,18 @@ func NewJWTProfileAssertionFromFileData(data []byte, audience []string, opts ...
|
|||
return NewJWTProfileAssertion(keyData.UserID, keyData.KeyID, audience, []byte(keyData.Key), opts...), nil
|
||||
}
|
||||
|
||||
type AssertionOption func(*jwtProfileAssertion)
|
||||
type AssertionOption func(*JWTProfileAssertionClaims)
|
||||
|
||||
func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte, opts ...AssertionOption) JWTProfileAssertionClaims {
|
||||
j := &jwtProfileAssertion{
|
||||
func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte, opts ...AssertionOption) *JWTProfileAssertionClaims {
|
||||
j := &JWTProfileAssertionClaims{
|
||||
PrivateKey: key,
|
||||
PrivateKeyID: keyID,
|
||||
Issuer: userID,
|
||||
Subject: userID,
|
||||
IssuedAt: Time(time.Now().UTC()),
|
||||
Expiration: Time(time.Now().Add(1 * time.Hour).UTC()),
|
||||
IssuedAt: FromTime(time.Now().UTC()),
|
||||
Expiration: FromTime(time.Now().Add(1 * time.Hour).UTC()),
|
||||
Audience: audience,
|
||||
customClaims: make(map[string]interface{}),
|
||||
Claims: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
|
@ -588,14 +304,14 @@ func AppendClientIDToAudience(clientID string, audience []string) []string {
|
|||
return append(audience, clientID)
|
||||
}
|
||||
|
||||
func GenerateJWTProfileToken(assertion JWTProfileAssertionClaims) (string, error) {
|
||||
privateKey, err := crypto.BytesToPrivateKey(assertion.GetPrivateKey())
|
||||
func GenerateJWTProfileToken(assertion *JWTProfileAssertionClaims) (string, error) {
|
||||
privateKey, err := crypto.BytesToPrivateKey(assertion.PrivateKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
key := jose.SigningKey{
|
||||
Algorithm: jose.RS256,
|
||||
Key: &jose.JSONWebKey{Key: privateKey, KeyID: assertion.GetKeyID()},
|
||||
Key: &jose.JSONWebKey{Key: privateKey, KeyID: assertion.PrivateKeyID},
|
||||
}
|
||||
signer, err := jose.NewSigner(key, &jose.SignerOptions{})
|
||||
if err != nil {
|
||||
|
@ -612,3 +328,12 @@ func GenerateJWTProfileToken(assertion JWTProfileAssertionClaims) (string, error
|
|||
}
|
||||
return signedAssertion.CompactSerialize()
|
||||
}
|
||||
|
||||
type TokenExchangeResponse struct {
|
||||
AccessToken string `json:"access_token"` // Can be access token or ID token
|
||||
IssuedTokenType TokenType `json:"issued_token_type"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn uint64 `json:"expires_in,omitempty"`
|
||||
Scopes SpaceDelimitedArray `json:"scope,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
|
|
@ -27,6 +27,9 @@ const (
|
|||
// GrantTypeImplicit defines the grant type `implicit` used for implicit flows that skip the generation and exchange of an Authorization Code
|
||||
GrantTypeImplicit GrantType = "implicit"
|
||||
|
||||
// GrantTypeDeviceCode
|
||||
GrantTypeDeviceCode GrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
|
||||
// ClientAssertionTypeJWTAssertion defines the client_assertion_type `urn:ietf:params:oauth:client-assertion-type:jwt-bearer`
|
||||
// used for the OAuth JWT Profile Client Authentication
|
||||
ClientAssertionTypeJWTAssertion = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
|
||||
|
@ -35,11 +38,34 @@ const (
|
|||
var AllGrantTypes = []GrantType{
|
||||
GrantTypeCode, GrantTypeRefreshToken, GrantTypeClientCredentials,
|
||||
GrantTypeBearer, GrantTypeTokenExchange, GrantTypeImplicit,
|
||||
ClientAssertionTypeJWTAssertion,
|
||||
GrantTypeDeviceCode, ClientAssertionTypeJWTAssertion,
|
||||
}
|
||||
|
||||
type GrantType string
|
||||
|
||||
const (
|
||||
AccessTokenType TokenType = "urn:ietf:params:oauth:token-type:access_token"
|
||||
RefreshTokenType TokenType = "urn:ietf:params:oauth:token-type:refresh_token"
|
||||
IDTokenType TokenType = "urn:ietf:params:oauth:token-type:id_token"
|
||||
JWTTokenType TokenType = "urn:ietf:params:oauth:token-type:jwt"
|
||||
)
|
||||
|
||||
var AllTokenTypes = []TokenType{
|
||||
AccessTokenType, RefreshTokenType, IDTokenType, JWTTokenType,
|
||||
}
|
||||
|
||||
type TokenType string
|
||||
|
||||
func (t TokenType) IsSupported() bool {
|
||||
for _, tt := range AllTokenTypes {
|
||||
if t == tt {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
type TokenRequest interface {
|
||||
// GrantType GrantType `schema:"grant_type"`
|
||||
GrantType() GrantType
|
||||
|
@ -161,12 +187,12 @@ func (j *JWTTokenRequest) GetAudience() []string {
|
|||
|
||||
// GetExpiration implements the Claims interface
|
||||
func (j *JWTTokenRequest) GetExpiration() time.Time {
|
||||
return time.Time(j.ExpiresAt)
|
||||
return j.ExpiresAt.AsTime()
|
||||
}
|
||||
|
||||
// GetIssuedAt implements the Claims interface
|
||||
func (j *JWTTokenRequest) GetIssuedAt() time.Time {
|
||||
return time.Time(j.IssuedAt)
|
||||
return j.ExpiresAt.AsTime()
|
||||
}
|
||||
|
||||
// GetNonce implements the Claims interface
|
||||
|
@ -203,19 +229,22 @@ func (j *JWTTokenRequest) GetScopes() []string {
|
|||
}
|
||||
|
||||
type TokenExchangeRequest struct {
|
||||
subjectToken string `schema:"subject_token"`
|
||||
subjectTokenType string `schema:"subject_token_type"`
|
||||
actorToken string `schema:"actor_token"`
|
||||
actorTokenType string `schema:"actor_token_type"`
|
||||
resource []string `schema:"resource"`
|
||||
audience Audience `schema:"audience"`
|
||||
Scope SpaceDelimitedArray `schema:"scope"`
|
||||
requestedTokenType string `schema:"requested_token_type"`
|
||||
GrantType GrantType `schema:"grant_type"`
|
||||
SubjectToken string `schema:"subject_token"`
|
||||
SubjectTokenType TokenType `schema:"subject_token_type"`
|
||||
ActorToken string `schema:"actor_token"`
|
||||
ActorTokenType TokenType `schema:"actor_token_type"`
|
||||
Resource []string `schema:"resource"`
|
||||
Audience Audience `schema:"audience"`
|
||||
Scopes SpaceDelimitedArray `schema:"scope"`
|
||||
RequestedTokenType TokenType `schema:"requested_token_type"`
|
||||
}
|
||||
|
||||
type ClientCredentialsRequest struct {
|
||||
GrantType GrantType `schema:"grant_type"`
|
||||
Scope SpaceDelimitedArray `schema:"scope"`
|
||||
ClientID string `schema:"client_id"`
|
||||
ClientSecret string `schema:"client_secret"`
|
||||
GrantType GrantType `schema:"grant_type"`
|
||||
Scope SpaceDelimitedArray `schema:"scope"`
|
||||
ClientID string `schema:"client_id"`
|
||||
ClientSecret string `schema:"client_secret"`
|
||||
ClientAssertion string `schema:"client_assertion"`
|
||||
ClientAssertionType string `schema:"client_assertion_type"`
|
||||
}
|
||||
|
|
227
pkg/oidc/token_test.go
Normal file
227
pkg/oidc/token_test.go
Normal file
|
@ -0,0 +1,227 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/text/language"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
var (
|
||||
tokenClaimsData = TokenClaims{
|
||||
Issuer: "zitadel",
|
||||
Subject: "hello@me.com",
|
||||
Audience: Audience{"foo", "bar"},
|
||||
Expiration: 12345,
|
||||
IssuedAt: 12000,
|
||||
JWTID: "900",
|
||||
AuthorizedParty: "just@me.com",
|
||||
Nonce: "6969",
|
||||
AuthTime: 12000,
|
||||
NotBefore: 12000,
|
||||
AuthenticationContextClassReference: "something",
|
||||
AuthenticationMethodsReferences: []string{"some", "methods"},
|
||||
ClientID: "777",
|
||||
SignatureAlg: jose.ES256,
|
||||
}
|
||||
accessTokenData = &AccessTokenClaims{
|
||||
TokenClaims: tokenClaimsData,
|
||||
Scopes: []string{"email", "phone"},
|
||||
Claims: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
}
|
||||
idTokenData = &IDTokenClaims{
|
||||
TokenClaims: tokenClaimsData,
|
||||
NotBefore: 12000,
|
||||
AccessTokenHash: "acthashhash",
|
||||
CodeHash: "hashhash",
|
||||
SessionID: "666",
|
||||
UserInfoProfile: userInfoData.UserInfoProfile,
|
||||
UserInfoEmail: userInfoData.UserInfoEmail,
|
||||
UserInfoPhone: userInfoData.UserInfoPhone,
|
||||
Address: userInfoData.Address,
|
||||
Claims: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
}
|
||||
introspectionResponseData = &IntrospectionResponse{
|
||||
Active: true,
|
||||
Scope: SpaceDelimitedArray{"email", "phone"},
|
||||
ClientID: "777",
|
||||
TokenType: "idtoken",
|
||||
Expiration: 12345,
|
||||
IssuedAt: 12000,
|
||||
NotBefore: 12000,
|
||||
Subject: "hello@me.com",
|
||||
Audience: Audience{"foo", "bar"},
|
||||
Issuer: "zitadel",
|
||||
JWTID: "900",
|
||||
Username: "muhlemmer",
|
||||
UserInfoProfile: userInfoData.UserInfoProfile,
|
||||
UserInfoEmail: userInfoData.UserInfoEmail,
|
||||
UserInfoPhone: userInfoData.UserInfoPhone,
|
||||
Address: userInfoData.Address,
|
||||
Claims: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
}
|
||||
userInfoData = &UserInfo{
|
||||
Subject: "hello@me.com",
|
||||
UserInfoProfile: UserInfoProfile{
|
||||
Name: "Tim Möhlmann",
|
||||
GivenName: "Tim",
|
||||
FamilyName: "Möhlmann",
|
||||
MiddleName: "Danger",
|
||||
Nickname: "muhlemmer",
|
||||
Profile: "https://github.com/muhlemmer",
|
||||
Picture: "https://avatars.githubusercontent.com/u/5411563?v=4",
|
||||
Website: "https://zitadel.com",
|
||||
Gender: "male",
|
||||
Birthdate: "1st of April",
|
||||
Zoneinfo: "Europe/Amsterdam",
|
||||
Locale: NewLocale(language.Dutch),
|
||||
UpdatedAt: 1,
|
||||
PreferredUsername: "muhlemmer",
|
||||
},
|
||||
UserInfoEmail: UserInfoEmail{
|
||||
Email: "tim@zitadel.com",
|
||||
EmailVerified: true,
|
||||
},
|
||||
UserInfoPhone: UserInfoPhone{
|
||||
PhoneNumber: "+1234567890",
|
||||
PhoneNumberVerified: true,
|
||||
},
|
||||
Address: &UserInfoAddress{
|
||||
Formatted: "Sesame street 666\n666-666, Smallvile\nMoon",
|
||||
StreetAddress: "Sesame street 666",
|
||||
Locality: "Smallvile",
|
||||
Region: "Outer space",
|
||||
PostalCode: "666-666",
|
||||
Country: "Moon",
|
||||
},
|
||||
Claims: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
}
|
||||
jwtProfileAssertionData = &JWTProfileAssertionClaims{
|
||||
PrivateKeyID: "8888",
|
||||
PrivateKey: []byte("qwerty"),
|
||||
Issuer: "zitadel",
|
||||
Subject: "hello@me.com",
|
||||
Audience: Audience{"foo", "bar"},
|
||||
Expiration: 12345,
|
||||
IssuedAt: 12000,
|
||||
Claims: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func TestTokenClaims(t *testing.T) {
|
||||
claims := tokenClaimsData
|
||||
|
||||
assert.Equal(t, claims.Issuer, tokenClaimsData.GetIssuer())
|
||||
assert.Equal(t, claims.Subject, tokenClaimsData.GetSubject())
|
||||
assert.Equal(t, []string(claims.Audience), tokenClaimsData.GetAudience())
|
||||
assert.Equal(t, claims.Expiration.AsTime(), tokenClaimsData.GetExpiration())
|
||||
assert.Equal(t, claims.IssuedAt.AsTime(), tokenClaimsData.GetIssuedAt())
|
||||
assert.Equal(t, claims.Nonce, tokenClaimsData.GetNonce())
|
||||
assert.Equal(t, claims.AuthTime.AsTime(), tokenClaimsData.GetAuthTime())
|
||||
assert.Equal(t, claims.AuthorizedParty, tokenClaimsData.GetAuthorizedParty())
|
||||
assert.Equal(t, claims.SignatureAlg, tokenClaimsData.GetSignatureAlgorithm())
|
||||
assert.Equal(t, claims.AuthenticationContextClassReference, tokenClaimsData.GetAuthenticationContextClassReference())
|
||||
|
||||
claims.SetSignatureAlgorithm(jose.ES384)
|
||||
assert.Equal(t, jose.ES384, claims.SignatureAlg)
|
||||
}
|
||||
|
||||
func TestNewAccessTokenClaims(t *testing.T) {
|
||||
want := &AccessTokenClaims{
|
||||
TokenClaims: TokenClaims{
|
||||
Issuer: "zitadel",
|
||||
Subject: "hello@me.com",
|
||||
Audience: Audience{"foo"},
|
||||
Expiration: 12345,
|
||||
JWTID: "900",
|
||||
},
|
||||
}
|
||||
|
||||
got := NewAccessTokenClaims(
|
||||
want.Issuer, want.Subject, nil,
|
||||
want.Expiration.AsTime(), want.JWTID, "foo", time.Second,
|
||||
)
|
||||
|
||||
// test if the dynamic timestamps are around now,
|
||||
// allowing for a delta of 1, just in case we flip on
|
||||
// either side of a second boundry.
|
||||
nowMinusSkew := NowTime() - 1
|
||||
assert.InDelta(t, int64(nowMinusSkew), int64(got.IssuedAt), 1)
|
||||
assert.InDelta(t, int64(nowMinusSkew), int64(got.NotBefore), 1)
|
||||
|
||||
// Make equal not fail on dynamic timestamp
|
||||
got.IssuedAt = 0
|
||||
got.NotBefore = 0
|
||||
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestIDTokenClaims_GetAccessTokenHash(t *testing.T) {
|
||||
assert.Equal(t, idTokenData.AccessTokenHash, idTokenData.GetAccessTokenHash())
|
||||
}
|
||||
|
||||
func TestIDTokenClaims_SetUserInfo(t *testing.T) {
|
||||
want := IDTokenClaims{
|
||||
TokenClaims: TokenClaims{
|
||||
Subject: userInfoData.Subject,
|
||||
},
|
||||
UserInfoProfile: userInfoData.UserInfoProfile,
|
||||
UserInfoEmail: userInfoData.UserInfoEmail,
|
||||
UserInfoPhone: userInfoData.UserInfoPhone,
|
||||
Address: userInfoData.Address,
|
||||
}
|
||||
|
||||
var got IDTokenClaims
|
||||
got.SetUserInfo(userInfoData)
|
||||
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNewIDTokenClaims(t *testing.T) {
|
||||
want := &IDTokenClaims{
|
||||
TokenClaims: TokenClaims{
|
||||
Issuer: "zitadel",
|
||||
Subject: "hello@me.com",
|
||||
Audience: Audience{"foo", "just@me.com"},
|
||||
Expiration: 12345,
|
||||
AuthTime: 12000,
|
||||
Nonce: "6969",
|
||||
AuthenticationContextClassReference: "something",
|
||||
AuthenticationMethodsReferences: []string{"some", "methods"},
|
||||
AuthorizedParty: "just@me.com",
|
||||
ClientID: "just@me.com",
|
||||
},
|
||||
}
|
||||
|
||||
got := NewIDTokenClaims(
|
||||
want.Issuer, want.Subject, want.Audience,
|
||||
want.Expiration.AsTime(),
|
||||
want.AuthTime.AsTime().Add(time.Second),
|
||||
want.Nonce, want.AuthenticationContextClassReference,
|
||||
want.AuthenticationMethodsReferences, want.AuthorizedParty,
|
||||
time.Second,
|
||||
)
|
||||
|
||||
// test if the dynamic timestamp is around now,
|
||||
// allowing for a delta of 1, just in case we flip on
|
||||
// either side of a second boundry.
|
||||
nowMinusSkew := NowTime() - 1
|
||||
assert.InDelta(t, int64(nowMinusSkew), int64(got.IssuedAt), 1)
|
||||
|
||||
// Make equal not fail on dynamic timestamp
|
||||
got.IssuedAt = 0
|
||||
|
||||
assert.Equal(t, want, got)
|
||||
}
|
|
@ -4,9 +4,11 @@ import (
|
|||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
"golang.org/x/text/language"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
@ -44,6 +46,39 @@ func (d *Display) UnmarshalText(text []byte) error {
|
|||
|
||||
type Gender string
|
||||
|
||||
type Locale struct {
|
||||
tag language.Tag
|
||||
}
|
||||
|
||||
func NewLocale(tag language.Tag) *Locale {
|
||||
return &Locale{tag: tag}
|
||||
}
|
||||
|
||||
func (l *Locale) Tag() language.Tag {
|
||||
if l == nil {
|
||||
return language.Und
|
||||
}
|
||||
|
||||
return l.tag
|
||||
}
|
||||
|
||||
func (l *Locale) String() string {
|
||||
return l.Tag().String()
|
||||
}
|
||||
|
||||
func (l *Locale) MarshalJSON() ([]byte, error) {
|
||||
tag := l.Tag()
|
||||
if tag.IsRoot() {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
return json.Marshal(tag)
|
||||
}
|
||||
|
||||
func (l *Locale) UnmarshalJSON(data []byte) error {
|
||||
return json.Unmarshal(data, &l.tag)
|
||||
}
|
||||
|
||||
type Locales []language.Tag
|
||||
|
||||
func (l *Locales) UnmarshalText(text []byte) error {
|
||||
|
@ -125,19 +160,52 @@ func (s SpaceDelimitedArray) Value() (driver.Value, error) {
|
|||
return strings.Join(s, " "), nil
|
||||
}
|
||||
|
||||
type Time time.Time
|
||||
|
||||
func (t *Time) UnmarshalJSON(data []byte) error {
|
||||
var i int64
|
||||
if err := json.Unmarshal(data, &i); err != nil {
|
||||
return err
|
||||
}
|
||||
*t = Time(time.Unix(i, 0).UTC())
|
||||
return nil
|
||||
// NewEncoder returns a schema Encoder with
|
||||
// a registered encoder for SpaceDelimitedArray.
|
||||
func NewEncoder() *schema.Encoder {
|
||||
e := schema.NewEncoder()
|
||||
e.RegisterEncoder(SpaceDelimitedArray{}, func(value reflect.Value) string {
|
||||
return value.Interface().(SpaceDelimitedArray).Encode()
|
||||
})
|
||||
return e
|
||||
}
|
||||
|
||||
func (t *Time) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(time.Time(*t).UTC().Unix())
|
||||
type Time int64
|
||||
|
||||
func (ts Time) AsTime() time.Time {
|
||||
return time.Unix(int64(ts), 0)
|
||||
}
|
||||
|
||||
func FromTime(tt time.Time) Time {
|
||||
return Time(tt.Unix())
|
||||
}
|
||||
|
||||
func NowTime() Time {
|
||||
return FromTime(time.Now())
|
||||
}
|
||||
|
||||
func (ts *Time) UnmarshalJSON(data []byte) error {
|
||||
var v any
|
||||
if err := json.Unmarshal(data, &v); err != nil {
|
||||
return fmt.Errorf("oidc.Time: %w", err)
|
||||
}
|
||||
switch x := v.(type) {
|
||||
case float64:
|
||||
*ts = Time(x)
|
||||
case string:
|
||||
// Compatibility with Auth0:
|
||||
// https://github.com/zitadel/oidc/issues/292
|
||||
tt, err := time.Parse(time.RFC3339, x)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oidc.Time: %w", err)
|
||||
}
|
||||
*ts = FromTime(tt)
|
||||
case nil:
|
||||
*ts = 0
|
||||
default:
|
||||
return fmt.Errorf("oidc.Time: unable to parse type %T with value %v", x, x)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RequestObject struct {
|
||||
|
@ -150,5 +218,4 @@ func (r *RequestObject) GetIssuer() string {
|
|||
return r.Issuer
|
||||
}
|
||||
|
||||
func (r *RequestObject) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
|
||||
}
|
||||
func (*RequestObject) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {}
|
||||
|
|
|
@ -3,11 +3,14 @@ package oidc
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
|
@ -109,6 +112,117 @@ func TestDisplay_UnmarshalText(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestLocale_Tag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
l *Locale
|
||||
want language.Tag
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
l: nil,
|
||||
want: language.Und,
|
||||
},
|
||||
{
|
||||
name: "Und",
|
||||
l: NewLocale(language.Und),
|
||||
want: language.Und,
|
||||
},
|
||||
{
|
||||
name: "language",
|
||||
l: NewLocale(language.Afrikaans),
|
||||
want: language.Afrikaans,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, tt.l.Tag())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocale_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
l *Locale
|
||||
want language.Tag
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
l: nil,
|
||||
want: language.Und,
|
||||
},
|
||||
{
|
||||
name: "Und",
|
||||
l: NewLocale(language.Und),
|
||||
want: language.Und,
|
||||
},
|
||||
{
|
||||
name: "language",
|
||||
l: NewLocale(language.Afrikaans),
|
||||
want: language.Afrikaans,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want.String(), tt.l.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocale_MarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
l *Locale
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
l: nil,
|
||||
want: "null",
|
||||
},
|
||||
{
|
||||
name: "und",
|
||||
l: NewLocale(language.Und),
|
||||
want: "null",
|
||||
},
|
||||
{
|
||||
name: "language",
|
||||
l: NewLocale(language.Afrikaans),
|
||||
want: `"af"`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := json.Marshal(tt.l)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.want, string(got))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocale_UnmarshalJSON(t *testing.T) {
|
||||
type a struct {
|
||||
Locale *Locale `json:"locale,omitempty"`
|
||||
}
|
||||
want := a{
|
||||
Locale: NewLocale(language.Afrikaans),
|
||||
}
|
||||
|
||||
const input = `{"locale": "af"}`
|
||||
var got a
|
||||
|
||||
require.NoError(t,
|
||||
json.Unmarshal([]byte(input), &got),
|
||||
)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLocales_UnmarshalText(t *testing.T) {
|
||||
type args struct {
|
||||
text []byte
|
||||
|
@ -335,3 +449,74 @@ func TestSpaceDelimitatedArray_ValuerNil(t *testing.T) {
|
|||
assert.Equal(t, SpaceDelimitedArray(nil), reversed, "scan nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEncoder(t *testing.T) {
|
||||
type request struct {
|
||||
Scopes SpaceDelimitedArray `schema:"scope"`
|
||||
}
|
||||
a := request{
|
||||
Scopes: SpaceDelimitedArray{"foo", "bar"},
|
||||
}
|
||||
|
||||
values := make(url.Values)
|
||||
NewEncoder().Encode(a, values)
|
||||
assert.Equal(t, url.Values{"scope": []string{"foo bar"}}, values)
|
||||
|
||||
var b request
|
||||
schema.NewDecoder().Decode(&b, values)
|
||||
assert.Equal(t, a, b)
|
||||
}
|
||||
|
||||
func TestTime_UnmarshalJSON(t *testing.T) {
|
||||
type dst struct {
|
||||
UpdatedAt Time `json:"updated_at"`
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
json string
|
||||
want dst
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "RFC3339", // https://github.com/zitadel/oidc/issues/292
|
||||
json: `{"updated_at": "2021-05-11T21:13:25.566Z"}`,
|
||||
want: dst{UpdatedAt: 1620767605},
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
json: `{"updated_at":1620767605}`,
|
||||
want: dst{UpdatedAt: 1620767605},
|
||||
},
|
||||
{
|
||||
name: "time parse error",
|
||||
json: `{"updated_at":"foo"}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "null",
|
||||
json: `{"updated_at":null}`,
|
||||
},
|
||||
{
|
||||
name: "invalid type",
|
||||
json: `{"updated_at":["foo","bar"]}`,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got dst
|
||||
err := json.Unmarshal([]byte(tt.json), &got)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
t.Run("syntax error", func(t *testing.T) {
|
||||
var ts Time
|
||||
err := ts.UnmarshalJSON([]byte{'~'})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,320 +1,73 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
type UserInfo interface {
|
||||
GetSubject() string
|
||||
// UserInfo implements OpenID Connect Core 1.0, section 5.1.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims.
|
||||
type UserInfo struct {
|
||||
Subject string `json:"sub,omitempty"`
|
||||
UserInfoProfile
|
||||
UserInfoEmail
|
||||
UserInfoPhone
|
||||
GetAddress() UserInfoAddress
|
||||
GetClaim(key string) interface{}
|
||||
GetClaims() map[string]interface{}
|
||||
Address *UserInfoAddress `json:"address,omitempty"`
|
||||
|
||||
Claims map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
type UserInfoProfile interface {
|
||||
GetName() string
|
||||
GetGivenName() string
|
||||
GetFamilyName() string
|
||||
GetMiddleName() string
|
||||
GetNickname() string
|
||||
GetProfile() string
|
||||
GetPicture() string
|
||||
GetWebsite() string
|
||||
GetGender() Gender
|
||||
GetBirthdate() string
|
||||
GetZoneinfo() string
|
||||
GetLocale() language.Tag
|
||||
GetPreferredUsername() string
|
||||
func (u *UserInfo) AppendClaims(k string, v any) {
|
||||
if u.Claims == nil {
|
||||
u.Claims = make(map[string]any)
|
||||
}
|
||||
|
||||
u.Claims[k] = v
|
||||
}
|
||||
|
||||
type UserInfoEmail interface {
|
||||
GetEmail() string
|
||||
IsEmailVerified() bool
|
||||
}
|
||||
|
||||
type UserInfoPhone interface {
|
||||
GetPhoneNumber() string
|
||||
IsPhoneNumberVerified() bool
|
||||
}
|
||||
|
||||
type UserInfoAddress interface {
|
||||
GetFormatted() string
|
||||
GetStreetAddress() string
|
||||
GetLocality() string
|
||||
GetRegion() string
|
||||
GetPostalCode() string
|
||||
GetCountry() string
|
||||
}
|
||||
|
||||
type UserInfoSetter interface {
|
||||
UserInfo
|
||||
SetSubject(sub string)
|
||||
UserInfoProfileSetter
|
||||
SetEmail(email string, verified bool)
|
||||
SetPhone(phone string, verified bool)
|
||||
SetAddress(address UserInfoAddress)
|
||||
AppendClaims(key string, values interface{})
|
||||
}
|
||||
|
||||
type UserInfoProfileSetter interface {
|
||||
SetName(name string)
|
||||
SetGivenName(name string)
|
||||
SetFamilyName(name string)
|
||||
SetMiddleName(name string)
|
||||
SetNickname(name string)
|
||||
SetUpdatedAt(date time.Time)
|
||||
SetProfile(profile string)
|
||||
SetPicture(profile string)
|
||||
SetWebsite(website string)
|
||||
SetGender(gender Gender)
|
||||
SetBirthdate(birthdate string)
|
||||
SetZoneinfo(zoneInfo string)
|
||||
SetLocale(locale language.Tag)
|
||||
SetPreferredUsername(name string)
|
||||
}
|
||||
|
||||
func NewUserInfo() UserInfoSetter {
|
||||
return &userinfo{}
|
||||
}
|
||||
|
||||
type userinfo struct {
|
||||
Subject string `json:"sub,omitempty"`
|
||||
userInfoProfile
|
||||
userInfoEmail
|
||||
userInfoPhone
|
||||
Address UserInfoAddress `json:"address,omitempty"`
|
||||
|
||||
claims map[string]interface{}
|
||||
}
|
||||
|
||||
func (u *userinfo) GetSubject() string {
|
||||
return u.Subject
|
||||
}
|
||||
|
||||
func (u *userinfo) GetName() string {
|
||||
return u.Name
|
||||
}
|
||||
|
||||
func (u *userinfo) GetGivenName() string {
|
||||
return u.GivenName
|
||||
}
|
||||
|
||||
func (u *userinfo) GetFamilyName() string {
|
||||
return u.FamilyName
|
||||
}
|
||||
|
||||
func (u *userinfo) GetMiddleName() string {
|
||||
return u.MiddleName
|
||||
}
|
||||
|
||||
func (u *userinfo) GetNickname() string {
|
||||
return u.Nickname
|
||||
}
|
||||
|
||||
func (u *userinfo) GetProfile() string {
|
||||
return u.Profile
|
||||
}
|
||||
|
||||
func (u *userinfo) GetPicture() string {
|
||||
return u.Picture
|
||||
}
|
||||
|
||||
func (u *userinfo) GetWebsite() string {
|
||||
return u.Website
|
||||
}
|
||||
|
||||
func (u *userinfo) GetGender() Gender {
|
||||
return u.Gender
|
||||
}
|
||||
|
||||
func (u *userinfo) GetBirthdate() string {
|
||||
return u.Birthdate
|
||||
}
|
||||
|
||||
func (u *userinfo) GetZoneinfo() string {
|
||||
return u.Zoneinfo
|
||||
}
|
||||
|
||||
func (u *userinfo) GetLocale() language.Tag {
|
||||
return u.Locale
|
||||
}
|
||||
|
||||
func (u *userinfo) GetPreferredUsername() string {
|
||||
return u.PreferredUsername
|
||||
}
|
||||
|
||||
func (u *userinfo) GetEmail() string {
|
||||
return u.Email
|
||||
}
|
||||
|
||||
func (u *userinfo) IsEmailVerified() bool {
|
||||
return bool(u.EmailVerified)
|
||||
}
|
||||
|
||||
func (u *userinfo) GetPhoneNumber() string {
|
||||
return u.PhoneNumber
|
||||
}
|
||||
|
||||
func (u *userinfo) IsPhoneNumberVerified() bool {
|
||||
return u.PhoneNumberVerified
|
||||
}
|
||||
|
||||
func (u *userinfo) GetAddress() UserInfoAddress {
|
||||
// GetAddress is a safe getter that takes
|
||||
// care of a possible nil value.
|
||||
func (u *UserInfo) GetAddress() *UserInfoAddress {
|
||||
if u.Address == nil {
|
||||
return &userInfoAddress{}
|
||||
return new(UserInfoAddress)
|
||||
}
|
||||
return u.Address
|
||||
}
|
||||
|
||||
func (u *userinfo) GetClaim(key string) interface{} {
|
||||
return u.claims[key]
|
||||
type uiAlias UserInfo
|
||||
|
||||
func (u *UserInfo) MarshalJSON() ([]byte, error) {
|
||||
return mergeAndMarshalClaims((*uiAlias)(u), u.Claims)
|
||||
}
|
||||
|
||||
func (u *userinfo) GetClaims() map[string]interface{} {
|
||||
return u.claims
|
||||
func (u *UserInfo) UnmarshalJSON(data []byte) error {
|
||||
return unmarshalJSONMulti(data, (*uiAlias)(u), &u.Claims)
|
||||
}
|
||||
|
||||
func (u *userinfo) SetSubject(sub string) {
|
||||
u.Subject = sub
|
||||
type UserInfoProfile struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
GivenName string `json:"given_name,omitempty"`
|
||||
FamilyName string `json:"family_name,omitempty"`
|
||||
MiddleName string `json:"middle_name,omitempty"`
|
||||
Nickname string `json:"nickname,omitempty"`
|
||||
Profile string `json:"profile,omitempty"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
Website string `json:"website,omitempty"`
|
||||
Gender Gender `json:"gender,omitempty"`
|
||||
Birthdate string `json:"birthdate,omitempty"`
|
||||
Zoneinfo string `json:"zoneinfo,omitempty"`
|
||||
Locale *Locale `json:"locale,omitempty"`
|
||||
UpdatedAt Time `json:"updated_at,omitempty"`
|
||||
PreferredUsername string `json:"preferred_username,omitempty"`
|
||||
}
|
||||
|
||||
func (u *userinfo) SetName(name string) {
|
||||
u.Name = name
|
||||
}
|
||||
|
||||
func (u *userinfo) SetGivenName(name string) {
|
||||
u.GivenName = name
|
||||
}
|
||||
|
||||
func (u *userinfo) SetFamilyName(name string) {
|
||||
u.FamilyName = name
|
||||
}
|
||||
|
||||
func (u *userinfo) SetMiddleName(name string) {
|
||||
u.MiddleName = name
|
||||
}
|
||||
|
||||
func (u *userinfo) SetNickname(name string) {
|
||||
u.Nickname = name
|
||||
}
|
||||
|
||||
func (u *userinfo) SetUpdatedAt(date time.Time) {
|
||||
u.UpdatedAt = Time(date)
|
||||
}
|
||||
|
||||
func (u *userinfo) SetProfile(profile string) {
|
||||
u.Profile = profile
|
||||
}
|
||||
|
||||
func (u *userinfo) SetPicture(picture string) {
|
||||
u.Picture = picture
|
||||
}
|
||||
|
||||
func (u *userinfo) SetWebsite(website string) {
|
||||
u.Website = website
|
||||
}
|
||||
|
||||
func (u *userinfo) SetGender(gender Gender) {
|
||||
u.Gender = gender
|
||||
}
|
||||
|
||||
func (u *userinfo) SetBirthdate(birthdate string) {
|
||||
u.Birthdate = birthdate
|
||||
}
|
||||
|
||||
func (u *userinfo) SetZoneinfo(zoneInfo string) {
|
||||
u.Zoneinfo = zoneInfo
|
||||
}
|
||||
|
||||
func (u *userinfo) SetLocale(locale language.Tag) {
|
||||
u.Locale = locale
|
||||
}
|
||||
|
||||
func (u *userinfo) SetPreferredUsername(name string) {
|
||||
u.PreferredUsername = name
|
||||
}
|
||||
|
||||
func (u *userinfo) SetEmail(email string, verified bool) {
|
||||
u.Email = email
|
||||
u.EmailVerified = boolString(verified)
|
||||
}
|
||||
|
||||
func (u *userinfo) SetPhone(phone string, verified bool) {
|
||||
u.PhoneNumber = phone
|
||||
u.PhoneNumberVerified = verified
|
||||
}
|
||||
|
||||
func (u *userinfo) SetAddress(address UserInfoAddress) {
|
||||
u.Address = address
|
||||
}
|
||||
|
||||
func (u *userinfo) AppendClaims(key string, value interface{}) {
|
||||
if u.claims == nil {
|
||||
u.claims = make(map[string]interface{})
|
||||
}
|
||||
u.claims[key] = value
|
||||
}
|
||||
|
||||
func (u *userInfoAddress) GetFormatted() string {
|
||||
return u.Formatted
|
||||
}
|
||||
|
||||
func (u *userInfoAddress) GetStreetAddress() string {
|
||||
return u.StreetAddress
|
||||
}
|
||||
|
||||
func (u *userInfoAddress) GetLocality() string {
|
||||
return u.Locality
|
||||
}
|
||||
|
||||
func (u *userInfoAddress) GetRegion() string {
|
||||
return u.Region
|
||||
}
|
||||
|
||||
func (u *userInfoAddress) GetPostalCode() string {
|
||||
return u.PostalCode
|
||||
}
|
||||
|
||||
func (u *userInfoAddress) GetCountry() string {
|
||||
return u.Country
|
||||
}
|
||||
|
||||
type userInfoProfile struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
GivenName string `json:"given_name,omitempty"`
|
||||
FamilyName string `json:"family_name,omitempty"`
|
||||
MiddleName string `json:"middle_name,omitempty"`
|
||||
Nickname string `json:"nickname,omitempty"`
|
||||
Profile string `json:"profile,omitempty"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
Website string `json:"website,omitempty"`
|
||||
Gender Gender `json:"gender,omitempty"`
|
||||
Birthdate string `json:"birthdate,omitempty"`
|
||||
Zoneinfo string `json:"zoneinfo,omitempty"`
|
||||
Locale language.Tag `json:"locale,omitempty"`
|
||||
UpdatedAt Time `json:"updated_at,omitempty"`
|
||||
PreferredUsername string `json:"preferred_username,omitempty"`
|
||||
}
|
||||
|
||||
type userInfoEmail struct {
|
||||
type UserInfoEmail struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
|
||||
// Handle providers that return email_verified as a string
|
||||
// https://forums.aws.amazon.com/thread.jspa?messageID=949441󧳁
|
||||
// https://discuss.elastic.co/t/openid-error-after-authenticating-against-aws-cognito/206018/11
|
||||
EmailVerified boolString `json:"email_verified,omitempty"`
|
||||
EmailVerified Bool `json:"email_verified,omitempty"`
|
||||
}
|
||||
|
||||
type boolString bool
|
||||
type Bool bool
|
||||
|
||||
func (bs *boolString) UnmarshalJSON(data []byte) error {
|
||||
func (bs *Bool) UnmarshalJSON(data []byte) error {
|
||||
if string(data) == "true" || string(data) == `"true"` {
|
||||
*bs = true
|
||||
}
|
||||
|
@ -322,12 +75,12 @@ func (bs *boolString) UnmarshalJSON(data []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
type userInfoPhone struct {
|
||||
type UserInfoPhone struct {
|
||||
PhoneNumber string `json:"phone_number,omitempty"`
|
||||
PhoneNumberVerified bool `json:"phone_number_verified,omitempty"`
|
||||
}
|
||||
|
||||
type userInfoAddress struct {
|
||||
type UserInfoAddress struct {
|
||||
Formatted string `json:"formatted,omitempty"`
|
||||
StreetAddress string `json:"street_address,omitempty"`
|
||||
Locality string `json:"locality,omitempty"`
|
||||
|
@ -336,76 +89,6 @@ type userInfoAddress struct {
|
|||
Country string `json:"country,omitempty"`
|
||||
}
|
||||
|
||||
func NewUserInfoAddress(streetAddress, locality, region, postalCode, country, formatted string) UserInfoAddress {
|
||||
return &userInfoAddress{
|
||||
StreetAddress: streetAddress,
|
||||
Locality: locality,
|
||||
Region: region,
|
||||
PostalCode: postalCode,
|
||||
Country: country,
|
||||
Formatted: formatted,
|
||||
}
|
||||
}
|
||||
|
||||
func (u *userinfo) MarshalJSON() ([]byte, error) {
|
||||
type Alias userinfo
|
||||
a := &struct {
|
||||
*Alias
|
||||
Locale interface{} `json:"locale,omitempty"`
|
||||
UpdatedAt int64 `json:"updated_at,omitempty"`
|
||||
}{
|
||||
Alias: (*Alias)(u),
|
||||
}
|
||||
if !u.Locale.IsRoot() {
|
||||
a.Locale = u.Locale
|
||||
}
|
||||
if !time.Time(u.UpdatedAt).IsZero() {
|
||||
a.UpdatedAt = time.Time(u.UpdatedAt).Unix()
|
||||
}
|
||||
|
||||
b, err := json.Marshal(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(u.claims) == 0 {
|
||||
return b, nil
|
||||
}
|
||||
|
||||
err = json.Unmarshal(b, &u.claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("jws: invalid map of custom claims %v", u.claims)
|
||||
}
|
||||
|
||||
return json.Marshal(u.claims)
|
||||
}
|
||||
|
||||
func (u *userinfo) UnmarshalJSON(data []byte) error {
|
||||
type Alias userinfo
|
||||
a := &struct {
|
||||
Address *userInfoAddress `json:"address,omitempty"`
|
||||
*Alias
|
||||
UpdatedAt int64 `json:"update_at,omitempty"`
|
||||
}{
|
||||
Alias: (*Alias)(u),
|
||||
}
|
||||
if err := json.Unmarshal(data, &a); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if a.Address != nil {
|
||||
u.Address = a.Address
|
||||
}
|
||||
|
||||
u.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC())
|
||||
|
||||
if err := json.Unmarshal(data, &u.claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type UserInfoRequest struct {
|
||||
AccessToken string `schema:"access_token"`
|
||||
}
|
||||
|
|
|
@ -7,21 +7,54 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUserInfo_AppendClaims(t *testing.T) {
|
||||
u := new(UserInfo)
|
||||
u.AppendClaims("a", "b")
|
||||
want := map[string]any{"a": "b"}
|
||||
assert.Equal(t, want, u.Claims)
|
||||
|
||||
u.AppendClaims("d", "e")
|
||||
want["d"] = "e"
|
||||
assert.Equal(t, want, u.Claims)
|
||||
}
|
||||
|
||||
func TestUserInfo_GetAddress(t *testing.T) {
|
||||
// nil address
|
||||
u := new(UserInfo)
|
||||
assert.Equal(t, &UserInfoAddress{}, u.GetAddress())
|
||||
|
||||
u.Address = &UserInfoAddress{PostalCode: "1234"}
|
||||
assert.Equal(t, u.Address, u.GetAddress())
|
||||
}
|
||||
|
||||
func TestUserInfoMarshal(t *testing.T) {
|
||||
userinfo := NewUserInfo()
|
||||
userinfo.SetSubject("test")
|
||||
userinfo.SetAddress(NewUserInfoAddress("Test 789\nPostfach 2", "", "", "", "", ""))
|
||||
userinfo.SetEmail("test", true)
|
||||
userinfo.SetPhone("0791234567", true)
|
||||
userinfo.SetName("Test")
|
||||
userinfo.AppendClaims("private_claim", "test")
|
||||
userinfo := &UserInfo{
|
||||
Subject: "test",
|
||||
Address: &UserInfoAddress{
|
||||
StreetAddress: "Test 789\nPostfach 2",
|
||||
},
|
||||
UserInfoEmail: UserInfoEmail{
|
||||
Email: "test",
|
||||
EmailVerified: true,
|
||||
},
|
||||
UserInfoPhone: UserInfoPhone{
|
||||
PhoneNumber: "0791234567",
|
||||
PhoneNumberVerified: true,
|
||||
},
|
||||
UserInfoProfile: UserInfoProfile{
|
||||
Name: "Test",
|
||||
},
|
||||
Claims: map[string]any{"private_claim": "test"},
|
||||
}
|
||||
|
||||
marshal, err := json.Marshal(userinfo)
|
||||
out := NewUserInfo()
|
||||
assert.NoError(t, err)
|
||||
|
||||
out := new(UserInfo)
|
||||
assert.NoError(t, json.Unmarshal(marshal, out))
|
||||
assert.Equal(t, userinfo.GetAddress(), out.GetAddress())
|
||||
assert.Equal(t, userinfo, out)
|
||||
expected, err := json.Marshal(out)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, marshal)
|
||||
}
|
||||
|
@ -29,91 +62,55 @@ func TestUserInfoMarshal(t *testing.T) {
|
|||
func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("unmarsha email_verified from json bool true", func(t *testing.T) {
|
||||
t.Run("unmarshal email_verified from json bool true", func(t *testing.T) {
|
||||
jsonBool := []byte(`{"email": "my@email.com", "email_verified": true}`)
|
||||
|
||||
var uie userInfoEmail
|
||||
var uie UserInfoEmail
|
||||
|
||||
err := json.Unmarshal(jsonBool, &uie)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, userInfoEmail{
|
||||
assert.Equal(t, UserInfoEmail{
|
||||
Email: "my@email.com",
|
||||
EmailVerified: true,
|
||||
}, uie)
|
||||
})
|
||||
|
||||
t.Run("unmarsha email_verified from json string true", func(t *testing.T) {
|
||||
t.Run("unmarshal email_verified from json string true", func(t *testing.T) {
|
||||
jsonBool := []byte(`{"email": "my@email.com", "email_verified": "true"}`)
|
||||
|
||||
var uie userInfoEmail
|
||||
var uie UserInfoEmail
|
||||
|
||||
err := json.Unmarshal(jsonBool, &uie)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, userInfoEmail{
|
||||
assert.Equal(t, UserInfoEmail{
|
||||
Email: "my@email.com",
|
||||
EmailVerified: true,
|
||||
}, uie)
|
||||
})
|
||||
|
||||
t.Run("unmarsha email_verified from json bool false", func(t *testing.T) {
|
||||
t.Run("unmarshal email_verified from json bool false", func(t *testing.T) {
|
||||
jsonBool := []byte(`{"email": "my@email.com", "email_verified": false}`)
|
||||
|
||||
var uie userInfoEmail
|
||||
var uie UserInfoEmail
|
||||
|
||||
err := json.Unmarshal(jsonBool, &uie)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, userInfoEmail{
|
||||
assert.Equal(t, UserInfoEmail{
|
||||
Email: "my@email.com",
|
||||
EmailVerified: false,
|
||||
}, uie)
|
||||
})
|
||||
|
||||
t.Run("unmarsha email_verified from json string false", func(t *testing.T) {
|
||||
t.Run("unmarshal email_verified from json string false", func(t *testing.T) {
|
||||
jsonBool := []byte(`{"email": "my@email.com", "email_verified": "false"}`)
|
||||
|
||||
var uie userInfoEmail
|
||||
var uie UserInfoEmail
|
||||
|
||||
err := json.Unmarshal(jsonBool, &uie)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, userInfoEmail{
|
||||
assert.Equal(t, UserInfoEmail{
|
||||
Email: "my@email.com",
|
||||
EmailVerified: false,
|
||||
}, uie)
|
||||
})
|
||||
}
|
||||
|
||||
// issue 203 test case.
|
||||
func Test_userinfo_GetAddress_issue_203(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data string
|
||||
}{
|
||||
{
|
||||
name: "with address",
|
||||
data: `{"address":{"street_address":"Test 789\nPostfach 2"},"email":"test","email_verified":true,"name":"Test","phone_number":"0791234567","phone_number_verified":true,"private_claim":"test","sub":"test"}`,
|
||||
},
|
||||
{
|
||||
name: "without address",
|
||||
data: `{"email":"test","email_verified":true,"name":"Test","phone_number":"0791234567","phone_number_verified":true,"private_claim":"test","sub":"test"}`,
|
||||
},
|
||||
{
|
||||
name: "null address",
|
||||
data: `{"address":null,"email":"test","email_verified":true,"name":"Test","phone_number":"0791234567","phone_number_verified":true,"private_claim":"test","sub":"test"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
info := &userinfo{}
|
||||
err := json.Unmarshal([]byte(tt.data), info)
|
||||
assert.NoError(t, err)
|
||||
|
||||
info.GetAddress().GetCountry() //<- used to panic
|
||||
|
||||
// now shortly assure that a marshalling still produces the same as was parsed into the struct
|
||||
marshal, err := json.Marshal(info)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.data, string(marshal))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
49
pkg/oidc/util.go
Normal file
49
pkg/oidc/util.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// mergeAndMarshalClaims merges registered and the custom
|
||||
// claims map into a single JSON object.
|
||||
// Registered fields overwrite custom claims.
|
||||
func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error) {
|
||||
// Use a buffer for memory re-use, instead off letting
|
||||
// json allocate a new []byte for every step.
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
// Marshal the registered claims into JSON
|
||||
if err := json.NewEncoder(buf).Encode(registered); err != nil {
|
||||
return nil, fmt.Errorf("oidc registered claims: %w", err)
|
||||
}
|
||||
|
||||
if len(claims) > 0 {
|
||||
// Merge JSON data into custom claims.
|
||||
// The full-read action by the decoder resets the buffer
|
||||
// to zero len, while retaining underlaying cap.
|
||||
if err := json.NewDecoder(buf).Decode(&claims); err != nil {
|
||||
return nil, fmt.Errorf("oidc registered claims: %w", err)
|
||||
}
|
||||
|
||||
// Marshal the final result.
|
||||
if err := json.NewEncoder(buf).Encode(claims); err != nil {
|
||||
return nil, fmt.Errorf("oidc custom claims: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// unmarshalJSONMulti unmarshals the same JSON data into multiple destinations.
|
||||
// Each destination must be a pointer, as per json.Unmarshal rules.
|
||||
// Returns on the first error and destinations may be partly filled with data.
|
||||
func unmarshalJSONMulti(data []byte, destinations ...any) error {
|
||||
for _, dst := range destinations {
|
||||
if err := json.Unmarshal(data, dst); err != nil {
|
||||
return fmt.Errorf("oidc: %w into %T", err, dst)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
147
pkg/oidc/util_test.go
Normal file
147
pkg/oidc/util_test.go
Normal file
|
@ -0,0 +1,147 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type jsonErrorTest struct{}
|
||||
|
||||
func (jsonErrorTest) MarshalJSON() ([]byte, error) {
|
||||
return nil, errors.New("test")
|
||||
}
|
||||
|
||||
func Test_mergeAndMarshalClaims(t *testing.T) {
|
||||
type args struct {
|
||||
registered any
|
||||
claims map[string]any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "encoder error",
|
||||
args: args{
|
||||
registered: jsonErrorTest{},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no claims",
|
||||
args: args{
|
||||
registered: struct {
|
||||
Foo string `json:"foo,omitempty"`
|
||||
}{
|
||||
Foo: "bar",
|
||||
},
|
||||
},
|
||||
want: "{\"foo\":\"bar\"}\n",
|
||||
},
|
||||
{
|
||||
name: "with claims",
|
||||
args: args{
|
||||
registered: struct {
|
||||
Foo string `json:"foo,omitempty"`
|
||||
}{
|
||||
Foo: "bar",
|
||||
},
|
||||
claims: map[string]any{
|
||||
"bar": "foo",
|
||||
},
|
||||
},
|
||||
want: "{\"bar\":\"foo\",\"foo\":\"bar\"}\n",
|
||||
},
|
||||
{
|
||||
name: "registered overwrites custom",
|
||||
args: args{
|
||||
registered: struct {
|
||||
Foo string `json:"foo,omitempty"`
|
||||
}{
|
||||
Foo: "bar",
|
||||
},
|
||||
claims: map[string]any{
|
||||
"foo": "Hello, World!",
|
||||
},
|
||||
},
|
||||
want: "{\"foo\":\"bar\"}\n",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := mergeAndMarshalClaims(tt.args.registered, tt.args.claims)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.want, string(got))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_unmarshalJSONMulti(t *testing.T) {
|
||||
type dst struct {
|
||||
Foo string `json:"foo,omitempty"`
|
||||
}
|
||||
|
||||
type args struct {
|
||||
data string
|
||||
destinations []any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []any
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "error",
|
||||
args: args{
|
||||
data: "~!~~",
|
||||
destinations: []any{
|
||||
&dst{},
|
||||
&map[string]any{},
|
||||
},
|
||||
},
|
||||
want: []any{
|
||||
&dst{},
|
||||
&map[string]any{},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
args: args{
|
||||
data: "{\"bar\":\"foo\",\"foo\":\"bar\"}\n",
|
||||
destinations: []any{
|
||||
&dst{},
|
||||
&map[string]any{},
|
||||
},
|
||||
},
|
||||
want: []any{
|
||||
&dst{Foo: "bar"},
|
||||
&map[string]any{
|
||||
"foo": "bar",
|
||||
"bar": "foo",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := unmarshalJSONMulti([]byte(tt.args.data), tt.args.destinations...)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.want, tt.args.destinations)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
str "github.com/zitadel/oidc/pkg/strings"
|
||||
str "github.com/zitadel/oidc/v2/pkg/strings"
|
||||
)
|
||||
|
||||
type Claims interface {
|
||||
|
@ -32,6 +32,12 @@ type ClaimsSignature interface {
|
|||
SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm)
|
||||
}
|
||||
|
||||
type IDClaims interface {
|
||||
Claims
|
||||
GetSignatureAlgorithm() jose.SignatureAlgorithm
|
||||
GetAccessTokenHash() string
|
||||
}
|
||||
|
||||
var (
|
||||
ErrParse = errors.New("parsing of request failed")
|
||||
ErrIssuerInvalid = errors.New("issuer does not match")
|
||||
|
|
|
@ -12,9 +12,9 @@ import (
|
|||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
str "github.com/zitadel/oidc/pkg/strings"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
str "github.com/zitadel/oidc/v2/pkg/strings"
|
||||
)
|
||||
|
||||
type AuthRequest interface {
|
||||
|
@ -39,10 +39,8 @@ type Authorizer interface {
|
|||
Storage() Storage
|
||||
Decoder() httphelper.Decoder
|
||||
Encoder() httphelper.Encoder
|
||||
Signer() Signer
|
||||
IDTokenHintVerifier() IDTokenHintVerifier
|
||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
||||
Crypto() Crypto
|
||||
Issuer() string
|
||||
RequestObjectSupported() bool
|
||||
}
|
||||
|
||||
|
@ -73,8 +71,9 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
|||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
|
||||
authReq, err = ParseRequestObject(r.Context(), authReq, authorizer.Storage(), authorizer.Issuer())
|
||||
authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
|
@ -92,7 +91,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
|||
if validater, ok := authorizer.(AuthorizeValidator); ok {
|
||||
validation = validater.ValidateAuthRequest
|
||||
}
|
||||
userID, err := validation(r.Context(), authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier())
|
||||
userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx))
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
|
@ -101,12 +100,12 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
|||
AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq, userID)
|
||||
req, err := authorizer.Storage().CreateAuthRequest(ctx, authReq, userID)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
|
||||
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder())
|
||||
return
|
||||
|
@ -390,7 +389,7 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
|
|||
if idTokenHint == "" {
|
||||
return "", nil
|
||||
}
|
||||
claims, err := VerifyIDTokenHint(ctx, idTokenHint, verifier)
|
||||
claims, err := VerifyIDTokenHint[*oidc.TokenClaims](ctx, idTokenHint, verifier)
|
||||
if err != nil {
|
||||
return "", oidc.ErrLoginRequired().WithDescription("The id_token_hint is invalid. " +
|
||||
"If you have any questions, you may contact the administrator of the application.")
|
||||
|
|
|
@ -13,10 +13,10 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/pkg/op"
|
||||
"github.com/zitadel/oidc/pkg/op/mock"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v2/pkg/op/mock"
|
||||
)
|
||||
|
||||
//
|
||||
|
|
100
pkg/op/client.go
100
pkg/op/client.go
|
@ -1,13 +1,19 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
//go:generate go get github.com/dmarkham/enumer
|
||||
//go:generate go run github.com/dmarkham/enumer -linecomment -sql -json -text -yaml -gqlgen -type=ApplicationType,AccessTokenType
|
||||
//go:generate go mod tidy
|
||||
|
||||
const (
|
||||
ApplicationTypeWeb ApplicationType = iota // web
|
||||
|
@ -67,3 +73,95 @@ func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseT
|
|||
func IsConfidentialType(c Client) bool {
|
||||
return c.ApplicationType() == ApplicationTypeWeb
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidAuthHeader = errors.New("invalid basic auth header")
|
||||
ErrNoClientCredentials = errors.New("no client credentials provided")
|
||||
ErrMissingClientID = errors.New("client_id missing from request")
|
||||
)
|
||||
|
||||
type ClientJWTProfile interface {
|
||||
JWTProfileVerifier(context.Context) JWTProfileVerifier
|
||||
}
|
||||
|
||||
func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) {
|
||||
if ca.ClientAssertion == "" {
|
||||
return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
|
||||
}
|
||||
|
||||
profile, err := VerifyJWTAssertion(ctx, ca.ClientAssertion, verifier.JWTProfileVerifier(ctx))
|
||||
if err != nil {
|
||||
return "", oidc.ErrUnauthorizedClient().WithParent(err).WithDescription("JWT assertion failed")
|
||||
}
|
||||
return profile.Issuer, nil
|
||||
}
|
||||
|
||||
func ClientBasicAuth(r *http.Request, storage Storage) (clientID string, err error) {
|
||||
clientID, clientSecret, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
|
||||
}
|
||||
clientID, err = url.QueryUnescape(clientID)
|
||||
if err != nil {
|
||||
return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
|
||||
}
|
||||
clientSecret, err = url.QueryUnescape(clientSecret)
|
||||
if err != nil {
|
||||
return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
|
||||
}
|
||||
if err := storage.AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil {
|
||||
return "", oidc.ErrUnauthorizedClient().WithParent(err)
|
||||
}
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
type ClientProvider interface {
|
||||
Decoder() httphelper.Decoder
|
||||
Storage() Storage
|
||||
}
|
||||
|
||||
type clientData struct {
|
||||
ClientID string `schema:"client_id"`
|
||||
oidc.ClientAssertionParams
|
||||
}
|
||||
|
||||
// ClientIDFromRequest parses the request form and tries to obtain the client ID
|
||||
// and reports if it is authenticated, using a JWT or static client secrets over
|
||||
// http basic auth.
|
||||
//
|
||||
// If the Provider implements IntrospectorJWTProfile and "client_assertion" is
|
||||
// present in the form data, JWT assertion will be verified and the
|
||||
// client ID is taken from there.
|
||||
// If any of them is absent, basic auth is attempted.
|
||||
// In absence of basic auth data, the unauthenticated client id from the form
|
||||
// data is returned.
|
||||
//
|
||||
// If no client id can be obtained by any method, oidc.ErrInvalidClient
|
||||
// is returned with ErrMissingClientID wrapped in it.
|
||||
func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, authenticated bool, err error) {
|
||||
err = r.ParseForm()
|
||||
if err != nil {
|
||||
return "", false, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
|
||||
}
|
||||
|
||||
data := new(clientData)
|
||||
if err = p.Decoder().Decode(data, r.PostForm); err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
JWTProfile, ok := p.(ClientJWTProfile)
|
||||
if ok {
|
||||
clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile)
|
||||
}
|
||||
if !ok || errors.Is(err, ErrNoClientCredentials) {
|
||||
clientID, err = ClientBasicAuth(r, p.Storage())
|
||||
}
|
||||
if err == nil {
|
||||
return clientID, true, nil
|
||||
}
|
||||
|
||||
if data.ClientID == "" {
|
||||
return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID)
|
||||
}
|
||||
return data.ClientID, false, nil
|
||||
}
|
||||
|
|
253
pkg/op/client_test.go
Normal file
253
pkg/op/client_test.go
Normal file
|
@ -0,0 +1,253 @@
|
|||
package op_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v2/pkg/op/mock"
|
||||
)
|
||||
|
||||
type testClientJWTProfile struct{}
|
||||
|
||||
func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil }
|
||||
|
||||
func TestClientJWTAuth(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
ca oidc.ClientAssertionParams
|
||||
verifier op.ClientJWTProfile
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantClientID string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty assertion",
|
||||
args: args{
|
||||
context.Background(),
|
||||
oidc.ClientAssertionParams{},
|
||||
testClientJWTProfile{},
|
||||
},
|
||||
wantErr: op.ErrNoClientCredentials,
|
||||
},
|
||||
{
|
||||
name: "verification error",
|
||||
args: args{
|
||||
context.Background(),
|
||||
oidc.ClientAssertionParams{
|
||||
ClientAssertion: "foo",
|
||||
},
|
||||
testClientJWTProfile{},
|
||||
},
|
||||
wantErr: oidc.ErrParse,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotClientID, err := op.ClientJWTAuth(tt.args.ctx, tt.args.ca, tt.args.verifier)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.wantClientID, gotClientID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientBasicAuth(t *testing.T) {
|
||||
errWrong := errors.New("wrong secret")
|
||||
|
||||
type args struct {
|
||||
username string
|
||||
password string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args *args
|
||||
storage op.Storage
|
||||
wantClientID string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "no args",
|
||||
wantErr: op.ErrNoClientCredentials,
|
||||
},
|
||||
{
|
||||
name: "username unescape err",
|
||||
args: &args{
|
||||
username: "%",
|
||||
password: "bar",
|
||||
},
|
||||
wantErr: op.ErrInvalidAuthHeader,
|
||||
},
|
||||
{
|
||||
name: "password unescape err",
|
||||
args: &args{
|
||||
username: "foo",
|
||||
password: "%",
|
||||
},
|
||||
wantErr: op.ErrInvalidAuthHeader,
|
||||
},
|
||||
{
|
||||
name: "auth error",
|
||||
args: &args{
|
||||
username: "foo",
|
||||
password: "wrong",
|
||||
},
|
||||
storage: func() op.Storage {
|
||||
s := mock.NewMockStorage(gomock.NewController(t))
|
||||
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "wrong").Return(errWrong)
|
||||
return s
|
||||
}(),
|
||||
wantErr: errWrong,
|
||||
},
|
||||
{
|
||||
name: "auth error",
|
||||
args: &args{
|
||||
username: "foo",
|
||||
password: "bar",
|
||||
},
|
||||
storage: func() op.Storage {
|
||||
s := mock.NewMockStorage(gomock.NewController(t))
|
||||
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
|
||||
return s
|
||||
}(),
|
||||
wantClientID: "foo",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||
if tt.args != nil {
|
||||
r.SetBasicAuth(tt.args.username, tt.args.password)
|
||||
}
|
||||
|
||||
gotClientID, err := op.ClientBasicAuth(r, tt.storage)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.wantClientID, gotClientID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type errReader struct{}
|
||||
|
||||
func (errReader) Read([]byte) (int, error) {
|
||||
return 0, io.ErrNoProgress
|
||||
}
|
||||
|
||||
type testClientProvider struct {
|
||||
storage op.Storage
|
||||
}
|
||||
|
||||
func (testClientProvider) Decoder() httphelper.Decoder {
|
||||
return schema.NewDecoder()
|
||||
}
|
||||
|
||||
func (p testClientProvider) Storage() op.Storage {
|
||||
return p.storage
|
||||
}
|
||||
|
||||
func TestClientIDFromRequest(t *testing.T) {
|
||||
type args struct {
|
||||
body io.Reader
|
||||
p op.ClientProvider
|
||||
}
|
||||
type basicAuth struct {
|
||||
username string
|
||||
password string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
basicAuth *basicAuth
|
||||
wantClientID string
|
||||
wantAuthenticated bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "parse error",
|
||||
args: args{
|
||||
body: errReader{},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unauthenticated",
|
||||
args: args{
|
||||
body: strings.NewReader(
|
||||
url.Values{
|
||||
"client_id": []string{"foo"},
|
||||
}.Encode(),
|
||||
),
|
||||
p: testClientProvider{
|
||||
storage: mock.NewStorage(t),
|
||||
},
|
||||
},
|
||||
wantClientID: "foo",
|
||||
wantAuthenticated: false,
|
||||
},
|
||||
{
|
||||
name: "authenticated",
|
||||
args: args{
|
||||
body: strings.NewReader(
|
||||
url.Values{}.Encode(),
|
||||
),
|
||||
p: testClientProvider{
|
||||
storage: func() op.Storage {
|
||||
s := mock.NewMockStorage(gomock.NewController(t))
|
||||
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
|
||||
return s
|
||||
}(),
|
||||
},
|
||||
},
|
||||
basicAuth: &basicAuth{
|
||||
username: "foo",
|
||||
password: "bar",
|
||||
},
|
||||
wantClientID: "foo",
|
||||
wantAuthenticated: true,
|
||||
},
|
||||
{
|
||||
name: "missing client id",
|
||||
args: args{
|
||||
body: strings.NewReader(
|
||||
url.Values{}.Encode(),
|
||||
),
|
||||
p: testClientProvider{
|
||||
storage: mock.NewStorage(t),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodPost, "/foo", tt.args.body)
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
if tt.basicAuth != nil {
|
||||
r.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
|
||||
}
|
||||
|
||||
gotClientID, gotAuthenticated, err := op.ClientIDFromRequest(r, tt.args.p)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.wantClientID, gotClientID)
|
||||
assert.Equal(t, tt.wantAuthenticated, gotAuthenticated)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -2,20 +2,24 @@ package op
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
const (
|
||||
OidcDevMode = "ZITADEL_OIDC_DEV"
|
||||
// deprecated: use OidcDevMode (ZITADEL_OIDC_DEV=true)
|
||||
devMode = "CAOS_OIDC_DEV"
|
||||
var (
|
||||
ErrInvalidIssuerPath = errors.New("no fragments or query allowed for issuer")
|
||||
ErrInvalidIssuerNoIssuer = errors.New("missing issuer")
|
||||
ErrInvalidIssuerURL = errors.New("invalid url for issuer")
|
||||
ErrInvalidIssuerMissingHost = errors.New("host for issuer missing")
|
||||
ErrInvalidIssuerHTTPS = errors.New("scheme for issuer must be `https`")
|
||||
)
|
||||
|
||||
type Configuration interface {
|
||||
Issuer() string
|
||||
IssuerFromRequest(r *http.Request) string
|
||||
Insecure() bool
|
||||
AuthorizationEndpoint() Endpoint
|
||||
TokenEndpoint() Endpoint
|
||||
IntrospectionEndpoint() Endpoint
|
||||
|
@ -23,6 +27,7 @@ type Configuration interface {
|
|||
RevocationEndpoint() Endpoint
|
||||
EndSessionEndpoint() Endpoint
|
||||
KeysEndpoint() Endpoint
|
||||
DeviceAuthorizationEndpoint() Endpoint
|
||||
|
||||
AuthMethodPostSupported() bool
|
||||
CodeMethodS256Supported() bool
|
||||
|
@ -32,6 +37,7 @@ type Configuration interface {
|
|||
GrantTypeTokenExchangeSupported() bool
|
||||
GrantTypeJWTAuthorizationSupported() bool
|
||||
GrantTypeClientCredentialsSupported() bool
|
||||
GrantTypeDeviceCodeSupported() bool
|
||||
IntrospectionAuthMethodPrivateKeyJWTSupported() bool
|
||||
IntrospectionEndpointSigningAlgorithmsSupported() []string
|
||||
RevocationAuthMethodPrivateKeyJWTSupported() bool
|
||||
|
@ -40,38 +46,77 @@ type Configuration interface {
|
|||
RequestObjectSigningAlgorithmsSupported() []string
|
||||
|
||||
SupportedUILocales() []language.Tag
|
||||
DeviceAuthorization() DeviceAuthorizationConfig
|
||||
}
|
||||
|
||||
func ValidateIssuer(issuer string) error {
|
||||
type IssuerFromRequest func(r *http.Request) string
|
||||
|
||||
func IssuerFromHost(path string) func(bool) (IssuerFromRequest, error) {
|
||||
return func(allowInsecure bool) (IssuerFromRequest, error) {
|
||||
issuerPath, err := url.Parse(path)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidIssuerURL
|
||||
}
|
||||
if err := ValidateIssuerPath(issuerPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return func(r *http.Request) string {
|
||||
return dynamicIssuer(r.Host, path, allowInsecure)
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func StaticIssuer(issuer string) func(bool) (IssuerFromRequest, error) {
|
||||
return func(allowInsecure bool) (IssuerFromRequest, error) {
|
||||
if err := ValidateIssuer(issuer, allowInsecure); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return func(_ *http.Request) string {
|
||||
return issuer
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateIssuer(issuer string, allowInsecure bool) error {
|
||||
if issuer == "" {
|
||||
return errors.New("missing issuer")
|
||||
return ErrInvalidIssuerNoIssuer
|
||||
}
|
||||
u, err := url.Parse(issuer)
|
||||
if err != nil {
|
||||
return errors.New("invalid url for issuer")
|
||||
return ErrInvalidIssuerURL
|
||||
}
|
||||
if u.Host == "" {
|
||||
return errors.New("host for issuer missing")
|
||||
return ErrInvalidIssuerMissingHost
|
||||
}
|
||||
if u.Scheme != "https" {
|
||||
if !devLocalAllowed(u) {
|
||||
return errors.New("scheme for issuer must be `https`")
|
||||
if !devLocalAllowed(u, allowInsecure) {
|
||||
return ErrInvalidIssuerHTTPS
|
||||
}
|
||||
}
|
||||
if u.Fragment != "" || len(u.Query()) > 0 {
|
||||
return errors.New("no fragments or query allowed for issuer")
|
||||
return ValidateIssuerPath(u)
|
||||
}
|
||||
|
||||
func ValidateIssuerPath(issuer *url.URL) error {
|
||||
if issuer.Fragment != "" || len(issuer.Query()) > 0 {
|
||||
return ErrInvalidIssuerPath
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func devLocalAllowed(url *url.URL) bool {
|
||||
_, b := os.LookupEnv(OidcDevMode)
|
||||
if !b {
|
||||
// check the old / current env var as well
|
||||
_, b = os.LookupEnv(devMode)
|
||||
if !b {
|
||||
return b
|
||||
}
|
||||
func devLocalAllowed(url *url.URL, allowInsecure bool) bool {
|
||||
if !allowInsecure {
|
||||
return false
|
||||
}
|
||||
return url.Scheme == "http"
|
||||
}
|
||||
|
||||
func dynamicIssuer(issuer, path string, allowInsecure bool) string {
|
||||
schema := "https"
|
||||
if allowInsecure {
|
||||
schema = "http"
|
||||
}
|
||||
if len(path) > 0 && !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
return schema + "://" + issuer + path
|
||||
}
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"os"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestValidateIssuer(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
issuer string
|
||||
allowInsecure bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -16,65 +20,97 @@ func TestValidateIssuer(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
"missing issuer fails",
|
||||
args{""},
|
||||
args{
|
||||
issuer: "",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid url for issuer fails",
|
||||
args{":issuer"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid url for issuer fails",
|
||||
args{":issuer"},
|
||||
args{
|
||||
issuer: ":issuer",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host for issuer missing fails",
|
||||
args{"https:///issuer"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host for not https fails",
|
||||
args{"http://issuer.com"},
|
||||
args{
|
||||
issuer: "https:///issuer",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host with fragment fails",
|
||||
args{"https://issuer.com/#issuer"},
|
||||
args{
|
||||
issuer: "https://issuer.com/#issuer",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host with query fails",
|
||||
args{"https://issuer.com?issuer=me"},
|
||||
args{
|
||||
issuer: "https://issuer.com?issuer=me",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host with http fails",
|
||||
args{
|
||||
issuer: "http://issuer.com",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host with https ok",
|
||||
args{"https://issuer.com"},
|
||||
args{
|
||||
issuer: "https://issuer.com",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"localhost with http fails",
|
||||
args{"http://localhost:9999"},
|
||||
"custom scheme fails",
|
||||
args{
|
||||
issuer: "custom://localhost:9999",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"http with allowInsecure ok",
|
||||
args{
|
||||
issuer: "http://localhost:9999",
|
||||
allowInsecure: true,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"https with allowInsecure ok",
|
||||
args{
|
||||
issuer: "https://localhost:9999",
|
||||
allowInsecure: true,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"custom scheme with allowInsecure fails",
|
||||
args{
|
||||
issuer: "custom://localhost:9999",
|
||||
allowInsecure: true,
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
// ensure env is not set
|
||||
//nolint:errcheck
|
||||
os.Unsetenv(OidcDevMode)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
|
||||
if err := ValidateIssuer(tt.args.issuer, tt.args.allowInsecure); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateIssuerDevLocalAllowed(t *testing.T) {
|
||||
func TestValidateIssuerPath(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
issuerPath *url.URL
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -82,17 +118,217 @@ func TestValidateIssuerDevLocalAllowed(t *testing.T) {
|
|||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"localhost with http with dev ok",
|
||||
args{"http://localhost:9999"},
|
||||
"empty ok",
|
||||
args{func() *url.URL {
|
||||
u, _ := url.Parse("")
|
||||
return u
|
||||
}()},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"custom ok",
|
||||
args{func() *url.URL {
|
||||
u, _ := url.Parse("/custom")
|
||||
return u
|
||||
}()},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"fragment fails",
|
||||
args{func() *url.URL {
|
||||
u, _ := url.Parse("#fragment")
|
||||
return u
|
||||
}()},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"query fails",
|
||||
args{func() *url.URL {
|
||||
u, _ := url.Parse("?query=value")
|
||||
return u
|
||||
}()},
|
||||
true,
|
||||
},
|
||||
}
|
||||
//nolint:errcheck
|
||||
os.Setenv(OidcDevMode, "true")
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
if err := ValidateIssuerPath(tt.args.issuerPath); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateIssuerPath() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIssuerFromHost(t *testing.T) {
|
||||
type args struct {
|
||||
path string
|
||||
allowInsecure bool
|
||||
target string
|
||||
}
|
||||
type res struct {
|
||||
issuer string
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"invalid issuer path",
|
||||
args{
|
||||
path: "/#fragment",
|
||||
allowInsecure: false,
|
||||
},
|
||||
res{
|
||||
issuer: "",
|
||||
err: ErrInvalidIssuerPath,
|
||||
},
|
||||
},
|
||||
{
|
||||
"empty path secure",
|
||||
args{
|
||||
path: "",
|
||||
allowInsecure: false,
|
||||
target: "https://issuer.com",
|
||||
},
|
||||
res{
|
||||
issuer: "https://issuer.com",
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom path secure",
|
||||
args{
|
||||
path: "/custom/",
|
||||
allowInsecure: false,
|
||||
target: "https://issuer.com",
|
||||
},
|
||||
res{
|
||||
issuer: "https://issuer.com/custom/",
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom path no leading slash",
|
||||
args{
|
||||
path: "custom/",
|
||||
allowInsecure: false,
|
||||
target: "https://issuer.com",
|
||||
},
|
||||
res{
|
||||
issuer: "https://issuer.com/custom/",
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"empty path unsecure",
|
||||
args{
|
||||
path: "",
|
||||
allowInsecure: true,
|
||||
target: "http://issuer.com",
|
||||
},
|
||||
res{
|
||||
issuer: "http://issuer.com",
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom path unsecure",
|
||||
args{
|
||||
path: "/custom/",
|
||||
allowInsecure: true,
|
||||
target: "http://issuer.com",
|
||||
},
|
||||
res{
|
||||
issuer: "http://issuer.com/custom/",
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
issuer, err := IssuerFromHost(tt.args.path)(tt.args.allowInsecure)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
req := httptest.NewRequest("", tt.args.target, nil)
|
||||
assert.Equal(t, tt.res.issuer, issuer(req))
|
||||
}
|
||||
if tt.res.err != nil {
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticIssuer(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
allowInsecure bool
|
||||
}
|
||||
type res struct {
|
||||
issuer string
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"invalid issuer",
|
||||
args{
|
||||
issuer: "",
|
||||
allowInsecure: false,
|
||||
},
|
||||
res{
|
||||
issuer: "",
|
||||
err: ErrInvalidIssuerNoIssuer,
|
||||
},
|
||||
},
|
||||
{
|
||||
"empty path secure",
|
||||
args{
|
||||
issuer: "https://issuer.com",
|
||||
allowInsecure: false,
|
||||
},
|
||||
res{
|
||||
issuer: "https://issuer.com",
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom path secure",
|
||||
args{
|
||||
issuer: "https://issuer.com/custom/",
|
||||
allowInsecure: false,
|
||||
},
|
||||
res{
|
||||
issuer: "https://issuer.com/custom/",
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"unsecure",
|
||||
args{
|
||||
issuer: "http://issuer.com",
|
||||
allowInsecure: true,
|
||||
},
|
||||
res{
|
||||
issuer: "http://issuer.com",
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
issuer, err := StaticIssuer(tt.args.issuer)(tt.args.allowInsecure)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.res.issuer, issuer(nil))
|
||||
}
|
||||
if tt.res.err != nil {
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
53
pkg/op/context.go
Normal file
53
pkg/op/context.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type key int
|
||||
|
||||
const (
|
||||
issuerKey key = 0
|
||||
)
|
||||
|
||||
type IssuerInterceptor struct {
|
||||
issuerFromRequest IssuerFromRequest
|
||||
}
|
||||
|
||||
// NewIssuerInterceptor will set the issuer into the context
|
||||
// by the provided IssuerFromRequest (e.g. returned from StaticIssuer or IssuerFromHost)
|
||||
func NewIssuerInterceptor(issuerFromRequest IssuerFromRequest) *IssuerInterceptor {
|
||||
return &IssuerInterceptor{
|
||||
issuerFromRequest: issuerFromRequest,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *IssuerInterceptor) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
i.setIssuerCtx(w, r, next)
|
||||
})
|
||||
}
|
||||
|
||||
func (i *IssuerInterceptor) HandlerFunc(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
i.setIssuerCtx(w, r, next)
|
||||
}
|
||||
}
|
||||
|
||||
// IssuerFromContext reads the issuer from the context (set by an IssuerInterceptor)
|
||||
// it will return an empty string if not found
|
||||
func IssuerFromContext(ctx context.Context) string {
|
||||
ctxIssuer, _ := ctx.Value(issuerKey).(string)
|
||||
return ctxIssuer
|
||||
}
|
||||
|
||||
// ContextWithIssuer returns a new context with issuer set to it.
|
||||
func ContextWithIssuer(ctx context.Context, issuer string) context.Context {
|
||||
return context.WithValue(ctx, issuerKey, issuer)
|
||||
}
|
||||
|
||||
func (i *IssuerInterceptor) setIssuerCtx(w http.ResponseWriter, r *http.Request, next http.Handler) {
|
||||
r = r.WithContext(ContextWithIssuer(r.Context(), i.issuerFromRequest(r)))
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
76
pkg/op/context_test.go
Normal file
76
pkg/op/context_test.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIssuerInterceptor(t *testing.T) {
|
||||
type fields struct {
|
||||
issuerFromRequest IssuerFromRequest
|
||||
}
|
||||
type args struct {
|
||||
r *http.Request
|
||||
next http.Handler
|
||||
}
|
||||
type res struct {
|
||||
issuer string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
fields{
|
||||
func(r *http.Request) string {
|
||||
return ""
|
||||
},
|
||||
},
|
||||
args{},
|
||||
res{
|
||||
issuer: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
"static",
|
||||
fields{
|
||||
func(r *http.Request) string {
|
||||
return "static"
|
||||
},
|
||||
},
|
||||
args{},
|
||||
res{
|
||||
issuer: "static",
|
||||
},
|
||||
},
|
||||
{
|
||||
"host",
|
||||
fields{
|
||||
func(r *http.Request) string {
|
||||
return r.Host
|
||||
},
|
||||
},
|
||||
args{},
|
||||
res{
|
||||
issuer: "issuer.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
i := NewIssuerInterceptor(tt.fields.issuerFromRequest)
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, tt.res.issuer, IssuerFromContext(r.Context()))
|
||||
})
|
||||
req := httptest.NewRequest("", "https://issuer.com", nil)
|
||||
i.Handler(next).ServeHTTP(nil, req)
|
||||
i.HandlerFunc(next).ServeHTTP(nil, req)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"github.com/zitadel/oidc/pkg/crypto"
|
||||
"github.com/zitadel/oidc/v2/pkg/crypto"
|
||||
)
|
||||
|
||||
type Crypto interface {
|
||||
|
|
265
pkg/op/device.go
Normal file
265
pkg/op/device.go
Normal file
|
@ -0,0 +1,265 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type DeviceAuthorizationConfig struct {
|
||||
Lifetime time.Duration
|
||||
PollInterval time.Duration
|
||||
UserFormURL string // the URL where the user must go to authorize the device
|
||||
UserCode UserCodeConfig
|
||||
}
|
||||
|
||||
type UserCodeConfig struct {
|
||||
CharSet string
|
||||
CharAmount int
|
||||
DashInterval int
|
||||
}
|
||||
|
||||
const (
|
||||
CharSetBase20 = "BCDFGHJKLMNPQRSTVWXZ"
|
||||
CharSetDigits = "0123456789"
|
||||
)
|
||||
|
||||
var (
|
||||
UserCodeBase20 = UserCodeConfig{
|
||||
CharSet: CharSetBase20,
|
||||
CharAmount: 8,
|
||||
DashInterval: 4,
|
||||
}
|
||||
UserCodeDigits = UserCodeConfig{
|
||||
CharSet: CharSetDigits,
|
||||
CharAmount: 9,
|
||||
DashInterval: 3,
|
||||
}
|
||||
)
|
||||
|
||||
func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := DeviceAuthorization(w, r, o); err != nil {
|
||||
RequestError(w, r, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error {
|
||||
storage, err := assertDeviceStorage(o.Storage())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := ParseDeviceCodeRequest(r, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := o.DeviceAuthorization()
|
||||
|
||||
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expires := time.Now().Add(config.Lifetime)
|
||||
err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response := &oidc.DeviceAuthorizationResponse{
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
VerificationURI: config.UserFormURL,
|
||||
ExpiresIn: int(config.Lifetime / time.Second),
|
||||
Interval: int(config.PollInterval / time.Second),
|
||||
}
|
||||
|
||||
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode)
|
||||
|
||||
httphelper.MarshalJSON(w, response)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) {
|
||||
clientID, _, err := ClientIDFromRequest(r, o)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := new(oidc.DeviceAuthorizationRequest)
|
||||
if err := o.Decoder().Decode(req, r.Form); err != nil {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err)
|
||||
}
|
||||
req.ClientID = clientID
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// 16 bytes gives 128 bit of entropy.
|
||||
// results in a 22 character base64 encoded string.
|
||||
const RecommendedDeviceCodeBytes = 16
|
||||
|
||||
func NewDeviceCode(nBytes int) (string, error) {
|
||||
bytes := make([]byte, nBytes)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("%w getting entropy for device code", err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) {
|
||||
var buf strings.Builder
|
||||
if dashInterval > 0 {
|
||||
buf.Grow(charAmount + charAmount/dashInterval - 1)
|
||||
} else {
|
||||
buf.Grow(charAmount)
|
||||
}
|
||||
|
||||
max := big.NewInt(int64(len(charSet)))
|
||||
|
||||
for i := 0; i < charAmount; i++ {
|
||||
if dashInterval != 0 && i != 0 && i%dashInterval == 0 {
|
||||
buf.WriteByte('-')
|
||||
}
|
||||
|
||||
bi, err := rand.Int(rand.Reader, max)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w getting entropy for user code", err)
|
||||
}
|
||||
|
||||
buf.WriteRune(charSet[int(bi.Int64())])
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
type deviceAccessTokenRequest struct {
|
||||
subject string
|
||||
audience []string
|
||||
scopes []string
|
||||
}
|
||||
|
||||
func (r *deviceAccessTokenRequest) GetSubject() string {
|
||||
return r.subject
|
||||
}
|
||||
|
||||
func (r *deviceAccessTokenRequest) GetAudience() []string {
|
||||
return r.audience
|
||||
}
|
||||
|
||||
func (r *deviceAccessTokenRequest) GetScopes() []string {
|
||||
return r.scopes
|
||||
}
|
||||
|
||||
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
if err := deviceAccessToken(w, r, exchanger); err != nil {
|
||||
RequestError(w, r, err)
|
||||
}
|
||||
}
|
||||
|
||||
func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) error {
|
||||
// use a limited context timeout shorter as the default
|
||||
// poll interval of 5 seconds.
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second)
|
||||
defer cancel()
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
clientID, clientAuthenticated, err := ClientIDFromRequest(r, exchanger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := ParseDeviceAccessTokenRequest(r, exchanger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := exchanger.Storage().GetClientByClientID(ctx, clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if clientAuthenticated != IsConfidentialType(client) {
|
||||
return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
|
||||
WithDescription("confidential client requires authentication")
|
||||
}
|
||||
|
||||
tokenRequest := &deviceAccessTokenRequest{
|
||||
subject: state.Subject,
|
||||
audience: []string{clientID},
|
||||
scopes: state.Scopes,
|
||||
}
|
||||
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httphelper.MarshalJSON(w, resp)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) {
|
||||
req := new(oidc.DeviceAccessTokenRequest)
|
||||
if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) {
|
||||
storage, err := assertDeviceStorage(exchanger.Storage())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode)
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, oidc.ErrSlowDown().WithParent(err)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, oidc.ErrAccessDenied().WithParent(err)
|
||||
}
|
||||
if state.Denied {
|
||||
return state, oidc.ErrAccessDenied()
|
||||
}
|
||||
if state.Done {
|
||||
return state, nil
|
||||
}
|
||||
if time.Now().After(state.Expires) {
|
||||
return state, oidc.ErrExpiredDeviceCode()
|
||||
}
|
||||
return state, oidc.ErrAuthorizationPending()
|
||||
}
|
||||
|
||||
func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client AccessTokenClient) (*oidc.AccessTokenResponse, error) {
|
||||
accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator, client, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &oidc.AccessTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
TokenType: oidc.BearerToken,
|
||||
ExpiresIn: uint64(validity.Seconds()),
|
||||
}, nil
|
||||
}
|
407
pkg/op/device_test.go
Normal file
407
pkg/op/device_test.go
Normal file
|
@ -0,0 +1,407 @@
|
|||
package op_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
mr "math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
)
|
||||
|
||||
func Test_deviceAuthorizationHandler(t *testing.T) {
|
||||
req := &oidc.DeviceAuthorizationRequest{
|
||||
Scopes: []string{"foo", "bar"},
|
||||
ClientID: "web",
|
||||
}
|
||||
values := make(url.Values)
|
||||
testProvider.Encoder().Encode(req, values)
|
||||
body := strings.NewReader(values.Encode())
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
runWithRandReader(mr.New(mr.NewSource(1)), func() {
|
||||
op.DeviceAuthorizationHandler(testProvider)(w, r)
|
||||
})
|
||||
|
||||
result := w.Result()
|
||||
|
||||
assert.Less(t, result.StatusCode, 300)
|
||||
|
||||
got, _ := io.ReadAll(result.Body)
|
||||
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
|
||||
}
|
||||
|
||||
func TestParseDeviceCodeRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *oidc.DeviceAuthorizationRequest
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty request",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
req: &oidc.DeviceAuthorizationRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"foo", "bar"},
|
||||
ClientID: "web",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var body io.Reader
|
||||
if tt.req != nil {
|
||||
values := make(url.Values)
|
||||
testProvider.Encoder().Encode(tt.req, values)
|
||||
body = strings.NewReader(values.Encode())
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
got, err := op.ParseDeviceCodeRequest(r, testProvider)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.req, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runWithRandReader(r io.Reader, f func()) {
|
||||
originalReader := rand.Reader
|
||||
rand.Reader = r
|
||||
defer func() {
|
||||
rand.Reader = originalReader
|
||||
}()
|
||||
|
||||
f()
|
||||
}
|
||||
|
||||
func TestNewDeviceCode(t *testing.T) {
|
||||
t.Run("reader error", func(t *testing.T) {
|
||||
runWithRandReader(errReader{}, func() {
|
||||
_, err := op.NewDeviceCode(16)
|
||||
require.Error(t, err)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("different lengths, rand reader", func(t *testing.T) {
|
||||
for i := 1; i <= 32; i++ {
|
||||
got, err := op.NewDeviceCode(i)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, got, base64.RawURLEncoding.EncodedLen(i))
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestNewUserCode(t *testing.T) {
|
||||
type args struct {
|
||||
charset []rune
|
||||
charAmount int
|
||||
dashInterval int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
reader io.Reader
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "reader error",
|
||||
args: args{
|
||||
charset: []rune(op.CharSetBase20),
|
||||
charAmount: 8,
|
||||
dashInterval: 4,
|
||||
},
|
||||
reader: errReader{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "base20",
|
||||
args: args{
|
||||
charset: []rune(op.CharSetBase20),
|
||||
charAmount: 8,
|
||||
dashInterval: 4,
|
||||
},
|
||||
reader: mr.New(mr.NewSource(1)),
|
||||
want: "XKCD-HTTD",
|
||||
},
|
||||
{
|
||||
name: "digits",
|
||||
args: args{
|
||||
charset: []rune(op.CharSetDigits),
|
||||
charAmount: 9,
|
||||
dashInterval: 3,
|
||||
},
|
||||
reader: mr.New(mr.NewSource(1)),
|
||||
want: "271-256-225",
|
||||
},
|
||||
{
|
||||
name: "no dashes",
|
||||
args: args{
|
||||
charset: []rune(op.CharSetDigits),
|
||||
charAmount: 9,
|
||||
},
|
||||
reader: mr.New(mr.NewSource(1)),
|
||||
want: "271256225",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
runWithRandReader(tt.reader, func() {
|
||||
got, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
|
||||
if tt.wantErr {
|
||||
require.ErrorIs(t, err, io.ErrNoProgress)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("crypto/rand", func(t *testing.T) {
|
||||
const testN = 100000
|
||||
|
||||
for _, c := range []op.UserCodeConfig{op.UserCodeBase20, op.UserCodeDigits} {
|
||||
t.Run(c.CharSet, func(t *testing.T) {
|
||||
results := make(map[string]int)
|
||||
|
||||
for i := 0; i < testN; i++ {
|
||||
code, err := op.NewUserCode([]rune(c.CharSet), c.CharAmount, c.DashInterval)
|
||||
require.NoError(t, err)
|
||||
results[code]++
|
||||
}
|
||||
|
||||
t.Log(results)
|
||||
|
||||
var duplicates int
|
||||
for code, count := range results {
|
||||
assert.Less(t, count, 3, code)
|
||||
if count == 2 {
|
||||
duplicates++
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkNewUserCode(b *testing.B) {
|
||||
type args struct {
|
||||
charset []rune
|
||||
charAmount int
|
||||
dashInterval int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
reader io.Reader
|
||||
}{
|
||||
{
|
||||
name: "math rand, base20",
|
||||
args: args{
|
||||
charset: []rune(op.CharSetBase20),
|
||||
charAmount: 8,
|
||||
dashInterval: 4,
|
||||
},
|
||||
reader: mr.New(mr.NewSource(1)),
|
||||
},
|
||||
{
|
||||
name: "math rand, digits",
|
||||
args: args{
|
||||
charset: []rune(op.CharSetDigits),
|
||||
charAmount: 9,
|
||||
dashInterval: 3,
|
||||
},
|
||||
reader: mr.New(mr.NewSource(1)),
|
||||
},
|
||||
{
|
||||
name: "crypto rand, base20",
|
||||
args: args{
|
||||
charset: []rune(op.CharSetBase20),
|
||||
charAmount: 8,
|
||||
dashInterval: 4,
|
||||
},
|
||||
reader: rand.Reader,
|
||||
},
|
||||
{
|
||||
name: "crypto rand, digits",
|
||||
args: args{
|
||||
charset: []rune(op.CharSetDigits),
|
||||
charAmount: 9,
|
||||
dashInterval: 3,
|
||||
},
|
||||
reader: rand.Reader,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
runWithRandReader(tt.reader, func() {
|
||||
b.Run(tt.name, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceAccessToken(t *testing.T) {
|
||||
storage := testProvider.Storage().(op.DeviceAuthorizationStorage)
|
||||
storage.StoreDeviceAuthorization(context.Background(), "native", "qwerty", "yuiop", time.Now().Add(time.Minute), []string{"foo"})
|
||||
storage.CompleteDeviceAuthorization(context.Background(), "yuiop", "tim")
|
||||
|
||||
values := make(url.Values)
|
||||
values.Set("client_id", "native")
|
||||
values.Set("grant_type", string(oidc.GrantTypeDeviceCode))
|
||||
values.Set("device_code", "qwerty")
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(values.Encode()))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
op.DeviceAccessToken(w, r, testProvider)
|
||||
|
||||
result := w.Result()
|
||||
got, _ := io.ReadAll(result.Body)
|
||||
t.Log(string(got))
|
||||
assert.Less(t, result.StatusCode, 300)
|
||||
assert.NotEmpty(t, string(got))
|
||||
}
|
||||
|
||||
func TestCheckDeviceAuthorizationState(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
storage := testProvider.Storage().(op.DeviceAuthorizationStorage)
|
||||
storage.StoreDeviceAuthorization(context.Background(), "native", "pending", "pending", now.Add(time.Minute), []string{"foo"})
|
||||
storage.StoreDeviceAuthorization(context.Background(), "native", "denied", "denied", now.Add(time.Minute), []string{"foo"})
|
||||
storage.StoreDeviceAuthorization(context.Background(), "native", "completed", "completed", now.Add(time.Minute), []string{"foo"})
|
||||
storage.StoreDeviceAuthorization(context.Background(), "native", "expired", "expired", now.Add(-time.Minute), []string{"foo"})
|
||||
|
||||
storage.DenyDeviceAuthorization(context.Background(), "denied")
|
||||
storage.CompleteDeviceAuthorization(context.Background(), "completed", "tim")
|
||||
|
||||
exceededCtx, cancel := context.WithTimeout(context.Background(), -time.Second)
|
||||
defer cancel()
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
clientID string
|
||||
deviceCode string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *op.DeviceAuthorizationState
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "pending",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
clientID: "native",
|
||||
deviceCode: "pending",
|
||||
},
|
||||
want: &op.DeviceAuthorizationState{
|
||||
ClientID: "native",
|
||||
Scopes: []string{"foo"},
|
||||
Expires: now.Add(time.Minute),
|
||||
},
|
||||
wantErr: oidc.ErrAuthorizationPending(),
|
||||
},
|
||||
{
|
||||
name: "slow down",
|
||||
args: args{
|
||||
ctx: exceededCtx,
|
||||
clientID: "native",
|
||||
deviceCode: "ok",
|
||||
},
|
||||
wantErr: oidc.ErrSlowDown(),
|
||||
},
|
||||
{
|
||||
name: "wrong client",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
clientID: "foo",
|
||||
deviceCode: "ok",
|
||||
},
|
||||
wantErr: oidc.ErrAccessDenied(),
|
||||
},
|
||||
{
|
||||
name: "denied",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
clientID: "native",
|
||||
deviceCode: "denied",
|
||||
},
|
||||
want: &op.DeviceAuthorizationState{
|
||||
ClientID: "native",
|
||||
Scopes: []string{"foo"},
|
||||
Expires: now.Add(time.Minute),
|
||||
Denied: true,
|
||||
},
|
||||
wantErr: oidc.ErrAccessDenied(),
|
||||
},
|
||||
{
|
||||
name: "completed",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
clientID: "native",
|
||||
deviceCode: "completed",
|
||||
},
|
||||
want: &op.DeviceAuthorizationState{
|
||||
ClientID: "native",
|
||||
Scopes: []string{"foo"},
|
||||
Expires: now.Add(time.Minute),
|
||||
Subject: "tim",
|
||||
Done: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "expired",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
clientID: "native",
|
||||
deviceCode: "expired",
|
||||
},
|
||||
want: &op.DeviceAuthorizationState{
|
||||
ClientID: "native",
|
||||
Scopes: []string{"foo"},
|
||||
Expires: now.Add(-time.Minute),
|
||||
},
|
||||
wantErr: oidc.ErrExpiredDeviceCode(),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := op.CheckDeviceAuthorizationState(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, testProvider)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,49 +1,17 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
func discoveryHandler(c Configuration, s Signer) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
Discover(w, CreateDiscoveryConfig(c, s))
|
||||
}
|
||||
}
|
||||
|
||||
func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
|
||||
httphelper.MarshalJSON(w, config)
|
||||
}
|
||||
|
||||
func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfiguration {
|
||||
return &oidc.DiscoveryConfiguration{
|
||||
Issuer: c.Issuer(),
|
||||
AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()),
|
||||
TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()),
|
||||
IntrospectionEndpoint: c.IntrospectionEndpoint().Absolute(c.Issuer()),
|
||||
UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()),
|
||||
RevocationEndpoint: c.RevocationEndpoint().Absolute(c.Issuer()),
|
||||
EndSessionEndpoint: c.EndSessionEndpoint().Absolute(c.Issuer()),
|
||||
JwksURI: c.KeysEndpoint().Absolute(c.Issuer()),
|
||||
ScopesSupported: Scopes(c),
|
||||
ResponseTypesSupported: ResponseTypes(c),
|
||||
GrantTypesSupported: GrantTypes(c),
|
||||
SubjectTypesSupported: SubjectTypes(c),
|
||||
IDTokenSigningAlgValuesSupported: SigAlgorithms(s),
|
||||
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(c),
|
||||
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(c),
|
||||
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(c),
|
||||
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(c),
|
||||
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(c),
|
||||
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(c),
|
||||
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(c),
|
||||
ClaimsSupported: SupportedClaims(c),
|
||||
CodeChallengeMethodsSupported: CodeChallengeMethods(c),
|
||||
UILocalesSupported: c.SupportedUILocales(),
|
||||
RequestParameterSupported: c.RequestObjectSupported(),
|
||||
}
|
||||
type DiscoverStorage interface {
|
||||
SignatureAlgorithms(context.Context) ([]jose.SignatureAlgorithm, error)
|
||||
}
|
||||
|
||||
var DefaultSupportedScopes = []string{
|
||||
|
@ -55,6 +23,47 @@ var DefaultSupportedScopes = []string{
|
|||
oidc.ScopeOfflineAccess,
|
||||
}
|
||||
|
||||
func discoveryHandler(c Configuration, s DiscoverStorage) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
Discover(w, CreateDiscoveryConfig(r, c, s))
|
||||
}
|
||||
}
|
||||
|
||||
func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
|
||||
httphelper.MarshalJSON(w, config)
|
||||
}
|
||||
|
||||
func CreateDiscoveryConfig(r *http.Request, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration {
|
||||
issuer := config.IssuerFromRequest(r)
|
||||
return &oidc.DiscoveryConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: config.AuthorizationEndpoint().Absolute(issuer),
|
||||
TokenEndpoint: config.TokenEndpoint().Absolute(issuer),
|
||||
IntrospectionEndpoint: config.IntrospectionEndpoint().Absolute(issuer),
|
||||
UserinfoEndpoint: config.UserinfoEndpoint().Absolute(issuer),
|
||||
RevocationEndpoint: config.RevocationEndpoint().Absolute(issuer),
|
||||
EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer),
|
||||
JwksURI: config.KeysEndpoint().Absolute(issuer),
|
||||
DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer),
|
||||
ScopesSupported: Scopes(config),
|
||||
ResponseTypesSupported: ResponseTypes(config),
|
||||
GrantTypesSupported: GrantTypes(config),
|
||||
SubjectTypesSupported: SubjectTypes(config),
|
||||
IDTokenSigningAlgValuesSupported: SigAlgorithms(r.Context(), storage),
|
||||
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
|
||||
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
|
||||
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
|
||||
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config),
|
||||
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config),
|
||||
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config),
|
||||
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config),
|
||||
ClaimsSupported: SupportedClaims(config),
|
||||
CodeChallengeMethodsSupported: CodeChallengeMethods(config),
|
||||
UILocalesSupported: config.SupportedUILocales(),
|
||||
RequestParameterSupported: config.RequestObjectSupported(),
|
||||
}
|
||||
}
|
||||
|
||||
func Scopes(c Configuration) []string {
|
||||
return DefaultSupportedScopes // TODO: config
|
||||
}
|
||||
|
@ -84,9 +93,94 @@ func GrantTypes(c Configuration) []oidc.GrantType {
|
|||
if c.GrantTypeJWTAuthorizationSupported() {
|
||||
grantTypes = append(grantTypes, oidc.GrantTypeBearer)
|
||||
}
|
||||
if c.GrantTypeDeviceCodeSupported() {
|
||||
grantTypes = append(grantTypes, oidc.GrantTypeDeviceCode)
|
||||
}
|
||||
return grantTypes
|
||||
}
|
||||
|
||||
func SubjectTypes(c Configuration) []string {
|
||||
return []string{"public"} //TODO: config
|
||||
}
|
||||
|
||||
func SigAlgorithms(ctx context.Context, storage DiscoverStorage) []string {
|
||||
algorithms, err := storage.SignatureAlgorithms(ctx)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
algs := make([]string, len(algorithms))
|
||||
for i, algorithm := range algorithms {
|
||||
algs[i] = string(algorithm)
|
||||
}
|
||||
return algs
|
||||
}
|
||||
|
||||
func RequestObjectSigAlgorithms(c Configuration) []string {
|
||||
if !c.RequestObjectSupported() {
|
||||
return nil
|
||||
}
|
||||
return c.RequestObjectSigningAlgorithmsSupported()
|
||||
}
|
||||
|
||||
func AuthMethodsTokenEndpoint(c Configuration) []oidc.AuthMethod {
|
||||
authMethods := []oidc.AuthMethod{
|
||||
oidc.AuthMethodNone,
|
||||
oidc.AuthMethodBasic,
|
||||
}
|
||||
if c.AuthMethodPostSupported() {
|
||||
authMethods = append(authMethods, oidc.AuthMethodPost)
|
||||
}
|
||||
if c.AuthMethodPrivateKeyJWTSupported() {
|
||||
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
|
||||
}
|
||||
return authMethods
|
||||
}
|
||||
|
||||
func TokenSigAlgorithms(c Configuration) []string {
|
||||
if !c.AuthMethodPrivateKeyJWTSupported() {
|
||||
return nil
|
||||
}
|
||||
return c.TokenEndpointSigningAlgorithmsSupported()
|
||||
}
|
||||
|
||||
func IntrospectionSigAlgorithms(c Configuration) []string {
|
||||
if !c.IntrospectionAuthMethodPrivateKeyJWTSupported() {
|
||||
return nil
|
||||
}
|
||||
return c.IntrospectionEndpointSigningAlgorithmsSupported()
|
||||
}
|
||||
|
||||
func AuthMethodsIntrospectionEndpoint(c Configuration) []oidc.AuthMethod {
|
||||
authMethods := []oidc.AuthMethod{
|
||||
oidc.AuthMethodBasic,
|
||||
}
|
||||
if c.AuthMethodPrivateKeyJWTSupported() {
|
||||
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
|
||||
}
|
||||
return authMethods
|
||||
}
|
||||
|
||||
func RevocationSigAlgorithms(c Configuration) []string {
|
||||
if !c.RevocationAuthMethodPrivateKeyJWTSupported() {
|
||||
return nil
|
||||
}
|
||||
return c.RevocationEndpointSigningAlgorithmsSupported()
|
||||
}
|
||||
|
||||
func AuthMethodsRevocationEndpoint(c Configuration) []oidc.AuthMethod {
|
||||
authMethods := []oidc.AuthMethod{
|
||||
oidc.AuthMethodNone,
|
||||
oidc.AuthMethodBasic,
|
||||
}
|
||||
if c.AuthMethodPostSupported() {
|
||||
authMethods = append(authMethods, oidc.AuthMethodPost)
|
||||
}
|
||||
if c.AuthMethodPrivateKeyJWTSupported() {
|
||||
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
|
||||
}
|
||||
return authMethods
|
||||
}
|
||||
|
||||
func SupportedClaims(c Configuration) []string {
|
||||
return []string{ // TODO: config
|
||||
"sub",
|
||||
|
@ -116,59 +210,6 @@ func SupportedClaims(c Configuration) []string {
|
|||
}
|
||||
}
|
||||
|
||||
func SigAlgorithms(s Signer) []string {
|
||||
return []string{string(s.SignatureAlgorithm())}
|
||||
}
|
||||
|
||||
func SubjectTypes(c Configuration) []string {
|
||||
return []string{"public"} // TODO: config
|
||||
}
|
||||
|
||||
func AuthMethodsTokenEndpoint(c Configuration) []oidc.AuthMethod {
|
||||
authMethods := []oidc.AuthMethod{
|
||||
oidc.AuthMethodNone,
|
||||
oidc.AuthMethodBasic,
|
||||
}
|
||||
if c.AuthMethodPostSupported() {
|
||||
authMethods = append(authMethods, oidc.AuthMethodPost)
|
||||
}
|
||||
if c.AuthMethodPrivateKeyJWTSupported() {
|
||||
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
|
||||
}
|
||||
return authMethods
|
||||
}
|
||||
|
||||
func TokenSigAlgorithms(c Configuration) []string {
|
||||
if !c.AuthMethodPrivateKeyJWTSupported() {
|
||||
return nil
|
||||
}
|
||||
return c.TokenEndpointSigningAlgorithmsSupported()
|
||||
}
|
||||
|
||||
func AuthMethodsIntrospectionEndpoint(c Configuration) []oidc.AuthMethod {
|
||||
authMethods := []oidc.AuthMethod{
|
||||
oidc.AuthMethodBasic,
|
||||
}
|
||||
if c.AuthMethodPrivateKeyJWTSupported() {
|
||||
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
|
||||
}
|
||||
return authMethods
|
||||
}
|
||||
|
||||
func AuthMethodsRevocationEndpoint(c Configuration) []oidc.AuthMethod {
|
||||
authMethods := []oidc.AuthMethod{
|
||||
oidc.AuthMethodNone,
|
||||
oidc.AuthMethodBasic,
|
||||
}
|
||||
if c.AuthMethodPostSupported() {
|
||||
authMethods = append(authMethods, oidc.AuthMethodPost)
|
||||
}
|
||||
if c.AuthMethodPrivateKeyJWTSupported() {
|
||||
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
|
||||
}
|
||||
return authMethods
|
||||
}
|
||||
|
||||
func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod {
|
||||
codeMethods := make([]oidc.CodeChallengeMethod, 0, 1)
|
||||
if c.CodeMethodS256Supported() {
|
||||
|
@ -176,24 +217,3 @@ func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod {
|
|||
}
|
||||
return codeMethods
|
||||
}
|
||||
|
||||
func IntrospectionSigAlgorithms(c Configuration) []string {
|
||||
if !c.IntrospectionAuthMethodPrivateKeyJWTSupported() {
|
||||
return nil
|
||||
}
|
||||
return c.IntrospectionEndpointSigningAlgorithmsSupported()
|
||||
}
|
||||
|
||||
func RevocationSigAlgorithms(c Configuration) []string {
|
||||
if !c.RevocationAuthMethodPrivateKeyJWTSupported() {
|
||||
return nil
|
||||
}
|
||||
return c.RevocationEndpointSigningAlgorithmsSupported()
|
||||
}
|
||||
|
||||
func RequestObjectSigAlgorithms(c Configuration) []string {
|
||||
if !c.RequestObjectSupported() {
|
||||
return nil
|
||||
}
|
||||
return c.RequestObjectSigningAlgorithmsSupported()
|
||||
}
|
||||
|
|
|
@ -1,18 +1,19 @@
|
|||
package op_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/pkg/op"
|
||||
"github.com/zitadel/oidc/pkg/op/mock"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v2/pkg/op/mock"
|
||||
)
|
||||
|
||||
func TestDiscover(t *testing.T) {
|
||||
|
@ -47,8 +48,9 @@ func TestDiscover(t *testing.T) {
|
|||
|
||||
func TestCreateDiscoveryConfig(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
s op.Signer
|
||||
request *http.Request
|
||||
c op.Configuration
|
||||
s op.DiscoverStorage
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -59,9 +61,8 @@ func TestCreateDiscoveryConfig(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want)
|
||||
}
|
||||
got := op.CreateDiscoveryConfig(tt.args.request, tt.args.c, tt.args.s)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -83,9 +84,8 @@ func Test_scopes(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.Scopes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("scopes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
got := op.Scopes(tt.args.c)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -99,13 +99,16 @@ func Test_ResponseTypes(t *testing.T) {
|
|||
args args
|
||||
want []string
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
{
|
||||
"code and implicit flow",
|
||||
args{},
|
||||
[]string{"code", "id_token", "id_token token"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.ResponseTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("responseTypes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
got := op.ResponseTypes(tt.args.c)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -117,63 +120,53 @@ func Test_GrantTypes(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.GrantTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("grantTypes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedClaims(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.SupportedClaims(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SupportedClaims() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SigAlgorithms(t *testing.T) {
|
||||
m := mock.NewMockSigner(gomock.NewController(t))
|
||||
type args struct {
|
||||
s op.Signer
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
want []oidc.GrantType
|
||||
}{
|
||||
{
|
||||
"",
|
||||
args{func() op.Signer {
|
||||
m.EXPECT().SignatureAlgorithm().Return(jose.RS256)
|
||||
return m
|
||||
}()},
|
||||
[]string{"RS256"},
|
||||
"code and implicit flow",
|
||||
args{
|
||||
func() op.Configuration {
|
||||
c := mock.NewMockConfiguration(gomock.NewController(t))
|
||||
c.EXPECT().GrantTypeRefreshTokenSupported().Return(false)
|
||||
c.EXPECT().GrantTypeTokenExchangeSupported().Return(false)
|
||||
c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(false)
|
||||
c.EXPECT().GrantTypeClientCredentialsSupported().Return(false)
|
||||
c.EXPECT().GrantTypeDeviceCodeSupported().Return(false)
|
||||
return c
|
||||
}(),
|
||||
},
|
||||
[]oidc.GrantType{
|
||||
oidc.GrantTypeCode,
|
||||
oidc.GrantTypeImplicit,
|
||||
},
|
||||
},
|
||||
{
|
||||
"code, implicit flow, refresh token, token exchange, jwt profile, client_credentials",
|
||||
args{
|
||||
func() op.Configuration {
|
||||
c := mock.NewMockConfiguration(gomock.NewController(t))
|
||||
c.EXPECT().GrantTypeRefreshTokenSupported().Return(true)
|
||||
c.EXPECT().GrantTypeTokenExchangeSupported().Return(true)
|
||||
c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(true)
|
||||
c.EXPECT().GrantTypeClientCredentialsSupported().Return(true)
|
||||
c.EXPECT().GrantTypeDeviceCodeSupported().Return(false)
|
||||
return c
|
||||
}(),
|
||||
},
|
||||
[]oidc.GrantType{
|
||||
oidc.GrantTypeCode,
|
||||
oidc.GrantTypeImplicit,
|
||||
oidc.GrantTypeRefreshToken,
|
||||
oidc.GrantTypeClientCredentials,
|
||||
oidc.GrantTypeTokenExchange,
|
||||
oidc.GrantTypeBearer,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.SigAlgorithms(tt.args.s); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("sigAlgorithms() = %v, want %v", got, tt.want)
|
||||
}
|
||||
got := op.GrantTypes(tt.args.c)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -195,9 +188,80 @@ func Test_SubjectTypes(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.SubjectTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("subjectTypes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
got := op.SubjectTypes(tt.args.c)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SigAlgorithms(t *testing.T) {
|
||||
m := mock.NewMockDiscoverStorage(gomock.NewController(t))
|
||||
type args struct {
|
||||
s op.DiscoverStorage
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"",
|
||||
args{func() op.DiscoverStorage {
|
||||
m.EXPECT().SignatureAlgorithms(gomock.Any()).Return([]jose.SignatureAlgorithm{jose.RS256}, nil)
|
||||
return m
|
||||
}()},
|
||||
[]string{"RS256"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := op.SigAlgorithms(context.Background(), tt.args.s)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_RequestObjectSigAlgorithms(t *testing.T) {
|
||||
m := mock.NewMockConfiguration(gomock.NewController(t))
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"not supported, empty",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().RequestObjectSupported().Return(false)
|
||||
return m
|
||||
}()},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"supported, empty",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().RequestObjectSupported().Return(true)
|
||||
m.EXPECT().RequestObjectSigningAlgorithmsSupported().Return(nil)
|
||||
return m
|
||||
}()},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"supported, list",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().RequestObjectSupported().Return(true)
|
||||
m.EXPECT().RequestObjectSigningAlgorithmsSupported().Return([]string{"RS256"})
|
||||
return m
|
||||
}()},
|
||||
[]string{"RS256"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := op.RequestObjectSigAlgorithms(tt.args.c)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -244,9 +308,311 @@ func Test_AuthMethodsTokenEndpoint(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.AuthMethodsTokenEndpoint(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("authMethods() = %v, want %v", got, tt.want)
|
||||
}
|
||||
got := op.AuthMethodsTokenEndpoint(tt.args.c)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TokenSigAlgorithms(t *testing.T) {
|
||||
m := mock.NewMockConfiguration(gomock.NewController(t))
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"not supported, empty",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
|
||||
return m
|
||||
}()},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"supported, empty",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
|
||||
m.EXPECT().TokenEndpointSigningAlgorithmsSupported().Return(nil)
|
||||
return m
|
||||
}()},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"supported, list",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
|
||||
m.EXPECT().TokenEndpointSigningAlgorithmsSupported().Return([]string{"RS256"})
|
||||
return m
|
||||
}()},
|
||||
[]string{"RS256"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := op.TokenSigAlgorithms(tt.args.c)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IntrospectionSigAlgorithms(t *testing.T) {
|
||||
m := mock.NewMockConfiguration(gomock.NewController(t))
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"not supported, empty",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().IntrospectionAuthMethodPrivateKeyJWTSupported().Return(false)
|
||||
return m
|
||||
}()},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"supported, empty",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().IntrospectionAuthMethodPrivateKeyJWTSupported().Return(true)
|
||||
m.EXPECT().IntrospectionEndpointSigningAlgorithmsSupported().Return(nil)
|
||||
return m
|
||||
}()},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"supported, list",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().IntrospectionAuthMethodPrivateKeyJWTSupported().Return(true)
|
||||
m.EXPECT().IntrospectionEndpointSigningAlgorithmsSupported().Return([]string{"RS256"})
|
||||
return m
|
||||
}()},
|
||||
[]string{"RS256"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := op.IntrospectionSigAlgorithms(tt.args.c)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_AuthMethodsIntrospectionEndpoint(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []oidc.AuthMethod
|
||||
}{
|
||||
{
|
||||
"basic only",
|
||||
args{func() op.Configuration {
|
||||
m := mock.NewMockConfiguration(gomock.NewController(t))
|
||||
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
|
||||
return m
|
||||
}()},
|
||||
[]oidc.AuthMethod{oidc.AuthMethodBasic},
|
||||
},
|
||||
{
|
||||
"basic and private_key_jwt",
|
||||
args{func() op.Configuration {
|
||||
m := mock.NewMockConfiguration(gomock.NewController(t))
|
||||
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
|
||||
return m
|
||||
}()},
|
||||
[]oidc.AuthMethod{oidc.AuthMethodBasic, oidc.AuthMethodPrivateKeyJWT},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := op.AuthMethodsIntrospectionEndpoint(tt.args.c)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_RevocationSigAlgorithms(t *testing.T) {
|
||||
m := mock.NewMockConfiguration(gomock.NewController(t))
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"not supported, empty",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().RevocationAuthMethodPrivateKeyJWTSupported().Return(false)
|
||||
return m
|
||||
}()},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"supported, empty",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().RevocationAuthMethodPrivateKeyJWTSupported().Return(true)
|
||||
m.EXPECT().RevocationEndpointSigningAlgorithmsSupported().Return(nil)
|
||||
return m
|
||||
}()},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"supported, list",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().RevocationAuthMethodPrivateKeyJWTSupported().Return(true)
|
||||
m.EXPECT().RevocationEndpointSigningAlgorithmsSupported().Return([]string{"RS256"})
|
||||
return m
|
||||
}()},
|
||||
[]string{"RS256"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := op.RevocationSigAlgorithms(tt.args.c)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_AuthMethodsRevocationEndpoint(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []oidc.AuthMethod
|
||||
}{
|
||||
{
|
||||
"none and basic",
|
||||
args{func() op.Configuration {
|
||||
m := mock.NewMockConfiguration(gomock.NewController(t))
|
||||
m.EXPECT().AuthMethodPostSupported().Return(false)
|
||||
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
|
||||
return m
|
||||
}()},
|
||||
[]oidc.AuthMethod{oidc.AuthMethodNone, oidc.AuthMethodBasic},
|
||||
},
|
||||
{
|
||||
"none, basic and post",
|
||||
args{func() op.Configuration {
|
||||
m := mock.NewMockConfiguration(gomock.NewController(t))
|
||||
m.EXPECT().AuthMethodPostSupported().Return(true)
|
||||
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
|
||||
return m
|
||||
}()},
|
||||
[]oidc.AuthMethod{oidc.AuthMethodNone, oidc.AuthMethodBasic, oidc.AuthMethodPost},
|
||||
},
|
||||
{
|
||||
"none, basic, post and private_key_jwt",
|
||||
args{func() op.Configuration {
|
||||
m := mock.NewMockConfiguration(gomock.NewController(t))
|
||||
m.EXPECT().AuthMethodPostSupported().Return(true)
|
||||
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
|
||||
return m
|
||||
}()},
|
||||
[]oidc.AuthMethod{oidc.AuthMethodNone, oidc.AuthMethodBasic, oidc.AuthMethodPost, oidc.AuthMethodPrivateKeyJWT},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := op.AuthMethodsRevocationEndpoint(tt.args.c)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedClaims(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"scopes",
|
||||
args{},
|
||||
[]string{
|
||||
"sub",
|
||||
"aud",
|
||||
"exp",
|
||||
"iat",
|
||||
"iss",
|
||||
"auth_time",
|
||||
"nonce",
|
||||
"acr",
|
||||
"amr",
|
||||
"c_hash",
|
||||
"at_hash",
|
||||
"act",
|
||||
"scopes",
|
||||
"client_id",
|
||||
"azp",
|
||||
"preferred_username",
|
||||
"name",
|
||||
"family_name",
|
||||
"given_name",
|
||||
"locale",
|
||||
"email",
|
||||
"email_verified",
|
||||
"phone_number",
|
||||
"phone_number_verified",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := op.SupportedClaims(tt.args.c)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CodeChallengeMethods(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []oidc.CodeChallengeMethod
|
||||
}{
|
||||
{
|
||||
"not supported",
|
||||
args{func() op.Configuration {
|
||||
m := mock.NewMockConfiguration(gomock.NewController(t))
|
||||
m.EXPECT().CodeMethodS256Supported().Return(false)
|
||||
return m
|
||||
}()},
|
||||
[]oidc.CodeChallengeMethod{},
|
||||
},
|
||||
{
|
||||
"S256",
|
||||
args{func() op.Configuration {
|
||||
m := mock.NewMockConfiguration(gomock.NewController(t))
|
||||
m.EXPECT().CodeMethodS256Supported().Return(true)
|
||||
return m
|
||||
}()},
|
||||
[]oidc.CodeChallengeMethod{oidc.CodeChallengeMethodS256},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := op.CodeChallengeMethods(tt.args.c)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ package op_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/op"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
)
|
||||
|
||||
func TestEndpoint_Path(t *testing.T) {
|
||||
|
|
|
@ -3,8 +3,8 @@ package op
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type ErrAuthRequest interface {
|
||||
|
|
|
@ -6,11 +6,11 @@ import (
|
|||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
)
|
||||
|
||||
type KeyProvider interface {
|
||||
GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
|
||||
KeySet(context.Context) ([]Key, error)
|
||||
}
|
||||
|
||||
func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) {
|
||||
|
@ -20,10 +20,23 @@ func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) {
|
|||
}
|
||||
|
||||
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
|
||||
keySet, err := k.GetKeySet(r.Context())
|
||||
keySet, err := k.KeySet(r.Context())
|
||||
if err != nil {
|
||||
httphelper.MarshalJSONWithStatus(w, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
httphelper.MarshalJSON(w, keySet)
|
||||
httphelper.MarshalJSON(w, jsonWebKeySet(keySet))
|
||||
}
|
||||
|
||||
func jsonWebKeySet(keys []Key) *jose.JSONWebKeySet {
|
||||
webKeys := make([]jose.JSONWebKey, len(keys))
|
||||
for i, key := range keys {
|
||||
webKeys[i] = jose.JSONWebKey{
|
||||
KeyID: key.ID(),
|
||||
Algorithm: string(key.Algorithm()),
|
||||
Use: key.Use(),
|
||||
Key: key.Key(),
|
||||
}
|
||||
}
|
||||
return &jose.JSONWebKeySet{Keys: webKeys}
|
||||
}
|
||||
|
|
|
@ -11,9 +11,9 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/pkg/op"
|
||||
"github.com/zitadel/oidc/pkg/op/mock"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v2/pkg/op/mock"
|
||||
)
|
||||
|
||||
func TestKeys(t *testing.T) {
|
||||
|
@ -35,7 +35,7 @@ func TestKeys(t *testing.T) {
|
|||
args: args{
|
||||
k: func() op.KeyProvider {
|
||||
m := mock.NewMockKeyProvider(gomock.NewController(t))
|
||||
m.EXPECT().GetKeySet(gomock.Any()).Return(nil, oidc.ErrServerError())
|
||||
m.EXPECT().KeySet(gomock.Any()).Return(nil, oidc.ErrServerError())
|
||||
return m
|
||||
}(),
|
||||
},
|
||||
|
@ -51,39 +51,39 @@ func TestKeys(t *testing.T) {
|
|||
args: args{
|
||||
k: func() op.KeyProvider {
|
||||
m := mock.NewMockKeyProvider(gomock.NewController(t))
|
||||
m.EXPECT().GetKeySet(gomock.Any()).Return(nil, nil)
|
||||
m.EXPECT().KeySet(gomock.Any()).Return(nil, nil)
|
||||
return m
|
||||
}(),
|
||||
},
|
||||
res: res{
|
||||
statusCode: http.StatusOK,
|
||||
contentType: "application/json",
|
||||
body: `{"keys":[]}
|
||||
`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list",
|
||||
args: args{
|
||||
k: func() op.KeyProvider {
|
||||
m := mock.NewMockKeyProvider(gomock.NewController(t))
|
||||
m.EXPECT().GetKeySet(gomock.Any()).Return(
|
||||
&jose.JSONWebKeySet{Keys: []jose.JSONWebKey{
|
||||
{
|
||||
Key: &rsa.PublicKey{
|
||||
N: big.NewInt(1),
|
||||
E: 1,
|
||||
},
|
||||
KeyID: "id",
|
||||
},
|
||||
}},
|
||||
nil,
|
||||
)
|
||||
ctrl := gomock.NewController(t)
|
||||
m := mock.NewMockKeyProvider(ctrl)
|
||||
k := mock.NewMockKey(ctrl)
|
||||
k.EXPECT().Key().Return(&rsa.PublicKey{
|
||||
N: big.NewInt(1),
|
||||
E: 1,
|
||||
})
|
||||
k.EXPECT().ID().Return("id")
|
||||
k.EXPECT().Algorithm().Return(jose.RS256)
|
||||
k.EXPECT().Use().Return("sig")
|
||||
m.EXPECT().KeySet(gomock.Any()).Return([]op.Key{k}, nil)
|
||||
return m
|
||||
}(),
|
||||
},
|
||||
res: res{
|
||||
statusCode: http.StatusOK,
|
||||
contentType: "application/json",
|
||||
body: `{"keys":[{"kty":"RSA","kid":"id","n":"AQ","e":"AQ"}]}
|
||||
body: `{"keys":[{"use":"sig","kty":"RSA","kid":"id","alg":"RS256","n":"AQ","e":"AQ"}]}
|
||||
`,
|
||||
},
|
||||
},
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/pkg/op (interfaces: Authorizer)
|
||||
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Authorizer)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
http "github.com/zitadel/oidc/pkg/http"
|
||||
op "github.com/zitadel/oidc/pkg/op"
|
||||
http "github.com/zitadel/oidc/v2/pkg/http"
|
||||
op "github.com/zitadel/oidc/v2/pkg/op"
|
||||
)
|
||||
|
||||
// MockAuthorizer is a mock of Authorizer interface.
|
||||
|
@ -78,31 +79,17 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
|
|||
}
|
||||
|
||||
// IDTokenHintVerifier mocks base method.
|
||||
func (m *MockAuthorizer) IDTokenHintVerifier() op.IDTokenHintVerifier {
|
||||
func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) op.IDTokenHintVerifier {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IDTokenHintVerifier")
|
||||
ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0)
|
||||
ret0, _ := ret[0].(op.IDTokenHintVerifier)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IDTokenHintVerifier indicates an expected call of IDTokenHintVerifier.
|
||||
func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier() *gomock.Call {
|
||||
func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier))
|
||||
}
|
||||
|
||||
// Issuer mocks base method.
|
||||
func (m *MockAuthorizer) Issuer() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Issuer")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Issuer indicates an expected call of Issuer.
|
||||
func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier), arg0)
|
||||
}
|
||||
|
||||
// RequestObjectSupported mocks base method.
|
||||
|
@ -119,20 +106,6 @@ func (mr *MockAuthorizerMockRecorder) RequestObjectSupported() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSupported", reflect.TypeOf((*MockAuthorizer)(nil).RequestObjectSupported))
|
||||
}
|
||||
|
||||
// Signer mocks base method.
|
||||
func (m *MockAuthorizer) Signer() op.Signer {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Signer")
|
||||
ret0, _ := ret[0].(op.Signer)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Signer indicates an expected call of Signer.
|
||||
func (mr *MockAuthorizerMockRecorder) Signer() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockAuthorizer)(nil).Signer))
|
||||
}
|
||||
|
||||
// Storage mocks base method.
|
||||
func (m *MockAuthorizer) Storage() op.Storage {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -8,8 +8,8 @@ import (
|
|||
"github.com/gorilla/schema"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/pkg/op"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
)
|
||||
|
||||
func NewAuthorizer(t *testing.T) op.Authorizer {
|
||||
|
@ -20,23 +20,13 @@ func NewAuthorizerExpectValid(t *testing.T, wantErr bool) op.Authorizer {
|
|||
m := NewAuthorizer(t)
|
||||
ExpectDecoder(m)
|
||||
ExpectEncoder(m)
|
||||
ExpectSigner(m, t)
|
||||
//ExpectSigner(m, t)
|
||||
ExpectStorage(m, t)
|
||||
ExpectVerifier(m, t)
|
||||
// ExpectErrorHandler(m, t, wantErr)
|
||||
return m
|
||||
}
|
||||
|
||||
// func NewAuthorizerExpectDecoderFails(t *testing.T) op.Authorizer {
|
||||
// m := NewAuthorizer(t)
|
||||
// ExpectDecoderFails(m)
|
||||
// ExpectEncoder(m)
|
||||
// ExpectSigner(m, t)
|
||||
// ExpectStorage(m, t)
|
||||
// ExpectErrorHandler(m, t)
|
||||
// return m
|
||||
// }
|
||||
|
||||
func ExpectDecoder(a op.Authorizer) {
|
||||
mockA := a.(*MockAuthorizer)
|
||||
mockA.EXPECT().Decoder().AnyTimes().Return(schema.NewDecoder())
|
||||
|
@ -47,17 +37,18 @@ func ExpectEncoder(a op.Authorizer) {
|
|||
mockA.EXPECT().Encoder().AnyTimes().Return(schema.NewEncoder())
|
||||
}
|
||||
|
||||
func ExpectSigner(a op.Authorizer, t *testing.T) {
|
||||
mockA := a.(*MockAuthorizer)
|
||||
mockA.EXPECT().Signer().DoAndReturn(
|
||||
func() op.Signer {
|
||||
return &Sig{}
|
||||
})
|
||||
}
|
||||
//
|
||||
//func ExpectSigner(a op.Authorizer, t *testing.T) {
|
||||
// mockA := a.(*MockAuthorizer)
|
||||
// mockA.EXPECT().Signer().DoAndReturn(
|
||||
// func() op.Signer {
|
||||
// return &Sig{}
|
||||
// })
|
||||
//}
|
||||
|
||||
func ExpectVerifier(a op.Authorizer, t *testing.T) {
|
||||
mockA := a.(*MockAuthorizer)
|
||||
mockA.EXPECT().IDTokenHintVerifier().DoAndReturn(
|
||||
mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn(
|
||||
func() op.IDTokenHintVerifier {
|
||||
return op.NewIDTokenHintVerifier("", nil)
|
||||
})
|
||||
|
|
|
@ -5,8 +5,8 @@ import (
|
|||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/pkg/op"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
)
|
||||
|
||||
func NewClient(t *testing.T) op.Client {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/pkg/op (interfaces: Client)
|
||||
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Client)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
@ -9,8 +9,8 @@ import (
|
|||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
oidc "github.com/zitadel/oidc/pkg/oidc"
|
||||
op "github.com/zitadel/oidc/pkg/op"
|
||||
oidc "github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
op "github.com/zitadel/oidc/v2/pkg/op"
|
||||
)
|
||||
|
||||
// MockClient is a mock of Client interface.
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/pkg/op (interfaces: Configuration)
|
||||
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Configuration)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
http "net/http"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
op "github.com/zitadel/oidc/pkg/op"
|
||||
op "github.com/zitadel/oidc/v2/pkg/op"
|
||||
language "golang.org/x/text/language"
|
||||
)
|
||||
|
||||
|
@ -91,6 +92,34 @@ func (mr *MockConfigurationMockRecorder) CodeMethodS256Supported() *gomock.Call
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CodeMethodS256Supported", reflect.TypeOf((*MockConfiguration)(nil).CodeMethodS256Supported))
|
||||
}
|
||||
|
||||
// DeviceAuthorization mocks base method.
|
||||
func (m *MockConfiguration) DeviceAuthorization() op.DeviceAuthorizationConfig {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeviceAuthorization")
|
||||
ret0, _ := ret[0].(op.DeviceAuthorizationConfig)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeviceAuthorization indicates an expected call of DeviceAuthorization.
|
||||
func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorization", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorization))
|
||||
}
|
||||
|
||||
// DeviceAuthorizationEndpoint mocks base method.
|
||||
func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeviceAuthorizationEndpoint indicates an expected call of DeviceAuthorizationEndpoint.
|
||||
func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorizationEndpoint))
|
||||
}
|
||||
|
||||
// EndSessionEndpoint mocks base method.
|
||||
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -119,6 +148,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeClientCredentialsSupported() *
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeClientCredentialsSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeClientCredentialsSupported))
|
||||
}
|
||||
|
||||
// GrantTypeDeviceCodeSupported mocks base method.
|
||||
func (m *MockConfiguration) GrantTypeDeviceCodeSupported() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GrantTypeDeviceCodeSupported")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GrantTypeDeviceCodeSupported indicates an expected call of GrantTypeDeviceCodeSupported.
|
||||
func (mr *MockConfigurationMockRecorder) GrantTypeDeviceCodeSupported() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeDeviceCodeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeDeviceCodeSupported))
|
||||
}
|
||||
|
||||
// GrantTypeJWTAuthorizationSupported mocks base method.
|
||||
func (m *MockConfiguration) GrantTypeJWTAuthorizationSupported() bool {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -161,6 +204,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeTokenExchangeSupported() *gomo
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeTokenExchangeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeTokenExchangeSupported))
|
||||
}
|
||||
|
||||
// Insecure mocks base method.
|
||||
func (m *MockConfiguration) Insecure() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Insecure")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Insecure indicates an expected call of Insecure.
|
||||
func (mr *MockConfigurationMockRecorder) Insecure() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insecure", reflect.TypeOf((*MockConfiguration)(nil).Insecure))
|
||||
}
|
||||
|
||||
// IntrospectionAuthMethodPrivateKeyJWTSupported mocks base method.
|
||||
func (m *MockConfiguration) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -203,18 +260,18 @@ func (mr *MockConfigurationMockRecorder) IntrospectionEndpointSigningAlgorithmsS
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionEndpointSigningAlgorithmsSupported))
|
||||
}
|
||||
|
||||
// Issuer mocks base method.
|
||||
func (m *MockConfiguration) Issuer() string {
|
||||
// IssuerFromRequest mocks base method.
|
||||
func (m *MockConfiguration) IssuerFromRequest(arg0 *http.Request) string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Issuer")
|
||||
ret := m.ctrl.Call(m, "IssuerFromRequest", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Issuer indicates an expected call of Issuer.
|
||||
func (mr *MockConfigurationMockRecorder) Issuer() *gomock.Call {
|
||||
// IssuerFromRequest indicates an expected call of IssuerFromRequest.
|
||||
func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockConfiguration)(nil).Issuer))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IssuerFromRequest", reflect.TypeOf((*MockConfiguration)(nil).IssuerFromRequest), arg0)
|
||||
}
|
||||
|
||||
// KeysEndpoint mocks base method.
|
||||
|
|
51
pkg/op/mock/discovery.mock.go
Normal file
51
pkg/op/mock/discovery.mock.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: DiscoverStorage)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// MockDiscoverStorage is a mock of DiscoverStorage interface.
|
||||
type MockDiscoverStorage struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockDiscoverStorageMockRecorder
|
||||
}
|
||||
|
||||
// MockDiscoverStorageMockRecorder is the mock recorder for MockDiscoverStorage.
|
||||
type MockDiscoverStorageMockRecorder struct {
|
||||
mock *MockDiscoverStorage
|
||||
}
|
||||
|
||||
// NewMockDiscoverStorage creates a new mock instance.
|
||||
func NewMockDiscoverStorage(ctrl *gomock.Controller) *MockDiscoverStorage {
|
||||
mock := &MockDiscoverStorage{ctrl: ctrl}
|
||||
mock.recorder = &MockDiscoverStorageMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockDiscoverStorage) EXPECT() *MockDiscoverStorageMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// SignatureAlgorithms mocks base method.
|
||||
func (m *MockDiscoverStorage) SignatureAlgorithms(arg0 context.Context) ([]jose.SignatureAlgorithm, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SignatureAlgorithms", arg0)
|
||||
ret0, _ := ret[0].([]jose.SignatureAlgorithm)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SignatureAlgorithms indicates an expected call of SignatureAlgorithms.
|
||||
func (mr *MockDiscoverStorageMockRecorder) SignatureAlgorithms(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithms", reflect.TypeOf((*MockDiscoverStorage)(nil).SignatureAlgorithms), arg0)
|
||||
}
|
|
@ -1,8 +1,10 @@
|
|||
package mock
|
||||
|
||||
//go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/pkg/op Storage
|
||||
//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/pkg/op Authorizer
|
||||
//go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/pkg/op Client
|
||||
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/pkg/op Configuration
|
||||
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/pkg/op Signer
|
||||
//go:generate mockgen -package mock -destination ./key.mock.go github.com/zitadel/oidc/pkg/op KeyProvider
|
||||
//go:generate go install github.com/golang/mock/mockgen@v1.6.0
|
||||
//go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/v2/pkg/op Storage
|
||||
//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/v2/pkg/op Authorizer
|
||||
//go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/v2/pkg/op Client
|
||||
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/v2/pkg/op Configuration
|
||||
//go:generate mockgen -package mock -destination ./discovery.mock.go github.com/zitadel/oidc/v2/pkg/op DiscoverStorage
|
||||
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/v2/pkg/op SigningKey,Key
|
||||
//go:generate mockgen -package mock -destination ./key.mock.go github.com/zitadel/oidc/v2/pkg/op KeyProvider
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/pkg/op (interfaces: KeyProvider)
|
||||
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: KeyProvider)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
@ -9,7 +9,7 @@ import (
|
|||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
op "github.com/zitadel/oidc/v2/pkg/op"
|
||||
)
|
||||
|
||||
// MockKeyProvider is a mock of KeyProvider interface.
|
||||
|
@ -35,17 +35,17 @@ func (m *MockKeyProvider) EXPECT() *MockKeyProviderMockRecorder {
|
|||
return m.recorder
|
||||
}
|
||||
|
||||
// GetKeySet mocks base method.
|
||||
func (m *MockKeyProvider) GetKeySet(arg0 context.Context) (*jose.JSONWebKeySet, error) {
|
||||
// KeySet mocks base method.
|
||||
func (m *MockKeyProvider) KeySet(arg0 context.Context) ([]op.Key, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetKeySet", arg0)
|
||||
ret0, _ := ret[0].(*jose.JSONWebKeySet)
|
||||
ret := m.ctrl.Call(m, "KeySet", arg0)
|
||||
ret0, _ := ret[0].([]op.Key)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetKeySet indicates an expected call of GetKeySet.
|
||||
func (mr *MockKeyProviderMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
|
||||
// KeySet indicates an expected call of KeySet.
|
||||
func (mr *MockKeyProviderMockRecorder) KeySet(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockKeyProvider)(nil).GetKeySet), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeySet", reflect.TypeOf((*MockKeyProvider)(nil).KeySet), arg0)
|
||||
}
|
||||
|
|
|
@ -1,56 +1,69 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/pkg/op (interfaces: Signer)
|
||||
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: SigningKey,Key)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// MockSigner is a mock of Signer interface.
|
||||
type MockSigner struct {
|
||||
// MockSigningKey is a mock of SigningKey interface.
|
||||
type MockSigningKey struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockSignerMockRecorder
|
||||
recorder *MockSigningKeyMockRecorder
|
||||
}
|
||||
|
||||
// MockSignerMockRecorder is the mock recorder for MockSigner.
|
||||
type MockSignerMockRecorder struct {
|
||||
mock *MockSigner
|
||||
// MockSigningKeyMockRecorder is the mock recorder for MockSigningKey.
|
||||
type MockSigningKeyMockRecorder struct {
|
||||
mock *MockSigningKey
|
||||
}
|
||||
|
||||
// NewMockSigner creates a new mock instance.
|
||||
func NewMockSigner(ctrl *gomock.Controller) *MockSigner {
|
||||
mock := &MockSigner{ctrl: ctrl}
|
||||
mock.recorder = &MockSignerMockRecorder{mock}
|
||||
// NewMockSigningKey creates a new mock instance.
|
||||
func NewMockSigningKey(ctrl *gomock.Controller) *MockSigningKey {
|
||||
mock := &MockSigningKey{ctrl: ctrl}
|
||||
mock.recorder = &MockSigningKeyMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockSigner) EXPECT() *MockSignerMockRecorder {
|
||||
func (m *MockSigningKey) EXPECT() *MockSigningKeyMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Health mocks base method.
|
||||
func (m *MockSigner) Health(arg0 context.Context) error {
|
||||
// ID mocks base method.
|
||||
func (m *MockSigningKey) ID() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Health", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
ret := m.ctrl.Call(m, "ID")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Health indicates an expected call of Health.
|
||||
func (mr *MockSignerMockRecorder) Health(arg0 interface{}) *gomock.Call {
|
||||
// ID indicates an expected call of ID.
|
||||
func (mr *MockSigningKeyMockRecorder) ID() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockSigner)(nil).Health), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockSigningKey)(nil).ID))
|
||||
}
|
||||
|
||||
// Key mocks base method.
|
||||
func (m *MockSigningKey) Key() interface{} {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Key")
|
||||
ret0, _ := ret[0].(interface{})
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Key indicates an expected call of Key.
|
||||
func (mr *MockSigningKeyMockRecorder) Key() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Key", reflect.TypeOf((*MockSigningKey)(nil).Key))
|
||||
}
|
||||
|
||||
// SignatureAlgorithm mocks base method.
|
||||
func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||
func (m *MockSigningKey) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SignatureAlgorithm")
|
||||
ret0, _ := ret[0].(jose.SignatureAlgorithm)
|
||||
|
@ -58,21 +71,86 @@ func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
|
|||
}
|
||||
|
||||
// SignatureAlgorithm indicates an expected call of SignatureAlgorithm.
|
||||
func (mr *MockSignerMockRecorder) SignatureAlgorithm() *gomock.Call {
|
||||
func (mr *MockSigningKeyMockRecorder) SignatureAlgorithm() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigner)(nil).SignatureAlgorithm))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigningKey)(nil).SignatureAlgorithm))
|
||||
}
|
||||
|
||||
// Signer mocks base method.
|
||||
func (m *MockSigner) Signer() jose.Signer {
|
||||
// MockKey is a mock of Key interface.
|
||||
type MockKey struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockKeyMockRecorder
|
||||
}
|
||||
|
||||
// MockKeyMockRecorder is the mock recorder for MockKey.
|
||||
type MockKeyMockRecorder struct {
|
||||
mock *MockKey
|
||||
}
|
||||
|
||||
// NewMockKey creates a new mock instance.
|
||||
func NewMockKey(ctrl *gomock.Controller) *MockKey {
|
||||
mock := &MockKey{ctrl: ctrl}
|
||||
mock.recorder = &MockKeyMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockKey) EXPECT() *MockKeyMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Algorithm mocks base method.
|
||||
func (m *MockKey) Algorithm() jose.SignatureAlgorithm {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Signer")
|
||||
ret0, _ := ret[0].(jose.Signer)
|
||||
ret := m.ctrl.Call(m, "Algorithm")
|
||||
ret0, _ := ret[0].(jose.SignatureAlgorithm)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Signer indicates an expected call of Signer.
|
||||
func (mr *MockSignerMockRecorder) Signer() *gomock.Call {
|
||||
// Algorithm indicates an expected call of Algorithm.
|
||||
func (mr *MockKeyMockRecorder) Algorithm() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockSigner)(nil).Signer))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Algorithm", reflect.TypeOf((*MockKey)(nil).Algorithm))
|
||||
}
|
||||
|
||||
// ID mocks base method.
|
||||
func (m *MockKey) ID() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ID")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ID indicates an expected call of ID.
|
||||
func (mr *MockKeyMockRecorder) ID() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockKey)(nil).ID))
|
||||
}
|
||||
|
||||
// Key mocks base method.
|
||||
func (m *MockKey) Key() interface{} {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Key")
|
||||
ret0, _ := ret[0].(interface{})
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Key indicates an expected call of Key.
|
||||
func (mr *MockKeyMockRecorder) Key() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Key", reflect.TypeOf((*MockKey)(nil).Key))
|
||||
}
|
||||
|
||||
// Use mocks base method.
|
||||
func (m *MockKey) Use() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Use")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Use indicates an expected call of Use.
|
||||
func (mr *MockKeyMockRecorder) Use() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Use", reflect.TypeOf((*MockKey)(nil).Use))
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/pkg/op (interfaces: Storage)
|
||||
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Storage)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
@ -10,8 +10,8 @@ import (
|
|||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
oidc "github.com/zitadel/oidc/pkg/oidc"
|
||||
op "github.com/zitadel/oidc/pkg/op"
|
||||
oidc "github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
op "github.com/zitadel/oidc/v2/pkg/op"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
|
@ -159,34 +159,19 @@ func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetKeyByIDAndUserID mocks base method.
|
||||
func (m *MockStorage) GetKeyByIDAndUserID(arg0 context.Context, arg1, arg2 string) (*jose.JSONWebKey, error) {
|
||||
// GetKeyByIDAndClientID mocks base method.
|
||||
func (m *MockStorage) GetKeyByIDAndClientID(arg0 context.Context, arg1, arg2 string) (*jose.JSONWebKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetKeyByIDAndUserID", arg0, arg1, arg2)
|
||||
ret := m.ctrl.Call(m, "GetKeyByIDAndClientID", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(*jose.JSONWebKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetKeyByIDAndUserID indicates an expected call of GetKeyByIDAndUserID.
|
||||
func (mr *MockStorageMockRecorder) GetKeyByIDAndUserID(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
// GetKeyByIDAndClientID indicates an expected call of GetKeyByIDAndClientID.
|
||||
func (mr *MockStorageMockRecorder) GetKeyByIDAndClientID(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeyByIDAndUserID", reflect.TypeOf((*MockStorage)(nil).GetKeyByIDAndUserID), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// GetKeySet mocks base method.
|
||||
func (m *MockStorage) GetKeySet(arg0 context.Context) (*jose.JSONWebKeySet, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetKeySet", arg0)
|
||||
ret0, _ := ret[0].(*jose.JSONWebKeySet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetKeySet indicates an expected call of GetKeySet.
|
||||
func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeyByIDAndClientID", reflect.TypeOf((*MockStorage)(nil).GetKeyByIDAndClientID), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// GetPrivateClaimsFromScopes mocks base method.
|
||||
|
@ -204,16 +189,20 @@ func (mr *MockStorageMockRecorder) GetPrivateClaimsFromScopes(arg0, arg1, arg2,
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivateClaimsFromScopes", reflect.TypeOf((*MockStorage)(nil).GetPrivateClaimsFromScopes), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// GetSigningKey mocks base method.
|
||||
func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey) {
|
||||
// GetRefreshTokenInfo mocks base method.
|
||||
func (m *MockStorage) GetRefreshTokenInfo(arg0 context.Context, arg1, arg2 string) (string, string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "GetSigningKey", arg0, arg1)
|
||||
ret := m.ctrl.Call(m, "GetRefreshTokenInfo", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(string)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// GetSigningKey indicates an expected call of GetSigningKey.
|
||||
func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1 interface{}) *gomock.Call {
|
||||
// GetRefreshTokenInfo indicates an expected call of GetRefreshTokenInfo.
|
||||
func (mr *MockStorageMockRecorder) GetRefreshTokenInfo(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey), arg0, arg1)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRefreshTokenInfo", reflect.TypeOf((*MockStorage)(nil).GetRefreshTokenInfo), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// Health mocks base method.
|
||||
|
@ -230,6 +219,21 @@ func (mr *MockStorageMockRecorder) Health(arg0 interface{}) *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0)
|
||||
}
|
||||
|
||||
// KeySet mocks base method.
|
||||
func (m *MockStorage) KeySet(arg0 context.Context) ([]op.Key, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "KeySet", arg0)
|
||||
ret0, _ := ret[0].([]op.Key)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// KeySet indicates an expected call of KeySet.
|
||||
func (mr *MockStorageMockRecorder) KeySet(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeySet", reflect.TypeOf((*MockStorage)(nil).KeySet), arg0)
|
||||
}
|
||||
|
||||
// RevokeToken mocks base method.
|
||||
func (m *MockStorage) RevokeToken(arg0 context.Context, arg1, arg2, arg3 string) *oidc.Error {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -259,7 +263,7 @@ func (mr *MockStorageMockRecorder) SaveAuthCode(arg0, arg1, arg2 interface{}) *g
|
|||
}
|
||||
|
||||
// SetIntrospectionFromToken mocks base method.
|
||||
func (m *MockStorage) SetIntrospectionFromToken(arg0 context.Context, arg1 oidc.IntrospectionResponse, arg2, arg3, arg4 string) error {
|
||||
func (m *MockStorage) SetIntrospectionFromToken(arg0 context.Context, arg1 *oidc.IntrospectionResponse, arg2, arg3, arg4 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetIntrospectionFromToken", arg0, arg1, arg2, arg3, arg4)
|
||||
ret0, _ := ret[0].(error)
|
||||
|
@ -273,7 +277,7 @@ func (mr *MockStorageMockRecorder) SetIntrospectionFromToken(arg0, arg1, arg2, a
|
|||
}
|
||||
|
||||
// SetUserinfoFromScopes mocks base method.
|
||||
func (m *MockStorage) SetUserinfoFromScopes(arg0 context.Context, arg1 oidc.UserInfoSetter, arg2, arg3 string, arg4 []string) error {
|
||||
func (m *MockStorage) SetUserinfoFromScopes(arg0 context.Context, arg1 *oidc.UserInfo, arg2, arg3 string, arg4 []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetUserinfoFromScopes", arg0, arg1, arg2, arg3, arg4)
|
||||
ret0, _ := ret[0].(error)
|
||||
|
@ -287,7 +291,7 @@ func (mr *MockStorageMockRecorder) SetUserinfoFromScopes(arg0, arg1, arg2, arg3,
|
|||
}
|
||||
|
||||
// SetUserinfoFromToken mocks base method.
|
||||
func (m *MockStorage) SetUserinfoFromToken(arg0 context.Context, arg1 oidc.UserInfoSetter, arg2, arg3, arg4 string) error {
|
||||
func (m *MockStorage) SetUserinfoFromToken(arg0 context.Context, arg1 *oidc.UserInfo, arg2, arg3, arg4 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetUserinfoFromToken", arg0, arg1, arg2, arg3, arg4)
|
||||
ret0, _ := ret[0].(error)
|
||||
|
@ -300,6 +304,36 @@ func (mr *MockStorageMockRecorder) SetUserinfoFromToken(arg0, arg1, arg2, arg3,
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUserinfoFromToken", reflect.TypeOf((*MockStorage)(nil).SetUserinfoFromToken), arg0, arg1, arg2, arg3, arg4)
|
||||
}
|
||||
|
||||
// SignatureAlgorithms mocks base method.
|
||||
func (m *MockStorage) SignatureAlgorithms(arg0 context.Context) ([]jose.SignatureAlgorithm, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SignatureAlgorithms", arg0)
|
||||
ret0, _ := ret[0].([]jose.SignatureAlgorithm)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SignatureAlgorithms indicates an expected call of SignatureAlgorithms.
|
||||
func (mr *MockStorageMockRecorder) SignatureAlgorithms(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithms", reflect.TypeOf((*MockStorage)(nil).SignatureAlgorithms), arg0)
|
||||
}
|
||||
|
||||
// SigningKey mocks base method.
|
||||
func (m *MockStorage) SigningKey(arg0 context.Context) (op.SigningKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SigningKey", arg0)
|
||||
ret0, _ := ret[0].(op.SigningKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SigningKey indicates an expected call of SigningKey.
|
||||
func (mr *MockStorageMockRecorder) SigningKey(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SigningKey", reflect.TypeOf((*MockStorage)(nil).SigningKey), arg0)
|
||||
}
|
||||
|
||||
// TerminateSession mocks base method.
|
||||
func (m *MockStorage) TerminateSession(arg0 context.Context, arg1, arg2 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -6,13 +6,10 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/op"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
)
|
||||
|
||||
func NewStorage(t *testing.T) op.Storage {
|
||||
|
@ -41,13 +38,13 @@ func NewMockStorageAny(t *testing.T) op.Storage {
|
|||
|
||||
func NewMockStorageSigningKeyInvalid(t *testing.T) op.Storage {
|
||||
m := NewStorage(t)
|
||||
ExpectSigningKeyInvalid(m)
|
||||
//ExpectSigningKeyInvalid(m)
|
||||
return m
|
||||
}
|
||||
|
||||
func NewMockStorageSigningKey(t *testing.T) op.Storage {
|
||||
m := NewStorage(t)
|
||||
ExpectSigningKey(m)
|
||||
//ExpectSigningKey(m)
|
||||
return m
|
||||
}
|
||||
|
||||
|
@ -85,24 +82,6 @@ func ExpectValidClientID(s op.Storage) {
|
|||
})
|
||||
}
|
||||
|
||||
func ExpectSigningKeyInvalid(s op.Storage) {
|
||||
mockS := s.(*MockStorage)
|
||||
mockS.EXPECT().GetSigningKey(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, keyCh chan<- jose.SigningKey) {
|
||||
keyCh <- jose.SigningKey{}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func ExpectSigningKey(s op.Storage) {
|
||||
mockS := s.(*MockStorage)
|
||||
mockS.EXPECT().GetSigningKey(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, keyCh chan<- jose.SigningKey) {
|
||||
keyCh <- jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
type ConfClient struct {
|
||||
id string
|
||||
appType op.ApplicationType
|
||||
|
|
352
pkg/op/op.go
352
pkg/op/op.go
|
@ -12,8 +12,8 @@ import (
|
|||
"golang.org/x/text/language"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -27,80 +27,84 @@ const (
|
|||
defaultRevocationEndpoint = "revoke"
|
||||
defaultEndSessionEndpoint = "end_session"
|
||||
defaultKeysEndpoint = "keys"
|
||||
defaultDeviceAuthzEndpoint = "/device_authorization"
|
||||
)
|
||||
|
||||
var DefaultEndpoints = &endpoints{
|
||||
Authorization: NewEndpoint(defaultAuthorizationEndpoint),
|
||||
Token: NewEndpoint(defaultTokenEndpoint),
|
||||
Introspection: NewEndpoint(defaultIntrospectEndpoint),
|
||||
Userinfo: NewEndpoint(defaultUserinfoEndpoint),
|
||||
Revocation: NewEndpoint(defaultRevocationEndpoint),
|
||||
EndSession: NewEndpoint(defaultEndSessionEndpoint),
|
||||
JwksURI: NewEndpoint(defaultKeysEndpoint),
|
||||
}
|
||||
var (
|
||||
DefaultEndpoints = &endpoints{
|
||||
Authorization: NewEndpoint(defaultAuthorizationEndpoint),
|
||||
Token: NewEndpoint(defaultTokenEndpoint),
|
||||
Introspection: NewEndpoint(defaultIntrospectEndpoint),
|
||||
Userinfo: NewEndpoint(defaultUserinfoEndpoint),
|
||||
Revocation: NewEndpoint(defaultRevocationEndpoint),
|
||||
EndSession: NewEndpoint(defaultEndSessionEndpoint),
|
||||
JwksURI: NewEndpoint(defaultKeysEndpoint),
|
||||
DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint),
|
||||
}
|
||||
|
||||
defaultCORSOptions = cors.Options{
|
||||
AllowCredentials: true,
|
||||
AllowedHeaders: []string{
|
||||
"Origin",
|
||||
"Accept",
|
||||
"Accept-Language",
|
||||
"Authorization",
|
||||
"Content-Type",
|
||||
"X-Requested-With",
|
||||
},
|
||||
AllowedMethods: []string{
|
||||
http.MethodGet,
|
||||
http.MethodHead,
|
||||
http.MethodPost,
|
||||
},
|
||||
ExposedHeaders: []string{
|
||||
"Location",
|
||||
"Content-Length",
|
||||
},
|
||||
AllowOriginFunc: func(_ string) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
type OpenIDProvider interface {
|
||||
Configuration
|
||||
Storage() Storage
|
||||
Decoder() httphelper.Decoder
|
||||
Encoder() httphelper.Encoder
|
||||
IDTokenHintVerifier() IDTokenHintVerifier
|
||||
AccessTokenVerifier() AccessTokenVerifier
|
||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
||||
Crypto() Crypto
|
||||
DefaultLogoutRedirectURI() string
|
||||
Signer() Signer
|
||||
Probes() []ProbesFn
|
||||
HttpHandler() http.Handler
|
||||
}
|
||||
|
||||
type HttpInterceptor func(http.Handler) http.Handler
|
||||
|
||||
var defaultCORSOptions = cors.Options{
|
||||
AllowCredentials: true,
|
||||
AllowedHeaders: []string{
|
||||
"Origin",
|
||||
"Accept",
|
||||
"Accept-Language",
|
||||
"Authorization",
|
||||
"Content-Type",
|
||||
"X-Requested-With",
|
||||
},
|
||||
AllowedMethods: []string{
|
||||
http.MethodGet,
|
||||
http.MethodHead,
|
||||
http.MethodPost,
|
||||
},
|
||||
ExposedHeaders: []string{
|
||||
"Location",
|
||||
"Content-Length",
|
||||
},
|
||||
AllowOriginFunc: func(_ string) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router {
|
||||
intercept := buildInterceptor(interceptors...)
|
||||
router := mux.NewRouter()
|
||||
router.Use(cors.New(defaultCORSOptions).Handler)
|
||||
router.Use(intercept(o.IssuerFromRequest, interceptors...))
|
||||
router.HandleFunc(healthEndpoint, healthHandler)
|
||||
router.HandleFunc(readinessEndpoint, readyHandler(o.Probes()))
|
||||
router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Signer()))
|
||||
router.Handle(o.AuthorizationEndpoint().Relative(), intercept(authorizeHandler(o)))
|
||||
router.NewRoute().Path(authCallbackPath(o)).Queries("id", "{id}").Handler(intercept(authorizeCallbackHandler(o)))
|
||||
router.Handle(o.TokenEndpoint().Relative(), intercept(tokenHandler(o)))
|
||||
router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage()))
|
||||
router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o))
|
||||
router.NewRoute().Path(authCallbackPath(o)).Queries("id", "{id}").HandlerFunc(authorizeCallbackHandler(o))
|
||||
router.HandleFunc(o.TokenEndpoint().Relative(), tokenHandler(o))
|
||||
router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o))
|
||||
router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o))
|
||||
router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o))
|
||||
router.Handle(o.EndSessionEndpoint().Relative(), intercept(endSessionHandler(o)))
|
||||
router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o))
|
||||
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
|
||||
router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), DeviceAuthorizationHandler(o))
|
||||
return router
|
||||
}
|
||||
|
||||
// AuthCallbackURL builds the url for the redirect (with the requestID) after a successful login
|
||||
func AuthCallbackURL(o OpenIDProvider) func(string) string {
|
||||
return func(requestID string) string {
|
||||
return o.AuthorizationEndpoint().Absolute(o.Issuer()) + authCallbackPathSuffix + "?id=" + requestID
|
||||
func AuthCallbackURL(o OpenIDProvider) func(context.Context, string) string {
|
||||
return func(ctx context.Context, requestID string) string {
|
||||
return o.AuthorizationEndpoint().Absolute(IssuerFromContext(ctx)) + authCallbackPathSuffix + "?id=" + requestID
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -109,7 +113,6 @@ func authCallbackPath(o OpenIDProvider) string {
|
|||
}
|
||||
|
||||
type Config struct {
|
||||
Issuer string
|
||||
CryptoKey [32]byte
|
||||
DefaultLogoutRedirectURI string
|
||||
CodeMethodS256 bool
|
||||
|
@ -118,44 +121,52 @@ type Config struct {
|
|||
GrantTypeRefreshToken bool
|
||||
RequestObjectSupported bool
|
||||
SupportedUILocales []language.Tag
|
||||
DeviceAuthorization DeviceAuthorizationConfig
|
||||
}
|
||||
|
||||
type endpoints struct {
|
||||
Authorization Endpoint
|
||||
Token Endpoint
|
||||
Introspection Endpoint
|
||||
Userinfo Endpoint
|
||||
Revocation Endpoint
|
||||
EndSession Endpoint
|
||||
CheckSessionIframe Endpoint
|
||||
JwksURI Endpoint
|
||||
Authorization Endpoint
|
||||
Token Endpoint
|
||||
Introspection Endpoint
|
||||
Userinfo Endpoint
|
||||
Revocation Endpoint
|
||||
EndSession Endpoint
|
||||
CheckSessionIframe Endpoint
|
||||
JwksURI Endpoint
|
||||
DeviceAuthorization Endpoint
|
||||
}
|
||||
|
||||
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
|
||||
// a http.Router that handles a suite of endpoints (some paths can be overridden):
|
||||
// /healthz
|
||||
// /ready
|
||||
// /.well-known/openid-configuration
|
||||
// /oauth/token
|
||||
// /oauth/introspect
|
||||
// /callback
|
||||
// /authorize
|
||||
// /userinfo
|
||||
// /revoke
|
||||
// /end_session
|
||||
// /keys
|
||||
//
|
||||
// /healthz
|
||||
// /ready
|
||||
// /.well-known/openid-configuration
|
||||
// /oauth/token
|
||||
// /oauth/introspect
|
||||
// /callback
|
||||
// /authorize
|
||||
// /userinfo
|
||||
// /revoke
|
||||
// /end_session
|
||||
// /keys
|
||||
// /device_authorization
|
||||
//
|
||||
// This does not include login. Login is handled with a redirect that includes the
|
||||
// request ID. The redirect for logins is specified per-client by Client.LoginURL().
|
||||
// Successful logins should mark the request as authorized and redirect back to to
|
||||
// op.AuthCallbackURL(provider) which is probably /callback. On the redirect back
|
||||
// to the AuthCallbackURL, the request id should be passed as the "id" parameter.
|
||||
func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opOpts ...Option) (OpenIDProvider, error) {
|
||||
err := ValidateIssuer(config.Issuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func NewOpenIDProvider(issuer string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) {
|
||||
return newProvider(config, storage, StaticIssuer(issuer), opOpts...)
|
||||
}
|
||||
|
||||
o := &openidProvider{
|
||||
func NewDynamicOpenIDProvider(path string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) {
|
||||
return newProvider(config, storage, IssuerFromHost(path), opOpts...)
|
||||
}
|
||||
|
||||
func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromRequest, error), opOpts ...Option) (_ *Provider, err error) {
|
||||
o := &Provider{
|
||||
config: config,
|
||||
storage: storage,
|
||||
endpoints: DefaultEndpoints,
|
||||
|
@ -168,36 +179,32 @@ func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opO
|
|||
}
|
||||
}
|
||||
|
||||
keyCh := make(chan jose.SigningKey)
|
||||
go storage.GetSigningKey(ctx, keyCh)
|
||||
o.signer = NewSigner(ctx, storage, keyCh)
|
||||
o.issuer, err = issuer(o.insecure)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o.httpHandler = CreateRouter(o, o.interceptors...)
|
||||
|
||||
o.decoder = schema.NewDecoder()
|
||||
o.decoder.IgnoreUnknownKeys(true)
|
||||
|
||||
o.encoder = schema.NewEncoder()
|
||||
o.encoder = oidc.NewEncoder()
|
||||
|
||||
o.crypto = NewAESCrypto(config.CryptoKey)
|
||||
|
||||
// Avoid potential race conditions by calling these early
|
||||
_ = o.AccessTokenVerifier() // sets accessTokenVerifier
|
||||
_ = o.IDTokenHintVerifier() // sets idTokenHintVerifier
|
||||
_ = o.JWTProfileVerifier() // sets jwtProfileVerifier
|
||||
_ = o.openIDKeySet() // sets keySet
|
||||
_ = o.openIDKeySet() // sets keySet
|
||||
|
||||
return o, nil
|
||||
}
|
||||
|
||||
type openidProvider struct {
|
||||
type Provider struct {
|
||||
config *Config
|
||||
issuer IssuerFromRequest
|
||||
insecure bool
|
||||
endpoints *endpoints
|
||||
storage Storage
|
||||
signer Signer
|
||||
idTokenHintVerifier IDTokenHintVerifier
|
||||
jwtProfileVerifier JWTProfileVerifier
|
||||
accessTokenVerifier AccessTokenVerifier
|
||||
keySet *openIDKeySet
|
||||
crypto Crypto
|
||||
httpHandler http.Handler
|
||||
|
@ -209,159 +216,163 @@ type openidProvider struct {
|
|||
idTokenHintVerifierOpts []IDTokenHintVerifierOpt
|
||||
}
|
||||
|
||||
func (o *openidProvider) Issuer() string {
|
||||
return o.config.Issuer
|
||||
func (o *Provider) IssuerFromRequest(r *http.Request) string {
|
||||
return o.issuer(r)
|
||||
}
|
||||
|
||||
func (o *openidProvider) AuthorizationEndpoint() Endpoint {
|
||||
func (o *Provider) Insecure() bool {
|
||||
return o.insecure
|
||||
}
|
||||
|
||||
func (o *Provider) AuthorizationEndpoint() Endpoint {
|
||||
return o.endpoints.Authorization
|
||||
}
|
||||
|
||||
func (o *openidProvider) TokenEndpoint() Endpoint {
|
||||
func (o *Provider) TokenEndpoint() Endpoint {
|
||||
return o.endpoints.Token
|
||||
}
|
||||
|
||||
func (o *openidProvider) IntrospectionEndpoint() Endpoint {
|
||||
func (o *Provider) IntrospectionEndpoint() Endpoint {
|
||||
return o.endpoints.Introspection
|
||||
}
|
||||
|
||||
func (o *openidProvider) UserinfoEndpoint() Endpoint {
|
||||
func (o *Provider) UserinfoEndpoint() Endpoint {
|
||||
return o.endpoints.Userinfo
|
||||
}
|
||||
|
||||
func (o *openidProvider) RevocationEndpoint() Endpoint {
|
||||
func (o *Provider) RevocationEndpoint() Endpoint {
|
||||
return o.endpoints.Revocation
|
||||
}
|
||||
|
||||
func (o *openidProvider) EndSessionEndpoint() Endpoint {
|
||||
func (o *Provider) EndSessionEndpoint() Endpoint {
|
||||
return o.endpoints.EndSession
|
||||
}
|
||||
|
||||
func (o *openidProvider) KeysEndpoint() Endpoint {
|
||||
func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
|
||||
return o.endpoints.DeviceAuthorization
|
||||
}
|
||||
|
||||
func (o *Provider) KeysEndpoint() Endpoint {
|
||||
return o.endpoints.JwksURI
|
||||
}
|
||||
|
||||
func (o *openidProvider) AuthMethodPostSupported() bool {
|
||||
func (o *Provider) AuthMethodPostSupported() bool {
|
||||
return o.config.AuthMethodPost
|
||||
}
|
||||
|
||||
func (o *openidProvider) CodeMethodS256Supported() bool {
|
||||
func (o *Provider) CodeMethodS256Supported() bool {
|
||||
return o.config.CodeMethodS256
|
||||
}
|
||||
|
||||
func (o *openidProvider) AuthMethodPrivateKeyJWTSupported() bool {
|
||||
func (o *Provider) AuthMethodPrivateKeyJWTSupported() bool {
|
||||
return o.config.AuthMethodPrivateKeyJWT
|
||||
}
|
||||
|
||||
func (o *openidProvider) TokenEndpointSigningAlgorithmsSupported() []string {
|
||||
func (o *Provider) TokenEndpointSigningAlgorithmsSupported() []string {
|
||||
return []string{"RS256"}
|
||||
}
|
||||
|
||||
func (o *openidProvider) GrantTypeRefreshTokenSupported() bool {
|
||||
func (o *Provider) GrantTypeRefreshTokenSupported() bool {
|
||||
return o.config.GrantTypeRefreshToken
|
||||
}
|
||||
|
||||
func (o *openidProvider) GrantTypeTokenExchangeSupported() bool {
|
||||
return false
|
||||
func (o *Provider) GrantTypeTokenExchangeSupported() bool {
|
||||
_, ok := o.storage.(TokenExchangeStorage)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (o *openidProvider) GrantTypeJWTAuthorizationSupported() bool {
|
||||
func (o *Provider) GrantTypeJWTAuthorizationSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (o *openidProvider) GrantTypeClientCredentialsSupported() bool {
|
||||
func (o *Provider) GrantTypeDeviceCodeSupported() bool {
|
||||
_, ok := o.storage.(DeviceAuthorizationStorage)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (o *Provider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (o *Provider) IntrospectionEndpointSigningAlgorithmsSupported() []string {
|
||||
return []string{"RS256"}
|
||||
}
|
||||
|
||||
func (o *Provider) GrantTypeClientCredentialsSupported() bool {
|
||||
_, ok := o.storage.(ClientCredentialsStorage)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (o *openidProvider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
|
||||
func (o *Provider) RevocationAuthMethodPrivateKeyJWTSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (o *openidProvider) IntrospectionEndpointSigningAlgorithmsSupported() []string {
|
||||
func (o *Provider) RevocationEndpointSigningAlgorithmsSupported() []string {
|
||||
return []string{"RS256"}
|
||||
}
|
||||
|
||||
func (o *openidProvider) RevocationAuthMethodPrivateKeyJWTSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (o *openidProvider) RevocationEndpointSigningAlgorithmsSupported() []string {
|
||||
return []string{"RS256"}
|
||||
}
|
||||
|
||||
func (o *openidProvider) RequestObjectSupported() bool {
|
||||
func (o *Provider) RequestObjectSupported() bool {
|
||||
return o.config.RequestObjectSupported
|
||||
}
|
||||
|
||||
func (o *openidProvider) RequestObjectSigningAlgorithmsSupported() []string {
|
||||
func (o *Provider) RequestObjectSigningAlgorithmsSupported() []string {
|
||||
return []string{"RS256"}
|
||||
}
|
||||
|
||||
func (o *openidProvider) SupportedUILocales() []language.Tag {
|
||||
func (o *Provider) SupportedUILocales() []language.Tag {
|
||||
return o.config.SupportedUILocales
|
||||
}
|
||||
|
||||
func (o *openidProvider) Storage() Storage {
|
||||
func (o *Provider) DeviceAuthorization() DeviceAuthorizationConfig {
|
||||
return o.config.DeviceAuthorization
|
||||
}
|
||||
|
||||
func (o *Provider) Storage() Storage {
|
||||
return o.storage
|
||||
}
|
||||
|
||||
func (o *openidProvider) Decoder() httphelper.Decoder {
|
||||
func (o *Provider) Decoder() httphelper.Decoder {
|
||||
return o.decoder
|
||||
}
|
||||
|
||||
func (o *openidProvider) Encoder() httphelper.Encoder {
|
||||
func (o *Provider) Encoder() httphelper.Encoder {
|
||||
return o.encoder
|
||||
}
|
||||
|
||||
func (o *openidProvider) IDTokenHintVerifier() IDTokenHintVerifier {
|
||||
if o.idTokenHintVerifier == nil {
|
||||
o.idTokenHintVerifier = NewIDTokenHintVerifier(o.Issuer(), o.openIDKeySet(), o.idTokenHintVerifierOpts...)
|
||||
}
|
||||
return o.idTokenHintVerifier
|
||||
func (o *Provider) IDTokenHintVerifier(ctx context.Context) IDTokenHintVerifier {
|
||||
return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...)
|
||||
}
|
||||
|
||||
func (o *openidProvider) JWTProfileVerifier() JWTProfileVerifier {
|
||||
if o.jwtProfileVerifier == nil {
|
||||
o.jwtProfileVerifier = NewJWTProfileVerifier(o.Storage(), o.Issuer(), 1*time.Hour, time.Second)
|
||||
}
|
||||
return o.jwtProfileVerifier
|
||||
func (o *Provider) JWTProfileVerifier(ctx context.Context) JWTProfileVerifier {
|
||||
return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second)
|
||||
}
|
||||
|
||||
func (o *openidProvider) AccessTokenVerifier() AccessTokenVerifier {
|
||||
if o.accessTokenVerifier == nil {
|
||||
o.accessTokenVerifier = NewAccessTokenVerifier(o.Issuer(), o.openIDKeySet(), o.accessTokenVerifierOpts...)
|
||||
}
|
||||
return o.accessTokenVerifier
|
||||
func (o *Provider) AccessTokenVerifier(ctx context.Context) AccessTokenVerifier {
|
||||
return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...)
|
||||
}
|
||||
|
||||
func (o *openidProvider) openIDKeySet() oidc.KeySet {
|
||||
func (o *Provider) openIDKeySet() oidc.KeySet {
|
||||
if o.keySet == nil {
|
||||
o.keySet = &openIDKeySet{o.Storage()}
|
||||
}
|
||||
return o.keySet
|
||||
}
|
||||
|
||||
func (o *openidProvider) Crypto() Crypto {
|
||||
func (o *Provider) Crypto() Crypto {
|
||||
return o.crypto
|
||||
}
|
||||
|
||||
func (o *openidProvider) DefaultLogoutRedirectURI() string {
|
||||
func (o *Provider) DefaultLogoutRedirectURI() string {
|
||||
return o.config.DefaultLogoutRedirectURI
|
||||
}
|
||||
|
||||
func (o *openidProvider) Signer() Signer {
|
||||
return o.signer
|
||||
}
|
||||
|
||||
func (o *openidProvider) Probes() []ProbesFn {
|
||||
func (o *Provider) Probes() []ProbesFn {
|
||||
return []ProbesFn{
|
||||
ReadySigner(o.Signer()),
|
||||
ReadyStorage(o.Storage()),
|
||||
}
|
||||
}
|
||||
|
||||
func (o *openidProvider) HttpHandler() http.Handler {
|
||||
func (o *Provider) HttpHandler() http.Handler {
|
||||
return o.httpHandler
|
||||
}
|
||||
|
||||
|
@ -372,22 +383,31 @@ type openIDKeySet struct {
|
|||
// VerifySignature implements the oidc.KeySet interface
|
||||
// providing an implementation for the keys stored in the OP Storage interface
|
||||
func (o *openIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||
keySet, err := o.Storage.GetKeySet(ctx)
|
||||
keySet, err := o.Storage.KeySet(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error fetching keys: %w", err)
|
||||
}
|
||||
keyID, alg := oidc.GetKeyIDAndAlg(jws)
|
||||
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keySet.Keys...)
|
||||
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, jsonWebKeySet(keySet).Keys...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid signature: %w", err)
|
||||
}
|
||||
return jws.Verify(&key)
|
||||
}
|
||||
|
||||
type Option func(o *openidProvider) error
|
||||
type Option func(o *Provider) error
|
||||
|
||||
// WithAllowInsecure allows the use of http (instead of https) for issuers
|
||||
// this is not recommended for production use and violates the OIDC specification
|
||||
func WithAllowInsecure() Option {
|
||||
return func(o *Provider) error {
|
||||
o.insecure = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithCustomAuthEndpoint(endpoint Endpoint) Option {
|
||||
return func(o *openidProvider) error {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -397,7 +417,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
|
||||
func WithCustomTokenEndpoint(endpoint Endpoint) Option {
|
||||
return func(o *openidProvider) error {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -407,7 +427,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
|
||||
func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
|
||||
return func(o *openidProvider) error {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -417,7 +437,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
|
||||
func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
|
||||
return func(o *openidProvider) error {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -427,7 +447,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
|
||||
func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
|
||||
return func(o *openidProvider) error {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -437,7 +457,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
|
||||
func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
|
||||
return func(o *openidProvider) error {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -447,7 +467,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
|
||||
func WithCustomKeysEndpoint(endpoint Endpoint) Option {
|
||||
return func(o *openidProvider) error {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -457,7 +477,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
|
||||
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option {
|
||||
return func(o *openidProvider) error {
|
||||
return func(o *Provider) error {
|
||||
o.endpoints.Authorization = auth
|
||||
o.endpoints.Token = token
|
||||
o.endpoints.Userinfo = userInfo
|
||||
|
@ -469,38 +489,32 @@ func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys End
|
|||
}
|
||||
|
||||
func WithHttpInterceptors(interceptors ...HttpInterceptor) Option {
|
||||
return func(o *openidProvider) error {
|
||||
return func(o *Provider) error {
|
||||
o.interceptors = append(o.interceptors, interceptors...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithAccessTokenVerifierOpts(opts ...AccessTokenVerifierOpt) Option {
|
||||
return func(o *openidProvider) error {
|
||||
return func(o *Provider) error {
|
||||
o.accessTokenVerifierOpts = opts
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option {
|
||||
return func(o *openidProvider) error {
|
||||
return func(o *Provider) error {
|
||||
o.idTokenHintVerifierOpts = opts
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func buildInterceptor(interceptors ...HttpInterceptor) func(http.HandlerFunc) http.Handler {
|
||||
return func(handlerFunc http.HandlerFunc) http.Handler {
|
||||
handler := handlerFuncToHandler(handlerFunc)
|
||||
func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler {
|
||||
issuerInterceptor := NewIssuerInterceptor(i)
|
||||
return func(handler http.Handler) http.Handler {
|
||||
for i := len(interceptors) - 1; i >= 0; i-- {
|
||||
handler = interceptors[i](handler)
|
||||
}
|
||||
return handler
|
||||
return cors.New(defaultCORSOptions).Handler(issuerInterceptor.Handler(handler))
|
||||
}
|
||||
}
|
||||
|
||||
func handlerFuncToHandler(handlerFunc http.HandlerFunc) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlerFunc(w, r)
|
||||
})
|
||||
}
|
||||
|
|
392
pkg/op/op_test.go
Normal file
392
pkg/op/op_test.go
Normal file
|
@ -0,0 +1,392 @@
|
|||
package op_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v2/example/server/storage"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
var testProvider op.OpenIDProvider
|
||||
|
||||
const (
|
||||
testIssuer = "https://localhost:9998/"
|
||||
pathLoggedOut = "/logged-out"
|
||||
)
|
||||
|
||||
func init() {
|
||||
config := &op.Config{
|
||||
CryptoKey: sha256.Sum256([]byte("test")),
|
||||
DefaultLogoutRedirectURI: pathLoggedOut,
|
||||
CodeMethodS256: true,
|
||||
AuthMethodPost: true,
|
||||
AuthMethodPrivateKeyJWT: true,
|
||||
GrantTypeRefreshToken: true,
|
||||
RequestObjectSupported: true,
|
||||
SupportedUILocales: []language.Tag{language.English},
|
||||
DeviceAuthorization: op.DeviceAuthorizationConfig{
|
||||
Lifetime: 5 * time.Minute,
|
||||
PollInterval: 5 * time.Second,
|
||||
UserFormURL: testIssuer + "device",
|
||||
UserCode: op.UserCodeBase20,
|
||||
},
|
||||
}
|
||||
|
||||
storage.RegisterClients(
|
||||
storage.NativeClient("native"),
|
||||
storage.WebClient("web", "secret", "https://example.com"),
|
||||
storage.WebClient("api", "secret"),
|
||||
)
|
||||
|
||||
var err error
|
||||
testProvider, err = op.NewOpenIDProvider(testIssuer, config,
|
||||
storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
type routesTestStorage interface {
|
||||
op.Storage
|
||||
AuthRequestDone(id string) error
|
||||
}
|
||||
|
||||
func mapAsValues(m map[string]string) string {
|
||||
values := make(url.Values, len(m))
|
||||
for k, v := range m {
|
||||
values.Set(k, v)
|
||||
}
|
||||
return values.Encode()
|
||||
}
|
||||
|
||||
func TestRoutes(t *testing.T) {
|
||||
storage := testProvider.Storage().(routesTestStorage)
|
||||
ctx := op.ContextWithIssuer(context.Background(), testIssuer)
|
||||
|
||||
client, err := storage.GetClientByClientID(ctx, "web")
|
||||
require.NoError(t, err)
|
||||
|
||||
oidcAuthReq := &oidc.AuthRequest{
|
||||
ClientID: client.GetID(),
|
||||
RedirectURI: "https://example.com",
|
||||
MaxAge: gu.Ptr[uint](300),
|
||||
Scopes: oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, oidc.ScopeEmail, oidc.ScopeProfile, oidc.ScopePhone},
|
||||
ResponseType: oidc.ResponseTypeCode,
|
||||
}
|
||||
|
||||
authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1")
|
||||
require.NoError(t, err)
|
||||
storage.AuthRequestDone(authReq.GetID())
|
||||
|
||||
accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
|
||||
require.NoError(t, err)
|
||||
accessTokenRevoke, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
|
||||
require.NoError(t, err)
|
||||
idToken, err := op.CreateIDToken(ctx, testIssuer, authReq, time.Hour, accessToken, "123", storage, client)
|
||||
require.NoError(t, err)
|
||||
jwtToken, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeJWT, testProvider, client, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
oidcAuthReq.IDTokenHint = idToken
|
||||
|
||||
serverURL, err := url.Parse(testIssuer)
|
||||
require.NoError(t, err)
|
||||
|
||||
type basicAuth struct {
|
||||
username, password string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
basicAuth *basicAuth
|
||||
header map[string]string
|
||||
values map[string]string
|
||||
body map[string]string
|
||||
wantCode int
|
||||
headerContains map[string]string
|
||||
json string // test for exact json output
|
||||
contains []string // when the body output is not constant, we just check for snippets to be present in the response
|
||||
}{
|
||||
{
|
||||
name: "health",
|
||||
method: http.MethodGet,
|
||||
path: "/healthz",
|
||||
wantCode: http.StatusOK,
|
||||
json: `{"status":"ok"}`,
|
||||
},
|
||||
{
|
||||
name: "ready",
|
||||
method: http.MethodGet,
|
||||
path: "/ready",
|
||||
wantCode: http.StatusOK,
|
||||
json: `{"status":"ok"}`,
|
||||
},
|
||||
{
|
||||
name: "discovery",
|
||||
method: http.MethodGet,
|
||||
path: oidc.DiscoveryEndpoint,
|
||||
wantCode: http.StatusOK,
|
||||
json: `{"issuer":"https://localhost:9998/","authorization_endpoint":"https://localhost:9998/authorize","token_endpoint":"https://localhost:9998/oauth/token","introspection_endpoint":"https://localhost:9998/oauth/introspect","userinfo_endpoint":"https://localhost:9998/userinfo","revocation_endpoint":"https://localhost:9998/revoke","end_session_endpoint":"https://localhost:9998/end_session","device_authorization_endpoint":"https://localhost:9998/device_authorization","jwks_uri":"https://localhost:9998/keys","scopes_supported":["openid","profile","email","phone","address","offline_access"],"response_types_supported":["code","id_token","id_token token"],"grant_types_supported":["authorization_code","implicit","refresh_token","client_credentials","urn:ietf:params:oauth:grant-type:token-exchange","urn:ietf:params:oauth:grant-type:jwt-bearer","urn:ietf:params:oauth:grant-type:device_code"],"subject_types_supported":["public"],"id_token_signing_alg_values_supported":["RS256"],"request_object_signing_alg_values_supported":["RS256"],"token_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"token_endpoint_auth_signing_alg_values_supported":["RS256"],"revocation_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["RS256"],"introspection_endpoint_auth_methods_supported":["client_secret_basic","private_key_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["RS256"],"claims_supported":["sub","aud","exp","iat","iss","auth_time","nonce","acr","amr","c_hash","at_hash","act","scopes","client_id","azp","preferred_username","name","family_name","given_name","locale","email","email_verified","phone_number","phone_number_verified"],"code_challenge_methods_supported":["S256"],"ui_locales_supported":["en"],"request_parameter_supported":true,"request_uri_parameter_supported":false}`,
|
||||
},
|
||||
{
|
||||
name: "authorization",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.AuthorizationEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"client_id": client.GetID(),
|
||||
"redirect_uri": "https://example.com",
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
|
||||
"response_type": string(oidc.ResponseTypeCode),
|
||||
},
|
||||
wantCode: http.StatusFound,
|
||||
headerContains: map[string]string{"Location": "/login/username?authRequestID="},
|
||||
},
|
||||
{
|
||||
name: "authorization callback",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.AuthorizationEndpoint().Relative() + "/callback",
|
||||
values: map[string]string{"id": authReq.GetID()},
|
||||
wantCode: http.StatusFound,
|
||||
headerContains: map[string]string{"Location": "https://example.com?code="},
|
||||
contains: []string{
|
||||
`<a href="https://example.com?code=`,
|
||||
">Found</a>.",
|
||||
},
|
||||
},
|
||||
{
|
||||
// This call will fail. A successfull test is already
|
||||
// part of client/integration_test.go
|
||||
name: "code exchange",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeCode),
|
||||
"code": "123",
|
||||
},
|
||||
wantCode: http.StatusUnauthorized,
|
||||
json: `{"error":"invalid_client"}`,
|
||||
},
|
||||
{
|
||||
name: "JWT authorization",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeBearer),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
|
||||
"assertion": jwtToken,
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
json: "{\"error\":\"server_error\",\"error_description\":\"audience is not valid: Audience must contain client_id \\\"https://localhost:9998/\\\"\"}",
|
||||
},
|
||||
{
|
||||
name: "Token exchange",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"web", "secret"},
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeTokenExchange),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
|
||||
"subject_token": jwtToken,
|
||||
"subject_token_type": string(oidc.AccessTokenType),
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{
|
||||
`{"access_token":"`,
|
||||
`","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Client credentials exchange",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"sid1", "verysecret"},
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeClientCredentials),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`},
|
||||
},
|
||||
{
|
||||
// This call will fail. A successfull test is already
|
||||
// part of device_test.go
|
||||
name: "device token",
|
||||
method: http.MethodPost,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"web", "secret"},
|
||||
header: map[string]string{
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
body: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeDeviceCode),
|
||||
"device_code": "123",
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
json: `{"error":"access_denied","error_description":"The authorization request was denied."}`,
|
||||
},
|
||||
{
|
||||
name: "missing grant type",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
wantCode: http.StatusBadRequest,
|
||||
json: `{"error":"invalid_request","error_description":"grant_type missing"}`,
|
||||
},
|
||||
{
|
||||
name: "unsupported grant type",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"grant_type": "foo",
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
json: `{"error":"unsupported_grant_type","error_description":"foo not supported"}`,
|
||||
},
|
||||
{
|
||||
name: "introspection",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.IntrospectionEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"web", "secret"},
|
||||
values: map[string]string{
|
||||
"token": accessToken,
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
json: `{"active":true,"scope":"openid offline_access email profile phone","client_id":"web","sub":"id1","username":"test-user@localhost","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
|
||||
},
|
||||
{
|
||||
name: "user info",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.UserinfoEndpoint().Relative(),
|
||||
header: map[string]string{
|
||||
"authorization": "Bearer " + accessToken,
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
json: `{"sub":"id1","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
|
||||
},
|
||||
{
|
||||
name: "refresh token",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeRefreshToken),
|
||||
"refresh_token": refreshToken,
|
||||
"client_id": client.GetID(),
|
||||
"client_secret": "secret",
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{
|
||||
`{"access_token":"`,
|
||||
`","token_type":"Bearer","refresh_token":"`,
|
||||
`","expires_in":299,"id_token":"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "revoke",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.RevocationEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"web", "secret"},
|
||||
values: map[string]string{
|
||||
"token": accessTokenRevoke,
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "end session",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.EndSessionEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"id_token_hint": idToken,
|
||||
"client_id": "web",
|
||||
},
|
||||
wantCode: http.StatusFound,
|
||||
headerContains: map[string]string{"Location": "/logged-out"},
|
||||
contains: []string{`<a href="/logged-out">Found</a>.`},
|
||||
},
|
||||
{
|
||||
name: "keys",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.KeysEndpoint().Relative(),
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{
|
||||
`{"keys":[{"use":"sig","kty":"RSA","kid":"`,
|
||||
`","alg":"RS256","n":"`, `","e":"AQAB"}]}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "device authorization",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.DeviceAuthorizationEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"web", "secret"},
|
||||
values: map[string]string{
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{
|
||||
`{"device_code":"`, `","user_code":"`,
|
||||
`","verification_uri":"https://localhost:9998/device"`,
|
||||
`"verification_uri_complete":"https://localhost:9998/device?user_code=`,
|
||||
`","expires_in":300,"interval":5}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
u := gu.PtrCopy(serverURL)
|
||||
u.Path = tt.path
|
||||
if tt.values != nil {
|
||||
u.RawQuery = mapAsValues(tt.values)
|
||||
}
|
||||
var body io.Reader
|
||||
if tt.body != nil {
|
||||
body = strings.NewReader(mapAsValues(tt.body))
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(tt.method, u.String(), body)
|
||||
for k, v := range tt.header {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
if tt.basicAuth != nil {
|
||||
req.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
testProvider.HttpHandler().ServeHTTP(rec, req)
|
||||
|
||||
resp := rec.Result()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantCode, resp.StatusCode)
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
respBodyString := string(respBody)
|
||||
t.Log(respBodyString)
|
||||
t.Log(resp.Header)
|
||||
|
||||
if tt.json != "" {
|
||||
assert.JSONEq(t, tt.json, respBodyString)
|
||||
}
|
||||
for _, c := range tt.contains {
|
||||
assert.Contains(t, respBodyString, c)
|
||||
}
|
||||
for k, v := range tt.headerContains {
|
||||
assert.Contains(t, resp.Header.Get(k), v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -5,7 +5,7 @@ import (
|
|||
"errors"
|
||||
"net/http"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
)
|
||||
|
||||
type ProbesFn func(context.Context) error
|
||||
|
@ -31,15 +31,6 @@ func Readiness(w http.ResponseWriter, r *http.Request, probes ...ProbesFn) {
|
|||
ok(w)
|
||||
}
|
||||
|
||||
func ReadySigner(s Signer) ProbesFn {
|
||||
return func(ctx context.Context) error {
|
||||
if s == nil {
|
||||
return errors.New("no signer")
|
||||
}
|
||||
return s.Health(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func ReadyStorage(s Storage) ProbesFn {
|
||||
return func(ctx context.Context) error {
|
||||
if s == nil {
|
||||
|
|
|
@ -6,14 +6,14 @@ import (
|
|||
"net/url"
|
||||
"path"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type SessionEnder interface {
|
||||
Decoder() httphelper.Decoder
|
||||
Storage() Storage
|
||||
IDTokenHintVerifier() IDTokenHintVerifier
|
||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
||||
DefaultLogoutRedirectURI() string
|
||||
}
|
||||
|
||||
|
@ -60,7 +60,7 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
|
|||
RedirectURI: ender.DefaultLogoutRedirectURI(),
|
||||
}
|
||||
if req.IdTokenHint != "" {
|
||||
claims, err := VerifyIDTokenHint(ctx, req.IdTokenHint, ender.IDTokenHintVerifier())
|
||||
claims, err := VerifyIDTokenHint[*oidc.TokenClaims](ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx))
|
||||
if err != nil {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err)
|
||||
}
|
||||
|
|
|
@ -1,88 +1,38 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
type Signer interface {
|
||||
Health(ctx context.Context) error
|
||||
Signer() jose.Signer
|
||||
var (
|
||||
ErrSignerCreationFailed = errors.New("signer creation failed")
|
||||
)
|
||||
|
||||
type SigningKey interface {
|
||||
SignatureAlgorithm() jose.SignatureAlgorithm
|
||||
Key() interface{}
|
||||
ID() string
|
||||
}
|
||||
|
||||
type tokenSigner struct {
|
||||
signer jose.Signer
|
||||
storage AuthStorage
|
||||
alg jose.SignatureAlgorithm
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer {
|
||||
s := &tokenSigner{
|
||||
storage: storage,
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case key := <-keyCh:
|
||||
s.exchangeSigningKey(key)
|
||||
}
|
||||
go s.refreshSigningKey(ctx, keyCh)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *tokenSigner) Health(_ context.Context) error {
|
||||
if s.signer == nil {
|
||||
return errors.New("no signer")
|
||||
}
|
||||
if string(s.alg) == "" {
|
||||
return errors.New("no signing algorithm")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *tokenSigner) Signer() jose.Signer {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
return s.signer
|
||||
}
|
||||
|
||||
func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case key := <-keyCh:
|
||||
s.exchangeSigningKey(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *tokenSigner) exchangeSigningKey(key jose.SigningKey) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.alg = key.Algorithm
|
||||
if key.Algorithm == "" || key.Key == nil {
|
||||
s.signer = nil
|
||||
logging.Warn("signer has no key")
|
||||
return
|
||||
}
|
||||
var err error
|
||||
s.signer, err = jose.NewSigner(key, &jose.SignerOptions{})
|
||||
func SignerFromKey(key SigningKey) (jose.Signer, error) {
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: key.SignatureAlgorithm(),
|
||||
Key: &jose.JSONWebKey{
|
||||
Key: key.Key(),
|
||||
KeyID: key.ID(),
|
||||
},
|
||||
}, &jose.SignerOptions{})
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Error("error creating signer")
|
||||
return
|
||||
return nil, ErrSignerCreationFailed //TODO: log / wrap error?
|
||||
}
|
||||
logging.Info("signer exchanged signing key")
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||
return s.alg
|
||||
type Key interface {
|
||||
ID() string
|
||||
Algorithm() jose.SignatureAlgorithm
|
||||
Use() string
|
||||
Key() interface{}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type AuthStorage interface {
|
||||
|
@ -25,6 +25,8 @@ type AuthStorage interface {
|
|||
//
|
||||
// * *oidc.JWTTokenRequest from a JWT that is the assertion value of a JWT Profile
|
||||
// Grant: https://datatracker.ietf.org/doc/html/rfc7523#section-2.1
|
||||
//
|
||||
// * TokenExchangeRequest as returned by ValidateTokenExchangeRequest
|
||||
CreateAccessToken(context.Context, TokenRequest) (accessTokenID string, expiration time.Time, err error)
|
||||
|
||||
// The TokenRequest parameter of CreateAccessAndRefreshTokens can be any of:
|
||||
|
@ -36,6 +38,8 @@ type AuthStorage interface {
|
|||
// * AuthRequest as by returned by the AuthRequestByID or AuthRequestByCode (above).
|
||||
// Used for the authorization code flow which requested offline_access scope and
|
||||
// registered the refresh_token grant type in advance
|
||||
//
|
||||
// * TokenExchangeRequest as returned by ValidateTokenExchangeRequest
|
||||
CreateAccessAndRefreshTokens(ctx context.Context, request TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshTokenID string, expiration time.Time, err error)
|
||||
TokenRequestByRefreshToken(ctx context.Context, refreshTokenID string) (RefreshTokenRequest, error)
|
||||
|
||||
|
@ -44,44 +48,85 @@ type AuthStorage interface {
|
|||
// RevokeToken should revoke a token. In the situation that the original request was to
|
||||
// revoke an access token, then tokenOrTokenID will be a tokenID and userID will be set
|
||||
// but if the original request was for a refresh token, then userID will be empty and
|
||||
// tokenOrTokenID will be the refresh token, not its ID.
|
||||
// tokenOrTokenID will be the refresh token, not its ID. RevokeToken depends upon GetRefreshTokenInfo
|
||||
// to get information from refresh tokens that are not either "<tokenID>:<userID>" strings
|
||||
// nor JWTs.
|
||||
RevokeToken(ctx context.Context, tokenOrTokenID string, userID string, clientID string) *oidc.Error
|
||||
|
||||
GetSigningKey(context.Context, chan<- jose.SigningKey)
|
||||
GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
|
||||
}
|
||||
|
||||
// CanRefreshTokenInfo is an optional additional interface that Storage can support.
|
||||
// Supporting CanRefreshTokenInfo is required to be able to (revoke) a refresh token that
|
||||
// is neither an encrypted string of <tokenID>:<userID> nor a JWT.
|
||||
type CanRefreshTokenInfo interface {
|
||||
// GetRefreshTokenInfo must return ErrInvalidRefreshToken when presented
|
||||
// with a token that is not a refresh token.
|
||||
GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error)
|
||||
|
||||
SigningKey(context.Context) (SigningKey, error)
|
||||
SignatureAlgorithms(context.Context) ([]jose.SignatureAlgorithm, error)
|
||||
KeySet(context.Context) ([]Key, error)
|
||||
}
|
||||
|
||||
type ClientCredentialsStorage interface {
|
||||
ClientCredentials(ctx context.Context, clientID, clientSecret string) (Client, error)
|
||||
ClientCredentialsTokenRequest(ctx context.Context, clientID string, scopes []string) (TokenRequest, error)
|
||||
}
|
||||
|
||||
type TokenExchangeStorage interface {
|
||||
// ValidateTokenExchangeRequest will be called to validate parsed (including tokens) Token Exchange Grant request.
|
||||
//
|
||||
// Important validations can include:
|
||||
// - permissions
|
||||
// - set requested token type to some default value if it is empty (rfc 8693 allows it) using SetRequestedTokenType method.
|
||||
// Depending on RequestedTokenType - the following tokens will be issued:
|
||||
// - RefreshTokenType - both access and refresh tokens
|
||||
// - AccessTokenType - only access token
|
||||
// - IDTokenType - only id token
|
||||
// - validation of subject's token type on possibility to be exchanged to the requested token type (according to your requirements)
|
||||
// - scopes (and update them using SetCurrentScopes method)
|
||||
// - set new subject if it differs from exchange subject (impersonation flow)
|
||||
//
|
||||
// Request will include subject's and/or actor's token claims if correspinding tokens are access/id_token issued by op
|
||||
// or third party tokens parsed by TokenExchangeTokensVerifierStorage interface methods.
|
||||
ValidateTokenExchangeRequest(ctx context.Context, request TokenExchangeRequest) error
|
||||
|
||||
// CreateTokenExchangeRequest will be called after parsing and validating token exchange request.
|
||||
// Stored request is not accessed later by op - so it is up to implementer to decide
|
||||
// should this method actually store the request or not (common use case - store for it for audit purposes)
|
||||
CreateTokenExchangeRequest(ctx context.Context, request TokenExchangeRequest) error
|
||||
|
||||
// GetPrivateClaimsFromTokenExchangeRequest will be called during access token creation.
|
||||
// Claims evaluation can be based on all validated request data available, including: scopes, resource, audience, etc.
|
||||
GetPrivateClaimsFromTokenExchangeRequest(ctx context.Context, request TokenExchangeRequest) (claims map[string]interface{}, err error)
|
||||
|
||||
// SetUserinfoFromTokenExchangeRequest will be called during id token creation.
|
||||
// Claims evaluation can be based on all validated request data available, including: scopes, resource, audience, etc.
|
||||
SetUserinfoFromTokenExchangeRequest(ctx context.Context, userinfo *oidc.UserInfo, request TokenExchangeRequest) error
|
||||
}
|
||||
|
||||
// TokenExchangeTokensVerifierStorage is an optional interface used in token exchange process to verify tokens
|
||||
// issued by third-party applications. If interface is not implemented - only tokens issued by op will be exchanged.
|
||||
type TokenExchangeTokensVerifierStorage interface {
|
||||
VerifyExchangeSubjectToken(ctx context.Context, token string, tokenType oidc.TokenType) (tokenIDOrToken string, subject string, tokenClaims map[string]interface{}, err error)
|
||||
VerifyExchangeActorToken(ctx context.Context, token string, tokenType oidc.TokenType) (tokenIDOrToken string, actor string, tokenClaims map[string]interface{}, err error)
|
||||
}
|
||||
|
||||
var ErrInvalidRefreshToken = errors.New("invalid_refresh_token")
|
||||
|
||||
type ClientCredentialsStorage interface {
|
||||
ClientCredentialsTokenRequest(ctx context.Context, clientID string, scopes []string) (TokenRequest, error)
|
||||
}
|
||||
|
||||
type OPStorage interface {
|
||||
// GetClientByClientID loads a Client. The returned Client is never cached and is only used to
|
||||
// handle the current request.
|
||||
GetClientByClientID(ctx context.Context, clientID string) (Client, error)
|
||||
AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error
|
||||
SetUserinfoFromScopes(ctx context.Context, userinfo oidc.UserInfoSetter, userID, clientID string, scopes []string) error
|
||||
SetUserinfoFromToken(ctx context.Context, userinfo oidc.UserInfoSetter, tokenID, subject, origin string) error
|
||||
SetIntrospectionFromToken(ctx context.Context, userinfo oidc.IntrospectionResponse, tokenID, subject, clientID string) error
|
||||
SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error
|
||||
SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error
|
||||
SetIntrospectionFromToken(ctx context.Context, userinfo *oidc.IntrospectionResponse, tokenID, subject, clientID string) error
|
||||
GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]interface{}, error)
|
||||
|
||||
// GetKeyByIDAndUserID is mis-named. It does not pass userID. Instead
|
||||
// it passes the clientID.
|
||||
GetKeyByIDAndUserID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error)
|
||||
GetKeyByIDAndClientID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error)
|
||||
ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error)
|
||||
}
|
||||
|
||||
// JWTProfileTokenStorage is an additional, optional storage to implement
|
||||
// implementing it, allows specifying the [AccessTokenType] of the access_token returned form the JWT Profile TokenRequest
|
||||
type JWTProfileTokenStorage interface {
|
||||
JWTProfileTokenType(ctx context.Context, request TokenRequest) (AccessTokenType, error)
|
||||
}
|
||||
|
||||
// Storage is a required parameter for NewOpenIDProvider(). In addition to the
|
||||
// embedded interfaces below, if the passed Storage implements ClientCredentialsStorage
|
||||
// then the grant type "client_credentials" will be supported. In that case, the access
|
||||
|
@ -102,3 +147,50 @@ type EndSessionRequest struct {
|
|||
ClientID string
|
||||
RedirectURI string
|
||||
}
|
||||
|
||||
var ErrDuplicateUserCode = errors.New("user code already exists")
|
||||
|
||||
type DeviceAuthorizationState struct {
|
||||
ClientID string
|
||||
Scopes []string
|
||||
Expires time.Time
|
||||
Done bool
|
||||
Subject string
|
||||
Denied bool
|
||||
}
|
||||
|
||||
type DeviceAuthorizationStorage interface {
|
||||
// StoreDeviceAuthorizationRequest stores a new device authorization request in the database.
|
||||
// User code will be used by the user to complete the login flow and must be unique.
|
||||
// ErrDuplicateUserCode signals the caller should try again with a new code.
|
||||
//
|
||||
// Note that user codes are low entropy keys and when many exist in the
|
||||
// database, the change for collisions increases. Therefore implementers
|
||||
// of this interface must make sure that user codes of expired authentication flows are purged,
|
||||
// after some time.
|
||||
StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error
|
||||
|
||||
// GetDeviceAuthorizatonState returns the current state of the device authorization flow in the database.
|
||||
// The method is polled untill the the authorization is eighter Completed, Expired or Denied.
|
||||
GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error)
|
||||
|
||||
// GetDeviceAuthorizationByUserCode resturn the current state of the device authorization flow,
|
||||
// identified by the user code.
|
||||
GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*DeviceAuthorizationState, error)
|
||||
|
||||
// CompleteDeviceAuthorization marks a device authorization entry as Completed,
|
||||
// identified by userCode. The Subject is added to the state, so that
|
||||
// GetDeviceAuthorizatonState can use it to create a new Access Token.
|
||||
CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error
|
||||
|
||||
// DenyDeviceAuthorization marks a device authorization entry as Denied.
|
||||
DenyDeviceAuthorization(ctx context.Context, userCode string) error
|
||||
}
|
||||
|
||||
func assertDeviceStorage(s Storage) (DeviceAuthorizationStorage, error) {
|
||||
storage, ok := s.(DeviceAuthorizationStorage)
|
||||
if !ok {
|
||||
return nil, oidc.ErrUnsupportedGrantType().WithDescription("device_code grant not supported")
|
||||
}
|
||||
return storage, nil
|
||||
}
|
||||
|
|
|
@ -4,14 +4,12 @@ import (
|
|||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/crypto"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/pkg/strings"
|
||||
"github.com/zitadel/oidc/v2/pkg/crypto"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/strings"
|
||||
)
|
||||
|
||||
type TokenCreator interface {
|
||||
Issuer() string
|
||||
Signer() Signer
|
||||
Storage() Storage
|
||||
Crypto() Crypto
|
||||
}
|
||||
|
@ -22,6 +20,13 @@ type TokenRequest interface {
|
|||
GetScopes() []string
|
||||
}
|
||||
|
||||
type AccessTokenClient interface {
|
||||
GetID() string
|
||||
ClockSkew() time.Duration
|
||||
RestrictAdditionalAccessTokenScopes() func(scopes []string) []string
|
||||
GrantTypes() []oidc.GrantType
|
||||
}
|
||||
|
||||
func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Client, creator TokenCreator, createAccessToken bool, code, refreshToken string) (*oidc.AccessTokenResponse, error) {
|
||||
var accessToken, newRefreshToken string
|
||||
var validity time.Duration
|
||||
|
@ -32,7 +37,7 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
idToken, err := CreateIDToken(ctx, creator.Issuer(), request, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer(), client)
|
||||
idToken, err := CreateIDToken(ctx, IssuerFromContext(ctx), request, client.IDTokenLifetime(), accessToken, code, creator.Storage(), client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -57,7 +62,7 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli
|
|||
}, nil
|
||||
}
|
||||
|
||||
func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage, refreshToken string, client Client) (id, newRefreshToken string, exp time.Time, err error) {
|
||||
func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage, refreshToken string, client AccessTokenClient) (id, newRefreshToken string, exp time.Time, err error) {
|
||||
if needsRefreshToken(tokenRequest, client) {
|
||||
return storage.CreateAccessAndRefreshTokens(ctx, tokenRequest, refreshToken)
|
||||
}
|
||||
|
@ -65,10 +70,12 @@ func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storag
|
|||
return
|
||||
}
|
||||
|
||||
func needsRefreshToken(tokenRequest TokenRequest, client Client) bool {
|
||||
func needsRefreshToken(tokenRequest TokenRequest, client AccessTokenClient) bool {
|
||||
switch req := tokenRequest.(type) {
|
||||
case AuthRequest:
|
||||
return strings.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode && ValidateGrantType(client, oidc.GrantTypeRefreshToken)
|
||||
case TokenExchangeRequest:
|
||||
return req.GetRequestedTokenType() == oidc.RefreshTokenType
|
||||
case RefreshTokenRequest:
|
||||
return true
|
||||
default:
|
||||
|
@ -76,7 +83,7 @@ func needsRefreshToken(tokenRequest TokenRequest, client Client) bool {
|
|||
}
|
||||
}
|
||||
|
||||
func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client Client, refreshToken string) (accessToken, newRefreshToken string, validity time.Duration, err error) {
|
||||
func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client AccessTokenClient, refreshToken string) (accessToken, newRefreshToken string, validity time.Duration, err error) {
|
||||
id, newRefreshToken, exp, err := createTokens(ctx, tokenRequest, creator.Storage(), refreshToken, client)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
|
@ -87,7 +94,7 @@ func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTok
|
|||
}
|
||||
validity = exp.Add(clockSkew).Sub(time.Now().UTC())
|
||||
if accessTokenType == AccessTokenTypeJWT {
|
||||
accessToken, err = CreateJWT(ctx, creator.Issuer(), tokenRequest, exp, id, creator.Signer(), client, creator.Storage())
|
||||
accessToken, err = CreateJWT(ctx, IssuerFromContext(ctx), tokenRequest, exp, id, client, creator.Storage())
|
||||
return
|
||||
}
|
||||
accessToken, err = CreateBearerToken(id, tokenRequest.GetSubject(), creator.Crypto())
|
||||
|
@ -98,17 +105,41 @@ func CreateBearerToken(tokenID, subject string, crypto Crypto) (string, error) {
|
|||
return crypto.Encrypt(tokenID + ":" + subject)
|
||||
}
|
||||
|
||||
func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, signer Signer, client Client, storage Storage) (string, error) {
|
||||
func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, client AccessTokenClient, storage Storage) (string, error) {
|
||||
claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id, client.GetID(), client.ClockSkew())
|
||||
if client != nil {
|
||||
restrictedScopes := client.RestrictAdditionalAccessTokenScopes()(tokenRequest.GetScopes())
|
||||
privateClaims, err := storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(restrictedScopes))
|
||||
|
||||
var (
|
||||
privateClaims map[string]interface{}
|
||||
err error
|
||||
)
|
||||
|
||||
tokenExchangeRequest, okReq := tokenRequest.(TokenExchangeRequest)
|
||||
teStorage, okStorage := storage.(TokenExchangeStorage)
|
||||
if okReq && okStorage {
|
||||
privateClaims, err = teStorage.GetPrivateClaimsFromTokenExchangeRequest(
|
||||
ctx,
|
||||
tokenExchangeRequest,
|
||||
)
|
||||
} else {
|
||||
privateClaims, err = storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(restrictedScopes))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
claims.SetPrivateClaims(privateClaims)
|
||||
claims.Claims = privateClaims
|
||||
}
|
||||
return crypto.Sign(claims, signer.Signer())
|
||||
signingKey, err := storage.SigningKey(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
signer, err := SignerFromKey(signingKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return crypto.Sign(claims, signer)
|
||||
}
|
||||
|
||||
type IDTokenRequest interface {
|
||||
|
@ -120,7 +151,7 @@ type IDTokenRequest interface {
|
|||
GetSubject() string
|
||||
}
|
||||
|
||||
func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer, client Client) (string, error) {
|
||||
func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, validity time.Duration, accessToken, code string, storage Storage, client Client) (string, error) {
|
||||
exp := time.Now().UTC().Add(client.ClockSkew()).Add(validity)
|
||||
var acr, nonce string
|
||||
if authRequest, ok := request.(AuthRequest); ok {
|
||||
|
@ -129,33 +160,50 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v
|
|||
}
|
||||
claims := oidc.NewIDTokenClaims(issuer, request.GetSubject(), request.GetAudience(), exp, request.GetAuthTime(), nonce, acr, request.GetAMR(), request.GetClientID(), client.ClockSkew())
|
||||
scopes := client.RestrictAdditionalIdTokenScopes()(request.GetScopes())
|
||||
signingKey, err := storage.SigningKey(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if accessToken != "" {
|
||||
atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
|
||||
atHash, err := oidc.ClaimHash(accessToken, signingKey.SignatureAlgorithm())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
claims.SetAccessTokenHash(atHash)
|
||||
claims.AccessTokenHash = atHash
|
||||
if !client.IDTokenUserinfoClaimsAssertion() {
|
||||
scopes = removeUserinfoScopes(scopes)
|
||||
}
|
||||
}
|
||||
if len(scopes) > 0 {
|
||||
userInfo := oidc.NewUserInfo()
|
||||
|
||||
tokenExchangeRequest, okReq := request.(TokenExchangeRequest)
|
||||
teStorage, okStorage := storage.(TokenExchangeStorage)
|
||||
if okReq && okStorage {
|
||||
userInfo := new(oidc.UserInfo)
|
||||
err := teStorage.SetUserinfoFromTokenExchangeRequest(ctx, userInfo, tokenExchangeRequest)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
claims.SetUserInfo(userInfo)
|
||||
} else if len(scopes) > 0 {
|
||||
userInfo := new(oidc.UserInfo)
|
||||
err := storage.SetUserinfoFromScopes(ctx, userInfo, request.GetSubject(), request.GetClientID(), scopes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
claims.SetUserinfo(userInfo)
|
||||
claims.SetUserInfo(userInfo)
|
||||
}
|
||||
if code != "" {
|
||||
codeHash, err := oidc.ClaimHash(code, signer.SignatureAlgorithm())
|
||||
codeHash, err := oidc.ClaimHash(code, signingKey.SignatureAlgorithm())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
claims.SetCodeHash(codeHash)
|
||||
claims.CodeHash = codeHash
|
||||
}
|
||||
|
||||
return crypto.Sign(claims, signer.Signer())
|
||||
signer, err := SignerFromKey(signingKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return crypto.Sign(claims, signer)
|
||||
}
|
||||
|
||||
func removeUserinfoScopes(scopes []string) []string {
|
||||
|
|
|
@ -5,8 +5,8 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
// ClientCredentialsExchange handles the OAuth 2.0 client_credentials grant, including
|
||||
|
@ -63,15 +63,15 @@ func ParseClientCredentialsRequest(r *http.Request, decoder httphelper.Decoder)
|
|||
return request, nil
|
||||
}
|
||||
|
||||
// ValidateClientCredentialsRequest validates the refresh_token request parameters including authorization check of the client
|
||||
// and returns the data representing the original auth request corresponding to the refresh_token
|
||||
// ValidateClientCredentialsRequest validates the client_credentials request parameters including authorization check of the client
|
||||
// and returns a TokenRequest and Client implementation to be used in the client_credentials response, resp. creation of the corresponding access_token.
|
||||
func ValidateClientCredentialsRequest(ctx context.Context, request *oidc.ClientCredentialsRequest, exchanger Exchanger) (TokenRequest, Client, error) {
|
||||
storage, ok := exchanger.Storage().(ClientCredentialsStorage)
|
||||
if !ok {
|
||||
return nil, nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported")
|
||||
}
|
||||
|
||||
client, err := AuthorizeClientCredentialsClient(ctx, request, exchanger)
|
||||
client, err := AuthorizeClientCredentialsClient(ctx, request, storage)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -84,12 +84,8 @@ func ValidateClientCredentialsRequest(ctx context.Context, request *oidc.ClientC
|
|||
return tokenRequest, client, nil
|
||||
}
|
||||
|
||||
func AuthorizeClientCredentialsClient(ctx context.Context, request *oidc.ClientCredentialsRequest, exchanger Exchanger) (Client, error) {
|
||||
if err := AuthorizeClientIDSecret(ctx, request.ClientID, request.ClientSecret, exchanger.Storage()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := exchanger.Storage().GetClientByClientID(ctx, request.ClientID)
|
||||
func AuthorizeClientCredentialsClient(ctx context.Context, request *oidc.ClientCredentialsRequest, storage ClientCredentialsStorage) (Client, error) {
|
||||
client, err := storage.ClientCredentials(ctx, request.ClientID, request.ClientSecret)
|
||||
if err != nil {
|
||||
return nil, oidc.ErrInvalidClient().WithParent(err)
|
||||
}
|
||||
|
@ -102,7 +98,7 @@ func AuthorizeClientCredentialsClient(ctx context.Context, request *oidc.ClientC
|
|||
}
|
||||
|
||||
func CreateClientCredentialsTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) {
|
||||
accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeJWT, creator, client, "")
|
||||
accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -4,8 +4,8 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
// CodeExchange handles the OAuth 2.0 authorization_code grant, including
|
||||
|
|
|
@ -1,11 +1,399 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
// TokenExchange will handle the OAuth 2.0 token exchange grant ("urn:ietf:params:oauth:grant-type:token-exchange")
|
||||
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
RequestError(w, r, errors.New("unimplemented"))
|
||||
type TokenExchangeRequest interface {
|
||||
GetAMR() []string
|
||||
GetAudience() []string
|
||||
GetResourses() []string
|
||||
GetAuthTime() time.Time
|
||||
GetClientID() string
|
||||
GetScopes() []string
|
||||
GetSubject() string
|
||||
GetRequestedTokenType() oidc.TokenType
|
||||
|
||||
GetExchangeSubject() string
|
||||
GetExchangeSubjectTokenType() oidc.TokenType
|
||||
GetExchangeSubjectTokenIDOrToken() string
|
||||
GetExchangeSubjectTokenClaims() map[string]interface{}
|
||||
|
||||
GetExchangeActor() string
|
||||
GetExchangeActorTokenType() oidc.TokenType
|
||||
GetExchangeActorTokenIDOrToken() string
|
||||
GetExchangeActorTokenClaims() map[string]interface{}
|
||||
|
||||
SetCurrentScopes(scopes []string)
|
||||
SetRequestedTokenType(tt oidc.TokenType)
|
||||
SetSubject(subject string)
|
||||
}
|
||||
|
||||
type tokenExchangeRequest struct {
|
||||
exchangeSubjectTokenIDOrToken string
|
||||
exchangeSubjectTokenType oidc.TokenType
|
||||
exchangeSubject string
|
||||
exchangeSubjectTokenClaims map[string]interface{}
|
||||
|
||||
exchangeActorTokenIDOrToken string
|
||||
exchangeActorTokenType oidc.TokenType
|
||||
exchangeActor string
|
||||
exchangeActorTokenClaims map[string]interface{}
|
||||
|
||||
resource []string
|
||||
audience oidc.Audience
|
||||
scopes oidc.SpaceDelimitedArray
|
||||
requestedTokenType oidc.TokenType
|
||||
clientID string
|
||||
authTime time.Time
|
||||
subject string
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetAMR() []string {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetAudience() []string {
|
||||
return r.audience
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetResourses() []string {
|
||||
return r.resource
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetAuthTime() time.Time {
|
||||
return r.authTime
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetClientID() string {
|
||||
return r.clientID
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetScopes() []string {
|
||||
return r.scopes
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetRequestedTokenType() oidc.TokenType {
|
||||
return r.requestedTokenType
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetExchangeSubject() string {
|
||||
return r.exchangeSubject
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetExchangeSubjectTokenType() oidc.TokenType {
|
||||
return r.exchangeSubjectTokenType
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetExchangeSubjectTokenIDOrToken() string {
|
||||
return r.exchangeSubjectTokenIDOrToken
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetExchangeSubjectTokenClaims() map[string]interface{} {
|
||||
return r.exchangeSubjectTokenClaims
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetExchangeActor() string {
|
||||
return r.exchangeActor
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetExchangeActorTokenType() oidc.TokenType {
|
||||
return r.exchangeActorTokenType
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetExchangeActorTokenIDOrToken() string {
|
||||
return r.exchangeActorTokenIDOrToken
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetExchangeActorTokenClaims() map[string]interface{} {
|
||||
return r.exchangeActorTokenClaims
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) GetSubject() string {
|
||||
return r.subject
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) SetCurrentScopes(scopes []string) {
|
||||
r.scopes = scopes
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) SetRequestedTokenType(tt oidc.TokenType) {
|
||||
r.requestedTokenType = tt
|
||||
}
|
||||
|
||||
func (r *tokenExchangeRequest) SetSubject(subject string) {
|
||||
r.subject = subject
|
||||
}
|
||||
|
||||
// TokenExchange handles the OAuth 2.0 token exchange grant ("urn:ietf:params:oauth:grant-type:token-exchange")
|
||||
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
tokenExchangeReq, clientID, clientSecret, err := ParseTokenExchangeRequest(r, exchanger.Decoder())
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
}
|
||||
|
||||
tokenExchangeRequest, client, err := ValidateTokenExchangeRequest(r.Context(), tokenExchangeReq, clientID, clientSecret, exchanger)
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
resp, err := CreateTokenExchangeResponse(r.Context(), tokenExchangeRequest, client, exchanger)
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
httphelper.MarshalJSON(w, resp)
|
||||
}
|
||||
|
||||
// ParseTokenExchangeRequest parses the http request into oidc.TokenExchangeRequest
|
||||
func ParseTokenExchangeRequest(r *http.Request, decoder httphelper.Decoder) (_ *oidc.TokenExchangeRequest, clientID, clientSecret string, err error) {
|
||||
err = r.ParseForm()
|
||||
if err != nil {
|
||||
return nil, "", "", oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
|
||||
}
|
||||
|
||||
request := new(oidc.TokenExchangeRequest)
|
||||
err = decoder.Decode(request, r.Form)
|
||||
if err != nil {
|
||||
return nil, "", "", oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
|
||||
}
|
||||
|
||||
var ok bool
|
||||
if clientID, clientSecret, ok = r.BasicAuth(); ok {
|
||||
clientID, err = url.QueryUnescape(clientID)
|
||||
if err != nil {
|
||||
return nil, "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
|
||||
}
|
||||
|
||||
clientSecret, err = url.QueryUnescape(clientSecret)
|
||||
if err != nil {
|
||||
return nil, "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
|
||||
}
|
||||
}
|
||||
|
||||
return request, clientID, clientSecret, nil
|
||||
}
|
||||
|
||||
// ValidateTokenExchangeRequest validates the token exchange request parameters including authorization check of the client,
|
||||
// subject_token and actor_token
|
||||
func ValidateTokenExchangeRequest(
|
||||
ctx context.Context,
|
||||
oidcTokenExchangeRequest *oidc.TokenExchangeRequest,
|
||||
clientID, clientSecret string,
|
||||
exchanger Exchanger,
|
||||
) (TokenExchangeRequest, Client, error) {
|
||||
if oidcTokenExchangeRequest.SubjectToken == "" {
|
||||
return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token missing")
|
||||
}
|
||||
|
||||
if oidcTokenExchangeRequest.SubjectTokenType == "" {
|
||||
return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing")
|
||||
}
|
||||
|
||||
storage := exchanger.Storage()
|
||||
teStorage, ok := storage.(TokenExchangeStorage)
|
||||
if !ok {
|
||||
return nil, nil, oidc.ErrUnsupportedGrantType().WithDescription("token_exchange grant not supported")
|
||||
}
|
||||
|
||||
client, err := AuthorizeTokenExchangeClient(ctx, clientID, clientSecret, exchanger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if oidcTokenExchangeRequest.RequestedTokenType != "" && !oidcTokenExchangeRequest.RequestedTokenType.IsSupported() {
|
||||
return nil, nil, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported")
|
||||
}
|
||||
|
||||
if !oidcTokenExchangeRequest.SubjectTokenType.IsSupported() {
|
||||
return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported")
|
||||
}
|
||||
|
||||
if oidcTokenExchangeRequest.ActorTokenType != "" && !oidcTokenExchangeRequest.ActorTokenType.IsSupported() {
|
||||
return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported")
|
||||
}
|
||||
|
||||
exchangeSubjectTokenIDOrToken, exchangeSubject, exchangeSubjectTokenClaims, ok := GetTokenIDAndSubjectFromToken(ctx, exchanger,
|
||||
oidcTokenExchangeRequest.SubjectToken, oidcTokenExchangeRequest.SubjectTokenType, false)
|
||||
if !ok {
|
||||
return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid")
|
||||
}
|
||||
|
||||
var (
|
||||
exchangeActorTokenIDOrToken, exchangeActor string
|
||||
exchangeActorTokenClaims map[string]interface{}
|
||||
)
|
||||
if oidcTokenExchangeRequest.ActorToken != "" {
|
||||
exchangeActorTokenIDOrToken, exchangeActor, exchangeActorTokenClaims, ok = GetTokenIDAndSubjectFromToken(ctx, exchanger,
|
||||
oidcTokenExchangeRequest.ActorToken, oidcTokenExchangeRequest.ActorTokenType, true)
|
||||
if !ok {
|
||||
return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
req := &tokenExchangeRequest{
|
||||
exchangeSubjectTokenIDOrToken: exchangeSubjectTokenIDOrToken,
|
||||
exchangeSubjectTokenType: oidcTokenExchangeRequest.SubjectTokenType,
|
||||
exchangeSubject: exchangeSubject,
|
||||
exchangeSubjectTokenClaims: exchangeSubjectTokenClaims,
|
||||
|
||||
exchangeActorTokenIDOrToken: exchangeActorTokenIDOrToken,
|
||||
exchangeActorTokenType: oidcTokenExchangeRequest.ActorTokenType,
|
||||
exchangeActor: exchangeActor,
|
||||
exchangeActorTokenClaims: exchangeActorTokenClaims,
|
||||
|
||||
subject: exchangeSubject,
|
||||
resource: oidcTokenExchangeRequest.Resource,
|
||||
audience: oidcTokenExchangeRequest.Audience,
|
||||
scopes: oidcTokenExchangeRequest.Scopes,
|
||||
requestedTokenType: oidcTokenExchangeRequest.RequestedTokenType,
|
||||
clientID: client.GetID(),
|
||||
authTime: time.Now(),
|
||||
}
|
||||
|
||||
err = teStorage.ValidateTokenExchangeRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = teStorage.CreateTokenExchangeRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return req, client, nil
|
||||
}
|
||||
|
||||
func GetTokenIDAndSubjectFromToken(
|
||||
ctx context.Context,
|
||||
exchanger Exchanger,
|
||||
token string,
|
||||
tokenType oidc.TokenType,
|
||||
isActor bool,
|
||||
) (tokenIDOrToken, subject string, claims map[string]interface{}, ok bool) {
|
||||
switch tokenType {
|
||||
case oidc.AccessTokenType:
|
||||
var accessTokenClaims *oidc.AccessTokenClaims
|
||||
tokenIDOrToken, subject, accessTokenClaims, ok = getTokenIDAndClaims(ctx, exchanger, token)
|
||||
claims = accessTokenClaims.Claims
|
||||
case oidc.RefreshTokenType:
|
||||
refreshTokenRequest, err := exchanger.Storage().TokenRequestByRefreshToken(ctx, token)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
tokenIDOrToken, subject, ok = token, refreshTokenRequest.GetSubject(), true
|
||||
case oidc.IDTokenType:
|
||||
idTokenClaims, err := VerifyIDTokenHint[*oidc.IDTokenClaims](ctx, token, exchanger.IDTokenHintVerifier(ctx))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
tokenIDOrToken, subject, claims, ok = token, idTokenClaims.Subject, idTokenClaims.Claims, true
|
||||
}
|
||||
|
||||
if !ok {
|
||||
if verifier, ok := exchanger.Storage().(TokenExchangeTokensVerifierStorage); ok {
|
||||
var err error
|
||||
if isActor {
|
||||
tokenIDOrToken, subject, claims, err = verifier.VerifyExchangeActorToken(ctx, token, tokenType)
|
||||
} else {
|
||||
tokenIDOrToken, subject, claims, err = verifier.VerifyExchangeSubjectToken(ctx, token, tokenType)
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", nil, false
|
||||
}
|
||||
|
||||
return tokenIDOrToken, subject, claims, true
|
||||
}
|
||||
|
||||
return "", "", nil, false
|
||||
}
|
||||
|
||||
return tokenIDOrToken, subject, claims, true
|
||||
}
|
||||
|
||||
// AuthorizeTokenExchangeClient authorizes a client by validating the client_id and client_secret
|
||||
func AuthorizeTokenExchangeClient(ctx context.Context, clientID, clientSecret string, exchanger Exchanger) (client Client, err error) {
|
||||
if err := AuthorizeClientIDSecret(ctx, clientID, clientSecret, exchanger.Storage()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err = exchanger.Storage().GetClientByClientID(ctx, clientID)
|
||||
if err != nil {
|
||||
return nil, oidc.ErrInvalidClient().WithParent(err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func CreateTokenExchangeResponse(
|
||||
ctx context.Context,
|
||||
tokenExchangeRequest TokenExchangeRequest,
|
||||
client Client,
|
||||
creator TokenCreator,
|
||||
) (_ *oidc.TokenExchangeResponse, err error) {
|
||||
|
||||
var (
|
||||
token, refreshToken, tokenType string
|
||||
validity time.Duration
|
||||
)
|
||||
|
||||
switch tokenExchangeRequest.GetRequestedTokenType() {
|
||||
case oidc.AccessTokenType, oidc.RefreshTokenType:
|
||||
token, refreshToken, validity, err = CreateAccessToken(ctx, tokenExchangeRequest, client.AccessTokenType(), creator, client, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenType = oidc.BearerToken
|
||||
case oidc.IDTokenType:
|
||||
token, err = CreateIDToken(ctx, IssuerFromContext(ctx), tokenExchangeRequest, client.IDTokenLifetime(), "", "", creator.Storage(), client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// not applicable (see https://datatracker.ietf.org/doc/html/rfc8693#section-2-2-1-2-6)
|
||||
tokenType = "N_A"
|
||||
default:
|
||||
// oidc.JWTTokenType and other custom token types are not supported for issuing.
|
||||
// In the future it can be considered to have custom tokens generation logic injected via op configuration
|
||||
// or via expanding Storage interface
|
||||
oidc.ErrInvalidRequest().WithDescription("requested_token_type is invalid")
|
||||
}
|
||||
|
||||
exp := uint64(validity.Seconds())
|
||||
return &oidc.TokenExchangeResponse{
|
||||
AccessToken: token,
|
||||
IssuedTokenType: tokenExchangeRequest.GetRequestedTokenType(),
|
||||
TokenType: tokenType,
|
||||
ExpiresIn: exp,
|
||||
RefreshToken: refreshToken,
|
||||
Scopes: tokenExchangeRequest.GetScopes(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getTokenIDAndClaims(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, *oidc.AccessTokenClaims, bool) {
|
||||
tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
|
||||
if err == nil {
|
||||
splitToken := strings.Split(tokenIDSubject, ":")
|
||||
if len(splitToken) != 2 {
|
||||
return "", "", nil, false
|
||||
}
|
||||
|
||||
return splitToken[0], splitToken[1], nil, true
|
||||
}
|
||||
accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
|
||||
if err != nil {
|
||||
return "", "", nil, false
|
||||
}
|
||||
|
||||
return accessTokenClaims.JWTID, accessTokenClaims.Subject, accessTokenClaims, true
|
||||
}
|
||||
|
|
|
@ -1,24 +1,24 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type Introspector interface {
|
||||
Decoder() httphelper.Decoder
|
||||
Crypto() Crypto
|
||||
Storage() Storage
|
||||
AccessTokenVerifier() AccessTokenVerifier
|
||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
||||
}
|
||||
|
||||
type IntrospectorJWTProfile interface {
|
||||
Introspector
|
||||
JWTProfileVerifier() JWTProfileVerifier
|
||||
JWTProfileVerifier(context.Context) JWTProfileVerifier
|
||||
}
|
||||
|
||||
func introspectionHandler(introspector Introspector) func(http.ResponseWriter, *http.Request) {
|
||||
|
@ -28,7 +28,7 @@ func introspectionHandler(introspector Introspector) func(http.ResponseWriter, *
|
|||
}
|
||||
|
||||
func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspector) {
|
||||
response := oidc.NewIntrospectionResponse()
|
||||
response := new(oidc.IntrospectionResponse)
|
||||
token, clientID, err := ParseTokenIntrospectionRequest(r, introspector)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
|
@ -44,43 +44,24 @@ func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspecto
|
|||
httphelper.MarshalJSON(w, response)
|
||||
return
|
||||
}
|
||||
response.SetActive(true)
|
||||
response.Active = true
|
||||
httphelper.MarshalJSON(w, response)
|
||||
}
|
||||
|
||||
func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) {
|
||||
err = r.ParseForm()
|
||||
clientID, authenticated, err := ClientIDFromRequest(r, introspector)
|
||||
if err != nil {
|
||||
return "", "", errors.New("unable to parse request")
|
||||
return "", "", err
|
||||
}
|
||||
req := new(struct {
|
||||
oidc.IntrospectionRequest
|
||||
oidc.ClientAssertionParams
|
||||
})
|
||||
if !authenticated {
|
||||
return "", "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
|
||||
}
|
||||
|
||||
req := new(oidc.IntrospectionRequest)
|
||||
err = introspector.Decoder().Decode(req, r.Form)
|
||||
if err != nil {
|
||||
return "", "", errors.New("unable to parse request")
|
||||
}
|
||||
if introspectorJWTProfile, ok := introspector.(IntrospectorJWTProfile); ok && req.ClientAssertion != "" {
|
||||
profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, introspectorJWTProfile.JWTProfileVerifier())
|
||||
if err == nil {
|
||||
return req.Token, profile.Issuer, nil
|
||||
}
|
||||
}
|
||||
clientID, clientSecret, ok := r.BasicAuth()
|
||||
if ok {
|
||||
clientID, err = url.QueryUnescape(clientID)
|
||||
if err != nil {
|
||||
return "", "", errors.New("invalid basic auth header")
|
||||
}
|
||||
clientSecret, err = url.QueryUnescape(clientSecret)
|
||||
if err != nil {
|
||||
return "", "", errors.New("invalid basic auth header")
|
||||
}
|
||||
if err := introspector.Storage().AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return req.Token, clientID, nil
|
||||
}
|
||||
return "", "", errors.New("invalid authorization")
|
||||
|
||||
return req.Token, clientID, nil
|
||||
}
|
||||
|
|
|
@ -5,13 +5,13 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type JWTAuthorizationGrantExchanger interface {
|
||||
Exchanger
|
||||
JWTProfileVerifier() JWTProfileVerifier
|
||||
JWTProfileVerifier(context.Context) JWTProfileVerifier
|
||||
}
|
||||
|
||||
// JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant https://tools.ietf.org/html/rfc7523#section-2.1
|
||||
|
@ -21,7 +21,7 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizati
|
|||
RequestError(w, r, err)
|
||||
}
|
||||
|
||||
tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, exchanger.JWTProfileVerifier())
|
||||
tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, exchanger.JWTProfileVerifier(r.Context()))
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
return
|
||||
|
@ -53,27 +53,65 @@ func ParseJWTProfileGrantRequest(r *http.Request, decoder httphelper.Decoder) (*
|
|||
return tokenReq, nil
|
||||
}
|
||||
|
||||
// CreateJWTTokenResponse creates
|
||||
// CreateJWTTokenResponse creates an access_token response for a JWT Profile Grant request
|
||||
// by default the access_token is an opaque string, but can be specified by implementing the JWTProfileTokenStorage interface
|
||||
func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator) (*oidc.AccessTokenResponse, error) {
|
||||
id, exp, err := creator.Storage().CreateAccessToken(ctx, tokenRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accessToken, err := CreateBearerToken(id, tokenRequest.GetSubject(), creator.Crypto())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// return an opaque token as default to not break current implementations
|
||||
tokenType := AccessTokenTypeBearer
|
||||
|
||||
// the current CreateAccessToken function, esp. CreateJWT requires an implementation of an AccessTokenClient
|
||||
client := &jwtProfileClient{
|
||||
id: tokenRequest.GetSubject(),
|
||||
}
|
||||
|
||||
// by implementing the JWTProfileTokenStorage the storage can specify the AccessTokenType to be returned
|
||||
tokenStorage, ok := creator.Storage().(JWTProfileTokenStorage)
|
||||
if ok {
|
||||
var err error
|
||||
tokenType, err = tokenStorage.JWTProfileTokenType(ctx, tokenRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oidc.AccessTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
TokenType: oidc.BearerToken,
|
||||
ExpiresIn: uint64(exp.Sub(time.Now().UTC()).Seconds()),
|
||||
ExpiresIn: uint64(validity.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ParseJWTProfileRequest has been renamed to ParseJWTProfileGrantRequest
|
||||
//
|
||||
//deprecated: use ParseJWTProfileGrantRequest
|
||||
// deprecated: use ParseJWTProfileGrantRequest
|
||||
func ParseJWTProfileRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.JWTProfileGrantRequest, error) {
|
||||
return ParseJWTProfileGrantRequest(r, decoder)
|
||||
}
|
||||
|
||||
type jwtProfileClient struct {
|
||||
id string
|
||||
}
|
||||
|
||||
func (j *jwtProfileClient) GetID() string {
|
||||
return j.id
|
||||
}
|
||||
|
||||
func (j *jwtProfileClient) ClockSkew() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (j *jwtProfileClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string {
|
||||
return func(scopes []string) []string {
|
||||
return scopes
|
||||
}
|
||||
}
|
||||
|
||||
func (j *jwtProfileClient) GrantTypes() []oidc.GrantType {
|
||||
return []oidc.GrantType{
|
||||
oidc.GrantTypeBearer,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,9 +6,9 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/pkg/strings"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/strings"
|
||||
)
|
||||
|
||||
type RefreshTokenRequest interface {
|
||||
|
|
|
@ -5,15 +5,13 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type Exchanger interface {
|
||||
Issuer() string
|
||||
Storage() Storage
|
||||
Decoder() httphelper.Decoder
|
||||
Signer() Signer
|
||||
Crypto() Crypto
|
||||
AuthMethodPostSupported() bool
|
||||
AuthMethodPrivateKeyJWTSupported() bool
|
||||
|
@ -21,6 +19,9 @@ type Exchanger interface {
|
|||
GrantTypeTokenExchangeSupported() bool
|
||||
GrantTypeJWTAuthorizationSupported() bool
|
||||
GrantTypeClientCredentialsSupported() bool
|
||||
GrantTypeDeviceCodeSupported() bool
|
||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
||||
}
|
||||
|
||||
func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -56,6 +57,11 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
|||
ClientCredentialsExchange(w, r, exchanger)
|
||||
return
|
||||
}
|
||||
case string(oidc.GrantTypeDeviceCode):
|
||||
if exchanger.GrantTypeDeviceCodeSupported() {
|
||||
DeviceAccessToken(w, r, exchanger)
|
||||
return
|
||||
}
|
||||
case "":
|
||||
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"))
|
||||
return
|
||||
|
@ -122,7 +128,7 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.C
|
|||
// AuthorizePrivateJWTKey authorizes a client by validating the client_assertion's signature with a previously
|
||||
// registered public key (JWT Profile)
|
||||
func AuthorizePrivateJWTKey(ctx context.Context, clientAssertion string, exchanger JWTAuthorizationGrantExchanger) (Client, error) {
|
||||
jwtReq, err := VerifyJWTAssertion(ctx, clientAssertion, exchanger.JWTProfileVerifier())
|
||||
jwtReq, err := VerifyJWTAssertion(ctx, clientAssertion, exchanger.JWTProfileVerifier(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -136,8 +142,8 @@ func AuthorizePrivateJWTKey(ctx context.Context, clientAssertion string, exchang
|
|||
return client, nil
|
||||
}
|
||||
|
||||
// ValidateGrantType ensures that the requested grant_type is allowed by the Client
|
||||
func ValidateGrantType(client Client, grantType oidc.GrantType) bool {
|
||||
// ValidateGrantType ensures that the requested grant_type is allowed by the client
|
||||
func ValidateGrantType(client interface{ GrantTypes() []oidc.GrantType }, grantType oidc.GrantType) bool {
|
||||
if client == nil {
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -7,22 +7,22 @@ import (
|
|||
"net/url"
|
||||
"strings"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type Revoker interface {
|
||||
Decoder() httphelper.Decoder
|
||||
Crypto() Crypto
|
||||
Storage() Storage
|
||||
AccessTokenVerifier() AccessTokenVerifier
|
||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
||||
AuthMethodPrivateKeyJWTSupported() bool
|
||||
AuthMethodPostSupported() bool
|
||||
}
|
||||
|
||||
type RevokerJWTProfile interface {
|
||||
Revoker
|
||||
JWTProfileVerifier() JWTProfileVerifier
|
||||
JWTProfileVerifier(context.Context) JWTProfileVerifier
|
||||
}
|
||||
|
||||
func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) {
|
||||
|
@ -39,8 +39,8 @@ func Revoke(w http.ResponseWriter, r *http.Request, revoker Revoker) {
|
|||
}
|
||||
var subject string
|
||||
doDecrypt := true
|
||||
if canRefreshInfo, ok := revoker.Storage().(CanRefreshTokenInfo); ok && tokenTypeHint != "access_token" {
|
||||
userID, tokenID, err := canRefreshInfo.GetRefreshTokenInfo(r.Context(), clientID, token)
|
||||
if tokenTypeHint != "access_token" {
|
||||
userID, tokenID, err := revoker.Storage().GetRefreshTokenInfo(r.Context(), clientID, token)
|
||||
if err != nil {
|
||||
// An invalid refresh token means that we'll try other things (leaving doDecrypt==true)
|
||||
if !errors.Is(err, ErrInvalidRefreshToken) {
|
||||
|
@ -87,7 +87,7 @@ func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, token
|
|||
if !ok || !revoker.AuthMethodPrivateKeyJWTSupported() {
|
||||
return "", "", "", oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
|
||||
}
|
||||
profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, revokerJWTProfile.JWTProfileVerifier())
|
||||
profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, revokerJWTProfile.JWTProfileVerifier(r.Context()))
|
||||
if err == nil {
|
||||
return req.Token, req.TokenTypeHint, profile.Issuer, nil
|
||||
}
|
||||
|
@ -151,9 +151,9 @@ func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider Use
|
|||
}
|
||||
return splitToken[0], splitToken[1], true
|
||||
}
|
||||
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier())
|
||||
accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
|
||||
if err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true
|
||||
return accessTokenClaims.JWTID, accessTokenClaims.Subject, true
|
||||
}
|
||||
|
|
|
@ -6,15 +6,15 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type UserinfoProvider interface {
|
||||
Decoder() httphelper.Decoder
|
||||
Crypto() Crypto
|
||||
Storage() Storage
|
||||
AccessTokenVerifier() AccessTokenVerifier
|
||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
||||
}
|
||||
|
||||
func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) {
|
||||
|
@ -34,7 +34,7 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP
|
|||
http.Error(w, "access token invalid", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
info := oidc.NewUserInfo()
|
||||
info := new(oidc.UserInfo)
|
||||
err = userinfoProvider.Storage().SetUserinfoFromToken(r.Context(), info, tokenID, subject, r.Header.Get("origin"))
|
||||
if err != nil {
|
||||
httphelper.MarshalJSONWithStatus(w, err, http.StatusForbidden)
|
||||
|
@ -81,9 +81,9 @@ func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider
|
|||
}
|
||||
return splitToken[0], splitToken[1], true
|
||||
}
|
||||
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier())
|
||||
accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
|
||||
if err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true
|
||||
return accessTokenClaims.JWTID, accessTokenClaims.Subject, true
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type AccessTokenVerifier interface {
|
||||
|
@ -18,8 +18,6 @@ type accessTokenVerifier struct {
|
|||
maxAgeIAT time.Duration
|
||||
offset time.Duration
|
||||
supportedSignAlgs []string
|
||||
maxAge time.Duration
|
||||
acr oidc.ACRVerifier
|
||||
keySet oidc.KeySet
|
||||
}
|
||||
|
||||
|
@ -67,29 +65,29 @@ func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTok
|
|||
return verifier
|
||||
}
|
||||
|
||||
// VerifyAccessToken validates the access token (issuer, signature and expiration)
|
||||
func VerifyAccessToken(ctx context.Context, token string, v AccessTokenVerifier) (oidc.AccessTokenClaims, error) {
|
||||
claims := oidc.EmptyAccessTokenClaims()
|
||||
// VerifyAccessToken validates the access token (issuer, signature and expiration).
|
||||
func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v AccessTokenVerifier) (claims C, err error) {
|
||||
var nilClaims C
|
||||
|
||||
decrypted, err := oidc.DecryptToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
payload, err := oidc.ParseToken(decrypted, claims)
|
||||
payload, err := oidc.ParseToken(decrypted, &claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
|
|
70
pkg/op/verifier_access_token_example_test.go
Normal file
70
pkg/op/verifier_access_token_example_test.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package op_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
)
|
||||
|
||||
// MyCustomClaims extends the TokenClaims base,
|
||||
// so it implements the oidc.Claims interface.
|
||||
// Instead of carrying a map, we add needed fields// to the struct for type safe access.
|
||||
type MyCustomClaims struct {
|
||||
oidc.TokenClaims
|
||||
NotBefore oidc.Time `json:"nbf,omitempty"`
|
||||
CodeHash string `json:"c_hash,omitempty"`
|
||||
SessionID string `json:"sid,omitempty"`
|
||||
Scopes []string `json:"scope,omitempty"`
|
||||
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
|
||||
Foo string `json:"foo,omitempty"`
|
||||
Bar *Nested `json:"bar,omitempty"`
|
||||
}
|
||||
|
||||
// Nested struct types are also possible.
|
||||
type Nested struct {
|
||||
Count int `json:"count,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
/*
|
||||
accessToken carries the following claims. foo and bar are custom claims
|
||||
|
||||
{
|
||||
"aud": [
|
||||
"unit",
|
||||
"test"
|
||||
],
|
||||
"bar": {
|
||||
"count": 22,
|
||||
"tags": [
|
||||
"some",
|
||||
"tags"
|
||||
]
|
||||
},
|
||||
"exp": 4802234675,
|
||||
"foo": "Hello, World!",
|
||||
"iat": 1678097014,
|
||||
"iss": "local.com",
|
||||
"jti": "9876",
|
||||
"nbf": 1678097014,
|
||||
"sub": "tim@local.com"
|
||||
}
|
||||
*/
|
||||
const accessToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhdWQiOlsidW5pdCIsInRlc3QiXSwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiZXhwIjo0ODAyMjM0Njc1LCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MDk3MDE0LCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MDk3MDE0LCJzdWIiOiJ0aW1AbG9jYWwuY29tIn0.OUgk-B7OXjYlYFj-nogqSDJiQE19tPrbzqUHEAjcEiJkaWo6-IpGVfDiGKm-TxjXQsNScxpaY0Pg3XIh1xK6TgtfYtoLQm-5RYw_mXgb9xqZB2VgPs6nNEYFUDM513MOU0EBc0QMyqAEGzW-HiSPAb4ugCvkLtM1yo11Xyy6vksAdZNs_mJDT4X3vFXnr0jk0ugnAW6fTN3_voC0F_9HQUAkmd750OIxkAHxAMvEPQcpbLHenVvX_Q0QMrzClVrxehn5TVMfmkYYg7ocr876Bq9xQGPNHAcrwvVIJqdg5uMUA38L3HC2BEueG6furZGvc7-qDWAT1VR9liM5ieKpPg`
|
||||
|
||||
func ExampleVerifyAccessToken_customClaims() {
|
||||
v := op.NewAccessTokenVerifier("local.com", tu.KeySet{})
|
||||
|
||||
// VerifyAccessToken can be called with the *MyCustomClaims.
|
||||
claims, err := op.VerifyAccessToken[*MyCustomClaims](context.TODO(), accessToken, v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Here we have typesafe access to the custom claims
|
||||
fmt.Println(claims.Foo, claims.Bar.Count, claims.Bar.Tags)
|
||||
// Output: Hello, World! 22 [some tags]
|
||||
}
|
126
pkg/op/verifier_access_token_test.go
Normal file
126
pkg/op/verifier_access_token_test.go
Normal file
|
@ -0,0 +1,126 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
func TestNewAccessTokenVerifier(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
keySet oidc.KeySet
|
||||
opts []AccessTokenVerifierOpt
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want AccessTokenVerifier
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
args: args{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
},
|
||||
want: &accessTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with signature algorithm",
|
||||
args: args{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
opts: []AccessTokenVerifierOpt{
|
||||
WithSupportedAccessTokenSigningAlgorithms("ABC", "DEF"),
|
||||
},
|
||||
},
|
||||
want: &accessTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
supportedSignAlgs: []string{"ABC", "DEF"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewAccessTokenVerifier(tt.args.issuer, tt.args.keySet, tt.args.opts...)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAccessToken(t *testing.T) {
|
||||
verifier := &accessTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
maxAgeIAT: 2 * time.Minute,
|
||||
offset: time.Second,
|
||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
keySet: tu.KeySet{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenClaims func() (string, *oidc.AccessTokenClaims)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
tokenClaims: tu.ValidAccessToken,
|
||||
},
|
||||
{
|
||||
name: "parse err",
|
||||
tokenClaims: func() (string, *oidc.AccessTokenClaims) { return "~~~~", nil },
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid signature",
|
||||
tokenClaims: func() (string, *oidc.AccessTokenClaims) { return tu.InvalidSignatureToken, nil },
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong issuer",
|
||||
tokenClaims: func() (string, *oidc.AccessTokenClaims) {
|
||||
return tu.NewAccessToken(
|
||||
"foo", tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidJWTID, tu.ValidClientID,
|
||||
tu.ValidSkew,
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired",
|
||||
tokenClaims: func() (string, *oidc.AccessTokenClaims) {
|
||||
return tu.NewAccessToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration.Add(-time.Hour), tu.ValidJWTID, tu.ValidClientID,
|
||||
tu.ValidSkew,
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, want := tt.tokenClaims()
|
||||
|
||||
got, err := VerifyAccessToken[*oidc.AccessTokenClaims](context.Background(), token, verifier)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, got, want)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -4,7 +4,7 @@ import (
|
|||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type IDTokenHintVerifier interface {
|
||||
|
@ -73,41 +73,41 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHi
|
|||
}
|
||||
|
||||
// VerifyIDTokenHint validates the id token according to
|
||||
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (oidc.IDTokenClaims, error) {
|
||||
claims := oidc.EmptyIDTokenClaims()
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTokenHintVerifier) (claims C, err error) {
|
||||
var nilClaims C
|
||||
|
||||
decrypted, err := oidc.DecryptToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
payload, err := oidc.ParseToken(decrypted, claims)
|
||||
payload, err := oidc.ParseToken(decrypted, &claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
|
||||
return nil, err
|
||||
return nilClaims, err
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
|
161
pkg/op/verifier_id_token_hint_test.go
Normal file
161
pkg/op/verifier_id_token_hint_test.go
Normal file
|
@ -0,0 +1,161 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
func TestNewIDTokenHintVerifier(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
keySet oidc.KeySet
|
||||
opts []IDTokenHintVerifierOpt
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want IDTokenHintVerifier
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
args: args{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
},
|
||||
want: &idTokenHintVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with signature algorithm",
|
||||
args: args{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
opts: []IDTokenHintVerifierOpt{
|
||||
WithSupportedIDTokenHintSigningAlgorithms("ABC", "DEF"),
|
||||
},
|
||||
},
|
||||
want: &idTokenHintVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
keySet: tu.KeySet{},
|
||||
supportedSignAlgs: []string{"ABC", "DEF"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewIDTokenHintVerifier(tt.args.issuer, tt.args.keySet, tt.args.opts...)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyIDTokenHint(t *testing.T) {
|
||||
verifier := &idTokenHintVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
maxAgeIAT: 2 * time.Minute,
|
||||
offset: time.Second,
|
||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
maxAge: 2 * time.Minute,
|
||||
acr: tu.ACRVerify,
|
||||
keySet: tu.KeySet{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenClaims func() (string, *oidc.IDTokenClaims)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
tokenClaims: tu.ValidIDToken,
|
||||
},
|
||||
{
|
||||
name: "parse err",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid signature",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong issuer",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
"foo", tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong IAT",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong acr",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
"else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired auth",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime.Add(-time.Hour), tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, want := tt.tokenClaims()
|
||||
|
||||
got, err := VerifyIDTokenHint[*oidc.IDTokenClaims](context.Background(), token, verifier)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, got, want)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -8,7 +8,7 @@ import (
|
|||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
)
|
||||
|
||||
type JWTProfileVerifier interface {
|
||||
|
@ -104,7 +104,7 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif
|
|||
}
|
||||
|
||||
type jwtProfileKeyStorage interface {
|
||||
GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
|
||||
GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
|
||||
}
|
||||
|
||||
func SubjectIsIssuer(request *oidc.JWTTokenRequest) error {
|
||||
|
@ -122,7 +122,7 @@ type jwtProfileKeySet struct {
|
|||
// VerifySignature implements oidc.KeySet by getting the public key from Storage implementation
|
||||
func (k *jwtProfileKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
|
||||
keyID, _ := oidc.GetKeyIDAndAlg(jws)
|
||||
key, err := k.storage.GetKeyByIDAndUserID(ctx, keyID, k.clientID)
|
||||
key, err := k.storage.GetKeyByIDAndClientID(ctx, keyID, k.clientID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error fetching keys: %w", err)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue