run integration tests against both Server and Provider

This commit is contained in:
Tim Möhlmann 2023-09-21 19:15:03 +03:00
parent af2d2942a1
commit 46839e095b
5 changed files with 47 additions and 10 deletions

View file

@ -3,6 +3,7 @@ package client_test
import (
"bytes"
"context"
"fmt"
"io"
"math/rand"
"net/http"
@ -50,6 +51,14 @@ func TestMain(m *testing.M) {
}
func TestRelyingPartySession(t *testing.T) {
for _, wrapServer := range []bool{false, true} {
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
testRelyingPartySession(t, wrapServer)
})
}
}
func testRelyingPartySession(t *testing.T, wrapServer bool) {
t.Log("------- start example OP ------")
targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
@ -57,7 +66,7 @@ func TestRelyingPartySession(t *testing.T) {
opServer := httptest.NewServer(&dh)
defer opServer.Close()
t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
@ -101,6 +110,14 @@ func TestRelyingPartySession(t *testing.T) {
}
func TestResourceServerTokenExchange(t *testing.T) {
for _, wrapServer := range []bool{false, true} {
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
testResourceServerTokenExchange(t, wrapServer)
})
}
}
func testResourceServerTokenExchange(t *testing.T, wrapServer bool) {
t.Log("------- start example OP ------")
targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
@ -108,7 +125,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
opServer := httptest.NewServer(&dh)
defer opServer.Close()
t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)

View file

@ -3,6 +3,7 @@ package op
import (
"context"
"net/http"
"net/url"
"github.com/go-chi/chi"
"github.com/rs/cors"
@ -66,17 +67,24 @@ func (s *webServer) createRouter() {
s.Handler = router
}
func (s *webServer) verifyRequestClient(r *http.Request) (Client, error) {
if err := r.ParseForm(); err != nil {
func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
if err = r.ParseForm(); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
}
cc := new(ClientCredentials)
if err := s.decoder.Decode(cc, r.Form); err != nil {
if err = s.decoder.Decode(cc, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
// Basic auth takes precedence, so if set it overwrites the form data.
if clientID, clientSecret, ok := r.BasicAuth(); ok {
cc.ClientID, cc.ClientSecret = clientID, clientSecret
cc.ClientID, err = url.QueryUnescape(clientID)
if err != nil {
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
cc.ClientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
}
if cc.ClientID == "" && cc.ClientAssertion == "" {
return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided")

View file

@ -6,6 +6,7 @@ import (
"net/http"
"time"
"github.com/go-chi/chi"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
@ -15,9 +16,15 @@ type LegacyServer struct {
}
func NewLegacyServer(provider OpenIDProvider) http.Handler {
return RegisterServer(&LegacyServer{
server := RegisterServer(&LegacyServer{
provider: provider,
}, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
router := chi.NewRouter()
router.Mount("/", server)
router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider))
return router
}
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {