From b12bb7a1f14120fc966d3ab9ce67f1e352b845be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Fri, 22 Sep 2023 14:40:56 +0300 Subject: [PATCH] cleanup tokenHandler --- pkg/op/server_http.go | 50 +++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index 887e16c..3a22fff 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -67,6 +67,23 @@ func (s *webServer) createRouter() { s.Handler = router } +type clientHandler func(w http.ResponseWriter, r *http.Request, client Client) + +func (s *webServer) withClient(w http.ResponseWriter, r *http.Request, handler clientHandler) { + client, err := s.verifyRequestClient(r) + if err != nil { + WriteError(w, r, err, slog.Default()) + return + } + if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" { + if !ValidateGrantType(client, grantType) { + WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.logger) + return + } + } + handler(w, r, client) +} + func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) { if err = r.ParseForm(); err != nil { return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) @@ -170,37 +187,20 @@ func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) { WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default()) return } - if !grantType.IsSupported() { - WriteError(w, r, unimplementedGrantError(grantType), s.logger) - return - } - - if grantType == oidc.GrantTypeBearer { - s.jwtProfileHandler(w, r) - return - } - - client, err := s.verifyRequestClient(r) - if err != nil { - WriteError(w, r, err, slog.Default()) - return - } - if !ValidateGrantType(client, grantType) { - WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.logger) - return - } switch grantType { case oidc.GrantTypeCode: - s.codeExchangeHandler(w, r, client) + s.withClient(w, r, s.codeExchangeHandler) case oidc.GrantTypeRefreshToken: - s.refreshTokenHandler(w, r, client) - case oidc.GrantTypeTokenExchange: - s.tokenExchangeHandler(w, r, client) + s.withClient(w, r, s.refreshTokenHandler) case oidc.GrantTypeClientCredentials: - s.clientCredentialsHandler(w, r, client) + s.withClient(w, r, s.clientCredentialsHandler) + case oidc.GrantTypeBearer: + s.jwtProfileHandler(w, r) + case oidc.GrantTypeTokenExchange: + s.withClient(w, r, s.tokenExchangeHandler) case oidc.GrantTypeDeviceCode: - s.deviceTokenHandler(w, r, client) + s.withClient(w, r, s.deviceTokenHandler) default: WriteError(w, r, unimplementedGrantError(grantType), s.logger) }