server options
This commit is contained in:
parent
e9c494041c
commit
a49ad31735
3 changed files with 127 additions and 46 deletions
|
@ -22,6 +22,13 @@ import (
|
||||||
// in the standards regarding the response models. Where applicable
|
// in the standards regarding the response models. Where applicable
|
||||||
// the method documentation gives a recommended type which can be used
|
// the method documentation gives a recommended type which can be used
|
||||||
// directly or extended upon.
|
// directly or extended upon.
|
||||||
|
//
|
||||||
|
// The addition of new methods is not considered a breaking change
|
||||||
|
// as defined by semver rules.
|
||||||
|
// Implementations MUST embed [UnimplementedServer] to maintain
|
||||||
|
// forward compatibility.
|
||||||
|
//
|
||||||
|
// EXPERIMENTAL: may change until v4
|
||||||
type Server interface {
|
type Server interface {
|
||||||
// Health returns a status of "ok" once the Server is listening.
|
// Health returns a status of "ok" once the Server is listening.
|
||||||
// The recommended Response Data type is [Status].
|
// The recommended Response Data type is [Status].
|
||||||
|
@ -146,6 +153,8 @@ type Server interface {
|
||||||
// and parsed Data from the request body (POST) or URL parameters (GET).
|
// and parsed Data from the request body (POST) or URL parameters (GET).
|
||||||
// Data can be assumed to be validated according to the applicable
|
// Data can be assumed to be validated according to the applicable
|
||||||
// standard for the specific endpoints.
|
// standard for the specific endpoints.
|
||||||
|
//
|
||||||
|
// EXPERIMENTAL: may change until v4
|
||||||
type Request[T any] struct {
|
type Request[T any] struct {
|
||||||
Method string
|
Method string
|
||||||
URL *url.URL
|
URL *url.URL
|
||||||
|
@ -173,6 +182,8 @@ func newRequest[T any](r *http.Request, data *T) *Request[T] {
|
||||||
// ClientRequest is a Request with a verified client attached to it.
|
// ClientRequest is a Request with a verified client attached to it.
|
||||||
// Methods the receive this argument may assume the client was authenticated,
|
// Methods the receive this argument may assume the client was authenticated,
|
||||||
// or verified to be a public client.
|
// or verified to be a public client.
|
||||||
|
//
|
||||||
|
// EXPERIMENTAL: may change until v4
|
||||||
type ClientRequest[T any] struct {
|
type ClientRequest[T any] struct {
|
||||||
*Request[T]
|
*Request[T]
|
||||||
Client Client
|
Client Client
|
||||||
|
@ -185,6 +196,9 @@ func newClientRequest[T any](r *http.Request, data *T, client Client) *ClientReq
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Response object for most [Server] methods.
|
||||||
|
//
|
||||||
|
// EXPERIMENTAL: may change until v4
|
||||||
type Response struct {
|
type Response struct {
|
||||||
// Header map will be merged with the
|
// Header map will be merged with the
|
||||||
// header on the [http.ResponseWriter].
|
// header on the [http.ResponseWriter].
|
||||||
|
@ -200,6 +214,8 @@ type Response struct {
|
||||||
Data any
|
Data any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewResponse creates a new response for data,
|
||||||
|
// without custom headers.
|
||||||
func NewResponse(data any) *Response {
|
func NewResponse(data any) *Response {
|
||||||
return &Response{
|
return &Response{
|
||||||
Data: data,
|
Data: data,
|
||||||
|
@ -215,6 +231,8 @@ func (resp *Response) writeOut(w http.ResponseWriter) {
|
||||||
// initiate a [http.StatusFound] redirect.
|
// initiate a [http.StatusFound] redirect.
|
||||||
// The Params field will be encoded and set to the
|
// The Params field will be encoded and set to the
|
||||||
// URL's RawQuery field before building the URL.
|
// URL's RawQuery field before building the URL.
|
||||||
|
//
|
||||||
|
// EXPERIMENTAL: may change until v4
|
||||||
type Redirect struct {
|
type Redirect struct {
|
||||||
// Header map will be merged with the
|
// Header map will be merged with the
|
||||||
// header on the [http.ResponseWriter].
|
// header on the [http.ResponseWriter].
|
||||||
|
|
|
@ -7,12 +7,19 @@ import (
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/rs/cors"
|
"github.com/rs/cors"
|
||||||
|
"github.com/zitadel/logging"
|
||||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
"github.com/zitadel/schema"
|
"github.com/zitadel/schema"
|
||||||
"golang.org/x/exp/slog"
|
"golang.org/x/exp/slog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// RegisterServer registers an implementation of Server.
|
||||||
|
// The resulting handler takes care of routing and request parsing,
|
||||||
|
// with some basic validation of required fields.
|
||||||
|
// The routes can be customized with [WithEndpoints].
|
||||||
|
//
|
||||||
|
// EXPERIMENTAL: may change until v4
|
||||||
func RegisterServer(server Server, options ...ServerOption) http.Handler {
|
func RegisterServer(server Server, options ...ServerOption) http.Handler {
|
||||||
decoder := schema.NewDecoder()
|
decoder := schema.NewDecoder()
|
||||||
decoder.IgnoreUnknownKeys(true)
|
decoder.IgnoreUnknownKeys(true)
|
||||||
|
@ -34,12 +41,38 @@ func RegisterServer(server Server, options ...ServerOption) http.Handler {
|
||||||
|
|
||||||
type ServerOption func(s *webServer)
|
type ServerOption func(s *webServer)
|
||||||
|
|
||||||
|
// WithHTTPMiddler sets the passed middleware chain to the root of
|
||||||
|
// the Server's router.
|
||||||
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
|
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
|
||||||
return func(s *webServer) {
|
return func(s *webServer) {
|
||||||
s.middleware = m
|
s.middleware = m
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithEndpoints overrides the [DefaultEndpoints]
|
||||||
|
func WithEndpoints(endpoints Endpoints) ServerOption {
|
||||||
|
return func(s *webServer) {
|
||||||
|
s.endpoints = endpoints
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDecoder overrides the default decoder,
|
||||||
|
// which is a [schema.Decoder] with IgnoreUnknownKeys set to true.
|
||||||
|
func WithDecoder(decoder httphelper.Decoder) ServerOption {
|
||||||
|
return func(s *webServer) {
|
||||||
|
s.decoder = decoder
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFallbackLogger overrides the fallback logger, which
|
||||||
|
// is used when no logger was found in the context.
|
||||||
|
// Defaults to [slog.Default].
|
||||||
|
func WithFallbackLogger(logger *slog.Logger) ServerOption {
|
||||||
|
return func(s *webServer) {
|
||||||
|
s.logger = logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type webServer struct {
|
type webServer struct {
|
||||||
http.Handler
|
http.Handler
|
||||||
server Server
|
server Server
|
||||||
|
@ -49,6 +82,13 @@ type webServer struct {
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
|
||||||
|
if logger, ok := logging.FromContext(ctx); ok {
|
||||||
|
return logger
|
||||||
|
}
|
||||||
|
return s.logger
|
||||||
|
}
|
||||||
|
|
||||||
func (s *webServer) createRouter() {
|
func (s *webServer) createRouter() {
|
||||||
router := chi.NewRouter()
|
router := chi.NewRouter()
|
||||||
router.Use(cors.New(defaultCORSOptions).Handler)
|
router.Use(cors.New(defaultCORSOptions).Handler)
|
||||||
|
@ -73,12 +113,12 @@ func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
client, err := s.verifyRequestClient(r)
|
client, err := s.verifyRequestClient(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" {
|
if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" {
|
||||||
if !ValidateGrantType(client, grantType) {
|
if !ValidateGrantType(client, grantType) {
|
||||||
WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.logger)
|
WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -123,12 +163,12 @@ func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
|
||||||
func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false)
|
request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
redirect, err := s.authorize(r.Context(), newRequest(r, request))
|
redirect, err := s.authorize(r.Context(), newRequest(r, request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
redirect.writeOut(w, r)
|
redirect.writeOut(w, r)
|
||||||
|
@ -163,12 +203,12 @@ func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest])
|
||||||
func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||||
request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false)
|
request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client))
|
resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.writeOut(w)
|
resp.writeOut(w)
|
||||||
|
@ -176,7 +216,7 @@ func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Re
|
||||||
|
|
||||||
func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := r.ParseForm(); err != nil {
|
if err := r.ParseForm(); err != nil {
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -194,25 +234,25 @@ func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
case oidc.GrantTypeDeviceCode:
|
case oidc.GrantTypeDeviceCode:
|
||||||
s.withClient(s.deviceTokenHandler)(w, r)
|
s.withClient(s.deviceTokenHandler)(w, r)
|
||||||
case "":
|
case "":
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.getLogger(r.Context()))
|
||||||
default:
|
default:
|
||||||
WriteError(w, r, unimplementedGrantError(grantType), s.logger)
|
WriteError(w, r, unimplementedGrantError(grantType), s.getLogger(r.Context()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false)
|
request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if request.Assertion == "" {
|
if request.Assertion == "" {
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request))
|
resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.writeOut(w)
|
resp.writeOut(w)
|
||||||
|
@ -221,20 +261,20 @@ func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||||
request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false)
|
request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if request.Code == "" {
|
if request.Code == "" {
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if request.RedirectURI == "" {
|
if request.RedirectURI == "" {
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client))
|
resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.writeOut(w)
|
resp.writeOut(w)
|
||||||
|
@ -243,16 +283,16 @@ func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request,
|
||||||
func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||||
request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false)
|
request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if request.RefreshToken == "" {
|
if request.RefreshToken == "" {
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client))
|
resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.writeOut(w)
|
resp.writeOut(w)
|
||||||
|
@ -261,32 +301,32 @@ func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request,
|
||||||
func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||||
request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false)
|
request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if request.SubjectToken == "" {
|
if request.SubjectToken == "" {
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if request.SubjectTokenType == "" {
|
if request.SubjectTokenType == "" {
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !request.SubjectTokenType.IsSupported() {
|
if !request.SubjectTokenType.IsSupported() {
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() {
|
if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() {
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() {
|
if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() {
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client))
|
resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.writeOut(w)
|
resp.writeOut(w)
|
||||||
|
@ -294,18 +334,18 @@ func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request,
|
||||||
|
|
||||||
func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||||
if client.AuthMethod() == oidc.AuthMethodNone {
|
if client.AuthMethod() == oidc.AuthMethodNone {
|
||||||
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.logger)
|
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false)
|
request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client))
|
resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.writeOut(w)
|
resp.writeOut(w)
|
||||||
|
@ -314,16 +354,16 @@ func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Requ
|
||||||
func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||||
request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false)
|
request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if request.DeviceCode == "" {
|
if request.DeviceCode == "" {
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client))
|
resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.writeOut(w)
|
resp.writeOut(w)
|
||||||
|
@ -331,21 +371,21 @@ func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, c
|
||||||
|
|
||||||
func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||||
if client.AuthMethod() == oidc.AuthMethodNone {
|
if client.AuthMethod() == oidc.AuthMethodNone {
|
||||||
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.logger)
|
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false)
|
request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if request.Token == "" {
|
if request.Token == "" {
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client))
|
resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.writeOut(w)
|
resp.writeOut(w)
|
||||||
|
@ -354,7 +394,7 @@ func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request,
|
||||||
func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false)
|
request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if token, err := getAccessToken(r); err == nil {
|
if token, err := getAccessToken(r); err == nil {
|
||||||
|
@ -365,12 +405,12 @@ func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
oidc.ErrInvalidRequest().WithDescription("access token missing"),
|
oidc.ErrInvalidRequest().WithDescription("access token missing"),
|
||||||
http.StatusUnauthorized,
|
http.StatusUnauthorized,
|
||||||
)
|
)
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := s.server.UserInfo(r.Context(), newRequest(r, request))
|
resp, err := s.server.UserInfo(r.Context(), newRequest(r, request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.writeOut(w)
|
resp.writeOut(w)
|
||||||
|
@ -379,16 +419,16 @@ func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||||
request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false)
|
request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if request.Token == "" {
|
if request.Token == "" {
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client))
|
resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.writeOut(w)
|
resp.writeOut(w)
|
||||||
|
@ -397,12 +437,12 @@ func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, cl
|
||||||
func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false)
|
request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := s.server.EndSession(r.Context(), newRequest(r, request))
|
resp, err := s.server.EndSession(r.Context(), newRequest(r, request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.writeOut(w, r)
|
resp.writeOut(w, r)
|
||||||
|
@ -411,12 +451,12 @@ func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc {
|
func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := r.ParseForm(); err != nil {
|
if err := r.ParseForm(); err != nil {
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := method(r.Context(), newRequest(r, &struct{}{}))
|
resp, err := method(r.Context(), newRequest(r, &struct{}{}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.writeOut(w)
|
resp.writeOut(w)
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -21,6 +22,28 @@ import (
|
||||||
"golang.org/x/exp/slog"
|
"golang.org/x/exp/slog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestRegisterServer(t *testing.T) {
|
||||||
|
server := UnimplementedServer{}
|
||||||
|
endpoints := Endpoints{
|
||||||
|
Authorization: Endpoint{
|
||||||
|
path: "/auth",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
decoder := schema.NewDecoder()
|
||||||
|
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
|
||||||
|
|
||||||
|
h := RegisterServer(server,
|
||||||
|
WithEndpoints(endpoints),
|
||||||
|
WithDecoder(decoder),
|
||||||
|
WithFallbackLogger(logger),
|
||||||
|
)
|
||||||
|
got := h.(*webServer)
|
||||||
|
assert.Equal(t, got.server, server)
|
||||||
|
assert.Equal(t, got.endpoints, endpoints)
|
||||||
|
assert.Equal(t, got.decoder, decoder)
|
||||||
|
assert.Equal(t, got.logger, logger)
|
||||||
|
}
|
||||||
|
|
||||||
type testClient struct {
|
type testClient struct {
|
||||||
id string
|
id string
|
||||||
appType ApplicationType
|
appType ApplicationType
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue