feat(op): user slog for logging

integrate with golang.org/x/exp/slog for logging.
provide a middleware for request scoped logging.

BREAKING CHANGES:

1. OpenIDProvider and sub-interfaces get a Logger()
method to return the configured logger;
2. AuthRequestError now takes the complete Authorizer,
instead of only the encoder. So that it may use its Logger() method.
3. RequestError now takes a Logger as argument.
This commit is contained in:
Tim Möhlmann 2023-08-21 19:55:24 +02:00
parent 6708ef4c24
commit f30f0d3ead
22 changed files with 297 additions and 61 deletions

View file

@ -4,9 +4,11 @@ import (
"crypto/sha256"
"log"
"net/http"
"os"
"time"
"github.com/go-chi/chi"
"golang.org/x/exp/slog"
"golang.org/x/text/language"
"github.com/zitadel/oidc/v3/example/server/storage"
@ -43,10 +45,8 @@ func SetupServer(issuer string, storage Storage) chi.Router {
// for simplicity, we provide a very small default page for users who have signed out
router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) {
_, err := w.Write([]byte("signed out successfully"))
if err != nil {
log.Printf("error serving logged out page: %v", err)
}
w.Write([]byte("signed out successfully"))
// no need to check/log error, this will be handeled by the middleware.
})
// creation of the OpenIDProvider with the just created in-memory Storage
@ -117,6 +117,12 @@ func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider,
op.WithAllowInsecure(),
// as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth
op.WithCustomAuthEndpoint(op.NewEndpoint("auth")),
op.WithLogger(slog.New(
slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelDebug,
}),
)),
)
if err != nil {
return nil, err

View file

@ -3,6 +3,7 @@ package storage
import (
"time"
"golang.org/x/exp/slog"
"golang.org/x/text/language"
"github.com/zitadel/oidc/v3/pkg/oidc"
@ -41,6 +42,19 @@ type AuthRequest struct {
authTime time.Time
}
// LogValue allows you to define which fields will be logged.
// Implements the [slog.LogValuer]
func (a *AuthRequest) LogValue() slog.Value {
return slog.GroupValue(
slog.String("id", a.ID),
slog.Time("creation_date", a.CreationDate),
slog.Any("scopes", a.Scopes),
slog.String("response_type", string(a.ResponseType)),
slog.String("app_id", a.ApplicationID),
slog.String("callback_uri", a.CallbackURI),
)
}
func (a *AuthRequest) GetID() string {
return a.ID
}

4
go.mod
View file

@ -11,9 +11,11 @@ require (
github.com/jeremija/gosubmit v0.2.7
github.com/muhlemmer/gu v0.3.1
github.com/rs/cors v1.9.0
github.com/rs/xid v1.5.0
github.com/sirupsen/logrus v1.9.0
github.com/stretchr/testify v1.8.2
github.com/zitadel/schema v1.3.0
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
golang.org/x/oauth2 v0.7.0
golang.org/x/text v0.9.0
gopkg.in/square/go-jose.v2 v2.6.0
@ -27,7 +29,7 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/crypto v0.7.0 // indirect
golang.org/x/net v0.9.0 // indirect
golang.org/x/sys v0.7.0 // indirect
golang.org/x/sys v0.11.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.29.1 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect

8
go.sum
View file

@ -36,6 +36,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE=
github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -53,6 +55,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
@ -73,8 +77,8 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=

View file

@ -546,7 +546,7 @@ func withURLParam(key, value string) func() []oauth2.AuthCodeOption {
// This is the generalized, unexported, function used by both
// URLParamOpt and AuthURLOpt.
func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption {
return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode())
return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).String())
}
type URLParamOpt func() []oauth2.AuthCodeOption

View file

@ -1,5 +1,11 @@
package oidc
import (
"fmt"
"golang.org/x/exp/slog"
)
const (
// ScopeOpenID defines the scope `openid`
// OpenID Connect requests MUST contain the `openid` scope value
@ -86,6 +92,15 @@ type AuthRequest struct {
RequestParam string `schema:"request"`
}
func (a *AuthRequest) LogValue() slog.Value {
return slog.GroupValue(
slog.Any("scopes", fmt.Stringer(a.Scopes)),
slog.String("response_type", string(a.ResponseType)),
slog.String("client_id", a.ClientID),
slog.String("redirect_uri", a.RedirectURI),
)
}
// GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface
func (a *AuthRequest) GetRedirectURI() string {
return a.RedirectURI

View file

@ -3,6 +3,8 @@ package oidc
import (
"errors"
"fmt"
"golang.org/x/exp/slog"
)
type errorType string
@ -171,3 +173,31 @@ func DefaultToServerError(err error, description string) *Error {
}
return oauth
}
func (e *Error) LogLevel() slog.Level {
level := slog.LevelWarn
if e.ErrorType == ServerError {
level = slog.LevelError
}
if e.ErrorType == AuthorizationPending {
level = slog.LevelInfo
}
return level
}
func (e *Error) LogValue() slog.Value {
attrs := make([]slog.Attr, 0, 4)
if e.Parent != nil {
attrs = append(attrs, slog.Any("parent", e.Parent))
}
if e.Description != "" {
attrs = append(attrs, slog.String("description", e.Description))
}
if e.ErrorType != "" {
attrs = append(attrs, slog.String("type", string(e.ErrorType)))
}
if e.State != "" {
attrs = append(attrs, slog.String("state", e.State))
}
return slog.GroupValue(attrs...)
}

View file

@ -106,7 +106,7 @@ type ResponseType string
type ResponseMode string
func (s SpaceDelimitedArray) Encode() string {
func (s SpaceDelimitedArray) String() string {
return strings.Join(s, " ")
}
@ -116,11 +116,11 @@ func (s *SpaceDelimitedArray) UnmarshalText(text []byte) error {
}
func (s SpaceDelimitedArray) MarshalText() ([]byte, error) {
return []byte(s.Encode()), nil
return []byte(s.String()), nil
}
func (s SpaceDelimitedArray) MarshalJSON() ([]byte, error) {
return json.Marshal((s).Encode())
return json.Marshal((s).String())
}
func (s *SpaceDelimitedArray) UnmarshalJSON(data []byte) error {
@ -165,7 +165,7 @@ func (s SpaceDelimitedArray) Value() (driver.Value, error) {
func NewEncoder() *schema.Encoder {
e := schema.NewEncoder()
e.RegisterEncoder(SpaceDelimitedArray{}, func(value reflect.Value) string {
return value.Interface().(SpaceDelimitedArray).Encode()
return value.Interface().(SpaceDelimitedArray).String()
})
return e
}

View file

@ -14,9 +14,13 @@ import (
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
str "github.com/zitadel/oidc/v3/pkg/strings"
"golang.org/x/exp/slog"
)
type AuthRequest interface {
// LogValuer allows the implementation which fields to log,
// and which ones to redact for security reasons.
slog.LogValuer
GetID() string
GetACR() string
GetAMR() []string
@ -41,6 +45,7 @@ type Authorizer interface {
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
Crypto() Crypto
RequestObjectSupported() bool
Logger() *slog.Logger
}
// AuthorizeValidator is an extension of Authorizer interface
@ -67,23 +72,23 @@ func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, *
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder())
if err != nil {
AuthRequestError(w, r, nil, err, authorizer.Encoder())
AuthRequestError(w, r, nil, err, authorizer)
return
}
ctx := r.Context()
if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, authReq, err, authorizer)
return
}
}
if authReq.ClientID == "" {
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer.Encoder())
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer)
return
}
if authReq.RedirectURI == "" {
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer.Encoder())
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer)
return
}
validation := ValidateAuthRequest
@ -92,21 +97,21 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
}
userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx))
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, authReq, err, authorizer)
return
}
if authReq.RequestParam != "" {
AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder())
AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer)
return
}
req, err := authorizer.Storage().CreateAuthRequest(ctx, authReq, userID)
if err != nil {
AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder())
AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer)
return
}
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
if err != nil {
AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder())
AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer)
return
}
RedirectToLogin(req.GetID(), client, w, r)
@ -406,18 +411,18 @@ func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *
func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
id, err := ParseAuthorizeCallbackRequest(r)
if err != nil {
AuthRequestError(w, r, nil, err, authorizer.Encoder())
AuthRequestError(w, r, nil, err, authorizer)
return
}
authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
if err != nil {
AuthRequestError(w, r, nil, err, authorizer.Encoder())
AuthRequestError(w, r, nil, err, authorizer)
return
}
if !authReq.Done() {
AuthRequestError(w, r, authReq,
oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."),
authorizer.Encoder())
authorizer)
return
}
AuthResponse(authReq, authorizer, w, r)
@ -438,7 +443,7 @@ func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) {
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, authReq, err, authorizer)
return
}
if authReq.GetResponseType() == oidc.ResponseTypeCode {
@ -452,7 +457,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, authReq, err, authorizer)
return
}
codeResponse := struct {
@ -464,7 +469,7 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques
}
callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, authReq, err, authorizer)
return
}
http.Redirect(w, r, callback, http.StatusFound)
@ -475,12 +480,12 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque
createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly
resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "", "")
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, authReq, err, authorizer)
return
}
callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, authReq, err, authorizer)
return
}
http.Redirect(w, r, callback, http.StatusFound)

