This commit is contained in:
Mark Laing 2025-06-17 08:03:25 -05:00 committed by GitHub
commit 2722c0e042
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 171 additions and 22 deletions

View file

@ -3,6 +3,9 @@ package client_test
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/hkdf"
cryptoRand "crypto/rand"
"crypto/sha256"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
@ -18,6 +21,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/uuid"
"github.com/gorilla/securecookie"
"github.com/jeremija/gosubmit" "github.com/jeremija/gosubmit"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -42,6 +47,15 @@ var Logger = slog.New(
var CTX context.Context var CTX context.Context
type cookieSpec struct {
cookieHandler *httphelper.CookieHandler
extraCookies []*http.Cookie
}
var defaultCookieSpec = cookieSpec{
cookieHandler: httphelper.NewCookieHandler([]byte("test1234test1234"), []byte("test1234test1234"), httphelper.WithUnsecure()),
}
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
os.Exit(func() int { os.Exit(func() int {
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT) ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT)
@ -53,14 +67,57 @@ func TestMain(m *testing.M) {
} }
func TestRelyingPartySession(t *testing.T) { func TestRelyingPartySession(t *testing.T) {
secret := make([]byte, 512)
_, _ = cryptoRand.Read(secret)
hashFunc := sha256.New
requestAwareCookieHandler := httphelper.NewRequestAwareCookieHandler(func(r *http.Request) (*securecookie.SecureCookie, error) {
loginIDCookie, err := r.Cookie("login_id")
require.NoError(t, err)
loginID, err := uuid.Parse(loginIDCookie.Value)
require.NoError(t, err)
prk, err := hkdf.Extract(hashFunc, secret, loginID[:])
require.NoError(t, err)
hash, err := hkdf.Expand(hashFunc, prk, "INTEGRITY", 64)
require.NoError(t, err)
block, err := hkdf.Expand(hashFunc, prk, "ENCRYPTION", 32)
require.NoError(t, err)
return securecookie.New(hash, block), nil
}, httphelper.WithUnsecure())
loginID := uuid.New()
loginIDCookie := &http.Cookie{
Name: "login_id",
Value: loginID.String(),
Secure: false,
SameSite: http.SameSiteLaxMode,
Path: "/",
}
cookieCases := []cookieSpec{
defaultCookieSpec,
{
cookieHandler: requestAwareCookieHandler,
extraCookies: []*http.Cookie{
loginIDCookie,
},
},
}
for _, cookieCase := range cookieCases {
for _, wrapServer := range []bool{false, true} { for _, wrapServer := range []bool{false, true} {
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) { t.Run(fmt.Sprint("wrapServer ", wrapServer, " requestAwareCookieHandler ", cookieCase.cookieHandler.IsRequestAware()), func(t *testing.T) {
testRelyingPartySession(t, wrapServer) testRelyingPartySession(t, wrapServer, cookieCase)
}) })
} }
} }
}
func testRelyingPartySession(t *testing.T, wrapServer bool) { func testRelyingPartySession(t *testing.T, wrapServer bool, cookieSpec cookieSpec) {
t.Log("------- start example OP ------") t.Log("------- start example OP ------")
targetURL := "http://local-site" targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
@ -74,7 +131,7 @@ func testRelyingPartySession(t *testing.T, wrapServer bool) {
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
t.Log("------- run authorization code flow ------") t.Log("------- run authorization code flow ------")
provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret") provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret", cookieSpec)
t.Log("------- refresh tokens ------") t.Log("------- refresh tokens ------")
@ -220,7 +277,7 @@ func testResourceServerTokenExchange(t *testing.T, wrapServer bool) {
clientSecret := "secret" clientSecret := "secret"
t.Log("------- run authorization code flow ------") t.Log("------- run authorization code flow ------")
provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret) provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret, defaultCookieSpec)
resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret) resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret)
require.NoError(t, err, "new resource server") require.NoError(t, err, "new resource server")
@ -275,7 +332,7 @@ func testResourceServerTokenExchange(t *testing.T, wrapServer bool) {
require.Nil(t, tokenExchangeResponse, "token exchange response") require.Nil(t, tokenExchangeResponse, "token exchange response")
} }
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) { func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string, cookieSpec cookieSpec) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) {
targetURL := "http://local-site" targetURL := "http://local-site"
localURL, err := url.Parse(targetURL + "/login?requestID=1234") localURL, err := url.Parse(targetURL + "/login?requestID=1234")
require.NoError(t, err, "local url") require.NoError(t, err, "local url")
@ -294,8 +351,6 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
} }
t.Log("------- create RP ------") t.Log("------- create RP ------")
key := []byte("test1234test1234")
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
provider, err = rp.NewRelyingPartyOIDC( provider, err = rp.NewRelyingPartyOIDC(
CTX, CTX,
opServer.URL, opServer.URL,
@ -303,7 +358,7 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
clientSecret, clientSecret,
targetURL, targetURL,
[]string{"openid", "email", "profile", "offline_access"}, []string{"openid", "email", "profile", "offline_access"},
rp.WithPKCE(cookieHandler), rp.WithPKCE(cookieSpec.cookieHandler),
rp.WithAuthStyle(oauth2.AuthStyleInHeader), rp.WithAuthStyle(oauth2.AuthStyleInHeader),
rp.WithVerifierOpts( rp.WithVerifierOpts(
rp.WithIssuedAtOffset(5*time.Second), rp.WithIssuedAtOffset(5*time.Second),
@ -317,6 +372,11 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
state := "state-" + strconv.FormatInt(seed.Int63(), 25) state := "state-" + strconv.FormatInt(seed.Int63(), 25)
capturedW := httptest.NewRecorder() capturedW := httptest.NewRecorder()
get := httptest.NewRequest("GET", localURL.String(), nil) get := httptest.NewRequest("GET", localURL.String(), nil)
for _, cookie := range cookieSpec.extraCookies {
get.AddCookie(cookie)
http.SetCookie(capturedW, cookie)
}
rp.AuthURLHandler(func() string { return state }, provider, rp.AuthURLHandler(func() string { return state }, provider,
rp.WithPromptURLParam("Hello, World!", "Goodbye, World!"), rp.WithPromptURLParam("Hello, World!", "Goodbye, World!"),
rp.WithURLParam("custom", "param"), rp.WithURLParam("custom", "param"),

View file

@ -410,12 +410,19 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParam
} }
state := stateFn() state := stateFn()
if err := trySetStateCookie(w, state, rp); err != nil { if err := trySetStateCookie(r, w, state, rp); err != nil {
unauthorizedError(w, r, "failed to create state cookie: "+err.Error(), state, rp) unauthorizedError(w, r, "failed to create state cookie: "+err.Error(), state, rp)
return return
} }
if rp.IsPKCE() { if rp.IsPKCE() {
codeChallenge, err := GenerateAndStoreCodeChallenge(w, rp) var codeChallenge string
var err error
if rp.CookieHandler().IsRequestAware() {
codeChallenge, err = GenerateAndStoreCodeChallengeWithRequest(r, w, rp)
} else {
codeChallenge, err = GenerateAndStoreCodeChallenge(w, rp)
}
if err != nil { if err != nil {
unauthorizedError(w, r, "failed to create code challenge: "+err.Error(), state, rp) unauthorizedError(w, r, "failed to create code challenge: "+err.Error(), state, rp)
return return
@ -436,6 +443,15 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri
return oidc.NewSHACodeChallenge(codeVerifier), nil return oidc.NewSHACodeChallenge(codeVerifier), nil
} }
// GenerateAndStoreCodeChallenge generates a PKCE code challenge and stores its verifier into a secure cookie
func GenerateAndStoreCodeChallengeWithRequest(r *http.Request, w http.ResponseWriter, rp RelyingParty) (string, error) {
codeVerifier := base64.RawURLEncoding.EncodeToString([]byte(uuid.New().String()))
if err := rp.CookieHandler().SetRequestAwareCookie(r, w, pkceCode, codeVerifier); err != nil {
return "", err
}
return oidc.NewSHACodeChallenge(codeVerifier), nil
}
// ErrMissingIDToken is returned when an id_token was expected, // ErrMissingIDToken is returned when an id_token was expected,
// but not received in the token response. // but not received in the token response.
var ErrMissingIDToken = errors.New("id_token missing") var ErrMissingIDToken = errors.New("id_token missing")
@ -607,9 +623,16 @@ func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject st
return userinfo, nil return userinfo, nil
} }
func trySetStateCookie(w http.ResponseWriter, state string, rp RelyingParty) error { func trySetStateCookie(r *http.Request, w http.ResponseWriter, state string, rp RelyingParty) error {
if rp.CookieHandler() != nil { if rp.CookieHandler() != nil {
if err := rp.CookieHandler().SetCookie(w, stateParam, state); err != nil { var err error
if rp.CookieHandler().IsRequestAware() {
err = rp.CookieHandler().SetRequestAwareCookie(r, w, stateParam, state)
} else {
err = rp.CookieHandler().SetCookie(w, stateParam, state)
}
if err != nil {
return err return err
} }
} }

View file

@ -9,6 +9,7 @@ import (
type CookieHandler struct { type CookieHandler struct {
securecookie *securecookie.SecureCookie securecookie *securecookie.SecureCookie
secureCookieFunc func(r *http.Request) (*securecookie.SecureCookie, error)
secureOnly bool secureOnly bool
sameSite http.SameSite sameSite http.SameSite
maxAge int maxAge int
@ -30,6 +31,21 @@ func NewCookieHandler(hashKey, encryptKey []byte, opts ...CookieHandlerOpt) *Coo
return c return c
} }
func NewRequestAwareCookieHandler(secureCookieFunc func(r *http.Request) (*securecookie.SecureCookie, error), opts ...CookieHandlerOpt) *CookieHandler {
c := &CookieHandler{
secureCookieFunc: secureCookieFunc,
secureOnly: true,
sameSite: http.SameSiteLaxMode,
path: "/",
}
for _, opt := range opts {
opt(c)
}
return c
}
type CookieHandlerOpt func(*CookieHandler) type CookieHandlerOpt func(*CookieHandler)
func WithUnsecure() CookieHandlerOpt { func WithUnsecure() CookieHandlerOpt {
@ -47,6 +63,10 @@ func WithSameSite(sameSite http.SameSite) CookieHandlerOpt {
func WithMaxAge(maxAge int) CookieHandlerOpt { func WithMaxAge(maxAge int) CookieHandlerOpt {
return func(c *CookieHandler) { return func(c *CookieHandler) {
c.maxAge = maxAge c.maxAge = maxAge
if c.IsRequestAware() {
return
}
c.securecookie.MaxAge(maxAge) c.securecookie.MaxAge(maxAge)
} }
} }
@ -68,8 +88,17 @@ func (c *CookieHandler) CheckCookie(r *http.Request, name string) (string, error
if err != nil { if err != nil {
return "", err return "", err
} }
secureCookie := c.securecookie
if c.IsRequestAware() {
secureCookie, err = c.secureCookieFunc(r)
if err != nil {
return "", err
}
}
var value string var value string
if err := c.securecookie.Decode(name, cookie.Value, &value); err != nil { if err := secureCookie.Decode(name, cookie.Value, &value); err != nil {
return "", err return "", err
} }
return value, nil return value, nil
@ -87,6 +116,10 @@ func (c *CookieHandler) CheckQueryCookie(r *http.Request, name string) (string,
} }
func (c *CookieHandler) SetCookie(w http.ResponseWriter, name, value string) error { func (c *CookieHandler) SetCookie(w http.ResponseWriter, name, value string) error {
if c.IsRequestAware() {
return errors.New("Cookie handler is request aware")
}
encoded, err := c.securecookie.Encode(name, value) encoded, err := c.securecookie.Encode(name, value)
if err != nil { if err != nil {
return err return err
@ -104,6 +137,35 @@ func (c *CookieHandler) SetCookie(w http.ResponseWriter, name, value string) err
return nil return nil
} }
func (c *CookieHandler) SetRequestAwareCookie(r *http.Request, w http.ResponseWriter, name string, value string) error {
if !c.IsRequestAware() {
return errors.New("Cookie handler is not request aware")
}
secureCookie, err := c.secureCookieFunc(r)
if err != nil {
return err
}
encoded, err := secureCookie.Encode(name, value)
if err != nil {
return err
}
http.SetCookie(w, &http.Cookie{
Name: name,
Value: encoded,
Domain: c.domain,
Path: c.path,
MaxAge: c.maxAge,
HttpOnly: true,
Secure: c.secureOnly,
SameSite: c.sameSite,
})
return nil
}
func (c *CookieHandler) DeleteCookie(w http.ResponseWriter, name string) { func (c *CookieHandler) DeleteCookie(w http.ResponseWriter, name string) {
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: name, Name: name,
@ -116,3 +178,7 @@ func (c *CookieHandler) DeleteCookie(w http.ResponseWriter, name string) {
SameSite: c.sameSite, SameSite: c.sameSite,
}) })
} }
func (c *CookieHandler) IsRequestAware() bool {
return c.secureCookieFunc != nil
}