copy and modify the routes test for the legacy server

This commit is contained in:
Tim Möhlmann 2023-09-21 18:13:38 +03:00
parent c98291a6a7
commit af2d2942a1
6 changed files with 390 additions and 17 deletions

View file

@ -12,29 +12,46 @@ import (
"golang.org/x/exp/slog"
)
func RegisterServer(server Server) http.Handler {
func RegisterServer(server Server, options ...ServerOption) http.Handler {
decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(true)
ws := &webServer{
server: server,
endpoints: *DefaultEndpoints,
decoder: schema.NewDecoder(),
decoder: decoder,
logger: slog.Default(),
}
for _, option := range options {
option(ws)
}
ws.createRouter()
return ws
}
type webServer struct {
http.Handler
server Server
endpoints Endpoints
decoder httphelper.Decoder
logger *slog.Logger
type ServerOption func(s *webServer)
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
return func(s *webServer) {
s.middleware = m
}
}
func (s *webServer) createRouter(interceptors ...func(http.Handler) http.Handler) {
type webServer struct {
http.Handler
server Server
middleware []func(http.Handler) http.Handler
endpoints Endpoints
decoder httphelper.Decoder
logger *slog.Logger
}
func (s *webServer) createRouter() {
router := chi.NewRouter()
router.Use(cors.New(defaultCORSOptions).Handler)
router.Use(interceptors...)
router.Use(s.middleware...)
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
@ -145,6 +162,11 @@ func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default())
return
}
if !grantType.IsSupported() {
WriteError(w, r, unimplementedGrantError(grantType), s.logger)
return
}
if grantType == oidc.GrantTypeBearer {
s.jwtProfileHandler(w, r)
return
@ -156,7 +178,7 @@ func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
return
}
if !ValidateGrantType(client, grantType) {
WriteError(w, r, oidc.ErrUnauthorizedClient(), s.logger)
WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.logger)
return
}
@ -331,9 +353,16 @@ func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request)
func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false)
if err != nil || request.AccessToken == "" {
if err != nil {
WriteError(w, r, err, s.logger)
return
}
if token, err := getAccessToken(r); err == nil {
request.AccessToken = token
}
if request.AccessToken == "" {
err = AsStatusError(
oidc.ErrInvalidRequest().WithParent(err).WithDescription("access token missing"),
oidc.ErrInvalidRequest().WithDescription("access token missing"),
http.StatusUnauthorized,
)
WriteError(w, r, err, s.logger)