Merge 6e7ee79a68
into d6e37fa741
This commit is contained in:
commit
2722c0e042
3 changed files with 171 additions and 22 deletions
|
@ -3,6 +3,9 @@ package client_test
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/hkdf"
|
||||||
|
cryptoRand "crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
@ -18,6 +21,8 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/gorilla/securecookie"
|
||||||
"github.com/jeremija/gosubmit"
|
"github.com/jeremija/gosubmit"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -42,6 +47,15 @@ var Logger = slog.New(
|
||||||
|
|
||||||
var CTX context.Context
|
var CTX context.Context
|
||||||
|
|
||||||
|
type cookieSpec struct {
|
||||||
|
cookieHandler *httphelper.CookieHandler
|
||||||
|
extraCookies []*http.Cookie
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultCookieSpec = cookieSpec{
|
||||||
|
cookieHandler: httphelper.NewCookieHandler([]byte("test1234test1234"), []byte("test1234test1234"), httphelper.WithUnsecure()),
|
||||||
|
}
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
os.Exit(func() int {
|
os.Exit(func() int {
|
||||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT)
|
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT)
|
||||||
|
@ -53,14 +67,57 @@ func TestMain(m *testing.M) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRelyingPartySession(t *testing.T) {
|
func TestRelyingPartySession(t *testing.T) {
|
||||||
|
secret := make([]byte, 512)
|
||||||
|
_, _ = cryptoRand.Read(secret)
|
||||||
|
hashFunc := sha256.New
|
||||||
|
requestAwareCookieHandler := httphelper.NewRequestAwareCookieHandler(func(r *http.Request) (*securecookie.SecureCookie, error) {
|
||||||
|
loginIDCookie, err := r.Cookie("login_id")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
loginID, err := uuid.Parse(loginIDCookie.Value)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
prk, err := hkdf.Extract(hashFunc, secret, loginID[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
hash, err := hkdf.Expand(hashFunc, prk, "INTEGRITY", 64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
block, err := hkdf.Expand(hashFunc, prk, "ENCRYPTION", 32)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return securecookie.New(hash, block), nil
|
||||||
|
}, httphelper.WithUnsecure())
|
||||||
|
|
||||||
|
loginID := uuid.New()
|
||||||
|
loginIDCookie := &http.Cookie{
|
||||||
|
Name: "login_id",
|
||||||
|
Value: loginID.String(),
|
||||||
|
Secure: false,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
Path: "/",
|
||||||
|
}
|
||||||
|
|
||||||
|
cookieCases := []cookieSpec{
|
||||||
|
defaultCookieSpec,
|
||||||
|
{
|
||||||
|
cookieHandler: requestAwareCookieHandler,
|
||||||
|
extraCookies: []*http.Cookie{
|
||||||
|
loginIDCookie,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cookieCase := range cookieCases {
|
||||||
for _, wrapServer := range []bool{false, true} {
|
for _, wrapServer := range []bool{false, true} {
|
||||||
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
|
t.Run(fmt.Sprint("wrapServer ", wrapServer, " requestAwareCookieHandler ", cookieCase.cookieHandler.IsRequestAware()), func(t *testing.T) {
|
||||||
testRelyingPartySession(t, wrapServer)
|
testRelyingPartySession(t, wrapServer, cookieCase)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func testRelyingPartySession(t *testing.T, wrapServer bool) {
|
func testRelyingPartySession(t *testing.T, wrapServer bool, cookieSpec cookieSpec) {
|
||||||
t.Log("------- start example OP ------")
|
t.Log("------- start example OP ------")
|
||||||
targetURL := "http://local-site"
|
targetURL := "http://local-site"
|
||||||
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
|
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
|
||||||
|
@ -74,7 +131,7 @@ func testRelyingPartySession(t *testing.T, wrapServer bool) {
|
||||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||||
|
|
||||||
t.Log("------- run authorization code flow ------")
|
t.Log("------- run authorization code flow ------")
|
||||||
provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret")
|
provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret", cookieSpec)
|
||||||
|
|
||||||
t.Log("------- refresh tokens ------")
|
t.Log("------- refresh tokens ------")
|
||||||
|
|
||||||
|
@ -220,7 +277,7 @@ func testResourceServerTokenExchange(t *testing.T, wrapServer bool) {
|
||||||
clientSecret := "secret"
|
clientSecret := "secret"
|
||||||
|
|
||||||
t.Log("------- run authorization code flow ------")
|
t.Log("------- run authorization code flow ------")
|
||||||
provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)
|
provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret, defaultCookieSpec)
|
||||||
|
|
||||||
resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret)
|
resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret)
|
||||||
require.NoError(t, err, "new resource server")
|
require.NoError(t, err, "new resource server")
|
||||||
|
@ -275,7 +332,7 @@ func testResourceServerTokenExchange(t *testing.T, wrapServer bool) {
|
||||||
require.Nil(t, tokenExchangeResponse, "token exchange response")
|
require.Nil(t, tokenExchangeResponse, "token exchange response")
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) {
|
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string, cookieSpec cookieSpec) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) {
|
||||||
targetURL := "http://local-site"
|
targetURL := "http://local-site"
|
||||||
localURL, err := url.Parse(targetURL + "/login?requestID=1234")
|
localURL, err := url.Parse(targetURL + "/login?requestID=1234")
|
||||||
require.NoError(t, err, "local url")
|
require.NoError(t, err, "local url")
|
||||||
|
@ -294,8 +351,6 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("------- create RP ------")
|
t.Log("------- create RP ------")
|
||||||
key := []byte("test1234test1234")
|
|
||||||
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
|
|
||||||
provider, err = rp.NewRelyingPartyOIDC(
|
provider, err = rp.NewRelyingPartyOIDC(
|
||||||
CTX,
|
CTX,
|
||||||
opServer.URL,
|
opServer.URL,
|
||||||
|
@ -303,7 +358,7 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
|
||||||
clientSecret,
|
clientSecret,
|
||||||
targetURL,
|
targetURL,
|
||||||
[]string{"openid", "email", "profile", "offline_access"},
|
[]string{"openid", "email", "profile", "offline_access"},
|
||||||
rp.WithPKCE(cookieHandler),
|
rp.WithPKCE(cookieSpec.cookieHandler),
|
||||||
rp.WithAuthStyle(oauth2.AuthStyleInHeader),
|
rp.WithAuthStyle(oauth2.AuthStyleInHeader),
|
||||||
rp.WithVerifierOpts(
|
rp.WithVerifierOpts(
|
||||||
rp.WithIssuedAtOffset(5*time.Second),
|
rp.WithIssuedAtOffset(5*time.Second),
|
||||||
|
@ -317,6 +372,11 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
|
||||||
state := "state-" + strconv.FormatInt(seed.Int63(), 25)
|
state := "state-" + strconv.FormatInt(seed.Int63(), 25)
|
||||||
capturedW := httptest.NewRecorder()
|
capturedW := httptest.NewRecorder()
|
||||||
get := httptest.NewRequest("GET", localURL.String(), nil)
|
get := httptest.NewRequest("GET", localURL.String(), nil)
|
||||||
|
for _, cookie := range cookieSpec.extraCookies {
|
||||||
|
get.AddCookie(cookie)
|
||||||
|
http.SetCookie(capturedW, cookie)
|
||||||
|
}
|
||||||
|
|
||||||
rp.AuthURLHandler(func() string { return state }, provider,
|
rp.AuthURLHandler(func() string { return state }, provider,
|
||||||
rp.WithPromptURLParam("Hello, World!", "Goodbye, World!"),
|
rp.WithPromptURLParam("Hello, World!", "Goodbye, World!"),
|
||||||
rp.WithURLParam("custom", "param"),
|
rp.WithURLParam("custom", "param"),
|
||||||
|
|
|
@ -410,12 +410,19 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParam
|
||||||
}
|
}
|
||||||
|
|
||||||
state := stateFn()
|
state := stateFn()
|
||||||
if err := trySetStateCookie(w, state, rp); err != nil {
|
if err := trySetStateCookie(r, w, state, rp); err != nil {
|
||||||
unauthorizedError(w, r, "failed to create state cookie: "+err.Error(), state, rp)
|
unauthorizedError(w, r, "failed to create state cookie: "+err.Error(), state, rp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if rp.IsPKCE() {
|
if rp.IsPKCE() {
|
||||||
codeChallenge, err := GenerateAndStoreCodeChallenge(w, rp)
|
var codeChallenge string
|
||||||
|
var err error
|
||||||
|
if rp.CookieHandler().IsRequestAware() {
|
||||||
|
codeChallenge, err = GenerateAndStoreCodeChallengeWithRequest(r, w, rp)
|
||||||
|
} else {
|
||||||
|
codeChallenge, err = GenerateAndStoreCodeChallenge(w, rp)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
unauthorizedError(w, r, "failed to create code challenge: "+err.Error(), state, rp)
|
unauthorizedError(w, r, "failed to create code challenge: "+err.Error(), state, rp)
|
||||||
return
|
return
|
||||||
|
@ -436,6 +443,15 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri
|
||||||
return oidc.NewSHACodeChallenge(codeVerifier), nil
|
return oidc.NewSHACodeChallenge(codeVerifier), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenerateAndStoreCodeChallenge generates a PKCE code challenge and stores its verifier into a secure cookie
|
||||||
|
func GenerateAndStoreCodeChallengeWithRequest(r *http.Request, w http.ResponseWriter, rp RelyingParty) (string, error) {
|
||||||
|
codeVerifier := base64.RawURLEncoding.EncodeToString([]byte(uuid.New().String()))
|
||||||
|
if err := rp.CookieHandler().SetRequestAwareCookie(r, w, pkceCode, codeVerifier); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return oidc.NewSHACodeChallenge(codeVerifier), nil
|
||||||
|
}
|
||||||
|
|
||||||
// ErrMissingIDToken is returned when an id_token was expected,
|
// ErrMissingIDToken is returned when an id_token was expected,
|
||||||
// but not received in the token response.
|
// but not received in the token response.
|
||||||
var ErrMissingIDToken = errors.New("id_token missing")
|
var ErrMissingIDToken = errors.New("id_token missing")
|
||||||
|
@ -607,9 +623,16 @@ func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject st
|
||||||
return userinfo, nil
|
return userinfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func trySetStateCookie(w http.ResponseWriter, state string, rp RelyingParty) error {
|
func trySetStateCookie(r *http.Request, w http.ResponseWriter, state string, rp RelyingParty) error {
|
||||||
if rp.CookieHandler() != nil {
|
if rp.CookieHandler() != nil {
|
||||||
if err := rp.CookieHandler().SetCookie(w, stateParam, state); err != nil {
|
var err error
|
||||||
|
if rp.CookieHandler().IsRequestAware() {
|
||||||
|
err = rp.CookieHandler().SetRequestAwareCookie(r, w, stateParam, state)
|
||||||
|
} else {
|
||||||
|
err = rp.CookieHandler().SetCookie(w, stateParam, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
|
|
||||||
type CookieHandler struct {
|
type CookieHandler struct {
|
||||||
securecookie *securecookie.SecureCookie
|
securecookie *securecookie.SecureCookie
|
||||||
|
secureCookieFunc func(r *http.Request) (*securecookie.SecureCookie, error)
|
||||||
secureOnly bool
|
secureOnly bool
|
||||||
sameSite http.SameSite
|
sameSite http.SameSite
|
||||||
maxAge int
|
maxAge int
|
||||||
|
@ -30,6 +31,21 @@ func NewCookieHandler(hashKey, encryptKey []byte, opts ...CookieHandlerOpt) *Coo
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewRequestAwareCookieHandler(secureCookieFunc func(r *http.Request) (*securecookie.SecureCookie, error), opts ...CookieHandlerOpt) *CookieHandler {
|
||||||
|
c := &CookieHandler{
|
||||||
|
secureCookieFunc: secureCookieFunc,
|
||||||
|
secureOnly: true,
|
||||||
|
sameSite: http.SameSiteLaxMode,
|
||||||
|
path: "/",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
type CookieHandlerOpt func(*CookieHandler)
|
type CookieHandlerOpt func(*CookieHandler)
|
||||||
|
|
||||||
func WithUnsecure() CookieHandlerOpt {
|
func WithUnsecure() CookieHandlerOpt {
|
||||||
|
@ -47,6 +63,10 @@ func WithSameSite(sameSite http.SameSite) CookieHandlerOpt {
|
||||||
func WithMaxAge(maxAge int) CookieHandlerOpt {
|
func WithMaxAge(maxAge int) CookieHandlerOpt {
|
||||||
return func(c *CookieHandler) {
|
return func(c *CookieHandler) {
|
||||||
c.maxAge = maxAge
|
c.maxAge = maxAge
|
||||||
|
if c.IsRequestAware() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.securecookie.MaxAge(maxAge)
|
c.securecookie.MaxAge(maxAge)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -68,8 +88,17 @@ func (c *CookieHandler) CheckCookie(r *http.Request, name string) (string, error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
secureCookie := c.securecookie
|
||||||
|
if c.IsRequestAware() {
|
||||||
|
secureCookie, err = c.secureCookieFunc(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var value string
|
var value string
|
||||||
if err := c.securecookie.Decode(name, cookie.Value, &value); err != nil {
|
if err := secureCookie.Decode(name, cookie.Value, &value); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return value, nil
|
return value, nil
|
||||||
|
@ -87,6 +116,10 @@ func (c *CookieHandler) CheckQueryCookie(r *http.Request, name string) (string,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CookieHandler) SetCookie(w http.ResponseWriter, name, value string) error {
|
func (c *CookieHandler) SetCookie(w http.ResponseWriter, name, value string) error {
|
||||||
|
if c.IsRequestAware() {
|
||||||
|
return errors.New("Cookie handler is request aware")
|
||||||
|
}
|
||||||
|
|
||||||
encoded, err := c.securecookie.Encode(name, value)
|
encoded, err := c.securecookie.Encode(name, value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -104,6 +137,35 @@ func (c *CookieHandler) SetCookie(w http.ResponseWriter, name, value string) err
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *CookieHandler) SetRequestAwareCookie(r *http.Request, w http.ResponseWriter, name string, value string) error {
|
||||||
|
if !c.IsRequestAware() {
|
||||||
|
return errors.New("Cookie handler is not request aware")
|
||||||
|
}
|
||||||
|
|
||||||
|
secureCookie, err := c.secureCookieFunc(r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded, err := secureCookie.Encode(name, value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: encoded,
|
||||||
|
Domain: c.domain,
|
||||||
|
Path: c.path,
|
||||||
|
MaxAge: c.maxAge,
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: c.secureOnly,
|
||||||
|
SameSite: c.sameSite,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *CookieHandler) DeleteCookie(w http.ResponseWriter, name string) {
|
func (c *CookieHandler) DeleteCookie(w http.ResponseWriter, name string) {
|
||||||
http.SetCookie(w, &http.Cookie{
|
http.SetCookie(w, &http.Cookie{
|
||||||
Name: name,
|
Name: name,
|
||||||
|
@ -116,3 +178,7 @@ func (c *CookieHandler) DeleteCookie(w http.ResponseWriter, name string) {
|
||||||
SameSite: c.sameSite,
|
SameSite: c.sameSite,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *CookieHandler) IsRequestAware() bool {
|
||||||
|
return c.secureCookieFunc != nil
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue