run integration tests against both Server and Provider

This commit is contained in:
Tim Möhlmann 2023-09-21 19:15:03 +03:00
parent af2d2942a1
commit 46839e095b
5 changed files with 47 additions and 10 deletions

View file

@ -40,7 +40,7 @@ var counter atomic.Int64
// SetupServer creates an OIDC server with Issuer=http://localhost:<port> // SetupServer creates an OIDC server with Issuer=http://localhost:<port>
// //
// Use one of the pre-made clients in storage/clients.go or register a new one. // Use one of the pre-made clients in storage/clients.go or register a new one.
func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router { func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer bool) chi.Router {
// the OpenID Provider requires a 32-byte key for (token) encryption // the OpenID Provider requires a 32-byte key for (token) encryption
// be sure to create a proper crypto random key and manage it securely! // be sure to create a proper crypto random key and manage it securely!
key := sha256.Sum256([]byte("test")) key := sha256.Sum256([]byte("test"))
@ -77,12 +77,17 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router
registerDeviceAuth(storage, r) registerDeviceAuth(storage, r)
}) })
handler := http.Handler(provider)
if wrapServer {
handler = op.NewLegacyServer(provider)
}
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
// is served on the correct path // is served on the correct path
// //
// if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), // if your issuer ends with a path (e.g. http://localhost:9998/custom/path/),
// then you would have to set the path prefix (/custom/path/) // then you would have to set the path prefix (/custom/path/)
router.Mount("/", provider) router.Mount("/", handler)
return router return router
} }

View file

@ -27,7 +27,7 @@ func main() {
Level: slog.LevelDebug, Level: slog.LevelDebug,
}), }),
) )
router := exampleop.SetupServer(issuer, storage, logger) router := exampleop.SetupServer(issuer, storage, logger, false)
server := &http.Server{ server := &http.Server{
Addr: ":" + port, Addr: ":" + port,

View file

@ -3,6 +3,7 @@ package client_test
import ( import (
"bytes" "bytes"
"context" "context"
"fmt"
"io" "io"
"math/rand" "math/rand"
"net/http" "net/http"
@ -50,6 +51,14 @@ func TestMain(m *testing.M) {
} }
func TestRelyingPartySession(t *testing.T) { func TestRelyingPartySession(t *testing.T) {
for _, wrapServer := range []bool{false, true} {
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
testRelyingPartySession(t, wrapServer)
})
}
}
func testRelyingPartySession(t *testing.T, wrapServer bool) {
t.Log("------- start example OP ------") t.Log("------- start example OP ------")
targetURL := "http://local-site" targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
@ -57,7 +66,7 @@ func TestRelyingPartySession(t *testing.T) {
opServer := httptest.NewServer(&dh) opServer := httptest.NewServer(&dh)
defer opServer.Close() defer opServer.Close()
t.Logf("auth server at %s", opServer.URL) t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger) dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
@ -101,6 +110,14 @@ func TestRelyingPartySession(t *testing.T) {
} }
func TestResourceServerTokenExchange(t *testing.T) { func TestResourceServerTokenExchange(t *testing.T) {
for _, wrapServer := range []bool{false, true} {
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
testResourceServerTokenExchange(t, wrapServer)
})
}
}
func testResourceServerTokenExchange(t *testing.T, wrapServer bool) {
t.Log("------- start example OP ------") t.Log("------- start example OP ------")
targetURL := "http://local-site" targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
@ -108,7 +125,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
opServer := httptest.NewServer(&dh) opServer := httptest.NewServer(&dh)
defer opServer.Close() defer opServer.Close()
t.Logf("auth server at %s", opServer.URL) t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger) dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)

View file

@ -3,6 +3,7 @@ package op
import ( import (
"context" "context"
"net/http" "net/http"
"net/url"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/rs/cors" "github.com/rs/cors"
@ -66,17 +67,24 @@ func (s *webServer) createRouter() {
s.Handler = router s.Handler = router
} }
func (s *webServer) verifyRequestClient(r *http.Request) (Client, error) { func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
if err := r.ParseForm(); err != nil { if err = r.ParseForm(); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
} }
cc := new(ClientCredentials) cc := new(ClientCredentials)
if err := s.decoder.Decode(cc, r.Form); err != nil { if err = s.decoder.Decode(cc, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
} }
// Basic auth takes precedence, so if set it overwrites the form data. // Basic auth takes precedence, so if set it overwrites the form data.
if clientID, clientSecret, ok := r.BasicAuth(); ok { if clientID, clientSecret, ok := r.BasicAuth(); ok {
cc.ClientID, cc.ClientSecret = clientID, clientSecret cc.ClientID, err = url.QueryUnescape(clientID)
if err != nil {
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
cc.ClientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
} }
if cc.ClientID == "" && cc.ClientAssertion == "" { if cc.ClientID == "" && cc.ClientAssertion == "" {
return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided") return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided")

View file

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/go-chi/chi"
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
@ -15,9 +16,15 @@ type LegacyServer struct {
} }
func NewLegacyServer(provider OpenIDProvider) http.Handler { func NewLegacyServer(provider OpenIDProvider) http.Handler {
return RegisterServer(&LegacyServer{ server := RegisterServer(&LegacyServer{
provider: provider, provider: provider,
}, WithHTTPMiddleware(intercept(provider.IssuerFromRequest))) }, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
router := chi.NewRouter()
router.Mount("/", server)
router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider))
return router
} }
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) { func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {