From 81d42b061df81c58ab61c1890ca00a92b3f0e5f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Wed, 13 Sep 2023 19:13:53 +0300 Subject: [PATCH] define handlers, routes --- pkg/op/client.go | 3 +- pkg/op/server.go | 5 ++ pkg/op/server_http.go | 172 +++++++++++++++++++++++++++++++++--------- 3 files changed, 143 insertions(+), 37 deletions(-) diff --git a/pkg/op/client.go b/pkg/op/client.go index 525af8d..04ef3c7 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -183,8 +183,7 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au type ClientCredentials struct { ClientID string `schema:"client_id"` - ClientSecret string `schema:"client_secret"` // Client secret from request body - ClientSecretBasic string `schema:"-"` // Obtained from http request + ClientSecret string `schema:"client_secret"` // Client secret from Basic auth or request body ClientAssertion string `schema:"client_assertion"` // JWT ClientAssertionType string `schema:"client_assertion_type"` } diff --git a/pkg/op/server.go b/pkg/op/server.go index a9be613..0862f32 100644 --- a/pkg/op/server.go +++ b/pkg/op/server.go @@ -218,6 +218,11 @@ func NewRedirect(url string) *Redirect { return &Redirect{URL: url} } +func (red *Redirect) writeOut(w http.ResponseWriter, r *http.Request) { + gu.MapMerge(r.Header, w.Header()) + http.Redirect(w, r, red.URL, http.StatusFound) +} + type UnimplementedServer struct{} // UnimplementedStatusCode is the statuscode returned for methods diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index 4016db3..dedbfe5 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -1,75 +1,105 @@ package op import ( + "context" "net/http" "github.com/go-chi/chi" "github.com/rs/cors" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/schema" "golang.org/x/exp/slog" ) -type webServer struct { - http.Handler - decoder httphelper.Decoder - server Server - logger *slog.Logger +func RegisterServer(server Server) http.Handler { + ws := &webServer{ + server: server, + endpoints: *DefaultEndpoints, + decoder: schema.NewDecoder(), + logger: slog.Default(), + } + ws.createRouter() + return ws } -func (s *webServer) createRouter(endpoints *Endpoints, interceptors ...func(http.Handler) http.Handler) chi.Router { +type webServer struct { + http.Handler + server Server + endpoints Endpoints + decoder httphelper.Decoder + logger *slog.Logger +} + +func (s *webServer) createRouter(interceptors ...func(http.Handler) http.Handler) { router := chi.NewRouter() router.Use(cors.New(defaultCORSOptions).Handler) router.Use(interceptors...) - router.HandleFunc(healthEndpoint, healthHandler) - //router.HandleFunc(readinessEndpoint, readyHandler(o.Probes())) - //router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage())) - //router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o)) - //router.HandleFunc(authCallbackPath(o), authorizeCallbackHandler(o)) - router.HandleFunc(endpoints.Token.Relative(), s.handleToken) - //router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o)) - //router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o)) - //router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o)) - //router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o)) - //router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage())) - //router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), DeviceAuthorizationHandler(o)) - return router + router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health)) + router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready)) + router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery)) + router.HandleFunc(s.endpoints.Authorization.Relative(), redirectHandler(s, s.server.Authorize)) + router.HandleFunc(s.endpoints.Token.Relative(), s.tokensHandler) + router.HandleFunc(s.endpoints.Introspection.Relative(), clientRequestHandler(s, s.server.Introspect)) + router.HandleFunc(s.endpoints.Userinfo.Relative(), requestHandler(s, s.server.UserInfo)) + router.HandleFunc(s.endpoints.Revocation.Relative(), clientRequestHandler(s, s.server.Revocation)) + router.HandleFunc(s.endpoints.EndSession.Relative(), redirectHandler(s, s.server.EndSession)) + router.HandleFunc(s.endpoints.JwksURI.Relative(), simpleHandler(s, s.server.Keys)) + router.HandleFunc(s.endpoints.DeviceAuthorization.Relative(), clientRequestHandler(s, s.server.DeviceAuthorization)) + s.Handler = router } func (s *webServer) verifyRequestClient(r *http.Request) (Client, error) { if err := r.ParseForm(); err != nil { return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) } - clientCredentials := new(ClientCredentials) - if err := s.decoder.Decode(clientCredentials, r.Form); err != nil { + cc := new(ClientCredentials) + if err := s.decoder.Decode(cc, r.Form); err != nil { return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) } // Basic auth takes precedence, so if set it overwrites the form data. if clientID, clientSecret, ok := r.BasicAuth(); ok { - clientCredentials.ClientID, clientCredentials.ClientSecret = clientID, clientSecret + cc.ClientID, cc.ClientSecret = clientID, clientSecret + } + if cc.ClientID == "" && cc.ClientAssertion == "" { + return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided") + } + if cc.ClientAssertion != "" && cc.ClientAssertionType != oidc.ClientAssertionTypeJWTAssertion { + return nil, oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type %s", cc.ClientAssertionType) } - return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{ Method: r.Method, URL: r.URL, Header: r.Header, Form: r.Form, - Data: clientCredentials, + Data: cc, }) } -func (s *webServer) handleToken(w http.ResponseWriter, r *http.Request) { +func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) { + grantType := oidc.GrantType(r.Form.Get("grant_type")) + if grantType == oidc.GrantTypeBearer { + callRequestMethod(s, w, r, s.server.JWTProfile) + return + } + client, err := s.verifyRequestClient(r) if err != nil { WriteError(w, r, err, slog.Default()) return } - grantType := oidc.GrantType(r.Form.Get("grant_type")) + switch grantType { case oidc.GrantTypeCode: - s.handleCodeExchange(w, r, client) + callClientMethod(s, w, r, client, s.server.CodeExchange) case oidc.GrantTypeRefreshToken: - s.handleRefreshToken(w, r, client) + callClientMethod(s, w, r, client, s.server.RefreshToken) + case oidc.GrantTypeTokenExchange: + callClientMethod(s, w, r, client, s.server.TokenExchange) + case oidc.GrantTypeClientCredentials: + callClientMethod(s, w, r, client, s.server.ClientCredentialsExchange) + case oidc.GrantTypeDeviceCode: + callClientMethod(s, w, r, client, s.server.DeviceToken) case "": WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default()) default: @@ -77,13 +107,36 @@ func (s *webServer) handleToken(w http.ResponseWriter, r *http.Request) { } } -func (s *webServer) handleCodeExchange(w http.ResponseWriter, r *http.Request, client Client) { - request, err := decodeRequest[*oidc.AccessTokenRequest](s.decoder, r.Form) +type requestMethod[T any] func(context.Context, *Request[T]) (*Response, error) + +func simpleHandler(s *webServer, method requestMethod[struct{}]) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger) + return + } + resp, err := method(r.Context(), newRequest(r, &struct{}{})) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp.writeOut(w) + } +} + +func requestHandler[T any](s *webServer, method requestMethod[T]) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + callRequestMethod(s, w, r, method) + } +} + +func callRequestMethod[T any](s *webServer, w http.ResponseWriter, r *http.Request, method requestMethod[T]) { + request, err := decodeRequest[T](s.decoder, r, false) if err != nil { WriteError(w, r, err, s.logger) return } - resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client)) + resp, err := method(r.Context(), newRequest[T](r, request)) if err != nil { WriteError(w, r, err, s.logger) return @@ -91,13 +144,62 @@ func (s *webServer) handleCodeExchange(w http.ResponseWriter, r *http.Request, c resp.writeOut(w) } -func (s *webServer) handleRefreshToken(w http.ResponseWriter, r *http.Request, client Client) { +type redirectMethod[T any] func(context.Context, *Request[T]) (*Redirect, error) +func redirectHandler[T any](s *webServer, method redirectMethod[T]) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + req, err := decodeRequest[T](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + redirect, err := method(r.Context(), newRequest(r, req)) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + redirect.writeOut(w, r) + } } -func decodeRequest[R any](decoder httphelper.Decoder, form map[string][]string) (request R, err error) { - if err := decoder.Decode(&request, form); err != nil { - return request, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) +type clientMethod[T any] func(context.Context, *ClientRequest[T]) (*Response, error) + +func clientRequestHandler[T any](s *webServer, method clientMethod[T]) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + client, err := s.verifyRequestClient(r) + if err != nil { + WriteError(w, r, err, slog.Default()) + return + } + callClientMethod(s, w, r, client, method) + } +} + +func callClientMethod[T any](s *webServer, w http.ResponseWriter, r *http.Request, client Client, method clientMethod[T]) { + request, err := decodeRequest[T](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp, err := method(r.Context(), newClientRequest[T](r, request, client)) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp.writeOut(w) +} + +func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) { + if err := r.ParseForm(); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) + } + form := r.Form + if postOnly { + form = r.PostForm + } + request := new(R) + if err := decoder.Decode(request, form); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) } return request, nil }