From f30f0d3ead6da5f1c42a217ef39347f39c245a0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Mon, 21 Aug 2023 19:55:24 +0200 Subject: [PATCH] feat(op): user slog for logging integrate with golang.org/x/exp/slog for logging. provide a middleware for request scoped logging. BREAKING CHANGES: 1. OpenIDProvider and sub-interfaces get a Logger() method to return the configured logger; 2. AuthRequestError now takes the complete Authorizer, instead of only the encoder. So that it may use its Logger() method. 3. RequestError now takes a Logger as argument. --- example/server/exampleop/op.go | 14 +++- example/server/storage/oidc.go | 14 ++++ go.mod | 4 +- go.sum | 8 +- pkg/client/rp/relying_party.go | 2 +- pkg/oidc/authorization.go | 15 ++++ pkg/oidc/error.go | 30 ++++++++ pkg/oidc/types.go | 8 +- pkg/op/auth_request.go | 37 ++++++---- pkg/op/device.go | 4 +- pkg/op/error.go | 19 ++++- pkg/op/logger.go | 114 +++++++++++++++++++++++++++++ pkg/op/mock/authorizer.mock.go | 15 ++++ pkg/op/op.go | 18 ++++- pkg/op/op_test.go | 10 +-- pkg/op/session.go | 6 +- pkg/op/token_client_credentials.go | 6 +- pkg/op/token_code.go | 8 +- pkg/op/token_exchange.go | 6 +- pkg/op/token_jwt_profile.go | 8 +- pkg/op/token_refresh.go | 6 +- pkg/op/token_request.go | 6 +- 22 files changed, 297 insertions(+), 61 deletions(-) create mode 100644 pkg/op/logger.go diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index 298bff6..8c1b56f 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -4,9 +4,11 @@ import ( "crypto/sha256" "log" "net/http" + "os" "time" "github.com/go-chi/chi" + "golang.org/x/exp/slog" "golang.org/x/text/language" "github.com/zitadel/oidc/v3/example/server/storage" @@ -43,10 +45,8 @@ func SetupServer(issuer string, storage Storage) chi.Router { // for simplicity, we provide a very small default page for users who have signed out router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { - _, err := w.Write([]byte("signed out successfully")) - if err != nil { - log.Printf("error serving logged out page: %v", err) - } + w.Write([]byte("signed out successfully")) + // no need to check/log error, this will be handeled by the middleware. }) // creation of the OpenIDProvider with the just created in-memory Storage @@ -117,6 +117,12 @@ func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider, op.WithAllowInsecure(), // as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth op.WithCustomAuthEndpoint(op.NewEndpoint("auth")), + op.WithLogger(slog.New( + slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), + )), ) if err != nil { return nil, err diff --git a/example/server/storage/oidc.go b/example/server/storage/oidc.go index b56ad09..63afcf9 100644 --- a/example/server/storage/oidc.go +++ b/example/server/storage/oidc.go @@ -3,6 +3,7 @@ package storage import ( "time" + "golang.org/x/exp/slog" "golang.org/x/text/language" "github.com/zitadel/oidc/v3/pkg/oidc" @@ -41,6 +42,19 @@ type AuthRequest struct { authTime time.Time } +// LogValue allows you to define which fields will be logged. +// Implements the [slog.LogValuer] +func (a *AuthRequest) LogValue() slog.Value { + return slog.GroupValue( + slog.String("id", a.ID), + slog.Time("creation_date", a.CreationDate), + slog.Any("scopes", a.Scopes), + slog.String("response_type", string(a.ResponseType)), + slog.String("app_id", a.ApplicationID), + slog.String("callback_uri", a.CallbackURI), + ) +} + func (a *AuthRequest) GetID() string { return a.ID } diff --git a/go.mod b/go.mod index 610d2a1..c2f95de 100644 --- a/go.mod +++ b/go.mod @@ -11,9 +11,11 @@ require ( github.com/jeremija/gosubmit v0.2.7 github.com/muhlemmer/gu v0.3.1 github.com/rs/cors v1.9.0 + github.com/rs/xid v1.5.0 github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.2 github.com/zitadel/schema v1.3.0 + golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 golang.org/x/oauth2 v0.7.0 golang.org/x/text v0.9.0 gopkg.in/square/go-jose.v2 v2.6.0 @@ -27,7 +29,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/crypto v0.7.0 // indirect golang.org/x/net v0.9.0 // indirect - golang.org/x/sys v0.7.0 // indirect + golang.org/x/sys v0.11.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/protobuf v1.29.1 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/go.sum b/go.sum index c9c8562..573fdc0 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE= github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -53,6 +55,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -73,8 +77,8 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= -golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index 5597c9d..29215a1 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -546,7 +546,7 @@ func withURLParam(key, value string) func() []oauth2.AuthCodeOption { // This is the generalized, unexported, function used by both // URLParamOpt and AuthURLOpt. func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption { - return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode()) + return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).String()) } type URLParamOpt func() []oauth2.AuthCodeOption diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index d8bf336..743d9c0 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -1,5 +1,11 @@ package oidc +import ( + "fmt" + + "golang.org/x/exp/slog" +) + const ( // ScopeOpenID defines the scope `openid` // OpenID Connect requests MUST contain the `openid` scope value @@ -86,6 +92,15 @@ type AuthRequest struct { RequestParam string `schema:"request"` } +func (a *AuthRequest) LogValue() slog.Value { + return slog.GroupValue( + slog.Any("scopes", fmt.Stringer(a.Scopes)), + slog.String("response_type", string(a.ResponseType)), + slog.String("client_id", a.ClientID), + slog.String("redirect_uri", a.RedirectURI), + ) +} + // GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface func (a *AuthRequest) GetRedirectURI() string { return a.RedirectURI diff --git a/pkg/oidc/error.go b/pkg/oidc/error.go index 79acecd..f6a5de5 100644 --- a/pkg/oidc/error.go +++ b/pkg/oidc/error.go @@ -3,6 +3,8 @@ package oidc import ( "errors" "fmt" + + "golang.org/x/exp/slog" ) type errorType string @@ -171,3 +173,31 @@ func DefaultToServerError(err error, description string) *Error { } return oauth } + +func (e *Error) LogLevel() slog.Level { + level := slog.LevelWarn + if e.ErrorType == ServerError { + level = slog.LevelError + } + if e.ErrorType == AuthorizationPending { + level = slog.LevelInfo + } + return level +} + +func (e *Error) LogValue() slog.Value { + attrs := make([]slog.Attr, 0, 4) + if e.Parent != nil { + attrs = append(attrs, slog.Any("parent", e.Parent)) + } + if e.Description != "" { + attrs = append(attrs, slog.String("description", e.Description)) + } + if e.ErrorType != "" { + attrs = append(attrs, slog.String("type", string(e.ErrorType))) + } + if e.State != "" { + attrs = append(attrs, slog.String("state", e.State)) + } + return slog.GroupValue(attrs...) +} diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index 86ee1e0..5db8bad 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -106,7 +106,7 @@ type ResponseType string type ResponseMode string -func (s SpaceDelimitedArray) Encode() string { +func (s SpaceDelimitedArray) String() string { return strings.Join(s, " ") } @@ -116,11 +116,11 @@ func (s *SpaceDelimitedArray) UnmarshalText(text []byte) error { } func (s SpaceDelimitedArray) MarshalText() ([]byte, error) { - return []byte(s.Encode()), nil + return []byte(s.String()), nil } func (s SpaceDelimitedArray) MarshalJSON() ([]byte, error) { - return json.Marshal((s).Encode()) + return json.Marshal((s).String()) } func (s *SpaceDelimitedArray) UnmarshalJSON(data []byte) error { @@ -165,7 +165,7 @@ func (s SpaceDelimitedArray) Value() (driver.Value, error) { func NewEncoder() *schema.Encoder { e := schema.NewEncoder() e.RegisterEncoder(SpaceDelimitedArray{}, func(value reflect.Value) string { - return value.Interface().(SpaceDelimitedArray).Encode() + return value.Interface().(SpaceDelimitedArray).String() }) return e } diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 7af3779..e163746 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -14,9 +14,13 @@ import ( httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" str "github.com/zitadel/oidc/v3/pkg/strings" + "golang.org/x/exp/slog" ) type AuthRequest interface { + // LogValuer allows the implementation which fields to log, + // and which ones to redact for security reasons. + slog.LogValuer GetID() string GetACR() string GetAMR() []string @@ -41,6 +45,7 @@ type Authorizer interface { IDTokenHintVerifier(context.Context) *IDTokenHintVerifier Crypto() Crypto RequestObjectSupported() bool + Logger() *slog.Logger } // AuthorizeValidator is an extension of Authorizer interface @@ -67,23 +72,23 @@ func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, * func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder()) if err != nil { - AuthRequestError(w, r, nil, err, authorizer.Encoder()) + AuthRequestError(w, r, nil, err, authorizer) return } ctx := r.Context() if authReq.RequestParam != "" && authorizer.RequestObjectSupported() { authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx)) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } } if authReq.ClientID == "" { - AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer.Encoder()) + AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer) return } if authReq.RedirectURI == "" { - AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer.Encoder()) + AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer) return } validation := ValidateAuthRequest @@ -92,21 +97,21 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { } userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx)) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } if authReq.RequestParam != "" { - AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder()) + AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer) return } req, err := authorizer.Storage().CreateAuthRequest(ctx, authReq, userID) if err != nil { - AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder()) + AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer) return } client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID()) if err != nil { - AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder()) + AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer) return } RedirectToLogin(req.GetID(), client, w, r) @@ -406,18 +411,18 @@ func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r * func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { id, err := ParseAuthorizeCallbackRequest(r) if err != nil { - AuthRequestError(w, r, nil, err, authorizer.Encoder()) + AuthRequestError(w, r, nil, err, authorizer) return } authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id) if err != nil { - AuthRequestError(w, r, nil, err, authorizer.Encoder()) + AuthRequestError(w, r, nil, err, authorizer) return } if !authReq.Done() { AuthRequestError(w, r, authReq, oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."), - authorizer.Encoder()) + authorizer) return } AuthResponse(authReq, authorizer, w, r) @@ -438,7 +443,7 @@ func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) { func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } if authReq.GetResponseType() == oidc.ResponseTypeCode { @@ -452,7 +457,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) { code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } codeResponse := struct { @@ -464,7 +469,7 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques } callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } http.Redirect(w, r, callback, http.StatusFound) @@ -475,12 +480,12 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "", "") if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } http.Redirect(w, r, callback, http.StatusFound) diff --git a/pkg/op/device.go b/pkg/op/device.go index 09c7fca..029bed8 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -57,7 +57,7 @@ var ( func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { if err := DeviceAuthorization(w, r, o); err != nil { - RequestError(w, r, err) + RequestError(w, r, err, o.Logger()) } } } @@ -190,7 +190,7 @@ func (r *deviceAccessTokenRequest) GetScopes() []string { func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { if err := deviceAccessToken(w, r, exchanger); err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } } diff --git a/pkg/op/error.go b/pkg/op/error.go index b2d84ae..ccb9000 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -5,21 +5,29 @@ import ( httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/exp/slog" ) type ErrAuthRequest interface { + slog.LogValuer GetRedirectURI() string GetResponseType() oidc.ResponseType GetState() string } -func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder httphelper.Encoder) { +func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, authorizer Authorizer) { + e := oidc.DefaultToServerError(err, err.Error()) + logger := authorizer.Logger().With("oidc_error", e) + if authReq == nil { + logger.Log(r.Context(), e.LogLevel(), "auth request nil") http.Error(w, err.Error(), http.StatusBadRequest) return } - e := oidc.DefaultToServerError(err, err.Error()) + logger = logger.With("authRequest", authReq) + if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() { + logger.Log(r.Context(), e.LogLevel(), "auth request without redirect") http.Error(w, e.Description, http.StatusBadRequest) return } @@ -28,19 +36,22 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { responseMode = rm.GetResponseMode() } - url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder) + url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, authorizer.Encoder()) if err != nil { + logger.ErrorContext(r.Context(), "auth response URL", "error", err) http.Error(w, err.Error(), http.StatusBadRequest) return } + logger.Log(r.Context(), e.LogLevel(), "auth request error") http.Redirect(w, r, url, http.StatusFound) } -func RequestError(w http.ResponseWriter, r *http.Request, err error) { +func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) { e := oidc.DefaultToServerError(err, err.Error()) status := http.StatusBadRequest if e.ErrorType == oidc.InvalidClient { status = 401 } + logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) httphelper.MarshalJSONWithStatus(w, e, status) } diff --git a/pkg/op/logger.go b/pkg/op/logger.go new file mode 100644 index 0000000..ea1251b --- /dev/null +++ b/pkg/op/logger.go @@ -0,0 +1,114 @@ +package op + +import ( + "context" + "net/http" + "time" + + "github.com/rs/xid" + "golang.org/x/exp/slog" +) + +func newLogger(logger *slog.Logger) *slog.Logger { + if logger == nil { + logger = slog.Default() + } + return slog.New(&logHandler{ + handler: logger.Handler(), + }) +} + +type LogKey int + +const ( + RequestID LogKey = iota + + maxLogKey +) + +type logHandler struct { + handler slog.Handler +} + +func (h *logHandler) Enabled(ctx context.Context, level slog.Level) bool { + return h.handler.Enabled(ctx, level) +} + +type logAttributes []slog.Attr + +func (attrs *logAttributes) appendFromContext(ctx context.Context, ctxKey any, logKey string) { + v := ctx.Value(RequestID) + if v == nil { + return + } + *attrs = append(*attrs, slog.Group("request", slog.Attr{ + Key: "id", + Value: slog.AnyValue(v), + })) +} + +func (h *logHandler) Handle(ctx context.Context, record slog.Record) error { + attrs := make(logAttributes, 0, maxLogKey) + attrs.appendFromContext(ctx, RequestID, "id") + + handler := h.handler.WithAttrs(attrs) + + return handler.Handle(ctx, record) +} + +func (h *logHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &logHandler{ + handler: h.handler.WithAttrs(attrs), + } +} + +func (h *logHandler) WithGroup(name string) slog.Handler { + return &logHandler{ + handler: h.handler.WithGroup(name), + } +} + +func (o *Provider) LogMiddleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + r = r.WithContext(context.WithValue(r.Context(), RequestID, xid.New())) + lw := &loggedWriter{ + ResponseWriter: w, + } + next.ServeHTTP(lw, r) + logger := o.logger.With( + slog.Group("request", "method", r.Method, "url", r.URL), + slog.Group("response", "duration", time.Since(start), "status", lw.statusCode, "written", lw.written), + ) + if lw.err != nil { + logger.ErrorContext(r.Context(), "response writer", "error", lw.err) + return + } + logger.InfoContext(r.Context(), "done") + }) + } +} + +type loggedWriter struct { + http.ResponseWriter + + statusCode int + written int + err error +} + +func (w *loggedWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w *loggedWriter) Write(b []byte) (int, error) { + if w.statusCode == 0 { + w.WriteHeader(http.StatusOK) + } + n, err := w.ResponseWriter.Write(b) + w.written += n + w.err = err + return n, err +} diff --git a/pkg/op/mock/authorizer.mock.go b/pkg/op/mock/authorizer.mock.go index a0c67e3..e4297cb 100644 --- a/pkg/op/mock/authorizer.mock.go +++ b/pkg/op/mock/authorizer.mock.go @@ -11,6 +11,7 @@ import ( gomock "github.com/golang/mock/gomock" http "github.com/zitadel/oidc/v3/pkg/http" op "github.com/zitadel/oidc/v3/pkg/op" + slog "golang.org/x/exp/slog" ) // MockAuthorizer is a mock of Authorizer interface. @@ -92,6 +93,20 @@ func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier(arg0 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier), arg0) } +// Logger mocks base method. +func (m *MockAuthorizer) Logger() *slog.Logger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Logger") + ret0, _ := ret[0].(*slog.Logger) + return ret0 +} + +// Logger indicates an expected call of Logger. +func (mr *MockAuthorizerMockRecorder) Logger() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockAuthorizer)(nil).Logger)) +} + // RequestObjectSupported mocks base method. func (m *MockAuthorizer) RequestObjectSupported() bool { m.ctrl.T.Helper() diff --git a/pkg/op/op.go b/pkg/op/op.go index 1fbe780..8d9c57a 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -9,6 +9,7 @@ import ( "github.com/go-chi/chi" "github.com/rs/cors" "github.com/zitadel/schema" + "golang.org/x/exp/slog" "golang.org/x/text/language" "gopkg.in/square/go-jose.v2" @@ -78,6 +79,7 @@ type OpenIDProvider interface { Crypto() Crypto DefaultLogoutRedirectURI() string Probes() []ProbesFn + Logger() *slog.Logger // Deprecated: Provider now implements http.Handler directly. HttpHandler() http.Handler @@ -85,8 +87,9 @@ type OpenIDProvider interface { type HttpInterceptor func(http.Handler) http.Handler -func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router { +func CreateRouter(o *Provider, interceptors ...HttpInterceptor) chi.Router { router := chi.NewRouter() + router.Use(o.LogMiddleware()) router.Use(cors.New(defaultCORSOptions).Handler) router.Use(intercept(o.IssuerFromRequest, interceptors...)) router.HandleFunc(healthEndpoint, healthHandler) @@ -174,6 +177,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR storage: storage, endpoints: DefaultEndpoints, timer: make(<-chan time.Time), + logger: slog.Default(), } for _, optFunc := range opOpts { @@ -217,6 +221,7 @@ type Provider struct { timer <-chan time.Time accessTokenVerifierOpts []AccessTokenVerifierOpt idTokenHintVerifierOpts []IDTokenHintVerifierOpt + logger *slog.Logger } func (o *Provider) IssuerFromRequest(r *http.Request) string { @@ -375,6 +380,10 @@ func (o *Provider) Probes() []ProbesFn { } } +func (o *Provider) Logger() *slog.Logger { + return o.logger +} + // Deprecated: Provider now implements http.Handler directly. func (o *Provider) HttpHandler() http.Handler { return o @@ -523,6 +532,13 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option { } } +func WithLogger(logger *slog.Logger) Option { + return func(o *Provider) error { + o.logger = newLogger(logger) + return nil + } +} + func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler { issuerInterceptor := NewIssuerInterceptor(i) return func(handler http.Handler) http.Handler { diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index d347d04..d33b39d 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -156,7 +156,7 @@ func TestRoutes(t *testing.T) { values: map[string]string{ "client_id": client.GetID(), "redirect_uri": "https://example.com", - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "response_type": string(oidc.ResponseTypeCode), }, wantCode: http.StatusFound, @@ -193,7 +193,7 @@ func TestRoutes(t *testing.T) { path: testProvider.TokenEndpoint().Relative(), values: map[string]string{ "grant_type": string(oidc.GrantTypeBearer), - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "assertion": jwtToken, }, wantCode: http.StatusBadRequest, @@ -206,7 +206,7 @@ func TestRoutes(t *testing.T) { basicAuth: &basicAuth{"web", "secret"}, values: map[string]string{ "grant_type": string(oidc.GrantTypeTokenExchange), - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "subject_token": jwtToken, "subject_token_type": string(oidc.AccessTokenType), }, @@ -223,7 +223,7 @@ func TestRoutes(t *testing.T) { basicAuth: &basicAuth{"sid1", "verysecret"}, values: map[string]string{ "grant_type": string(oidc.GrantTypeClientCredentials), - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), }, wantCode: http.StatusOK, contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`}, @@ -338,7 +338,7 @@ func TestRoutes(t *testing.T) { path: testProvider.DeviceAuthorizationEndpoint().Relative(), basicAuth: &basicAuth{"web", "secret"}, values: map[string]string{ - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), }, wantCode: http.StatusOK, contains: []string{ diff --git a/pkg/op/session.go b/pkg/op/session.go index fd914d1..2467b20 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -8,6 +8,7 @@ import ( httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/exp/slog" ) type SessionEnder interface { @@ -15,6 +16,7 @@ type SessionEnder interface { Storage() Storage IDTokenHintVerifier(context.Context) *IDTokenHintVerifier DefaultLogoutRedirectURI() string + Logger() *slog.Logger } func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) { @@ -31,12 +33,12 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) { } session, err := ValidateEndSessionRequest(r.Context(), req, ender) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, ender.Logger()) return } err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID) if err != nil { - RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session")) + RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"), ender.Logger()) return } http.Redirect(w, r, session.RedirectURI, http.StatusFound) diff --git a/pkg/op/token_client_credentials.go b/pkg/op/token_client_credentials.go index 0cf7796..043bb07 100644 --- a/pkg/op/token_client_credentials.go +++ b/pkg/op/token_client_credentials.go @@ -14,18 +14,18 @@ import ( func ClientCredentialsExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { request, err := ParseClientCredentialsRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } validatedRequest, client, err := ValidateClientCredentialsRequest(r.Context(), request, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateClientCredentialsTokenResponse(r.Context(), validatedRequest, exchanger, client) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index b5e892a..baf377b 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -13,20 +13,20 @@ import ( func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } if tokenReq.Code == "" { - RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing")) + RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), exchanger.Logger()) return } authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code, "") if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) diff --git a/pkg/op/token_exchange.go b/pkg/op/token_exchange.go index 93aa9b2..21db134 100644 --- a/pkg/op/token_exchange.go +++ b/pkg/op/token_exchange.go @@ -136,17 +136,17 @@ func (r *tokenExchangeRequest) SetSubject(subject string) { func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { tokenExchangeReq, clientID, clientSecret, err := ParseTokenExchangeRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } tokenExchangeRequest, client, err := ValidateTokenExchangeRequest(r.Context(), tokenExchangeReq, clientID, clientSecret, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateTokenExchangeResponse(r.Context(), tokenExchangeRequest, client, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) diff --git a/pkg/op/token_jwt_profile.go b/pkg/op/token_jwt_profile.go index 4cd7b1e..357200e 100644 --- a/pkg/op/token_jwt_profile.go +++ b/pkg/op/token_jwt_profile.go @@ -18,23 +18,23 @@ type JWTAuthorizationGrantExchanger interface { func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizationGrantExchanger) { profileRequest, err := ParseJWTProfileGrantRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, exchanger.JWTProfileVerifier(r.Context())) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(r.Context(), tokenRequest.Issuer, profileRequest.Scope) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) diff --git a/pkg/op/token_refresh.go b/pkg/op/token_refresh.go index aeaa5b4..9421033 100644 --- a/pkg/op/token_refresh.go +++ b/pkg/op/token_refresh.go @@ -26,16 +26,16 @@ type RefreshTokenRequest interface { func RefreshTokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { tokenReq, err := ParseRefreshTokenRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } validatedRequest, client, err := ValidateRefreshTokenRequest(r.Context(), tokenReq, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateTokenResponse(r.Context(), validatedRequest, client, exchanger, true, "", tokenReq.RefreshToken) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index c06a51b..0df2fce 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -7,6 +7,7 @@ import ( httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/exp/slog" ) type Exchanger interface { @@ -22,6 +23,7 @@ type Exchanger interface { GrantTypeDeviceCodeSupported() bool AccessTokenVerifier(context.Context) *AccessTokenVerifier IDTokenHintVerifier(context.Context) *IDTokenHintVerifier + Logger() *slog.Logger } func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) { @@ -63,10 +65,10 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { return } case "": - RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing")) + RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), exchanger.Logger()) return } - RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType)) + RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType), exchanger.Logger()) } // AuthenticatedTokenRequest is a helper interface for ParseAuthenticatedTokenRequest