server options

This commit is contained in:
Tim Möhlmann 2023-09-25 20:02:11 +03:00
parent e9c494041c
commit a49ad31735
3 changed files with 127 additions and 46 deletions

View file

@ -22,6 +22,13 @@ import (
// in the standards regarding the response models. Where applicable // in the standards regarding the response models. Where applicable
// the method documentation gives a recommended type which can be used // the method documentation gives a recommended type which can be used
// directly or extended upon. // directly or extended upon.
//
// The addition of new methods is not considered a breaking change
// as defined by semver rules.
// Implementations MUST embed [UnimplementedServer] to maintain
// forward compatibility.
//
// EXPERIMENTAL: may change until v4
type Server interface { type Server interface {
// Health returns a status of "ok" once the Server is listening. // Health returns a status of "ok" once the Server is listening.
// The recommended Response Data type is [Status]. // The recommended Response Data type is [Status].
@ -146,6 +153,8 @@ type Server interface {
// and parsed Data from the request body (POST) or URL parameters (GET). // and parsed Data from the request body (POST) or URL parameters (GET).
// Data can be assumed to be validated according to the applicable // Data can be assumed to be validated according to the applicable
// standard for the specific endpoints. // standard for the specific endpoints.
//
// EXPERIMENTAL: may change until v4
type Request[T any] struct { type Request[T any] struct {
Method string Method string
URL *url.URL URL *url.URL
@ -173,6 +182,8 @@ func newRequest[T any](r *http.Request, data *T) *Request[T] {
// ClientRequest is a Request with a verified client attached to it. // ClientRequest is a Request with a verified client attached to it.
// Methods the receive this argument may assume the client was authenticated, // Methods the receive this argument may assume the client was authenticated,
// or verified to be a public client. // or verified to be a public client.
//
// EXPERIMENTAL: may change until v4
type ClientRequest[T any] struct { type ClientRequest[T any] struct {
*Request[T] *Request[T]
Client Client Client Client
@ -185,6 +196,9 @@ func newClientRequest[T any](r *http.Request, data *T, client Client) *ClientReq
} }
} }
// Response object for most [Server] methods.
//
// EXPERIMENTAL: may change until v4
type Response struct { type Response struct {
// Header map will be merged with the // Header map will be merged with the
// header on the [http.ResponseWriter]. // header on the [http.ResponseWriter].
@ -200,6 +214,8 @@ type Response struct {
Data any Data any
} }
// NewResponse creates a new response for data,
// without custom headers.
func NewResponse(data any) *Response { func NewResponse(data any) *Response {
return &Response{ return &Response{
Data: data, Data: data,
@ -215,6 +231,8 @@ func (resp *Response) writeOut(w http.ResponseWriter) {
// initiate a [http.StatusFound] redirect. // initiate a [http.StatusFound] redirect.
// The Params field will be encoded and set to the // The Params field will be encoded and set to the
// URL's RawQuery field before building the URL. // URL's RawQuery field before building the URL.
//
// EXPERIMENTAL: may change until v4
type Redirect struct { type Redirect struct {
// Header map will be merged with the // Header map will be merged with the
// header on the [http.ResponseWriter]. // header on the [http.ResponseWriter].

View file

@ -7,12 +7,19 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/rs/cors" "github.com/rs/cors"
"github.com/zitadel/logging"
httphelper "github.com/zitadel/oidc/v3/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/schema" "github.com/zitadel/schema"
"golang.org/x/exp/slog" "golang.org/x/exp/slog"
) )
// RegisterServer registers an implementation of Server.
// The resulting handler takes care of routing and request parsing,
// with some basic validation of required fields.
// The routes can be customized with [WithEndpoints].
//
// EXPERIMENTAL: may change until v4
func RegisterServer(server Server, options ...ServerOption) http.Handler { func RegisterServer(server Server, options ...ServerOption) http.Handler {
decoder := schema.NewDecoder() decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(true) decoder.IgnoreUnknownKeys(true)
@ -34,12 +41,38 @@ func RegisterServer(server Server, options ...ServerOption) http.Handler {
type ServerOption func(s *webServer) type ServerOption func(s *webServer)
// WithHTTPMiddler sets the passed middleware chain to the root of
// the Server's router.
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption { func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
return func(s *webServer) { return func(s *webServer) {
s.middleware = m s.middleware = m
} }
} }
// WithEndpoints overrides the [DefaultEndpoints]
func WithEndpoints(endpoints Endpoints) ServerOption {
return func(s *webServer) {
s.endpoints = endpoints
}
}
// WithDecoder overrides the default decoder,
// which is a [schema.Decoder] with IgnoreUnknownKeys set to true.
func WithDecoder(decoder httphelper.Decoder) ServerOption {
return func(s *webServer) {
s.decoder = decoder
}
}
// WithFallbackLogger overrides the fallback logger, which
// is used when no logger was found in the context.
// Defaults to [slog.Default].
func WithFallbackLogger(logger *slog.Logger) ServerOption {
return func(s *webServer) {
s.logger = logger
}
}
type webServer struct { type webServer struct {
http.Handler http.Handler
server Server server Server
@ -49,6 +82,13 @@ type webServer struct {
logger *slog.Logger logger *slog.Logger
} }
func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
if logger, ok := logging.FromContext(ctx); ok {
return logger
}
return s.logger
}
func (s *webServer) createRouter() { func (s *webServer) createRouter() {
router := chi.NewRouter() router := chi.NewRouter()
router.Use(cors.New(defaultCORSOptions).Handler) router.Use(cors.New(defaultCORSOptions).Handler)
@ -73,12 +113,12 @@ func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
client, err := s.verifyRequestClient(r) client, err := s.verifyRequestClient(r)
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" { if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" {
if !ValidateGrantType(client, grantType) { if !ValidateGrantType(client, grantType) {
WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.logger) WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.getLogger(r.Context()))
return return
} }
} }
@ -123,12 +163,12 @@ func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) { func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false) request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false)
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
redirect, err := s.authorize(r.Context(), newRequest(r, request)) redirect, err := s.authorize(r.Context(), newRequest(r, request))
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
redirect.writeOut(w, r) redirect.writeOut(w, r)
@ -163,12 +203,12 @@ func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest])
func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) { func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false) request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false)
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client)) resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client))
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp.writeOut(w) resp.writeOut(w)
@ -176,7 +216,7 @@ func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Re
func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) { func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger) WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
return return
} }
@ -194,25 +234,25 @@ func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
case oidc.GrantTypeDeviceCode: case oidc.GrantTypeDeviceCode:
s.withClient(s.deviceTokenHandler)(w, r) s.withClient(s.deviceTokenHandler)(w, r)
case "": case "":
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.logger) WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.getLogger(r.Context()))
default: default:
WriteError(w, r, unimplementedGrantError(grantType), s.logger) WriteError(w, r, unimplementedGrantError(grantType), s.getLogger(r.Context()))
} }
} }
func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) { func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false) request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false)
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
if request.Assertion == "" { if request.Assertion == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.logger) WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.getLogger(r.Context()))
return return
} }
resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request)) resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request))
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp.writeOut(w) resp.writeOut(w)
@ -221,20 +261,20 @@ func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) {
func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) { func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false) request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false)
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
if request.Code == "" { if request.Code == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.logger) WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.getLogger(r.Context()))
return return
} }
if request.RedirectURI == "" { if request.RedirectURI == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.logger) WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.getLogger(r.Context()))
return return
} }
resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client)) resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client))
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp.writeOut(w) resp.writeOut(w)
@ -243,16 +283,16 @@ func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request,
func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) { func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false) request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false)
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
if request.RefreshToken == "" { if request.RefreshToken == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.logger) WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.getLogger(r.Context()))
return return
} }
resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client)) resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client))
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp.writeOut(w) resp.writeOut(w)
@ -261,32 +301,32 @@ func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request,
func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) { func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false) request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false)
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
if request.SubjectToken == "" { if request.SubjectToken == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.logger) WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.getLogger(r.Context()))
return return
} }
if request.SubjectTokenType == "" { if request.SubjectTokenType == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.logger) WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.getLogger(r.Context()))
return return
} }
if !request.SubjectTokenType.IsSupported() { if !request.SubjectTokenType.IsSupported() {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.logger) WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.getLogger(r.Context()))
return return
} }
if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() { if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.logger) WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.getLogger(r.Context()))
return return
} }
if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() { if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger) WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.getLogger(r.Context()))
return return
} }
resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client)) resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client))
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp.writeOut(w) resp.writeOut(w)
@ -294,18 +334,18 @@ func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request,
func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) { func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) {
if client.AuthMethod() == oidc.AuthMethodNone { if client.AuthMethod() == oidc.AuthMethodNone {
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.logger) WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
return return
} }
request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false) request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false)
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client)) resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client))
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp.writeOut(w) resp.writeOut(w)
@ -314,16 +354,16 @@ func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Requ
func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) { func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false) request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false)
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
if request.DeviceCode == "" { if request.DeviceCode == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.logger) WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.getLogger(r.Context()))
return return
} }
resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client)) resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client))
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp.writeOut(w) resp.writeOut(w)
@ -331,21 +371,21 @@ func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, c
func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) { func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) {
if client.AuthMethod() == oidc.AuthMethodNone { if client.AuthMethod() == oidc.AuthMethodNone {
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.logger) WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
return return
} }
request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false) request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false)
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
if request.Token == "" { if request.Token == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.logger) WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
return return
} }
resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client)) resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client))
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp.writeOut(w) resp.writeOut(w)
@ -354,7 +394,7 @@ func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request,
func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) { func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false) request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false)
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
if token, err := getAccessToken(r); err == nil { if token, err := getAccessToken(r); err == nil {
@ -365,12 +405,12 @@ func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
oidc.ErrInvalidRequest().WithDescription("access token missing"), oidc.ErrInvalidRequest().WithDescription("access token missing"),
http.StatusUnauthorized, http.StatusUnauthorized,
) )
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp, err := s.server.UserInfo(r.Context(), newRequest(r, request)) resp, err := s.server.UserInfo(r.Context(), newRequest(r, request))
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp.writeOut(w) resp.writeOut(w)
@ -379,16 +419,16 @@ func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) { func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false) request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false)
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
if request.Token == "" { if request.Token == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.logger) WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
return return
} }
resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client)) resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client))
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp.writeOut(w) resp.writeOut(w)
@ -397,12 +437,12 @@ func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, cl
func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) { func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false) request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false)
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp, err := s.server.EndSession(r.Context(), newRequest(r, request)) resp, err := s.server.EndSession(r.Context(), newRequest(r, request))
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp.writeOut(w, r) resp.writeOut(w, r)
@ -411,12 +451,12 @@ func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) {
func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc { func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger) WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
return return
} }
resp, err := method(r.Context(), newRequest(r, &struct{}{})) resp, err := method(r.Context(), newRequest(r, &struct{}{}))
if err != nil { if err != nil {
WriteError(w, r, err, s.logger) WriteError(w, r, err, s.getLogger(r.Context()))
return return
} }
resp.writeOut(w) resp.writeOut(w)

View file

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"os"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -21,6 +22,28 @@ import (
"golang.org/x/exp/slog" "golang.org/x/exp/slog"
) )
func TestRegisterServer(t *testing.T) {
server := UnimplementedServer{}
endpoints := Endpoints{
Authorization: Endpoint{
path: "/auth",
},
}
decoder := schema.NewDecoder()
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
h := RegisterServer(server,
WithEndpoints(endpoints),
WithDecoder(decoder),
WithFallbackLogger(logger),
)
got := h.(*webServer)
assert.Equal(t, got.server, server)
assert.Equal(t, got.endpoints, endpoints)
assert.Equal(t, got.decoder, decoder)
assert.Equal(t, got.logger, logger)
}
type testClient struct { type testClient struct {
id string id string
appType ApplicationType appType ApplicationType