From d17e452122d563c861b9aff371d689c09c8dcd04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Mon, 25 Sep 2023 18:18:40 +0300 Subject: [PATCH] finish http unit tests --- pkg/op/server_http.go | 91 ++-- pkg/op/server_http_test.go | 903 ++++++++++++++++++++++++++++++++++--- 2 files changed, 891 insertions(+), 103 deletions(-) diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index 3a22fff..e60b2ce 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -57,11 +57,11 @@ func (s *webServer) createRouter() { router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready)) router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery)) router.HandleFunc(s.endpoints.Authorization.Relative(), s.authorizeHandler) - router.HandleFunc(s.endpoints.DeviceAuthorization.Relative(), s.deviceAuthorizationHandler) + router.HandleFunc(s.endpoints.DeviceAuthorization.Relative(), s.withClient(s.deviceAuthorizationHandler)) router.HandleFunc(s.endpoints.Token.Relative(), s.tokensHandler) - router.HandleFunc(s.endpoints.Introspection.Relative(), s.introspectionHandler) + router.HandleFunc(s.endpoints.Introspection.Relative(), s.withClient(s.introspectionHandler)) router.HandleFunc(s.endpoints.Userinfo.Relative(), s.userInfoHandler) - router.HandleFunc(s.endpoints.Revocation.Relative(), s.revokationHandler) + router.HandleFunc(s.endpoints.Revocation.Relative(), s.withClient(s.revocationHandler)) router.HandleFunc(s.endpoints.EndSession.Relative(), s.endSessionHandler) router.HandleFunc(s.endpoints.JwksURI.Relative(), simpleHandler(s, s.server.Keys)) s.Handler = router @@ -69,19 +69,21 @@ func (s *webServer) createRouter() { type clientHandler func(w http.ResponseWriter, r *http.Request, client Client) -func (s *webServer) withClient(w http.ResponseWriter, r *http.Request, handler clientHandler) { - client, err := s.verifyRequestClient(r) - if err != nil { - WriteError(w, r, err, slog.Default()) - return - } - if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" { - if !ValidateGrantType(client, grantType) { - WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.logger) +func (s *webServer) withClient(handler clientHandler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + client, err := s.verifyRequestClient(r) + if err != nil { + WriteError(w, r, err, s.logger) return } + if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" { + if !ValidateGrantType(client, grantType) { + WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.logger) + return + } + } + handler(w, r, client) } - handler(w, r, client) } func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) { @@ -158,12 +160,7 @@ func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest]) return s.server.Authorize(ctx, cr) } -func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request) { - client, err := s.verifyRequestClient(r) - if err != nil { - WriteError(w, r, err, slog.Default()) - return - } +func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) { request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false) if err != nil { WriteError(w, r, err, s.logger) @@ -182,25 +179,22 @@ func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) { WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger) return } - grantType := oidc.GrantType(r.Form.Get("grant_type")) - if grantType == "" { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default()) - return - } - switch grantType { + switch grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType { case oidc.GrantTypeCode: - s.withClient(w, r, s.codeExchangeHandler) + s.withClient(s.codeExchangeHandler)(w, r) case oidc.GrantTypeRefreshToken: - s.withClient(w, r, s.refreshTokenHandler) + s.withClient(s.refreshTokenHandler)(w, r) case oidc.GrantTypeClientCredentials: - s.withClient(w, r, s.clientCredentialsHandler) + s.withClient(s.clientCredentialsHandler)(w, r) case oidc.GrantTypeBearer: s.jwtProfileHandler(w, r) case oidc.GrantTypeTokenExchange: - s.withClient(w, r, s.tokenExchangeHandler) + s.withClient(s.tokenExchangeHandler)(w, r) case oidc.GrantTypeDeviceCode: - s.withClient(w, r, s.deviceTokenHandler) + s.withClient(s.deviceTokenHandler)(w, r) + case "": + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.logger) default: WriteError(w, r, unimplementedGrantError(grantType), s.logger) } @@ -271,19 +265,19 @@ func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, return } if request.SubjectToken == "" { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.logger) return } if request.SubjectTokenType == "" { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger) - return - } - if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.logger) return } if !request.SubjectTokenType.IsSupported() { - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger) + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.logger) + return + } + if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.logger) return } if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() { @@ -300,8 +294,7 @@ func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) { if client.AuthMethod() == oidc.AuthMethodNone { - err := oidc.ErrInvalidClient().WithDescription("client must be authenticated") - WriteError(w, r, err, s.logger) + WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.logger) return } @@ -336,10 +329,9 @@ func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, c resp.writeOut(w) } -func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request) { - client, err := s.verifyRequestClient(r) - if err != nil { - WriteError(w, r, err, slog.Default()) +func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) { + if client.AuthMethod() == oidc.AuthMethodNone { + WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.logger) return } request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false) @@ -369,7 +361,7 @@ func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) { request.AccessToken = token } if request.AccessToken == "" { - err = AsStatusError( + err = NewStatusError( oidc.ErrInvalidRequest().WithDescription("access token missing"), http.StatusUnauthorized, ) @@ -384,17 +376,16 @@ func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) { resp.writeOut(w) } -func (s *webServer) revokationHandler(w http.ResponseWriter, r *http.Request) { - client, err := s.verifyRequestClient(r) - if err != nil { - WriteError(w, r, err, slog.Default()) - return - } +func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) { request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false) if err != nil { WriteError(w, r, err, s.logger) return } + if request.Token == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.logger) + return + } resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client)) if err != nil { WriteError(w, r, err, s.logger) diff --git a/pkg/op/server_http_test.go b/pkg/op/server_http_test.go index b3b034c..40fece5 100644 --- a/pkg/op/server_http_test.go +++ b/pkg/op/server_http_test.go @@ -3,6 +3,7 @@ package op import ( "bytes" "context" + "fmt" "io" "net/http" "net/http/httptest" @@ -30,27 +31,37 @@ type testClient struct { devMode bool } -func newClient(kind string) *testClient { +type clientType string + +const ( + clientTypeWeb clientType = "web" + clientTypeNative clientType = "native" + clientTypeUserAgent clientType = "useragent" +) + +func newClient(kind clientType) *testClient { client := &testClient{ - id: kind, + id: string(kind), } switch kind { - case "web_client": + case clientTypeWeb: client.appType = ApplicationTypeWeb client.authMethod = oidc.AuthMethodBasic client.accessTokenType = AccessTokenTypeBearer client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode} - case "native_client": + case clientTypeNative: client.appType = ApplicationTypeNative client.authMethod = oidc.AuthMethodNone client.accessTokenType = AccessTokenTypeBearer client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode} - case "useragent_client": + case clientTypeUserAgent: client.appType = ApplicationTypeUserAgent client.authMethod = oidc.AuthMethodBasic client.accessTokenType = AccessTokenTypeJWT client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeIDToken} + default: + panic(fmt.Errorf("invalid client type %s", kind)) } return client } @@ -142,11 +153,23 @@ var testDecoder = func() *schema.Decoder { return decoder }() -var testWebServer = &webServer{ - server: UnimplementedServer{}, - endpoints: *DefaultEndpoints, - decoder: testDecoder, - logger: slog.Default(), +type webServerResult struct { + wantStatus int + wantBody string +} + +func runWebServerTest(t *testing.T, handler http.HandlerFunc, r *http.Request, want webServerResult) { + t.Helper() + if r.Method == http.MethodPost { + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + w := httptest.NewRecorder() + handler(w, r) + res := w.Result() + assert.Equal(t, want.wantStatus, res.StatusCode) + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.JSONEq(t, want.wantBody, string(body)) } func Test_webServer_verifyRequestClient(t *testing.T) { @@ -160,20 +183,20 @@ func Test_webServer_verifyRequestClient(t *testing.T) { { name: "parse form error", decoder: testDecoder, - r: httptest.NewRequest("POST", "/", bytes.NewReader(make([]byte, 11<<20))), + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), wantErr: oidc.ErrInvalidRequest().WithDescription("error parsing form"), }, { name: "decoder error", decoder: schema.NewDecoder(), - r: httptest.NewRequest("POST", "/", strings.NewReader("foo=bar")), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), wantErr: oidc.ErrInvalidRequest().WithDescription("error decoding form"), }, { name: "basic auth, client_id error", decoder: testDecoder, r: func() *http.Request { - r := httptest.NewRequest("POST", "/", strings.NewReader("foo=bar")) + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")) r.SetBasicAuth(`%%%`, "secret") return r }(), @@ -183,7 +206,7 @@ func Test_webServer_verifyRequestClient(t *testing.T) { name: "basic auth, client_secret error", decoder: testDecoder, r: func() *http.Request { - r := httptest.NewRequest("POST", "/", strings.NewReader("foo=bar")) + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")) r.SetBasicAuth("web", `%%%`) return r }(), @@ -192,19 +215,19 @@ func Test_webServer_verifyRequestClient(t *testing.T) { { name: "missing client id and assertion", decoder: testDecoder, - r: httptest.NewRequest("POST", "/", strings.NewReader("foo=bar")), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), wantErr: oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided"), }, { name: "wrong assertion type", decoder: testDecoder, - r: httptest.NewRequest("POST", "/", strings.NewReader("foo=bar&client_assertion=xxx&client_assertion_type=wrong")), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar&client_assertion=xxx&client_assertion_type=wrong")), wantErr: oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type wrong"), }, { name: "unimplemented verify client called", decoder: testDecoder, - r: httptest.NewRequest("POST", "/", strings.NewReader("foo=bar&client_id=web")), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar&client_id=web")), wantErr: StatusError{ parent: oidc.ErrServerError().WithDescription("/ not implemented on this server"), statusCode: UnimplementedStatusCode, @@ -226,12 +249,12 @@ func Test_webServer_verifyRequestClient(t *testing.T) { } } -type authRequestVerifier struct { +type requestVerifier struct { UnimplementedServer client Client } -func (s *authRequestVerifier) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { +func (s *requestVerifier) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { if s.client == nil { return nil, oidc.ErrServerError() } @@ -241,37 +264,47 @@ func (s *authRequestVerifier) VerifyAuthRequest(ctx context.Context, r *Request[ }, nil } +func (s *requestVerifier) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { + if s.client == nil { + return nil, oidc.ErrServerError() + } + return s.client, nil +} + func Test_webServer_authorizeHandler(t *testing.T) { type fields struct { server Server decoder httphelper.Decoder } tests := []struct { - name string - fields fields - r *http.Request - wantStatus int - wantBody string + name string + fields fields + r *http.Request + want webServerResult }{ { name: "decoder error", fields: fields{ - server: &authRequestVerifier{}, + server: &requestVerifier{}, decoder: schema.NewDecoder(), }, - r: httptest.NewRequest("POST", "/authorize", bytes.NewBufferString("foo=bar")), - wantStatus: http.StatusBadRequest, - wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, }, { name: "authorize error", fields: fields{ - server: &authRequestVerifier{}, + server: &requestVerifier{}, decoder: testDecoder, }, - r: httptest.NewRequest("POST", "/authorize", bytes.NewBufferString("foo=bar")), - wantStatus: http.StatusBadRequest, - wantBody: `{"error":"server_error"}`, + r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"server_error"}`, + }, }, } for _, tt := range tests { @@ -281,14 +314,7 @@ func Test_webServer_authorizeHandler(t *testing.T) { decoder: tt.fields.decoder, logger: slog.Default(), } - tt.r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - w := httptest.NewRecorder() - s.authorizeHandler(w, tt.r) - res := w.Result() - assert.Equal(t, tt.wantStatus, res.StatusCode) - body, err := io.ReadAll(res.Body) - require.NoError(t, err) - assert.JSONEq(t, tt.wantBody, string(body)) + runWebServerTest(t, s.authorizeHandler, tt.r, tt.want) }) } } @@ -307,7 +333,7 @@ func Test_webServer_authorize(t *testing.T) { }{ { name: "verify error", - server: &authRequestVerifier{}, + server: &requestVerifier{}, args: args{ ctx: context.Background(), r: &Request[oidc.AuthRequest]{ @@ -324,8 +350,8 @@ func Test_webServer_authorize(t *testing.T) { }, { name: "missing redirect", - server: &authRequestVerifier{ - client: newClient("web_client"), + server: &requestVerifier{ + client: newClient(clientTypeWeb), }, args: args{ ctx: context.Background(), @@ -342,8 +368,8 @@ func Test_webServer_authorize(t *testing.T) { }, { name: "invalid prompt", - server: &authRequestVerifier{ - client: newClient("web_client"), + server: &requestVerifier{ + client: newClient(clientTypeWeb), }, args: args{ ctx: context.Background(), @@ -362,8 +388,8 @@ func Test_webServer_authorize(t *testing.T) { }, { name: "missing scopes", - server: &authRequestVerifier{ - client: newClient("web_client"), + server: &requestVerifier{ + client: newClient(clientTypeWeb), }, args: args{ ctx: context.Background(), @@ -383,8 +409,8 @@ func Test_webServer_authorize(t *testing.T) { }, { name: "invalid redirect", - server: &authRequestVerifier{ - client: newClient("web_client"), + server: &requestVerifier{ + client: newClient(clientTypeWeb), }, args: args{ ctx: context.Background(), @@ -405,8 +431,8 @@ func Test_webServer_authorize(t *testing.T) { }, { name: "invalid response type", - server: &authRequestVerifier{ - client: newClient("web_client"), + server: &requestVerifier{ + client: newClient(clientTypeWeb), }, args: args{ ctx: context.Background(), @@ -426,8 +452,8 @@ func Test_webServer_authorize(t *testing.T) { }, { name: "unimplemented Authorize called", - server: &authRequestVerifier{ - client: newClient("web_client"), + server: &requestVerifier{ + client: newClient(clientTypeWeb), }, args: args{ ctx: context.Background(), @@ -464,3 +490,774 @@ func Test_webServer_authorize(t *testing.T) { }) } } + +func Test_webServer_deviceAuthorizationHandler(t *testing.T) { + type fields struct { + server Server + decoder httphelper.Decoder + } + tests := []struct { + name string + fields fields + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + fields: fields{ + server: &requestVerifier{}, + decoder: schema.NewDecoder(), + }, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "unimplemented DeviceAuthorization called", + fields: fields{ + server: &requestVerifier{ + client: newClient(clientTypeNative), + }, + decoder: testDecoder, + }, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native_client")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: tt.fields.server, + decoder: tt.fields.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.deviceAuthorizationHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_tokensHandler(t *testing.T) { + tests := []struct { + name string + r *http.Request + want webServerResult + }{ + { + name: "parse form error", + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`, + }, + }, + { + name: "missing grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"grant_type missing"}`, + }, + }, + { + name: "invalid grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("grant_type=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"bar not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + logger: slog.Default(), + } + runWebServerTest(t, s.tokensHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_jwtProfileHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "assertion missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"assertion missing"}`, + }, + }, + { + name: "unimplemented JWTProfile called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("assertion=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:jwt-bearer not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.jwtProfileHandler, tt.r, tt.want) + }) + } +} + +func runWebServerClientTest(t *testing.T, handler func(http.ResponseWriter, *http.Request, Client), r *http.Request, client Client, want webServerResult) { + t.Helper() + runWebServerTest(t, func(client Client) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + handler(w, r, client) + } + }(client), r, want) +} + +func Test_webServer_codeExchangeHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "code missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"code missing"}`, + }, + }, + { + name: "redirect missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("code=123")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"redirect_uri missing"}`, + }, + }, + { + name: "unimplemented JWTProfile called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("code=123&redirect_uri=https://example.com/callback")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"authorization_code not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.codeExchangeHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_refreshTokenHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "refresh token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"refresh_token missing"}`, + }, + }, + { + name: "unimplemented RefreshToken called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("refresh_token=xxx")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"refresh_token not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.refreshTokenHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_tokenExchangeHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "subject token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"subject_token missing"}`, + }, + }, + { + name: "subject token type missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"subject_token_type missing"}`, + }, + }, + { + name: "subject token type unsupported", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=foo")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"subject_token_type is not supported"}`, + }, + }, + { + name: "unsupported requested token type", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=foo")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"requested_token_type is not supported"}`, + }, + }, + { + name: "unsupported actor token type", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=urn:ietf:params:oauth:token-type:access_token&actor_token_type=foo")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"actor_token_type is not supported"}`, + }, + }, + { + name: "unimplemented TokenExchange called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=urn:ietf:params:oauth:token-type:access_token&actor_token_type=urn:ietf:params:oauth:token-type:access_token")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:token-exchange not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.tokenExchangeHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_clientCredentialsHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + client Client + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + client: newClient(clientTypeUserAgent), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "public client", + decoder: testDecoder, + client: newClient(clientTypeNative), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`, + }, + }, + { + name: "unimplemented ClientCredentialsExchange called", + decoder: testDecoder, + client: newClient(clientTypeUserAgent), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"client_credentials not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerClientTest(t, s.clientCredentialsHandler, tt.r, tt.client, tt.want) + }) + } +} + +func Test_webServer_deviceTokenHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "device code missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"device_code missing"}`, + }, + }, + { + name: "unimplemented DeviceToken called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("device_code=xxx")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:device_code not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.deviceTokenHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_introspectionHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + client Client + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + client: newClient(clientTypeUserAgent), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "public client", + decoder: testDecoder, + client: newClient(clientTypeNative), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`, + }, + }, + { + name: "token missing", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"token missing"}`, + }, + }, + { + name: "unimplemented Introspect called", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerClientTest(t, s.introspectionHandler, tt.r, tt.client, tt.want) + }) + } +} + +func Test_webServer_userInfoHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "access token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusUnauthorized, + wantBody: `{"error":"invalid_request", "error_description":"access token missing"}`, + }, + }, + { + name: "unimplemented UserInfo called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("access_token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + { + name: "bearer", + decoder: testDecoder, + r: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("authorization", strings.Join([]string{"Bearer", "xxx"}, " ")) + return r + }(), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.userInfoHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_revocationHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + client Client + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "token missing", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"token missing"}`, + }, + }, + { + name: "unimplemented Revocation called, confidential client", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + { + name: "unimplemented Revocation called, public client", + decoder: testDecoder, + client: newClient(clientTypeNative), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerClientTest(t, s.revocationHandler, tt.r, tt.client, tt.want) + }) + } +} + +func Test_webServer_endSessionHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "unimplemented EndSession called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("id_token_hint=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.endSessionHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_simpleHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + method func(context.Context, *Request[struct{}]) (*Response, error) + r *http.Request + want webServerResult + }{ + { + name: "parse error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`, + }, + }, + { + name: "method error", + decoder: schema.NewDecoder(), + method: func(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, io.ErrClosedPipe + }, + r: httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, simpleHandler(s, tt.method), tt.r, tt.want) + }) + } +} + +func Test_decodeRequest(t *testing.T) { + type dst struct { + A string `schema:"a"` + B string `schema:"b"` + } + type args struct { + r *http.Request + postOnly bool + } + tests := []struct { + name string + args args + want *dst + wantErr error + }{ + { + name: "parse error", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + }, + wantErr: oidc.ErrInvalidRequest().WithDescription("error parsing form"), + }, + { + name: "decode error", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + }, + wantErr: oidc.ErrInvalidRequest().WithDescription("error decoding form"), + }, + { + name: "success, get", + args: args{ + r: httptest.NewRequest(http.MethodGet, "/?a=b&b=a", nil), + }, + want: &dst{ + A: "b", + B: "a", + }, + }, + { + name: "success, post only", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/?b=a", strings.NewReader("a=b&")), + postOnly: true, + }, + want: &dst{ + A: "b", + }, + }, + { + name: "success, post mixed", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/?b=a", strings.NewReader("a=b&")), + postOnly: false, + }, + want: &dst{ + A: "b", + B: "a", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.args.r.Method == http.MethodPost { + tt.args.r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + got, err := decodeRequest[dst](schema.NewDecoder(), tt.args.r, tt.args.postOnly) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +}