define handlers, routes
This commit is contained in:
parent
fe3f98a4f9
commit
81d42b061d
3 changed files with 143 additions and 37 deletions
|
@ -183,8 +183,7 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au
|
||||||
|
|
||||||
type ClientCredentials struct {
|
type ClientCredentials struct {
|
||||||
ClientID string `schema:"client_id"`
|
ClientID string `schema:"client_id"`
|
||||||
ClientSecret string `schema:"client_secret"` // Client secret from request body
|
ClientSecret string `schema:"client_secret"` // Client secret from Basic auth or request body
|
||||||
ClientSecretBasic string `schema:"-"` // Obtained from http request
|
|
||||||
ClientAssertion string `schema:"client_assertion"` // JWT
|
ClientAssertion string `schema:"client_assertion"` // JWT
|
||||||
ClientAssertionType string `schema:"client_assertion_type"`
|
ClientAssertionType string `schema:"client_assertion_type"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -218,6 +218,11 @@ func NewRedirect(url string) *Redirect {
|
||||||
return &Redirect{URL: url}
|
return &Redirect{URL: url}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (red *Redirect) writeOut(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gu.MapMerge(r.Header, w.Header())
|
||||||
|
http.Redirect(w, r, red.URL, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
type UnimplementedServer struct{}
|
type UnimplementedServer struct{}
|
||||||
|
|
||||||
// UnimplementedStatusCode is the statuscode returned for methods
|
// UnimplementedStatusCode is the statuscode returned for methods
|
||||||
|
|
|
@ -1,75 +1,105 @@
|
||||||
package op
|
package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/rs/cors"
|
"github.com/rs/cors"
|
||||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
|
"github.com/zitadel/schema"
|
||||||
"golang.org/x/exp/slog"
|
"golang.org/x/exp/slog"
|
||||||
)
|
)
|
||||||
|
|
||||||
type webServer struct {
|
func RegisterServer(server Server) http.Handler {
|
||||||
http.Handler
|
ws := &webServer{
|
||||||
decoder httphelper.Decoder
|
server: server,
|
||||||
server Server
|
endpoints: *DefaultEndpoints,
|
||||||
logger *slog.Logger
|
decoder: schema.NewDecoder(),
|
||||||
|
logger: slog.Default(),
|
||||||
|
}
|
||||||
|
ws.createRouter()
|
||||||
|
return ws
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *webServer) createRouter(endpoints *Endpoints, interceptors ...func(http.Handler) http.Handler) chi.Router {
|
type webServer struct {
|
||||||
|
http.Handler
|
||||||
|
server Server
|
||||||
|
endpoints Endpoints
|
||||||
|
decoder httphelper.Decoder
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServer) createRouter(interceptors ...func(http.Handler) http.Handler) {
|
||||||
router := chi.NewRouter()
|
router := chi.NewRouter()
|
||||||
router.Use(cors.New(defaultCORSOptions).Handler)
|
router.Use(cors.New(defaultCORSOptions).Handler)
|
||||||
router.Use(interceptors...)
|
router.Use(interceptors...)
|
||||||
router.HandleFunc(healthEndpoint, healthHandler)
|
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
|
||||||
//router.HandleFunc(readinessEndpoint, readyHandler(o.Probes()))
|
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
|
||||||
//router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage()))
|
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
|
||||||
//router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o))
|
router.HandleFunc(s.endpoints.Authorization.Relative(), redirectHandler(s, s.server.Authorize))
|
||||||
//router.HandleFunc(authCallbackPath(o), authorizeCallbackHandler(o))
|
router.HandleFunc(s.endpoints.Token.Relative(), s.tokensHandler)
|
||||||
router.HandleFunc(endpoints.Token.Relative(), s.handleToken)
|
router.HandleFunc(s.endpoints.Introspection.Relative(), clientRequestHandler(s, s.server.Introspect))
|
||||||
//router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o))
|
router.HandleFunc(s.endpoints.Userinfo.Relative(), requestHandler(s, s.server.UserInfo))
|
||||||
//router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o))
|
router.HandleFunc(s.endpoints.Revocation.Relative(), clientRequestHandler(s, s.server.Revocation))
|
||||||
//router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o))
|
router.HandleFunc(s.endpoints.EndSession.Relative(), redirectHandler(s, s.server.EndSession))
|
||||||
//router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o))
|
router.HandleFunc(s.endpoints.JwksURI.Relative(), simpleHandler(s, s.server.Keys))
|
||||||
//router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
|
router.HandleFunc(s.endpoints.DeviceAuthorization.Relative(), clientRequestHandler(s, s.server.DeviceAuthorization))
|
||||||
//router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), DeviceAuthorizationHandler(o))
|
s.Handler = router
|
||||||
return router
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *webServer) verifyRequestClient(r *http.Request) (Client, error) {
|
func (s *webServer) verifyRequestClient(r *http.Request) (Client, error) {
|
||||||
if err := r.ParseForm(); err != nil {
|
if err := r.ParseForm(); err != nil {
|
||||||
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
|
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
|
||||||
}
|
}
|
||||||
clientCredentials := new(ClientCredentials)
|
cc := new(ClientCredentials)
|
||||||
if err := s.decoder.Decode(clientCredentials, r.Form); err != nil {
|
if err := s.decoder.Decode(cc, r.Form); err != nil {
|
||||||
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
|
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
|
||||||
}
|
}
|
||||||
// Basic auth takes precedence, so if set it overwrites the form data.
|
// Basic auth takes precedence, so if set it overwrites the form data.
|
||||||
if clientID, clientSecret, ok := r.BasicAuth(); ok {
|
if clientID, clientSecret, ok := r.BasicAuth(); ok {
|
||||||
clientCredentials.ClientID, clientCredentials.ClientSecret = clientID, clientSecret
|
cc.ClientID, cc.ClientSecret = clientID, clientSecret
|
||||||
|
}
|
||||||
|
if cc.ClientID == "" && cc.ClientAssertion == "" {
|
||||||
|
return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided")
|
||||||
|
}
|
||||||
|
if cc.ClientAssertion != "" && cc.ClientAssertionType != oidc.ClientAssertionTypeJWTAssertion {
|
||||||
|
return nil, oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type %s", cc.ClientAssertionType)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{
|
return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{
|
||||||
Method: r.Method,
|
Method: r.Method,
|
||||||
URL: r.URL,
|
URL: r.URL,
|
||||||
Header: r.Header,
|
Header: r.Header,
|
||||||
Form: r.Form,
|
Form: r.Form,
|
||||||
Data: clientCredentials,
|
Data: cc,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *webServer) handleToken(w http.ResponseWriter, r *http.Request) {
|
func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
grantType := oidc.GrantType(r.Form.Get("grant_type"))
|
||||||
|
if grantType == oidc.GrantTypeBearer {
|
||||||
|
callRequestMethod(s, w, r, s.server.JWTProfile)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
client, err := s.verifyRequestClient(r)
|
client, err := s.verifyRequestClient(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, slog.Default())
|
WriteError(w, r, err, slog.Default())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
grantType := oidc.GrantType(r.Form.Get("grant_type"))
|
|
||||||
switch grantType {
|
switch grantType {
|
||||||
case oidc.GrantTypeCode:
|
case oidc.GrantTypeCode:
|
||||||
s.handleCodeExchange(w, r, client)
|
callClientMethod(s, w, r, client, s.server.CodeExchange)
|
||||||
case oidc.GrantTypeRefreshToken:
|
case oidc.GrantTypeRefreshToken:
|
||||||
s.handleRefreshToken(w, r, client)
|
callClientMethod(s, w, r, client, s.server.RefreshToken)
|
||||||
|
case oidc.GrantTypeTokenExchange:
|
||||||
|
callClientMethod(s, w, r, client, s.server.TokenExchange)
|
||||||
|
case oidc.GrantTypeClientCredentials:
|
||||||
|
callClientMethod(s, w, r, client, s.server.ClientCredentialsExchange)
|
||||||
|
case oidc.GrantTypeDeviceCode:
|
||||||
|
callClientMethod(s, w, r, client, s.server.DeviceToken)
|
||||||
case "":
|
case "":
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default())
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default())
|
||||||
default:
|
default:
|
||||||
|
@ -77,13 +107,36 @@ func (s *webServer) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *webServer) handleCodeExchange(w http.ResponseWriter, r *http.Request, client Client) {
|
type requestMethod[T any] func(context.Context, *Request[T]) (*Response, error)
|
||||||
request, err := decodeRequest[*oidc.AccessTokenRequest](s.decoder, r.Form)
|
|
||||||
|
func simpleHandler(s *webServer, method requestMethod[struct{}]) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := method(r.Context(), newRequest(r, &struct{}{}))
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.writeOut(w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestHandler[T any](s *webServer, method requestMethod[T]) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
callRequestMethod(s, w, r, method)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func callRequestMethod[T any](s *webServer, w http.ResponseWriter, r *http.Request, method requestMethod[T]) {
|
||||||
|
request, err := decodeRequest[T](s.decoder, r, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.logger)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client))
|
resp, err := method(r.Context(), newRequest[T](r, request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, r, err, s.logger)
|
WriteError(w, r, err, s.logger)
|
||||||
return
|
return
|
||||||
|
@ -91,13 +144,62 @@ func (s *webServer) handleCodeExchange(w http.ResponseWriter, r *http.Request, c
|
||||||
resp.writeOut(w)
|
resp.writeOut(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *webServer) handleRefreshToken(w http.ResponseWriter, r *http.Request, client Client) {
|
type redirectMethod[T any] func(context.Context, *Request[T]) (*Redirect, error)
|
||||||
|
|
||||||
|
func redirectHandler[T any](s *webServer, method redirectMethod[T]) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
req, err := decodeRequest[T](s.decoder, r, false)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirect, err := method(r.Context(), newRequest(r, req))
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirect.writeOut(w, r)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeRequest[R any](decoder httphelper.Decoder, form map[string][]string) (request R, err error) {
|
type clientMethod[T any] func(context.Context, *ClientRequest[T]) (*Response, error)
|
||||||
if err := decoder.Decode(&request, form); err != nil {
|
|
||||||
return request, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
|
func clientRequestHandler[T any](s *webServer, method clientMethod[T]) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
client, err := s.verifyRequestClient(r)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, slog.Default())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
callClientMethod(s, w, r, client, method)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func callClientMethod[T any](s *webServer, w http.ResponseWriter, r *http.Request, client Client, method clientMethod[T]) {
|
||||||
|
request, err := decodeRequest[T](s.decoder, r, false)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := method(r.Context(), newClientRequest[T](r, request, client))
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.writeOut(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
|
||||||
|
}
|
||||||
|
form := r.Form
|
||||||
|
if postOnly {
|
||||||
|
form = r.PostForm
|
||||||
|
}
|
||||||
|
request := new(R)
|
||||||
|
if err := decoder.Decode(request, form); err != nil {
|
||||||
|
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
|
||||||
}
|
}
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue