524 lines
16 KiB
Go
524 lines
16 KiB
Go
package op
|
|
|
|
import (
|
|
"context"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/url"
|
|
|
|
httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
|
|
"git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/rs/cors"
|
|
"github.com/zitadel/logging"
|
|
"github.com/zitadel/schema"
|
|
)
|
|
|
|
// RegisterServer registers an implementation of Server.
|
|
// The resulting handler takes care of routing and request parsing,
|
|
// with some basic validation of required fields.
|
|
// The routes can be customized with [WithEndpoints].
|
|
//
|
|
// EXPERIMENTAL: may change until v4
|
|
func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler {
|
|
decoder := schema.NewDecoder()
|
|
decoder.IgnoreUnknownKeys(true)
|
|
|
|
ws := &webServer{
|
|
router: chi.NewRouter(),
|
|
server: server,
|
|
endpoints: endpoints,
|
|
decoder: decoder,
|
|
corsOpts: &defaultCORSOptions,
|
|
logger: slog.Default(),
|
|
}
|
|
|
|
for _, option := range options {
|
|
option(ws)
|
|
}
|
|
|
|
ws.createRouter()
|
|
ws.handler = ws.router
|
|
if ws.corsOpts != nil {
|
|
ws.handler = cors.New(*ws.corsOpts).Handler(ws.router)
|
|
}
|
|
return ws
|
|
}
|
|
|
|
type ServerOption func(s *webServer)
|
|
|
|
// WithHTTPMiddleware sets the passed middleware chain to the root of
|
|
// the Server's router.
|
|
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
|
|
return func(s *webServer) {
|
|
s.router.Use(m...)
|
|
}
|
|
}
|
|
|
|
// WithSetRouter allows customization or the Server's router.
|
|
func WithSetRouter(set func(chi.Router)) ServerOption {
|
|
return func(s *webServer) {
|
|
set(s.router)
|
|
}
|
|
}
|
|
|
|
// WithDecoder overrides the default decoder,
|
|
// which is a [schema.Decoder] with IgnoreUnknownKeys set to true.
|
|
func WithDecoder(decoder httphelper.Decoder) ServerOption {
|
|
return func(s *webServer) {
|
|
s.decoder = decoder
|
|
}
|
|
}
|
|
|
|
// WithServerCORSOptions sets the CORS policy for the Server's router.
|
|
func WithServerCORSOptions(opts *cors.Options) ServerOption {
|
|
return func(s *webServer) {
|
|
s.corsOpts = opts
|
|
}
|
|
}
|
|
|
|
// WithFallbackLogger overrides the fallback logger, which
|
|
// is used when no logger was found in the context.
|
|
// Defaults to [slog.Default].
|
|
func WithFallbackLogger(logger *slog.Logger) ServerOption {
|
|
return func(s *webServer) {
|
|
s.logger = logger
|
|
}
|
|
}
|
|
|
|
type webServer struct {
|
|
server Server
|
|
router *chi.Mux
|
|
handler http.Handler
|
|
endpoints Endpoints
|
|
decoder httphelper.Decoder
|
|
corsOpts *cors.Options
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func (s *webServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
s.handler.ServeHTTP(w, r)
|
|
}
|
|
|
|
func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
|
|
if logger, ok := logging.FromContext(ctx); ok {
|
|
return logger
|
|
}
|
|
return s.logger
|
|
}
|
|
|
|
func (s *webServer) createRouter() {
|
|
s.router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
|
|
s.router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
|
|
s.router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
|
|
|
|
s.endpointRoute(s.endpoints.Authorization, s.authorizeHandler)
|
|
s.endpointRoute(s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
|
|
s.endpointRoute(s.endpoints.Token, s.tokensHandler)
|
|
s.endpointRoute(s.endpoints.Introspection, s.introspectionHandler)
|
|
s.endpointRoute(s.endpoints.Userinfo, s.userInfoHandler)
|
|
s.endpointRoute(s.endpoints.Revocation, s.withClient(s.revocationHandler))
|
|
s.endpointRoute(s.endpoints.EndSession, s.endSessionHandler)
|
|
s.endpointRoute(s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
|
|
}
|
|
|
|
func (s *webServer) endpointRoute(e *Endpoint, hf http.HandlerFunc) {
|
|
if e != nil {
|
|
traceHandler := func(w http.ResponseWriter, r *http.Request) {
|
|
ctx, span := tracer.Start(r.Context(), e.Relative())
|
|
r = r.WithContext(ctx)
|
|
hf(w, r)
|
|
defer span.End()
|
|
}
|
|
s.router.HandleFunc(e.Relative(), traceHandler)
|
|
s.logger.Info("registered route", "endpoint", e.Relative())
|
|
}
|
|
}
|
|
|
|
type clientHandler func(w http.ResponseWriter, r *http.Request, client Client)
|
|
|
|
func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
ctx, span := tracer.Start(r.Context(), r.URL.Path)
|
|
defer span.End()
|
|
r = r.WithContext(ctx)
|
|
|
|
client, err := s.verifyRequestClient(r)
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" {
|
|
if !ValidateGrantType(client, grantType) {
|
|
WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
}
|
|
handler(w, r, client)
|
|
}
|
|
}
|
|
|
|
func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
|
|
cc, err := s.parseClientCredentials(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{
|
|
Method: r.Method,
|
|
URL: r.URL,
|
|
Header: r.Header,
|
|
Form: r.Form,
|
|
Data: cc,
|
|
})
|
|
}
|
|
|
|
func (s *webServer) parseClientCredentials(r *http.Request) (_ *ClientCredentials, err error) {
|
|
if err := r.ParseForm(); err != nil {
|
|
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
|
|
}
|
|
cc := new(ClientCredentials)
|
|
if err = s.decoder.Decode(cc, r.Form); err != nil {
|
|
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
|
|
}
|
|
// Basic auth takes precedence, so if set it overwrites the form data.
|
|
if clientID, clientSecret, ok := r.BasicAuth(); ok {
|
|
cc.ClientID, err = url.QueryUnescape(clientID)
|
|
if err != nil {
|
|
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
|
|
}
|
|
cc.ClientSecret, err = url.QueryUnescape(clientSecret)
|
|
if err != nil {
|
|
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
|
|
}
|
|
}
|
|
if cc.ClientID == "" && cc.ClientAssertion == "" {
|
|
return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided")
|
|
}
|
|
if cc.ClientAssertion != "" && cc.ClientAssertionType != oidc.ClientAssertionTypeJWTAssertion {
|
|
return nil, oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type %s", cc.ClientAssertionType)
|
|
}
|
|
return cc, nil
|
|
}
|
|
|
|
func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
|
|
request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false)
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
redirect, err := s.authorize(r.Context(), newRequest(r, request))
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
redirect.writeOut(w, r)
|
|
}
|
|
|
|
func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) {
|
|
cr, err := s.server.VerifyAuthRequest(ctx, r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
authReq := cr.Data
|
|
if authReq.RedirectURI == "" {
|
|
return nil, ErrAuthReqMissingRedirectURI
|
|
}
|
|
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
authReq.Scopes, err = ValidateAuthReqScopes(cr.Client, authReq.Scopes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := ValidateAuthReqRedirectURI(cr.Client, authReq.RedirectURI, authReq.ResponseType); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := ValidateAuthReqResponseType(cr.Client, authReq.ResponseType); err != nil {
|
|
return nil, err
|
|
}
|
|
return s.server.Authorize(ctx, cr)
|
|
}
|
|
|
|
func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
|
request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false)
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client))
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp.writeOut(w)
|
|
}
|
|
|
|
func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
|
|
switch grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType {
|
|
case oidc.GrantTypeCode:
|
|
s.withClient(s.codeExchangeHandler)(w, r)
|
|
case oidc.GrantTypeRefreshToken:
|
|
s.withClient(s.refreshTokenHandler)(w, r)
|
|
case oidc.GrantTypeClientCredentials:
|
|
s.withClient(s.clientCredentialsHandler)(w, r)
|
|
case oidc.GrantTypeBearer:
|
|
s.jwtProfileHandler(w, r)
|
|
case oidc.GrantTypeTokenExchange:
|
|
s.withClient(s.tokenExchangeHandler)(w, r)
|
|
case oidc.GrantTypeDeviceCode:
|
|
s.withClient(s.deviceTokenHandler)(w, r)
|
|
case "":
|
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.getLogger(r.Context()))
|
|
default:
|
|
WriteError(w, r, unimplementedGrantError(grantType), s.getLogger(r.Context()))
|
|
}
|
|
}
|
|
|
|
func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) {
|
|
request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false)
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
if request.Assertion == "" {
|
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request))
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp.writeOut(w)
|
|
}
|
|
|
|
func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
|
request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false)
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
if request.Code == "" {
|
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
if request.RedirectURI == "" {
|
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client))
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp.writeOut(w)
|
|
}
|
|
|
|
func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
|
request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false)
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
if request.RefreshToken == "" {
|
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client))
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp.writeOut(w)
|
|
}
|
|
|
|
func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
|
request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false)
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
if request.SubjectToken == "" {
|
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
if request.SubjectTokenType == "" {
|
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
if !request.SubjectTokenType.IsSupported() {
|
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() {
|
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() {
|
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client))
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp.writeOut(w)
|
|
}
|
|
|
|
func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
|
if client.AuthMethod() == oidc.AuthMethodNone {
|
|
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
|
|
request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false)
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client))
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp.writeOut(w)
|
|
}
|
|
|
|
func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
|
request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false)
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
if request.DeviceCode == "" {
|
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client))
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp.writeOut(w)
|
|
}
|
|
|
|
func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request) {
|
|
cc, err := s.parseClientCredentials(r)
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
if cc.ClientSecret == "" && cc.ClientAssertion == "" {
|
|
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false)
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
if request.Token == "" {
|
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp, err := s.server.Introspect(r.Context(), newRequest(r, &IntrospectionRequest{cc, request}))
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp.writeOut(w)
|
|
}
|
|
|
|
func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
|
|
request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false)
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
if token, err := getAccessToken(r); err == nil {
|
|
request.AccessToken = token
|
|
}
|
|
if request.AccessToken == "" {
|
|
err = NewStatusError(
|
|
oidc.ErrInvalidRequest().WithDescription("access token missing"),
|
|
http.StatusUnauthorized,
|
|
)
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp, err := s.server.UserInfo(r.Context(), newRequest(r, request))
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp.writeOut(w)
|
|
}
|
|
|
|
func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
|
request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false)
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
if request.Token == "" {
|
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client))
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp.writeOut(w)
|
|
}
|
|
|
|
func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) {
|
|
request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false)
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp, err := s.server.EndSession(r.Context(), newRequest(r, request))
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp.writeOut(w, r)
|
|
}
|
|
|
|
func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp, err := method(r.Context(), newRequest(r, &struct{}{}))
|
|
if err != nil {
|
|
WriteError(w, r, err, s.getLogger(r.Context()))
|
|
return
|
|
}
|
|
resp.writeOut(w)
|
|
}
|
|
}
|
|
|
|
func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) {
|
|
dst := new(R)
|
|
if err := r.ParseForm(); err != nil {
|
|
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
|
|
}
|
|
form := r.Form
|
|
if postOnly {
|
|
form = r.PostForm
|
|
}
|
|
if err := decoder.Decode(dst, form); err != nil {
|
|
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
|
|
}
|
|
return dst, nil
|
|
}
|