run integration tests against both Server and Provider

This commit is contained in:
Tim Möhlmann 2023-09-21 19:15:03 +03:00
parent af2d2942a1
commit 46839e095b
5 changed files with 47 additions and 10 deletions

View file

@ -40,7 +40,7 @@ var counter atomic.Int64
// SetupServer creates an OIDC server with Issuer=http://localhost:<port>
//
// Use one of the pre-made clients in storage/clients.go or register a new one.
func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router {
func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer bool) chi.Router {
// the OpenID Provider requires a 32-byte key for (token) encryption
// be sure to create a proper crypto random key and manage it securely!
key := sha256.Sum256([]byte("test"))
@ -77,12 +77,17 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router
registerDeviceAuth(storage, r)
})
handler := http.Handler(provider)
if wrapServer {
handler = op.NewLegacyServer(provider)
}
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
// is served on the correct path
//
// if your issuer ends with a path (e.g. http://localhost:9998/custom/path/),
// then you would have to set the path prefix (/custom/path/)
router.Mount("/", provider)
router.Mount("/", handler)
return router
}

View file

@ -27,7 +27,7 @@ func main() {
Level: slog.LevelDebug,
}),
)
router := exampleop.SetupServer(issuer, storage, logger)
router := exampleop.SetupServer(issuer, storage, logger, false)
server := &http.Server{
Addr: ":" + port,

View file

@ -3,6 +3,7 @@ package client_test
import (
"bytes"
"context"
"fmt"
"io"
"math/rand"
"net/http"
@ -50,6 +51,14 @@ func TestMain(m *testing.M) {
}
func TestRelyingPartySession(t *testing.T) {
for _, wrapServer := range []bool{false, true} {
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
testRelyingPartySession(t, wrapServer)
})
}
}
func testRelyingPartySession(t *testing.T, wrapServer bool) {
t.Log("------- start example OP ------")
targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
@ -57,7 +66,7 @@ func TestRelyingPartySession(t *testing.T) {
opServer := httptest.NewServer(&dh)
defer opServer.Close()
t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
@ -101,6 +110,14 @@ func TestRelyingPartySession(t *testing.T) {
}
func TestResourceServerTokenExchange(t *testing.T) {
for _, wrapServer := range []bool{false, true} {
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
testResourceServerTokenExchange(t, wrapServer)
})
}
}
func testResourceServerTokenExchange(t *testing.T, wrapServer bool) {
t.Log("------- start example OP ------")
targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
@ -108,7 +125,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
opServer := httptest.NewServer(&dh)
defer opServer.Close()
t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)

View file

@ -3,6 +3,7 @@ package op
import (
"context"
"net/http"
"net/url"
"github.com/go-chi/chi"
"github.com/rs/cors"
@ -66,17 +67,24 @@ func (s *webServer) createRouter() {
s.Handler = router
}
func (s *webServer) verifyRequestClient(r *http.Request) (Client, error) {
if err := r.ParseForm(); err != nil {
func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
if err = r.ParseForm(); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
}
cc := new(ClientCredentials)
if err := s.decoder.Decode(cc, r.Form); err != nil {
if err = s.decoder.Decode(cc, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
// Basic auth takes precedence, so if set it overwrites the form data.
if clientID, clientSecret, ok := r.BasicAuth(); ok {
cc.ClientID, cc.ClientSecret = clientID, clientSecret
cc.ClientID, err = url.QueryUnescape(clientID)
if err != nil {
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
cc.ClientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
}
if cc.ClientID == "" && cc.ClientAssertion == "" {
return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided")

View file

@ -6,6 +6,7 @@ import (
"net/http"
"time"
"github.com/go-chi/chi"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
@ -15,9 +16,15 @@ type LegacyServer struct {
}
func NewLegacyServer(provider OpenIDProvider) http.Handler {
return RegisterServer(&LegacyServer{
server := RegisterServer(&LegacyServer{
provider: provider,
}, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
router := chi.NewRouter()
router.Mount("/", server)
router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider))
return router
}
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {