Merge pull request #456 from zitadel/next-main

Merge next into main in order to release v3. Merge conflicts were handled in an intermediate branch.

BREAKING CHANGE - Just making sure v3 release is triggered.
This commit is contained in:
Tim Möhlmann 2023-10-13 08:44:41 +03:00 committed by GitHub
commit 976b40620c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
118 changed files with 6091 additions and 981 deletions

View file

@ -44,9 +44,9 @@ Check the `/example` folder where example code for different scenarios is locate
```bash
# start oidc op server
# oidc discovery http://localhost:9998/.well-known/openid-configuration
go run github.com/zitadel/oidc/v2/example/server
go run github.com/zitadel/oidc/v3/example/server
# start oidc web client (in a new terminal)
CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v2/example/client/app
CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v3/example/client/app
```
- open http://localhost:9999/login in your browser
@ -56,11 +56,11 @@ CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid
for the dynamic issuer, just start it with:
```bash
go run github.com/zitadel/oidc/v2/example/server/dynamic
go run github.com/zitadel/oidc/v3/example/server/dynamic
```
the oidc web client above will still work, but if you add `oidc.local` (pointing to 127.0.0.1) in your hosts file you can also start it with:
```bash
CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://oidc.local:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v2/example/client/app
CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://oidc.local:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v3/example/client/app
```
> Note: Usernames are suffixed with the hostname (`test-user@localhost` or `test-user@oidc.local`)

View file