View file

@ -57,7 +57,7 @@ var (
func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if err := DeviceAuthorization(w, r, o); err != nil {
RequestError(w, r, err)
RequestError(w, r, err, o.Logger())
}
}
}
@ -190,7 +190,7 @@ func (r *deviceAccessTokenRequest) GetScopes() []string {
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
if err := deviceAccessToken(w, r, exchanger); err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
}
}

View file

@ -5,21 +5,29 @@ import (
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/exp/slog"
)
type ErrAuthRequest interface {
slog.LogValuer
GetRedirectURI() string
GetResponseType() oidc.ResponseType
GetState() string
}
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder httphelper.Encoder) {
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, authorizer Authorizer) {
e := oidc.DefaultToServerError(err, err.Error())
logger := authorizer.Logger().With("oidc_error", e)
if authReq == nil {
logger.Log(r.Context(), e.LogLevel(), "auth request nil")
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
e := oidc.DefaultToServerError(err, err.Error())
logger = logger.With("authRequest", authReq)
if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
logger.Log(r.Context(), e.LogLevel(), "auth request without redirect")
http.Error(w, e.Description, http.StatusBadRequest)
return
}
@ -28,19 +36,22 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
responseMode = rm.GetResponseMode()
}
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, authorizer.Encoder())
if err != nil {
logger.ErrorContext(r.Context(), "auth response URL", "error", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
logger.Log(r.Context(), e.LogLevel(), "auth request error")
http.Redirect(w, r, url, http.StatusFound)
}
func RequestError(w http.ResponseWriter, r *http.Request, err error) {
func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
e := oidc.DefaultToServerError(err, err.Error())
status := http.StatusBadRequest
if e.ErrorType == oidc.InvalidClient {
status = 401
}
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
httphelper.MarshalJSONWithStatus(w, e, status)
}

114
pkg/op/logger.go Normal file
View file

@ -0,0 +1,114 @@
package op
import (
"context"
"net/http"
"time"
"github.com/rs/xid"
"golang.org/x/exp/slog"
)
func newLogger(logger *slog.Logger) *slog.Logger {
if logger == nil {
logger = slog.Default()
}
return slog.New(&logHandler{
handler: logger.Handler(),
})
}
type LogKey int
const (
RequestID LogKey = iota
maxLogKey
)
type logHandler struct {
handler slog.Handler
}
func (h *logHandler) Enabled(ctx context.Context, level slog.Level) bool {
return h.handler.Enabled(ctx, level)
}
type logAttributes []slog.Attr
func (attrs *logAttributes) appendFromContext(ctx context.Context, ctxKey any, logKey string) {
v := ctx.Value(RequestID)
if v == nil {
return
}
*attrs = append(*attrs, slog.Group("request", slog.Attr{
Key: "id",
Value: slog.AnyValue(v),
}))
}
func (h *logHandler) Handle(ctx context.Context, record slog.Record) error {
attrs := make(logAttributes, 0, maxLogKey)
attrs.appendFromContext(ctx, RequestID, "id")
handler := h.handler.WithAttrs(attrs)
return handler.Handle(ctx, record)
}
func (h *logHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return &logHandler{
handler: h.handler.WithAttrs(attrs),
}
}
func (h *logHandler) WithGroup(name string) slog.Handler {
return &logHandler{
handler: h.handler.WithGroup(name),
}
}
func (o *Provider) LogMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
r = r.WithContext(context.WithValue(r.Context(), RequestID, xid.New()))
lw := &loggedWriter{
ResponseWriter: w,
}
next.ServeHTTP(lw, r)
logger := o.logger.With(
slog.Group("request", "method", r.Method, "url", r.URL),
slog.Group("response", "duration", time.Since(start), "status", lw.statusCode, "written", lw.written),
)
if lw.err != nil {
logger.ErrorContext(r.Context(), "response writer", "error", lw.err)
return
}
logger.InfoContext(r.Context(), "done")
})
}
}
type loggedWriter struct {
http.ResponseWriter
statusCode int
written int
err error
}
func (w *loggedWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *loggedWriter) Write(b []byte) (int, error) {
if w.statusCode == 0 {
w.WriteHeader(http.StatusOK)
}
n, err := w.ResponseWriter.Write(b)
w.written += n
w.err = err
return n, err
}

