feat: add slog logging (#432)
* feat(op): user slog for logging integrate with golang.org/x/exp/slog for logging. provide a middleware for request scoped logging. BREAKING CHANGES: 1. OpenIDProvider and sub-interfaces get a Logger() method to return the configured logger; 2. AuthRequestError now takes the complete Authorizer, instead of only the encoder. So that it may use its Logger() method. 3. RequestError now takes a Logger as argument. * use zitadel/logging * finish op and testing without middleware for now * minimum go version 1.19 * update go mod * log value testing only on go 1.20 or later * finish the RP and example * ping logging release
This commit is contained in:
parent
6708ef4c24
commit
0879c88399
34 changed files with 800 additions and 85 deletions
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
|
@ -16,7 +16,7 @@ jobs:
|
|||
runs-on: ubuntu-20.04
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.18', '1.19', '1.20']
|
||||
go: ['1.19', '1.20', '1.21']
|
||||
name: Go ${{ matrix.go }} test
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
|
|
@ -115,10 +115,10 @@ Versions that also build are marked with :warning:.
|
|||
|
||||
| Version | Supported |
|
||||
| ------- | ------------------ |
|
||||
| <1.18 | :x: |
|
||||
| 1.18 | :warning: |
|
||||
| 1.19 | :white_check_mark: |
|
||||
| <1.19 | :x: |
|
||||
| 1.19 | :warning: |
|
||||
| 1.20 | :white_check_mark: |
|
||||
| 1.21 | :white_check_mark: |
|
||||
|
||||
## Why another library
|
||||
|
||||
|
|
|
@ -7,11 +7,14 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/slog"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
@ -33,9 +36,25 @@ func main() {
|
|||
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
|
||||
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
|
||||
|
||||
logger := slog.New(
|
||||
slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
AddSource: true,
|
||||
Level: slog.LevelDebug,
|
||||
}),
|
||||
)
|
||||
client := &http.Client{
|
||||
Timeout: time.Minute,
|
||||
}
|
||||
// enable outgoing request logging
|
||||
logging.EnableHTTPClient(client,
|
||||
logging.WithClientGroup("client"),
|
||||
)
|
||||
|
||||
options := []rp.Option{
|
||||
rp.WithCookieHandler(cookieHandler),
|
||||
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
|
||||
rp.WithHTTPClient(client),
|
||||
rp.WithLogger(logger),
|
||||
}
|
||||
if clientSecret == "" {
|
||||
options = append(options, rp.WithPKCE(cookieHandler))
|
||||
|
@ -44,7 +63,10 @@ func main() {
|
|||
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
|
||||
}
|
||||
|
||||
provider, err := rp.NewRelyingPartyOIDC(context.TODO(), issuer, clientID, clientSecret, redirectURI, scopes, options...)
|
||||
// One can add a logger to the context,
|
||||
// pre-defining log attributes as required.
|
||||
ctx := logging.ToContext(context.TODO(), logger)
|
||||
provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, redirectURI, scopes, options...)
|
||||
if err != nil {
|
||||
logrus.Fatalf("error creating provider %s", err.Error())
|
||||
}
|
||||
|
@ -119,8 +141,22 @@ func main() {
|
|||
//
|
||||
// http.Handle(callbackPath, rp.CodeExchangeHandler(marshalToken, provider))
|
||||
|
||||
// simple counter for request IDs
|
||||
var counter atomic.Int64
|
||||
// enable incomming request logging
|
||||
mw := logging.Middleware(
|
||||
logging.WithLogger(logger),
|
||||
logging.WithGroup("server"),
|
||||
logging.WithIDFunc(func() slog.Attr {
|
||||
return slog.Int64("id", counter.Add(1))
|
||||
}),
|
||||
)
|
||||
|
||||
lis := fmt.Sprintf("127.0.0.1:%s", port)
|
||||
logrus.Infof("listening on http://%s/", lis)
|
||||
logrus.Info("press ctrl+c to stop")
|
||||
logrus.Fatal(http.ListenAndServe(lis, nil))
|
||||
logger.Info("server listening, press ctrl+c to stop", "addr", lis)
|
||||
err = http.ListenAndServe(lis, mw(http.DefaultServeMux))
|
||||
if err != http.ErrServerClosed {
|
||||
logger.Error("server terminated", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,9 +4,12 @@ import (
|
|||
"crypto/sha256"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/zitadel/logging"
|
||||
"golang.org/x/exp/slog"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/oidc/v3/example/server/storage"
|
||||
|
@ -31,26 +34,33 @@ type Storage interface {
|
|||
deviceAuthenticate
|
||||
}
|
||||
|
||||
// simple counter for request IDs
|
||||
var counter atomic.Int64
|
||||
|
||||
// SetupServer creates an OIDC server with Issuer=http://localhost:<port>
|
||||
//
|
||||
// Use one of the pre-made clients in storage/clients.go or register a new one.
|
||||
func SetupServer(issuer string, storage Storage) chi.Router {
|
||||
func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router {
|
||||
// the OpenID Provider requires a 32-byte key for (token) encryption
|
||||
// be sure to create a proper crypto random key and manage it securely!
|
||||
key := sha256.Sum256([]byte("test"))
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Use(logging.Middleware(
|
||||
logging.WithLogger(logger),
|
||||
logging.WithIDFunc(func() slog.Attr {
|
||||
return slog.Int64("id", counter.Add(1))
|
||||
}),
|
||||
))
|
||||
|
||||
// for simplicity, we provide a very small default page for users who have signed out
|
||||
router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) {
|
||||
_, err := w.Write([]byte("signed out successfully"))
|
||||
if err != nil {
|
||||
log.Printf("error serving logged out page: %v", err)
|
||||
}
|
||||
w.Write([]byte("signed out successfully"))
|
||||
// no need to check/log error, this will be handeled by the middleware.
|
||||
})
|
||||
|
||||
// creation of the OpenIDProvider with the just created in-memory Storage
|
||||
provider, err := newOP(storage, issuer, key)
|
||||
provider, err := newOP(storage, issuer, key, logger)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -80,7 +90,7 @@ func SetupServer(issuer string, storage Storage) chi.Router {
|
|||
// newOP will create an OpenID Provider for localhost on a specified port with a given encryption key
|
||||
// and a predefined default logout uri
|
||||
// it will enable all options (see descriptions)
|
||||
func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider, error) {
|
||||
func newOP(storage op.Storage, issuer string, key [32]byte, logger *slog.Logger) (op.OpenIDProvider, error) {
|
||||
config := &op.Config{
|
||||
CryptoKey: key,
|
||||
|
||||
|
@ -117,6 +127,8 @@ func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider,
|
|||
op.WithAllowInsecure(),
|
||||
// as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth
|
||||
op.WithCustomAuthEndpoint(op.NewEndpoint("auth")),
|
||||
// Pass our logger to the OP
|
||||
op.WithLogger(logger.WithGroup("op")),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -2,11 +2,12 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/zitadel/oidc/v3/example/server/exampleop"
|
||||
"github.com/zitadel/oidc/v3/example/server/storage"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
@ -20,16 +21,22 @@ func main() {
|
|||
// in this example it will be handled in-memory
|
||||
storage := storage.NewStorage(storage.NewUserStore(issuer))
|
||||
|
||||
router := exampleop.SetupServer(issuer, storage)
|
||||
logger := slog.New(
|
||||
slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
AddSource: true,
|
||||
Level: slog.LevelDebug,
|
||||
}),
|
||||
)
|
||||
router := exampleop.SetupServer(issuer, storage, logger)
|
||||
|
||||
server := &http.Server{
|
||||
Addr: ":" + port,
|
||||
Handler: router,
|
||||
}
|
||||
log.Printf("server listening on http://localhost:%s/", port)
|
||||
log.Println("press ctrl+c to stop")
|
||||
logger.Info("server listening, press ctrl+c to stop", "addr", fmt.Sprintf("http://localhost:%s/", port))
|
||||
err := server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
if err != http.ErrServerClosed {
|
||||
logger.Error("server terminated", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package storage
|
|||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/slog"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
@ -41,6 +42,19 @@ type AuthRequest struct {
|
|||
authTime time.Time
|
||||
}
|
||||
|
||||
// LogValue allows you to define which fields will be logged.
|
||||
// Implements the [slog.LogValuer]
|
||||
func (a *AuthRequest) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.String("id", a.ID),
|
||||
slog.Time("creation_date", a.CreationDate),
|
||||
slog.Any("scopes", a.Scopes),
|
||||
slog.String("response_type", string(a.ResponseType)),
|
||||
slog.String("app_id", a.ApplicationID),
|
||||
slog.String("callback_uri", a.CallbackURI),
|
||||
)
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
|
|
8
go.mod
8
go.mod
|
@ -1,6 +1,6 @@
|
|||
module github.com/zitadel/oidc/v3
|
||||
|
||||
go 1.18
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/go-chi/chi v1.5.4
|
||||
|
@ -11,9 +11,11 @@ require (
|
|||
github.com/jeremija/gosubmit v0.2.7
|
||||
github.com/muhlemmer/gu v0.3.1
|
||||
github.com/rs/cors v1.9.0
|
||||
github.com/sirupsen/logrus v1.9.0
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/stretchr/testify v1.8.2
|
||||
github.com/zitadel/logging v0.4.0
|
||||
github.com/zitadel/schema v1.3.0
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
|
||||
golang.org/x/oauth2 v0.7.0
|
||||
golang.org/x/text v0.9.0
|
||||
gopkg.in/square/go-jose.v2 v2.6.0
|
||||
|
@ -27,7 +29,7 @@ require (
|
|||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
golang.org/x/crypto v0.7.0 // indirect
|
||||
golang.org/x/net v0.9.0 // indirect
|
||||
golang.org/x/sys v0.7.0 // indirect
|
||||
golang.org/x/sys v0.11.0 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/protobuf v1.29.1 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
|
|
13
go.sum
13
go.sum
|
@ -36,8 +36,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
|||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE=
|
||||
github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
|
||||
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
|
||||
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
|
@ -47,12 +47,16 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
|||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
github.com/zitadel/logging v0.4.0 h1:lRAIFgaRoJpLNbsL7jtIYHcMDoEJP9QZB4GqMfl4xaA=
|
||||
github.com/zitadel/logging v0.4.0/go.mod h1:6uALRJawpkkuUPCkgzfgcPR3c2N908wqnOnIrRelUFc=
|
||||
github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0=
|
||||
github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
|
||||
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
|
@ -73,8 +77,8 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
|
||||
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
|
@ -100,6 +104,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN
|
|||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
|
||||
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/oidc/v3/pkg/crypto"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
@ -37,6 +38,10 @@ func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellK
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if logger, ok := logging.FromContext(ctx); ok {
|
||||
logger.Debug("discover", "config", discoveryConfig)
|
||||
}
|
||||
|
||||
if discoveryConfig.Issuer != issuer {
|
||||
return nil, oidc.ErrIssuerInvalid
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/jeremija/gosubmit"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/slog"
|
||||
|
||||
"github.com/zitadel/oidc/v3/example/server/exampleop"
|
||||
"github.com/zitadel/oidc/v3/example/server/storage"
|
||||
|
@ -29,6 +30,13 @@ import (
|
|||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
var Logger = slog.New(
|
||||
slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
AddSource: true,
|
||||
Level: slog.LevelDebug,
|
||||
}),
|
||||
)
|
||||
|
||||
var CTX context.Context
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
@ -49,7 +57,7 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
opServer := httptest.NewServer(&dh)
|
||||
defer opServer.Close()
|
||||
t.Logf("auth server at %s", opServer.URL)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger)
|
||||
|
||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
|
@ -100,7 +108,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
|
|||
opServer := httptest.NewServer(&dh)
|
||||
defer opServer.Close()
|
||||
t.Logf("auth server at %s", opServer.URL)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger)
|
||||
|
||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
|
|
|
@ -33,6 +33,7 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.
|
|||
// in RFC 8628, section 3.1 and 3.2:
|
||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
|
||||
func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, authFn any) (*oidc.DeviceAuthorizationResponse, error) {
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAuthorization")
|
||||
req, err := newDeviceClientCredentialsRequest(scopes, rp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -45,6 +46,7 @@ func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty,
|
|||
// by means of polling as defined in RFC, section 3.3 and 3.4:
|
||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4
|
||||
func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAccessToken")
|
||||
req := &client.DeviceAccessTokenRequest{
|
||||
DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
|
||||
GrantType: oidc.GrantTypeDeviceCode,
|
||||
|
|
17
pkg/client/rp/log.go
Normal file
17
pkg/client/rp/log.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func logCtxWithRPData(ctx context.Context, rp RelyingParty, attrs ...any) context.Context {
|
||||
logger, ok := rp.Logger(ctx)
|
||||
if !ok {
|
||||
return ctx
|
||||
}
|
||||
logger = logger.With(slog.Group("rp", attrs...))
|
||||
return logging.ToContext(ctx, logger)
|
||||
}
|
|
@ -10,6 +10,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/zitadel/logging"
|
||||
"golang.org/x/exp/slog"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
|
@ -67,6 +69,9 @@ type RelyingParty interface {
|
|||
// ErrorHandler returns the handler used for callback errors
|
||||
|
||||
ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)
|
||||
|
||||
// Logger from the context, or a fallback if set.
|
||||
Logger(context.Context) (logger *slog.Logger, ok bool)
|
||||
}
|
||||
|
||||
type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string)
|
||||
|
@ -90,6 +95,7 @@ type relyingParty struct {
|
|||
idTokenVerifier *IDTokenVerifier
|
||||
verifierOpts []VerifierOption
|
||||
signer jose.Signer
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (rp *relyingParty) OAuthConfig() *oauth2.Config {
|
||||
|
@ -150,6 +156,14 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request,
|
|||
return rp.errorHandler
|
||||
}
|
||||
|
||||
func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) {
|
||||
logger, ok = logging.FromContext(ctx)
|
||||
if ok {
|
||||
return logger, ok
|
||||
}
|
||||
return rp.logger, rp.logger != nil
|
||||
}
|
||||
|
||||
// NewRelyingPartyOAuth creates an (OAuth2) RelyingParty with the given
|
||||
// OAuth2 Config and possible configOptions
|
||||
// it will use the AuthURL and TokenURL set in config
|
||||
|
@ -194,6 +208,7 @@ func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, re
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "NewRelyingPartyOIDC")
|
||||
discoveryConfiguration, err := client.Discover(ctx, rp.issuer, rp.httpClient, rp.DiscoveryEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -281,6 +296,15 @@ func WithJWTProfile(signerFromKey SignerFromKey) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithLogger sets a logger that is used
|
||||
// in case the request context does not contain a logger.
|
||||
func WithLogger(logger *slog.Logger) Option {
|
||||
return func(rp *relyingParty) error {
|
||||
rp.logger = logger
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type SignerFromKey func() (jose.Signer, error)
|
||||
|
||||
func SignerFromKeyPath(path string) SignerFromKey {
|
||||
|
@ -378,6 +402,7 @@ func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Tok
|
|||
// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
|
||||
// returning it parsed together with the oauth2 tokens (access, refresh)
|
||||
func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "CodeExchange")
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
|
||||
codeOpts := make([]oauth2.AuthCodeOption, 0)
|
||||
for _, opt := range opts {
|
||||
|
@ -467,6 +492,7 @@ func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCa
|
|||
// [UserInfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject string, rp RelyingParty) (userinfo U, err error) {
|
||||
var nilU U
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "Userinfo")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil)
|
||||
if err != nil {
|
||||
|
@ -546,7 +572,7 @@ func withURLParam(key, value string) func() []oauth2.AuthCodeOption {
|
|||
// This is the generalized, unexported, function used by both
|
||||
// URLParamOpt and AuthURLOpt.
|
||||
func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption {
|
||||
return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode())
|
||||
return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).String())
|
||||
}
|
||||
|
||||
type URLParamOpt func() []oauth2.AuthCodeOption
|
||||
|
@ -621,6 +647,7 @@ type RefreshTokenRequest struct {
|
|||
// the IDToken and AccessToken will be verfied
|
||||
// and the IDToken and IDTokenClaims fields will be populated in the returned object.
|
||||
func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) {
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "RefreshTokens")
|
||||
request := RefreshTokenRequest{
|
||||
RefreshToken: refreshToken,
|
||||
Scopes: rp.OAuthConfig().Scopes,
|
||||
|
@ -644,6 +671,7 @@ func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refres
|
|||
}
|
||||
|
||||
func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) {
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "EndSession")
|
||||
request := oidc.EndSessionRequest{
|
||||
IdTokenHint: idToken,
|
||||
ClientID: rp.OAuthConfig().ClientID,
|
||||
|
@ -659,6 +687,7 @@ func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectU
|
|||
//
|
||||
// tokenTypeHint should be either "id_token" or "refresh_token".
|
||||
func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error {
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "RevokeToken")
|
||||
request := client.RevokeRequest{
|
||||
Token: token,
|
||||
TokenTypeHint: tokenTypeHint,
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
const (
|
||||
// ScopeOpenID defines the scope `openid`
|
||||
// OpenID Connect requests MUST contain the `openid` scope value
|
||||
|
@ -86,6 +90,15 @@ type AuthRequest struct {
|
|||
RequestParam string `schema:"request"`
|
||||
}
|
||||
|
||||
func (a *AuthRequest) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.Any("scopes", a.Scopes),
|
||||
slog.String("response_type", string(a.ResponseType)),
|
||||
slog.String("client_id", a.ClientID),
|
||||
slog.String("redirect_uri", a.RedirectURI),
|
||||
)
|
||||
}
|
||||
|
||||
// GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface
|
||||
func (a *AuthRequest) GetRedirectURI() string {
|
||||
return a.RedirectURI
|
||||
|
|
27
pkg/oidc/authorization_test.go
Normal file
27
pkg/oidc/authorization_test.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
//go:build go1.20
|
||||
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func TestAuthRequest_LogValue(t *testing.T) {
|
||||
a := &AuthRequest{
|
||||
Scopes: SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "respType",
|
||||
ClientID: "123",
|
||||
RedirectURI: "http://example.com/callback",
|
||||
}
|
||||
want := slog.GroupValue(
|
||||
slog.Any("scopes", SpaceDelimitedArray{"a", "b"}),
|
||||
slog.String("response_type", "respType"),
|
||||
slog.String("client_id", "123"),
|
||||
slog.String("redirect_uri", "http://example.com/callback"),
|
||||
)
|
||||
got := a.LogValue()
|
||||
assert.Equal(t, want, got)
|
||||
}
|
|
@ -3,6 +3,8 @@ package oidc
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
type errorType string
|
||||
|
@ -171,3 +173,34 @@ func DefaultToServerError(err error, description string) *Error {
|
|||
}
|
||||
return oauth
|
||||
}
|
||||
|
||||
func (e *Error) LogLevel() slog.Level {
|
||||
level := slog.LevelWarn
|
||||
if e.ErrorType == ServerError {
|
||||
level = slog.LevelError
|
||||
}
|
||||
if e.ErrorType == AuthorizationPending {
|
||||
level = slog.LevelInfo
|
||||
}
|
||||
return level
|
||||
}
|
||||
|
||||
func (e *Error) LogValue() slog.Value {
|
||||
attrs := make([]slog.Attr, 0, 5)
|
||||
if e.Parent != nil {
|
||||
attrs = append(attrs, slog.Any("parent", e.Parent))
|
||||
}
|
||||
if e.Description != "" {
|
||||
attrs = append(attrs, slog.String("description", e.Description))
|
||||
}
|
||||
if e.ErrorType != "" {
|
||||
attrs = append(attrs, slog.String("type", string(e.ErrorType)))
|
||||
}
|
||||
if e.State != "" {
|
||||
attrs = append(attrs, slog.String("state", e.State))
|
||||
}
|
||||
if e.redirectDisabled {
|
||||
attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled))
|
||||
}
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
|
83
pkg/oidc/error_go120_test.go
Normal file
83
pkg/oidc/error_go120_test.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
//go:build go1.20
|
||||
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func TestError_LogValue(t *testing.T) {
|
||||
type fields struct {
|
||||
Parent error
|
||||
ErrorType errorType
|
||||
Description string
|
||||
State string
|
||||
redirectDisabled bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want slog.Value
|
||||
}{
|
||||
{
|
||||
name: "parent",
|
||||
fields: fields{
|
||||
Parent: io.EOF,
|
||||
},
|
||||
want: slog.GroupValue(slog.Any("parent", io.EOF)),
|
||||
},
|
||||
{
|
||||
name: "description",
|
||||
fields: fields{
|
||||
Description: "oops",
|
||||
},
|
||||
want: slog.GroupValue(slog.String("description", "oops")),
|
||||
},
|
||||
{
|
||||
name: "errorType",
|
||||
fields: fields{
|
||||
ErrorType: ExpiredToken,
|
||||
},
|
||||
want: slog.GroupValue(slog.String("type", string(ExpiredToken))),
|
||||
},
|
||||
{
|
||||
name: "state",
|
||||
fields: fields{
|
||||
State: "123",
|
||||
},
|
||||
want: slog.GroupValue(slog.String("state", "123")),
|
||||
},
|
||||
{
|
||||
name: "all fields",
|
||||
fields: fields{
|
||||
Parent: io.EOF,
|
||||
Description: "oops",
|
||||
ErrorType: ExpiredToken,
|
||||
State: "123",
|
||||
},
|
||||
want: slog.GroupValue(
|
||||
slog.Any("parent", io.EOF),
|
||||
slog.String("description", "oops"),
|
||||
slog.String("type", string(ExpiredToken)),
|
||||
slog.String("state", "123"),
|
||||
),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := &Error{
|
||||
Parent: tt.fields.Parent,
|
||||
ErrorType: tt.fields.ErrorType,
|
||||
Description: tt.fields.Description,
|
||||
State: tt.fields.State,
|
||||
redirectDisabled: tt.fields.redirectDisabled,
|
||||
}
|
||||
got := e.LogValue()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
81
pkg/oidc/error_test.go
Normal file
81
pkg/oidc/error_test.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func TestDefaultToServerError(t *testing.T) {
|
||||
type args struct {
|
||||
err error
|
||||
description string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *Error
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
args: args{
|
||||
err: io.ErrClosedPipe,
|
||||
description: "oops",
|
||||
},
|
||||
want: &Error{
|
||||
ErrorType: ServerError,
|
||||
Description: "oops",
|
||||
Parent: io.ErrClosedPipe,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "our Error",
|
||||
args: args{
|
||||
err: ErrAccessDenied(),
|
||||
description: "oops",
|
||||
},
|
||||
want: &Error{
|
||||
ErrorType: AccessDenied,
|
||||
Description: "The authorization request was denied.",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := DefaultToServerError(tt.args.err, tt.args.description)
|
||||
assert.ErrorIs(t, got, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestError_LogLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *Error
|
||||
want slog.Level
|
||||
}{
|
||||
{
|
||||
name: "server error",
|
||||
err: ErrServerError(),
|
||||
want: slog.LevelError,
|
||||
},
|
||||
{
|
||||
name: "authorization pending",
|
||||
err: ErrAuthorizationPending(),
|
||||
want: slog.LevelInfo,
|
||||
},
|
||||
{
|
||||
name: "some other error",
|
||||
err: ErrAccessDenied(),
|
||||
want: slog.LevelWarn,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.err.LogLevel()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -106,7 +106,7 @@ type ResponseType string
|
|||
|
||||
type ResponseMode string
|
||||
|
||||
func (s SpaceDelimitedArray) Encode() string {
|
||||
func (s SpaceDelimitedArray) String() string {
|
||||
return strings.Join(s, " ")
|
||||
}
|
||||
|
||||
|
@ -116,11 +116,11 @@ func (s *SpaceDelimitedArray) UnmarshalText(text []byte) error {
|
|||
}
|
||||
|
||||
func (s SpaceDelimitedArray) MarshalText() ([]byte, error) {
|
||||
return []byte(s.Encode()), nil
|
||||
return []byte(s.String()), nil
|
||||
}
|
||||
|
||||
func (s SpaceDelimitedArray) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal((s).Encode())
|
||||
return json.Marshal((s).String())
|
||||
}
|
||||
|
||||
func (s *SpaceDelimitedArray) UnmarshalJSON(data []byte) error {
|
||||
|
@ -165,7 +165,7 @@ func (s SpaceDelimitedArray) Value() (driver.Value, error) {
|
|||
func NewEncoder() *schema.Encoder {
|
||||
e := schema.NewEncoder()
|
||||
e.RegisterEncoder(SpaceDelimitedArray{}, func(value reflect.Value) string {
|
||||
return value.Interface().(SpaceDelimitedArray).Encode()
|
||||
return value.Interface().(SpaceDelimitedArray).String()
|
||||
})
|
||||
return e
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
str "github.com/zitadel/oidc/v3/pkg/strings"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
type AuthRequest interface {
|
||||
|
@ -41,6 +42,7 @@ type Authorizer interface {
|
|||
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
|
||||
Crypto() Crypto
|
||||
RequestObjectSupported() bool
|
||||
Logger() *slog.Logger
|
||||
}
|
||||
|
||||
// AuthorizeValidator is an extension of Authorizer interface
|
||||
|
@ -67,23 +69,23 @@ func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, *
|
|||
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||
authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, nil, err, authorizer)
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
|
||||
authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
}
|
||||
if authReq.ClientID == "" {
|
||||
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer)
|
||||
return
|
||||
}
|
||||
if authReq.RedirectURI == "" {
|
||||
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer)
|
||||
return
|
||||
}
|
||||
validation := ValidateAuthRequest
|
||||
|
@ -92,21 +94,21 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
|||
}
|
||||
userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx))
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
if authReq.RequestParam != "" {
|
||||
AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer)
|
||||
return
|
||||
}
|
||||
req, err := authorizer.Storage().CreateAuthRequest(ctx, authReq, userID)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer)
|
||||
return
|
||||
}
|
||||
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder())
|
||||
AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer)
|
||||
return
|
||||
}
|
||||
RedirectToLogin(req.GetID(), client, w, r)
|
||||
|
@ -406,18 +408,18 @@ func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *
|
|||
func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||
id, err := ParseAuthorizeCallbackRequest(r)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, nil, err, authorizer)
|
||||
return
|
||||
}
|
||||
authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, nil, err, authorizer)
|
||||
return
|
||||
}
|
||||
if !authReq.Done() {
|
||||
AuthRequestError(w, r, authReq,
|
||||
oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."),
|
||||
authorizer.Encoder())
|
||||
authorizer)
|
||||
return
|
||||
}
|
||||
AuthResponse(authReq, authorizer, w, r)
|
||||
|
@ -438,7 +440,7 @@ func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) {
|
|||
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
|
||||
client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||
|
@ -452,7 +454,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
|
|||
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
|
||||
code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
codeResponse := struct {
|
||||
|
@ -464,7 +466,7 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques
|
|||
}
|
||||
callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
|
@ -475,12 +477,12 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque
|
|||
createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly
|
||||
resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "", "")
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/pkg/op/mock"
|
||||
"github.com/zitadel/schema"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func TestAuthorize(t *testing.T) {
|
||||
|
@ -38,7 +39,7 @@ func TestAuthorize(t *testing.T) {
|
|||
|
||||
expect := authorizer.EXPECT()
|
||||
expect.Decoder().Return(schema.NewDecoder())
|
||||
expect.Encoder().Return(schema.NewEncoder())
|
||||
expect.Logger().Return(slog.Default())
|
||||
|
||||
if tt.expect != nil {
|
||||
tt.expect(expect)
|
||||
|
|
|
@ -57,7 +57,7 @@ var (
|
|||
func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := DeviceAuthorization(w, r, o); err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, o.Logger())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -190,7 +190,7 @@ func (r *deviceAccessTokenRequest) GetScopes() []string {
|
|||
|
||||
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
if err := deviceAccessToken(w, r, exchanger); err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
type ErrAuthRequest interface {
|
||||
|
@ -13,13 +14,31 @@ type ErrAuthRequest interface {
|
|||
GetState() string
|
||||
}
|
||||
|
||||
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder httphelper.Encoder) {
|
||||
// LogAuthRequest is an optional interface,
|
||||
// that allows logging AuthRequest fields.
|
||||
// If the AuthRequest does not implement this interface,
|
||||
// no details shall be printed to the logs.
|
||||
type LogAuthRequest interface {
|
||||
ErrAuthRequest
|
||||
slog.LogValuer
|
||||
}
|
||||
|
||||
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, authorizer Authorizer) {
|
||||
e := oidc.DefaultToServerError(err, err.Error())
|
||||
logger := authorizer.Logger().With("oidc_error", e)
|
||||
|
||||
if authReq == nil {
|
||||
logger.Log(r.Context(), e.LogLevel(), "auth request")
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
e := oidc.DefaultToServerError(err, err.Error())
|
||||
|
||||
if logAuthReq, ok := authReq.(LogAuthRequest); ok {
|
||||
logger = logger.With("auth_request", logAuthReq)
|
||||
}
|
||||
|
||||
if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
|
||||
logger.Log(r.Context(), e.LogLevel(), "auth request: not redirecting")
|
||||
http.Error(w, e.Description, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
@ -28,19 +47,22 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
|
|||
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
|
||||
responseMode = rm.GetResponseMode()
|
||||
}
|
||||
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
|
||||
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, authorizer.Encoder())
|
||||
if err != nil {
|
||||
logger.ErrorContext(r.Context(), "auth response URL", "error", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
logger.Log(r.Context(), e.LogLevel(), "auth request")
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
||||
func RequestError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
|
||||
e := oidc.DefaultToServerError(err, err.Error())
|
||||
status := http.StatusBadRequest
|
||||
if e.ErrorType == oidc.InvalidClient {
|
||||
status = 401
|
||||
status = http.StatusUnauthorized
|
||||
}
|
||||
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
|
||||
httphelper.MarshalJSONWithStatus(w, e, status)
|
||||
}
|
||||
|
|
277
pkg/op/error_test.go
Normal file
277
pkg/op/error_test.go
Normal file
|
@ -0,0 +1,277 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/schema"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func TestAuthRequestError(t *testing.T) {
|
||||
type args struct {
|
||||
authReq ErrAuthRequest
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantCode int
|
||||
wantHeaders map[string]string
|
||||
wantBody string
|
||||
wantLog string
|
||||
}{
|
||||
{
|
||||
name: "nil auth request",
|
||||
args: args{
|
||||
authReq: nil,
|
||||
err: io.ErrClosedPipe,
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantBody: "io: read/write on closed pipe\n",
|
||||
wantLog: `{
|
||||
"level":"ERROR",
|
||||
"msg":"auth request",
|
||||
"time":"not",
|
||||
"oidc_error":{
|
||||
"description":"io: read/write on closed pipe",
|
||||
"parent":"io: read/write on closed pipe",
|
||||
"type":"server_error"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request, no redirect URI",
|
||||
args: args{
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
err: oidc.ErrInteractionRequired().WithDescription("sign in"),
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantBody: "sign in\n",
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"auth request: not redirecting",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"sign in",
|
||||
"type":"interaction_required"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request, redirect disabled",
|
||||
args: args{
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
RedirectURI: "http://example.com/callback",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
err: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"),
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantBody: "oops\n",
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"auth request: not redirecting",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"http://example.com/callback",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"oops",
|
||||
"type":"invalid_request",
|
||||
"redirect_disabled":true
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request, url parse error",
|
||||
args: args{
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
RedirectURI: "can't parse this!\n",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
err: oidc.ErrInteractionRequired().WithDescription("sign in"),
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantBody: "ErrorType=server_error Parent=parse \"can't parse this!\\n\": net/url: invalid control character in URL\n",
|
||||
wantLog: `{
|
||||
"level":"ERROR",
|
||||
"msg":"auth response URL",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"can't parse this!\n",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"error":{
|
||||
"type":"server_error",
|
||||
"parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"sign in",
|
||||
"type":"interaction_required"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request redirect",
|
||||
args: args{
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
RedirectURI: "http://example.com/callback",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
err: oidc.ErrInteractionRequired().WithDescription("sign in"),
|
||||
},
|
||||
wantCode: http.StatusFound,
|
||||
wantHeaders: map[string]string{"Location": "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"},
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"auth request",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"http://example.com/callback",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"sign in",
|
||||
"type":"interaction_required"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logOut := new(strings.Builder)
|
||||
authorizer := &Provider{
|
||||
encoder: schema.NewEncoder(),
|
||||
logger: slog.New(
|
||||
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
|
||||
),
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/path", nil)
|
||||
AuthRequestError(w, r, tt.args.authReq, tt.args.err, authorizer)
|
||||
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, tt.wantCode, res.StatusCode)
|
||||
for key, wantHeader := range tt.wantHeaders {
|
||||
gotHeader := res.Header.Get(key)
|
||||
assert.Equalf(t, wantHeader, gotHeader, "header %q", key)
|
||||
}
|
||||
gotBody, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err, "read result body")
|
||||
assert.Equal(t, tt.wantBody, string(gotBody), "result body")
|
||||
|
||||
gotLog := logOut.String()
|
||||
t.Log(gotLog)
|
||||
assert.JSONEq(t, tt.wantLog, gotLog, "log output")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantCode int
|
||||
wantBody string
|
||||
wantLog string
|
||||
}{
|
||||
{
|
||||
name: "server error",
|
||||
err: io.ErrClosedPipe,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`,
|
||||
wantLog: `{
|
||||
"level":"ERROR",
|
||||
"msg":"request error",
|
||||
"time":"not",
|
||||
"oidc_error":{
|
||||
"parent":"io: read/write on closed pipe",
|
||||
"description":"io: read/write on closed pipe",
|
||||
"type":"server_error"}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "invalid client",
|
||||
err: oidc.ErrInvalidClient().WithDescription("not good"),
|
||||
wantCode: http.StatusUnauthorized,
|
||||
wantBody: `{"error":"invalid_client", "error_description":"not good"}`,
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"request error",
|
||||
"time":"not",
|
||||
"oidc_error":{
|
||||
"description":"not good",
|
||||
"type":"invalid_client"}
|
||||
}`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logOut := new(strings.Builder)
|
||||
logger := slog.New(
|
||||
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/path", nil)
|
||||
RequestError(w, r, tt.err, logger)
|
||||
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, tt.wantCode, res.StatusCode, "status code")
|
||||
|
||||
gotBody, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err, "read result body")
|
||||
assert.JSONEq(t, tt.wantBody, string(gotBody), "result body")
|
||||
|
||||
gotLog := logOut.String()
|
||||
t.Log(gotLog)
|
||||
assert.JSONEq(t, tt.wantLog, gotLog, "log output")
|
||||
})
|
||||
}
|
||||
}
|
|
@ -11,6 +11,7 @@ import (
|
|||
gomock "github.com/golang/mock/gomock"
|
||||
http "github.com/zitadel/oidc/v3/pkg/http"
|
||||
op "github.com/zitadel/oidc/v3/pkg/op"
|
||||
slog "golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
// MockAuthorizer is a mock of Authorizer interface.
|
||||
|
@ -92,6 +93,20 @@ func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier(arg0 interface{}) *gom
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier), arg0)
|
||||
}
|
||||
|
||||
// Logger mocks base method.
|
||||
func (m *MockAuthorizer) Logger() *slog.Logger {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Logger")
|
||||
ret0, _ := ret[0].(*slog.Logger)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Logger indicates an expected call of Logger.
|
||||
func (mr *MockAuthorizerMockRecorder) Logger() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockAuthorizer)(nil).Logger))
|
||||
}
|
||||
|
||||
// RequestObjectSupported mocks base method.
|
||||
func (m *MockAuthorizer) RequestObjectSupported() bool {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
20
pkg/op/op.go
20
pkg/op/op.go
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/go-chi/chi"
|
||||
"github.com/rs/cors"
|
||||
"github.com/zitadel/schema"
|
||||
"golang.org/x/exp/slog"
|
||||
"golang.org/x/text/language"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
|
@ -79,6 +80,9 @@ type OpenIDProvider interface {
|
|||
DefaultLogoutRedirectURI() string
|
||||
Probes() []ProbesFn
|
||||
|
||||
// EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20
|
||||
Logger() *slog.Logger
|
||||
|
||||
// Deprecated: Provider now implements http.Handler directly.
|
||||
HttpHandler() http.Handler
|
||||
}
|
||||
|
@ -174,6 +178,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
|
|||
storage: storage,
|
||||
endpoints: DefaultEndpoints,
|
||||
timer: make(<-chan time.Time),
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
for _, optFunc := range opOpts {
|
||||
|
@ -217,6 +222,7 @@ type Provider struct {
|
|||
timer <-chan time.Time
|
||||
accessTokenVerifierOpts []AccessTokenVerifierOpt
|
||||
idTokenHintVerifierOpts []IDTokenHintVerifierOpt
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (o *Provider) IssuerFromRequest(r *http.Request) string {
|
||||
|
@ -375,6 +381,10 @@ func (o *Provider) Probes() []ProbesFn {
|
|||
}
|
||||
}
|
||||
|
||||
func (o *Provider) Logger() *slog.Logger {
|
||||
return o.logger
|
||||
}
|
||||
|
||||
// Deprecated: Provider now implements http.Handler directly.
|
||||
func (o *Provider) HttpHandler() http.Handler {
|
||||
return o
|
||||
|
@ -523,6 +533,16 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithLogger lets a logger other than slog.Default().
|
||||
//
|
||||
// EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20
|
||||
func WithLogger(logger *slog.Logger) Option {
|
||||
return func(o *Provider) error {
|
||||
o.logger = logger
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler {
|
||||
issuerInterceptor := NewIssuerInterceptor(i)
|
||||
return func(handler http.Handler) http.Handler {
|
||||
|
|
|
@ -156,7 +156,7 @@ func TestRoutes(t *testing.T) {
|
|||
values: map[string]string{
|
||||
"client_id": client.GetID(),
|
||||
"redirect_uri": "https://example.com",
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
"response_type": string(oidc.ResponseTypeCode),
|
||||
},
|
||||
wantCode: http.StatusFound,
|
||||
|
@ -193,7 +193,7 @@ func TestRoutes(t *testing.T) {
|
|||
path: testProvider.TokenEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeBearer),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
"assertion": jwtToken,
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
|
@ -206,7 +206,7 @@ func TestRoutes(t *testing.T) {
|
|||
basicAuth: &basicAuth{"web", "secret"},
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeTokenExchange),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
"subject_token": jwtToken,
|
||||
"subject_token_type": string(oidc.AccessTokenType),
|
||||
},
|
||||
|
@ -223,7 +223,7 @@ func TestRoutes(t *testing.T) {
|
|||
basicAuth: &basicAuth{"sid1", "verysecret"},
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeClientCredentials),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`},
|
||||
|
@ -338,7 +338,7 @@ func TestRoutes(t *testing.T) {
|
|||
path: testProvider.DeviceAuthorizationEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"web", "secret"},
|
||||
values: map[string]string{
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
type SessionEnder interface {
|
||||
|
@ -15,6 +16,7 @@ type SessionEnder interface {
|
|||
Storage() Storage
|
||||
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
|
||||
DefaultLogoutRedirectURI() string
|
||||
Logger() *slog.Logger
|
||||
}
|
||||
|
||||
func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) {
|
||||
|
@ -31,12 +33,12 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
|
|||
}
|
||||
session, err := ValidateEndSessionRequest(r.Context(), req, ender)
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, ender.Logger())
|
||||
return
|
||||
}
|
||||
err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID)
|
||||
if err != nil {
|
||||
RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"))
|
||||
RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"), ender.Logger())
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, session.RedirectURI, http.StatusFound)
|
||||
|
|
|
@ -14,18 +14,18 @@ import (
|
|||
func ClientCredentialsExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
request, err := ParseClientCredentialsRequest(r, exchanger.Decoder())
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
}
|
||||
|
||||
validatedRequest, client, err := ValidateClientCredentialsRequest(r.Context(), request, exchanger)
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := CreateClientCredentialsTokenResponse(r.Context(), validatedRequest, exchanger, client)
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -13,20 +13,20 @@ import (
|
|||
func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder())
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
}
|
||||
if tokenReq.Code == "" {
|
||||
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"))
|
||||
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), exchanger.Logger())
|
||||
return
|
||||
}
|
||||
authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
return
|
||||
}
|
||||
resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code, "")
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
return
|
||||
}
|
||||
httphelper.MarshalJSON(w, resp)
|
||||
|
|
|
@ -136,17 +136,17 @@ func (r *tokenExchangeRequest) SetSubject(subject string) {
|
|||
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
tokenExchangeReq, clientID, clientSecret, err := ParseTokenExchangeRequest(r, exchanger.Decoder())
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
}
|
||||
|
||||
tokenExchangeRequest, client, err := ValidateTokenExchangeRequest(r.Context(), tokenExchangeReq, clientID, clientSecret, exchanger)
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
return
|
||||
}
|
||||
resp, err := CreateTokenExchangeResponse(r.Context(), tokenExchangeRequest, client, exchanger)
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
return
|
||||
}
|
||||
httphelper.MarshalJSON(w, resp)
|
||||
|
|
|
@ -18,23 +18,23 @@ type JWTAuthorizationGrantExchanger interface {
|
|||
func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizationGrantExchanger) {
|
||||
profileRequest, err := ParseJWTProfileGrantRequest(r, exchanger.Decoder())
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
}
|
||||
|
||||
tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, exchanger.JWTProfileVerifier(r.Context()))
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
return
|
||||
}
|
||||
|
||||
tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(r.Context(), tokenRequest.Issuer, profileRequest.Scope)
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
return
|
||||
}
|
||||
resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, exchanger)
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
return
|
||||
}
|
||||
httphelper.MarshalJSON(w, resp)
|
||||
|
|
|
@ -26,16 +26,16 @@ type RefreshTokenRequest interface {
|
|||
func RefreshTokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
tokenReq, err := ParseRefreshTokenRequest(r, exchanger.Decoder())
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
}
|
||||
validatedRequest, client, err := ValidateRefreshTokenRequest(r.Context(), tokenReq, exchanger)
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
return
|
||||
}
|
||||
resp, err := CreateTokenResponse(r.Context(), validatedRequest, client, exchanger, true, "", tokenReq.RefreshToken)
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
return
|
||||
}
|
||||
httphelper.MarshalJSON(w, resp)
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
type Exchanger interface {
|
||||
|
@ -22,6 +23,7 @@ type Exchanger interface {
|
|||
GrantTypeDeviceCodeSupported() bool
|
||||
AccessTokenVerifier(context.Context) *AccessTokenVerifier
|
||||
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
|
||||
Logger() *slog.Logger
|
||||
}
|
||||
|
||||
func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -63,10 +65,10 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
|||
return
|
||||
}
|
||||
case "":
|
||||
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"))
|
||||
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), exchanger.Logger())
|
||||
return
|
||||
}
|
||||
RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType))
|
||||
RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType), exchanger.Logger())
|
||||
}
|
||||
|
||||
// AuthenticatedTokenRequest is a helper interface for ParseAuthenticatedTokenRequest
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue