From a49ad31735a466e142766b904d1b8fcc98a7fc74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Mon, 25 Sep 2023 20:02:11 +0300 Subject: [PATCH] server options --- pkg/op/server.go | 18 +++++ pkg/op/server_http.go | 132 ++++++++++++++++++++++++------------- pkg/op/server_http_test.go | 23 +++++++ 3 files changed, 127 insertions(+), 46 deletions(-) diff --git a/pkg/op/server.go b/pkg/op/server.go index 9d43bfa..47398ec 100644 --- a/pkg/op/server.go +++ b/pkg/op/server.go @@ -22,6 +22,13 @@ import ( // in the standards regarding the response models. Where applicable // the method documentation gives a recommended type which can be used // directly or extended upon. +// +// The addition of new methods is not considered a breaking change +// as defined by semver rules. +// Implementations MUST embed [UnimplementedServer] to maintain +// forward compatibility. +// +// EXPERIMENTAL: may change until v4 type Server interface { // Health returns a status of "ok" once the Server is listening. // The recommended Response Data type is [Status]. @@ -146,6 +153,8 @@ type Server interface { // and parsed Data from the request body (POST) or URL parameters (GET). // Data can be assumed to be validated according to the applicable // standard for the specific endpoints. +// +// EXPERIMENTAL: may change until v4 type Request[T any] struct { Method string URL *url.URL @@ -173,6 +182,8 @@ func newRequest[T any](r *http.Request, data *T) *Request[T] { // ClientRequest is a Request with a verified client attached to it. // Methods the receive this argument may assume the client was authenticated, // or verified to be a public client. +// +// EXPERIMENTAL: may change until v4 type ClientRequest[T any] struct { *Request[T] Client Client @@ -185,6 +196,9 @@ func newClientRequest[T any](r *http.Request, data *T, client Client) *ClientReq } } +// Response object for most [Server] methods. +// +// EXPERIMENTAL: may change until v4 type Response struct { // Header map will be merged with the // header on the [http.ResponseWriter]. @@ -200,6 +214,8 @@ type Response struct { Data any } +// NewResponse creates a new response for data, +// without custom headers. func NewResponse(data any) *Response { return &Response{ Data: data, @@ -215,6 +231,8 @@ func (resp *Response) writeOut(w http.ResponseWriter) { // initiate a [http.StatusFound] redirect. // The Params field will be encoded and set to the // URL's RawQuery field before building the URL. +// +// EXPERIMENTAL: may change until v4 type Redirect struct { // Header map will be merged with the // header on the [http.ResponseWriter]. diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index e60b2ce..88b18cc 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -7,12 +7,19 @@ import ( "github.com/go-chi/chi" "github.com/rs/cors" + "github.com/zitadel/logging" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/schema" "golang.org/x/exp/slog" ) +// RegisterServer registers an implementation of Server. +// The resulting handler takes care of routing and request parsing, +// with some basic validation of required fields. +// The routes can be customized with [WithEndpoints]. +// +// EXPERIMENTAL: may change until v4 func RegisterServer(server Server, options ...ServerOption) http.Handler { decoder := schema.NewDecoder() decoder.IgnoreUnknownKeys(true) @@ -34,12 +41,38 @@ func RegisterServer(server Server, options ...ServerOption) http.Handler { type ServerOption func(s *webServer) +// WithHTTPMiddler sets the passed middleware chain to the root of +// the Server's router. func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption { return func(s *webServer) { s.middleware = m } } +// WithEndpoints overrides the [DefaultEndpoints] +func WithEndpoints(endpoints Endpoints) ServerOption { + return func(s *webServer) { + s.endpoints = endpoints + } +} + +// WithDecoder overrides the default decoder, +// which is a [schema.Decoder] with IgnoreUnknownKeys set to true. +func WithDecoder(decoder httphelper.Decoder) ServerOption { + return func(s *webServer) { + s.decoder = decoder + } +} + +// WithFallbackLogger overrides the fallback logger, which +// is used when no logger was found in the context. +// Defaults to [slog.Default]. +func WithFallbackLogger(logger *slog.Logger) ServerOption { + return func(s *webServer) { + s.logger = logger + } +} + type webServer struct { http.Handler server Server @@ -49,6 +82,13 @@ type webServer struct { logger *slog.Logger } +func (s *webServer) getLogger(ctx context.Context) *slog.Logger { + if logger, ok := logging.FromContext(ctx); ok { + return logger + } + return s.logger +} + func (s *webServer) createRouter() { router := chi.NewRouter() router.Use(cors.New(defaultCORSOptions).Handler) @@ -73,12 +113,12 @@ func (s *webServer) withClient(handler clientHandler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { client, err := s.verifyRequestClient(r) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" { if !ValidateGrantType(client, grantType) { - WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.logger) + WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.getLogger(r.Context())) return } } @@ -123,12 +163,12 @@ func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) { func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) { request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } redirect, err := s.authorize(r.Context(), newRequest(r, request)) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } redirect.writeOut(w, r) @@ -163,12 +203,12 @@ func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest]) func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) { request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client)) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp.writeOut(w) @@ -176,7 +216,7 @@ func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Re func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context())) return } @@ -194,25 +234,25 @@ func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) { case oidc.GrantTypeDeviceCode: s.withClient(s.deviceTokenHandler)(w, r) case "": - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.getLogger(r.Context())) default: - WriteError(w, r, unimplementedGrantError(grantType), s.logger) + WriteError(w, r, unimplementedGrantError(grantType), s.getLogger(r.Context())) } } func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) { request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } if request.Assertion == "" { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.getLogger(r.Context())) return } resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request)) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp.writeOut(w) @@ -221,20 +261,20 @@ func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) { func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) { request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } if request.Code == "" { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.getLogger(r.Context())) return } if request.RedirectURI == "" { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.getLogger(r.Context())) return } resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client)) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp.writeOut(w) @@ -243,16 +283,16 @@ func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) { request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } if request.RefreshToken == "" { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.getLogger(r.Context())) return } resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client)) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp.writeOut(w) @@ -261,32 +301,32 @@ func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) { request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } if request.SubjectToken == "" { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.getLogger(r.Context())) return } if request.SubjectTokenType == "" { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.getLogger(r.Context())) return } if !request.SubjectTokenType.IsSupported() { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.getLogger(r.Context())) return } if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.getLogger(r.Context())) return } if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.getLogger(r.Context())) return } resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client)) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp.writeOut(w) @@ -294,18 +334,18 @@ func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) { if client.AuthMethod() == oidc.AuthMethodNone { - WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.logger) + WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context())) return } request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client)) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp.writeOut(w) @@ -314,16 +354,16 @@ func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Requ func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) { request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } if request.DeviceCode == "" { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.getLogger(r.Context())) return } resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client)) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp.writeOut(w) @@ -331,21 +371,21 @@ func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, c func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) { if client.AuthMethod() == oidc.AuthMethodNone { - WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.logger) + WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context())) return } request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } if request.Token == "" { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context())) return } resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client)) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp.writeOut(w) @@ -354,7 +394,7 @@ func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) { request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } if token, err := getAccessToken(r); err == nil { @@ -365,12 +405,12 @@ func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) { oidc.ErrInvalidRequest().WithDescription("access token missing"), http.StatusUnauthorized, ) - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp, err := s.server.UserInfo(r.Context(), newRequest(r, request)) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp.writeOut(w) @@ -379,16 +419,16 @@ func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) { func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) { request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } if request.Token == "" { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context())) return } resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client)) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp.writeOut(w) @@ -397,12 +437,12 @@ func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, cl func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) { request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp, err := s.server.EndSession(r.Context(), newRequest(r, request)) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp.writeOut(w, r) @@ -411,12 +451,12 @@ func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) { func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context())) return } resp, err := method(r.Context(), newRequest(r, &struct{}{})) if err != nil { - WriteError(w, r, err, s.logger) + WriteError(w, r, err, s.getLogger(r.Context())) return } resp.writeOut(w) diff --git a/pkg/op/server_http_test.go b/pkg/op/server_http_test.go index 6ff4678..bbacf28 100644 --- a/pkg/op/server_http_test.go +++ b/pkg/op/server_http_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" "strings" "testing" "time" @@ -21,6 +22,28 @@ import ( "golang.org/x/exp/slog" ) +func TestRegisterServer(t *testing.T) { + server := UnimplementedServer{} + endpoints := Endpoints{ + Authorization: Endpoint{ + path: "/auth", + }, + } + decoder := schema.NewDecoder() + logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + + h := RegisterServer(server, + WithEndpoints(endpoints), + WithDecoder(decoder), + WithFallbackLogger(logger), + ) + got := h.(*webServer) + assert.Equal(t, got.server, server) + assert.Equal(t, got.endpoints, endpoints) + assert.Equal(t, got.decoder, decoder) + assert.Equal(t, got.logger, logger) +} + type testClient struct { id string appType ApplicationType