View file

@ -11,6 +11,7 @@ import (
gomock "github.com/golang/mock/gomock"
http "github.com/zitadel/oidc/v3/pkg/http"
op "github.com/zitadel/oidc/v3/pkg/op"
slog "golang.org/x/exp/slog"
)
// MockAuthorizer is a mock of Authorizer interface.
@ -92,6 +93,20 @@ func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier(arg0 interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier), arg0)
}
// Logger mocks base method.
func (m *MockAuthorizer) Logger() *slog.Logger {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Logger")
ret0, _ := ret[0].(*slog.Logger)
return ret0
}
// Logger indicates an expected call of Logger.
func (mr *MockAuthorizerMockRecorder) Logger() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockAuthorizer)(nil).Logger))
}
// RequestObjectSupported mocks base method.
func (m *MockAuthorizer) RequestObjectSupported() bool {
m.ctrl.T.Helper()

View file

@ -9,6 +9,7 @@ import (
"github.com/go-chi/chi"
"github.com/rs/cors"
"github.com/zitadel/schema"
"golang.org/x/exp/slog"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
@ -78,6 +79,7 @@ type OpenIDProvider interface {
Crypto() Crypto
DefaultLogoutRedirectURI() string
Probes() []ProbesFn
Logger() *slog.Logger
// Deprecated: Provider now implements http.Handler directly.
HttpHandler() http.Handler
@ -85,8 +87,9 @@ type OpenIDProvider interface {
type HttpInterceptor func(http.Handler) http.Handler
func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router {
func CreateRouter(o *Provider, interceptors ...HttpInterceptor) chi.Router {
router := chi.NewRouter()
router.Use(o.LogMiddleware())
router.Use(cors.New(defaultCORSOptions).Handler)
router.Use(intercept(o.IssuerFromRequest, interceptors...))
router.HandleFunc(healthEndpoint, healthHandler)
@ -174,6 +177,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
storage: storage,
endpoints: DefaultEndpoints,
timer: make(<-chan time.Time),
logger: slog.Default(),
}
for _, optFunc := range opOpts {
@ -217,6 +221,7 @@ type Provider struct {
timer <-chan time.Time
accessTokenVerifierOpts []AccessTokenVerifierOpt
idTokenHintVerifierOpts []IDTokenHintVerifierOpt
logger *slog.Logger
}
func (o *Provider) IssuerFromRequest(r *http.Request) string {
@ -375,6 +380,10 @@ func (o *Provider) Probes() []ProbesFn {
}
}
func (o *Provider) Logger() *slog.Logger {
return o.logger
}
// Deprecated: Provider now implements http.Handler directly.
func (o *Provider) HttpHandler() http.Handler {
return o
@ -523,6 +532,13 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option {
}
}
func WithLogger(logger *slog.Logger) Option {
return func(o *Provider) error {
o.logger = newLogger(logger)
return nil
}
}
func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler {
issuerInterceptor := NewIssuerInterceptor(i)
return func(handler http.Handler) http.Handler {

View file

@ -156,7 +156,7 @@ func TestRoutes(t *testing.T) {
values: map[string]string{
"client_id": client.GetID(),
"redirect_uri": "https://example.com",
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"response_type": string(oidc.ResponseTypeCode),
},
wantCode: http.StatusFound,
@ -193,7 +193,7 @@ func TestRoutes(t *testing.T) {
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeBearer),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"assertion": jwtToken,
},
wantCode: http.StatusBadRequest,
@ -206,7 +206,7 @@ func TestRoutes(t *testing.T) {
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"grant_type": string(oidc.GrantTypeTokenExchange),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"subject_token": jwtToken,
"subject_token_type": string(oidc.AccessTokenType),
},
@ -223,7 +223,7 @@ func TestRoutes(t *testing.T) {
basicAuth: &basicAuth{"sid1", "verysecret"},
values: map[string]string{
"grant_type": string(oidc.GrantTypeClientCredentials),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
},
wantCode: http.StatusOK,
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`},
@ -338,7 +338,7 @@ func TestRoutes(t *testing.T) {
path: testProvider.DeviceAuthorizationEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
},
wantCode: http.StatusOK,
contains: []string{

View file

@ -8,6 +8,7 @@ import (
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/exp/slog"
)
type SessionEnder interface {
@ -15,6 +16,7 @@ type SessionEnder interface {
Storage() Storage
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
DefaultLogoutRedirectURI() string
Logger() *slog.Logger
}
func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) {
@ -31,12 +33,12 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
}
session, err := ValidateEndSessionRequest(r.Context(), req, ender)
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, ender.Logger())
return
}
err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID)
if err != nil {
RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"))
RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"), ender.Logger())
return
}
http.Redirect(w, r, session.RedirectURI, http.StatusFound)

View file

@ -14,18 +14,18 @@ import (
func ClientCredentialsExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
request, err := ParseClientCredentialsRequest(r, exchanger.Decoder())
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
}
validatedRequest, client, err := ValidateClientCredentialsRequest(r.Context(), request, exchanger)
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
return
}
resp, err := CreateClientCredentialsTokenResponse(r.Context(), validatedRequest, exchanger, client)
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
return
}

View file

@ -13,20 +13,20 @@ import (
func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder())
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
}
if tokenReq.Code == "" {
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"))
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), exchanger.Logger())
return
}
authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
return
}
resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code, "")
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
return
}
httphelper.MarshalJSON(w, resp)

View file

@ -136,17 +136,17 @@ func (r *tokenExchangeRequest) SetSubject(subject string) {
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
tokenExchangeReq, clientID, clientSecret, err := ParseTokenExchangeRequest(r, exchanger.Decoder())
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
}
tokenExchangeRequest, client, err := ValidateTokenExchangeRequest(r.Context(), tokenExchangeReq, clientID, clientSecret, exchanger)
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
return
}
resp, err := CreateTokenExchangeResponse(r.Context(), tokenExchangeRequest, client, exchanger)
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
return
}
httphelper.MarshalJSON(w, resp)

View file

@ -18,23 +18,23 @@ type JWTAuthorizationGrantExchanger interface {
func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizationGrantExchanger) {
profileRequest, err := ParseJWTProfileGrantRequest(r, exchanger.Decoder())
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
}
tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, exchanger.JWTProfileVerifier(r.Context()))
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
return
}
tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(r.Context(), tokenRequest.Issuer, profileRequest.Scope)
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
return
}
resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, exchanger)
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
return
}
httphelper.MarshalJSON(w, resp)

View file

@ -26,16 +26,16 @@ type RefreshTokenRequest interface {
func RefreshTokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
tokenReq, err := ParseRefreshTokenRequest(r, exchanger.Decoder())
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
}
validatedRequest, client, err := ValidateRefreshTokenRequest(r.Context(), tokenReq, exchanger)
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
return
}
resp, err := CreateTokenResponse(r.Context(), validatedRequest, client, exchanger, true, "", tokenReq.RefreshToken)
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
return
}
httphelper.MarshalJSON(w, resp)

View file

@ -7,6 +7,7 @@ import (
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/exp/slog"
)
type Exchanger interface {
@ -22,6 +23,7 @@ type Exchanger interface {
GrantTypeDeviceCodeSupported() bool
AccessTokenVerifier(context.Context) *AccessTokenVerifier
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
Logger() *slog.Logger
}
func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) {
@ -63,10 +65,10 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
return
}
case "":
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"))
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), exchanger.Logger())
return
}
RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType))
RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType), exchanger.Logger())
}
// AuthenticatedTokenRequest is a helper interface for ParseAuthenticatedTokenRequest