1328 lines
38 KiB
Go
1328 lines
38 KiB
Go
package op
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/muhlemmer/gu"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
httphelper "github.com/zitadel/oidc/v4/pkg/http"
|
|
"github.com/zitadel/oidc/v4/pkg/oidc"
|
|
"github.com/zitadel/schema"
|
|
)
|
|
|
|
func TestRegisterServer(t *testing.T) {
|
|
server := UnimplementedServer{}
|
|
endpoints := Endpoints{
|
|
Authorization: &Endpoint{
|
|
path: "/auth",
|
|
},
|
|
}
|
|
decoder := schema.NewDecoder()
|
|
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
|
|
|
|
h := RegisterServer(server, endpoints,
|
|
WithDecoder(decoder),
|
|
WithFallbackLogger(logger),
|
|
)
|
|
got := h.(*webServer)
|
|
assert.Equal(t, got.server, server)
|
|
assert.Equal(t, got.endpoints, endpoints)
|
|
assert.Equal(t, got.decoder, decoder)
|
|
assert.Equal(t, got.logger, logger)
|
|
}
|
|
|
|
type testClient struct {
|
|
id string
|
|
appType ApplicationType
|
|
authMethod oidc.AuthMethod
|
|
accessTokenType AccessTokenType
|
|
responseTypes []oidc.ResponseType
|
|
grantTypes []oidc.GrantType
|
|
devMode bool
|
|
}
|
|
|
|
type clientType string
|
|
|
|
const (
|
|
clientTypeWeb clientType = "web"
|
|
clientTypeNative clientType = "native"
|
|
clientTypeUserAgent clientType = "useragent"
|
|
)
|
|
|
|
func newClient(kind clientType) *testClient {
|
|
client := &testClient{
|
|
id: string(kind),
|
|
}
|
|
|
|
switch kind {
|
|
case clientTypeWeb:
|
|
client.appType = ApplicationTypeWeb
|
|
client.authMethod = oidc.AuthMethodBasic
|
|
client.accessTokenType = AccessTokenTypeBearer
|
|
client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode}
|
|
case clientTypeNative:
|
|
client.appType = ApplicationTypeNative
|
|
client.authMethod = oidc.AuthMethodNone
|
|
client.accessTokenType = AccessTokenTypeBearer
|
|
client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode}
|
|
case clientTypeUserAgent:
|
|
client.appType = ApplicationTypeUserAgent
|
|
client.authMethod = oidc.AuthMethodBasic
|
|
client.accessTokenType = AccessTokenTypeJWT
|
|
client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeIDToken}
|
|
default:
|
|
panic(fmt.Errorf("invalid client type %s", kind))
|
|
}
|
|
return client
|
|
}
|
|
|
|
func (c *testClient) RedirectURIs() []string {
|
|
return []string{
|
|
"https://registered.com/callback",
|
|
"http://registered.com/callback",
|
|
"http://localhost:9999/callback",
|
|
"custom://callback",
|
|
}
|
|
}
|
|
|
|
func (c *testClient) PostLogoutRedirectURIs() []string {
|
|
return []string{}
|
|
}
|
|
|
|
func (c *testClient) LoginURL(id string) string {
|
|
return "login?id=" + id
|
|
}
|
|
|
|
func (c *testClient) ApplicationType() ApplicationType {
|
|
return c.appType
|
|
}
|
|
|
|
func (c *testClient) AuthMethod() oidc.AuthMethod {
|
|
return c.authMethod
|
|
}
|
|
|
|
func (c *testClient) GetID() string {
|
|
return c.id
|
|
}
|
|
|
|
func (c *testClient) AccessTokenLifetime() time.Duration {
|
|
return 5 * time.Minute
|
|
}
|
|
|
|
func (c *testClient) IDTokenLifetime() time.Duration {
|
|
return 5 * time.Minute
|
|
}
|
|
|
|
func (c *testClient) AccessTokenType() AccessTokenType {
|
|
return c.accessTokenType
|
|
}
|
|
|
|
func (c *testClient) ResponseTypes() []oidc.ResponseType {
|
|
return c.responseTypes
|
|
}
|
|
|
|
func (c *testClient) GrantTypes() []oidc.GrantType {
|
|
return c.grantTypes
|
|
}
|
|
|
|
func (c *testClient) DevMode() bool {
|
|
return c.devMode
|
|
}
|
|
|
|
func (c *testClient) AllowedScopes() []string {
|
|
return nil
|
|
}
|
|
|
|
func (c *testClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string {
|
|
return func(scopes []string) []string {
|
|
return scopes
|
|
}
|
|
}
|
|
|
|
func (c *testClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string {
|
|
return func(scopes []string) []string {
|
|
return scopes
|
|
}
|
|
}
|
|
|
|
func (c *testClient) IsScopeAllowed(scope string) bool {
|
|
return false
|
|
}
|
|
|
|
func (c *testClient) IDTokenUserinfoClaimsAssertion() bool {
|
|
return false
|
|
}
|
|
|
|
func (c *testClient) ClockSkew() time.Duration {
|
|
return 0
|
|
}
|
|
|
|
type requestVerifier struct {
|
|
UnimplementedServer
|
|
client Client
|
|
}
|
|
|
|
func (s *requestVerifier) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
|
|
if s.client == nil {
|
|
return nil, oidc.ErrServerError()
|
|
}
|
|
return &ClientRequest[oidc.AuthRequest]{
|
|
Request: r,
|
|
Client: s.client,
|
|
}, nil
|
|
}
|
|
|
|
func (s *requestVerifier) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
|
|
if s.client == nil {
|
|
return nil, oidc.ErrServerError()
|
|
}
|
|
return s.client, nil
|
|
}
|
|
|
|
var testDecoder = func() *schema.Decoder {
|
|
decoder := schema.NewDecoder()
|
|
decoder.IgnoreUnknownKeys(true)
|
|
return decoder
|
|
}()
|
|
|
|
type webServerResult struct {
|
|
wantStatus int
|
|
wantBody string
|
|
}
|
|
|
|
func runWebServerTest(t *testing.T, handler http.HandlerFunc, r *http.Request, want webServerResult) {
|
|
t.Helper()
|
|
if r.Method == http.MethodPost {
|
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
}
|
|
w := httptest.NewRecorder()
|
|
handler(w, r)
|
|
res := w.Result()
|
|
assert.Equal(t, want.wantStatus, res.StatusCode)
|
|
body, err := io.ReadAll(res.Body)
|
|
require.NoError(t, err)
|
|
assert.JSONEq(t, want.wantBody, string(body))
|
|
}
|
|
|
|
func Test_webServer_withClient(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
r *http.Request
|
|
want webServerResult
|
|
}{
|
|
{
|
|
name: "parse error",
|
|
r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "invalid grant type",
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native&grant_type=bad&foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"unauthorized_client", "error_description":"grant_type \"bad\" not allowed"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "no grant type",
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native&foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusOK,
|
|
wantBody: `{"foo":"bar"}`,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: &requestVerifier{
|
|
client: newClient(clientTypeNative),
|
|
},
|
|
decoder: testDecoder,
|
|
logger: slog.Default(),
|
|
}
|
|
handler := func(w http.ResponseWriter, r *http.Request, client Client) {
|
|
fmt.Fprintf(w, `{"foo":%q}`, r.FormValue("foo"))
|
|
}
|
|
runWebServerTest(t, s.withClient(handler), tt.r, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_webServer_verifyRequestClient(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
decoder httphelper.Decoder
|
|
r *http.Request
|
|
want Client
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "parse form error",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))),
|
|
wantErr: oidc.ErrInvalidRequest().WithDescription("error parsing form"),
|
|
},
|
|
{
|
|
name: "decoder error",
|
|
decoder: schema.NewDecoder(),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
wantErr: oidc.ErrInvalidRequest().WithDescription("error decoding form"),
|
|
},
|
|
{
|
|
name: "basic auth, client_id error",
|
|
decoder: testDecoder,
|
|
r: func() *http.Request {
|
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar"))
|
|
r.SetBasicAuth(`%%%`, "secret")
|
|
return r
|
|
}(),
|
|
wantErr: oidc.ErrInvalidClient().WithDescription("invalid basic auth header"),
|
|
},
|
|
{
|
|
name: "basic auth, client_secret error",
|
|
decoder: testDecoder,
|
|
r: func() *http.Request {
|
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar"))
|
|
r.SetBasicAuth("web", `%%%`)
|
|
return r
|
|
}(),
|
|
wantErr: oidc.ErrInvalidClient().WithDescription("invalid basic auth header"),
|
|
},
|
|
{
|
|
name: "missing client id and assertion",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
wantErr: oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided"),
|
|
},
|
|
{
|
|
name: "wrong assertion type",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar&client_assertion=xxx&client_assertion_type=wrong")),
|
|
wantErr: oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type wrong"),
|
|
},
|
|
{
|
|
name: "unimplemented verify client called",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar&client_id=web")),
|
|
wantErr: StatusError{
|
|
parent: oidc.ErrServerError().WithDescription("/ not implemented on this server"),
|
|
statusCode: UnimplementedStatusCode,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: UnimplementedServer{},
|
|
decoder: tt.decoder,
|
|
logger: slog.Default(),
|
|
}
|
|
tt.r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
got, err := s.verifyRequestClient(tt.r)
|
|
require.ErrorIs(t, err, tt.wantErr)
|
|
assert.Equal(t, tt.want, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_webServer_authorizeHandler(t *testing.T) {
|
|
type fields struct {
|
|
server Server
|
|
decoder httphelper.Decoder
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
r *http.Request
|
|
want webServerResult
|
|
}{
|
|
{
|
|
name: "decoder error",
|
|
fields: fields{
|
|
server: &requestVerifier{},
|
|
decoder: schema.NewDecoder(),
|
|
},
|
|
r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "server error",
|
|
fields: fields{
|
|
server: &requestVerifier{},
|
|
decoder: testDecoder,
|
|
},
|
|
r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusInternalServerError,
|
|
wantBody: `{"error":"server_error"}`,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: tt.fields.server,
|
|
decoder: tt.fields.decoder,
|
|
logger: slog.Default(),
|
|
}
|
|
runWebServerTest(t, s.authorizeHandler, tt.r, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_webServer_authorize(t *testing.T) {
|
|
type args struct {
|
|
ctx context.Context
|
|
r *Request[oidc.AuthRequest]
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
server Server
|
|
args args
|
|
want *Redirect
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "verify error",
|
|
server: &requestVerifier{},
|
|
args: args{
|
|
ctx: context.Background(),
|
|
r: &Request[oidc.AuthRequest]{
|
|
Data: &oidc.AuthRequest{
|
|
Scopes: oidc.SpaceDelimitedArray{"openid"},
|
|
ResponseType: oidc.ResponseTypeCode,
|
|
ClientID: "web",
|
|
RedirectURI: "https://registered.com/callback",
|
|
MaxAge: gu.Ptr[uint](300),
|
|
},
|
|
},
|
|
},
|
|
wantErr: oidc.ErrServerError(),
|
|
},
|
|
{
|
|
name: "missing redirect",
|
|
server: &requestVerifier{
|
|
client: newClient(clientTypeWeb),
|
|
},
|
|
args: args{
|
|
ctx: context.Background(),
|
|
r: &Request[oidc.AuthRequest]{
|
|
Data: &oidc.AuthRequest{
|
|
Scopes: oidc.SpaceDelimitedArray{"openid"},
|
|
ResponseType: oidc.ResponseTypeCode,
|
|
ClientID: "web",
|
|
MaxAge: gu.Ptr[uint](300),
|
|
},
|
|
},
|
|
},
|
|
wantErr: ErrAuthReqMissingRedirectURI,
|
|
},
|
|
{
|
|
name: "invalid prompt",
|
|
server: &requestVerifier{
|
|
client: newClient(clientTypeWeb),
|
|
},
|
|
args: args{
|
|
ctx: context.Background(),
|
|
r: &Request[oidc.AuthRequest]{
|
|
Data: &oidc.AuthRequest{
|
|
Scopes: oidc.SpaceDelimitedArray{"openid"},
|
|
ResponseType: oidc.ResponseTypeCode,
|
|
ClientID: "web",
|
|
RedirectURI: "https://registered.com/callback",
|
|
MaxAge: gu.Ptr[uint](300),
|
|
Prompt: []string{oidc.PromptNone, oidc.PromptLogin},
|
|
},
|
|
},
|
|
},
|
|
wantErr: oidc.ErrInvalidRequest().WithDescription("The prompt parameter `none` must only be used as a single value"),
|
|
},
|
|
{
|
|
name: "missing scopes",
|
|
server: &requestVerifier{
|
|
client: newClient(clientTypeWeb),
|
|
},
|
|
args: args{
|
|
ctx: context.Background(),
|
|
r: &Request[oidc.AuthRequest]{
|
|
Data: &oidc.AuthRequest{
|
|
ResponseType: oidc.ResponseTypeCode,
|
|
ClientID: "web",
|
|
RedirectURI: "https://registered.com/callback",
|
|
MaxAge: gu.Ptr[uint](300),
|
|
Prompt: []string{oidc.PromptNone},
|
|
},
|
|
},
|
|
},
|
|
wantErr: oidc.ErrInvalidRequest().
|
|
WithDescription("The scope of your request is missing. Please ensure some scopes are requested. " +
|
|
"If you have any questions, you may contact the administrator of the application."),
|
|
},
|
|
{
|
|
name: "invalid redirect",
|
|
server: &requestVerifier{
|
|
client: newClient(clientTypeWeb),
|
|
},
|
|
args: args{
|
|
ctx: context.Background(),
|
|
r: &Request[oidc.AuthRequest]{
|
|
Data: &oidc.AuthRequest{
|
|
Scopes: oidc.SpaceDelimitedArray{"openid"},
|
|
ResponseType: oidc.ResponseTypeCode,
|
|
ClientID: "web",
|
|
RedirectURI: "https://example.com/callback",
|
|
MaxAge: gu.Ptr[uint](300),
|
|
Prompt: []string{oidc.PromptNone},
|
|
},
|
|
},
|
|
},
|
|
wantErr: oidc.ErrInvalidRequestRedirectURI().
|
|
WithDescription("The requested redirect_uri is missing in the client configuration. " +
|
|
"If you have any questions, you may contact the administrator of the application."),
|
|
},
|
|
{
|
|
name: "invalid response type",
|
|
server: &requestVerifier{
|
|
client: newClient(clientTypeWeb),
|
|
},
|
|
args: args{
|
|
ctx: context.Background(),
|
|
r: &Request[oidc.AuthRequest]{
|
|
Data: &oidc.AuthRequest{
|
|
Scopes: oidc.SpaceDelimitedArray{"openid"},
|
|
ResponseType: oidc.ResponseTypeIDToken,
|
|
ClientID: "web",
|
|
RedirectURI: "https://registered.com/callback",
|
|
MaxAge: gu.Ptr[uint](300),
|
|
Prompt: []string{oidc.PromptNone},
|
|
},
|
|
},
|
|
},
|
|
wantErr: oidc.ErrUnauthorizedClient().WithDescription("The requested response type is missing in the client configuration. " +
|
|
"If you have any questions, you may contact the administrator of the application."),
|
|
},
|
|
{
|
|
name: "unimplemented Authorize called",
|
|
server: &requestVerifier{
|
|
client: newClient(clientTypeWeb),
|
|
},
|
|
args: args{
|
|
ctx: context.Background(),
|
|
r: &Request[oidc.AuthRequest]{
|
|
URL: &url.URL{
|
|
Path: "/authorize",
|
|
},
|
|
Data: &oidc.AuthRequest{
|
|
Scopes: oidc.SpaceDelimitedArray{"openid"},
|
|
ResponseType: oidc.ResponseTypeCode,
|
|
ClientID: "web",
|
|
RedirectURI: "https://registered.com/callback",
|
|
MaxAge: gu.Ptr[uint](300),
|
|
Prompt: []string{oidc.PromptNone},
|
|
},
|
|
},
|
|
},
|
|
wantErr: StatusError{
|
|
parent: oidc.ErrServerError().WithDescription("/authorize not implemented on this server"),
|
|
statusCode: UnimplementedStatusCode,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: tt.server,
|
|
decoder: testDecoder,
|
|
logger: slog.Default(),
|
|
}
|
|
got, err := s.authorize(tt.args.ctx, tt.args.r)
|
|
require.ErrorIs(t, err, tt.wantErr)
|
|
assert.Equal(t, tt.want, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_webServer_deviceAuthorizationHandler(t *testing.T) {
|
|
type fields struct {
|
|
server Server
|
|
decoder httphelper.Decoder
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
r *http.Request
|
|
want webServerResult
|
|
}{
|
|
{
|
|
name: "decoder error",
|
|
fields: fields{
|
|
server: &requestVerifier{},
|
|
decoder: schema.NewDecoder(),
|
|
},
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "unimplemented DeviceAuthorization called",
|
|
fields: fields{
|
|
server: &requestVerifier{
|
|
client: newClient(clientTypeNative),
|
|
},
|
|
decoder: testDecoder,
|
|
},
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native_client")),
|
|
want: webServerResult{
|
|
wantStatus: UnimplementedStatusCode,
|
|
wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: tt.fields.server,
|
|
decoder: tt.fields.decoder,
|
|
logger: slog.Default(),
|
|
}
|
|
client := newClient(clientTypeUserAgent)
|
|
runWebServerClientTest(t, s.deviceAuthorizationHandler, tt.r, client, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_webServer_tokensHandler(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
r *http.Request
|
|
want webServerResult
|
|
}{
|
|
{
|
|
name: "parse form error",
|
|
r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "missing grant type",
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"grant_type missing"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "invalid grant type",
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("grant_type=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"unsupported_grant_type", "error_description":"bar not supported"}`,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
logger: slog.Default(),
|
|
}
|
|
runWebServerTest(t, s.tokensHandler, tt.r, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_webServer_jwtProfileHandler(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
decoder httphelper.Decoder
|
|
r *http.Request
|
|
want webServerResult
|
|
}{
|
|
{
|
|
name: "decoder error",
|
|
decoder: schema.NewDecoder(),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "assertion missing",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"assertion missing"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "unimplemented JWTProfile called",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("assertion=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:jwt-bearer not supported"}`,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: UnimplementedServer{},
|
|
decoder: tt.decoder,
|
|
logger: slog.Default(),
|
|
}
|
|
runWebServerTest(t, s.jwtProfileHandler, tt.r, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func runWebServerClientTest(t *testing.T, handler func(http.ResponseWriter, *http.Request, Client), r *http.Request, client Client, want webServerResult) {
|
|
t.Helper()
|
|
runWebServerTest(t, func(client Client) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
handler(w, r, client)
|
|
}
|
|
}(client), r, want)
|
|
}
|
|
|
|
func Test_webServer_codeExchangeHandler(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
decoder httphelper.Decoder
|
|
r *http.Request
|
|
want webServerResult
|
|
}{
|
|
{
|
|
name: "decoder error",
|
|
decoder: schema.NewDecoder(),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "code missing",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"code missing"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "redirect missing",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("code=123")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"redirect_uri missing"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "unimplemented CodeExchange called",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("code=123&redirect_uri=https://example.com/callback")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"unsupported_grant_type", "error_description":"authorization_code not supported"}`,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: UnimplementedServer{},
|
|
decoder: tt.decoder,
|
|
logger: slog.Default(),
|
|
}
|
|
client := newClient(clientTypeUserAgent)
|
|
runWebServerClientTest(t, s.codeExchangeHandler, tt.r, client, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_webServer_refreshTokenHandler(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
decoder httphelper.Decoder
|
|
r *http.Request
|
|
want webServerResult
|
|
}{
|
|
{
|
|
name: "decoder error",
|
|
decoder: schema.NewDecoder(),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "refresh token missing",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"refresh_token missing"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "unimplemented RefreshToken called",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("refresh_token=xxx")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"unsupported_grant_type", "error_description":"refresh_token not supported"}`,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: UnimplementedServer{},
|
|
decoder: tt.decoder,
|
|
logger: slog.Default(),
|
|
}
|
|
client := newClient(clientTypeUserAgent)
|
|
runWebServerClientTest(t, s.refreshTokenHandler, tt.r, client, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_webServer_tokenExchangeHandler(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
decoder httphelper.Decoder
|
|
r *http.Request
|
|
want webServerResult
|
|
}{
|
|
{
|
|
name: "decoder error",
|
|
decoder: schema.NewDecoder(),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "subject token missing",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"subject_token missing"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "subject token type missing",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"subject_token_type missing"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "subject token type unsupported",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=foo")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"subject_token_type is not supported"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "unsupported requested token type",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=foo")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"requested_token_type is not supported"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "unsupported actor token type",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=urn:ietf:params:oauth:token-type:access_token&actor_token_type=foo")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"actor_token_type is not supported"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "unimplemented TokenExchange called",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=urn:ietf:params:oauth:token-type:access_token&actor_token_type=urn:ietf:params:oauth:token-type:access_token")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:token-exchange not supported"}`,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: UnimplementedServer{},
|
|
decoder: tt.decoder,
|
|
logger: slog.Default(),
|
|
}
|
|
client := newClient(clientTypeUserAgent)
|
|
runWebServerClientTest(t, s.tokenExchangeHandler, tt.r, client, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_webServer_clientCredentialsHandler(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
decoder httphelper.Decoder
|
|
client Client
|
|
r *http.Request
|
|
want webServerResult
|
|
}{
|
|
{
|
|
name: "decoder error",
|
|
decoder: schema.NewDecoder(),
|
|
client: newClient(clientTypeUserAgent),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "public client",
|
|
decoder: testDecoder,
|
|
client: newClient(clientTypeNative),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "unimplemented ClientCredentialsExchange called",
|
|
decoder: testDecoder,
|
|
client: newClient(clientTypeUserAgent),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"unsupported_grant_type", "error_description":"client_credentials not supported"}`,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: UnimplementedServer{},
|
|
decoder: tt.decoder,
|
|
logger: slog.Default(),
|
|
}
|
|
runWebServerClientTest(t, s.clientCredentialsHandler, tt.r, tt.client, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_webServer_deviceTokenHandler(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
decoder httphelper.Decoder
|
|
r *http.Request
|
|
want webServerResult
|
|
}{
|
|
{
|
|
name: "decoder error",
|
|
decoder: schema.NewDecoder(),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "device code missing",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"device_code missing"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "unimplemented DeviceToken called",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("device_code=xxx")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:device_code not supported"}`,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: UnimplementedServer{},
|
|
decoder: tt.decoder,
|
|
logger: slog.Default(),
|
|
}
|
|
client := newClient(clientTypeUserAgent)
|
|
runWebServerClientTest(t, s.deviceTokenHandler, tt.r, client, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_webServer_introspectionHandler(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
decoder httphelper.Decoder
|
|
r *http.Request
|
|
want webServerResult
|
|
}{
|
|
{
|
|
name: "decoder error",
|
|
decoder: schema.NewDecoder(),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "public client",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=123")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "token missing",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=123&client_secret=SECRET")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"token missing"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "unimplemented Introspect called",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=123&client_secret=SECRET&token=xxx")),
|
|
want: webServerResult{
|
|
wantStatus: UnimplementedStatusCode,
|
|
wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: UnimplementedServer{},
|
|
decoder: tt.decoder,
|
|
logger: slog.Default(),
|
|
}
|
|
runWebServerTest(t, s.introspectionHandler, tt.r, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_webServer_userInfoHandler(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
decoder httphelper.Decoder
|
|
r *http.Request
|
|
want webServerResult
|
|
}{
|
|
{
|
|
name: "decoder error",
|
|
decoder: schema.NewDecoder(),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "access token missing",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusUnauthorized,
|
|
wantBody: `{"error":"invalid_request", "error_description":"access token missing"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "unimplemented UserInfo called",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("access_token=xxx")),
|
|
want: webServerResult{
|
|
wantStatus: UnimplementedStatusCode,
|
|
wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "bearer",
|
|
decoder: testDecoder,
|
|
r: func() *http.Request {
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
r.Header.Set("authorization", strings.Join([]string{"Bearer", "xxx"}, " "))
|
|
return r
|
|
}(),
|
|
want: webServerResult{
|
|
wantStatus: UnimplementedStatusCode,
|
|
wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: UnimplementedServer{},
|
|
decoder: tt.decoder,
|
|
logger: slog.Default(),
|
|
}
|
|
runWebServerTest(t, s.userInfoHandler, tt.r, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_webServer_revocationHandler(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
decoder httphelper.Decoder
|
|
client Client
|
|
r *http.Request
|
|
want webServerResult
|
|
}{
|
|
{
|
|
name: "decoder error",
|
|
decoder: schema.NewDecoder(),
|
|
client: newClient(clientTypeWeb),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "token missing",
|
|
decoder: testDecoder,
|
|
client: newClient(clientTypeWeb),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"token missing"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "unimplemented Revocation called, confidential client",
|
|
decoder: testDecoder,
|
|
client: newClient(clientTypeWeb),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")),
|
|
want: webServerResult{
|
|
wantStatus: UnimplementedStatusCode,
|
|
wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "unimplemented Revocation called, public client",
|
|
decoder: testDecoder,
|
|
client: newClient(clientTypeNative),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")),
|
|
want: webServerResult{
|
|
wantStatus: UnimplementedStatusCode,
|
|
wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: UnimplementedServer{},
|
|
decoder: tt.decoder,
|
|
logger: slog.Default(),
|
|
}
|
|
runWebServerClientTest(t, s.revocationHandler, tt.r, tt.client, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_webServer_endSessionHandler(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
decoder httphelper.Decoder
|
|
r *http.Request
|
|
want webServerResult
|
|
}{
|
|
{
|
|
name: "decoder error",
|
|
decoder: schema.NewDecoder(),
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "unimplemented EndSession called",
|
|
decoder: testDecoder,
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("id_token_hint=xxx")),
|
|
want: webServerResult{
|
|
wantStatus: UnimplementedStatusCode,
|
|
wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: UnimplementedServer{},
|
|
decoder: tt.decoder,
|
|
logger: slog.Default(),
|
|
}
|
|
runWebServerTest(t, s.endSessionHandler, tt.r, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_webServer_simpleHandler(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
decoder httphelper.Decoder
|
|
method func(context.Context, *Request[struct{}]) (*Response, error)
|
|
r *http.Request
|
|
want webServerResult
|
|
}{
|
|
{
|
|
name: "parse error",
|
|
decoder: schema.NewDecoder(),
|
|
r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`,
|
|
},
|
|
},
|
|
{
|
|
name: "method error",
|
|
decoder: schema.NewDecoder(),
|
|
method: func(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
|
return nil, io.ErrClosedPipe
|
|
},
|
|
r: httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(make([]byte, 11<<20))),
|
|
want: webServerResult{
|
|
wantStatus: http.StatusInternalServerError,
|
|
wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &webServer{
|
|
server: UnimplementedServer{},
|
|
decoder: tt.decoder,
|
|
logger: slog.Default(),
|
|
}
|
|
runWebServerTest(t, simpleHandler(s, tt.method), tt.r, tt.want)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_decodeRequest(t *testing.T) {
|
|
type dst struct {
|
|
A string `schema:"a"`
|
|
B string `schema:"b"`
|
|
}
|
|
type args struct {
|
|
r *http.Request
|
|
postOnly bool
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want *dst
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "parse error",
|
|
args: args{
|
|
r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))),
|
|
},
|
|
wantErr: oidc.ErrInvalidRequest().WithDescription("error parsing form"),
|
|
},
|
|
{
|
|
name: "decode error",
|
|
args: args{
|
|
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
|
},
|
|
wantErr: oidc.ErrInvalidRequest().WithDescription("error decoding form"),
|
|
},
|
|
{
|
|
name: "success, get",
|
|
args: args{
|
|
r: httptest.NewRequest(http.MethodGet, "/?a=b&b=a", nil),
|
|
},
|
|
want: &dst{
|
|
A: "b",
|
|
B: "a",
|
|
},
|
|
},
|
|
{
|
|
name: "success, post only",
|
|
args: args{
|
|
r: httptest.NewRequest(http.MethodPost, "/?b=a", strings.NewReader("a=b&")),
|
|
postOnly: true,
|
|
},
|
|
want: &dst{
|
|
A: "b",
|
|
},
|
|
},
|
|
{
|
|
name: "success, post mixed",
|
|
args: args{
|
|
r: httptest.NewRequest(http.MethodPost, "/?b=a", strings.NewReader("a=b&")),
|
|
postOnly: false,
|
|
},
|
|
want: &dst{
|
|
A: "b",
|
|
B: "a",
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if tt.args.r.Method == http.MethodPost {
|
|
tt.args.r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
}
|
|
got, err := decodeRequest[dst](schema.NewDecoder(), tt.args.r, tt.args.postOnly)
|
|
require.ErrorIs(t, err, tt.wantErr)
|
|
assert.Equal(t, tt.want, got)
|
|
})
|
|
}
|
|
}
|