run integration tests against both Server and Provider
This commit is contained in:
parent
af2d2942a1
commit
46839e095b
5 changed files with 47 additions and 10 deletions
|
@ -40,7 +40,7 @@ var counter atomic.Int64
|
||||||
// SetupServer creates an OIDC server with Issuer=http://localhost:<port>
|
// SetupServer creates an OIDC server with Issuer=http://localhost:<port>
|
||||||
//
|
//
|
||||||
// Use one of the pre-made clients in storage/clients.go or register a new one.
|
// Use one of the pre-made clients in storage/clients.go or register a new one.
|
||||||
func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router {
|
func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer bool) chi.Router {
|
||||||
// the OpenID Provider requires a 32-byte key for (token) encryption
|
// the OpenID Provider requires a 32-byte key for (token) encryption
|
||||||
// be sure to create a proper crypto random key and manage it securely!
|
// be sure to create a proper crypto random key and manage it securely!
|
||||||
key := sha256.Sum256([]byte("test"))
|
key := sha256.Sum256([]byte("test"))
|
||||||
|
@ -77,12 +77,17 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router
|
||||||
registerDeviceAuth(storage, r)
|
registerDeviceAuth(storage, r)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
handler := http.Handler(provider)
|
||||||
|
if wrapServer {
|
||||||
|
handler = op.NewLegacyServer(provider)
|
||||||
|
}
|
||||||
|
|
||||||
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
|
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
|
||||||
// is served on the correct path
|
// is served on the correct path
|
||||||
//
|
//
|
||||||
// if your issuer ends with a path (e.g. http://localhost:9998/custom/path/),
|
// if your issuer ends with a path (e.g. http://localhost:9998/custom/path/),
|
||||||
// then you would have to set the path prefix (/custom/path/)
|
// then you would have to set the path prefix (/custom/path/)
|
||||||
router.Mount("/", provider)
|
router.Mount("/", handler)
|
||||||
|
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,7 @@ func main() {
|
||||||
Level: slog.LevelDebug,
|
Level: slog.LevelDebug,
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
router := exampleop.SetupServer(issuer, storage, logger)
|
router := exampleop.SetupServer(issuer, storage, logger, false)
|
||||||
|
|
||||||
server := &http.Server{
|
server := &http.Server{
|
||||||
Addr: ":" + port,
|
Addr: ":" + port,
|
||||||
|
|
|
@ -3,6 +3,7 @@ package client_test
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -50,6 +51,14 @@ func TestMain(m *testing.M) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRelyingPartySession(t *testing.T) {
|
func TestRelyingPartySession(t *testing.T) {
|
||||||
|
for _, wrapServer := range []bool{false, true} {
|
||||||
|
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
|
||||||
|
testRelyingPartySession(t, wrapServer)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testRelyingPartySession(t *testing.T, wrapServer bool) {
|
||||||
t.Log("------- start example OP ------")
|
t.Log("------- start example OP ------")
|
||||||
targetURL := "http://local-site"
|
targetURL := "http://local-site"
|
||||||
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
|
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
|
||||||
|
@ -57,7 +66,7 @@ func TestRelyingPartySession(t *testing.T) {
|
||||||
opServer := httptest.NewServer(&dh)
|
opServer := httptest.NewServer(&dh)
|
||||||
defer opServer.Close()
|
defer opServer.Close()
|
||||||
t.Logf("auth server at %s", opServer.URL)
|
t.Logf("auth server at %s", opServer.URL)
|
||||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger)
|
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
|
||||||
|
|
||||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||||
|
@ -101,6 +110,14 @@ func TestRelyingPartySession(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResourceServerTokenExchange(t *testing.T) {
|
func TestResourceServerTokenExchange(t *testing.T) {
|
||||||
|
for _, wrapServer := range []bool{false, true} {
|
||||||
|
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
|
||||||
|
testResourceServerTokenExchange(t, wrapServer)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testResourceServerTokenExchange(t *testing.T, wrapServer bool) {
|
||||||
t.Log("------- start example OP ------")
|
t.Log("------- start example OP ------")
|
||||||
targetURL := "http://local-site"
|
targetURL := "http://local-site"
|
||||||
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
|
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
|
||||||
|
@ -108,7 +125,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
|
||||||
opServer := httptest.NewServer(&dh)
|
opServer := httptest.NewServer(&dh)
|
||||||
defer opServer.Close()
|
defer opServer.Close()
|
||||||
t.Logf("auth server at %s", opServer.URL)
|
t.Logf("auth server at %s", opServer.URL)
|
||||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger)
|
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
|
||||||
|
|
||||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||||
|
|
|
@ -3,6 +3,7 @@ package op
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/rs/cors"
|
"github.com/rs/cors"
|
||||||
|
@ -66,17 +67,24 @@ func (s *webServer) createRouter() {
|
||||||
s.Handler = router
|
s.Handler = router
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *webServer) verifyRequestClient(r *http.Request) (Client, error) {
|
func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
|
||||||
if err := r.ParseForm(); err != nil {
|
if err = r.ParseForm(); err != nil {
|
||||||
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
|
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
|
||||||
}
|
}
|
||||||
cc := new(ClientCredentials)
|
cc := new(ClientCredentials)
|
||||||
if err := s.decoder.Decode(cc, r.Form); err != nil {
|
if err = s.decoder.Decode(cc, r.Form); err != nil {
|
||||||
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
|
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
|
||||||
}
|
}
|
||||||
// Basic auth takes precedence, so if set it overwrites the form data.
|
// Basic auth takes precedence, so if set it overwrites the form data.
|
||||||
if clientID, clientSecret, ok := r.BasicAuth(); ok {
|
if clientID, clientSecret, ok := r.BasicAuth(); ok {
|
||||||
cc.ClientID, cc.ClientSecret = clientID, clientSecret
|
cc.ClientID, err = url.QueryUnescape(clientID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
|
||||||
|
}
|
||||||
|
cc.ClientSecret, err = url.QueryUnescape(clientSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if cc.ClientID == "" && cc.ClientAssertion == "" {
|
if cc.ClientID == "" && cc.ClientAssertion == "" {
|
||||||
return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided")
|
return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided")
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi"
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,9 +16,15 @@ type LegacyServer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLegacyServer(provider OpenIDProvider) http.Handler {
|
func NewLegacyServer(provider OpenIDProvider) http.Handler {
|
||||||
return RegisterServer(&LegacyServer{
|
server := RegisterServer(&LegacyServer{
|
||||||
provider: provider,
|
provider: provider,
|
||||||
}, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
|
}, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
|
||||||
|
|
||||||
|
router := chi.NewRouter()
|
||||||
|
router.Mount("/", server)
|
||||||
|
router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider))
|
||||||
|
|
||||||
|
return router
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {
|
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue