diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index 98a9d3a..fe20492 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -3,6 +3,9 @@ package client_test import ( "bytes" "context" + "crypto/hkdf" + cryptoRand "crypto/rand" + "crypto/sha256" "fmt" "io" "log/slog" @@ -18,6 +21,8 @@ import ( "testing" "time" + "github.com/google/uuid" + "github.com/gorilla/securecookie" "github.com/jeremija/gosubmit" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -42,6 +47,15 @@ var Logger = slog.New( var CTX context.Context +type cookieSpec struct { + cookieHandler *httphelper.CookieHandler + extraCookies []*http.Cookie +} + +var defaultCookieSpec = cookieSpec{ + cookieHandler: httphelper.NewCookieHandler([]byte("test1234test1234"), []byte("test1234test1234"), httphelper.WithUnsecure()), +} + func TestMain(m *testing.M) { os.Exit(func() int { ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT) @@ -53,14 +67,57 @@ func TestMain(m *testing.M) { } func TestRelyingPartySession(t *testing.T) { - for _, wrapServer := range []bool{false, true} { - t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) { - testRelyingPartySession(t, wrapServer) - }) + secret := make([]byte, 512) + _, _ = cryptoRand.Read(secret) + hashFunc := sha256.New + requestAwareCookieHandler := httphelper.NewRequestAwareCookieHandler(func(r *http.Request) (*securecookie.SecureCookie, error) { + loginIDCookie, err := r.Cookie("login_id") + require.NoError(t, err) + + loginID, err := uuid.Parse(loginIDCookie.Value) + require.NoError(t, err) + + prk, err := hkdf.Extract(hashFunc, secret, loginID[:]) + require.NoError(t, err) + + hash, err := hkdf.Expand(hashFunc, prk, "INTEGRITY", 64) + require.NoError(t, err) + + block, err := hkdf.Expand(hashFunc, prk, "ENCRYPTION", 32) + require.NoError(t, err) + + return securecookie.New(hash, block), nil + }, httphelper.WithUnsecure()) + + loginID := uuid.New() + loginIDCookie := &http.Cookie{ + Name: "login_id", + Value: loginID.String(), + Secure: false, + SameSite: http.SameSiteLaxMode, + Path: "/", + } + + cookieCases := []cookieSpec{ + defaultCookieSpec, + { + cookieHandler: requestAwareCookieHandler, + extraCookies: []*http.Cookie{ + loginIDCookie, + }, + }, + } + + for _, cookieCase := range cookieCases { + for _, wrapServer := range []bool{false, true} { + t.Run(fmt.Sprint("wrapServer ", wrapServer, " requestAwareCookieHandler ", cookieCase.cookieHandler.IsRequestAware()), func(t *testing.T) { + testRelyingPartySession(t, wrapServer, cookieCase) + }) + } } } -func testRelyingPartySession(t *testing.T, wrapServer bool) { +func testRelyingPartySession(t *testing.T, wrapServer bool, cookieSpec cookieSpec) { t.Log("------- start example OP ------") targetURL := "http://local-site" exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) @@ -74,7 +131,7 @@ func testRelyingPartySession(t *testing.T, wrapServer bool) { clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) t.Log("------- run authorization code flow ------") - provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret") + provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret", cookieSpec) t.Log("------- refresh tokens ------") @@ -220,7 +277,7 @@ func testResourceServerTokenExchange(t *testing.T, wrapServer bool) { clientSecret := "secret" t.Log("------- run authorization code flow ------") - provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret) + provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret, defaultCookieSpec) resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret) require.NoError(t, err, "new resource server") @@ -275,7 +332,7 @@ func testResourceServerTokenExchange(t *testing.T, wrapServer bool) { require.Nil(t, tokenExchangeResponse, "token exchange response") } -func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) { +func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string, cookieSpec cookieSpec) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) { targetURL := "http://local-site" localURL, err := url.Parse(targetURL + "/login?requestID=1234") require.NoError(t, err, "local url") @@ -294,8 +351,6 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, } t.Log("------- create RP ------") - key := []byte("test1234test1234") - cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) provider, err = rp.NewRelyingPartyOIDC( CTX, opServer.URL, @@ -303,7 +358,7 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret, targetURL, []string{"openid", "email", "profile", "offline_access"}, - rp.WithPKCE(cookieHandler), + rp.WithPKCE(cookieSpec.cookieHandler), rp.WithAuthStyle(oauth2.AuthStyleInHeader), rp.WithVerifierOpts( rp.WithIssuedAtOffset(5*time.Second), @@ -317,6 +372,11 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, state := "state-" + strconv.FormatInt(seed.Int63(), 25) capturedW := httptest.NewRecorder() get := httptest.NewRequest("GET", localURL.String(), nil) + for _, cookie := range cookieSpec.extraCookies { + get.AddCookie(cookie) + http.SetCookie(capturedW, cookie) + } + rp.AuthURLHandler(func() string { return state }, provider, rp.WithPromptURLParam("Hello, World!", "Goodbye, World!"), rp.WithURLParam("custom", "param"),