finish http unit tests

This commit is contained in:
Tim Möhlmann 2023-09-25 18:18:40 +03:00
parent f9a4b82b3b
commit d17e452122
2 changed files with 891 additions and 103 deletions

View file

@ -57,11 +57,11 @@ func (s *webServer) createRouter() {
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
router.HandleFunc(s.endpoints.Authorization.Relative(), s.authorizeHandler)
router.HandleFunc(s.endpoints.DeviceAuthorization.Relative(), s.deviceAuthorizationHandler)
router.HandleFunc(s.endpoints.DeviceAuthorization.Relative(), s.withClient(s.deviceAuthorizationHandler))
router.HandleFunc(s.endpoints.Token.Relative(), s.tokensHandler)
router.HandleFunc(s.endpoints.Introspection.Relative(), s.introspectionHandler)
router.HandleFunc(s.endpoints.Introspection.Relative(), s.withClient(s.introspectionHandler))
router.HandleFunc(s.endpoints.Userinfo.Relative(), s.userInfoHandler)
router.HandleFunc(s.endpoints.Revocation.Relative(), s.revokationHandler)
router.HandleFunc(s.endpoints.Revocation.Relative(), s.withClient(s.revocationHandler))
router.HandleFunc(s.endpoints.EndSession.Relative(), s.endSessionHandler)
router.HandleFunc(s.endpoints.JwksURI.Relative(), simpleHandler(s, s.server.Keys))
s.Handler = router
@ -69,19 +69,21 @@ func (s *webServer) createRouter() {
type clientHandler func(w http.ResponseWriter, r *http.Request, client Client)
func (s *webServer) withClient(w http.ResponseWriter, r *http.Request, handler clientHandler) {
client, err := s.verifyRequestClient(r)
if err != nil {
WriteError(w, r, err, slog.Default())
return
}
if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" {
if !ValidateGrantType(client, grantType) {
WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.logger)
func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
client, err := s.verifyRequestClient(r)
if err != nil {
WriteError(w, r, err, s.logger)
return
}
if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" {
if !ValidateGrantType(client, grantType) {
WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.logger)
return
}
}
handler(w, r, client)
}
handler(w, r, client)
}
func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
@ -158,12 +160,7 @@ func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest])
return s.server.Authorize(ctx, cr)
}
func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request) {
client, err := s.verifyRequestClient(r)
if err != nil {
WriteError(w, r, err, slog.Default())
return
}
func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.logger)
@ -182,25 +179,22 @@ func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger)
return
}
grantType := oidc.GrantType(r.Form.Get("grant_type"))
if grantType == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default())
return
}
switch grantType {
switch grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType {
case oidc.GrantTypeCode:
s.withClient(w, r, s.codeExchangeHandler)
s.withClient(s.codeExchangeHandler)(w, r)
case oidc.GrantTypeRefreshToken:
s.withClient(w, r, s.refreshTokenHandler)
s.withClient(s.refreshTokenHandler)(w, r)
case oidc.GrantTypeClientCredentials:
s.withClient(w, r, s.clientCredentialsHandler)
s.withClient(s.clientCredentialsHandler)(w, r)
case oidc.GrantTypeBearer:
s.jwtProfileHandler(w, r)
case oidc.GrantTypeTokenExchange:
s.withClient(w, r, s.tokenExchangeHandler)
s.withClient(s.tokenExchangeHandler)(w, r)
case oidc.GrantTypeDeviceCode:
s.withClient(w, r, s.deviceTokenHandler)
s.withClient(s.deviceTokenHandler)(w, r)
case "":
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.logger)
default:
WriteError(w, r, unimplementedGrantError(grantType), s.logger)
}
@ -271,19 +265,19 @@ func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request,
return
}
if request.SubjectToken == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger)
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.logger)
return
}
if request.SubjectTokenType == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger)
return
}
if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger)
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.logger)
return
}
if !request.SubjectTokenType.IsSupported() {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger)
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.logger)
return
}
if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.logger)
return
}
if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() {
@ -300,8 +294,7 @@ func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request,
func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) {
if client.AuthMethod() == oidc.AuthMethodNone {
err := oidc.ErrInvalidClient().WithDescription("client must be authenticated")
WriteError(w, r, err, s.logger)
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.logger)
return
}
@ -336,10 +329,9 @@ func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, c
resp.writeOut(w)
}
func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request) {
client, err := s.verifyRequestClient(r)
if err != nil {
WriteError(w, r, err, slog.Default())
func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) {
if client.AuthMethod() == oidc.AuthMethodNone {
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.logger)
return
}
request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false)
@ -369,7 +361,7 @@ func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
request.AccessToken = token
}
if request.AccessToken == "" {
err = AsStatusError(
err = NewStatusError(
oidc.ErrInvalidRequest().WithDescription("access token missing"),
http.StatusUnauthorized,
)
@ -384,17 +376,16 @@ func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
resp.writeOut(w)
}
func (s *webServer) revokationHandler(w http.ResponseWriter, r *http.Request) {
client, err := s.verifyRequestClient(r)
if err != nil {
WriteError(w, r, err, slog.Default())
return
}
func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.logger)
return
}
if request.Token == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.logger)
return
}
resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.logger)

File diff suppressed because it is too large Load diff