diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index b5ee7b3..b807382 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -40,7 +40,7 @@ var counter atomic.Int64 // SetupServer creates an OIDC server with Issuer=http://localhost: // // Use one of the pre-made clients in storage/clients.go or register a new one. -func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router { +func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer bool) chi.Router { // the OpenID Provider requires a 32-byte key for (token) encryption // be sure to create a proper crypto random key and manage it securely! key := sha256.Sum256([]byte("test")) @@ -77,12 +77,17 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router registerDeviceAuth(storage, r) }) + handler := http.Handler(provider) + if wrapServer { + handler = op.NewLegacyServer(provider) + } + // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) // is served on the correct path // // if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), // then you would have to set the path prefix (/custom/path/) - router.Mount("/", provider) + router.Mount("/", handler) return router } diff --git a/example/server/main.go b/example/server/main.go index a1cc461..38057fb 100644 --- a/example/server/main.go +++ b/example/server/main.go @@ -27,7 +27,7 @@ func main() { Level: slog.LevelDebug, }), ) - router := exampleop.SetupServer(issuer, storage, logger) + router := exampleop.SetupServer(issuer, storage, logger, false) server := &http.Server{ Addr: ":" + port, diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index 7cbb62e..1d3559e 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -3,6 +3,7 @@ package client_test import ( "bytes" "context" + "fmt" "io" "math/rand" "net/http" @@ -50,6 +51,14 @@ func TestMain(m *testing.M) { } func TestRelyingPartySession(t *testing.T) { + for _, wrapServer := range []bool{false, true} { + t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) { + testRelyingPartySession(t, wrapServer) + }) + } +} + +func testRelyingPartySession(t *testing.T, wrapServer bool) { t.Log("------- start example OP ------") targetURL := "http://local-site" exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) @@ -57,7 +66,7 @@ func TestRelyingPartySession(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) @@ -101,6 +110,14 @@ func TestRelyingPartySession(t *testing.T) { } func TestResourceServerTokenExchange(t *testing.T) { + for _, wrapServer := range []bool{false, true} { + t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) { + testResourceServerTokenExchange(t, wrapServer) + }) + } +} + +func testResourceServerTokenExchange(t *testing.T, wrapServer bool) { t.Log("------- start example OP ------") targetURL := "http://local-site" exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) @@ -108,7 +125,7 @@ func TestResourceServerTokenExchange(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index bd2019a..887e16c 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -3,6 +3,7 @@ package op import ( "context" "net/http" + "net/url" "github.com/go-chi/chi" "github.com/rs/cors" @@ -66,17 +67,24 @@ func (s *webServer) createRouter() { s.Handler = router } -func (s *webServer) verifyRequestClient(r *http.Request) (Client, error) { - if err := r.ParseForm(); err != nil { +func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) { + if err = r.ParseForm(); err != nil { return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) } cc := new(ClientCredentials) - if err := s.decoder.Decode(cc, r.Form); err != nil { + if err = s.decoder.Decode(cc, r.Form); err != nil { return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) } // Basic auth takes precedence, so if set it overwrites the form data. if clientID, clientSecret, ok := r.BasicAuth(); ok { - cc.ClientID, cc.ClientSecret = clientID, clientSecret + cc.ClientID, err = url.QueryUnescape(clientID) + if err != nil { + return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) + } + cc.ClientSecret, err = url.QueryUnescape(clientSecret) + if err != nil { + return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) + } } if cc.ClientID == "" && cc.ClientAssertion == "" { return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided") diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go index d27fd28..61782a1 100644 --- a/pkg/op/server_legacy.go +++ b/pkg/op/server_legacy.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + "github.com/go-chi/chi" "github.com/zitadel/oidc/v3/pkg/oidc" ) @@ -15,9 +16,15 @@ type LegacyServer struct { } func NewLegacyServer(provider OpenIDProvider) http.Handler { - return RegisterServer(&LegacyServer{ + server := RegisterServer(&LegacyServer{ provider: provider, }, WithHTTPMiddleware(intercept(provider.IssuerFromRequest))) + + router := chi.NewRouter() + router.Mount("/", server) + router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider)) + + return router } func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {