server options
This commit is contained in:
parent
e9c494041c
commit
a49ad31735
3 changed files with 127 additions and 46 deletions
|
@ -22,6 +22,13 @@ import (
|
|||
// in the standards regarding the response models. Where applicable
|
||||
// the method documentation gives a recommended type which can be used
|
||||
// directly or extended upon.
|
||||
//
|
||||
// The addition of new methods is not considered a breaking change
|
||||
// as defined by semver rules.
|
||||
// Implementations MUST embed [UnimplementedServer] to maintain
|
||||
// forward compatibility.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type Server interface {
|
||||
// Health returns a status of "ok" once the Server is listening.
|
||||
// The recommended Response Data type is [Status].
|
||||
|
@ -146,6 +153,8 @@ type Server interface {
|
|||
// and parsed Data from the request body (POST) or URL parameters (GET).
|
||||
// Data can be assumed to be validated according to the applicable
|
||||
// standard for the specific endpoints.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type Request[T any] struct {
|
||||
Method string
|
||||
URL *url.URL
|
||||
|
@ -173,6 +182,8 @@ func newRequest[T any](r *http.Request, data *T) *Request[T] {
|
|||
// ClientRequest is a Request with a verified client attached to it.
|
||||
// Methods the receive this argument may assume the client was authenticated,
|
||||
// or verified to be a public client.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type ClientRequest[T any] struct {
|
||||
*Request[T]
|
||||
Client Client
|
||||
|
@ -185,6 +196,9 @@ func newClientRequest[T any](r *http.Request, data *T, client Client) *ClientReq
|
|||
}
|
||||
}
|
||||
|
||||
// Response object for most [Server] methods.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type Response struct {
|
||||
// Header map will be merged with the
|
||||
// header on the [http.ResponseWriter].
|
||||
|
@ -200,6 +214,8 @@ type Response struct {
|
|||
Data any
|
||||
}
|
||||
|
||||
// NewResponse creates a new response for data,
|
||||
// without custom headers.
|
||||
func NewResponse(data any) *Response {
|
||||
return &Response{
|
||||
Data: data,
|
||||
|
@ -215,6 +231,8 @@ func (resp *Response) writeOut(w http.ResponseWriter) {
|
|||
// initiate a [http.StatusFound] redirect.
|
||||
// The Params field will be encoded and set to the
|
||||
// URL's RawQuery field before building the URL.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type Redirect struct {
|
||||
// Header map will be merged with the
|
||||
// header on the [http.ResponseWriter].
|
||||
|
|
|
@ -7,12 +7,19 @@ import (
|
|||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/rs/cors"
|
||||
"github.com/zitadel/logging"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/schema"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
// RegisterServer registers an implementation of Server.
|
||||
// The resulting handler takes care of routing and request parsing,
|
||||
// with some basic validation of required fields.
|
||||
// The routes can be customized with [WithEndpoints].
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
func RegisterServer(server Server, options ...ServerOption) http.Handler {
|
||||
decoder := schema.NewDecoder()
|
||||
decoder.IgnoreUnknownKeys(true)
|
||||
|
@ -34,12 +41,38 @@ func RegisterServer(server Server, options ...ServerOption) http.Handler {
|
|||
|
||||
type ServerOption func(s *webServer)
|
||||
|
||||
// WithHTTPMiddler sets the passed middleware chain to the root of
|
||||
// the Server's router.
|
||||
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
|
||||
return func(s *webServer) {
|
||||
s.middleware = m
|
||||
}
|
||||
}
|
||||
|
||||
// WithEndpoints overrides the [DefaultEndpoints]
|
||||
func WithEndpoints(endpoints Endpoints) ServerOption {
|
||||
return func(s *webServer) {
|
||||
s.endpoints = endpoints
|
||||
}
|
||||
}
|
||||
|
||||
// WithDecoder overrides the default decoder,
|
||||
// which is a [schema.Decoder] with IgnoreUnknownKeys set to true.
|
||||
func WithDecoder(decoder httphelper.Decoder) ServerOption {
|
||||
return func(s *webServer) {
|
||||
s.decoder = decoder
|
||||
}
|
||||
}
|
||||
|
||||
// WithFallbackLogger overrides the fallback logger, which
|
||||
// is used when no logger was found in the context.
|
||||
// Defaults to [slog.Default].
|
||||
func WithFallbackLogger(logger *slog.Logger) ServerOption {
|
||||
return func(s *webServer) {
|
||||
s.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
type webServer struct {
|
||||
http.Handler
|
||||
server Server
|
||||
|
@ -49,6 +82,13 @@ type webServer struct {
|
|||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
|
||||
if logger, ok := logging.FromContext(ctx); ok {
|
||||
return logger
|
||||
}
|
||||
return s.logger
|
||||
}
|
||||
|
||||
func (s *webServer) createRouter() {
|
||||
router := chi.NewRouter()
|
||||
router.Use(cors.New(defaultCORSOptions).Handler)
|
||||
|
@ -73,12 +113,12 @@ func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {
|
|||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
client, err := s.verifyRequestClient(r)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" {
|
||||
if !ValidateGrantType(client, grantType) {
|
||||
WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.logger)
|
||||
WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -123,12 +163,12 @@ func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
|
|||
func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
redirect, err := s.authorize(r.Context(), newRequest(r, request))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
redirect.writeOut(w, r)
|
||||
|
@ -163,12 +203,12 @@ func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest])
|
|||
func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
|
@ -176,7 +216,7 @@ func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Re
|
|||
|
||||
func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -194,25 +234,25 @@ func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
|
|||
case oidc.GrantTypeDeviceCode:
|
||||
s.withClient(s.deviceTokenHandler)(w, r)
|
||||
case "":
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.getLogger(r.Context()))
|
||||
default:
|
||||
WriteError(w, r, unimplementedGrantError(grantType), s.logger)
|
||||
WriteError(w, r, unimplementedGrantError(grantType), s.getLogger(r.Context()))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) {
|
||||
request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.Assertion == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
|
@ -221,20 +261,20 @@ func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) {
|
|||
func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.Code == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.RedirectURI == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
|
@ -243,16 +283,16 @@ func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request,
|
|||
func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.RefreshToken == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
|
@ -261,32 +301,32 @@ func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request,
|
|||
func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.SubjectToken == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.SubjectTokenType == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if !request.SubjectTokenType.IsSupported() {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
|
@ -294,18 +334,18 @@ func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request,
|
|||
|
||||
func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
if client.AuthMethod() == oidc.AuthMethodNone {
|
||||
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
|
||||
request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
|
@ -314,16 +354,16 @@ func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Requ
|
|||
func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.DeviceCode == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
|
@ -331,21 +371,21 @@ func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, c
|
|||
|
||||
func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
if client.AuthMethod() == oidc.AuthMethodNone {
|
||||
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.Token == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
|
@ -354,7 +394,7 @@ func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request,
|
|||
func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
|
||||
request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if token, err := getAccessToken(r); err == nil {
|
||||
|
@ -365,12 +405,12 @@ func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
|
|||
oidc.ErrInvalidRequest().WithDescription("access token missing"),
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.UserInfo(r.Context(), newRequest(r, request))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
|
@ -379,16 +419,16 @@ func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
|
|||
func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.Token == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
|
@ -397,12 +437,12 @@ func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, cl
|
|||
func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.EndSession(r.Context(), newRequest(r, request))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w, r)
|
||||
|
@ -411,12 +451,12 @@ func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) {
|
|||
func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger)
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := method(r.Context(), newRequest(r, &struct{}{}))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.logger)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -21,6 +22,28 @@ import (
|
|||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func TestRegisterServer(t *testing.T) {
|
||||
server := UnimplementedServer{}
|
||||
endpoints := Endpoints{
|
||||
Authorization: Endpoint{
|
||||
path: "/auth",
|
||||
},
|
||||
}
|
||||
decoder := schema.NewDecoder()
|
||||
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
|
||||
|
||||
h := RegisterServer(server,
|
||||
WithEndpoints(endpoints),
|
||||
WithDecoder(decoder),
|
||||
WithFallbackLogger(logger),
|
||||
)
|
||||
got := h.(*webServer)
|
||||
assert.Equal(t, got.server, server)
|
||||
assert.Equal(t, got.endpoints, endpoints)
|
||||
assert.Equal(t, got.decoder, decoder)
|
||||
assert.Equal(t, got.logger, logger)
|
||||
}
|
||||
|
||||
type testClient struct {
|
||||
id string
|
||||
appType ApplicationType
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue