From 4211fad1103a40faa0a992b4fe07030cf3399a66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Fri, 25 Aug 2023 12:24:09 +0300 Subject: [PATCH] finish op and testing without middleware for now --- example/server/exampleop/op.go | 13 -- example/server/main.go | 6 +- go.mod | 2 - go.sum | 5 - pkg/client/integration_test.go | 12 +- pkg/oidc/authorization.go | 4 +- pkg/oidc/authorization_test.go | 25 +++ pkg/oidc/error.go | 5 +- pkg/oidc/error_test.go | 153 ++++++++++++++++++ pkg/op/auth_request.go | 3 - pkg/op/auth_request_test.go | 3 +- pkg/op/error.go | 19 ++- pkg/op/error_test.go | 277 +++++++++++++++++++++++++++++++++ pkg/op/op.go | 5 + 14 files changed, 494 insertions(+), 38 deletions(-) create mode 100644 pkg/oidc/authorization_test.go create mode 100644 pkg/oidc/error_test.go create mode 100644 pkg/op/error_test.go diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index 6334ebf..57db020 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -4,15 +4,12 @@ import ( "crypto/sha256" "log" "net/http" - "strconv" - "sync/atomic" "time" "github.com/go-chi/chi" "golang.org/x/exp/slog" "golang.org/x/text/language" - "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/example/server/storage" "github.com/zitadel/oidc/v3/pkg/op" ) @@ -35,9 +32,6 @@ type Storage interface { deviceAuthenticate } -// simple request id counter -var requestID atomic.Uint64 - // SetupServer creates an OIDC server with Issuer=http://localhost: // // Use one of the pre-made clients in storage/clients.go or register a new one. @@ -47,13 +41,6 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router key := sha256.Sum256([]byte("test")) router := chi.NewRouter() - // Enable request logging on INFO level - router.Use(logging.Middleware( - logging.MiddlewareWithLoggerOption(logger), - logging.MiddlewareWithIDOption(func() string { - return strconv.FormatUint(requestID.Add(1), 10) - }), - )) // for simplicity, we provide a very small default page for users who have signed out router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { diff --git a/example/server/main.go b/example/server/main.go index c7711f3..ee8422b 100644 --- a/example/server/main.go +++ b/example/server/main.go @@ -6,7 +6,6 @@ import ( "net/http" "os" - "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/example/server/exampleop" "github.com/zitadel/oidc/v3/example/server/storage" "golang.org/x/exp/slog" @@ -27,13 +26,12 @@ func main() { // data set to the context gets printed // as part of the log output. // This helps us tie log output to requests. - logger := slog.New(logging.WrapHandler( + logger := slog.New( slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ AddSource: true, Level: slog.LevelDebug, }), - logging.HandlerWithCTXGroupName("ctx"), - )) + ) router := exampleop.SetupServer(issuer, storage, logger) server := &http.Server{ diff --git a/go.mod b/go.mod index c4f1d87..8a50ca6 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,6 @@ require ( github.com/rs/cors v1.9.0 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.2 - github.com/zitadel/logging v0.3.5-0.20230824152050-9b8a8a0bdf73 github.com/zitadel/schema v1.3.0 golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 golang.org/x/oauth2 v0.7.0 @@ -22,7 +21,6 @@ require ( ) require ( - github.com/benbjohnson/clock v1.3.5 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/go-cmp v0.5.9 // indirect diff --git a/go.sum b/go.sum index 9188041..4b6e60d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= -github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -49,8 +47,6 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -github.com/zitadel/logging v0.3.5-0.20230824152050-9b8a8a0bdf73 h1:vYoTHKhD3upeAXNTBcoSmdZkRF9WOcJzozJX/1SiXUI= -github.com/zitadel/logging v0.3.5-0.20230824152050-9b8a8a0bdf73/go.mod h1:p1XQg2/CP5BGGwTrZ2thanc3cvGpkHALLCrMiO+ULsY= github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0= github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -106,7 +102,6 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= -gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index 073efef..7cbb62e 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -19,6 +19,7 @@ import ( "github.com/jeremija/gosubmit" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/slog" "github.com/zitadel/oidc/v3/example/server/exampleop" "github.com/zitadel/oidc/v3/example/server/storage" @@ -29,6 +30,13 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" ) +var Logger = slog.New( + slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), +) + var CTX context.Context func TestMain(m *testing.M) { @@ -49,7 +57,7 @@ func TestRelyingPartySession(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) @@ -100,7 +108,7 @@ func TestResourceServerTokenExchange(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index 743d9c0..7e7c30c 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -1,8 +1,6 @@ package oidc import ( - "fmt" - "golang.org/x/exp/slog" ) @@ -94,7 +92,7 @@ type AuthRequest struct { func (a *AuthRequest) LogValue() slog.Value { return slog.GroupValue( - slog.Any("scopes", fmt.Stringer(a.Scopes)), + slog.Any("scopes", a.Scopes), slog.String("response_type", string(a.ResponseType)), slog.String("client_id", a.ClientID), slog.String("redirect_uri", a.RedirectURI), diff --git a/pkg/oidc/authorization_test.go b/pkg/oidc/authorization_test.go new file mode 100644 index 0000000..96a6bd7 --- /dev/null +++ b/pkg/oidc/authorization_test.go @@ -0,0 +1,25 @@ +package oidc + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/exp/slog" +) + +func TestAuthRequest_LogValue(t *testing.T) { + a := &AuthRequest{ + Scopes: SpaceDelimitedArray{"a", "b"}, + ResponseType: "respType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + } + want := slog.GroupValue( + slog.Any("scopes", SpaceDelimitedArray{"a", "b"}), + slog.String("response_type", "respType"), + slog.String("client_id", "123"), + slog.String("redirect_uri", "http://example.com/callback"), + ) + got := a.LogValue() + assert.Equal(t, want, got) +} diff --git a/pkg/oidc/error.go b/pkg/oidc/error.go index f6a5de5..07a9069 100644 --- a/pkg/oidc/error.go +++ b/pkg/oidc/error.go @@ -186,7 +186,7 @@ func (e *Error) LogLevel() slog.Level { } func (e *Error) LogValue() slog.Value { - attrs := make([]slog.Attr, 0, 4) + attrs := make([]slog.Attr, 0, 5) if e.Parent != nil { attrs = append(attrs, slog.Any("parent", e.Parent)) } @@ -199,5 +199,8 @@ func (e *Error) LogValue() slog.Value { if e.State != "" { attrs = append(attrs, slog.String("state", e.State)) } + if e.redirectDisabled { + attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled)) + } return slog.GroupValue(attrs...) } diff --git a/pkg/oidc/error_test.go b/pkg/oidc/error_test.go new file mode 100644 index 0000000..3ad29ec --- /dev/null +++ b/pkg/oidc/error_test.go @@ -0,0 +1,153 @@ +package oidc + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/exp/slog" +) + +func TestDefaultToServerError(t *testing.T) { + type args struct { + err error + description string + } + tests := []struct { + name string + args args + want *Error + }{ + { + name: "default", + args: args{ + err: io.ErrClosedPipe, + description: "oops", + }, + want: &Error{ + ErrorType: ServerError, + Description: "oops", + Parent: io.ErrClosedPipe, + }, + }, + { + name: "our Error", + args: args{ + err: ErrAccessDenied(), + description: "oops", + }, + want: &Error{ + ErrorType: AccessDenied, + Description: "The authorization request was denied.", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := DefaultToServerError(tt.args.err, tt.args.description) + assert.ErrorIs(t, got, tt.want) + }) + } +} + +func TestError_LogLevel(t *testing.T) { + tests := []struct { + name string + err *Error + want slog.Level + }{ + { + name: "server error", + err: ErrServerError(), + want: slog.LevelError, + }, + { + name: "authorization pending", + err: ErrAuthorizationPending(), + want: slog.LevelInfo, + }, + { + name: "some other error", + err: ErrAccessDenied(), + want: slog.LevelWarn, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.err.LogLevel() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestError_LogValue(t *testing.T) { + type fields struct { + Parent error + ErrorType errorType + Description string + State string + redirectDisabled bool + } + tests := []struct { + name string + fields fields + want slog.Value + }{ + { + name: "parent", + fields: fields{ + Parent: io.EOF, + }, + want: slog.GroupValue(slog.Any("parent", io.EOF)), + }, + { + name: "description", + fields: fields{ + Description: "oops", + }, + want: slog.GroupValue(slog.String("description", "oops")), + }, + { + name: "errorType", + fields: fields{ + ErrorType: ExpiredToken, + }, + want: slog.GroupValue(slog.String("type", string(ExpiredToken))), + }, + { + name: "state", + fields: fields{ + State: "123", + }, + want: slog.GroupValue(slog.String("state", "123")), + }, + { + name: "all fields", + fields: fields{ + Parent: io.EOF, + Description: "oops", + ErrorType: ExpiredToken, + State: "123", + }, + want: slog.GroupValue( + slog.Any("parent", io.EOF), + slog.String("description", "oops"), + slog.String("type", string(ExpiredToken)), + slog.String("state", "123"), + ), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Error{ + Parent: tt.fields.Parent, + ErrorType: tt.fields.ErrorType, + Description: tt.fields.Description, + State: tt.fields.State, + redirectDisabled: tt.fields.redirectDisabled, + } + got := e.LogValue() + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index e163746..7610248 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -18,9 +18,6 @@ import ( ) type AuthRequest interface { - // LogValuer allows the implementation which fields to log, - // and which ones to redact for security reasons. - slog.LogValuer GetID() string GetACR() string GetAMR() []string diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index df340b6..42fd0aa 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -18,6 +18,7 @@ import ( "github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op/mock" "github.com/zitadel/schema" + "golang.org/x/exp/slog" ) func TestAuthorize(t *testing.T) { @@ -38,7 +39,7 @@ func TestAuthorize(t *testing.T) { expect := authorizer.EXPECT() expect.Decoder().Return(schema.NewDecoder()) - expect.Encoder().Return(schema.NewEncoder()) + expect.Logger().Return(slog.Default()) if tt.expect != nil { tt.expect(expect) diff --git a/pkg/op/error.go b/pkg/op/error.go index 4898e69..9981fec 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -9,12 +9,20 @@ import ( ) type ErrAuthRequest interface { - slog.LogValuer GetRedirectURI() string GetResponseType() oidc.ResponseType GetState() string } +// LogAuthRequest is an optional interface, +// that allows logging AuthRequest fields. +// If the AuthRequest does not implement this interface, +// no details shall be printed to the logs. +type LogAuthRequest interface { + ErrAuthRequest + slog.LogValuer +} + func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, authorizer Authorizer) { e := oidc.DefaultToServerError(err, err.Error()) logger := authorizer.Logger().With("oidc_error", e) @@ -24,10 +32,13 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq http.Error(w, err.Error(), http.StatusBadRequest) return } - logger = logger.With("authRequest", authReq) + + if logAuthReq, ok := authReq.(LogAuthRequest); ok { + logger = logger.With("auth_request", logAuthReq) + } if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() { - logger.Log(r.Context(), e.LogLevel(), "auth request without redirect") + logger.Log(r.Context(), e.LogLevel(), "auth request: not redirecting") http.Error(w, e.Description, http.StatusBadRequest) return } @@ -50,7 +61,7 @@ func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slo e := oidc.DefaultToServerError(err, err.Error()) status := http.StatusBadRequest if e.ErrorType == oidc.InvalidClient { - status = 401 + status = http.StatusUnauthorized } logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) httphelper.MarshalJSONWithStatus(w, e, status) diff --git a/pkg/op/error_test.go b/pkg/op/error_test.go new file mode 100644 index 0000000..dc5ef11 --- /dev/null +++ b/pkg/op/error_test.go @@ -0,0 +1,277 @@ +package op + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/schema" + "golang.org/x/exp/slog" +) + +func TestAuthRequestError(t *testing.T) { + type args struct { + authReq ErrAuthRequest + err error + } + tests := []struct { + name string + args args + wantCode int + wantHeaders map[string]string + wantBody string + wantLog string + }{ + { + name: "nil auth request", + args: args{ + authReq: nil, + err: io.ErrClosedPipe, + }, + wantCode: http.StatusBadRequest, + wantBody: "io: read/write on closed pipe\n", + wantLog: `{ + "level":"ERROR", + "msg":"auth request", + "time":"not", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + } + }`, + }, + { + name: "auth request, no redirect URI", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantCode: http.StatusBadRequest, + wantBody: "sign in\n", + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request, redirect disabled", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), + }, + wantCode: http.StatusBadRequest, + wantBody: "oops\n", + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"oops", + "type":"invalid_request", + "redirect_disabled":true + } + }`, + }, + { + name: "auth request, url parse error", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "can't parse this!\n", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantCode: http.StatusBadRequest, + wantBody: "ErrorType=server_error Parent=parse \"can't parse this!\\n\": net/url: invalid control character in URL\n", + wantLog: `{ + "level":"ERROR", + "msg":"auth response URL", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"can't parse this!\n", + "response_type":"responseType", + "scopes":"a b" + }, + "error":{ + "type":"server_error", + "parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request redirect", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantCode: http.StatusFound, + wantHeaders: map[string]string{"Location": "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"}, + wantLog: `{ + "level":"WARN", + "msg":"auth request", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + authorizer := &Provider{ + encoder: schema.NewEncoder(), + logger: slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ), + } + + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/path", nil) + AuthRequestError(w, r, tt.args.authReq, tt.args.err, authorizer) + + res := w.Result() + defer res.Body.Close() + + assert.Equal(t, tt.wantCode, res.StatusCode) + for key, wantHeader := range tt.wantHeaders { + gotHeader := res.Header.Get(key) + assert.Equalf(t, wantHeader, gotHeader, "header %q", key) + } + gotBody, err := io.ReadAll(res.Body) + require.NoError(t, err, "read result body") + assert.Equal(t, tt.wantBody, string(gotBody), "result body") + + gotLog := logOut.String() + t.Log(gotLog) + assert.JSONEq(t, tt.wantLog, gotLog, "log output") + }) + } +} + +func TestRequestError(t *testing.T) { + tests := []struct { + name string + err error + wantCode int + wantBody string + wantLog string + }{ + { + name: "server error", + err: io.ErrClosedPipe, + wantCode: http.StatusBadRequest, + wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`, + wantLog: `{ + "level":"ERROR", + "msg":"request error", + "time":"not", + "oidc_error":{ + "parent":"io: read/write on closed pipe", + "description":"io: read/write on closed pipe", + "type":"server_error"} + }`, + }, + { + name: "invalid client", + err: oidc.ErrInvalidClient().WithDescription("not good"), + wantCode: http.StatusUnauthorized, + wantBody: `{"error":"invalid_client", "error_description":"not good"}`, + wantLog: `{ + "level":"WARN", + "msg":"request error", + "time":"not", + "oidc_error":{ + "description":"not good", + "type":"invalid_client"} + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + logger := slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ) + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/path", nil) + RequestError(w, r, tt.err, logger) + + res := w.Result() + defer res.Body.Close() + + assert.Equal(t, tt.wantCode, res.StatusCode, "status code") + + gotBody, err := io.ReadAll(res.Body) + require.NoError(t, err, "read result body") + assert.JSONEq(t, tt.wantBody, string(gotBody), "result body") + + gotLog := logOut.String() + t.Log(gotLog) + assert.JSONEq(t, tt.wantLog, gotLog, "log output") + }) + } +} diff --git a/pkg/op/op.go b/pkg/op/op.go index 26f731e..d8ae570 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -79,6 +79,8 @@ type OpenIDProvider interface { Crypto() Crypto DefaultLogoutRedirectURI() string Probes() []ProbesFn + + // EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20 Logger() *slog.Logger // Deprecated: Provider now implements http.Handler directly. @@ -531,6 +533,9 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option { } } +// WithLogger lets a logger other than slog.Default(). +// +// EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20 func WithLogger(logger *slog.Logger) Option { return func(o *Provider) error { o.logger = logger