@ -1,6 +1,7 @@
package main
import (
"context"
"encoding/json"
"fmt"
"log"
@ -9,11 +10,11 @@ import (
"strings"
"time"
"github.com/gorilla/mux"
"github.com/go-chi/chi"
"github.com/sirupsen/logrus"
"github.com/zitadel/oidc/v2/pkg/client/rs"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/client/rs"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
const (
@ -27,12 +28,12 @@ func main() {
port := os.Getenv("PORT")
issuer := os.Getenv("ISSUER")
provider, err := rs.NewResourceServerFromKeyFile(issuer, keyPath)
provider, err := rs.NewResourceServerFromKeyFile(context.TODO(), issuer, keyPath)
if err != nil {
logrus.Fatalf("error creating provider %s", err.Error())
}
router := mux.NewRouter()
router := chi.NewRouter()
// public url accessible without any authorization
// will print `OK` and current timestamp
@ -47,7 +48,7 @@ func main() {
if !ok {
return
}
resp, err := rs.Introspect(r.Context(), provider, token)
resp, err := rs.Introspect[*oidc.IntrospectionResponse](r.Context(), provider, token)
if err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
@ -68,14 +69,14 @@ func main() {
if !ok {
return
}
resp, err := rs.Introspect(r.Context(), provider, token)
resp, err := rs.Introspect[*oidc.IntrospectionResponse](r.Context(), provider, token)
if err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
params := mux.Vars(r)
requestedClaim := params["claim"]
requestedValue := params["value"]
requestedClaim := chi.URLParam(r, "claim")
requestedValue := chi.URLParam(r, "value")
value, ok := resp.Claims[requestedClaim].(string)
if !ok || value == "" || value != requestedValue {
http.Error(w, "claim does not match", http.StatusForbidden)

View file

@ -1,19 +1,23 @@
package main
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"strings"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"golang.org/x/exp/slog"
"github.com/zitadel/oidc/v2/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
var (
@ -32,9 +36,25 @@ func main() {
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
logger := slog.New(
slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelDebug,
}),
)
client := &http.Client{
Timeout: time.Minute,
}
// enable outgoing request logging
logging.EnableHTTPClient(client,
logging.WithClientGroup("client"),
)
options := []rp.Option{
rp.WithCookieHandler(cookieHandler),
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
rp.WithHTTPClient(client),
rp.WithLogger(logger),
}
if clientSecret == "" {
options = append(options, rp.WithPKCE(cookieHandler))
@ -43,7 +63,10 @@ func main() {
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
}
provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, options...)
// One can add a logger to the context,
// pre-defining log attributes as required.
ctx := logging.ToContext(context.TODO(), logger)
provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, redirectURI, scopes, options...)
if err != nil {
logrus.Fatalf("error creating provider %s", err.Error())
}
@ -118,8 +141,22 @@ func main() {
//
// http.Handle(callbackPath, rp.CodeExchangeHandler(marshalToken, provider))
// simple counter for request IDs
var counter atomic.Int64
// enable incomming request logging
mw := logging.Middleware(
logging.WithLogger(logger),
logging.WithGroup("server"),
logging.WithIDFunc(func() slog.Attr {
return slog.Int64("id", counter.Add(1))
}),
)
lis := fmt.Sprintf("127.0.0.1:%s", port)
logrus.Infof("listening on http://%s/", lis)
logrus.Info("press ctrl+c to stop")
logrus.Fatal(http.ListenAndServe(lis, nil))
logger.Info("server listening, press ctrl+c to stop", "addr", lis)
err = http.ListenAndServe(lis, mw(http.DefaultServeMux))
if err != http.ErrServerClosed {
logger.Error("server terminated", "error", err)
os.Exit(1)
}
}

View file

@ -11,8 +11,8 @@ import (
"github.com/sirupsen/logrus"
"github.com/zitadel/oidc/v2/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v3/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
)
var (
@ -39,13 +39,13 @@ func main() {
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
}
provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, "", scopes, options...)
provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, "", scopes, options...)
if err != nil {
logrus.Fatalf("error creating provider %s", err.Error())
}
logrus.Info("starting device authorization flow")
resp, err := rp.DeviceAuthorization(scopes, provider)
resp, err := rp.DeviceAuthorization(ctx, scopes, provider, nil)
if err != nil {
logrus.Fatal(err)
}

View file

@ -10,10 +10,10 @@ import (
"golang.org/x/oauth2"
githubOAuth "golang.org/x/oauth2/github"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/client/rp/cli"
"github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/client/rp/cli"
"github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
var (

View file

@ -13,7 +13,7 @@ import (
"github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/zitadel/oidc/v2/pkg/client/profile"
"github.com/zitadel/oidc/v3/pkg/client/profile"
)
var client = http.DefaultClient
@ -25,7 +25,7 @@ func main() {
scopes := strings.Split(os.Getenv("SCOPES"), " ")
if keyPath != "" {
ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath, scopes)
ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(context.TODO(), issuer, keyPath, scopes)
if err != nil {
logrus.Fatalf("error creating token source %s", err.Error())
}
@ -76,7 +76,7 @@ func main() {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(issuer, key, scopes)
ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(context.TODO(), issuer, key, scopes)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return

View file

@ -6,9 +6,9 @@ import (
"html/template"
"net/http"
"github.com/gorilla/mux"
"github.com/go-chi/chi"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v3/pkg/op"
)
const (
@ -43,7 +43,7 @@ var (
type login struct {
authenticate authenticate
router *mux.Router
router chi.Router
callback func(context.Context, string) string
}
@ -57,9 +57,9 @@ func NewLogin(authenticate authenticate, callback func(context.Context, string)
}
func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) {
l.router = mux.NewRouter()
l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler)
l.router.Path("/username").Methods("POST").HandlerFunc(issuerInterceptor.HandlerFunc(l.checkLoginHandler))
l.router = chi.NewRouter()
l.router.Get("/username", l.loginHandler)
l.router.With(issuerInterceptor.Handler).Post("/username", l.checkLoginHandler)
}
type authenticate interface {

View file

@ -7,11 +7,11 @@ import (
"log"
"net/http"
"github.com/gorilla/mux"
"github.com/go-chi/chi"
"golang.org/x/text/language"
"github.com/zitadel/oidc/v2/example/server/storage"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v3/example/server/storage"
"github.com/zitadel/oidc/v3/pkg/op"
)
const (
@ -47,7 +47,7 @@ func main() {
//be sure to create a proper crypto random key and manage it securely!
key := sha256.Sum256([]byte("test"))
router := mux.NewRouter()
router := chi.NewRouter()
//for simplicity, we provide a very small default page for users who have signed out
router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) {
@ -76,7 +76,7 @@ func main() {
//regardless of how many pages / steps there are in the process, the UI must be registered in the router,
//so we will direct all calls to /login to the login UI
router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router))
router.Mount("/login/", http.StripPrefix("/login", l.router))
//we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
//is served on the correct path
@ -84,7 +84,7 @@ func main() {
//if your issuer ends with a path (e.g. http://localhost:9998/custom/path/),
//then you would have to set the path prefix (/custom/path/):
//router.PathPrefix("/custom/path/").Handler(http.StripPrefix("/custom/path", provider.HttpHandler()))
router.PathPrefix("/").Handler(provider.HttpHandler())
router.Mount("/", provider)
server := &http.Server{
Addr: ":" + port,

View file

@ -1,21 +1,34 @@
package exampleop
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"github.com/gorilla/mux"
"github.com/go-chi/chi"
"github.com/gorilla/securecookie"
"github.com/sirupsen/logrus"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v3/pkg/op"
)
type deviceAuthenticate interface {
CheckUsernamePasswordSimple(username, password string) error
op.DeviceAuthorizationStorage
// GetDeviceAuthorizationByUserCode resturns the current state of the device authorization flow,
// identified by the user code.
GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error)
// CompleteDeviceAuthorization marks a device authorization entry as Completed,
// identified by userCode. The Subject is added to the state, so that
// GetDeviceAuthorizatonState can use it to create a new Access Token.
CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error
// DenyDeviceAuthorization marks a device authorization entry as Denied.
DenyDeviceAuthorization(ctx context.Context, userCode string) error
}
type deviceLogin struct {
@ -23,14 +36,14 @@ type deviceLogin struct {
cookie *securecookie.SecureCookie
}
func registerDeviceAuth(storage deviceAuthenticate, router *mux.Router) {
func registerDeviceAuth(storage deviceAuthenticate, router chi.Router) {
l := &deviceLogin{
storage: storage,
cookie: securecookie.New(securecookie.GenerateRandomKey(32), nil),
}
router.HandleFunc("", l.userCodeHandler)
router.Path("/login").Methods(http.MethodPost).HandlerFunc(l.loginHandler)
router.HandleFunc("/", l.userCodeHandler)
router.Post("/login", l.loginHandler)
router.HandleFunc("/confirm", l.confirmHandler)
}

View file

@ -5,13 +5,13 @@ import (
"fmt"
"net/http"
"github.com/gorilla/mux"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/go-chi/chi"
"github.com/zitadel/oidc/v3/pkg/op"
)
type login struct {
authenticate authenticate
router *mux.Router
router chi.Router
callback func(context.Context, string) string
}
@ -25,9 +25,9 @@ func NewLogin(authenticate authenticate, callback func(context.Context, string)
}
func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) {
l.router = mux.NewRouter()
l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler)
l.router.Path("/username").Methods("POST").HandlerFunc(issuerInterceptor.HandlerFunc(l.checkLoginHandler))
l.router = chi.NewRouter()
l.router.Get("/username", l.loginHandler)
l.router.Post("/username", issuerInterceptor.HandlerFunc(l.checkLoginHandler))
}
type authenticate interface {

View file

@ -4,13 +4,16 @@ import (
"crypto/sha256"
"log"
"net/http"
"sync/atomic"
"time"
"github.com/gorilla/mux"
"github.com/go-chi/chi"
"github.com/zitadel/logging"
"golang.org/x/exp/slog"
"golang.org/x/text/language"
"github.com/zitadel/oidc/v2/example/server/storage"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v3/example/server/storage"
"github.com/zitadel/oidc/v3/pkg/op"
)
const (
@ -31,26 +34,33 @@ type Storage interface {
deviceAuthenticate
}
// simple counter for request IDs
var counter atomic.Int64
// SetupServer creates an OIDC server with Issuer=http://localhost:<port>
//
// Use one of the pre-made clients in storage/clients.go or register a new one.
func SetupServer(issuer string, storage Storage, extraOptions ...op.Option) *mux.Router {
func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer bool, extraOptions ...op.Option) chi.Router {
// the OpenID Provider requires a 32-byte key for (token) encryption
// be sure to create a proper crypto random key and manage it securely!
key := sha256.Sum256([]byte("test"))
router := mux.NewRouter()
router := chi.NewRouter()
router.Use(logging.Middleware(
logging.WithLogger(logger),
logging.WithIDFunc(func() slog.Attr {
return slog.Int64("id", counter.Add(1))
}),
))
// for simplicity, we provide a very small default page for users who have signed out
router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) {
_, err := w.Write([]byte("signed out successfully"))
if err != nil {
log.Printf("error serving logged out page: %v", err)
}
w.Write([]byte("signed out successfully"))
// no need to check/log error, this will be handeled by the middleware.
})
// creation of the OpenIDProvider with the just created in-memory Storage
provider, err := newOP(storage, issuer, key, extraOptions...)
provider, err := newOP(storage, issuer, key, logger, extraOptions...)
if err != nil {
log.Fatal(err)
}
@ -62,17 +72,23 @@ func SetupServer(issuer string, storage Storage, extraOptions ...op.Option) *mux
// regardless of how many pages / steps there are in the process, the UI must be registered in the router,
// so we will direct all calls to /login to the login UI
router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router))
router.Mount("/login/", http.StripPrefix("/login", l.router))
router.PathPrefix("/device").Subrouter()
registerDeviceAuth(storage, router.PathPrefix("/device").Subrouter())
router.Route("/device", func(r chi.Router) {
registerDeviceAuth(storage, r)
})
handler := http.Handler(provider)
if wrapServer {
handler = op.NewLegacyServer(provider, *op.DefaultEndpoints)
}
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
// is served on the correct path
//
// if your issuer ends with a path (e.g. http://localhost:9998/custom/path/),
// then you would have to set the path prefix (/custom/path/)
router.PathPrefix("/").Handler(provider.HttpHandler())
router.Mount("/", handler)
return router
}
@ -80,7 +96,7 @@ func SetupServer(issuer string, storage Storage, extraOptions ...op.Option) *mux
// newOP will create an OpenID Provider for localhost on a specified port with a given encryption key
// and a predefined default logout uri
// it will enable all options (see descriptions)
func newOP(storage op.Storage, issuer string, key [32]byte, extraOptions ...op.Option) (op.OpenIDProvider, error) {
func newOP(storage op.Storage, issuer string, key [32]byte, logger *slog.Logger, extraOptions ...op.Option) (op.OpenIDProvider, error) {
config := &op.Config{
CryptoKey: key,
@ -114,10 +130,12 @@ func newOP(storage op.Storage, issuer string, key [32]byte, extraOptions ...op.O
}
handler, err := op.NewOpenIDProvider(issuer, config, storage,
append([]op.Option{
// we must explicitly allow the use of the http issuer
//we must explicitly allow the use of the http issuer
op.WithAllowInsecure(),
// as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth
op.WithCustomAuthEndpoint(op.NewEndpoint("auth")),
// Pass our logger to the OP
op.WithLogger(logger.WithGroup("op")),
}, extraOptions...)...,
)
if err != nil {

View file

@ -2,11 +2,12 @@ package main
import (
"fmt"
"log"
"net/http"
"os"
"github.com/zitadel/oidc/v2/example/server/exampleop"
"github.com/zitadel/oidc/v2/example/server/storage"
"github.com/zitadel/oidc/v3/example/server/exampleop"
"github.com/zitadel/oidc/v3/example/server/storage"
"golang.org/x/exp/slog"
)
func main() {
@ -20,16 +21,22 @@ func main() {
// in this example it will be handled in-memory
storage := storage.NewStorage(storage.NewUserStore(issuer))
router := exampleop.SetupServer(issuer, storage)
logger := slog.New(
slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelDebug,
}),
)
router := exampleop.SetupServer(issuer, storage, logger, false)
server := &http.Server{
Addr: ":" + port,
Handler: router,
}
log.Printf("server listening on http://localhost:%s/", port)
log.Println("press ctrl+c to stop")
logger.Info("server listening, press ctrl+c to stop", "addr", fmt.Sprintf("http://localhost:%s/", port))
err := server.ListenAndServe()
if err != nil {
log.Fatal(err)
if err != http.ErrServerClosed {
logger.Error("server terminated", "error", err)
os.Exit(1)
}
}

View file

@ -3,8 +3,8 @@ package storage
import (
"time"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
)
var (
@ -185,7 +185,7 @@ func WebClient(id, secret string, redirectURIs ...string) *Client {
authMethod: oidc.AuthMethodBasic,
loginURL: defaultLoginURL,
responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode},
grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken},
grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken, oidc.GrantTypeTokenExchange},
accessTokenType: op.AccessTokenTypeBearer,
devMode: false,
idTokenUserinfoClaimsAssertion: false,

View file

@ -3,10 +3,11 @@ package storage
import (
"time"
"golang.org/x/exp/slog"
"golang.org/x/text/language"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
)
const (
@ -41,6 +42,19 @@ type AuthRequest struct {
authTime time.Time
}
// LogValue allows you to define which fields will be logged.
// Implements the [slog.LogValuer]
func (a *AuthRequest) LogValue() slog.Value {
return slog.GroupValue(
slog.String("id", a.ID),
slog.Time("creation_date", a.CreationDate),
slog.Any("scopes", a.Scopes),
slog.String("response_type", string(a.ResponseType)),
slog.String("app_id", a.ApplicationID),
slog.String("callback_uri", a.CallbackURI),
)
}
func (a *AuthRequest) GetID() string {
return a.ID
}

View file

@ -11,11 +11,11 @@ import (
"sync"
"time"
jose "github.com/go-jose/go-jose/v3"
"github.com/google/uuid"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
)
// serviceKey1 is a public key which will be used for the JWT Profile Authorization Grant

View file

@ -4,10 +4,10 @@ import (
"context"
"time"
"gopkg.in/square/go-jose.v2"
jose "github.com/go-jose/go-jose/v3"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
)
type multiStorage struct {

10
go.mod
View file

@ -1,13 +1,13 @@
module github.com/zitadel/oidc/v2
module github.com/zitadel/oidc/v3
go 1.19
require (
github.com/go-chi/chi v1.5.4
github.com/go-jose/go-jose/v3 v3.0.0
github.com/golang/mock v1.6.0
github.com/google/go-github/v31 v31.0.0
github.com/google/uuid v1.3.1
github.com/gorilla/mux v1.8.0
github.com/gorilla/schema v1.2.0
github.com/gorilla/securecookie v1.1.1
github.com/jeremija/gosubmit v0.2.7
github.com/muhlemmer/gu v0.3.1
@ -15,11 +15,13 @@ require (
github.com/rs/cors v1.10.1
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.8.4
github.com/zitadel/logging v0.4.0
github.com/zitadel/schema v1.3.0
go.opentelemetry.io/otel v1.19.0
go.opentelemetry.io/otel/trace v1.19.0
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
golang.org/x/oauth2 v0.13.0
golang.org/x/text v0.13.0
gopkg.in/square/go-jose.v2 v2.6.0
)
require (

20
go.sum
View file

@ -1,6 +1,10 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-chi/chi v1.5.4 h1:QHdzF2szwjqVV4wmByUnTcsbIg7UGaQ0tPF2t5GcAIs=
github.com/go-chi/chi v1.5.4/go.mod h1:uaf8YgoFazUOkPBG7fxPftUylNumIev9awIWOENIuEg=
github.com/go-jose/go-jose/v3 v3.0.0 h1:s6rrhirfEP/CGIoc6p+PZAeogN2SxKav6Wp7+dyMWVo=
github.com/go-jose/go-jose/v3 v3.0.0/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
@ -13,6 +17,7 @@ github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
@ -23,10 +28,6 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/schema v1.2.0 h1:YufUaxZYCKGFuAq3c96BOhjgd5nmXiOY9NGzF247Tsc=
github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/jeremija/gosubmit v0.2.7 h1:At0OhGCFGPXyjPYAsCchoBUhE099pcBXmsb4iZqROIc=
@ -44,10 +45,15 @@ github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/zitadel/logging v0.4.0 h1:lRAIFgaRoJpLNbsL7jtIYHcMDoEJP9QZB4GqMfl4xaA=
github.com/zitadel/logging v0.4.0/go.mod h1:6uALRJawpkkuUPCkgzfgcPR3c2N908wqnOnIrRelUFc=
github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0=
github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc=
go.opentelemetry.io/otel v1.19.0 h1:MuS/TNf4/j4IXsZuJegVzI1cwut7Qc00344rgH7p8bs=
go.opentelemetry.io/otel v1.19.0/go.mod h1:i0QyjOq3UPoTzff0PJB2N66fb4S0+rSbSB15/oyH9fY=
go.opentelemetry.io/otel/metric v1.19.0 h1:aTzpGtV0ar9wlV4Sna9sdJyII5jTVJEvKETPiOKwvpE=
@ -55,9 +61,12 @@ go.opentelemetry.io/otel/metric v1.19.0/go.mod h1:L5rUsV9kM1IxCj1MmSdS+JQAcVm319
go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1Dzxpg=
go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
@ -102,8 +111,7 @@ google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -8,8 +8,8 @@ import (
"fmt"
"os"
tu "github.com/zitadel/oidc/v2/internal/testutil"
"github.com/zitadel/oidc/v2/pkg/oidc"
tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
var custom = map[string]any{

View file

@ -8,8 +8,9 @@ import (
"errors"
"time"
"github.com/zitadel/oidc/v2/pkg/oidc"
"gopkg.in/square/go-jose.v2"
jose "github.com/go-jose/go-jose/v3"
"github.com/muhlemmer/gu"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
// KeySet implements oidc.Keys
@ -17,7 +18,7 @@ type KeySet struct{}
// VerifySignature implments op.KeySet.
func (KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
if ctx.Err() != nil {
if err = ctx.Err(); err != nil {
return nil, err
}
@ -45,6 +46,16 @@ func init() {
}
}
type JWTProfileKeyStorage struct{}
func (JWTProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID string, clientID string) (*jose.JSONWebKey, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
return gu.Ptr(WebKey.Public()), nil
}
func signEncodeTokenClaims(claims any) string {
payload, err := json.Marshal(claims)
if err != nil {
@ -106,6 +117,25 @@ func NewAccessToken(issuer, subject string, audience []string, expiration time.T
return NewAccessTokenCustom(issuer, subject, audience, expiration, jwtid, clientID, skew, nil)
}
func NewJWTProfileAssertion(issuer, clientID string, audience []string, issuedAt, expiration time.Time) (string, *oidc.JWTTokenRequest) {
req := &oidc.JWTTokenRequest{
Issuer: issuer,
Subject: clientID,
Audience: audience,
ExpiresAt: oidc.FromTime(expiration),
IssuedAt: oidc.FromTime(issuedAt),
}
// make sure the private claim map is set correctly
data, err := json.Marshal(req)
if err != nil {
panic(err)
}
if err = json.Unmarshal(data, req); err != nil {
panic(err)
}
return signEncodeTokenClaims(req), req
}
const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg`
// These variables always result in a valid token
@ -137,6 +167,10 @@ func ValidAccessToken() (string, *oidc.AccessTokenClaims) {
return NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew)
}
func ValidJWTProfileAssertion() (string, *oidc.JWTTokenRequest) {
return NewJWTProfileAssertion(ValidClientID, ValidClientID, []string{ValidIssuer}, time.Now(), ValidExpiration)
}
// ACRVerify is a oidc.ACRVerifier func.
func ACRVerify(acr string) error {
if acr != ValidACR {

View file

@ -11,24 +11,25 @@ import (
"strings"
"time"
jose "github.com/go-jose/go-jose/v3"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/v2/pkg/crypto"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/crypto"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
var Encoder = httphelper.Encoder(oidc.NewEncoder())
// Discover calls the discovery endpoint of the provided issuer and returns its configuration
// It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url
func Discover(issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) {
func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) {
wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
if len(wellKnownUrl) == 1 && wellKnownUrl[0] != "" {
wellKnown = wellKnownUrl[0]
}
req, err := http.NewRequest("GET", wellKnown, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil)
if err != nil {
return nil, err
}
@ -37,6 +38,10 @@ func Discover(issuer string, httpClient *http.Client, wellKnownUrl ...string) (*
if err != nil {
return nil, err
}
if logger, ok := logging.FromContext(ctx); ok {
logger.Debug("discover", "config", discoveryConfig)
}
if discoveryConfig.Issuer != issuer {
return nil, oidc.ErrIssuerInvalid
}
@ -48,12 +53,12 @@ type TokenEndpointCaller interface {
HttpClient() *http.Client
}
func CallTokenEndpoint(request any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
return callTokenEndpoint(request, nil, caller)
func CallTokenEndpoint(ctx context.Context, request any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
return callTokenEndpoint(ctx, request, nil, caller)
}
func callTokenEndpoint(request any, authFn any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn)
func callTokenEndpoint(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn)
if err != nil {
return nil, err
}
@ -80,8 +85,8 @@ type EndSessionCaller interface {
HttpClient() *http.Client
}
func CallEndSessionEndpoint(request any, authFn any, caller EndSessionCaller) (*url.URL, error) {
req, err := httphelper.FormRequest(caller.GetEndSessionEndpoint(), request, Encoder, authFn)
func CallEndSessionEndpoint(ctx context.Context, request any, authFn any, caller EndSessionCaller) (*url.URL, error) {
req, err := httphelper.FormRequest(ctx, caller.GetEndSessionEndpoint(), request, Encoder, authFn)
if err != nil {
return nil, err
}
@ -123,8 +128,8 @@ type RevokeRequest struct {
ClientSecret string `schema:"client_secret"`
}
func CallRevokeEndpoint(request any, authFn any, caller RevokeCaller) error {
req, err := httphelper.FormRequest(caller.GetRevokeEndpoint(), request, Encoder, authFn)
func CallRevokeEndpoint(ctx context.Context, request any, authFn any, caller RevokeCaller) error {
req, err := httphelper.FormRequest(ctx, caller.GetRevokeEndpoint(), request, Encoder, authFn)
if err != nil {
return err
}
@ -151,8 +156,8 @@ func CallRevokeEndpoint(request any, authFn any, caller RevokeCaller) error {
return nil
}
func CallTokenExchangeEndpoint(request any, authFn any, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) {
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn)
func CallTokenExchangeEndpoint(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) {
req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn)
if err != nil {
return nil, err
}
@ -192,8 +197,8 @@ type DeviceAuthorizationCaller interface {
HttpClient() *http.Client
}
func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) {
req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil)
func CallDeviceAuthorizationEndpoint(ctx context.Context, request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller, authFn any) (*oidc.DeviceAuthorizationResponse, error) {
req, err := httphelper.FormRequest(ctx, caller.GetDeviceAuthorizationEndpoint(), request, Encoder, authFn)
if err != nil {
return nil, err
}
@ -214,7 +219,7 @@ type DeviceAccessTokenRequest struct {
}
func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil)
req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, nil)
if err != nil {
return nil, err
}

View file

@ -1,6 +1,7 @@
package client
import (
"context"
"net/http"
"testing"
@ -36,7 +37,7 @@ func TestDiscover(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Discover(tt.args.issuer, http.DefaultClient, tt.args.wellKnownUrl...)
got, err := Discover(context.Background(), tt.args.issuer, http.DefaultClient, tt.args.wellKnownUrl...)
if tt.wantErr {
assert.Error(t, err)
return

View file

@ -2,33 +2,64 @@ package client_test
import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"os"
"os/signal"
"strconv"
"syscall"
"testing"
"time"
"github.com/jeremija/gosubmit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slog"
"github.com/zitadel/oidc/v2/example/server/exampleop"
"github.com/zitadel/oidc/v2/example/server/storage"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/client/rs"
"github.com/zitadel/oidc/v2/pkg/client/tokenexchange"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v3/example/server/exampleop"
"github.com/zitadel/oidc/v3/example/server/storage"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/client/rs"
"github.com/zitadel/oidc/v3/pkg/client/tokenexchange"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
)
var Logger = slog.New(
slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelDebug,
}),
)
var CTX context.Context
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT)
defer cancel()
CTX, cancel = context.WithTimeout(ctx, time.Minute)
defer cancel()
return m.Run()
}())
}
func TestRelyingPartySession(t *testing.T) {
for _, wrapServer := range []bool{false, true} {
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
testRelyingPartySession(t, wrapServer)
})
}
}
func testRelyingPartySession(t *testing.T, wrapServer bool) {
t.Log("------- start example OP ------")
targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
@ -36,17 +67,17 @@ func TestRelyingPartySession(t *testing.T) {
opServer := httptest.NewServer(&dh)
defer opServer.Close()
t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
t.Log("------- run authorization code flow ------")
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, "secret")
provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret")
t.Log("------- refresh tokens ------")
newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "")
newTokens, err := rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
require.NoError(t, err, "refresh token")
assert.NotNil(t, newTokens, "access token")
t.Logf("new access token %s", newTokens.AccessToken)
@ -54,11 +85,13 @@ func TestRelyingPartySession(t *testing.T) {
t.Logf("new token type %s", newTokens.TokenType)
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
assert.NotEmpty(t, newTokens.Extra("id_token"), "new idToken")
assert.NotEmpty(t, newTokens.IDToken, "new idToken")
assert.NotNil(t, newTokens.IDTokenClaims)
assert.Equal(t, newTokens.IDTokenClaims.Subject, tokens.IDTokenClaims.Subject)
t.Log("------ end session (logout) ------")
newLoc, err := rp.EndSession(provider, idToken, "", "")
newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "")
require.NoError(t, err, "logout")
if newLoc != nil {
t.Logf("redirect to %s", newLoc)
@ -67,17 +100,25 @@ func TestRelyingPartySession(t *testing.T) {
}
t.Log("------ attempt refresh again (should fail) ------")
t.Log("trying original refresh token", refreshToken)
_, err = rp.RefreshAccessToken(provider, refreshToken, "", "")
t.Log("trying original refresh token", tokens.RefreshToken)
_, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
assert.Errorf(t, err, "refresh with original")
if newTokens.RefreshToken != "" {
t.Log("trying replacement refresh token", newTokens.RefreshToken)
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
_, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, newTokens.RefreshToken, "", "")
assert.Errorf(t, err, "refresh with replacement")
}
}
func TestResourceServerTokenExchange(t *testing.T) {
for _, wrapServer := range []bool{false, true} {
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
testResourceServerTokenExchange(t, wrapServer)
})
}
}
func testResourceServerTokenExchange(t *testing.T, wrapServer bool) {
t.Log("------- start example OP ------")
targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
@ -85,23 +126,24 @@ func TestResourceServerTokenExchange(t *testing.T) {
opServer := httptest.NewServer(&dh)
defer opServer.Close()
t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
clientSecret := "secret"
t.Log("------- run authorization code flow ------")
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)
provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)
resourceServer, err := rs.NewResourceServerClientCredentials(opServer.URL, clientID, clientSecret)
resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret)
require.NoError(t, err, "new resource server")
t.Log("------- exchage refresh tokens (impersonation) ------")
tokenExchangeResponse, err := tokenexchange.ExchangeToken(
CTX,
resourceServer,
refreshToken,
tokens.RefreshToken,
oidc.RefreshTokenType,
"",
"",
@ -119,7 +161,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
t.Log("------ end session (logout) ------")
newLoc, err := rp.EndSession(provider, idToken, "", "")
newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "")
require.NoError(t, err, "logout")
if newLoc != nil {
t.Logf("redirect to %s", newLoc)
@ -130,8 +172,9 @@ func TestResourceServerTokenExchange(t *testing.T) {
t.Log("------- attempt exchage again (should fail) ------")
tokenExchangeResponse, err = tokenexchange.ExchangeToken(
CTX,
resourceServer,
refreshToken,
tokens.RefreshToken,
oidc.RefreshTokenType,
"",
"",
@ -145,7 +188,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
require.Nil(t, tokenExchangeResponse, "token exchange response")
}
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) {
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) {
targetURL := "http://local-site"
localURL, err := url.Parse(targetURL + "/login?requestID=1234")
require.NoError(t, err, "local url")
@ -167,6 +210,7 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
key := []byte("test1234test1234")
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
provider, err = rp.NewRelyingPartyOIDC(
CTX,
opServer.URL,
clientID,
clientSecret,
@ -241,7 +285,8 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
}
var email string
redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
redirect := func(w http.ResponseWriter, r *http.Request, newTokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
tokens = newTokens
require.NotNil(t, tokens, "tokens")
require.NotNil(t, info, "info")
t.Log("access token", tokens.AccessToken)
@ -249,9 +294,6 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
t.Log("id token", tokens.IDToken)
t.Log("email", info.Email)
accessToken = tokens.AccessToken
refreshToken = tokens.RefreshToken
idToken = tokens.IDToken
email = info.Email
http.Redirect(w, r, targetURL, 302)
}
@ -273,12 +315,12 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
require.NoError(t, err, "get fully-authorizied redirect location")
require.Equal(t, targetURL, authorizedURL.String(), "fully-authorizied redirect location")
require.NotEmpty(t, idToken, "id token")
assert.NotEmpty(t, refreshToken, "refresh token")
assert.NotEmpty(t, accessToken, "access token")
require.NotEmpty(t, tokens.IDToken, "id token")
assert.NotEmpty(t, tokens.RefreshToken, "refresh token")
assert.NotEmpty(t, tokens.AccessToken, "access token")
assert.NotEmpty(t, email, "email")
return provider, accessToken, refreshToken, idToken
return provider, tokens
}
func TestErrorFromPromptNone(t *testing.T) {
@ -299,7 +341,7 @@ func TestErrorFromPromptNone(t *testing.T) {
opServer := httptest.NewServer(&dh)
defer opServer.Close()
t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, op.WithHttpInterceptors(
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, false, op.WithHttpInterceptors(
func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("request to %s", r.URL)
@ -317,6 +359,7 @@ func TestErrorFromPromptNone(t *testing.T) {
key := []byte("test1234test1234")
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
provider, err := rp.NewRelyingPartyOIDC(
CTX,
opServer.URL,
clientID,
clientSecret,
@ -412,7 +455,7 @@ func getForm(t *testing.T, desc string, httpClient *http.Client, uri *url.URL) [
func fillForm(t *testing.T, desc string, httpClient *http.Client, body []byte, uri *url.URL, opts ...gosubmit.Option) *url.URL {
// TODO: switch to io.NopCloser when go1.15 support is dropped
req := gosubmit.ParseWithURL(ioutil.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest(
req := gosubmit.ParseWithURL(io.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest(
append([]gosubmit.Option{gosubmit.AutoFill()}, opts...)...,
)
if req.URL.Scheme == "" {

View file

@ -1,17 +1,18 @@
package client
import (
"context"
"net/url"
"golang.org/x/oauth2"
"github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
// JWTProfileExchange handles the oauth2 jwt profile exchange
func JWTProfileExchange(jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) {
return CallTokenEndpoint(jwtProfileGrantRequest, caller)
func JWTProfileExchange(ctx context.Context, jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) {
return CallTokenEndpoint(ctx, jwtProfileGrantRequest, caller)
}
func ClientAssertionCodeOptions(assertion string) []oauth2.AuthCodeOption {

View file

@ -10,7 +10,7 @@ const (
applicationKey = "application"
)
type keyFile struct {
type KeyFile struct {
Type string `json:"type"` // serviceaccount or application
KeyID string `json:"keyId"`
Key string `json:"key"`
@ -23,7 +23,7 @@ type keyFile struct {
ClientID string `json:"clientId"`
}
func ConfigFromKeyFile(path string) (*keyFile, error) {
func ConfigFromKeyFile(path string) (*KeyFile, error) {
data, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
@ -31,8 +31,8 @@ func ConfigFromKeyFile(path string) (*keyFile, error) {
return ConfigFromKeyFileData(data)
}
func ConfigFromKeyFileData(data []byte) (*keyFile, error) {
var f keyFile
func ConfigFromKeyFileData(data []byte) (*KeyFile, error) {
var f KeyFile
if err := json.Unmarshal(data, &f); err != nil {
return nil, err
}

View file

@ -1,16 +1,22 @@
package profile
import (
"context"
"net/http"
"time"
jose "github.com/go-jose/go-jose/v3"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/v2/pkg/client"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/client"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
type TokenSource interface {
oauth2.TokenSource
TokenCtx(context.Context) (*oauth2.Token, error)
}
// jwtProfileTokenSource implement the oauth2.TokenSource
// it will request a token using the OAuth2 JWT Profile Grant
// therefore sending an `assertion` by signing a JWT with the provided private key
@ -23,23 +29,38 @@ type jwtProfileTokenSource struct {
tokenEndpoint string
}
func NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath string, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) {
keyData, err := client.ConfigFromKeyFile(keyPath)
// NewJWTProfileTokenSourceFromKeyFile returns an implementation of TokenSource
// It will request a token using the OAuth2 JWT Profile Grant,
// therefore sending an `assertion` by singing a JWT with the provided private key from jsonFile.
//
// The passed context is only used for the call to the Discover endpoint.
func NewJWTProfileTokenSourceFromKeyFile(ctx context.Context, issuer, jsonFile string, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) {
keyData, err := client.ConfigFromKeyFile(jsonFile)
if err != nil {
return nil, err
}
return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
}
func NewJWTProfileTokenSourceFromKeyFileData(issuer string, data []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) {
keyData, err := client.ConfigFromKeyFileData(data)
// NewJWTProfileTokenSourceFromKeyFileData returns an implementation of oauth2.TokenSource
// It will request a token using the OAuth2 JWT Profile Grant,
// therefore sending an `assertion` by singing a JWT with the provided private key in jsonData.
//
// The passed context is only used for the call to the Discover endpoint.
func NewJWTProfileTokenSourceFromKeyFileData(ctx context.Context, issuer string, jsonData []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) {
keyData, err := client.ConfigFromKeyFileData(jsonData)
if err != nil {
return nil, err
}
return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
}
func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) {
// NewJWTProfileSource returns an implementation of oauth2.TokenSource
// It will request a token using the OAuth2 JWT Profile Grant,
// therefore sending an `assertion` by singing a JWT with the provided private key.
//
// The passed context is only used for the call to the Discover endpoint.
func NewJWTProfileTokenSource(ctx context.Context, issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) {
signer, err := client.NewSignerFromPrivateKeyByte(key, keyID)
if err != nil {
return nil, err
@ -55,7 +76,7 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes
opt(source)
}
if source.tokenEndpoint == "" {
config, err := client.Discover(issuer, source.httpClient)
config, err := client.Discover(ctx, issuer, source.httpClient)
if err != nil {
return nil, err
}
@ -64,13 +85,13 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes
return source, nil
}
func WithHTTPClient(client *http.Client) func(*jwtProfileTokenSource) {
func WithHTTPClient(client *http.Client) func(source *jwtProfileTokenSource) {
return func(source *jwtProfileTokenSource) {
source.httpClient = client
}
}
func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(*jwtProfileTokenSource) {
func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(source *jwtProfileTokenSource) {
return func(source *jwtProfileTokenSource) {
source.tokenEndpoint = tokenEndpoint
}
@ -85,9 +106,13 @@ func (j *jwtProfileTokenSource) HttpClient() *http.Client {
}
func (j *jwtProfileTokenSource) Token() (*oauth2.Token, error) {
return j.TokenCtx(context.Background())
}
func (j *jwtProfileTokenSource) TokenCtx(ctx context.Context) (*oauth2.Token, error) {
assertion, err := client.SignedJWTProfileAssertion(j.clientID, j.audience, time.Hour, j.signer)
if err != nil {
return nil, err
}
return client.JWTProfileExchange(oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j)
return client.JWTProfileExchange(ctx, oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j)
}

View file

@ -4,9 +4,9 @@ import (
"context"
"net/http"
"github.com/zitadel/oidc/v2/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
const (

View file

@ -1,7 +1,7 @@
package rp
import (
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange"
"github.com/zitadel/oidc/v3/pkg/oidc/grants/tokenexchange"
)
// DelegationTokenRequest is an implementation of TokenExchangeRequest

View file

@ -5,8 +5,8 @@ import (
"fmt"
"time"
"github.com/zitadel/oidc/v2/pkg/client"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/client"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) {
@ -32,19 +32,21 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.
// DeviceAuthorization starts a new Device Authorization flow as defined
// in RFC 8628, section 3.1 and 3.2:
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) {
func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, authFn any) (*oidc.DeviceAuthorizationResponse, error) {
ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAuthorization")
req, err := newDeviceClientCredentialsRequest(scopes, rp)
if err != nil {
return nil, err
}
return client.CallDeviceAuthorizationEndpoint(req, rp)
return client.CallDeviceAuthorizationEndpoint(ctx, req, rp, authFn)
}
// DeviceAccessToken attempts to obtain tokens from a Device Authorization,
// by means of polling as defined in RFC, section 3.3 and 3.4:
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4
func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAccessToken")
req := &client.DeviceAccessTokenRequest{
DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
GrantType: oidc.GrantTypeDeviceCode,

View file

@ -7,10 +7,10 @@ import (
"net/http"
"sync"
"gopkg.in/square/go-jose.v2"
jose "github.com/go-jose/go-jose/v3"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet {

17
pkg/client/rp/log.go Normal file
View file

@ -0,0 +1,17 @@
package rp
import (
"context"
"github.com/zitadel/logging"
"golang.org/x/exp/slog"
)
func logCtxWithRPData(ctx context.Context, rp RelyingParty, attrs ...any) context.Context {
logger, ok := rp.Logger(ctx)
if !ok {
return ctx
}
logger = logger.With(slog.Group("rp", attrs...))
return logging.ToContext(ctx, logger)
}

View file

@ -7,16 +7,17 @@ import (
"fmt"
"net/http"
"net/url"
"strings"
"time"
jose "github.com/go-jose/go-jose/v3"
"github.com/google/uuid"
"github.com/zitadel/logging"
"golang.org/x/exp/slog"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/v2/pkg/client"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/client"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
const (
@ -63,11 +64,14 @@ type RelyingParty interface {
// be used to start a DeviceAuthorization flow.
GetDeviceAuthorizationEndpoint() string
// IDTokenVerifier returns the verifier interface used for oidc id_token verification
IDTokenVerifier() IDTokenVerifier
// IDTokenVerifier returns the verifier used for oidc id_token verification
IDTokenVerifier() *IDTokenVerifier
// ErrorHandler returns the handler used for callback errors
ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)
// Logger from the context, or a fallback if set.
Logger(context.Context) (logger *slog.Logger, ok bool)
}
type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string)
@ -88,9 +92,10 @@ type relyingParty struct {
cookieHandler *httphelper.CookieHandler
errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
idTokenVerifier IDTokenVerifier
idTokenVerifier *IDTokenVerifier
verifierOpts []VerifierOption
signer jose.Signer
logger *slog.Logger
}
func (rp *relyingParty) OAuthConfig() *oauth2.Config {
@ -137,7 +142,7 @@ func (rp *relyingParty) GetRevokeEndpoint() string {
return rp.endpoints.RevokeURL
}
func (rp *relyingParty) IDTokenVerifier() IDTokenVerifier {
func (rp *relyingParty) IDTokenVerifier() *IDTokenVerifier {
if rp.idTokenVerifier == nil {
rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...)
}
@ -151,6 +156,14 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request,
return rp.errorHandler
}
func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) {
logger, ok = logging.FromContext(ctx)
if ok {
return logger, ok
}
return rp.logger, rp.logger != nil
}
// NewRelyingPartyOAuth creates an (OAuth2) RelyingParty with the given
// OAuth2 Config and possible configOptions
// it will use the AuthURL and TokenURL set in config
@ -177,7 +190,7 @@ func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingPart
// NewRelyingPartyOIDC creates an (OIDC) RelyingParty with the given
// issuer, clientID, clientSecret, redirectURI, scopes and possible configOptions
// it will run discovery on the provided issuer and use the found endpoints
func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) {
func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) {
rp := &relyingParty{
issuer: issuer,
oauthConfig: &oauth2.Config{
@ -195,7 +208,8 @@ func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, sco
return nil, err
}
}
discoveryConfiguration, err := client.Discover(rp.issuer, rp.httpClient, rp.DiscoveryEndpoint)
ctx = logCtxWithRPData(ctx, rp, "function", "NewRelyingPartyOIDC")
discoveryConfiguration, err := client.Discover(ctx, rp.issuer, rp.httpClient, rp.DiscoveryEndpoint)
if err != nil {
return nil, err
}
@ -282,6 +296,15 @@ func WithJWTProfile(signerFromKey SignerFromKey) Option {
}
}
// WithLogger sets a logger that is used
// in case the request context does not contain a logger.
func WithLogger(logger *slog.Logger) Option {
return func(rp *relyingParty) error {
rp.logger = logger
return nil
}
}
type SignerFromKey func() (jose.Signer, error)
func SignerFromKeyPath(path string) SignerFromKey {
@ -310,26 +333,6 @@ func SignerFromKeyAndKeyID(key []byte, keyID string) SignerFromKey {
}
}
// Discover calls the discovery endpoint of the provided issuer and returns the found endpoints
//
// deprecated: use client.Discover
func Discover(issuer string, httpClient *http.Client) (Endpoints, error) {
wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
req, err := http.NewRequest("GET", wellKnown, nil)
if err != nil {
return Endpoints{}, err
}
discoveryConfig := new(oidc.DiscoveryConfiguration)
err = httphelper.HttpRequest(httpClient, req, &discoveryConfig)
if err != nil {
return Endpoints{}, err
}
if discoveryConfig.Issuer != issuer {
return Endpoints{}, fmt.Errorf("%w: Expected: %s, got: %s", oidc.ErrIssuerInvalid, discoveryConfig.Issuer, issuer)
}
return GetEndpoints(discoveryConfig), nil
}
// AuthURL returns the auth request url
// (wrapping the oauth2 `AuthCodeURL`)
func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string {
@ -377,9 +380,29 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri
return oidc.NewSHACodeChallenge(codeVerifier), nil
}
// ErrMissingIDToken is returned when an id_token was expected,
// but not received in the token response.
var ErrMissingIDToken = errors.New("id_token missing")
func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Token, rp RelyingParty) (*oidc.Tokens[C], error) {
if rp.IsOAuth2Only() {
return &oidc.Tokens[C]{Token: token}, nil
}
idTokenString, ok := token.Extra(idTokenKey).(string)
if !ok {
return &oidc.Tokens[C]{Token: token}, ErrMissingIDToken
}
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
if err != nil {
return nil, err
}
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
}
// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
// returning it parsed together with the oauth2 tokens (access, refresh)
func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
ctx = logCtxWithRPData(ctx, rp, "function", "CodeExchange")
ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
codeOpts := make([]oauth2.AuthCodeOption, 0)
for _, opt := range opts {
@ -390,22 +413,7 @@ func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingP
if err != nil {
return nil, err
}
if rp.IsOAuth2Only() {
return &oidc.Tokens[C]{Token: token}, nil
}
idTokenString, ok := token.Extra(idTokenKey).(string)
if !ok {
return nil, errors.New("id_token missing")
}
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
if err != nil {
return nil, err
}
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
return verifyTokenResponse[C](ctx, token, rp)
}
type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty)
@ -457,14 +465,18 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
}
}
type CodeExchangeUserinfoCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info *oidc.UserInfo)
type SubjectGetter interface {
GetSubject() string
}
type CodeExchangeUserinfoCallback[C oidc.IDClaims, U SubjectGetter] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info U)
// UserinfoCallback wraps the callback function of the CodeExchangeHandler
// and calls the userinfo endpoint with the access token
// on success it will pass the userinfo into its callback function as well
func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] {
func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCallback[C, U]) CodeExchangeCallback[C] {
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
if err != nil {
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
return
@ -473,19 +485,26 @@ func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeEx
}
}
// Userinfo will call the OIDC Userinfo Endpoint with the provided token
func Userinfo(token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) {
req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil)
// Userinfo will call the OIDC [UserInfo] Endpoint with the provided token and returns
// the response in an instance of type U.
// [*oidc.UserInfo] can be used as a good example, or use a custom type if type-safe
// access to custom claims is needed.
//
// [UserInfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject string, rp RelyingParty) (userinfo U, err error) {
var nilU U
ctx = logCtxWithRPData(ctx, rp, "function", "Userinfo")
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil)
if err != nil {
return nil, err
return nilU, err
}
req.Header.Set("authorization", tokenType+" "+token)
userinfo := new(oidc.UserInfo)
if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil {
return nil, err
return nilU, err
}
if userinfo.Subject != subject {
return nil, ErrUserInfoSubNotMatching
if userinfo.GetSubject() != subject {
return nilU, ErrUserInfoSubNotMatching
}
return userinfo, nil
}
@ -554,7 +573,7 @@ func withURLParam(key, value string) func() []oauth2.AuthCodeOption {
// This is the generalized, unexported, function used by both
// URLParamOpt and AuthURLOpt.
func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption {
return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode())
return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).String())
}
type URLParamOpt func() []oauth2.AuthCodeOption
@ -626,11 +645,15 @@ type RefreshTokenRequest struct {
GrantType oidc.GrantType `schema:"grant_type"`
}
// RefreshAccessToken performs a token refresh. If it doesn't error, it will always
// RefreshTokens performs a token refresh. If it doesn't error, it will always
// provide a new AccessToken. It may provide a new RefreshToken, and if it does, then
// the old one should be considered invalid. It may also provide a new IDToken. The
// new IDToken can be retrieved with token.Extra("id_token").
func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) {
// the old one should be considered invalid.
//
// In case the RP is not OAuth2 only and an IDToken was part of the response,
// the IDToken and AccessToken will be verfied
// and the IDToken and IDTokenClaims fields will be populated in the returned object.
func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) {
ctx = logCtxWithRPData(ctx, rp, "function", "RefreshTokens")
request := RefreshTokenRequest{
RefreshToken: refreshToken,
Scopes: rp.OAuthConfig().Scopes,
@ -640,17 +663,28 @@ func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAs
ClientAssertionType: clientAssertionType,
GrantType: oidc.GrantTypeRefreshToken,
}
return client.CallTokenEndpoint(request, tokenEndpointCaller{RelyingParty: rp})
newToken, err := client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp})
if err != nil {
return nil, err
}
tokens, err := verifyTokenResponse[C](ctx, newToken, rp)
if err == nil || errors.Is(err, ErrMissingIDToken) {
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse
// ...except that it might not contain an id_token.
return tokens, nil
}
return nil, err
}
func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) {
func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) {
ctx = logCtxWithRPData(ctx, rp, "function", "EndSession")
request := oidc.EndSessionRequest{
IdTokenHint: idToken,
ClientID: rp.OAuthConfig().ClientID,
PostLogoutRedirectURI: optionalRedirectURI,
State: optionalState,
}
return client.CallEndSessionEndpoint(request, nil, rp)
return client.CallEndSessionEndpoint(ctx, request, nil, rp)
}
// RevokeToken requires a RelyingParty that is also a client.RevokeCaller. The RelyingParty
@ -658,7 +692,8 @@ func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState str
// NewRelyingPartyOAuth() does not.
//
// tokenTypeHint should be either "id_token" or "refresh_token".
func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error {
func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error {
ctx = logCtxWithRPData(ctx, rp, "function", "RevokeToken")
request := client.RevokeRequest{
Token: token,
TokenTypeHint: tokenTypeHint,
@ -666,7 +701,7 @@ func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error {
ClientSecret: rp.OAuthConfig().ClientSecret,
}
if rc, ok := rp.(client.RevokeCaller); ok && rc.GetRevokeEndpoint() != "" {
return client.CallRevokeEndpoint(request, nil, rc)
return client.CallRevokeEndpoint(ctx, request, nil, rc)
}
return fmt.Errorf("RelyingParty does not support RevokeCaller")
}

View file

@ -0,0 +1,107 @@
package rp
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
)
func Test_verifyTokenResponse(t *testing.T) {
verifier := &IDTokenVerifier{
Issuer: tu.ValidIssuer,
MaxAgeIAT: 2 * time.Minute,
ClientID: tu.ValidClientID,
Offset: time.Second,
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
KeySet: tu.KeySet{},
MaxAge: 2 * time.Minute,
ACR: tu.ACRVerify,
Nonce: func(context.Context) string { return tu.ValidNonce },
}
tests := []struct {
name string
oauth2Only bool
tokens func() (token *oauth2.Token, want *oidc.Tokens[*oidc.IDTokenClaims])
wantErr error
}{
{
name: "succes, oauth2 only",
oauth2Only: true,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
Token: token,
}
},
},
{
name: "id_token missing error",
oauth2Only: false,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
Token: token,
}
},
wantErr: ErrMissingIDToken,
},
{
name: "verify tokens error",
oauth2Only: false,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
token = token.WithExtra(map[string]any{
"id_token": "foobar",
})
return token, nil
},
wantErr: oidc.ErrParse,
},
{
name: "success, with id_token",
oauth2Only: false,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
idToken, claims := tu.ValidIDToken()
token = token.WithExtra(map[string]any{
"id_token": idToken,
})
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
Token: token,
IDTokenClaims: claims,
IDToken: idToken,
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := &relyingParty{
oauth2Only: tt.oauth2Only,
idTokenVerifier: verifier,
}
token, want := tt.tokens()
got, err := verifyTokenResponse[*oidc.IDTokenClaims](context.Background(), token, rp)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, want, got)
})
}
}

View file

@ -5,7 +5,7 @@ import (
"golang.org/x/oauth2"
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange"
"github.com/zitadel/oidc/v3/pkg/oidc/grants/tokenexchange"
)
// TokenExchangeRP extends the `RelyingParty` interface for the *draft* oauth2 `Token Exchange`

View file

@ -0,0 +1,45 @@
package rp_test
import (
"context"
"fmt"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
type UserInfo struct {
Subject string `json:"sub,omitempty"`
oidc.UserInfoProfile
oidc.UserInfoEmail
oidc.UserInfoPhone
Address *oidc.UserInfoAddress `json:"address,omitempty"`
// Foo and Bar are custom claims
Foo string `json:"foo,omitempty"`
Bar struct {
Val1 string `json:"val_1,omitempty"`
Val2 string `json:"val_2,omitempty"`
} `json:"bar,omitempty"`
// Claims are all the combined claims, including custom.
Claims map[string]any `json:"-,omitempty"`
}
func (u *UserInfo) GetSubject() string {
return u.Subject
}
func ExampleUserinfo_custom() {
rpo, err := rp.NewRelyingPartyOIDC(context.TODO(), "http://localhost:8080", "clientid", "clientsecret", "http://example.com/redirect", []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone})
if err != nil {
panic(err)
}
info, err := rp.Userinfo[*UserInfo](context.TODO(), "accesstokenstring", "Bearer", "userid", rpo)
if err != nil {
panic(err)
}
fmt.Println(info)
}

View file

@ -4,24 +4,14 @@ import (
"context"
"time"
"gopkg.in/square/go-jose.v2"
jose "github.com/go-jose/go-jose/v3"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
type IDTokenVerifier interface {
oidc.Verifier
ClientID() string
SupportedSignAlgs() []string
KeySet() oidc.KeySet
Nonce(context.Context) string
ACR() oidc.ACRVerifier
MaxAge() time.Duration
}
// VerifyTokens implement the Token Response Validation as defined in OIDC specification
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) {
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier) (claims C, err error) {
var nilClaims C
claims, err = VerifyIDToken[C](ctx, idToken, v)
@ -36,7 +26,7 @@ func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken str
// VerifyIDToken validates the id token according to
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) {
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier) (claims C, err error) {
var nilClaims C
decrypted, err := oidc.DecryptToken(token)
@ -52,27 +42,27 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe
return nilClaims, err
}
if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil {
if err = oidc.CheckIssuer(claims, v.Issuer); err != nil {
return nilClaims, err
}
if err = oidc.CheckAudience(claims, v.ClientID()); err != nil {
if err = oidc.CheckAudience(claims, v.ClientID); err != nil {
return nilClaims, err
}
if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil {
if err = oidc.CheckAuthorizedParty(claims, v.ClientID); err != nil {
return nilClaims, err
}
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil {
return nilClaims, err
}
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
return nilClaims, err
}
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil {
return nilClaims, err
}
@ -80,16 +70,18 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe
return nilClaims, err
}
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil {
return nilClaims, err
}
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil {
return nilClaims, err
}
return claims, nil
}
type IDTokenVerifier oidc.Verifier
// VerifyAccessToken validates the access token according to
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
@ -107,15 +99,14 @@ func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAl
return nil
}
// NewIDTokenVerifier returns an implementation of `IDTokenVerifier`
// for `VerifyTokens` and `VerifyIDToken`
func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) IDTokenVerifier {
v := &idTokenVerifier{
issuer: issuer,
clientID: clientID,
keySet: keySet,
offset: time.Second,
nonce: func(_ context.Context) string {
// NewIDTokenVerifier returns a oidc.Verifier suitable for ID token verification.
func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) *IDTokenVerifier {
v := &IDTokenVerifier{
Issuer: issuer,
ClientID: clientID,
KeySet: keySet,
Offset: time.Second,
Nonce: func(_ context.Context) string {
return ""
},
}
@ -128,95 +119,47 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...
}
// VerifierOption is the type for providing dynamic options to the IDTokenVerifier
type VerifierOption func(*idTokenVerifier)
type VerifierOption func(*IDTokenVerifier)
// WithIssuedAtOffset mitigates the risk of iat to be in the future
// because of clock skews with the ability to add an offset to the current time
func WithIssuedAtOffset(offset time.Duration) func(*idTokenVerifier) {
return func(v *idTokenVerifier) {
v.offset = offset
func WithIssuedAtOffset(offset time.Duration) VerifierOption {
return func(v *IDTokenVerifier) {
v.Offset = offset
}
}
// WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
func WithIssuedAtMaxAge(maxAge time.Duration) func(*idTokenVerifier) {
return func(v *idTokenVerifier) {
v.maxAgeIAT = maxAge
func WithIssuedAtMaxAge(maxAge time.Duration) VerifierOption {
return func(v *IDTokenVerifier) {
v.MaxAgeIAT = maxAge
}
}
// WithNonce sets the function to check the nonce
func WithNonce(nonce func(context.Context) string) VerifierOption {
return func(v *idTokenVerifier) {
v.nonce = nonce
return func(v *IDTokenVerifier) {
v.Nonce = nonce
}
}
// WithACRVerifier sets the verifier for the acr claim
func WithACRVerifier(verifier oidc.ACRVerifier) VerifierOption {
return func(v *idTokenVerifier) {
v.acr = verifier
return func(v *IDTokenVerifier) {
v.ACR = verifier
}
}
// WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now
func WithAuthTimeMaxAge(maxAge time.Duration) VerifierOption {
return func(v *idTokenVerifier) {
v.maxAge = maxAge
return func(v *IDTokenVerifier) {
v.MaxAge = maxAge
}
}
// WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm
func WithSupportedSigningAlgorithms(algs ...string) VerifierOption {
return func(v *idTokenVerifier) {
v.supportedSignAlgs = algs
return func(v *IDTokenVerifier) {
v.SupportedSignAlgs = algs
}
}
type idTokenVerifier struct {
issuer string
maxAgeIAT time.Duration
offset time.Duration
clientID string
supportedSignAlgs []string
keySet oidc.KeySet
acr oidc.ACRVerifier
maxAge time.Duration
nonce func(ctx context.Context) string
}
func (i *idTokenVerifier) Issuer() string {
return i.issuer
}
func (i *idTokenVerifier) MaxAgeIAT() time.Duration {
return i.maxAgeIAT
}
func (i *idTokenVerifier) Offset() time.Duration {
return i.offset
}
func (i *idTokenVerifier) ClientID() string {
return i.clientID
}
func (i *idTokenVerifier) SupportedSignAlgs() []string {
return i.supportedSignAlgs
}
func (i *idTokenVerifier) KeySet() oidc.KeySet {
return i.keySet
}
func (i *idTokenVerifier) Nonce(ctx context.Context) string {
return i.nonce(ctx)
}
func (i *idTokenVerifier) ACR() oidc.ACRVerifier {
return i.acr
}
func (i *idTokenVerifier) MaxAge() time.Duration {
return i.maxAge
}

View file

@ -5,24 +5,24 @@ import (
"testing"
"time"
jose "github.com/go-jose/go-jose/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v2/internal/testutil"
"github.com/zitadel/oidc/v2/pkg/oidc"
"gopkg.in/square/go-jose.v2"
tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
func TestVerifyTokens(t *testing.T) {
verifier := &idTokenVerifier{
issuer: tu.ValidIssuer,
maxAgeIAT: 2 * time.Minute,
offset: time.Second,
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
keySet: tu.KeySet{},
maxAge: 2 * time.Minute,
acr: tu.ACRVerify,
nonce: func(context.Context) string { return tu.ValidNonce },
clientID: tu.ValidClientID,
verifier := &IDTokenVerifier{
Issuer: tu.ValidIssuer,
MaxAgeIAT: 2 * time.Minute,
Offset: time.Second,
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
KeySet: tu.KeySet{},
MaxAge: 2 * time.Minute,
ACR: tu.ACRVerify,
Nonce: func(context.Context) string { return tu.ValidNonce },
ClientID: tu.ValidClientID,
}
accessToken, _ := tu.ValidAccessToken()
atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm)
@ -91,15 +91,15 @@ func TestVerifyTokens(t *testing.T) {
}
func TestVerifyIDToken(t *testing.T) {
verifier := &idTokenVerifier{
issuer: tu.ValidIssuer,
maxAgeIAT: 2 * time.Minute,
offset: time.Second,
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
keySet: tu.KeySet{},
maxAge: 2 * time.Minute,
acr: tu.ACRVerify,
nonce: func(context.Context) string { return tu.ValidNonce },
verifier := &IDTokenVerifier{
Issuer: tu.ValidIssuer,
MaxAgeIAT: 2 * time.Minute,
Offset: time.Second,
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
KeySet: tu.KeySet{},
MaxAge: 2 * time.Minute,
ACR: tu.ACRVerify,
Nonce: func(context.Context) string { return tu.ValidNonce },
}
tests := []struct {
@ -231,7 +231,7 @@ func TestVerifyIDToken(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, want := tt.tokenClaims()
verifier.clientID = tt.clientID
verifier.ClientID = tt.clientID
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
if tt.wantErr {
assert.Error(t, err)
@ -312,7 +312,7 @@ func TestNewIDTokenVerifier(t *testing.T) {
tests := []struct {
name string
args args
want IDTokenVerifier
want *IDTokenVerifier
}{
{
name: "nil nonce", // otherwise assert.Equal will fail on the function
@ -329,16 +329,16 @@ func TestNewIDTokenVerifier(t *testing.T) {
WithSupportedSigningAlgorithms("ABC", "DEF"),
},
},
want: &idTokenVerifier{
issuer: tu.ValidIssuer,
offset: time.Minute,
maxAgeIAT: time.Hour,
clientID: tu.ValidClientID,
keySet: tu.KeySet{},
nonce: nil,
acr: nil,
maxAge: 2 * time.Hour,
supportedSignAlgs: []string{"ABC", "DEF"},
want: &IDTokenVerifier{
Issuer: tu.ValidIssuer,
Offset: time.Minute,
MaxAgeIAT: time.Hour,
ClientID: tu.ValidClientID,
KeySet: tu.KeySet{},
Nonce: nil,
ACR: nil,
MaxAge: 2 * time.Hour,
SupportedSignAlgs: []string{"ABC", "DEF"},
},
},
}

View file

@ -4,9 +4,9 @@ import (
"context"
"fmt"
tu "github.com/zitadel/oidc/v2/internal/testutil"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/oidc"
tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
// MyCustomClaims extends the TokenClaims base,

View file

@ -0,0 +1,52 @@
package rs_test
import (
"context"
"fmt"
"github.com/zitadel/oidc/v3/pkg/client/rs"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
type IntrospectionResponse struct {
Active bool `json:"active"`
Scope oidc.SpaceDelimitedArray `json:"scope,omitempty"`
ClientID string `json:"client_id,omitempty"`
TokenType string `json:"token_type,omitempty"`
Expiration oidc.Time `json:"exp,omitempty"`
IssuedAt oidc.Time `json:"iat,omitempty"`
NotBefore oidc.Time `json:"nbf,omitempty"`
Subject string `json:"sub,omitempty"`
Audience oidc.Audience `json:"aud,omitempty"`
Issuer string `json:"iss,omitempty"`
JWTID string `json:"jti,omitempty"`
Username string `json:"username,omitempty"`
oidc.UserInfoProfile
oidc.UserInfoEmail
oidc.UserInfoPhone
Address *oidc.UserInfoAddress `json:"address,omitempty"`
// Foo and Bar are custom claims
Foo string `json:"foo,omitempty"`
Bar struct {
Val1 string `json:"val_1,omitempty"`
Val2 string `json:"val_2,omitempty"`
} `json:"bar,omitempty"`
// Claims are all the combined claims, including custom.
Claims map[string]any `json:"-,omitempty"`
}
func ExampleIntrospect_custom() {
rss, err := rs.NewResourceServerClientCredentials(context.TODO(), "http://localhost:8080", "clientid", "clientsecret")
if err != nil {
panic(err)
}
resp, err := rs.Introspect[*IntrospectionResponse](context.TODO(), rss, "accesstokenstring")
if err != nil {
panic(err)
}
fmt.Println(resp)
}

View file

@ -6,9 +6,9 @@ import (
"net/http"
"time"
"github.com/zitadel/oidc/v2/pkg/client"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/client"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
type ResourceServer interface {
@ -42,14 +42,14 @@ func (r *resourceServer) AuthFn() (any, error) {
return r.authFn()
}
func NewResourceServerClientCredentials(issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) {
func NewResourceServerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) {
authorizer := func() (any, error) {
return httphelper.AuthorizeBasic(clientID, clientSecret), nil
}
return newResourceServer(issuer, authorizer, option...)
return newResourceServer(ctx, issuer, authorizer, option...)
}
func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) {
func NewResourceServerJWTProfile(ctx context.Context, issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) {
signer, err := client.NewSignerFromPrivateKeyByte(key, keyID)
if err != nil {
return nil, err
@ -61,10 +61,10 @@ func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, opt
}
return client.ClientAssertionFormAuthorization(assertion), nil
}
return newResourceServer(issuer, authorizer, options...)
return newResourceServer(ctx, issuer, authorizer, options...)
}
func newResourceServer(issuer string, authorizer func() (any, error), options ...Option) (*resourceServer, error) {
func newResourceServer(ctx context.Context, issuer string, authorizer func() (any, error), options ...Option) (*resourceServer, error) {
rs := &resourceServer{
issuer: issuer,
httpClient: httphelper.DefaultHTTPClient,
@ -73,7 +73,7 @@ func newResourceServer(issuer string, authorizer func() (any, error), options ..
optFunc(rs)
}
if rs.introspectURL == "" || rs.tokenURL == "" {
config, err := client.Discover(rs.issuer, rs.httpClient)
config, err := client.Discover(ctx, rs.issuer, rs.httpClient)
if err != nil {
return nil, err
}
@ -91,12 +91,12 @@ func newResourceServer(issuer string, authorizer func() (any, error), options ..
return rs, nil
}
func NewResourceServerFromKeyFile(issuer, path string, options ...Option) (ResourceServer, error) {
func NewResourceServerFromKeyFile(ctx context.Context, issuer, path string, options ...Option) (ResourceServer, error) {
c, err := client.ConfigFromKeyFile(path)
if err != nil {
return nil, err
}
return NewResourceServerJWTProfile(issuer, c.ClientID, c.KeyID, []byte(c.Key), options...)
return NewResourceServerJWTProfile(ctx, issuer, c.ClientID, c.KeyID, []byte(c.Key), options...)
}
type Option func(*resourceServer)
@ -116,21 +116,27 @@ func WithStaticEndpoints(tokenURL, introspectURL string) Option {
}
}
func Introspect(ctx context.Context, rp ResourceServer, token string) (*oidc.IntrospectionResponse, error) {
// Introspect calls the [RFC7662] Token Introspection
// endpoint and returns the response in an instance of type R.
// [*oidc.IntrospectionResponse] can be used as a good example, or use a custom type if type-safe
// access to custom claims is needed.
//
// [RFC7662]: https://www.rfc-editor.org/rfc/rfc7662
func Introspect[R any](ctx context.Context, rp ResourceServer, token string) (resp R, err error) {
if rp.IntrospectionURL() == "" {
return nil, errors.New("resource server: introspection URL is empty")
return resp, errors.New("resource server: introspection URL is empty")
}
authFn, err := rp.AuthFn()
if err != nil {
return nil, err
return resp, err
}
req, err := httphelper.FormRequest(rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn)
req, err := httphelper.FormRequest(ctx, rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn)
if err != nil {
return nil, err
return resp, err
}
resp := new(oidc.IntrospectionResponse)
if err := httphelper.HttpRequest(rp.HttpClient(), req, resp); err != nil {
return nil, err
if err := httphelper.HttpRequest(rp.HttpClient(), req, &resp); err != nil {
return resp, err
}
return resp, nil
}

View file

@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
func TestNewResourceServer(t *testing.T) {
@ -164,7 +165,7 @@ func TestNewResourceServer(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := newResourceServer(tt.args.issuer, tt.args.authorizer, tt.args.options...)
got, err := newResourceServer(context.Background(), tt.args.issuer, tt.args.authorizer, tt.args.options...)
if tt.wantErr {
assert.Error(t, err)
return
@ -187,6 +188,7 @@ func TestIntrospect(t *testing.T) {
token string
}
rp, err := newResourceServer(
context.Background(),
"https://accounts.spotify.com",
nil,
)
@ -208,7 +210,7 @@ func TestIntrospect(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := Introspect(tt.args.ctx, tt.args.rp, tt.args.token)
_, err := Introspect[*oidc.IntrospectionResponse](tt.args.ctx, tt.args.rp, tt.args.token)
if tt.wantErr {
assert.Error(t, err)
return

View file

@ -1,12 +1,13 @@
package tokenexchange
import (
"context"
"errors"
"net/http"
"github.com/zitadel/oidc/v2/pkg/client"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/client"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
type TokenExchanger interface {
@ -21,18 +22,18 @@ type OAuthTokenExchange struct {
authFn func() (any, error)
}
func NewTokenExchanger(issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
return newOAuthTokenExchange(issuer, nil, options...)
func NewTokenExchanger(ctx context.Context, issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
return newOAuthTokenExchange(ctx, issuer, nil, options...)
}
func NewTokenExchangerClientCredentials(issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
func NewTokenExchangerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
authorizer := func() (any, error) {
return httphelper.AuthorizeBasic(clientID, clientSecret), nil
}
return newOAuthTokenExchange(issuer, authorizer, options...)
return newOAuthTokenExchange(ctx, issuer, authorizer, options...)
}
func newOAuthTokenExchange(issuer string, authorizer func() (any, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) {
func newOAuthTokenExchange(ctx context.Context, issuer string, authorizer func() (any, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) {
te := &OAuthTokenExchange{
httpClient: httphelper.DefaultHTTPClient,
}
@ -41,7 +42,7 @@ func newOAuthTokenExchange(issuer string, authorizer func() (any, error), option
}
if te.tokenEndpoint == "" {
config, err := client.Discover(issuer, te.httpClient)
config, err := client.Discover(ctx, issuer, te.httpClient)
if err != nil {
return nil, err
}
@ -89,6 +90,7 @@ func (te *OAuthTokenExchange) AuthFn() (any, error) {
// ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint.
// SubjectToken and SubjectTokenType are required parameters.
func ExchangeToken(
ctx context.Context,
te TokenExchanger,
SubjectToken string,
SubjectTokenType oidc.TokenType,
@ -123,5 +125,5 @@ func ExchangeToken(
RequestedTokenType: RequestedTokenType,
}
return client.CallTokenExchangeEndpoint(request, authFn, te)
return client.CallTokenExchangeEndpoint(ctx, request, authFn, te)
}

View file

@ -8,7 +8,7 @@ import (
"fmt"
"hash"
"gopkg.in/square/go-jose.v2"
jose "github.com/go-jose/go-jose/v3"
)
var ErrUnsupportedAlgorithm = errors.New("unsupported signing algorithm")

View file

@ -4,7 +4,7 @@ import (
"encoding/json"
"errors"
"gopkg.in/square/go-jose.v2"
jose "github.com/go-jose/go-jose/v3"
)
func Sign(object any, signer jose.Signer) (string, error) {

View file

@ -33,7 +33,7 @@ func AuthorizeBasic(user, password string) RequestAuthorization {
}
}
func FormRequest(endpoint string, request any, encoder Encoder, authFn any) (*http.Request, error) {
func FormRequest(ctx context.Context, endpoint string, request any, encoder Encoder, authFn any) (*http.Request, error) {
form := url.Values{}
if err := encoder.Encode(request, form); err != nil {
return nil, err
@ -42,7 +42,7 @@ func FormRequest(endpoint string, request any, encoder Encoder, authFn any) (*ht
fn(form)
}
body := strings.NewReader(form.Encode())
req, err := http.NewRequest("POST", endpoint, body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body)
if err != nil {
return nil, err
}

View file

@ -1,5 +1,9 @@
package oidc
import (
"golang.org/x/exp/slog"
)
const (
// ScopeOpenID defines the scope `openid`
// OpenID Connect requests MUST contain the `openid` scope value
@ -77,7 +81,7 @@ type AuthRequest struct {
UILocales Locales `json:"ui_locales" schema:"ui_locales"`
IDTokenHint string `json:"id_token_hint" schema:"id_token_hint"`
LoginHint string `json:"login_hint" schema:"login_hint"`
ACRValues []string `json:"acr_values" schema:"acr_values"`
ACRValues SpaceDelimitedArray `json:"acr_values" schema:"acr_values"`
CodeChallenge string `json:"code_challenge" schema:"code_challenge"`
CodeChallengeMethod CodeChallengeMethod `json:"code_challenge_method" schema:"code_challenge_method"`
@ -86,6 +90,15 @@ type AuthRequest struct {
RequestParam string `schema:"request"`
}
func (a *AuthRequest) LogValue() slog.Value {
return slog.GroupValue(
slog.Any("scopes", a.Scopes),
slog.String("response_type", string(a.ResponseType)),
slog.String("client_id", a.ClientID),
slog.String("redirect_uri", a.RedirectURI),
)
}
// GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface
func (a *AuthRequest) GetRedirectURI() string {
return a.RedirectURI

View file

@ -0,0 +1,27 @@
//go:build go1.20
package oidc
import (
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/slog"
)
func TestAuthRequest_LogValue(t *testing.T) {
a := &AuthRequest{
Scopes: SpaceDelimitedArray{"a", "b"},
ResponseType: "respType",
ClientID: "123",
RedirectURI: "http://example.com/callback",
}
want := slog.GroupValue(
slog.Any("scopes", SpaceDelimitedArray{"a", "b"}),
slog.String("response_type", "respType"),
slog.String("client_id", "123"),
slog.String("redirect_uri", "http://example.com/callback"),
)
got := a.LogValue()
assert.Equal(t, want, got)
}

View file

@ -3,7 +3,7 @@ package oidc
import (
"crypto/sha256"
"github.com/zitadel/oidc/v2/pkg/crypto"
"github.com/zitadel/oidc/v3/pkg/crypto"
)
const (

View file

@ -3,6 +3,8 @@ package oidc
import (
"errors"
"fmt"
"golang.org/x/exp/slog"
)
type errorType string
@ -171,3 +173,34 @@ func DefaultToServerError(err error, description string) *Error {
}
return oauth
}
func (e *Error) LogLevel() slog.Level {
level := slog.LevelWarn
if e.ErrorType == ServerError {
level = slog.LevelError
}
if e.ErrorType == AuthorizationPending {
level = slog.LevelInfo
}
return level
}
func (e *Error) LogValue() slog.Value {
attrs := make([]slog.Attr, 0, 5)
if e.Parent != nil {
attrs = append(attrs, slog.Any("parent", e.Parent))
}
if e.Description != "" {
attrs = append(attrs, slog.String("description", e.Description))
}
if e.ErrorType != "" {
attrs = append(attrs, slog.String("type", string(e.ErrorType)))
}
if e.State != "" {
attrs = append(attrs, slog.String("state", e.State))
}
if e.redirectDisabled {
attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled))
}
return slog.GroupValue(attrs...)
}

View file

@ -0,0 +1,83 @@
//go:build go1.20
package oidc
import (
"io"
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/slog"
)
func TestError_LogValue(t *testing.T) {
type fields struct {
Parent error
ErrorType errorType
Description string
State string
redirectDisabled bool
}
tests := []struct {
name string
fields fields
want slog.Value
}{
{
name: "parent",
fields: fields{
Parent: io.EOF,
},
want: slog.GroupValue(slog.Any("parent", io.EOF)),
},
{
name: "description",
fields: fields{
Description: "oops",
},
want: slog.GroupValue(slog.String("description", "oops")),
},
{
name: "errorType",
fields: fields{
ErrorType: ExpiredToken,
},
want: slog.GroupValue(slog.String("type", string(ExpiredToken))),
},
{
name: "state",
fields: fields{
State: "123",
},
want: slog.GroupValue(slog.String("state", "123")),
},
{
name: "all fields",
fields: fields{
Parent: io.EOF,
Description: "oops",
ErrorType: ExpiredToken,
State: "123",
},
want: slog.GroupValue(
slog.Any("parent", io.EOF),
slog.String("description", "oops"),
slog.String("type", string(ExpiredToken)),
slog.String("state", "123"),
),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &Error{
Parent: tt.fields.Parent,
ErrorType: tt.fields.ErrorType,
Description: tt.fields.Description,
State: tt.fields.State,
redirectDisabled: tt.fields.redirectDisabled,
}
got := e.LogValue()
assert.Equal(t, tt.want, got)
})
}
}

81
pkg/oidc/error_test.go Normal file
View file

@ -0,0 +1,81 @@
package oidc
import (
"io"
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/slog"
)
func TestDefaultToServerError(t *testing.T) {
type args struct {
err error
description string
}
tests := []struct {
name string
args args
want *Error
}{
{
name: "default",
args: args{
err: io.ErrClosedPipe,
description: "oops",
},
want: &Error{
ErrorType: ServerError,
Description: "oops",
Parent: io.ErrClosedPipe,
},
},
{
name: "our Error",
args: args{
err: ErrAccessDenied(),
description: "oops",
},
want: &Error{
ErrorType: AccessDenied,
Description: "The authorization request was denied.",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := DefaultToServerError(tt.args.err, tt.args.description)
assert.ErrorIs(t, got, tt.want)
})
}
}
func TestError_LogLevel(t *testing.T) {
tests := []struct {
name string
err *Error
want slog.Level
}{
{
name: "server error",
err: ErrServerError(),
want: slog.LevelError,
},
{
name: "authorization pending",
err: ErrAuthorizationPending(),
want: slog.LevelInfo,
},
{
name: "some other error",
err: ErrAccessDenied(),
want: slog.LevelWarn,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.err.LogLevel()
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -7,7 +7,7 @@ import (
"crypto/rsa"
"errors"
"gopkg.in/square/go-jose.v2"
jose "github.com/go-jose/go-jose/v3"
)
const (

View file

@ -7,7 +7,7 @@ import (
"reflect"
"testing"
"gopkg.in/square/go-jose.v2"
jose "github.com/go-jose/go-jose/v3"
)
func TestFindKey(t *testing.T) {

View file

@ -5,11 +5,11 @@ import (
"os"
"time"
jose "github.com/go-jose/go-jose/v3"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/muhlemmer/gu"
"github.com/zitadel/oidc/v2/pkg/crypto"
"github.com/zitadel/oidc/v3/pkg/crypto"
)
const (

View file

@ -5,7 +5,7 @@ import (
"fmt"
"time"
"gopkg.in/square/go-jose.v2"
jose "github.com/go-jose/go-jose/v3"
)
const (

View file

@ -4,9 +4,9 @@ import (
"testing"
"time"
jose "github.com/go-jose/go-jose/v3"
"github.com/stretchr/testify/assert"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
)
var (

View file

@ -8,10 +8,10 @@ import (
"strings"
"time"
"github.com/gorilla/schema"
jose "github.com/go-jose/go-jose/v3"
"github.com/muhlemmer/gu"
"github.com/zitadel/schema"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
)
type Audience []string
@ -151,7 +151,7 @@ type ResponseType string
type ResponseMode string
func (s SpaceDelimitedArray) Encode() string {
func (s SpaceDelimitedArray) String() string {
return strings.Join(s, " ")
}
@ -161,11 +161,11 @@ func (s *SpaceDelimitedArray) UnmarshalText(text []byte) error {
}
func (s SpaceDelimitedArray) MarshalText() ([]byte, error) {
return []byte(s.Encode()), nil
return []byte(s.String()), nil
}
func (s SpaceDelimitedArray) MarshalJSON() ([]byte, error) {
return json.Marshal((s).Encode())
return json.Marshal((s).String())
}
func (s *SpaceDelimitedArray) UnmarshalJSON(data []byte) error {
@ -210,7 +210,7 @@ func (s SpaceDelimitedArray) Value() (driver.Value, error) {
func NewEncoder() *schema.Encoder {
e := schema.NewEncoder()
e.RegisterEncoder(SpaceDelimitedArray{}, func(value reflect.Value) string {
return value.Interface().(SpaceDelimitedArray).Encode()
return value.Interface().(SpaceDelimitedArray).String()
})
return e
}

View file

@ -9,9 +9,9 @@ import (
"testing"
"time"
"github.com/gorilla/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/schema"
"golang.org/x/text/language"
)

View file

@ -29,6 +29,11 @@ func (u *UserInfo) GetAddress() *UserInfoAddress {
return u.Address
}
// GetSubject implements [rp.SubjectGetter]
func (u *UserInfo) GetSubject() string {
return u.Subject
}
type uiAlias UserInfo
func (u *UserInfo) MarshalJSON() ([]byte, error) {

View file

@ -10,9 +10,9 @@ import (
"strings"
"time"
"gopkg.in/square/go-jose.v2"
jose "github.com/go-jose/go-jose/v3"
str "github.com/zitadel/oidc/v2/pkg/strings"
str "github.com/zitadel/oidc/v3/pkg/strings"
)
type Claims interface {
@ -61,10 +61,19 @@ var (
ErrAtHash = errors.New("at_hash does not correspond to access token")
)
type Verifier interface {
Issuer() string
MaxAgeIAT() time.Duration
Offset() time.Duration
// Verifier caries configuration for the various token verification
// functions. Use package specific constructor functions to know
// which values need to be set.
type Verifier struct {
Issuer string
MaxAgeIAT time.Duration
Offset time.Duration
ClientID string
SupportedSignAlgs []string
MaxAge time.Duration
ACR ACRVerifier
KeySet KeySet
Nonce func(ctx context.Context) string
}
// ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
@ -121,6 +130,11 @@ func CheckAudience(claims Claims, clientID string) error {
return nil
}
// CheckAuthorizedParty checks azp (authorized party) claim requirements.
//
// If the ID Token contains multiple audiences, the Client SHOULD verify that an azp Claim is present.
// If an azp Claim is present, the Client SHOULD verify that its client_id is the Claim Value.
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func CheckAuthorizedParty(claims Claims, clientID string) error {
if len(claims.GetAudience()) > 1 {
if claims.GetAuthorizedParty() == "" {
@ -167,26 +181,26 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl
}
func CheckExpiration(claims Claims, offset time.Duration) error {
expiration := claims.GetExpiration().Round(time.Second)
if !time.Now().UTC().Add(offset).Before(expiration) {
expiration := claims.GetExpiration()
if !time.Now().Add(offset).Before(expiration) {
return ErrExpired
}
return nil
}
func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error {
issuedAt := claims.GetIssuedAt().Round(time.Second)
issuedAt := claims.GetIssuedAt()
if issuedAt.IsZero() {
return ErrIatMissing
}
nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second)
nowWithOffset := time.Now().Add(offset).Round(time.Second)
if issuedAt.After(nowWithOffset) {
return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset)
}
if maxAgeIAT == 0 {
return nil
}
maxAge := time.Now().UTC().Add(-maxAgeIAT).Round(time.Second)
maxAge := time.Now().Add(-maxAgeIAT).Round(time.Second)
if issuedAt.Before(maxAge) {
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt))
}
@ -216,8 +230,8 @@ func CheckAuthTime(claims Claims, maxAge time.Duration) error {
if claims.GetAuthTime().IsZero() {
return ErrAuthTimeNotPresent
}
authTime := claims.GetAuthTime().Round(time.Second)
maxAuthTime := time.Now().UTC().Add(-maxAge).Round(time.Second)
authTime := claims.GetAuthTime()
maxAuthTime := time.Now().Add(-maxAge).Round(time.Second)
if authTime.Before(maxAuthTime) {
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime))
}

View file

@ -0,0 +1,128 @@
package oidc_test
import (
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
func TestParseToken(t *testing.T) {
token, wantClaims := tu.ValidIDToken()
wantClaims.SignatureAlg = "" // unset, because is not part of the JSON payload
wantPayload, err := json.Marshal(wantClaims)
require.NoError(t, err)
tests := []struct {
name string
tokenString string
wantErr bool
}{
{
name: "split error",
tokenString: "nope",
wantErr: true,
},
{
name: "base64 error",
tokenString: "foo.~.bar",
wantErr: true,
},
{
name: "success",
tokenString: token,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotClaims := new(oidc.IDTokenClaims)
gotPayload, err := oidc.ParseToken(tt.tokenString, gotClaims)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, wantClaims, gotClaims)
assert.JSONEq(t, string(wantPayload), string(gotPayload))
})
}
}
func TestCheckSignature(t *testing.T) {
errCtx, cancel := context.WithCancel(context.Background())
cancel()
token, _ := tu.ValidIDToken()
payload, err := oidc.ParseToken(token, &oidc.IDTokenClaims{})
require.NoError(t, err)
type args struct {
ctx context.Context
token string
payload []byte
supportedSigAlgs []string
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "parse error",
args: args{
ctx: context.Background(),
token: "~",
payload: payload,
},
wantErr: oidc.ErrParse,
},
{
name: "default sigAlg",
args: args{
ctx: context.Background(),
token: token,
payload: payload,
},
},
{
name: "unsupported sigAlg",
args: args{
ctx: context.Background(),
token: token,
payload: payload,
supportedSigAlgs: []string{"foo", "bar"},
},
wantErr: oidc.ErrSignatureUnsupportedAlg,
},
{
name: "verify error",
args: args{
ctx: errCtx,
token: token,
payload: payload,
},
wantErr: oidc.ErrSignatureInvalid,
},
{
name: "inequal payloads",
args: args{
ctx: context.Background(),
token: token,
payload: []byte{0, 1, 2},
},
wantErr: oidc.ErrSignatureInvalidPayload,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
claims := new(oidc.TokenClaims)
err := oidc.CheckSignature(tt.args.ctx, tt.args.token, tt.args.payload, claims, tt.args.supportedSigAlgs, tu.KeySet{})
assert.ErrorIs(t, err, tt.wantErr)
})
}
}

374
pkg/oidc/verifier_test.go Normal file
View file

@ -0,0 +1,374 @@
package oidc
import (
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDecryptToken(t *testing.T) {
const tokenString = "ABC"
got, err := DecryptToken(tokenString)
require.NoError(t, err)
assert.Equal(t, tokenString, got)
}
func TestDefaultACRVerifier(t *testing.T) {
acrVerfier := DefaultACRVerifier([]string{"foo", "bar"})
tests := []struct {
name string
acr string
wantErr string
}{
{
name: "ok",
acr: "bar",
},
{
name: "error",
acr: "hello",
wantErr: "expected one of: [foo bar], got: \"hello\"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := acrVerfier(tt.acr)
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
return
}
require.NoError(t, err)
})
}
}
func TestCheckSubject(t *testing.T) {
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrSubjectMissing,
},
{
name: "ok",
claims: &TokenClaims{
Subject: "foo",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckSubject(tt.claims)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckIssuer(t *testing.T) {
const issuer = "foo.bar"
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrIssuerInvalid,
},
{
name: "wrong",
claims: &TokenClaims{
Issuer: "wrong",
},
wantErr: ErrIssuerInvalid,
},
{
name: "ok",
claims: &TokenClaims{
Issuer: issuer,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckIssuer(tt.claims, issuer)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckAudience(t *testing.T) {
const clientID = "foo.bar"
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrAudience,
},
{
name: "wrong",
claims: &TokenClaims{
Audience: []string{"wrong"},
},
wantErr: ErrAudience,
},
{
name: "ok",
claims: &TokenClaims{
Audience: []string{clientID},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckAudience(tt.claims, clientID)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckAuthorizedParty(t *testing.T) {
const clientID = "foo.bar"
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "single audience, no azp",
claims: &TokenClaims{
Audience: []string{clientID},
},
},
{
name: "multiple audience, no azp",
claims: &TokenClaims{
Audience: []string{clientID, "other"},
},
wantErr: ErrAzpMissing,
},
{
name: "single audience, with azp",
claims: &TokenClaims{
Audience: []string{clientID},
AuthorizedParty: clientID,
},
},
{
name: "multiple audience, with azp",
claims: &TokenClaims{
Audience: []string{clientID, "other"},
AuthorizedParty: clientID,
},
},
{
name: "wrong azp",
claims: &TokenClaims{
AuthorizedParty: "wrong",
},
wantErr: ErrAzpInvalid,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckAuthorizedParty(tt.claims, clientID)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckExpiration(t *testing.T) {
const offset = time.Minute
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrExpired,
},
{
name: "expired",
claims: &TokenClaims{
Expiration: FromTime(time.Now().Add(-2 * offset)),
},
wantErr: ErrExpired,
},
{
name: "valid",
claims: &TokenClaims{
Expiration: FromTime(time.Now().Add(2 * offset)),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckExpiration(tt.claims, offset)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckIssuedAt(t *testing.T) {
const offset = time.Minute
tests := []struct {
name string
maxAgeIAT time.Duration
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrIatMissing,
},
{
name: "future",
claims: &TokenClaims{
IssuedAt: FromTime(time.Now().Add(time.Hour)),
},
wantErr: ErrIatInFuture,
},
{
name: "no max",
claims: &TokenClaims{
IssuedAt: FromTime(time.Now()),
},
},
{
name: "past max",
maxAgeIAT: time.Minute,
claims: &TokenClaims{
IssuedAt: FromTime(time.Now().Add(-time.Hour)),
},
wantErr: ErrIatToOld,
},
{
name: "within max",
maxAgeIAT: time.Hour,
claims: &TokenClaims{
IssuedAt: FromTime(time.Now()),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckIssuedAt(tt.claims, tt.maxAgeIAT, offset)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckNonce(t *testing.T) {
const nonce = "123"
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrNonceInvalid,
},
{
name: "wrong",
claims: &TokenClaims{
Nonce: "wrong",
},
wantErr: ErrNonceInvalid,
},
{
name: "ok",
claims: &TokenClaims{
Nonce: nonce,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckNonce(tt.claims, nonce)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckAuthorizationContextClassReference(t *testing.T) {
tests := []struct {
name string
acr ACRVerifier
wantErr error
}{
{
name: "error",
acr: func(s string) error { return errors.New("oops") },
wantErr: ErrAcrInvalid,
},
{
name: "ok",
acr: func(s string) error { return nil },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckAuthorizationContextClassReference(&IDTokenClaims{}, tt.acr)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckAuthTime(t *testing.T) {
tests := []struct {
name string
claims Claims
maxAge time.Duration
wantErr error
}{
{
name: "no max age",
claims: &TokenClaims{},
},
{
name: "missing",
claims: &TokenClaims{},
maxAge: time.Minute,
wantErr: ErrAuthTimeNotPresent,
},
{
name: "expired",
maxAge: time.Minute,
claims: &TokenClaims{
AuthTime: FromTime(time.Now().Add(-time.Hour)),
},
wantErr: ErrAuthTimeToOld,
},
{
name: "ok",
maxAge: time.Minute,
claims: &TokenClaims{
AuthTime: NowTime(),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckAuthTime(tt.claims, tt.maxAge)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}

View file

@ -2,6 +2,7 @@ package op
import (
"context"
"errors"
"fmt"
"net"
"net/http"
@ -10,11 +11,10 @@ import (
"strings"
"time"
"github.com/gorilla/mux"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
str "github.com/zitadel/oidc/v2/pkg/strings"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
str "github.com/zitadel/oidc/v3/pkg/strings"
"golang.org/x/exp/slog"
)
type AuthRequest interface {
@ -39,16 +39,17 @@ type Authorizer interface {
Storage() Storage
Decoder() httphelper.Decoder
Encoder() httphelper.Encoder
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
Crypto() Crypto
RequestObjectSupported() bool
Logger() *slog.Logger
}
// AuthorizeValidator is an extension of Authorizer interface
// implementing its own validation mechanism for the auth request
type AuthorizeValidator interface {
Authorizer
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error)
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, *IDTokenHintVerifier) (string, error)
}
func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) {
@ -68,23 +69,23 @@ func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, *
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder())
if err != nil {
AuthRequestError(w, r, nil, err, authorizer.Encoder())
AuthRequestError(w, r, nil, err, authorizer)
return
}
ctx := r.Context()
if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, authReq, err, authorizer)
return
}
}
if authReq.ClientID == "" {
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer.Encoder())
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer)
return
}
if authReq.RedirectURI == "" {
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer.Encoder())
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer)
return
}
validation := ValidateAuthRequest
@ -93,21 +94,21 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
}
userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx))
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, authReq, err, authorizer)
return
}
if authReq.RequestParam != "" {
AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder())
AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer)
return
}
req, err := authorizer.Storage().CreateAuthRequest(ctx, authReq, userID)
if err != nil {
AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder())
AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer)
return
}
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
if err != nil {
AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder())
AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer)
return
}
RedirectToLogin(req.GetID(), client, w, r)
@ -129,31 +130,31 @@ func ParseAuthorizeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.A
// ParseRequestObject parse the `request` parameter, validates the token including the signature
// and copies the token claims into the auth request
func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) (*oidc.AuthRequest, error) {
func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) error {
requestObject := new(oidc.RequestObject)
payload, err := oidc.ParseToken(authReq.RequestParam, requestObject)
if err != nil {
return nil, err
return err
}
if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID {
return authReq, oidc.ErrInvalidRequest()
return oidc.ErrInvalidRequest()
}
if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType {
return authReq, oidc.ErrInvalidRequest()
return oidc.ErrInvalidRequest()
}
if requestObject.Issuer != requestObject.ClientID {
return authReq, oidc.ErrInvalidRequest()
return oidc.ErrInvalidRequest()
}
if !str.Contains(requestObject.Audience, issuer) {
return authReq, oidc.ErrInvalidRequest()
return oidc.ErrInvalidRequest()
}
keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer}
if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil {
return authReq, err
return err
}
CopyRequestObjectToAuthRequest(authReq, requestObject)
return authReq, nil
return nil
}
// CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request
@ -205,7 +206,7 @@ func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oi
}
// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) {
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) {
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
if err != nil {
return "", err
@ -385,7 +386,7 @@ func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType)
// ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request)
// and returns the `sub` claim
func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier IDTokenHintVerifier) (string, error) {
func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier *IDTokenHintVerifier) (string, error) {
if idTokenHint == "" {
return "", nil
}
@ -405,32 +406,41 @@ func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *
// AuthorizeCallback handles the callback after authentication in the Login UI
func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
params := mux.Vars(r)
id := params["id"]
if id == "" {
AuthRequestError(w, r, nil, fmt.Errorf("auth request callback is missing id"), authorizer.Encoder())
id, err := ParseAuthorizeCallbackRequest(r)
if err != nil {
AuthRequestError(w, r, nil, err, authorizer)
return
}
authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
if err != nil {
AuthRequestError(w, r, nil, err, authorizer.Encoder())
AuthRequestError(w, r, nil, err, authorizer)
return
}
if !authReq.Done() {
AuthRequestError(w, r, authReq,
oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."),
authorizer.Encoder())
authorizer)
return
}
AuthResponse(authReq, authorizer, w, r)
}
func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) {
if err = r.ParseForm(); err != nil {
return "", fmt.Errorf("cannot parse form: %w", err)
}
id = r.Form.Get("id")
if id == "" {
return "", errors.New("auth request callback is missing id")
}
return id, nil
}
// AuthResponse creates the successful authentication response (either code or tokens)
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, authReq, err, authorizer)
return
}
if authReq.GetResponseType() == oidc.ResponseTypeCode {
@ -444,7 +454,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, authReq, err, authorizer)
return
}
codeResponse := struct {
@ -456,7 +466,7 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques
}
callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, authReq, err, authorizer)
return
}
http.Redirect(w, r, callback, http.StatusFound)
@ -471,12 +481,12 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque
createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly
resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "", "")
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, authReq, err, authorizer)
return
}
callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, authReq, err, authorizer)
return
}
http.Redirect(w, r, callback, http.StatusFound)

View file

@ -11,14 +11,16 @@ import (
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/example/server/storage"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v2/pkg/op/mock"
"github.com/zitadel/oidc/v3/example/server/storage"
tu "github.com/zitadel/oidc/v3/internal/testutil"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/oidc/v3/pkg/op/mock"
"github.com/zitadel/schema"
"golang.org/x/exp/slog"
)
func TestAuthorize(t *testing.T) {
@ -39,7 +41,7 @@ func TestAuthorize(t *testing.T) {
expect := authorizer.EXPECT()
expect.Decoder().Return(schema.NewDecoder())
expect.Encoder().Return(schema.NewEncoder())
expect.Logger().Return(slog.Default())
if tt.expect != nil {
tt.expect(expect)
@ -123,7 +125,7 @@ func TestValidateAuthRequest(t *testing.T) {
type args struct {
authRequest *oidc.AuthRequest
storage op.Storage
verifier op.IDTokenHintVerifier
verifier *op.IDTokenHintVerifier
}
tests := []struct {
name string
@ -996,7 +998,7 @@ func TestAuthResponseCode(t *testing.T) {
authorizer.EXPECT().Crypto().Return(&mockCrypto{
returnErr: io.ErrClosedPipe,
})
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
authorizer.EXPECT().Logger().Return(slog.Default())
return authorizer
},
},
@ -1071,3 +1073,71 @@ func TestAuthResponseCode(t *testing.T) {
})
}
}
func Test_parseAuthorizeCallbackRequest(t *testing.T) {
tests := []struct {
name string
url string
wantId string
wantErr bool
}{
{
name: "parse error",
url: "/?id;=99",
wantErr: true,
},
{
name: "missing id",
url: "/",
wantErr: true,
},
{
name: "ok",
url: "/?id=99",
wantId: "99",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, tt.url, nil)
gotId, err := op.ParseAuthorizeCallbackRequest(r)
if tt.wantErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.wantId, gotId)
})
}
}
func TestValidateAuthReqIDTokenHint(t *testing.T) {
token, _ := tu.ValidIDToken()
tests := []struct {
name string
idTokenHint string
want string
wantErr error
}{
{
name: "empty",
},
{
name: "verify err",
idTokenHint: "foo",
wantErr: oidc.ErrLoginRequired(),
},
{
name: "ok",
idTokenHint: token,
want: tu.ValidSubject,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := op.ValidateAuthReqIDTokenHint(context.Background(), tt.idTokenHint, op.NewIDTokenHintVerifier(tu.ValidIssuer, tu.KeySet{}))
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -7,8 +7,8 @@ import (
"net/url"
"time"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
//go:generate go get github.com/dmarkham/enumer
@ -87,7 +87,7 @@ var (
)
type ClientJWTProfile interface {
JWTProfileVerifier(context.Context) JWTProfileVerifier
JWTProfileVerifier(context.Context) *JWTProfileVerifier
}
func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) {
@ -180,3 +180,10 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au
}
return data.ClientID, false, nil
}
type ClientCredentials struct {
ClientID string `schema:"client_id"`
ClientSecret string `schema:"client_secret"` // Client secret from Basic auth or request body
ClientAssertion string `schema:"client_assertion"` // JWT
ClientAssertionType string `schema:"client_assertion_type"`
}

View file

@ -11,18 +11,18 @@ import (
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v2/pkg/op/mock"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/oidc/v3/pkg/op/mock"
"github.com/zitadel/schema"
)
type testClientJWTProfile struct{}
func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil }
func (testClientJWTProfile) JWTProfileVerifier(context.Context) *op.JWTProfileVerifier { return nil }
func TestClientJWTAuth(t *testing.T) {
type args struct {

View file

@ -22,14 +22,14 @@ var (
type Configuration interface {
IssuerFromRequest(r *http.Request) string
Insecure() bool
AuthorizationEndpoint() Endpoint
TokenEndpoint() Endpoint
IntrospectionEndpoint() Endpoint
UserinfoEndpoint() Endpoint
RevocationEndpoint() Endpoint
EndSessionEndpoint() Endpoint
KeysEndpoint() Endpoint
DeviceAuthorizationEndpoint() Endpoint
AuthorizationEndpoint() *Endpoint
TokenEndpoint() *Endpoint
IntrospectionEndpoint() *Endpoint
UserinfoEndpoint() *Endpoint
RevocationEndpoint() *Endpoint
EndSessionEndpoint() *Endpoint
KeysEndpoint() *Endpoint
DeviceAuthorizationEndpoint() *Endpoint
AuthMethodPostSupported() bool
CodeMethodS256Supported() bool

View file

@ -1,7 +1,7 @@
package op
import (
"github.com/zitadel/oidc/v2/pkg/crypto"
"github.com/zitadel/oidc/v3/pkg/crypto"
)
type Crypto interface {

View file

@ -12,8 +12,8 @@ import (
"strings"
"time"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
type DeviceAuthorizationConfig struct {
@ -57,47 +57,57 @@ var (
func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if err := DeviceAuthorization(w, r, o); err != nil {
RequestError(w, r, err)
RequestError(w, r, err, o.Logger())
}
}
}
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error {
storage, err := assertDeviceStorage(o.Storage())
if err != nil {
return err
}
req, err := ParseDeviceCodeRequest(r, o)
if err != nil {
return err
}
response, err := createDeviceAuthorization(r.Context(), req, req.ClientID, o)
if err != nil {
return err
}
httphelper.MarshalJSON(w, response)
return nil
}
func createDeviceAuthorization(ctx context.Context, req *oidc.DeviceAuthorizationRequest, clientID string, o OpenIDProvider) (*oidc.DeviceAuthorizationResponse, error) {
storage, err := assertDeviceStorage(o.Storage())
if err != nil {
return nil, err
}
config := o.DeviceAuthorization()
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
if err != nil {
return err
return nil, NewStatusError(err, http.StatusInternalServerError)
}
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval)
if err != nil {
return err
return nil, NewStatusError(err, http.StatusInternalServerError)
}
expires := time.Now().Add(config.Lifetime)
err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes)
err = storage.StoreDeviceAuthorization(ctx, clientID, deviceCode, userCode, expires, req.Scopes)
if err != nil {
return err
return nil, NewStatusError(err, http.StatusInternalServerError)
}
var verification *url.URL
if config.UserFormURL != "" {
if verification, err = url.Parse(config.UserFormURL); err != nil {
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
return nil, NewStatusError(err, http.StatusInternalServerError)
}
} else {
if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil {
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
if verification, err = url.Parse(IssuerFromContext(ctx)); err != nil {
err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
return nil, NewStatusError(err, http.StatusInternalServerError)
}
verification.Path = config.UserFormPath
}
@ -112,9 +122,7 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
verification.RawQuery = "user_code=" + userCode
response.VerificationURIComplete = verification.String()
httphelper.MarshalJSON(w, response)
return nil
return response, nil
}
func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) {
@ -201,7 +209,7 @@ func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
r = r.WithContext(ctx)
if err := deviceAccessToken(w, r, exchanger); err != nil {
RequestError(w, r, err)
RequestError(w, r, err, exchanger.Logger())
}
}

View file

@ -16,8 +16,9 @@ import (
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v3/example/server/storage"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
)
func Test_deviceAuthorizationHandler(t *testing.T) {
@ -319,7 +320,7 @@ func BenchmarkNewUserCode(b *testing.B) {
}
func TestDeviceAccessToken(t *testing.T) {
storage := testProvider.Storage().(op.DeviceAuthorizationStorage)
storage := testProvider.Storage().(*storage.Storage)
storage.StoreDeviceAuthorization(context.Background(), "native", "qwerty", "yuiop", time.Now().Add(time.Minute), []string{"foo"})
storage.CompleteDeviceAuthorization(context.Background(), "yuiop", "tim")
@ -344,7 +345,7 @@ func TestDeviceAccessToken(t *testing.T) {
func TestCheckDeviceAuthorizationState(t *testing.T) {
now := time.Now()
storage := testProvider.Storage().(op.DeviceAuthorizationStorage)
storage := testProvider.Storage().(*storage.Storage)
storage.StoreDeviceAuthorization(context.Background(), "native", "pending", "pending", now.Add(time.Minute), []string{"foo"})
storage.StoreDeviceAuthorization(context.Background(), "native", "denied", "denied", now.Add(time.Minute), []string{"foo"})
storage.StoreDeviceAuthorization(context.Background(), "native", "completed", "completed", now.Add(time.Minute), []string{"foo"})

View file

@ -4,10 +4,10 @@ import (
"context"
"net/http"
"gopkg.in/square/go-jose.v2"
jose "github.com/go-jose/go-jose/v3"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
type DiscoverStorage interface {
@ -25,7 +25,7 @@ var DefaultSupportedScopes = []string{
func discoveryHandler(c Configuration, s DiscoverStorage) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
Discover(w, CreateDiscoveryConfig(r, c, s))
Discover(w, CreateDiscoveryConfig(r.Context(), c, s))
}
}
@ -33,8 +33,8 @@ func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
httphelper.MarshalJSON(w, config)
}
func CreateDiscoveryConfig(r *http.Request, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration {
issuer := config.IssuerFromRequest(r)
func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration {
issuer := IssuerFromContext(ctx)
return &oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: config.AuthorizationEndpoint().Absolute(issuer),
@ -49,7 +49,38 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov
ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config),
SubjectTypesSupported: SubjectTypes(config),
IDTokenSigningAlgValuesSupported: SigAlgorithms(r.Context(), storage),
IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config),
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config),
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config),
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config),
ClaimsSupported: SupportedClaims(config),
CodeChallengeMethodsSupported: CodeChallengeMethods(config),
UILocalesSupported: config.SupportedUILocales(),
RequestParameterSupported: config.RequestObjectSupported(),
}
}
func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage DiscoverStorage, endpoints *Endpoints) *oidc.DiscoveryConfiguration {
issuer := IssuerFromContext(ctx)
return &oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: endpoints.Authorization.Absolute(issuer),
TokenEndpoint: endpoints.Token.Absolute(issuer),
IntrospectionEndpoint: endpoints.Introspection.Absolute(issuer),
UserinfoEndpoint: endpoints.Userinfo.Absolute(issuer),
RevocationEndpoint: endpoints.Revocation.Absolute(issuer),
EndSessionEndpoint: endpoints.EndSession.Absolute(issuer),
JwksURI: endpoints.JwksURI.Absolute(issuer),
DeviceAuthorizationEndpoint: endpoints.DeviceAuthorization.Absolute(issuer),
ScopesSupported: Scopes(config),
ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config),
SubjectTypesSupported: SubjectTypes(config),
IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),

View file

@ -6,14 +6,14 @@ import (
"net/http/httptest"
"testing"
jose "github.com/go-jose/go-jose/v3"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v2/pkg/op/mock"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/oidc/v3/pkg/op/mock"
)
func TestDiscover(t *testing.T) {
@ -48,9 +48,9 @@ func TestDiscover(t *testing.T) {
func TestCreateDiscoveryConfig(t *testing.T) {
type args struct {
request *http.Request
c op.Configuration
s op.DiscoverStorage
ctx context.Context
c op.Configuration
s op.DiscoverStorage
}
tests := []struct {
name string
@ -61,7 +61,7 @@ func TestCreateDiscoveryConfig(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.CreateDiscoveryConfig(tt.args.request, tt.args.c, tt.args.s)
got := op.CreateDiscoveryConfig(tt.args.ctx, tt.args.c, tt.args.s)
assert.Equal(t, tt.want, got)
})
}

View file

@ -1,32 +1,46 @@
package op
import "strings"
import (
"errors"
"strings"
)
type Endpoint struct {
path string
url string
}
func NewEndpoint(path string) Endpoint {
return Endpoint{path: path}
func NewEndpoint(path string) *Endpoint {
return &Endpoint{path: path}
}
func NewEndpointWithURL(path, url string) Endpoint {
return Endpoint{path: path, url: url}
func NewEndpointWithURL(path, url string) *Endpoint {
return &Endpoint{path: path, url: url}
}
func (e Endpoint) Relative() string {
func (e *Endpoint) Relative() string {
if e == nil {
return ""
}
return relativeEndpoint(e.path)
}
func (e Endpoint) Absolute(host string) string {
func (e *Endpoint) Absolute(host string) string {
if e == nil {
return ""
}
if e.url != "" {
return e.url
}
return absoluteEndpoint(host, e.path)
}
func (e Endpoint) Validate() error {
var ErrNilEndpoint = errors.New("nil endpoint")
func (e *Endpoint) Validate() error {
if e == nil {
return ErrNilEndpoint
}
return nil // TODO:
}

View file

@ -3,13 +3,14 @@ package op_test
import (
"testing"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/op"
)
func TestEndpoint_Path(t *testing.T) {
tests := []struct {
name string
e op.Endpoint
e *op.Endpoint
want string
}{
{
@ -27,6 +28,11 @@ func TestEndpoint_Path(t *testing.T) {
op.NewEndpointWithURL("/test", "http://test.com/test"),
"/test",
},
{
"nil",
nil,
"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -43,7 +49,7 @@ func TestEndpoint_Absolute(t *testing.T) {
}
tests := []struct {
name string
e op.Endpoint
e *op.Endpoint
args args
want string
}{
@ -77,6 +83,12 @@ func TestEndpoint_Absolute(t *testing.T) {
args{"https://host"},
"https://test.com/test",
},
{
"nil",
nil,
args{"https://host"},
"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -91,16 +103,19 @@ func TestEndpoint_Absolute(t *testing.T) {
func TestEndpoint_Validate(t *testing.T) {
tests := []struct {
name string
e op.Endpoint
wantErr bool
e *op.Endpoint
wantErr error
}{
// TODO: Add test cases.
{
"nil",
nil,
op.ErrNilEndpoint,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.e.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
err := tt.e.Validate()
require.ErrorIs(t, err, tt.wantErr)
})
}
}

View file

@ -1,10 +1,14 @@
package op
import (
"context"
"errors"
"fmt"
"net/http"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/exp/slog"
)
type ErrAuthRequest interface {
@ -13,13 +17,31 @@ type ErrAuthRequest interface {
GetState() string
}
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder httphelper.Encoder) {
// LogAuthRequest is an optional interface,
// that allows logging AuthRequest fields.
// If the AuthRequest does not implement this interface,
// no details shall be printed to the logs.
type LogAuthRequest interface {
ErrAuthRequest
slog.LogValuer
}
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, authorizer Authorizer) {
e := oidc.DefaultToServerError(err, err.Error())
logger := authorizer.Logger().With("oidc_error", e)
if authReq == nil {
logger.Log(r.Context(), e.LogLevel(), "auth request")
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
e := oidc.DefaultToServerError(err, err.Error())
if logAuthReq, ok := authReq.(LogAuthRequest); ok {
logger = logger.With("auth_request", logAuthReq)
}
if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
logger.Log(r.Context(), e.LogLevel(), "auth request: not redirecting")
http.Error(w, e.Description, http.StatusBadRequest)
return
}
@ -28,19 +50,120 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
responseMode = rm.GetResponseMode()
}
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, authorizer.Encoder())
if err != nil {
logger.ErrorContext(r.Context(), "auth response URL", "error", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
logger.Log(r.Context(), e.LogLevel(), "auth request")
http.Redirect(w, r, url, http.StatusFound)
}
func RequestError(w http.ResponseWriter, r *http.Request, err error) {
func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
e := oidc.DefaultToServerError(err, err.Error())
status := http.StatusBadRequest
if e.ErrorType == oidc.InvalidClient {
status = 401
status = http.StatusUnauthorized
}
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
httphelper.MarshalJSONWithStatus(w, e, status)
}
// TryErrorRedirect tries to handle an error by redirecting a client.
// If this attempt fails, an error is returned that must be returned
// to the client instead.
func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) {
e := oidc.DefaultToServerError(parent, parent.Error())
logger = logger.With("oidc_error", e)
if authReq == nil {
logger.Log(ctx, e.LogLevel(), "auth request")
return nil, AsStatusError(e, http.StatusBadRequest)
}
if logAuthReq, ok := authReq.(LogAuthRequest); ok {
logger = logger.With("auth_request", logAuthReq)
}
if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
logger.Log(ctx, e.LogLevel(), "auth request: not redirecting")
return nil, AsStatusError(e, http.StatusBadRequest)
}
e.State = authReq.GetState()
var responseMode oidc.ResponseMode
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
responseMode = rm.GetResponseMode()
}
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
if err != nil {
logger.ErrorContext(ctx, "auth response URL", "error", err)
return nil, AsStatusError(err, http.StatusBadRequest)
}
logger.Log(ctx, e.LogLevel(), "auth request redirect", "url", url)
return NewRedirect(url), nil
}
// StatusError wraps an error with a HTTP status code.
// The status code is passed to the handler's writer.
type StatusError struct {
parent error
statusCode int
}
// NewStatusError sets the parent and statusCode to a new StatusError.
// It is recommended for parent to be an [oidc.Error].
//
// Typically implementations should only use this to signal something
// very specific, like an internal server error.
// If a returned error is not a StatusError, the framework
// will set a statusCode based on what the standard specifies,
// which is [http.StatusBadRequest] for most of the time.
// If the error encountered can described clearly with a [oidc.Error],
// do not use this function, as it might break standard rules!
func NewStatusError(parent error, statusCode int) StatusError {
return StatusError{
parent: parent,
statusCode: statusCode,
}
}
// AsStatusError unwraps a StatusError from err
// and returns it unmodified if found.
// If no StatuError was found, a new one is returned
// with statusCode set to it as a default.
func AsStatusError(err error, statusCode int) (target StatusError) {
if errors.As(err, &target) {
return target
}
return NewStatusError(err, statusCode)
}
func (e StatusError) Error() string {
return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error())
}
func (e StatusError) Unwrap() error {
return e.parent
}
func (e StatusError) Is(err error) bool {
var target StatusError
if !errors.As(err, &target) {
return false
}
return errors.Is(e.parent, target.parent) &&
e.statusCode == target.statusCode
}
// WriteError asserts for a StatusError containing an [oidc.Error].
// If no StatusError is found, the status code will default to [http.StatusBadRequest].
// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError].
func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
statusError := AsStatusError(err, http.StatusBadRequest)
e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error())
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
httphelper.MarshalJSONWithStatus(w, e, statusError.statusCode)
}

677
pkg/op/error_test.go Normal file
View file

@ -0,0 +1,677 @@
package op
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/schema"
"golang.org/x/exp/slog"
)
func TestAuthRequestError(t *testing.T) {
type args struct {
authReq ErrAuthRequest
err error
}
tests := []struct {
name string
args args
wantCode int
wantHeaders map[string]string
wantBody string
wantLog string
}{
{
name: "nil auth request",
args: args{
authReq: nil,
err: io.ErrClosedPipe,
},
wantCode: http.StatusBadRequest,
wantBody: "io: read/write on closed pipe\n",
wantLog: `{
"level":"ERROR",
"msg":"auth request",
"time":"not",
"oidc_error":{
"description":"io: read/write on closed pipe",
"parent":"io: read/write on closed pipe",
"type":"server_error"
}
}`,
},
{
name: "auth request, no redirect URI",
args: args{
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
err: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
wantCode: http.StatusBadRequest,
wantBody: "sign in\n",
wantLog: `{
"level":"WARN",
"msg":"auth request: not redirecting",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"",
"response_type":"responseType",
"scopes":"a b"
},
"oidc_error":{
"description":"sign in",
"type":"interaction_required"
}
}`,
},
{
name: "auth request, redirect disabled",
args: args{
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
RedirectURI: "http://example.com/callback",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
err: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"),
},
wantCode: http.StatusBadRequest,
wantBody: "oops\n",
wantLog: `{
"level":"WARN",
"msg":"auth request: not redirecting",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"http://example.com/callback",
"response_type":"responseType",
"scopes":"a b"
},
"oidc_error":{
"description":"oops",
"type":"invalid_request",
"redirect_disabled":true
}
}`,
},
{
name: "auth request, url parse error",
args: args{
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
RedirectURI: "can't parse this!\n",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
err: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
wantCode: http.StatusBadRequest,
wantBody: "ErrorType=server_error Parent=parse \"can't parse this!\\n\": net/url: invalid control character in URL\n",
wantLog: `{
"level":"ERROR",
"msg":"auth response URL",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"can't parse this!\n",
"response_type":"responseType",
"scopes":"a b"
},
"error":{
"type":"server_error",
"parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL"
},
"oidc_error":{
"description":"sign in",
"type":"interaction_required"
}
}`,
},
{
name: "auth request redirect",
args: args{
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
RedirectURI: "http://example.com/callback",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
err: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
wantCode: http.StatusFound,
wantHeaders: map[string]string{"Location": "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"},
wantLog: `{
"level":"WARN",
"msg":"auth request",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"http://example.com/callback",
"response_type":"responseType",
"scopes":"a b"
},
"oidc_error":{
"description":"sign in",
"type":"interaction_required"
}
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logOut := new(strings.Builder)
authorizer := &Provider{
encoder: schema.NewEncoder(),
logger: slog.New(
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
Level: slog.LevelInfo,
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
),
}
w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/path", nil)
AuthRequestError(w, r, tt.args.authReq, tt.args.err, authorizer)
res := w.Result()
defer res.Body.Close()
assert.Equal(t, tt.wantCode, res.StatusCode)
for key, wantHeader := range tt.wantHeaders {
gotHeader := res.Header.Get(key)
assert.Equalf(t, wantHeader, gotHeader, "header %q", key)
}
gotBody, err := io.ReadAll(res.Body)
require.NoError(t, err, "read result body")
assert.Equal(t, tt.wantBody, string(gotBody), "result body")
gotLog := logOut.String()
t.Log(gotLog)
assert.JSONEq(t, tt.wantLog, gotLog, "log output")
})
}
}
func TestRequestError(t *testing.T) {
tests := []struct {
name string
err error
wantCode int
wantBody string
wantLog string
}{
{
name: "server error",
err: io.ErrClosedPipe,
wantCode: http.StatusBadRequest,
wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`,
wantLog: `{
"level":"ERROR",
"msg":"request error",
"time":"not",
"oidc_error":{
"parent":"io: read/write on closed pipe",
"description":"io: read/write on closed pipe",
"type":"server_error"}
}`,
},
{
name: "invalid client",
err: oidc.ErrInvalidClient().WithDescription("not good"),
wantCode: http.StatusUnauthorized,
wantBody: `{"error":"invalid_client", "error_description":"not good"}`,
wantLog: `{
"level":"WARN",
"msg":"request error",
"time":"not",
"oidc_error":{
"description":"not good",
"type":"invalid_client"}
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logOut := new(strings.Builder)
logger := slog.New(
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
Level: slog.LevelInfo,
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
)
w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/path", nil)
RequestError(w, r, tt.err, logger)
res := w.Result()
defer res.Body.Close()
assert.Equal(t, tt.wantCode, res.StatusCode, "status code")
gotBody, err := io.ReadAll(res.Body)
require.NoError(t, err, "read result body")
assert.JSONEq(t, tt.wantBody, string(gotBody), "result body")
gotLog := logOut.String()
t.Log(gotLog)
assert.JSONEq(t, tt.wantLog, gotLog, "log output")
})
}
}
func TestTryErrorRedirect(t *testing.T) {
type args struct {
ctx context.Context
authReq ErrAuthRequest
parent error
}
tests := []struct {
name string
args args
want *Redirect
wantErr error
wantLog string
}{
{
name: "nil auth request",
args: args{
ctx: context.Background(),
authReq: nil,
parent: io.ErrClosedPipe,
},
wantErr: NewStatusError(io.ErrClosedPipe, http.StatusBadRequest),
wantLog: `{
"level":"ERROR",
"msg":"auth request",
"time":"not",
"oidc_error":{
"description":"io: read/write on closed pipe",
"parent":"io: read/write on closed pipe",
"type":"server_error"
}
}`,
},
{
name: "auth request, no redirect URI",
args: args{
ctx: context.Background(),
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
wantErr: NewStatusError(oidc.ErrInteractionRequired().WithDescription("sign in"), http.StatusBadRequest),
wantLog: `{
"level":"WARN",
"msg":"auth request: not redirecting",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"",
"response_type":"responseType",
"scopes":"a b"
},
"oidc_error":{
"description":"sign in",
"type":"interaction_required"
}
}`,
},
{
name: "auth request, redirect disabled",
args: args{
ctx: context.Background(),
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
RedirectURI: "http://example.com/callback",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
parent: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"),
},
wantErr: NewStatusError(oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), http.StatusBadRequest),
wantLog: `{
"level":"WARN",
"msg":"auth request: not redirecting",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"http://example.com/callback",
"response_type":"responseType",
"scopes":"a b"
},
"oidc_error":{
"description":"oops",
"type":"invalid_request",
"redirect_disabled":true
}
}`,
},
{
name: "auth request, url parse error",
args: args{
ctx: context.Background(),
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
RedirectURI: "can't parse this!\n",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
wantErr: func() error {
//lint:ignore SA1007 just recreating the error for testing
_, err := url.Parse("can't parse this!\n")
err = oidc.ErrServerError().WithParent(err)
return NewStatusError(err, http.StatusBadRequest)
}(),
wantLog: `{
"level":"ERROR",
"msg":"auth response URL",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"can't parse this!\n",
"response_type":"responseType",
"scopes":"a b"
},
"error":{
"type":"server_error",
"parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL"
},
"oidc_error":{
"description":"sign in",
"type":"interaction_required"
}
}`,
},
{
name: "auth request redirect",
args: args{
ctx: context.Background(),
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
RedirectURI: "http://example.com/callback",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
want: &Redirect{
URL: "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1",
},
wantLog: `{
"level":"WARN",
"msg":"auth request redirect",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"http://example.com/callback",
"response_type":"responseType",
"scopes":"a b"
},
"oidc_error":{
"description":"sign in",
"type":"interaction_required"
},
"url":"http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logOut := new(strings.Builder)
logger := slog.New(
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
Level: slog.LevelInfo,
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
)
encoder := schema.NewEncoder()
got, err := TryErrorRedirect(tt.args.ctx, tt.args.authReq, tt.args.parent, encoder, logger)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
gotLog := logOut.String()
t.Log(gotLog)
assert.JSONEq(t, tt.wantLog, gotLog, "log output")
})
}
}
func TestNewStatusError(t *testing.T) {
err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
want := "Internal Server Error: io: read/write on closed pipe"
got := fmt.Sprint(err)
assert.Equal(t, want, got)
}
func TestAsStatusError(t *testing.T) {
type args struct {
err error
statusCode int
}
tests := []struct {
name string
args args
want string
}{
{
name: "already status error",
args: args{
err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError),
statusCode: http.StatusBadRequest,
},
want: "Internal Server Error: io: read/write on closed pipe",
},
{
name: "oidc error",
args: args{
err: oidc.ErrAcrInvalid,
statusCode: http.StatusBadRequest,
},
want: "Bad Request: acr is invalid",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := AsStatusError(tt.args.err, tt.args.statusCode)
got := fmt.Sprint(err)
assert.Equal(t, tt.want, got)
})
}
}
func TestStatusError_Unwrap(t *testing.T) {
err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
require.ErrorIs(t, err, io.ErrClosedPipe)
}
func TestStatusError_Is(t *testing.T) {
type args struct {
err error
}
tests := []struct {
name string
args args
want bool
}{
{
name: "nil error",
args: args{err: nil},
want: false,
},
{
name: "other error",
args: args{err: io.EOF},
want: false,
},
{
name: "other parent",
args: args{err: NewStatusError(io.EOF, http.StatusInternalServerError)},
want: false,
},
{
name: "other status",
args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInsufficientStorage)},
want: false,
},
{
name: "same",
args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)},
want: true,
},
{
name: "wrapped",
args: args{err: fmt.Errorf("wrap: %w", NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError))},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
if got := e.Is(tt.args.err); got != tt.want {
t.Errorf("StatusError.Is() = %v, want %v", got, tt.want)
}
})
}
}
func TestWriteError(t *testing.T) {
tests := []struct {
name string
err error
wantStatus int
wantBody string
wantLog string
}{
{
name: "not a status or oidc error",
err: io.ErrClosedPipe,
wantStatus: http.StatusBadRequest,
wantBody: `{
"error":"server_error",
"error_description":"io: read/write on closed pipe"
}`,
wantLog: `{
"level":"ERROR",
"msg":"request error",
"oidc_error":{
"description":"io: read/write on closed pipe",
"parent":"io: read/write on closed pipe",
"type":"server_error"
},
"time":"not"
}`,
},
{
name: "status error w/o oidc",
err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError),
wantStatus: http.StatusInternalServerError,
wantBody: `{
"error":"server_error",
"error_description":"io: read/write on closed pipe"
}`,
wantLog: `{
"level":"ERROR",
"msg":"request error",
"oidc_error":{
"description":"io: read/write on closed pipe",
"parent":"io: read/write on closed pipe",
"type":"server_error"
},
"time":"not"
}`,
},
{
name: "oidc error w/o status",
err: oidc.ErrInvalidRequest().WithDescription("oops"),
wantStatus: http.StatusBadRequest,
wantBody: `{
"error":"invalid_request",
"error_description":"oops"
}`,
wantLog: `{
"level":"WARN",
"msg":"request error",
"oidc_error":{
"description":"oops",
"type":"invalid_request"
},
"time":"not"
}`,
},
{
name: "status with oidc error",
err: NewStatusError(
oidc.ErrUnauthorizedClient().WithDescription("oops"),
http.StatusUnauthorized,
),
wantStatus: http.StatusUnauthorized,
wantBody: `{
"error":"unauthorized_client",
"error_description":"oops"
}`,
wantLog: `{
"level":"WARN",
"msg":"request error",
"oidc_error":{
"description":"oops",
"type":"unauthorized_client"
},
"time":"not"
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logOut := new(strings.Builder)
logger := slog.New(
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
Level: slog.LevelInfo,
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
)
r := httptest.NewRequest("GET", "/target", nil)
w := httptest.NewRecorder()
WriteError(w, r, tt.err, logger)
res := w.Result()
assert.Equal(t, tt.wantStatus, res.StatusCode, "status code")
gotBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.JSONEq(t, tt.wantBody, string(gotBody), "body")
assert.JSONEq(t, tt.wantLog, logOut.String())
})
}
}

View file

@ -4,9 +4,9 @@ import (
"context"
"net/http"
"gopkg.in/square/go-jose.v2"
jose "github.com/go-jose/go-jose/v3"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
)
type KeyProvider interface {

View file

@ -7,13 +7,13 @@ import (
"net/http/httptest"
"testing"
jose "github.com/go-jose/go-jose/v3"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v2/pkg/op/mock"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/oidc/v3/pkg/op/mock"
)
func TestKeys(t *testing.T) {

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Authorizer)
// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Authorizer)
// Package mock is a generated GoMock package.
package mock
@ -9,8 +9,9 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
http "github.com/zitadel/oidc/v2/pkg/http"
op "github.com/zitadel/oidc/v2/pkg/op"
http "github.com/zitadel/oidc/v3/pkg/http"
op "github.com/zitadel/oidc/v3/pkg/op"
slog "golang.org/x/exp/slog"
)
// MockAuthorizer is a mock of Authorizer interface.
@ -79,10 +80,10 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
}
// IDTokenHintVerifier mocks base method.
func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) op.IDTokenHintVerifier {
func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) *op.IDTokenHintVerifier {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0)
ret0, _ := ret[0].(op.IDTokenHintVerifier)
ret0, _ := ret[0].(*op.IDTokenHintVerifier)
return ret0
}
@ -92,6 +93,20 @@ func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier(arg0 interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier), arg0)
}
// Logger mocks base method.
func (m *MockAuthorizer) Logger() *slog.Logger {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Logger")
ret0, _ := ret[0].(*slog.Logger)
return ret0
}
// Logger indicates an expected call of Logger.
func (mr *MockAuthorizerMockRecorder) Logger() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockAuthorizer)(nil).Logger))
}
// RequestObjectSupported mocks base method.
func (m *MockAuthorizer) RequestObjectSupported() bool {
m.ctrl.T.Helper()

View file

@ -4,12 +4,12 @@ import (
"context"
"testing"
jose "github.com/go-jose/go-jose/v3"
"github.com/golang/mock/gomock"
"github.com/gorilla/schema"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/schema"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
)
func NewAuthorizer(t *testing.T) op.Authorizer {
@ -49,7 +49,7 @@ func ExpectEncoder(a op.Authorizer) {
func ExpectVerifier(a op.Authorizer, t *testing.T) {
mockA := a.(*MockAuthorizer)
mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn(
func() op.IDTokenHintVerifier {
func() *op.IDTokenHintVerifier {
return op.NewIDTokenHintVerifier("", nil)
})
}

View file

@ -5,8 +5,8 @@ import (
"github.com/golang/mock/gomock"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
)
func NewClient(t *testing.T) op.Client {

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Client)
// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Client)
// Package mock is a generated GoMock package.
package mock
@ -9,8 +9,8 @@ import (
time "time"
gomock "github.com/golang/mock/gomock"
oidc "github.com/zitadel/oidc/v2/pkg/oidc"
op "github.com/zitadel/oidc/v2/pkg/op"
oidc "github.com/zitadel/oidc/v3/pkg/oidc"
op "github.com/zitadel/oidc/v3/pkg/op"
)
// MockClient is a mock of Client interface.

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Configuration)
// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Configuration)
// Package mock is a generated GoMock package.
package mock
@ -9,7 +9,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
op "github.com/zitadel/oidc/v2/pkg/op"
op "github.com/zitadel/oidc/v3/pkg/op"
language "golang.org/x/text/language"
)
@ -65,10 +65,10 @@ func (mr *MockConfigurationMockRecorder) AuthMethodPrivateKeyJWTSupported() *gom
}
// AuthorizationEndpoint mocks base method.
func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint {
func (m *MockConfiguration) AuthorizationEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthorizationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -107,10 +107,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
}
// DeviceAuthorizationEndpoint mocks base method.
func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint {
func (m *MockConfiguration) DeviceAuthorizationEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -121,10 +121,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.C
}
// EndSessionEndpoint mocks base method.
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
func (m *MockConfiguration) EndSessionEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "EndSessionEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -233,10 +233,10 @@ func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSup
}
// IntrospectionEndpoint mocks base method.
func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint {
func (m *MockConfiguration) IntrospectionEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IntrospectionEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -275,10 +275,10 @@ func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *go
}
// KeysEndpoint mocks base method.
func (m *MockConfiguration) KeysEndpoint() op.Endpoint {
func (m *MockConfiguration) KeysEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeysEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -331,10 +331,10 @@ func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSuppor
}
// RevocationEndpoint mocks base method.
func (m *MockConfiguration) RevocationEndpoint() op.Endpoint {
func (m *MockConfiguration) RevocationEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevocationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -373,10 +373,10 @@ func (mr *MockConfigurationMockRecorder) SupportedUILocales() *gomock.Call {
}
// TokenEndpoint mocks base method.
func (m *MockConfiguration) TokenEndpoint() op.Endpoint {
func (m *MockConfiguration) TokenEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TokenEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -401,10 +401,10 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
}
// UserinfoEndpoint mocks base method.
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
func (m *MockConfiguration) UserinfoEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserinfoEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: DiscoverStorage)
// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: DiscoverStorage)
// Package mock is a generated GoMock package.
package mock
@ -8,8 +8,8 @@ import (
context "context"
reflect "reflect"
jose "github.com/go-jose/go-jose/v3"
gomock "github.com/golang/mock/gomock"
jose "gopkg.in/square/go-jose.v2"
)
// MockDiscoverStorage is a mock of DiscoverStorage interface.

View file

@ -1,10 +1,10 @@
package mock
//go:generate go install github.com/golang/mock/mockgen@v1.6.0
//go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/v2/pkg/op Storage
//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/v2/pkg/op Authorizer
//go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/v2/pkg/op Client
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/v2/pkg/op Configuration
//go:generate mockgen -package mock -destination ./discovery.mock.go github.com/zitadel/oidc/v2/pkg/op DiscoverStorage
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/v2/pkg/op SigningKey,Key
//go:generate mockgen -package mock -destination ./key.mock.go github.com/zitadel/oidc/v2/pkg/op KeyProvider
//go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/v3/pkg/op Storage
//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/v3/pkg/op Authorizer
//go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/v3/pkg/op Client
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/v3/pkg/op Configuration
//go:generate mockgen -package mock -destination ./discovery.mock.go github.com/zitadel/oidc/v3/pkg/op DiscoverStorage
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/v3/pkg/op SigningKey,Key
//go:generate mockgen -package mock -destination ./key.mock.go github.com/zitadel/oidc/v3/pkg/op KeyProvider

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: KeyProvider)
// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: KeyProvider)
// Package mock is a generated GoMock package.
package mock
@ -9,7 +9,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
op "github.com/zitadel/oidc/v2/pkg/op"
op "github.com/zitadel/oidc/v3/pkg/op"
)
// MockKeyProvider is a mock of KeyProvider interface.

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: SigningKey,Key)
// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: SigningKey,Key)
// Package mock is a generated GoMock package.
package mock
@ -7,8 +7,8 @@ package mock
import (
reflect "reflect"
jose "github.com/go-jose/go-jose/v3"
gomock "github.com/golang/mock/gomock"
jose "gopkg.in/square/go-jose.v2"
)
// MockSigningKey is a mock of SigningKey interface.

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Storage)
// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Storage)
// Package mock is a generated GoMock package.
package mock
@ -9,10 +9,10 @@ import (
reflect "reflect"
time "time"
jose "github.com/go-jose/go-jose/v3"
gomock "github.com/golang/mock/gomock"
oidc "github.com/zitadel/oidc/v2/pkg/oidc"
op "github.com/zitadel/oidc/v2/pkg/op"
jose "gopkg.in/square/go-jose.v2"
oidc "github.com/zitadel/oidc/v3/pkg/oidc"
op "github.com/zitadel/oidc/v3/pkg/op"
)
// MockStorage is a mock of Storage interface.

View file

@ -8,8 +8,8 @@ import (
"github.com/golang/mock/gomock"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
)
func NewStorage(t *testing.T) op.Storage {

View file

@ -6,16 +6,17 @@ import (
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/gorilla/schema"
"github.com/go-chi/chi"
jose "github.com/go-jose/go-jose/v3"
"github.com/rs/cors"
"github.com/zitadel/schema"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"golang.org/x/exp/slog"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
const (
@ -33,7 +34,7 @@ const (
)
var (
DefaultEndpoints = &endpoints{
DefaultEndpoints = &Endpoints{
Authorization: NewEndpoint(defaultAuthorizationEndpoint),
Token: NewEndpoint(defaultTokenEndpoint),
Introspection: NewEndpoint(defaultIntrospectEndpoint),
@ -76,29 +77,35 @@ func init() {
}
type OpenIDProvider interface {
http.Handler
Configuration
Storage() Storage
Decoder() httphelper.Decoder
Encoder() httphelper.Encoder
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
AccessTokenVerifier(context.Context) AccessTokenVerifier
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
AccessTokenVerifier(context.Context) *AccessTokenVerifier
Crypto() Crypto
DefaultLogoutRedirectURI() string
Probes() []ProbesFn
// EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20
Logger() *slog.Logger
// Deprecated: Provider now implements http.Handler directly.
HttpHandler() http.Handler
}
type HttpInterceptor func(http.Handler) http.Handler
func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router {
router := mux.NewRouter()
func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router {
router := chi.NewRouter()
router.Use(cors.New(defaultCORSOptions).Handler)
router.Use(intercept(o.IssuerFromRequest, interceptors...))
router.HandleFunc(healthEndpoint, healthHandler)
router.HandleFunc(readinessEndpoint, readyHandler(o.Probes()))
router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage()))
router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o))
router.NewRoute().Path(authCallbackPath(o)).Queries("id", "{id}").HandlerFunc(authorizeCallbackHandler(o))
router.HandleFunc(authCallbackPath(o), authorizeCallbackHandler(o))
router.HandleFunc(o.TokenEndpoint().Relative(), tokenHandler(o))
router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o))
router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o))
@ -132,16 +139,17 @@ type Config struct {
DeviceAuthorization DeviceAuthorizationConfig
}
type endpoints struct {
Authorization Endpoint
Token Endpoint
Introspection Endpoint
Userinfo Endpoint
Revocation Endpoint
EndSession Endpoint
CheckSessionIframe Endpoint
JwksURI Endpoint
DeviceAuthorization Endpoint
// Endpoints defines endpoint routes.
type Endpoints struct {
Authorization *Endpoint
Token *Endpoint
Introspection *Endpoint
Userinfo *Endpoint
Revocation *Endpoint
EndSession *Endpoint
CheckSessionIframe *Endpoint
JwksURI *Endpoint
DeviceAuthorization *Endpoint
}
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
@ -186,6 +194,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
storage: storage,
endpoints: DefaultEndpoints,
timer: make(<-chan time.Time),
logger: slog.Default(),
}
for _, optFunc := range opOpts {
@ -199,7 +208,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
return nil, err
}
o.httpHandler = CreateRouter(o, o.interceptors...)
o.Handler = CreateRouter(o, o.interceptors...)
o.decoder = schema.NewDecoder()
o.decoder.IgnoreUnknownKeys(true)
@ -215,20 +224,21 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
}
type Provider struct {
http.Handler
config *Config
issuer IssuerFromRequest
insecure bool
endpoints *endpoints
endpoints *Endpoints
storage Storage
keySet *openIDKeySet
crypto Crypto
httpHandler http.Handler
decoder *schema.Decoder
encoder *schema.Encoder
interceptors []HttpInterceptor
timer <-chan time.Time
accessTokenVerifierOpts []AccessTokenVerifierOpt
idTokenHintVerifierOpts []IDTokenHintVerifierOpt
logger *slog.Logger
}
func (o *Provider) IssuerFromRequest(r *http.Request) string {
@ -239,35 +249,35 @@ func (o *Provider) Insecure() bool {
return o.insecure
}
func (o *Provider) AuthorizationEndpoint() Endpoint {
func (o *Provider) AuthorizationEndpoint() *Endpoint {
return o.endpoints.Authorization
}
func (o *Provider) TokenEndpoint() Endpoint {
func (o *Provider) TokenEndpoint() *Endpoint {
return o.endpoints.Token
}
func (o *Provider) IntrospectionEndpoint() Endpoint {
func (o *Provider) IntrospectionEndpoint() *Endpoint {
return o.endpoints.Introspection
}
func (o *Provider) UserinfoEndpoint() Endpoint {
func (o *Provider) UserinfoEndpoint() *Endpoint {
return o.endpoints.Userinfo
}
func (o *Provider) RevocationEndpoint() Endpoint {
func (o *Provider) RevocationEndpoint() *Endpoint {
return o.endpoints.Revocation
}
func (o *Provider) EndSessionEndpoint() Endpoint {
func (o *Provider) EndSessionEndpoint() *Endpoint {
return o.endpoints.EndSession
}
func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint {
return o.endpoints.DeviceAuthorization
}
func (o *Provider) KeysEndpoint() Endpoint {
func (o *Provider) KeysEndpoint() *Endpoint {
return o.endpoints.JwksURI
}
@ -354,15 +364,15 @@ func (o *Provider) Encoder() httphelper.Encoder {
return o.encoder
}
func (o *Provider) IDTokenHintVerifier(ctx context.Context) IDTokenHintVerifier {
func (o *Provider) IDTokenHintVerifier(ctx context.Context) *IDTokenHintVerifier {
return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...)
}
func (o *Provider) JWTProfileVerifier(ctx context.Context) JWTProfileVerifier {
func (o *Provider) JWTProfileVerifier(ctx context.Context) *JWTProfileVerifier {
return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second)
}
func (o *Provider) AccessTokenVerifier(ctx context.Context) AccessTokenVerifier {
func (o *Provider) AccessTokenVerifier(ctx context.Context) *AccessTokenVerifier {
return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...)
}
@ -387,8 +397,13 @@ func (o *Provider) Probes() []ProbesFn {
}
}
func (o *Provider) Logger() *slog.Logger {
return o.logger
}
// Deprecated: Provider now implements http.Handler directly.
func (o *Provider) HttpHandler() http.Handler {
return o.httpHandler
return o
}
type openIDKeySet struct {
@ -421,7 +436,7 @@ func WithAllowInsecure() Option {
}
}
func WithCustomAuthEndpoint(endpoint Endpoint) Option {
func WithCustomAuthEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -431,7 +446,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomTokenEndpoint(endpoint Endpoint) Option {
func WithCustomTokenEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -441,7 +456,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
func WithCustomIntrospectionEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -451,7 +466,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
func WithCustomUserinfoEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -461,7 +476,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
func WithCustomRevocationEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -471,7 +486,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
func WithCustomEndSessionEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -481,7 +496,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomKeysEndpoint(endpoint Endpoint) Option {
func WithCustomKeysEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -491,7 +506,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
func WithCustomDeviceAuthorizationEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -501,8 +516,16 @@ func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option {
// WithCustomEndpoints sets multiple endpoints at once.
// Non of the endpoints may be nil, or an error will
// be returned when the Option used by the Provider.
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys *Endpoint) Option {
return func(o *Provider) error {
for _, e := range []*Endpoint{auth, token, userInfo, revocation, endSession, keys} {
if err := e.Validate(); err != nil {
return err
}
}
o.endpoints.Authorization = auth
o.endpoints.Token = token
o.endpoints.Userinfo = userInfo
@ -534,6 +557,16 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option {
}
}
// WithLogger lets a logger other than slog.Default().
//
// EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20
func WithLogger(logger *slog.Logger) Option {
return func(o *Provider) error {
o.logger = logger
return nil
}
}
func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler {
issuerInterceptor := NewIssuerInterceptor(i)
return func(handler http.Handler) http.Handler {

View file

@ -14,9 +14,9 @@ import (
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/example/server/storage"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v3/example/server/storage"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"golang.org/x/text/language"
)
@ -157,7 +157,7 @@ func TestRoutes(t *testing.T) {
values: map[string]string{
"client_id": client.GetID(),
"redirect_uri": "https://example.com",
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"response_type": string(oidc.ResponseTypeCode),
},
wantCode: http.StatusFound,
@ -194,7 +194,7 @@ func TestRoutes(t *testing.T) {
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeBearer),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"assertion": jwtToken,
},
wantCode: http.StatusBadRequest,
@ -207,7 +207,7 @@ func TestRoutes(t *testing.T) {
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"grant_type": string(oidc.GrantTypeTokenExchange),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"subject_token": jwtToken,
"subject_token_type": string(oidc.AccessTokenType),
},
@ -224,7 +224,7 @@ func TestRoutes(t *testing.T) {
basicAuth: &basicAuth{"sid1", "verysecret"},
values: map[string]string{
"grant_type": string(oidc.GrantTypeClientCredentials),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
},
wantCode: http.StatusOK,
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`},
@ -339,7 +339,7 @@ func TestRoutes(t *testing.T) {
path: testProvider.DeviceAuthorizationEndpoint().Relative(),
basicAuth: &basicAuth{"device", "secret"},
values: map[string]string{
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
},
wantCode: http.StatusOK,
contains: []string{
@ -371,7 +371,7 @@ func TestRoutes(t *testing.T) {
}
rec := httptest.NewRecorder()
testProvider.HttpHandler().ServeHTTP(rec, req)
testProvider.ServeHTTP(rec, req)
resp := rec.Result()
require.NoError(t, err)
@ -396,3 +396,54 @@ func TestRoutes(t *testing.T) {
})
}
}
func TestWithCustomEndpoints(t *testing.T) {
type args struct {
auth *op.Endpoint
token *op.Endpoint
userInfo *op.Endpoint
revocation *op.Endpoint
endSession *op.Endpoint
keys *op.Endpoint
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "all nil",
args: args{},
wantErr: op.ErrNilEndpoint,
},
{
name: "all set",
args: args{
auth: op.NewEndpoint("/authorize"),
token: op.NewEndpoint("/oauth/token"),
userInfo: op.NewEndpoint("/userinfo"),
revocation: op.NewEndpoint("/revoke"),
endSession: op.NewEndpoint("/end_session"),
keys: op.NewEndpoint("/keys"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider, err := op.NewOpenIDProvider(testIssuer, testConfig,
storage.NewStorage(storage.NewUserStore(testIssuer)),
op.WithCustomEndpoints(tt.args.auth, tt.args.token, tt.args.userInfo, tt.args.revocation, tt.args.endSession, tt.args.keys),
)
require.ErrorIs(t, err, tt.wantErr)
if tt.wantErr != nil {
return
}
assert.Equal(t, tt.args.auth, provider.AuthorizationEndpoint())
assert.Equal(t, tt.args.token, provider.TokenEndpoint())
assert.Equal(t, tt.args.userInfo, provider.UserinfoEndpoint())
assert.Equal(t, tt.args.revocation, provider.RevocationEndpoint())
assert.Equal(t, tt.args.endSession, provider.EndSessionEndpoint())
assert.Equal(t, tt.args.keys, provider.KeysEndpoint())
})
}
}

View file

@ -5,7 +5,7 @@ import (
"errors"
"net/http"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
)
type ProbesFn func(context.Context) error
@ -41,9 +41,9 @@ func ReadyStorage(s Storage) ProbesFn {
}
func ok(w http.ResponseWriter) {
httphelper.MarshalJSON(w, status{"ok"})
httphelper.MarshalJSON(w, Status{"ok"})
}
type status struct {
type Status struct {
Status string `json:"status,omitempty"`
}

346
pkg/op/server.go Normal file
View file

@ -0,0 +1,346 @@
package op
import (
"context"
"net/http"
"net/url"
"github.com/muhlemmer/gu"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
// Server describes the interface that needs to be implemented to serve
// OpenID Connect and Oauth2 standard requests.
//
// Methods are called after the HTTP route is resolved and
// the request body is parsed into the Request's Data field.
// When a method is called, it can be assumed that required fields,
// as described in their relevant standard, are validated already.
// The Response Data field may be of any type to allow flexibility
// to extend responses with custom fields. There are however requirements
// in the standards regarding the response models. Where applicable
// the method documentation gives a recommended type which can be used
// directly or extended upon.
//
// The addition of new methods is not considered a breaking change
// as defined by semver rules.
// Implementations MUST embed [UnimplementedServer] to maintain
// forward compatibility.
//
// EXPERIMENTAL: may change until v4
type Server interface {
// Health returns a status of "ok" once the Server is listening.
// The recommended Response Data type is [Status].
Health(context.Context, *Request[struct{}]) (*Response, error)
// Ready returns a status of "ok" once all dependencies,
// such as database storage, are ready.
// An error can be returned to explain what is not ready.
// The recommended Response Data type is [Status].
Ready(context.Context, *Request[struct{}]) (*Response, error)
// Discovery returns the OpenID Provider Configuration Information for this server.
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig
// The recommended Response Data type is [oidc.DiscoveryConfiguration].
Discovery(context.Context, *Request[struct{}]) (*Response, error)
// Keys serves the JWK set which the client can use verify signatures from the op.
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata `jwks_uri` key.
// The recommended Response Data type is [jose.JSONWebKeySet].
Keys(context.Context, *Request[struct{}]) (*Response, error)
// VerifyAuthRequest verifies the Auth Request and
// adds the Client to the request.
//
// When the `request` field is populated with a
// "Request Object" JWT, it needs to be Validated
// and its claims overwrite any fields in the AuthRequest.
// If the implementation does not support "Request Object",
// it MUST return an [oidc.ErrRequestNotSupported].
// https://openid.net/specs/openid-connect-core-1_0.html#RequestObject
VerifyAuthRequest(context.Context, *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error)
// Authorize initiates the authorization flow and redirects to a login page.
// See the various https://openid.net/specs/openid-connect-core-1_0.html
// authorize endpoint sections (one for each type of flow).
Authorize(context.Context, *ClientRequest[oidc.AuthRequest]) (*Redirect, error)
// DeviceAuthorization initiates the device authorization flow.
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
// The recommended Response Data type is [oidc.DeviceAuthorizationResponse].
DeviceAuthorization(context.Context, *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error)
// VerifyClient is called on most oauth/token handlers to authenticate,
// using either a secret (POST, Basic) or assertion (JWT).
// If no secrets are provided, the client must be public.
// This method is called before each method that takes a
// [ClientRequest] argument.
VerifyClient(context.Context, *Request[ClientCredentials]) (Client, error)
// CodeExchange returns Tokens after an authorization code
// is obtained in a successful Authorize flow.
// It is called by the Token endpoint handler when
// grant_type has the value authorization_code
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
// The recommended Response Data type is [oidc.AccessTokenResponse].
CodeExchange(context.Context, *ClientRequest[oidc.AccessTokenRequest]) (*Response, error)
// RefreshToken returns new Tokens after verifying a Refresh token.
// It is called by the Token endpoint handler when
// grant_type has the value refresh_token
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
// The recommended Response Data type is [oidc.AccessTokenResponse].
RefreshToken(context.Context, *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error)
// JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant
// It is called by the Token endpoint handler when
// grant_type has the value urn:ietf:params:oauth:grant-type:jwt-bearer
// https://datatracker.ietf.org/doc/html/rfc7523#section-2.1
// The recommended Response Data type is [oidc.AccessTokenResponse].
JWTProfile(context.Context, *Request[oidc.JWTProfileGrantRequest]) (*Response, error)
// TokenExchange handles the OAuth 2.0 token exchange grant
// It is called by the Token endpoint handler when
// grant_type has the value urn:ietf:params:oauth:grant-type:token-exchange
// https://datatracker.ietf.org/doc/html/rfc8693
// The recommended Response Data type is [oidc.AccessTokenResponse].
TokenExchange(context.Context, *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error)
// ClientCredentialsExchange handles the OAuth 2.0 client credentials grant
// It is called by the Token endpoint handler when
// grant_type has the value client_credentials
// https://datatracker.ietf.org/doc/html/rfc6749#section-4.4
// The recommended Response Data type is [oidc.AccessTokenResponse].
ClientCredentialsExchange(context.Context, *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error)
// DeviceToken handles the OAuth 2.0 Device Authorization Grant
// It is called by the Token endpoint handler when
// grant_type has the value urn:ietf:params:oauth:grant-type:device_code.
// It is typically called in a polling fashion and appropriate errors
// should be returned to signal authorization_pending or access_denied etc.
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4,
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5.
// The recommended Response Data type is [oidc.AccessTokenResponse].
DeviceToken(context.Context, *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error)
// Introspect handles the OAuth 2.0 Token Introspection endpoint.
// https://datatracker.ietf.org/doc/html/rfc7662
// The recommended Response Data type is [oidc.IntrospectionResponse].
Introspect(context.Context, *ClientRequest[oidc.IntrospectionRequest]) (*Response, error)
// UserInfo handles the UserInfo endpoint and returns Claims about the authenticated End-User.
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
// The recommended Response Data type is [oidc.UserInfo].
UserInfo(context.Context, *Request[oidc.UserInfoRequest]) (*Response, error)
// Revocation handles token revocation using an access or refresh token.
// https://datatracker.ietf.org/doc/html/rfc7009
// There are no response requirements. Data may remain empty.
Revocation(context.Context, *ClientRequest[oidc.RevocationRequest]) (*Response, error)
// EndSession handles the OpenID Connect RP-Initiated Logout.
// https://openid.net/specs/openid-connect-rpinitiated-1_0.html
// There are no response requirements. Data may remain empty.
EndSession(context.Context, *Request[oidc.EndSessionRequest]) (*Redirect, error)
// mustImpl forces implementations to embed the UnimplementedServer for forward
// compatibility with the interface.
mustImpl()
}
// Request contains the [http.Request] informational fields
// and parsed Data from the request body (POST) or URL parameters (GET).
// Data can be assumed to be validated according to the applicable
// standard for the specific endpoints.
//
// EXPERIMENTAL: may change until v4
type Request[T any] struct {
Method string
URL *url.URL
Header http.Header
Form url.Values
PostForm url.Values
Data *T
}
func (r *Request[_]) path() string {
return r.URL.Path
}
func newRequest[T any](r *http.Request, data *T) *Request[T] {
return &Request[T]{
Method: r.Method,
URL: r.URL,
Header: r.Header,
Form: r.Form,
PostForm: r.PostForm,
Data: data,
}
}
// ClientRequest is a Request with a verified client attached to it.
// Methods that receive this argument may assume the client was authenticated,
// or verified to be a public client.
//
// EXPERIMENTAL: may change until v4
type ClientRequest[T any] struct {
*Request[T]
Client Client
}
func newClientRequest[T any](r *http.Request, data *T, client Client) *ClientRequest[T] {
return &ClientRequest[T]{
Request: newRequest[T](r, data),
Client: client,
}
}
// Response object for most [Server] methods.
//
// EXPERIMENTAL: may change until v4
type Response struct {
// Header map will be merged with the
// header on the [http.ResponseWriter].
Header http.Header
// Data will be JSON marshaled to
// the response body.
// We allow any type, so that implementations
// can extend the standard types as they wish.
// However, each method will recommend which
// (base) type to use as model, in order to
// be compliant with the standards.
Data any
}
// NewResponse creates a new response for data,
// without custom headers.
func NewResponse(data any) *Response {
return &Response{
Data: data,
}
}
func (resp *Response) writeOut(w http.ResponseWriter) {
gu.MapMerge(resp.Header, w.Header())
httphelper.MarshalJSON(w, resp.Data)
}
// Redirect is a special response type which will
// initiate a [http.StatusFound] redirect.
// The Params field will be encoded and set to the
// URL's RawQuery field before building the URL.
//
// EXPERIMENTAL: may change until v4
type Redirect struct {
// Header map will be merged with the
// header on the [http.ResponseWriter].
Header http.Header
URL string
}
func NewRedirect(url string) *Redirect {
return &Redirect{URL: url}
}
func (red *Redirect) writeOut(w http.ResponseWriter, r *http.Request) {
gu.MapMerge(r.Header, w.Header())
http.Redirect(w, r, red.URL, http.StatusFound)
}
type UnimplementedServer struct{}
// UnimplementedStatusCode is the status code returned for methods
// that are not yet implemented.
// Note that this means methods in the sense of the Go interface,
// and not http methods covered by "501 Not Implemented".
var UnimplementedStatusCode = http.StatusNotFound
func unimplementedError(r interface{ path() string }) StatusError {
err := oidc.ErrServerError().WithDescription("%s not implemented on this server", r.path())
return NewStatusError(err, UnimplementedStatusCode)
}
func unimplementedGrantError(gt oidc.GrantType) StatusError {
err := oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", gt)
return NewStatusError(err, http.StatusBadRequest) // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
}
func (UnimplementedServer) mustImpl() {}
func (UnimplementedServer) Health(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
if r.Data.RequestParam != "" {
return nil, oidc.ErrRequestNotSupported()
}
return nil, unimplementedError(r)
}
func (UnimplementedServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (*Redirect, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeCode)
}
func (UnimplementedServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken)
}
func (UnimplementedServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeBearer)
}
func (UnimplementedServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
}
func (UnimplementedServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials)
}
func (UnimplementedServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
}
func (UnimplementedServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) {
return nil, unimplementedError(r)
}

480
pkg/op/server_http.go Normal file
View file

@ -0,0 +1,480 @@
package op
import (
"context"
"net/http"
"net/url"
"github.com/go-chi/chi"
"github.com/rs/cors"
"github.com/zitadel/logging"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/schema"
"golang.org/x/exp/slog"
)
// RegisterServer registers an implementation of Server.
// The resulting handler takes care of routing and request parsing,
// with some basic validation of required fields.
// The routes can be customized with [WithEndpoints].
//
// EXPERIMENTAL: may change until v4
func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler {
decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(true)
ws := &webServer{
server: server,
endpoints: endpoints,
decoder: decoder,
logger: slog.Default(),
}
for _, option := range options {
option(ws)
}
ws.createRouter()
return ws
}
type ServerOption func(s *webServer)
// WithHTTPMiddleware sets the passed middleware chain to the root of
// the Server's router.
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
return func(s *webServer) {
s.middleware = m
}
}
// WithDecoder overrides the default decoder,
// which is a [schema.Decoder] with IgnoreUnknownKeys set to true.
func WithDecoder(decoder httphelper.Decoder) ServerOption {
return func(s *webServer) {
s.decoder = decoder
}
}
// WithFallbackLogger overrides the fallback logger, which
// is used when no logger was found in the context.
// Defaults to [slog.Default].
func WithFallbackLogger(logger *slog.Logger) ServerOption {
return func(s *webServer) {
s.logger = logger
}
}
type webServer struct {
http.Handler
server Server
middleware []func(http.Handler) http.Handler
endpoints Endpoints
decoder httphelper.Decoder
logger *slog.Logger
}
func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
if logger, ok := logging.FromContext(ctx); ok {
return logger
}
return s.logger
}
func (s *webServer) createRouter() {
router := chi.NewRouter()
router.Use(cors.New(defaultCORSOptions).Handler)
router.Use(s.middleware...)
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler)
s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
s.endpointRoute(router, s.endpoints.Token, s.tokensHandler)
s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler))
s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler)
s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler))
s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler)
s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
s.Handler = router
}
func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) {
if e != nil {
router.HandleFunc(e.Relative(), hf)
s.logger.Info("registered route", "endpoint", e.Relative())
}
}
type clientHandler func(w http.ResponseWriter, r *http.Request, client Client)
func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
client, err := s.verifyRequestClient(r)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" {
if !ValidateGrantType(client, grantType) {
WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.getLogger(r.Context()))
return
}
}
handler(w, r, client)
}
}
func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
if err = r.ParseForm(); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
}
cc := new(ClientCredentials)
if err = s.decoder.Decode(cc, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
// Basic auth takes precedence, so if set it overwrites the form data.
if clientID, clientSecret, ok := r.BasicAuth(); ok {
cc.ClientID, err = url.QueryUnescape(clientID)
if err != nil {
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
cc.ClientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
}
if cc.ClientID == "" && cc.ClientAssertion == "" {
return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided")
}
if cc.ClientAssertion != "" && cc.ClientAssertionType != oidc.ClientAssertionTypeJWTAssertion {
return nil, oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type %s", cc.ClientAssertionType)
}
return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{
Method: r.Method,
URL: r.URL,
Header: r.Header,
Form: r.Form,
Data: cc,
})
}
func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
redirect, err := s.authorize(r.Context(), newRequest(r, request))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
redirect.writeOut(w, r)
}
func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) {
cr, err := s.server.VerifyAuthRequest(ctx, r)
if err != nil {
return nil, err
}
authReq := cr.Data
if authReq.RedirectURI == "" {
return nil, ErrAuthReqMissingRedirectURI
}
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
if err != nil {
return nil, err
}
authReq.Scopes, err = ValidateAuthReqScopes(cr.Client, authReq.Scopes)
if err != nil {
return nil, err
}
if err := ValidateAuthReqRedirectURI(cr.Client, authReq.RedirectURI, authReq.ResponseType); err != nil {
return nil, err
}
if err := ValidateAuthReqResponseType(cr.Client, authReq.ResponseType); err != nil {
return nil, err
}
return s.server.Authorize(ctx, cr)
}
func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
return
}
switch grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType {
case oidc.GrantTypeCode:
s.withClient(s.codeExchangeHandler)(w, r)
case oidc.GrantTypeRefreshToken:
s.withClient(s.refreshTokenHandler)(w, r)
case oidc.GrantTypeClientCredentials:
s.withClient(s.clientCredentialsHandler)(w, r)
case oidc.GrantTypeBearer:
s.jwtProfileHandler(w, r)
case oidc.GrantTypeTokenExchange:
s.withClient(s.tokenExchangeHandler)(w, r)
case oidc.GrantTypeDeviceCode:
s.withClient(s.deviceTokenHandler)(w, r)
case "":
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.getLogger(r.Context()))
default:
WriteError(w, r, unimplementedGrantError(grantType), s.getLogger(r.Context()))
}
}
func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.Assertion == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.Code == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.getLogger(r.Context()))
return
}
if request.RedirectURI == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.RefreshToken == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.SubjectToken == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.getLogger(r.Context()))
return
}
if request.SubjectTokenType == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.getLogger(r.Context()))
return
}
if !request.SubjectTokenType.IsSupported() {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.getLogger(r.Context()))
return
}
if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.getLogger(r.Context()))
return
}
if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.getLogger(r.Context()))
return
}
resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) {
if client.AuthMethod() == oidc.AuthMethodNone {
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
return
}
request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.DeviceCode == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) {
if client.AuthMethod() == oidc.AuthMethodNone {
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
return
}
request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.Token == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if token, err := getAccessToken(r); err == nil {
request.AccessToken = token
}
if request.AccessToken == "" {
err = NewStatusError(
oidc.ErrInvalidRequest().WithDescription("access token missing"),
http.StatusUnauthorized,
)
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp, err := s.server.UserInfo(r.Context(), newRequest(r, request))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.Token == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp, err := s.server.EndSession(r.Context(), newRequest(r, request))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w, r)
}
func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
return
}
resp, err := method(r.Context(), newRequest(r, &struct{}{}))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
}
func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) {
dst := new(R)
if err := r.ParseForm(); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
}
form := r.Form
if postOnly {
form = r.PostForm
}
if err := decoder.Decode(dst, form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
return dst, nil
}

View file

@ -0,0 +1,345 @@
package op_test
import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/client"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
)
func jwtProfile() (string, error) {
keyData, err := client.ConfigFromKeyFile("../../example/server/service-key1.json")
if err != nil {
return "", err
}
signer, err := client.NewSignerFromPrivateKeyByte([]byte(keyData.Key), keyData.KeyID)
if err != nil {
return "", err
}
return client.SignedJWTProfileAssertion(keyData.UserID, []string{testIssuer}, time.Hour, signer)
}
func TestServerRoutes(t *testing.T) {
server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints)
storage := testProvider.Storage().(routesTestStorage)
ctx := op.ContextWithIssuer(context.Background(), testIssuer)
client, err := storage.GetClientByClientID(ctx, "web")
require.NoError(t, err)
oidcAuthReq := &oidc.AuthRequest{
ClientID: client.GetID(),
RedirectURI: "https://example.com",
MaxAge: gu.Ptr[uint](300),
Scopes: oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, oidc.ScopeEmail, oidc.ScopeProfile, oidc.ScopePhone},
ResponseType: oidc.ResponseTypeCode,
}
authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1")
require.NoError(t, err)
storage.AuthRequestDone(authReq.GetID())
accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
require.NoError(t, err)
accessTokenRevoke, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
require.NoError(t, err)
idToken, err := op.CreateIDToken(ctx, testIssuer, authReq, time.Hour, accessToken, "123", storage, client)
require.NoError(t, err)
jwtToken, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeJWT, testProvider, client, "")
require.NoError(t, err)
jwtProfileToken, err := jwtProfile()
require.NoError(t, err)
oidcAuthReq.IDTokenHint = idToken
serverURL, err := url.Parse(testIssuer)
require.NoError(t, err)
type basicAuth struct {
username, password string
}
tests := []struct {
name string
method string
path string
basicAuth *basicAuth
header map[string]string
values map[string]string
body map[string]string
wantCode int
headerContains map[string]string
json string // test for exact json output
contains []string // when the body output is not constant, we just check for snippets to be present in the response
}{
{
name: "health",
method: http.MethodGet,
path: "/healthz",
wantCode: http.StatusOK,
json: `{"status":"ok"}`,
},
{
name: "ready",
method: http.MethodGet,
path: "/ready",
wantCode: http.StatusOK,
json: `{"status":"ok"}`,
},
{
name: "discovery",
method: http.MethodGet,
path: oidc.DiscoveryEndpoint,
wantCode: http.StatusOK,
json: `{"issuer":"https://localhost:9998/","authorization_endpoint":"https://localhost:9998/authorize","token_endpoint":"https://localhost:9998/oauth/token","introspection_endpoint":"https://localhost:9998/oauth/introspect","userinfo_endpoint":"https://localhost:9998/userinfo","revocation_endpoint":"https://localhost:9998/revoke","end_session_endpoint":"https://localhost:9998/end_session","device_authorization_endpoint":"https://localhost:9998/device_authorization","jwks_uri":"https://localhost:9998/keys","scopes_supported":["openid","profile","email","phone","address","offline_access"],"response_types_supported":["code","id_token","id_token token"],"grant_types_supported":["authorization_code","implicit","refresh_token","client_credentials","urn:ietf:params:oauth:grant-type:token-exchange","urn:ietf:params:oauth:grant-type:jwt-bearer","urn:ietf:params:oauth:grant-type:device_code"],"subject_types_supported":["public"],"id_token_signing_alg_values_supported":["RS256"],"request_object_signing_alg_values_supported":["RS256"],"token_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"token_endpoint_auth_signing_alg_values_supported":["RS256"],"revocation_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["RS256"],"introspection_endpoint_auth_methods_supported":["client_secret_basic","private_key_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["RS256"],"claims_supported":["sub","aud","exp","iat","iss","auth_time","nonce","acr","amr","c_hash","at_hash","act","scopes","client_id","azp","preferred_username","name","family_name","given_name","locale","email","email_verified","phone_number","phone_number_verified"],"code_challenge_methods_supported":["S256"],"ui_locales_supported":["en"],"request_parameter_supported":true,"request_uri_parameter_supported":false}`,
},
{
name: "authorization",
method: http.MethodGet,
path: testProvider.AuthorizationEndpoint().Relative(),
values: map[string]string{
"client_id": client.GetID(),
"redirect_uri": "https://example.com",
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"response_type": string(oidc.ResponseTypeCode),
},
wantCode: http.StatusFound,
headerContains: map[string]string{"Location": "/login/username?authRequestID="},
},
{
// This call will fail. A successfull test is already
// part of client/integration_test.go
name: "code exchange",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeCode),
"client_id": client.GetID(),
"client_secret": "secret",
"redirect_uri": "https://example.com",
"code": "123",
},
wantCode: http.StatusBadRequest,
json: `{"error":"invalid_grant", "error_description":"invalid code"}`,
},
{
name: "JWT authorization",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeBearer),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"assertion": jwtProfileToken,
},
wantCode: http.StatusOK,
contains: []string{`{"access_token":`, `"token_type":"Bearer","expires_in":299}`},
},
{
name: "Token exchange",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"grant_type": string(oidc.GrantTypeTokenExchange),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"subject_token": jwtToken,
"subject_token_type": string(oidc.AccessTokenType),
},
wantCode: http.StatusOK,
contains: []string{
`{"access_token":"`,
`","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`,
},
},
{
name: "Client credentials exchange",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
basicAuth: &basicAuth{"sid1", "verysecret"},
values: map[string]string{
"grant_type": string(oidc.GrantTypeClientCredentials),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
},
wantCode: http.StatusOK,
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`},
},
{
// This call will fail. A successfull test is already
// part of device_test.go
name: "device token",
method: http.MethodPost,
path: testProvider.TokenEndpoint().Relative(),
basicAuth: &basicAuth{"device", "secret"},
header: map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
},
body: map[string]string{
"grant_type": string(oidc.GrantTypeDeviceCode),
"device_code": "123",
},
wantCode: http.StatusBadRequest,
json: `{"error":"access_denied","error_description":"The authorization request was denied."}`,
},
{
name: "missing grant type",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
wantCode: http.StatusBadRequest,
json: `{"error":"invalid_request","error_description":"grant_type missing"}`,
},
{
name: "unsupported grant type",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": "foo",
},
wantCode: http.StatusBadRequest,
json: `{"error":"unsupported_grant_type","error_description":"foo not supported"}`,
},
{
name: "introspection",
method: http.MethodGet,
path: testProvider.IntrospectionEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"token": accessToken,
},
wantCode: http.StatusOK,
json: `{"active":true,"scope":"openid offline_access email profile phone","client_id":"web","sub":"id1","username":"test-user@localhost","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
},
{
name: "user info",
method: http.MethodGet,
path: testProvider.UserinfoEndpoint().Relative(),
header: map[string]string{
"authorization": "Bearer " + accessToken,
},
wantCode: http.StatusOK,
json: `{"sub":"id1","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
},
{
name: "refresh token",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeRefreshToken),
"refresh_token": refreshToken,
"client_id": client.GetID(),
"client_secret": "secret",
},
wantCode: http.StatusOK,
contains: []string{
`{"access_token":"`,
`","token_type":"Bearer","refresh_token":"`,
`","expires_in":299,"id_token":"`,
},
},
{
name: "revoke",
method: http.MethodGet,
path: testProvider.RevocationEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"token": accessTokenRevoke,
},
wantCode: http.StatusOK,
},
{
name: "end session",
method: http.MethodGet,
path: testProvider.EndSessionEndpoint().Relative(),
values: map[string]string{
"id_token_hint": idToken,
"client_id": "web",
},
wantCode: http.StatusFound,
headerContains: map[string]string{"Location": "/logged-out"},
contains: []string{`<a href="/logged-out">Found</a>.`},
},
{
name: "keys",
method: http.MethodGet,
path: testProvider.KeysEndpoint().Relative(),
wantCode: http.StatusOK,
contains: []string{
`{"keys":[{"use":"sig","kty":"RSA","kid":"`,
`","alg":"RS256","n":"`, `","e":"AQAB"}]}`,
},
},
{
name: "device authorization",
method: http.MethodGet,
path: testProvider.DeviceAuthorizationEndpoint().Relative(),
basicAuth: &basicAuth{"device", "secret"},
values: map[string]string{
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
},
wantCode: http.StatusOK,
contains: []string{
`{"device_code":"`, `","user_code":"`,
`","verification_uri":"https://localhost:9998/device"`,
`"verification_uri_complete":"https://localhost:9998/device?user_code=`,
`","expires_in":300,"interval":5}`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u := gu.PtrCopy(serverURL)
u.Path = tt.path
if tt.values != nil {
u.RawQuery = mapAsValues(tt.values)
}
var body io.Reader
if tt.body != nil {
body = strings.NewReader(mapAsValues(tt.body))
}
req := httptest.NewRequest(tt.method, u.String(), body)
for k, v := range tt.header {
req.Header.Set(k, v)
}
if tt.basicAuth != nil {
req.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
}
rec := httptest.NewRecorder()
server.ServeHTTP(rec, req)
resp := rec.Result()
require.NoError(t, err)
assert.Equal(t, tt.wantCode, resp.StatusCode)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
respBodyString := string(respBody)
t.Log(respBodyString)
t.Log(resp.Header)
if tt.json != "" {
assert.JSONEq(t, tt.json, respBodyString)
}
for _, c := range tt.contains {
assert.Contains(t, respBodyString, c)
}
for k, v := range tt.headerContains {
assert.Contains(t, resp.Header.Get(k), v)
}
})
}
}

1333
pkg/op/server_http_test.go Normal file

File diff suppressed because it is too large Load diff

344
pkg/op/server_legacy.go Normal file
View file

@ -0,0 +1,344 @@
package op
import (
"context"
"errors"
"net/http"
"time"
"github.com/go-chi/chi"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
// LegacyServer is an implementation of [Server[] that
// simply wraps a [OpenIDProvider].
// It can be used to transition from the former Provider/Storage
// interfaces to the new Server interface.
type LegacyServer struct {
UnimplementedServer
provider OpenIDProvider
endpoints Endpoints
}
// NewLegacyServer wraps provider in a `Server` and returns a handler which is
// the Server's router.
//
// Only non-nil endpoints will be registered on the router.
// Nil endpoints are disabled.
//
// The passed endpoints is also set to the provider,
// to be consistent with the discovery config.
// Any `With*Endpoint()` option used on the provider is
// therefore ineffective.
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler {
server := RegisterServer(&LegacyServer{
provider: provider,
endpoints: endpoints,
}, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
router := chi.NewRouter()
router.Mount("/", server)
router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider))
return router
}
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {
return NewResponse(Status{Status: "ok"}), nil
}
func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) {
for _, probe := range s.provider.Probes() {
// shouldn't we run probes in Go routines?
if err := probe(ctx); err != nil {
return nil, NewStatusError(err, http.StatusInternalServerError)
}
}
return NewResponse(Status{Status: "ok"}), nil
}
func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return NewResponse(
createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints),
), nil
}
func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
keys, err := s.provider.Storage().KeySet(ctx)
if err != nil {
return nil, NewStatusError(err, http.StatusInternalServerError)
}
return NewResponse(jsonWebKeySet(keys)), nil
}
var (
ErrAuthReqMissingClientID = errors.New("auth request is missing client_id")
ErrAuthReqMissingRedirectURI = errors.New("auth request is missing redirect_uri")
)
func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
if r.Data.RequestParam != "" {
if !s.provider.RequestObjectSupported() {
return nil, oidc.ErrRequestNotSupported()
}
err := ParseRequestObject(ctx, r.Data, s.provider.Storage(), IssuerFromContext(ctx))
if err != nil {
return nil, err
}
}
if r.Data.ClientID == "" {
return nil, ErrAuthReqMissingClientID
}
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
if err != nil {
return nil, oidc.DefaultToServerError(err, "unable to retrieve client by id")
}
return &ClientRequest[oidc.AuthRequest]{
Request: r,
Client: client,
}, nil
}
func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (_ *Redirect, err error) {
userID, err := ValidateAuthReqIDTokenHint(ctx, r.Data.IDTokenHint, s.provider.IDTokenHintVerifier(ctx))
if err != nil {
return nil, err
}
req, err := s.provider.Storage().CreateAuthRequest(ctx, r.Data, userID)
if err != nil {
return TryErrorRedirect(ctx, r.Data, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger())
}
return NewRedirect(r.Client.LoginURL(req.GetID())), nil
}
func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider)
if err != nil {
return nil, NewStatusError(err, http.StatusInternalServerError)
}
return NewResponse(response), nil
}
func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials {
storage, ok := s.provider.Storage().(ClientCredentialsStorage)
if !ok {
return nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported")
}
return storage.ClientCredentials(ctx, r.Data.ClientID, r.Data.ClientSecret)
}
if r.Data.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
jwtExchanger, ok := s.provider.(JWTAuthorizationGrantExchanger)
if !ok || !s.provider.AuthMethodPrivateKeyJWTSupported() {
return nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
}
return AuthorizePrivateJWTKey(ctx, r.Data.ClientAssertion, jwtExchanger)
}
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
if err != nil {
return nil, oidc.ErrInvalidClient().WithParent(err)
}
switch client.AuthMethod() {
case oidc.AuthMethodNone:
return client, nil
case oidc.AuthMethodPrivateKeyJWT:
return nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client")
case oidc.AuthMethodPost:
if !s.provider.AuthMethodPostSupported() {
return nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
}
}
err = AuthorizeClientIDSecret(ctx, r.Data.ClientID, r.Data.ClientSecret, s.provider.Storage())
if err != nil {
return nil, err
}
return client, nil
}
func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) {
authReq, err := AuthRequestByCode(ctx, s.provider.Storage(), r.Data.Code)
if err != nil {
return nil, err
}
if r.Client.AuthMethod() == oidc.AuthMethodNone {
if err = AuthorizeCodeChallenge(r.Data.CodeVerifier, authReq.GetCodeChallenge()); err != nil {
return nil, err
}
}
resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.provider, true, r.Data.Code, "")
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) {
if !s.provider.GrantTypeRefreshTokenSupported() {
return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken)
}
request, err := RefreshTokenRequestByRefreshToken(ctx, s.provider.Storage(), r.Data.RefreshToken)
if err != nil {
return nil, err
}
if r.Client.GetID() != request.GetClientID() {
return nil, oidc.ErrInvalidGrant()
}
if err = ValidateRefreshTokenScopes(r.Data.Scopes, request); err != nil {
return nil, err
}
resp, err := CreateTokenResponse(ctx, request, r.Client, s.provider, true, "", r.Data.RefreshToken)
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) {
exchanger, ok := s.provider.(JWTAuthorizationGrantExchanger)
if !ok {
return nil, unimplementedGrantError(oidc.GrantTypeBearer)
}
tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx))
if err != nil {
return nil, err
}
tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope)
if err != nil {
return nil, err
}
resp, err := CreateJWTTokenResponse(ctx, tokenRequest, exchanger)
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) {
if !s.provider.GrantTypeTokenExchangeSupported() {
return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
}
tokenExchangeRequest, err := CreateTokenExchangeRequest(ctx, r.Data, r.Client, s.provider)
if err != nil {
return nil, err
}
resp, err := CreateTokenExchangeResponse(ctx, tokenExchangeRequest, r.Client, s.provider)
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) {
storage, ok := s.provider.Storage().(ClientCredentialsStorage)
if !ok {
return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials)
}
tokenRequest, err := storage.ClientCredentialsTokenRequest(ctx, r.Client.GetID(), r.Data.Scope)
if err != nil {
return nil, err
}
resp, err := CreateClientCredentialsTokenResponse(ctx, tokenRequest, s.provider, r.Client)
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
if !s.provider.GrantTypeClientCredentialsSupported() {
return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
}
// use a limited context timeout shorter as the default
// poll interval of 5 seconds.
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()
state, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider)
if err != nil {
return nil, err
}
tokenRequest := &deviceAccessTokenRequest{
subject: state.Subject,
audience: []string{r.Client.GetID()},
scopes: state.Scopes,
}
resp, err := CreateDeviceTokenResponse(ctx, tokenRequest, s.provider, r.Client)
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) {
response := new(oidc.IntrospectionResponse)
tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.Token)
if !ok {
return NewResponse(response), nil
}
err := s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, r.Client.GetID())
if err != nil {
return NewResponse(response), nil
}
response.Active = true
return NewResponse(response), nil
}
func (s *LegacyServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) {
tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.AccessToken)
if !ok {
return nil, NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid"), http.StatusUnauthorized)
}
info := new(oidc.UserInfo)
err := s.provider.Storage().SetUserinfoFromToken(ctx, info, tokenID, subject, r.Header.Get("origin"))
if err != nil {
return nil, NewStatusError(err, http.StatusForbidden)
}
return NewResponse(info), nil
}
func (s *LegacyServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) {
var subject string
doDecrypt := true
if r.Data.TokenTypeHint != "access_token" {
userID, tokenID, err := s.provider.Storage().GetRefreshTokenInfo(ctx, r.Client.GetID(), r.Data.Token)
if err != nil {
// An invalid refresh token means that we'll try other things (leaving doDecrypt==true)
if !errors.Is(err, ErrInvalidRefreshToken) {
return nil, RevocationError(oidc.ErrServerError().WithParent(err))
}
} else {
r.Data.Token = tokenID
subject = userID
doDecrypt = false
}
}
if doDecrypt {
tokenID, userID, ok := getTokenIDAndSubjectForRevocation(ctx, s.provider, r.Data.Token)
if ok {
r.Data.Token = tokenID
subject = userID
}
}
if err := s.provider.Storage().RevokeToken(ctx, r.Data.Token, subject, r.Client.GetID()); err != nil {
return nil, RevocationError(err)
}
return NewResponse(nil), nil
}
func (s *LegacyServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) {
session, err := ValidateEndSessionRequest(ctx, r.Data, s.provider)
if err != nil {
return nil, err
}
err = s.provider.Storage().TerminateSession(ctx, session.UserID, session.ClientID)
if err != nil {
return nil, err
}
return NewRedirect(session.RedirectURI), nil
}

5
pkg/op/server_test.go Normal file
View file

@ -0,0 +1,5 @@
package op
// implementation check
var _ Server = &UnimplementedServer{}
var _ Server = &LegacyServer{}

View file

@ -6,15 +6,17 @@ import (
"net/url"
"path"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/exp/slog"
)
type SessionEnder interface {
Decoder() httphelper.Decoder
Storage() Storage
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
DefaultLogoutRedirectURI() string
Logger() *slog.Logger
}
func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) {
@ -31,7 +33,7 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
}
session, err := ValidateEndSessionRequest(r.Context(), req, ender)
if err != nil {
RequestError(w, r, err)
RequestError(w, r, err, ender.Logger())
return
}
redirect := session.RedirectURI
@ -41,7 +43,7 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID)
}
if err != nil {
RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"))
RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"), ender.Logger())
return
}
http.Redirect(w, r, redirect, http.StatusFound)

View file

@ -3,7 +3,7 @@ package op
import (
"errors"
"gopkg.in/square/go-jose.v2"
jose "github.com/go-jose/go-jose/v3"
)
var ErrSignerCreationFailed = errors.New("signer creation failed")

Some files were not shown because too many files have changed in this diff Show more