Merge pull request #456 from zitadel/next-main

Merge next into main in order to release v3. Merge conflicts were handled in an intermediate branch.

BREAKING CHANGE - Just making sure v3 release is triggered.
This commit is contained in:
Tim Möhlmann 2023-10-13 08:44:41 +03:00 committed by GitHub
commit 976b40620c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
118 changed files with 6091 additions and 981 deletions

View file

@ -44,9 +44,9 @@ Check the `/example` folder where example code for different scenarios is locate
```bash ```bash
# start oidc op server # start oidc op server
# oidc discovery http://localhost:9998/.well-known/openid-configuration # oidc discovery http://localhost:9998/.well-known/openid-configuration
go run github.com/zitadel/oidc/v2/example/server go run github.com/zitadel/oidc/v3/example/server
# start oidc web client (in a new terminal) # start oidc web client (in a new terminal)
CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v2/example/client/app CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v3/example/client/app
``` ```
- open http://localhost:9999/login in your browser - open http://localhost:9999/login in your browser
@ -56,11 +56,11 @@ CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid
for the dynamic issuer, just start it with: for the dynamic issuer, just start it with:
```bash ```bash
go run github.com/zitadel/oidc/v2/example/server/dynamic go run github.com/zitadel/oidc/v3/example/server/dynamic
``` ```
the oidc web client above will still work, but if you add `oidc.local` (pointing to 127.0.0.1) in your hosts file you can also start it with: the oidc web client above will still work, but if you add `oidc.local` (pointing to 127.0.0.1) in your hosts file you can also start it with:
```bash ```bash
CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://oidc.local:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v2/example/client/app CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://oidc.local:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v3/example/client/app
``` ```
> Note: Usernames are suffixed with the hostname (`test-user@localhost` or `test-user@oidc.local`) > Note: Usernames are suffixed with the hostname (`test-user@localhost` or `test-user@oidc.local`)

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
@ -9,11 +10,11 @@ import (
"strings" "strings"
"time" "time"
"github.com/gorilla/mux" "github.com/go-chi/chi"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/zitadel/oidc/v2/pkg/client/rs" "github.com/zitadel/oidc/v3/pkg/client/rs"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
const ( const (
@ -27,12 +28,12 @@ func main() {
port := os.Getenv("PORT") port := os.Getenv("PORT")
issuer := os.Getenv("ISSUER") issuer := os.Getenv("ISSUER")
provider, err := rs.NewResourceServerFromKeyFile(issuer, keyPath) provider, err := rs.NewResourceServerFromKeyFile(context.TODO(), issuer, keyPath)
if err != nil { if err != nil {
logrus.Fatalf("error creating provider %s", err.Error()) logrus.Fatalf("error creating provider %s", err.Error())
} }
router := mux.NewRouter() router := chi.NewRouter()
// public url accessible without any authorization // public url accessible without any authorization
// will print `OK` and current timestamp // will print `OK` and current timestamp
@ -47,7 +48,7 @@ func main() {
if !ok { if !ok {
return return
} }
resp, err := rs.Introspect(r.Context(), provider, token) resp, err := rs.Introspect[*oidc.IntrospectionResponse](r.Context(), provider, token)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusForbidden) http.Error(w, err.Error(), http.StatusForbidden)
return return
@ -68,14 +69,14 @@ func main() {
if !ok { if !ok {
return return
} }
resp, err := rs.Introspect(r.Context(), provider, token) resp, err := rs.Introspect[*oidc.IntrospectionResponse](r.Context(), provider, token)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusForbidden) http.Error(w, err.Error(), http.StatusForbidden)
return return
} }
params := mux.Vars(r) requestedClaim := chi.URLParam(r, "claim")
requestedClaim := params["claim"] requestedValue := chi.URLParam(r, "value")
requestedValue := params["value"]
value, ok := resp.Claims[requestedClaim].(string) value, ok := resp.Claims[requestedClaim].(string)
if !ok || value == "" || value != requestedValue { if !ok || value == "" || value != requestedValue {
http.Error(w, "claim does not match", http.StatusForbidden) http.Error(w, "claim does not match", http.StatusForbidden)

View file

@ -1,19 +1,23 @@
package main package main
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/exp/slog"
"github.com/zitadel/oidc/v2/pkg/client/rp" "github.com/zitadel/logging"
httphelper "github.com/zitadel/oidc/v2/pkg/http" "github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/oidc" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
) )
var ( var (
@ -32,9 +36,25 @@ func main() {
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
logger := slog.New(
slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelDebug,
}),
)
client := &http.Client{
Timeout: time.Minute,
}
// enable outgoing request logging
logging.EnableHTTPClient(client,
logging.WithClientGroup("client"),
)
options := []rp.Option{ options := []rp.Option{
rp.WithCookieHandler(cookieHandler), rp.WithCookieHandler(cookieHandler),
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)), rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
rp.WithHTTPClient(client),
rp.WithLogger(logger),
} }
if clientSecret == "" { if clientSecret == "" {
options = append(options, rp.WithPKCE(cookieHandler)) options = append(options, rp.WithPKCE(cookieHandler))
@ -43,7 +63,10 @@ func main() {
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
} }
provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, options...) // One can add a logger to the context,
// pre-defining log attributes as required.
ctx := logging.ToContext(context.TODO(), logger)
provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, redirectURI, scopes, options...)
if err != nil { if err != nil {
logrus.Fatalf("error creating provider %s", err.Error()) logrus.Fatalf("error creating provider %s", err.Error())
} }
@ -118,8 +141,22 @@ func main() {
// //
// http.Handle(callbackPath, rp.CodeExchangeHandler(marshalToken, provider)) // http.Handle(callbackPath, rp.CodeExchangeHandler(marshalToken, provider))
// simple counter for request IDs
var counter atomic.Int64
// enable incomming request logging
mw := logging.Middleware(
logging.WithLogger(logger),
logging.WithGroup("server"),
logging.WithIDFunc(func() slog.Attr {
return slog.Int64("id", counter.Add(1))
}),
)
lis := fmt.Sprintf("127.0.0.1:%s", port) lis := fmt.Sprintf("127.0.0.1:%s", port)
logrus.Infof("listening on http://%s/", lis) logger.Info("server listening, press ctrl+c to stop", "addr", lis)
logrus.Info("press ctrl+c to stop") err = http.ListenAndServe(lis, mw(http.DefaultServeMux))
logrus.Fatal(http.ListenAndServe(lis, nil)) if err != http.ErrServerClosed {
logger.Error("server terminated", "error", err)
os.Exit(1)
}
} }

View file

@ -11,8 +11,8 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/zitadel/oidc/v2/pkg/client/rp" "github.com/zitadel/oidc/v3/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
) )
var ( var (
@ -39,13 +39,13 @@ func main() {
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
} }
provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, "", scopes, options...) provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, "", scopes, options...)
if err != nil { if err != nil {
logrus.Fatalf("error creating provider %s", err.Error()) logrus.Fatalf("error creating provider %s", err.Error())
} }
logrus.Info("starting device authorization flow") logrus.Info("starting device authorization flow")
resp, err := rp.DeviceAuthorization(scopes, provider) resp, err := rp.DeviceAuthorization(ctx, scopes, provider, nil)
if err != nil { if err != nil {
logrus.Fatal(err) logrus.Fatal(err)
} }

View file

@ -10,10 +10,10 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
githubOAuth "golang.org/x/oauth2/github" githubOAuth "golang.org/x/oauth2/github"
"github.com/zitadel/oidc/v2/pkg/client/rp" "github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/client/rp/cli" "github.com/zitadel/oidc/v3/pkg/client/rp/cli"
"github.com/zitadel/oidc/v2/pkg/http" "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
var ( var (

View file

@ -13,7 +13,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/zitadel/oidc/v2/pkg/client/profile" "github.com/zitadel/oidc/v3/pkg/client/profile"
) )
var client = http.DefaultClient var client = http.DefaultClient
@ -25,7 +25,7 @@ func main() {
scopes := strings.Split(os.Getenv("SCOPES"), " ") scopes := strings.Split(os.Getenv("SCOPES"), " ")
if keyPath != "" { if keyPath != "" {
ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath, scopes) ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(context.TODO(), issuer, keyPath, scopes)
if err != nil { if err != nil {
logrus.Fatalf("error creating token source %s", err.Error()) logrus.Fatalf("error creating token source %s", err.Error())
} }
@ -76,7 +76,7 @@ func main() {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(issuer, key, scopes) ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(context.TODO(), issuer, key, scopes)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return

View file

@ -6,9 +6,9 @@ import (
"html/template" "html/template"
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/go-chi/chi"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
) )
const ( const (
@ -43,7 +43,7 @@ var (
type login struct { type login struct {
authenticate authenticate authenticate authenticate
router *mux.Router router chi.Router
callback func(context.Context, string) string callback func(context.Context, string) string
} }
@ -57,9 +57,9 @@ func NewLogin(authenticate authenticate, callback func(context.Context, string)
} }
func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) { func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) {
l.router = mux.NewRouter() l.router = chi.NewRouter()
l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler) l.router.Get("/username", l.loginHandler)
l.router.Path("/username").Methods("POST").HandlerFunc(issuerInterceptor.HandlerFunc(l.checkLoginHandler)) l.router.With(issuerInterceptor.Handler).Post("/username", l.checkLoginHandler)
} }
type authenticate interface { type authenticate interface {

View file

@ -7,11 +7,11 @@ import (
"log" "log"
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/go-chi/chi"
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/zitadel/oidc/v2/example/server/storage" "github.com/zitadel/oidc/v3/example/server/storage"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
) )
const ( const (
@ -47,7 +47,7 @@ func main() {
//be sure to create a proper crypto random key and manage it securely! //be sure to create a proper crypto random key and manage it securely!
key := sha256.Sum256([]byte("test")) key := sha256.Sum256([]byte("test"))
router := mux.NewRouter() router := chi.NewRouter()
//for simplicity, we provide a very small default page for users who have signed out //for simplicity, we provide a very small default page for users who have signed out
router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) {
@ -76,7 +76,7 @@ func main() {
//regardless of how many pages / steps there are in the process, the UI must be registered in the router, //regardless of how many pages / steps there are in the process, the UI must be registered in the router,
//so we will direct all calls to /login to the login UI //so we will direct all calls to /login to the login UI
router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router)) router.Mount("/login/", http.StripPrefix("/login", l.router))
//we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) //we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
//is served on the correct path //is served on the correct path
@ -84,7 +84,7 @@ func main() {
//if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), //if your issuer ends with a path (e.g. http://localhost:9998/custom/path/),
//then you would have to set the path prefix (/custom/path/): //then you would have to set the path prefix (/custom/path/):
//router.PathPrefix("/custom/path/").Handler(http.StripPrefix("/custom/path", provider.HttpHandler())) //router.PathPrefix("/custom/path/").Handler(http.StripPrefix("/custom/path", provider.HttpHandler()))
router.PathPrefix("/").Handler(provider.HttpHandler()) router.Mount("/", provider)
server := &http.Server{ server := &http.Server{
Addr: ":" + port, Addr: ":" + port,

View file

@ -1,21 +1,34 @@
package exampleop package exampleop
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"github.com/gorilla/mux" "github.com/go-chi/chi"
"github.com/gorilla/securecookie" "github.com/gorilla/securecookie"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
) )
type deviceAuthenticate interface { type deviceAuthenticate interface {
CheckUsernamePasswordSimple(username, password string) error CheckUsernamePasswordSimple(username, password string) error
op.DeviceAuthorizationStorage op.DeviceAuthorizationStorage
// GetDeviceAuthorizationByUserCode resturns the current state of the device authorization flow,
// identified by the user code.
GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error)
// CompleteDeviceAuthorization marks a device authorization entry as Completed,
// identified by userCode. The Subject is added to the state, so that
// GetDeviceAuthorizatonState can use it to create a new Access Token.
CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error
// DenyDeviceAuthorization marks a device authorization entry as Denied.
DenyDeviceAuthorization(ctx context.Context, userCode string) error
} }
type deviceLogin struct { type deviceLogin struct {
@ -23,14 +36,14 @@ type deviceLogin struct {
cookie *securecookie.SecureCookie cookie *securecookie.SecureCookie
} }
func registerDeviceAuth(storage deviceAuthenticate, router *mux.Router) { func registerDeviceAuth(storage deviceAuthenticate, router chi.Router) {
l := &deviceLogin{ l := &deviceLogin{
storage: storage, storage: storage,
cookie: securecookie.New(securecookie.GenerateRandomKey(32), nil), cookie: securecookie.New(securecookie.GenerateRandomKey(32), nil),
} }
router.HandleFunc("", l.userCodeHandler) router.HandleFunc("/", l.userCodeHandler)
router.Path("/login").Methods(http.MethodPost).HandlerFunc(l.loginHandler) router.Post("/login", l.loginHandler)
router.HandleFunc("/confirm", l.confirmHandler) router.HandleFunc("/confirm", l.confirmHandler)
} }

View file

@ -5,13 +5,13 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/go-chi/chi"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
) )
type login struct { type login struct {
authenticate authenticate authenticate authenticate
router *mux.Router router chi.Router
callback func(context.Context, string) string callback func(context.Context, string) string
} }
@ -25,9 +25,9 @@ func NewLogin(authenticate authenticate, callback func(context.Context, string)
} }
func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) { func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) {
l.router = mux.NewRouter() l.router = chi.NewRouter()
l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler) l.router.Get("/username", l.loginHandler)
l.router.Path("/username").Methods("POST").HandlerFunc(issuerInterceptor.HandlerFunc(l.checkLoginHandler)) l.router.Post("/username", issuerInterceptor.HandlerFunc(l.checkLoginHandler))
} }
type authenticate interface { type authenticate interface {

View file

@ -4,13 +4,16 @@ import (
"crypto/sha256" "crypto/sha256"
"log" "log"
"net/http" "net/http"
"sync/atomic"
"time" "time"
"github.com/gorilla/mux" "github.com/go-chi/chi"
"github.com/zitadel/logging"
"golang.org/x/exp/slog"
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/zitadel/oidc/v2/example/server/storage" "github.com/zitadel/oidc/v3/example/server/storage"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
) )
const ( const (
@ -31,26 +34,33 @@ type Storage interface {
deviceAuthenticate deviceAuthenticate
} }
// simple counter for request IDs
var counter atomic.Int64
// SetupServer creates an OIDC server with Issuer=http://localhost:<port> // SetupServer creates an OIDC server with Issuer=http://localhost:<port>
// //
// Use one of the pre-made clients in storage/clients.go or register a new one. // Use one of the pre-made clients in storage/clients.go or register a new one.
func SetupServer(issuer string, storage Storage, extraOptions ...op.Option) *mux.Router { func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer bool, extraOptions ...op.Option) chi.Router {
// the OpenID Provider requires a 32-byte key for (token) encryption // the OpenID Provider requires a 32-byte key for (token) encryption
// be sure to create a proper crypto random key and manage it securely! // be sure to create a proper crypto random key and manage it securely!
key := sha256.Sum256([]byte("test")) key := sha256.Sum256([]byte("test"))
router := mux.NewRouter() router := chi.NewRouter()
router.Use(logging.Middleware(
logging.WithLogger(logger),
logging.WithIDFunc(func() slog.Attr {
return slog.Int64("id", counter.Add(1))
}),
))
// for simplicity, we provide a very small default page for users who have signed out // for simplicity, we provide a very small default page for users who have signed out
router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) {
_, err := w.Write([]byte("signed out successfully")) w.Write([]byte("signed out successfully"))
if err != nil { // no need to check/log error, this will be handeled by the middleware.
log.Printf("error serving logged out page: %v", err)
}
}) })
// creation of the OpenIDProvider with the just created in-memory Storage // creation of the OpenIDProvider with the just created in-memory Storage
provider, err := newOP(storage, issuer, key, extraOptions...) provider, err := newOP(storage, issuer, key, logger, extraOptions...)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -62,17 +72,23 @@ func SetupServer(issuer string, storage Storage, extraOptions ...op.Option) *mux
// regardless of how many pages / steps there are in the process, the UI must be registered in the router, // regardless of how many pages / steps there are in the process, the UI must be registered in the router,
// so we will direct all calls to /login to the login UI // so we will direct all calls to /login to the login UI
router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router)) router.Mount("/login/", http.StripPrefix("/login", l.router))
router.PathPrefix("/device").Subrouter() router.Route("/device", func(r chi.Router) {
registerDeviceAuth(storage, router.PathPrefix("/device").Subrouter()) registerDeviceAuth(storage, r)
})
handler := http.Handler(provider)
if wrapServer {
handler = op.NewLegacyServer(provider, *op.DefaultEndpoints)
}
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
// is served on the correct path // is served on the correct path
// //
// if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), // if your issuer ends with a path (e.g. http://localhost:9998/custom/path/),
// then you would have to set the path prefix (/custom/path/) // then you would have to set the path prefix (/custom/path/)
router.PathPrefix("/").Handler(provider.HttpHandler()) router.Mount("/", handler)
return router return router
} }
@ -80,7 +96,7 @@ func SetupServer(issuer string, storage Storage, extraOptions ...op.Option) *mux
// newOP will create an OpenID Provider for localhost on a specified port with a given encryption key // newOP will create an OpenID Provider for localhost on a specified port with a given encryption key
// and a predefined default logout uri // and a predefined default logout uri
// it will enable all options (see descriptions) // it will enable all options (see descriptions)
func newOP(storage op.Storage, issuer string, key [32]byte, extraOptions ...op.Option) (op.OpenIDProvider, error) { func newOP(storage op.Storage, issuer string, key [32]byte, logger *slog.Logger, extraOptions ...op.Option) (op.OpenIDProvider, error) {
config := &op.Config{ config := &op.Config{
CryptoKey: key, CryptoKey: key,
@ -118,6 +134,8 @@ func newOP(storage op.Storage, issuer string, key [32]byte, extraOptions ...op.O
op.WithAllowInsecure(), op.WithAllowInsecure(),
// as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth // as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth
op.WithCustomAuthEndpoint(op.NewEndpoint("auth")), op.WithCustomAuthEndpoint(op.NewEndpoint("auth")),
// Pass our logger to the OP
op.WithLogger(logger.WithGroup("op")),
}, extraOptions...)..., }, extraOptions...)...,
) )
if err != nil { if err != nil {

View file

@ -2,11 +2,12 @@ package main
import ( import (
"fmt" "fmt"
"log"
"net/http" "net/http"
"os"
"github.com/zitadel/oidc/v2/example/server/exampleop" "github.com/zitadel/oidc/v3/example/server/exampleop"
"github.com/zitadel/oidc/v2/example/server/storage" "github.com/zitadel/oidc/v3/example/server/storage"
"golang.org/x/exp/slog"
) )
func main() { func main() {
@ -20,16 +21,22 @@ func main() {
// in this example it will be handled in-memory // in this example it will be handled in-memory
storage := storage.NewStorage(storage.NewUserStore(issuer)) storage := storage.NewStorage(storage.NewUserStore(issuer))
router := exampleop.SetupServer(issuer, storage) logger := slog.New(
slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelDebug,
}),
)
router := exampleop.SetupServer(issuer, storage, logger, false)
server := &http.Server{ server := &http.Server{
Addr: ":" + port, Addr: ":" + port,
Handler: router, Handler: router,
} }
log.Printf("server listening on http://localhost:%s/", port) logger.Info("server listening, press ctrl+c to stop", "addr", fmt.Sprintf("http://localhost:%s/", port))
log.Println("press ctrl+c to stop")
err := server.ListenAndServe() err := server.ListenAndServe()
if err != nil { if err != http.ErrServerClosed {
log.Fatal(err) logger.Error("server terminated", "error", err)
os.Exit(1)
} }
} }

View file

@ -3,8 +3,8 @@ package storage
import ( import (
"time" "time"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
) )
var ( var (
@ -185,7 +185,7 @@ func WebClient(id, secret string, redirectURIs ...string) *Client {
authMethod: oidc.AuthMethodBasic, authMethod: oidc.AuthMethodBasic,
loginURL: defaultLoginURL, loginURL: defaultLoginURL,
responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode}, responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode},
grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken}, grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken, oidc.GrantTypeTokenExchange},
accessTokenType: op.AccessTokenTypeBearer, accessTokenType: op.AccessTokenTypeBearer,
devMode: false, devMode: false,
idTokenUserinfoClaimsAssertion: false, idTokenUserinfoClaimsAssertion: false,

View file

@ -3,10 +3,11 @@ package storage
import ( import (
"time" "time"
"golang.org/x/exp/slog"
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
) )
const ( const (
@ -41,6 +42,19 @@ type AuthRequest struct {
authTime time.Time authTime time.Time
} }
// LogValue allows you to define which fields will be logged.
// Implements the [slog.LogValuer]
func (a *AuthRequest) LogValue() slog.Value {
return slog.GroupValue(
slog.String("id", a.ID),
slog.Time("creation_date", a.CreationDate),
slog.Any("scopes", a.Scopes),
slog.String("response_type", string(a.ResponseType)),
slog.String("app_id", a.ApplicationID),
slog.String("callback_uri", a.CallbackURI),
)
}
func (a *AuthRequest) GetID() string { func (a *AuthRequest) GetID() string {
return a.ID return a.ID
} }

View file

@ -11,11 +11,11 @@ import (
"sync" "sync"
"time" "time"
jose "github.com/go-jose/go-jose/v3"
"github.com/google/uuid" "github.com/google/uuid"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
) )
// serviceKey1 is a public key which will be used for the JWT Profile Authorization Grant // serviceKey1 is a public key which will be used for the JWT Profile Authorization Grant

View file

@ -4,10 +4,10 @@ import (
"context" "context"
"time" "time"
"gopkg.in/square/go-jose.v2" jose "github.com/go-jose/go-jose/v3"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
) )
type multiStorage struct { type multiStorage struct {

10
go.mod
View file

@ -1,13 +1,13 @@
module github.com/zitadel/oidc/v2 module github.com/zitadel/oidc/v3
go 1.19 go 1.19
require ( require (
github.com/go-chi/chi v1.5.4
github.com/go-jose/go-jose/v3 v3.0.0
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/google/go-github/v31 v31.0.0 github.com/google/go-github/v31 v31.0.0
github.com/google/uuid v1.3.1 github.com/google/uuid v1.3.1
github.com/gorilla/mux v1.8.0
github.com/gorilla/schema v1.2.0
github.com/gorilla/securecookie v1.1.1 github.com/gorilla/securecookie v1.1.1
github.com/jeremija/gosubmit v0.2.7 github.com/jeremija/gosubmit v0.2.7
github.com/muhlemmer/gu v0.3.1 github.com/muhlemmer/gu v0.3.1
@ -15,11 +15,13 @@ require (
github.com/rs/cors v1.10.1 github.com/rs/cors v1.10.1
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.8.4
github.com/zitadel/logging v0.4.0
github.com/zitadel/schema v1.3.0
go.opentelemetry.io/otel v1.19.0 go.opentelemetry.io/otel v1.19.0
go.opentelemetry.io/otel/trace v1.19.0 go.opentelemetry.io/otel/trace v1.19.0
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
golang.org/x/oauth2 v0.13.0 golang.org/x/oauth2 v0.13.0
golang.org/x/text v0.13.0 golang.org/x/text v0.13.0
gopkg.in/square/go-jose.v2 v2.6.0
) )
require ( require (

20
go.sum
View file

@ -1,6 +1,10 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-chi/chi v1.5.4 h1:QHdzF2szwjqVV4wmByUnTcsbIg7UGaQ0tPF2t5GcAIs=
github.com/go-chi/chi v1.5.4/go.mod h1:uaf8YgoFazUOkPBG7fxPftUylNumIev9awIWOENIuEg=
github.com/go-jose/go-jose/v3 v3.0.0 h1:s6rrhirfEP/CGIoc6p+PZAeogN2SxKav6Wp7+dyMWVo=
github.com/go-jose/go-jose/v3 v3.0.0/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
@ -13,6 +17,7 @@ github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
@ -23,10 +28,6 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/schema v1.2.0 h1:YufUaxZYCKGFuAq3c96BOhjgd5nmXiOY9NGzF247Tsc=
github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/jeremija/gosubmit v0.2.7 h1:At0OhGCFGPXyjPYAsCchoBUhE099pcBXmsb4iZqROIc= github.com/jeremija/gosubmit v0.2.7 h1:At0OhGCFGPXyjPYAsCchoBUhE099pcBXmsb4iZqROIc=
@ -44,10 +45,15 @@ github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/zitadel/logging v0.4.0 h1:lRAIFgaRoJpLNbsL7jtIYHcMDoEJP9QZB4GqMfl4xaA=
github.com/zitadel/logging v0.4.0/go.mod h1:6uALRJawpkkuUPCkgzfgcPR3c2N908wqnOnIrRelUFc=
github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0=
github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc=
go.opentelemetry.io/otel v1.19.0 h1:MuS/TNf4/j4IXsZuJegVzI1cwut7Qc00344rgH7p8bs= go.opentelemetry.io/otel v1.19.0 h1:MuS/TNf4/j4IXsZuJegVzI1cwut7Qc00344rgH7p8bs=
go.opentelemetry.io/otel v1.19.0/go.mod h1:i0QyjOq3UPoTzff0PJB2N66fb4S0+rSbSB15/oyH9fY= go.opentelemetry.io/otel v1.19.0/go.mod h1:i0QyjOq3UPoTzff0PJB2N66fb4S0+rSbSB15/oyH9fY=
go.opentelemetry.io/otel/metric v1.19.0 h1:aTzpGtV0ar9wlV4Sna9sdJyII5jTVJEvKETPiOKwvpE= go.opentelemetry.io/otel/metric v1.19.0 h1:aTzpGtV0ar9wlV4Sna9sdJyII5jTVJEvKETPiOKwvpE=
@ -55,9 +61,12 @@ go.opentelemetry.io/otel/metric v1.19.0/go.mod h1:L5rUsV9kM1IxCj1MmSdS+JQAcVm319
go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1Dzxpg= go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1Dzxpg=
go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo= go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
@ -102,8 +111,7 @@ google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -8,8 +8,8 @@ import (
"fmt" "fmt"
"os" "os"
tu "github.com/zitadel/oidc/v2/internal/testutil" tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
var custom = map[string]any{ var custom = map[string]any{

View file

@ -8,8 +8,9 @@ import (
"errors" "errors"
"time" "time"
"github.com/zitadel/oidc/v2/pkg/oidc" jose "github.com/go-jose/go-jose/v3"
"gopkg.in/square/go-jose.v2" "github.com/muhlemmer/gu"
"github.com/zitadel/oidc/v3/pkg/oidc"
) )
// KeySet implements oidc.Keys // KeySet implements oidc.Keys
@ -17,7 +18,7 @@ type KeySet struct{}
// VerifySignature implments op.KeySet. // VerifySignature implments op.KeySet.
func (KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) { func (KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
if ctx.Err() != nil { if err = ctx.Err(); err != nil {
return nil, err return nil, err
} }
@ -45,6 +46,16 @@ func init() {
} }
} }
type JWTProfileKeyStorage struct{}
func (JWTProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID string, clientID string) (*jose.JSONWebKey, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
return gu.Ptr(WebKey.Public()), nil
}
func signEncodeTokenClaims(claims any) string { func signEncodeTokenClaims(claims any) string {
payload, err := json.Marshal(claims) payload, err := json.Marshal(claims)
if err != nil { if err != nil {
@ -106,6 +117,25 @@ func NewAccessToken(issuer, subject string, audience []string, expiration time.T
return NewAccessTokenCustom(issuer, subject, audience, expiration, jwtid, clientID, skew, nil) return NewAccessTokenCustom(issuer, subject, audience, expiration, jwtid, clientID, skew, nil)
} }
func NewJWTProfileAssertion(issuer, clientID string, audience []string, issuedAt, expiration time.Time) (string, *oidc.JWTTokenRequest) {
req := &oidc.JWTTokenRequest{
Issuer: issuer,
Subject: clientID,
Audience: audience,
ExpiresAt: oidc.FromTime(expiration),
IssuedAt: oidc.FromTime(issuedAt),
}
// make sure the private claim map is set correctly
data, err := json.Marshal(req)
if err != nil {
panic(err)
}
if err = json.Unmarshal(data, req); err != nil {
panic(err)
}
return signEncodeTokenClaims(req), req
}
const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg` const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg`
// These variables always result in a valid token // These variables always result in a valid token
@ -137,6 +167,10 @@ func ValidAccessToken() (string, *oidc.AccessTokenClaims) {
return NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew) return NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew)
} }
func ValidJWTProfileAssertion() (string, *oidc.JWTTokenRequest) {
return NewJWTProfileAssertion(ValidClientID, ValidClientID, []string{ValidIssuer}, time.Now(), ValidExpiration)
}
// ACRVerify is a oidc.ACRVerifier func. // ACRVerify is a oidc.ACRVerifier func.
func ACRVerify(acr string) error { func ACRVerify(acr string) error {
if acr != ValidACR { if acr != ValidACR {

View file

@ -11,24 +11,25 @@ import (
"strings" "strings"
"time" "time"
jose "github.com/go-jose/go-jose/v3"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/v2/pkg/crypto" "github.com/zitadel/logging"
httphelper "github.com/zitadel/oidc/v2/pkg/http" "github.com/zitadel/oidc/v3/pkg/crypto"
"github.com/zitadel/oidc/v2/pkg/oidc" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
) )
var Encoder = httphelper.Encoder(oidc.NewEncoder()) var Encoder = httphelper.Encoder(oidc.NewEncoder())
// Discover calls the discovery endpoint of the provided issuer and returns its configuration // Discover calls the discovery endpoint of the provided issuer and returns its configuration
// It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url // It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url
func Discover(issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) { func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) {
wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
if len(wellKnownUrl) == 1 && wellKnownUrl[0] != "" { if len(wellKnownUrl) == 1 && wellKnownUrl[0] != "" {
wellKnown = wellKnownUrl[0] wellKnown = wellKnownUrl[0]
} }
req, err := http.NewRequest("GET", wellKnown, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -37,6 +38,10 @@ func Discover(issuer string, httpClient *http.Client, wellKnownUrl ...string) (*
if err != nil { if err != nil {
return nil, err return nil, err
} }
if logger, ok := logging.FromContext(ctx); ok {
logger.Debug("discover", "config", discoveryConfig)
}
if discoveryConfig.Issuer != issuer { if discoveryConfig.Issuer != issuer {
return nil, oidc.ErrIssuerInvalid return nil, oidc.ErrIssuerInvalid
} }
@ -48,12 +53,12 @@ type TokenEndpointCaller interface {
HttpClient() *http.Client HttpClient() *http.Client
} }
func CallTokenEndpoint(request any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { func CallTokenEndpoint(ctx context.Context, request any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
return callTokenEndpoint(request, nil, caller) return callTokenEndpoint(ctx, request, nil, caller)
} }
func callTokenEndpoint(request any, authFn any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { func callTokenEndpoint(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn) req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -80,8 +85,8 @@ type EndSessionCaller interface {
HttpClient() *http.Client HttpClient() *http.Client
} }
func CallEndSessionEndpoint(request any, authFn any, caller EndSessionCaller) (*url.URL, error) { func CallEndSessionEndpoint(ctx context.Context, request any, authFn any, caller EndSessionCaller) (*url.URL, error) {
req, err := httphelper.FormRequest(caller.GetEndSessionEndpoint(), request, Encoder, authFn) req, err := httphelper.FormRequest(ctx, caller.GetEndSessionEndpoint(), request, Encoder, authFn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -123,8 +128,8 @@ type RevokeRequest struct {
ClientSecret string `schema:"client_secret"` ClientSecret string `schema:"client_secret"`
} }
func CallRevokeEndpoint(request any, authFn any, caller RevokeCaller) error { func CallRevokeEndpoint(ctx context.Context, request any, authFn any, caller RevokeCaller) error {
req, err := httphelper.FormRequest(caller.GetRevokeEndpoint(), request, Encoder, authFn) req, err := httphelper.FormRequest(ctx, caller.GetRevokeEndpoint(), request, Encoder, authFn)
if err != nil { if err != nil {
return err return err
} }
@ -151,8 +156,8 @@ func CallRevokeEndpoint(request any, authFn any, caller RevokeCaller) error {
return nil return nil
} }
func CallTokenExchangeEndpoint(request any, authFn any, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) { func CallTokenExchangeEndpoint(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) {
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn) req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -192,8 +197,8 @@ type DeviceAuthorizationCaller interface {
HttpClient() *http.Client HttpClient() *http.Client
} }
func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) { func CallDeviceAuthorizationEndpoint(ctx context.Context, request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller, authFn any) (*oidc.DeviceAuthorizationResponse, error) {
req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil) req, err := httphelper.FormRequest(ctx, caller.GetDeviceAuthorizationEndpoint(), request, Encoder, authFn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -214,7 +219,7 @@ type DeviceAccessTokenRequest struct {
} }
func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) { func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil) req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,6 +1,7 @@
package client package client
import ( import (
"context"
"net/http" "net/http"
"testing" "testing"
@ -36,7 +37,7 @@ func TestDiscover(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := Discover(tt.args.issuer, http.DefaultClient, tt.args.wellKnownUrl...) got, err := Discover(context.Background(), tt.args.issuer, http.DefaultClient, tt.args.wellKnownUrl...)
if tt.wantErr { if tt.wantErr {
assert.Error(t, err) assert.Error(t, err)
return return

View file

@ -2,33 +2,64 @@ package client_test
import ( import (
"bytes" "bytes"
"context"
"fmt"
"io" "io"
"io/ioutil"
"math/rand" "math/rand"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"os" "os"
"os/signal"
"strconv" "strconv"
"syscall"
"testing" "testing"
"time" "time"
"github.com/jeremija/gosubmit" "github.com/jeremija/gosubmit"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/exp/slog"
"github.com/zitadel/oidc/v2/example/server/exampleop" "github.com/zitadel/oidc/v3/example/server/exampleop"
"github.com/zitadel/oidc/v2/example/server/storage" "github.com/zitadel/oidc/v3/example/server/storage"
"github.com/zitadel/oidc/v2/pkg/client/rp" "github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/client/rs" "github.com/zitadel/oidc/v3/pkg/client/rs"
"github.com/zitadel/oidc/v2/pkg/client/tokenexchange" "github.com/zitadel/oidc/v3/pkg/client/tokenexchange"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
) )
var Logger = slog.New(
slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelDebug,
}),
)
var CTX context.Context
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT)
defer cancel()
CTX, cancel = context.WithTimeout(ctx, time.Minute)
defer cancel()
return m.Run()
}())
}
func TestRelyingPartySession(t *testing.T) { func TestRelyingPartySession(t *testing.T) {
for _, wrapServer := range []bool{false, true} {
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
testRelyingPartySession(t, wrapServer)
})
}
}
func testRelyingPartySession(t *testing.T, wrapServer bool) {
t.Log("------- start example OP ------") t.Log("------- start example OP ------")
targetURL := "http://local-site" targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
@ -36,17 +67,17 @@ func TestRelyingPartySession(t *testing.T) {
opServer := httptest.NewServer(&dh) opServer := httptest.NewServer(&dh)
defer opServer.Close() defer opServer.Close()
t.Logf("auth server at %s", opServer.URL) t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage) dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
t.Log("------- run authorization code flow ------") t.Log("------- run authorization code flow ------")
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, "secret") provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret")
t.Log("------- refresh tokens ------") t.Log("------- refresh tokens ------")
newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "") newTokens, err := rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
require.NoError(t, err, "refresh token") require.NoError(t, err, "refresh token")
assert.NotNil(t, newTokens, "access token") assert.NotNil(t, newTokens, "access token")
t.Logf("new access token %s", newTokens.AccessToken) t.Logf("new access token %s", newTokens.AccessToken)
@ -54,11 +85,13 @@ func TestRelyingPartySession(t *testing.T) {
t.Logf("new token type %s", newTokens.TokenType) t.Logf("new token type %s", newTokens.TokenType)
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339)) t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
require.NotEmpty(t, newTokens.AccessToken, "new accessToken") require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
assert.NotEmpty(t, newTokens.Extra("id_token"), "new idToken") assert.NotEmpty(t, newTokens.IDToken, "new idToken")
assert.NotNil(t, newTokens.IDTokenClaims)
assert.Equal(t, newTokens.IDTokenClaims.Subject, tokens.IDTokenClaims.Subject)
t.Log("------ end session (logout) ------") t.Log("------ end session (logout) ------")
newLoc, err := rp.EndSession(provider, idToken, "", "") newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "")
require.NoError(t, err, "logout") require.NoError(t, err, "logout")
if newLoc != nil { if newLoc != nil {
t.Logf("redirect to %s", newLoc) t.Logf("redirect to %s", newLoc)
@ -67,17 +100,25 @@ func TestRelyingPartySession(t *testing.T) {
} }
t.Log("------ attempt refresh again (should fail) ------") t.Log("------ attempt refresh again (should fail) ------")
t.Log("trying original refresh token", refreshToken) t.Log("trying original refresh token", tokens.RefreshToken)
_, err = rp.RefreshAccessToken(provider, refreshToken, "", "") _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
assert.Errorf(t, err, "refresh with original") assert.Errorf(t, err, "refresh with original")
if newTokens.RefreshToken != "" { if newTokens.RefreshToken != "" {
t.Log("trying replacement refresh token", newTokens.RefreshToken) t.Log("trying replacement refresh token", newTokens.RefreshToken)
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "") _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, newTokens.RefreshToken, "", "")
assert.Errorf(t, err, "refresh with replacement") assert.Errorf(t, err, "refresh with replacement")
} }
} }
func TestResourceServerTokenExchange(t *testing.T) { func TestResourceServerTokenExchange(t *testing.T) {
for _, wrapServer := range []bool{false, true} {
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
testResourceServerTokenExchange(t, wrapServer)
})
}
}
func testResourceServerTokenExchange(t *testing.T, wrapServer bool) {
t.Log("------- start example OP ------") t.Log("------- start example OP ------")
targetURL := "http://local-site" targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
@ -85,23 +126,24 @@ func TestResourceServerTokenExchange(t *testing.T) {
opServer := httptest.NewServer(&dh) opServer := httptest.NewServer(&dh)
defer opServer.Close() defer opServer.Close()
t.Logf("auth server at %s", opServer.URL) t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage) dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
clientSecret := "secret" clientSecret := "secret"
t.Log("------- run authorization code flow ------") t.Log("------- run authorization code flow ------")
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret) provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)
resourceServer, err := rs.NewResourceServerClientCredentials(opServer.URL, clientID, clientSecret) resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret)
require.NoError(t, err, "new resource server") require.NoError(t, err, "new resource server")
t.Log("------- exchage refresh tokens (impersonation) ------") t.Log("------- exchage refresh tokens (impersonation) ------")
tokenExchangeResponse, err := tokenexchange.ExchangeToken( tokenExchangeResponse, err := tokenexchange.ExchangeToken(
CTX,
resourceServer, resourceServer,
refreshToken, tokens.RefreshToken,
oidc.RefreshTokenType, oidc.RefreshTokenType,
"", "",
"", "",
@ -119,7 +161,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
t.Log("------ end session (logout) ------") t.Log("------ end session (logout) ------")
newLoc, err := rp.EndSession(provider, idToken, "", "") newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "")
require.NoError(t, err, "logout") require.NoError(t, err, "logout")
if newLoc != nil { if newLoc != nil {
t.Logf("redirect to %s", newLoc) t.Logf("redirect to %s", newLoc)
@ -130,8 +172,9 @@ func TestResourceServerTokenExchange(t *testing.T) {
t.Log("------- attempt exchage again (should fail) ------") t.Log("------- attempt exchage again (should fail) ------")
tokenExchangeResponse, err = tokenexchange.ExchangeToken( tokenExchangeResponse, err = tokenexchange.ExchangeToken(
CTX,
resourceServer, resourceServer,
refreshToken, tokens.RefreshToken,
oidc.RefreshTokenType, oidc.RefreshTokenType,
"", "",
"", "",
@ -145,7 +188,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
require.Nil(t, tokenExchangeResponse, "token exchange response") require.Nil(t, tokenExchangeResponse, "token exchange response")
} }
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) { func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) {
targetURL := "http://local-site" targetURL := "http://local-site"
localURL, err := url.Parse(targetURL + "/login?requestID=1234") localURL, err := url.Parse(targetURL + "/login?requestID=1234")
require.NoError(t, err, "local url") require.NoError(t, err, "local url")
@ -167,6 +210,7 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
key := []byte("test1234test1234") key := []byte("test1234test1234")
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
provider, err = rp.NewRelyingPartyOIDC( provider, err = rp.NewRelyingPartyOIDC(
CTX,
opServer.URL, opServer.URL,
clientID, clientID,
clientSecret, clientSecret,
@ -241,7 +285,8 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
} }
var email string var email string
redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) { redirect := func(w http.ResponseWriter, r *http.Request, newTokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
tokens = newTokens
require.NotNil(t, tokens, "tokens") require.NotNil(t, tokens, "tokens")
require.NotNil(t, info, "info") require.NotNil(t, info, "info")
t.Log("access token", tokens.AccessToken) t.Log("access token", tokens.AccessToken)
@ -249,9 +294,6 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
t.Log("id token", tokens.IDToken) t.Log("id token", tokens.IDToken)
t.Log("email", info.Email) t.Log("email", info.Email)
accessToken = tokens.AccessToken
refreshToken = tokens.RefreshToken
idToken = tokens.IDToken
email = info.Email email = info.Email
http.Redirect(w, r, targetURL, 302) http.Redirect(w, r, targetURL, 302)
} }
@ -273,12 +315,12 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
require.NoError(t, err, "get fully-authorizied redirect location") require.NoError(t, err, "get fully-authorizied redirect location")
require.Equal(t, targetURL, authorizedURL.String(), "fully-authorizied redirect location") require.Equal(t, targetURL, authorizedURL.String(), "fully-authorizied redirect location")
require.NotEmpty(t, idToken, "id token") require.NotEmpty(t, tokens.IDToken, "id token")
assert.NotEmpty(t, refreshToken, "refresh token") assert.NotEmpty(t, tokens.RefreshToken, "refresh token")
assert.NotEmpty(t, accessToken, "access token") assert.NotEmpty(t, tokens.AccessToken, "access token")
assert.NotEmpty(t, email, "email") assert.NotEmpty(t, email, "email")
return provider, accessToken, refreshToken, idToken return provider, tokens
} }
func TestErrorFromPromptNone(t *testing.T) { func TestErrorFromPromptNone(t *testing.T) {
@ -299,7 +341,7 @@ func TestErrorFromPromptNone(t *testing.T) {
opServer := httptest.NewServer(&dh) opServer := httptest.NewServer(&dh)
defer opServer.Close() defer opServer.Close()
t.Logf("auth server at %s", opServer.URL) t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, op.WithHttpInterceptors( dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, false, op.WithHttpInterceptors(
func(next http.Handler) http.Handler { func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("request to %s", r.URL) t.Logf("request to %s", r.URL)
@ -317,6 +359,7 @@ func TestErrorFromPromptNone(t *testing.T) {
key := []byte("test1234test1234") key := []byte("test1234test1234")
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
provider, err := rp.NewRelyingPartyOIDC( provider, err := rp.NewRelyingPartyOIDC(
CTX,
opServer.URL, opServer.URL,
clientID, clientID,
clientSecret, clientSecret,
@ -412,7 +455,7 @@ func getForm(t *testing.T, desc string, httpClient *http.Client, uri *url.URL) [
func fillForm(t *testing.T, desc string, httpClient *http.Client, body []byte, uri *url.URL, opts ...gosubmit.Option) *url.URL { func fillForm(t *testing.T, desc string, httpClient *http.Client, body []byte, uri *url.URL, opts ...gosubmit.Option) *url.URL {
// TODO: switch to io.NopCloser when go1.15 support is dropped // TODO: switch to io.NopCloser when go1.15 support is dropped
req := gosubmit.ParseWithURL(ioutil.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest( req := gosubmit.ParseWithURL(io.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest(
append([]gosubmit.Option{gosubmit.AutoFill()}, opts...)..., append([]gosubmit.Option{gosubmit.AutoFill()}, opts...)...,
) )
if req.URL.Scheme == "" { if req.URL.Scheme == "" {

View file

@ -1,17 +1,18 @@
package client package client
import ( import (
"context"
"net/url" "net/url"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/zitadel/oidc/v2/pkg/http" "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
// JWTProfileExchange handles the oauth2 jwt profile exchange // JWTProfileExchange handles the oauth2 jwt profile exchange
func JWTProfileExchange(jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) { func JWTProfileExchange(ctx context.Context, jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) {
return CallTokenEndpoint(jwtProfileGrantRequest, caller) return CallTokenEndpoint(ctx, jwtProfileGrantRequest, caller)
} }
func ClientAssertionCodeOptions(assertion string) []oauth2.AuthCodeOption { func ClientAssertionCodeOptions(assertion string) []oauth2.AuthCodeOption {

View file

@ -10,7 +10,7 @@ const (
applicationKey = "application" applicationKey = "application"
) )
type keyFile struct { type KeyFile struct {
Type string `json:"type"` // serviceaccount or application Type string `json:"type"` // serviceaccount or application
KeyID string `json:"keyId"` KeyID string `json:"keyId"`
Key string `json:"key"` Key string `json:"key"`
@ -23,7 +23,7 @@ type keyFile struct {
ClientID string `json:"clientId"` ClientID string `json:"clientId"`
} }
func ConfigFromKeyFile(path string) (*keyFile, error) { func ConfigFromKeyFile(path string) (*KeyFile, error) {
data, err := ioutil.ReadFile(path) data, err := ioutil.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
@ -31,8 +31,8 @@ func ConfigFromKeyFile(path string) (*keyFile, error) {
return ConfigFromKeyFileData(data) return ConfigFromKeyFileData(data)
} }
func ConfigFromKeyFileData(data []byte) (*keyFile, error) { func ConfigFromKeyFileData(data []byte) (*KeyFile, error) {
var f keyFile var f KeyFile
if err := json.Unmarshal(data, &f); err != nil { if err := json.Unmarshal(data, &f); err != nil {
return nil, err return nil, err
} }

View file

@ -1,16 +1,22 @@
package profile package profile
import ( import (
"context"
"net/http" "net/http"
"time" "time"
jose "github.com/go-jose/go-jose/v3"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/v2/pkg/client" "github.com/zitadel/oidc/v3/pkg/client"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
type TokenSource interface {
oauth2.TokenSource
TokenCtx(context.Context) (*oauth2.Token, error)
}
// jwtProfileTokenSource implement the oauth2.TokenSource // jwtProfileTokenSource implement the oauth2.TokenSource
// it will request a token using the OAuth2 JWT Profile Grant // it will request a token using the OAuth2 JWT Profile Grant
// therefore sending an `assertion` by signing a JWT with the provided private key // therefore sending an `assertion` by signing a JWT with the provided private key
@ -23,23 +29,38 @@ type jwtProfileTokenSource struct {
tokenEndpoint string tokenEndpoint string
} }
func NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath string, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) { // NewJWTProfileTokenSourceFromKeyFile returns an implementation of TokenSource
keyData, err := client.ConfigFromKeyFile(keyPath) // It will request a token using the OAuth2 JWT Profile Grant,
// therefore sending an `assertion` by singing a JWT with the provided private key from jsonFile.
//
// The passed context is only used for the call to the Discover endpoint.
func NewJWTProfileTokenSourceFromKeyFile(ctx context.Context, issuer, jsonFile string, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) {
keyData, err := client.ConfigFromKeyFile(jsonFile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
} }
func NewJWTProfileTokenSourceFromKeyFileData(issuer string, data []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) { // NewJWTProfileTokenSourceFromKeyFileData returns an implementation of oauth2.TokenSource
keyData, err := client.ConfigFromKeyFileData(data) // It will request a token using the OAuth2 JWT Profile Grant,
// therefore sending an `assertion` by singing a JWT with the provided private key in jsonData.
//
// The passed context is only used for the call to the Discover endpoint.
func NewJWTProfileTokenSourceFromKeyFileData(ctx context.Context, issuer string, jsonData []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) {
keyData, err := client.ConfigFromKeyFileData(jsonData)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
} }
func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) { // NewJWTProfileSource returns an implementation of oauth2.TokenSource
// It will request a token using the OAuth2 JWT Profile Grant,
// therefore sending an `assertion` by singing a JWT with the provided private key.
//
// The passed context is only used for the call to the Discover endpoint.
func NewJWTProfileTokenSource(ctx context.Context, issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) {
signer, err := client.NewSignerFromPrivateKeyByte(key, keyID) signer, err := client.NewSignerFromPrivateKeyByte(key, keyID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -55,7 +76,7 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes
opt(source) opt(source)
} }
if source.tokenEndpoint == "" { if source.tokenEndpoint == "" {
config, err := client.Discover(issuer, source.httpClient) config, err := client.Discover(ctx, issuer, source.httpClient)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -64,13 +85,13 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes
return source, nil return source, nil
} }
func WithHTTPClient(client *http.Client) func(*jwtProfileTokenSource) { func WithHTTPClient(client *http.Client) func(source *jwtProfileTokenSource) {
return func(source *jwtProfileTokenSource) { return func(source *jwtProfileTokenSource) {
source.httpClient = client source.httpClient = client
} }
} }
func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(*jwtProfileTokenSource) { func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(source *jwtProfileTokenSource) {
return func(source *jwtProfileTokenSource) { return func(source *jwtProfileTokenSource) {
source.tokenEndpoint = tokenEndpoint source.tokenEndpoint = tokenEndpoint
} }
@ -85,9 +106,13 @@ func (j *jwtProfileTokenSource) HttpClient() *http.Client {
} }
func (j *jwtProfileTokenSource) Token() (*oauth2.Token, error) { func (j *jwtProfileTokenSource) Token() (*oauth2.Token, error) {
return j.TokenCtx(context.Background())
}
func (j *jwtProfileTokenSource) TokenCtx(ctx context.Context) (*oauth2.Token, error) {
assertion, err := client.SignedJWTProfileAssertion(j.clientID, j.audience, time.Hour, j.signer) assertion, err := client.SignedJWTProfileAssertion(j.clientID, j.audience, time.Hour, j.signer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return client.JWTProfileExchange(oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j) return client.JWTProfileExchange(ctx, oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j)
} }

View file

@ -4,9 +4,9 @@ import (
"context" "context"
"net/http" "net/http"
"github.com/zitadel/oidc/v2/pkg/client/rp" "github.com/zitadel/oidc/v3/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
const ( const (

View file

@ -1,7 +1,7 @@
package rp package rp
import ( import (
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange" "github.com/zitadel/oidc/v3/pkg/oidc/grants/tokenexchange"
) )
// DelegationTokenRequest is an implementation of TokenExchangeRequest // DelegationTokenRequest is an implementation of TokenExchangeRequest

View file

@ -5,8 +5,8 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/zitadel/oidc/v2/pkg/client" "github.com/zitadel/oidc/v3/pkg/client"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) { func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) {
@ -32,19 +32,21 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.
// DeviceAuthorization starts a new Device Authorization flow as defined // DeviceAuthorization starts a new Device Authorization flow as defined
// in RFC 8628, section 3.1 and 3.2: // in RFC 8628, section 3.1 and 3.2:
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1 // https://www.rfc-editor.org/rfc/rfc8628#section-3.1
func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) { func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, authFn any) (*oidc.DeviceAuthorizationResponse, error) {
ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAuthorization")
req, err := newDeviceClientCredentialsRequest(scopes, rp) req, err := newDeviceClientCredentialsRequest(scopes, rp)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return client.CallDeviceAuthorizationEndpoint(req, rp) return client.CallDeviceAuthorizationEndpoint(ctx, req, rp, authFn)
} }
// DeviceAccessToken attempts to obtain tokens from a Device Authorization, // DeviceAccessToken attempts to obtain tokens from a Device Authorization,
// by means of polling as defined in RFC, section 3.3 and 3.4: // by means of polling as defined in RFC, section 3.3 and 3.4:
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4 // https://www.rfc-editor.org/rfc/rfc8628#section-3.4
func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) { func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAccessToken")
req := &client.DeviceAccessTokenRequest{ req := &client.DeviceAccessTokenRequest{
DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{ DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
GrantType: oidc.GrantTypeDeviceCode, GrantType: oidc.GrantTypeDeviceCode,

View file

@ -7,10 +7,10 @@ import (
"net/http" "net/http"
"sync" "sync"
"gopkg.in/square/go-jose.v2" jose "github.com/go-jose/go-jose/v3"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet { func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet {

17
pkg/client/rp/log.go Normal file
View file

@ -0,0 +1,17 @@
package rp
import (
"context"
"github.com/zitadel/logging"
"golang.org/x/exp/slog"
)
func logCtxWithRPData(ctx context.Context, rp RelyingParty, attrs ...any) context.Context {
logger, ok := rp.Logger(ctx)
if !ok {
return ctx
}
logger = logger.With(slog.Group("rp", attrs...))
return logging.ToContext(ctx, logger)
}

View file

@ -7,16 +7,17 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"time" "time"
jose "github.com/go-jose/go-jose/v3"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/zitadel/logging"
"golang.org/x/exp/slog"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/v2/pkg/client" "github.com/zitadel/oidc/v3/pkg/client"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
const ( const (
@ -63,11 +64,14 @@ type RelyingParty interface {
// be used to start a DeviceAuthorization flow. // be used to start a DeviceAuthorization flow.
GetDeviceAuthorizationEndpoint() string GetDeviceAuthorizationEndpoint() string
// IDTokenVerifier returns the verifier interface used for oidc id_token verification // IDTokenVerifier returns the verifier used for oidc id_token verification
IDTokenVerifier() IDTokenVerifier IDTokenVerifier() *IDTokenVerifier
// ErrorHandler returns the handler used for callback errors // ErrorHandler returns the handler used for callback errors
ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)
// Logger from the context, or a fallback if set.
Logger(context.Context) (logger *slog.Logger, ok bool)
} }
type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string)
@ -88,9 +92,10 @@ type relyingParty struct {
cookieHandler *httphelper.CookieHandler cookieHandler *httphelper.CookieHandler
errorHandler func(http.ResponseWriter, *http.Request, string, string, string) errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
idTokenVerifier IDTokenVerifier idTokenVerifier *IDTokenVerifier
verifierOpts []VerifierOption verifierOpts []VerifierOption
signer jose.Signer signer jose.Signer
logger *slog.Logger
} }
func (rp *relyingParty) OAuthConfig() *oauth2.Config { func (rp *relyingParty) OAuthConfig() *oauth2.Config {
@ -137,7 +142,7 @@ func (rp *relyingParty) GetRevokeEndpoint() string {
return rp.endpoints.RevokeURL return rp.endpoints.RevokeURL
} }
func (rp *relyingParty) IDTokenVerifier() IDTokenVerifier { func (rp *relyingParty) IDTokenVerifier() *IDTokenVerifier {
if rp.idTokenVerifier == nil { if rp.idTokenVerifier == nil {
rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...) rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...)
} }
@ -151,6 +156,14 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request,
return rp.errorHandler return rp.errorHandler
} }
func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) {
logger, ok = logging.FromContext(ctx)
if ok {
return logger, ok
}
return rp.logger, rp.logger != nil
}
// NewRelyingPartyOAuth creates an (OAuth2) RelyingParty with the given // NewRelyingPartyOAuth creates an (OAuth2) RelyingParty with the given
// OAuth2 Config and possible configOptions // OAuth2 Config and possible configOptions
// it will use the AuthURL and TokenURL set in config // it will use the AuthURL and TokenURL set in config
@ -177,7 +190,7 @@ func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingPart
// NewRelyingPartyOIDC creates an (OIDC) RelyingParty with the given // NewRelyingPartyOIDC creates an (OIDC) RelyingParty with the given
// issuer, clientID, clientSecret, redirectURI, scopes and possible configOptions // issuer, clientID, clientSecret, redirectURI, scopes and possible configOptions
// it will run discovery on the provided issuer and use the found endpoints // it will run discovery on the provided issuer and use the found endpoints
func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) { func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) {
rp := &relyingParty{ rp := &relyingParty{
issuer: issuer, issuer: issuer,
oauthConfig: &oauth2.Config{ oauthConfig: &oauth2.Config{
@ -195,7 +208,8 @@ func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, sco
return nil, err return nil, err
} }
} }
discoveryConfiguration, err := client.Discover(rp.issuer, rp.httpClient, rp.DiscoveryEndpoint) ctx = logCtxWithRPData(ctx, rp, "function", "NewRelyingPartyOIDC")
discoveryConfiguration, err := client.Discover(ctx, rp.issuer, rp.httpClient, rp.DiscoveryEndpoint)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -282,6 +296,15 @@ func WithJWTProfile(signerFromKey SignerFromKey) Option {
} }
} }
// WithLogger sets a logger that is used
// in case the request context does not contain a logger.
func WithLogger(logger *slog.Logger) Option {
return func(rp *relyingParty) error {
rp.logger = logger
return nil
}
}
type SignerFromKey func() (jose.Signer, error) type SignerFromKey func() (jose.Signer, error)
func SignerFromKeyPath(path string) SignerFromKey { func SignerFromKeyPath(path string) SignerFromKey {
@ -310,26 +333,6 @@ func SignerFromKeyAndKeyID(key []byte, keyID string) SignerFromKey {
} }
} }
// Discover calls the discovery endpoint of the provided issuer and returns the found endpoints
//
// deprecated: use client.Discover
func Discover(issuer string, httpClient *http.Client) (Endpoints, error) {
wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
req, err := http.NewRequest("GET", wellKnown, nil)
if err != nil {
return Endpoints{}, err
}
discoveryConfig := new(oidc.DiscoveryConfiguration)
err = httphelper.HttpRequest(httpClient, req, &discoveryConfig)
if err != nil {
return Endpoints{}, err
}
if discoveryConfig.Issuer != issuer {
return Endpoints{}, fmt.Errorf("%w: Expected: %s, got: %s", oidc.ErrIssuerInvalid, discoveryConfig.Issuer, issuer)
}
return GetEndpoints(discoveryConfig), nil
}
// AuthURL returns the auth request url // AuthURL returns the auth request url
// (wrapping the oauth2 `AuthCodeURL`) // (wrapping the oauth2 `AuthCodeURL`)
func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string { func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string {
@ -377,9 +380,29 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri
return oidc.NewSHACodeChallenge(codeVerifier), nil return oidc.NewSHACodeChallenge(codeVerifier), nil
} }
// ErrMissingIDToken is returned when an id_token was expected,
// but not received in the token response.
var ErrMissingIDToken = errors.New("id_token missing")
func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Token, rp RelyingParty) (*oidc.Tokens[C], error) {
if rp.IsOAuth2Only() {
return &oidc.Tokens[C]{Token: token}, nil
}
idTokenString, ok := token.Extra(idTokenKey).(string)
if !ok {
return &oidc.Tokens[C]{Token: token}, ErrMissingIDToken
}
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
if err != nil {
return nil, err
}
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
}
// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token // CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
// returning it parsed together with the oauth2 tokens (access, refresh) // returning it parsed together with the oauth2 tokens (access, refresh)
func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) { func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
ctx = logCtxWithRPData(ctx, rp, "function", "CodeExchange")
ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient()) ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
codeOpts := make([]oauth2.AuthCodeOption, 0) codeOpts := make([]oauth2.AuthCodeOption, 0)
for _, opt := range opts { for _, opt := range opts {
@ -390,22 +413,7 @@ func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingP
if err != nil { if err != nil {
return nil, err return nil, err
} }
return verifyTokenResponse[C](ctx, token, rp)
if rp.IsOAuth2Only() {
return &oidc.Tokens[C]{Token: token}, nil
}
idTokenString, ok := token.Extra(idTokenKey).(string)
if !ok {
return nil, errors.New("id_token missing")
}
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
if err != nil {
return nil, err
}
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
} }
type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty)
@ -457,14 +465,18 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
} }
} }
type CodeExchangeUserinfoCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info *oidc.UserInfo) type SubjectGetter interface {
GetSubject() string
}
type CodeExchangeUserinfoCallback[C oidc.IDClaims, U SubjectGetter] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info U)
// UserinfoCallback wraps the callback function of the CodeExchangeHandler // UserinfoCallback wraps the callback function of the CodeExchangeHandler
// and calls the userinfo endpoint with the access token // and calls the userinfo endpoint with the access token
// on success it will pass the userinfo into its callback function as well // on success it will pass the userinfo into its callback function as well
func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] { func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCallback[C, U]) CodeExchangeCallback[C] {
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) { return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
if err != nil { if err != nil {
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized) http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
return return
@ -473,19 +485,26 @@ func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeEx
} }
} }
// Userinfo will call the OIDC Userinfo Endpoint with the provided token // Userinfo will call the OIDC [UserInfo] Endpoint with the provided token and returns
func Userinfo(token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) { // the response in an instance of type U.
req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil) // [*oidc.UserInfo] can be used as a good example, or use a custom type if type-safe
// access to custom claims is needed.
//
// [UserInfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject string, rp RelyingParty) (userinfo U, err error) {
var nilU U
ctx = logCtxWithRPData(ctx, rp, "function", "Userinfo")
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil)
if err != nil { if err != nil {
return nil, err return nilU, err
} }
req.Header.Set("authorization", tokenType+" "+token) req.Header.Set("authorization", tokenType+" "+token)
userinfo := new(oidc.UserInfo)
if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil { if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil {
return nil, err return nilU, err
} }
if userinfo.Subject != subject { if userinfo.GetSubject() != subject {
return nil, ErrUserInfoSubNotMatching return nilU, ErrUserInfoSubNotMatching
} }
return userinfo, nil return userinfo, nil
} }
@ -554,7 +573,7 @@ func withURLParam(key, value string) func() []oauth2.AuthCodeOption {
// This is the generalized, unexported, function used by both // This is the generalized, unexported, function used by both
// URLParamOpt and AuthURLOpt. // URLParamOpt and AuthURLOpt.
func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption { func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption {
return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode()) return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).String())
} }
type URLParamOpt func() []oauth2.AuthCodeOption type URLParamOpt func() []oauth2.AuthCodeOption
@ -626,11 +645,15 @@ type RefreshTokenRequest struct {
GrantType oidc.GrantType `schema:"grant_type"` GrantType oidc.GrantType `schema:"grant_type"`
} }
// RefreshAccessToken performs a token refresh. If it doesn't error, it will always // RefreshTokens performs a token refresh. If it doesn't error, it will always
// provide a new AccessToken. It may provide a new RefreshToken, and if it does, then // provide a new AccessToken. It may provide a new RefreshToken, and if it does, then
// the old one should be considered invalid. It may also provide a new IDToken. The // the old one should be considered invalid.
// new IDToken can be retrieved with token.Extra("id_token"). //
func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) { // In case the RP is not OAuth2 only and an IDToken was part of the response,
// the IDToken and AccessToken will be verfied
// and the IDToken and IDTokenClaims fields will be populated in the returned object.
func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) {
ctx = logCtxWithRPData(ctx, rp, "function", "RefreshTokens")
request := RefreshTokenRequest{ request := RefreshTokenRequest{
RefreshToken: refreshToken, RefreshToken: refreshToken,
Scopes: rp.OAuthConfig().Scopes, Scopes: rp.OAuthConfig().Scopes,
@ -640,17 +663,28 @@ func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAs
ClientAssertionType: clientAssertionType, ClientAssertionType: clientAssertionType,
GrantType: oidc.GrantTypeRefreshToken, GrantType: oidc.GrantTypeRefreshToken,
} }
return client.CallTokenEndpoint(request, tokenEndpointCaller{RelyingParty: rp}) newToken, err := client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp})
if err != nil {
return nil, err
}
tokens, err := verifyTokenResponse[C](ctx, newToken, rp)
if err == nil || errors.Is(err, ErrMissingIDToken) {
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse
// ...except that it might not contain an id_token.
return tokens, nil
}
return nil, err
} }
func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) {
ctx = logCtxWithRPData(ctx, rp, "function", "EndSession")
request := oidc.EndSessionRequest{ request := oidc.EndSessionRequest{
IdTokenHint: idToken, IdTokenHint: idToken,
ClientID: rp.OAuthConfig().ClientID, ClientID: rp.OAuthConfig().ClientID,
PostLogoutRedirectURI: optionalRedirectURI, PostLogoutRedirectURI: optionalRedirectURI,
State: optionalState, State: optionalState,
} }
return client.CallEndSessionEndpoint(request, nil, rp) return client.CallEndSessionEndpoint(ctx, request, nil, rp)
} }
// RevokeToken requires a RelyingParty that is also a client.RevokeCaller. The RelyingParty // RevokeToken requires a RelyingParty that is also a client.RevokeCaller. The RelyingParty
@ -658,7 +692,8 @@ func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState str
// NewRelyingPartyOAuth() does not. // NewRelyingPartyOAuth() does not.
// //
// tokenTypeHint should be either "id_token" or "refresh_token". // tokenTypeHint should be either "id_token" or "refresh_token".
func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error { func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error {
ctx = logCtxWithRPData(ctx, rp, "function", "RevokeToken")
request := client.RevokeRequest{ request := client.RevokeRequest{
Token: token, Token: token,
TokenTypeHint: tokenTypeHint, TokenTypeHint: tokenTypeHint,
@ -666,7 +701,7 @@ func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error {
ClientSecret: rp.OAuthConfig().ClientSecret, ClientSecret: rp.OAuthConfig().ClientSecret,
} }
if rc, ok := rp.(client.RevokeCaller); ok && rc.GetRevokeEndpoint() != "" { if rc, ok := rp.(client.RevokeCaller); ok && rc.GetRevokeEndpoint() != "" {
return client.CallRevokeEndpoint(request, nil, rc) return client.CallRevokeEndpoint(ctx, request, nil, rc)
} }
return fmt.Errorf("RelyingParty does not support RevokeCaller") return fmt.Errorf("RelyingParty does not support RevokeCaller")
} }

View file

@ -0,0 +1,107 @@
package rp
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
)
func Test_verifyTokenResponse(t *testing.T) {
verifier := &IDTokenVerifier{
Issuer: tu.ValidIssuer,
MaxAgeIAT: 2 * time.Minute,
ClientID: tu.ValidClientID,
Offset: time.Second,
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
KeySet: tu.KeySet{},
MaxAge: 2 * time.Minute,
ACR: tu.ACRVerify,
Nonce: func(context.Context) string { return tu.ValidNonce },
}
tests := []struct {
name string
oauth2Only bool
tokens func() (token *oauth2.Token, want *oidc.Tokens[*oidc.IDTokenClaims])
wantErr error
}{
{
name: "succes, oauth2 only",
oauth2Only: true,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
Token: token,
}
},
},
{
name: "id_token missing error",
oauth2Only: false,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
Token: token,
}
},
wantErr: ErrMissingIDToken,
},
{
name: "verify tokens error",
oauth2Only: false,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
token = token.WithExtra(map[string]any{
"id_token": "foobar",
})
return token, nil
},
wantErr: oidc.ErrParse,
},
{
name: "success, with id_token",
oauth2Only: false,
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
accesToken, _ := tu.ValidAccessToken()
token := &oauth2.Token{
AccessToken: accesToken,
}
idToken, claims := tu.ValidIDToken()
token = token.WithExtra(map[string]any{
"id_token": idToken,
})
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
Token: token,
IDTokenClaims: claims,
IDToken: idToken,
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := &relyingParty{
oauth2Only: tt.oauth2Only,
idTokenVerifier: verifier,
}
token, want := tt.tokens()
got, err := verifyTokenResponse[*oidc.IDTokenClaims](context.Background(), token, rp)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, want, got)
})
}
}

View file

@ -5,7 +5,7 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange" "github.com/zitadel/oidc/v3/pkg/oidc/grants/tokenexchange"
) )
// TokenExchangeRP extends the `RelyingParty` interface for the *draft* oauth2 `Token Exchange` // TokenExchangeRP extends the `RelyingParty` interface for the *draft* oauth2 `Token Exchange`

View file

@ -0,0 +1,45 @@
package rp_test
import (
"context"
"fmt"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
type UserInfo struct {
Subject string `json:"sub,omitempty"`
oidc.UserInfoProfile
oidc.UserInfoEmail
oidc.UserInfoPhone
Address *oidc.UserInfoAddress `json:"address,omitempty"`
// Foo and Bar are custom claims
Foo string `json:"foo,omitempty"`
Bar struct {
Val1 string `json:"val_1,omitempty"`
Val2 string `json:"val_2,omitempty"`
} `json:"bar,omitempty"`
// Claims are all the combined claims, including custom.
Claims map[string]any `json:"-,omitempty"`
}
func (u *UserInfo) GetSubject() string {
return u.Subject
}
func ExampleUserinfo_custom() {
rpo, err := rp.NewRelyingPartyOIDC(context.TODO(), "http://localhost:8080", "clientid", "clientsecret", "http://example.com/redirect", []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone})
if err != nil {
panic(err)
}
info, err := rp.Userinfo[*UserInfo](context.TODO(), "accesstokenstring", "Bearer", "userid", rpo)
if err != nil {
panic(err)
}
fmt.Println(info)
}

View file

@ -4,24 +4,14 @@ import (
"context" "context"
"time" "time"
"gopkg.in/square/go-jose.v2" jose "github.com/go-jose/go-jose/v3"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
type IDTokenVerifier interface {
oidc.Verifier
ClientID() string
SupportedSignAlgs() []string
KeySet() oidc.KeySet
Nonce(context.Context) string
ACR() oidc.ACRVerifier
MaxAge() time.Duration
}
// VerifyTokens implement the Token Response Validation as defined in OIDC specification // VerifyTokens implement the Token Response Validation as defined in OIDC specification
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) { func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier) (claims C, err error) {
var nilClaims C var nilClaims C
claims, err = VerifyIDToken[C](ctx, idToken, v) claims, err = VerifyIDToken[C](ctx, idToken, v)
@ -36,7 +26,7 @@ func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken str
// VerifyIDToken validates the id token according to // VerifyIDToken validates the id token according to
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) { func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier) (claims C, err error) {
var nilClaims C var nilClaims C
decrypted, err := oidc.DecryptToken(token) decrypted, err := oidc.DecryptToken(token)
@ -52,27 +42,27 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil { if err = oidc.CheckIssuer(claims, v.Issuer); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckAudience(claims, v.ClientID()); err != nil { if err = oidc.CheckAudience(claims, v.ClientID); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil { if err = oidc.CheckAuthorizedParty(claims, v.ClientID); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil { if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil {
return nilClaims, err return nilClaims, err
} }
@ -80,16 +70,18 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil { if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil {
return nilClaims, err return nilClaims, err
} }
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil { if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil {
return nilClaims, err return nilClaims, err
} }
return claims, nil return claims, nil
} }
type IDTokenVerifier oidc.Verifier
// VerifyAccessToken validates the access token according to // VerifyAccessToken validates the access token according to
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation // https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error { func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
@ -107,15 +99,14 @@ func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAl
return nil return nil
} }
// NewIDTokenVerifier returns an implementation of `IDTokenVerifier` // NewIDTokenVerifier returns a oidc.Verifier suitable for ID token verification.
// for `VerifyTokens` and `VerifyIDToken` func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) *IDTokenVerifier {
func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) IDTokenVerifier { v := &IDTokenVerifier{
v := &idTokenVerifier{ Issuer: issuer,
issuer: issuer, ClientID: clientID,
clientID: clientID, KeySet: keySet,
keySet: keySet, Offset: time.Second,
offset: time.Second, Nonce: func(_ context.Context) string {
nonce: func(_ context.Context) string {
return "" return ""
}, },
} }
@ -128,95 +119,47 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...
} }
// VerifierOption is the type for providing dynamic options to the IDTokenVerifier // VerifierOption is the type for providing dynamic options to the IDTokenVerifier
type VerifierOption func(*idTokenVerifier) type VerifierOption func(*IDTokenVerifier)
// WithIssuedAtOffset mitigates the risk of iat to be in the future // WithIssuedAtOffset mitigates the risk of iat to be in the future
// because of clock skews with the ability to add an offset to the current time // because of clock skews with the ability to add an offset to the current time
func WithIssuedAtOffset(offset time.Duration) func(*idTokenVerifier) { func WithIssuedAtOffset(offset time.Duration) VerifierOption {
return func(v *idTokenVerifier) { return func(v *IDTokenVerifier) {
v.offset = offset v.Offset = offset
} }
} }
// WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now // WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
func WithIssuedAtMaxAge(maxAge time.Duration) func(*idTokenVerifier) { func WithIssuedAtMaxAge(maxAge time.Duration) VerifierOption {
return func(v *idTokenVerifier) { return func(v *IDTokenVerifier) {
v.maxAgeIAT = maxAge v.MaxAgeIAT = maxAge
} }
} }
// WithNonce sets the function to check the nonce // WithNonce sets the function to check the nonce
func WithNonce(nonce func(context.Context) string) VerifierOption { func WithNonce(nonce func(context.Context) string) VerifierOption {
return func(v *idTokenVerifier) { return func(v *IDTokenVerifier) {
v.nonce = nonce v.Nonce = nonce
} }
} }
// WithACRVerifier sets the verifier for the acr claim // WithACRVerifier sets the verifier for the acr claim
func WithACRVerifier(verifier oidc.ACRVerifier) VerifierOption { func WithACRVerifier(verifier oidc.ACRVerifier) VerifierOption {
return func(v *idTokenVerifier) { return func(v *IDTokenVerifier) {
v.acr = verifier v.ACR = verifier
} }
} }
// WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now // WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now
func WithAuthTimeMaxAge(maxAge time.Duration) VerifierOption { func WithAuthTimeMaxAge(maxAge time.Duration) VerifierOption {
return func(v *idTokenVerifier) { return func(v *IDTokenVerifier) {
v.maxAge = maxAge v.MaxAge = maxAge
} }
} }
// WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm // WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm
func WithSupportedSigningAlgorithms(algs ...string) VerifierOption { func WithSupportedSigningAlgorithms(algs ...string) VerifierOption {
return func(v *idTokenVerifier) { return func(v *IDTokenVerifier) {
v.supportedSignAlgs = algs v.SupportedSignAlgs = algs
} }
} }
type idTokenVerifier struct {
issuer string
maxAgeIAT time.Duration
offset time.Duration
clientID string
supportedSignAlgs []string
keySet oidc.KeySet
acr oidc.ACRVerifier
maxAge time.Duration
nonce func(ctx context.Context) string
}
func (i *idTokenVerifier) Issuer() string {
return i.issuer
}
func (i *idTokenVerifier) MaxAgeIAT() time.Duration {
return i.maxAgeIAT
}
func (i *idTokenVerifier) Offset() time.Duration {
return i.offset
}
func (i *idTokenVerifier) ClientID() string {
return i.clientID
}
func (i *idTokenVerifier) SupportedSignAlgs() []string {
return i.supportedSignAlgs
}
func (i *idTokenVerifier) KeySet() oidc.KeySet {
return i.keySet
}
func (i *idTokenVerifier) Nonce(ctx context.Context) string {
return i.nonce(ctx)
}
func (i *idTokenVerifier) ACR() oidc.ACRVerifier {
return i.acr
}
func (i *idTokenVerifier) MaxAge() time.Duration {
return i.maxAge
}

View file

@ -5,24 +5,24 @@ import (
"testing" "testing"
"time" "time"
jose "github.com/go-jose/go-jose/v3"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v2/internal/testutil" tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"gopkg.in/square/go-jose.v2"
) )
func TestVerifyTokens(t *testing.T) { func TestVerifyTokens(t *testing.T) {
verifier := &idTokenVerifier{ verifier := &IDTokenVerifier{
issuer: tu.ValidIssuer, Issuer: tu.ValidIssuer,
maxAgeIAT: 2 * time.Minute, MaxAgeIAT: 2 * time.Minute,
offset: time.Second, Offset: time.Second,
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
keySet: tu.KeySet{}, KeySet: tu.KeySet{},
maxAge: 2 * time.Minute, MaxAge: 2 * time.Minute,
acr: tu.ACRVerify, ACR: tu.ACRVerify,
nonce: func(context.Context) string { return tu.ValidNonce }, Nonce: func(context.Context) string { return tu.ValidNonce },
clientID: tu.ValidClientID, ClientID: tu.ValidClientID,
} }
accessToken, _ := tu.ValidAccessToken() accessToken, _ := tu.ValidAccessToken()
atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm) atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm)
@ -91,15 +91,15 @@ func TestVerifyTokens(t *testing.T) {
} }
func TestVerifyIDToken(t *testing.T) { func TestVerifyIDToken(t *testing.T) {
verifier := &idTokenVerifier{ verifier := &IDTokenVerifier{
issuer: tu.ValidIssuer, Issuer: tu.ValidIssuer,
maxAgeIAT: 2 * time.Minute, MaxAgeIAT: 2 * time.Minute,
offset: time.Second, Offset: time.Second,
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
keySet: tu.KeySet{}, KeySet: tu.KeySet{},
maxAge: 2 * time.Minute, MaxAge: 2 * time.Minute,
acr: tu.ACRVerify, ACR: tu.ACRVerify,
nonce: func(context.Context) string { return tu.ValidNonce }, Nonce: func(context.Context) string { return tu.ValidNonce },
} }
tests := []struct { tests := []struct {
@ -231,7 +231,7 @@ func TestVerifyIDToken(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
token, want := tt.tokenClaims() token, want := tt.tokenClaims()
verifier.clientID = tt.clientID verifier.ClientID = tt.clientID
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier) got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
if tt.wantErr { if tt.wantErr {
assert.Error(t, err) assert.Error(t, err)
@ -312,7 +312,7 @@ func TestNewIDTokenVerifier(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
want IDTokenVerifier want *IDTokenVerifier
}{ }{
{ {
name: "nil nonce", // otherwise assert.Equal will fail on the function name: "nil nonce", // otherwise assert.Equal will fail on the function
@ -329,16 +329,16 @@ func TestNewIDTokenVerifier(t *testing.T) {
WithSupportedSigningAlgorithms("ABC", "DEF"), WithSupportedSigningAlgorithms("ABC", "DEF"),
}, },
}, },
want: &idTokenVerifier{ want: &IDTokenVerifier{
issuer: tu.ValidIssuer, Issuer: tu.ValidIssuer,
offset: time.Minute, Offset: time.Minute,
maxAgeIAT: time.Hour, MaxAgeIAT: time.Hour,
clientID: tu.ValidClientID, ClientID: tu.ValidClientID,
keySet: tu.KeySet{}, KeySet: tu.KeySet{},
nonce: nil, Nonce: nil,
acr: nil, ACR: nil,
maxAge: 2 * time.Hour, MaxAge: 2 * time.Hour,
supportedSignAlgs: []string{"ABC", "DEF"}, SupportedSignAlgs: []string{"ABC", "DEF"},
}, },
}, },
} }

View file

@ -4,9 +4,9 @@ import (
"context" "context"
"fmt" "fmt"
tu "github.com/zitadel/oidc/v2/internal/testutil" tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v2/pkg/client/rp" "github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
// MyCustomClaims extends the TokenClaims base, // MyCustomClaims extends the TokenClaims base,

View file

@ -0,0 +1,52 @@
package rs_test
import (
"context"
"fmt"
"github.com/zitadel/oidc/v3/pkg/client/rs"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
type IntrospectionResponse struct {
Active bool `json:"active"`
Scope oidc.SpaceDelimitedArray `json:"scope,omitempty"`
ClientID string `json:"client_id,omitempty"`
TokenType string `json:"token_type,omitempty"`
Expiration oidc.Time `json:"exp,omitempty"`
IssuedAt oidc.Time `json:"iat,omitempty"`
NotBefore oidc.Time `json:"nbf,omitempty"`
Subject string `json:"sub,omitempty"`
Audience oidc.Audience `json:"aud,omitempty"`
Issuer string `json:"iss,omitempty"`
JWTID string `json:"jti,omitempty"`
Username string `json:"username,omitempty"`
oidc.UserInfoProfile
oidc.UserInfoEmail
oidc.UserInfoPhone
Address *oidc.UserInfoAddress `json:"address,omitempty"`
// Foo and Bar are custom claims
Foo string `json:"foo,omitempty"`
Bar struct {
Val1 string `json:"val_1,omitempty"`
Val2 string `json:"val_2,omitempty"`
} `json:"bar,omitempty"`
// Claims are all the combined claims, including custom.
Claims map[string]any `json:"-,omitempty"`
}
func ExampleIntrospect_custom() {
rss, err := rs.NewResourceServerClientCredentials(context.TODO(), "http://localhost:8080", "clientid", "clientsecret")
if err != nil {
panic(err)
}
resp, err := rs.Introspect[*IntrospectionResponse](context.TODO(), rss, "accesstokenstring")
if err != nil {
panic(err)
}
fmt.Println(resp)
}

View file

@ -6,9 +6,9 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/zitadel/oidc/v2/pkg/client" "github.com/zitadel/oidc/v3/pkg/client"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
type ResourceServer interface { type ResourceServer interface {
@ -42,14 +42,14 @@ func (r *resourceServer) AuthFn() (any, error) {
return r.authFn() return r.authFn()
} }
func NewResourceServerClientCredentials(issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) { func NewResourceServerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) {
authorizer := func() (any, error) { authorizer := func() (any, error) {
return httphelper.AuthorizeBasic(clientID, clientSecret), nil return httphelper.AuthorizeBasic(clientID, clientSecret), nil
} }
return newResourceServer(issuer, authorizer, option...) return newResourceServer(ctx, issuer, authorizer, option...)
} }
func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) { func NewResourceServerJWTProfile(ctx context.Context, issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) {
signer, err := client.NewSignerFromPrivateKeyByte(key, keyID) signer, err := client.NewSignerFromPrivateKeyByte(key, keyID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -61,10 +61,10 @@ func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, opt
} }
return client.ClientAssertionFormAuthorization(assertion), nil return client.ClientAssertionFormAuthorization(assertion), nil
} }
return newResourceServer(issuer, authorizer, options...) return newResourceServer(ctx, issuer, authorizer, options...)
} }
func newResourceServer(issuer string, authorizer func() (any, error), options ...Option) (*resourceServer, error) { func newResourceServer(ctx context.Context, issuer string, authorizer func() (any, error), options ...Option) (*resourceServer, error) {
rs := &resourceServer{ rs := &resourceServer{
issuer: issuer, issuer: issuer,
httpClient: httphelper.DefaultHTTPClient, httpClient: httphelper.DefaultHTTPClient,
@ -73,7 +73,7 @@ func newResourceServer(issuer string, authorizer func() (any, error), options ..
optFunc(rs) optFunc(rs)
} }
if rs.introspectURL == "" || rs.tokenURL == "" { if rs.introspectURL == "" || rs.tokenURL == "" {
config, err := client.Discover(rs.issuer, rs.httpClient) config, err := client.Discover(ctx, rs.issuer, rs.httpClient)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -91,12 +91,12 @@ func newResourceServer(issuer string, authorizer func() (any, error), options ..
return rs, nil return rs, nil
} }
func NewResourceServerFromKeyFile(issuer, path string, options ...Option) (ResourceServer, error) { func NewResourceServerFromKeyFile(ctx context.Context, issuer, path string, options ...Option) (ResourceServer, error) {
c, err := client.ConfigFromKeyFile(path) c, err := client.ConfigFromKeyFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewResourceServerJWTProfile(issuer, c.ClientID, c.KeyID, []byte(c.Key), options...) return NewResourceServerJWTProfile(ctx, issuer, c.ClientID, c.KeyID, []byte(c.Key), options...)
} }
type Option func(*resourceServer) type Option func(*resourceServer)
@ -116,21 +116,27 @@ func WithStaticEndpoints(tokenURL, introspectURL string) Option {
} }
} }
func Introspect(ctx context.Context, rp ResourceServer, token string) (*oidc.IntrospectionResponse, error) { // Introspect calls the [RFC7662] Token Introspection
// endpoint and returns the response in an instance of type R.
// [*oidc.IntrospectionResponse] can be used as a good example, or use a custom type if type-safe
// access to custom claims is needed.
//
// [RFC7662]: https://www.rfc-editor.org/rfc/rfc7662
func Introspect[R any](ctx context.Context, rp ResourceServer, token string) (resp R, err error) {
if rp.IntrospectionURL() == "" { if rp.IntrospectionURL() == "" {
return nil, errors.New("resource server: introspection URL is empty") return resp, errors.New("resource server: introspection URL is empty")
} }
authFn, err := rp.AuthFn() authFn, err := rp.AuthFn()
if err != nil { if err != nil {
return nil, err return resp, err
} }
req, err := httphelper.FormRequest(rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn) req, err := httphelper.FormRequest(ctx, rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn)
if err != nil { if err != nil {
return nil, err return resp, err
} }
resp := new(oidc.IntrospectionResponse)
if err := httphelper.HttpRequest(rp.HttpClient(), req, resp); err != nil { if err := httphelper.HttpRequest(rp.HttpClient(), req, &resp); err != nil {
return nil, err return resp, err
} }
return resp, nil return resp, nil
} }

View file

@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/oidc"
) )
func TestNewResourceServer(t *testing.T) { func TestNewResourceServer(t *testing.T) {
@ -164,7 +165,7 @@ func TestNewResourceServer(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := newResourceServer(tt.args.issuer, tt.args.authorizer, tt.args.options...) got, err := newResourceServer(context.Background(), tt.args.issuer, tt.args.authorizer, tt.args.options...)
if tt.wantErr { if tt.wantErr {
assert.Error(t, err) assert.Error(t, err)
return return
@ -187,6 +188,7 @@ func TestIntrospect(t *testing.T) {
token string token string
} }
rp, err := newResourceServer( rp, err := newResourceServer(
context.Background(),
"https://accounts.spotify.com", "https://accounts.spotify.com",
nil, nil,
) )
@ -208,7 +210,7 @@ func TestIntrospect(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := Introspect(tt.args.ctx, tt.args.rp, tt.args.token) _, err := Introspect[*oidc.IntrospectionResponse](tt.args.ctx, tt.args.rp, tt.args.token)
if tt.wantErr { if tt.wantErr {
assert.Error(t, err) assert.Error(t, err)
return return

View file

@ -1,12 +1,13 @@
package tokenexchange package tokenexchange
import ( import (
"context"
"errors" "errors"
"net/http" "net/http"
"github.com/zitadel/oidc/v2/pkg/client" "github.com/zitadel/oidc/v3/pkg/client"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
type TokenExchanger interface { type TokenExchanger interface {
@ -21,18 +22,18 @@ type OAuthTokenExchange struct {
authFn func() (any, error) authFn func() (any, error)
} }
func NewTokenExchanger(issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { func NewTokenExchanger(ctx context.Context, issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
return newOAuthTokenExchange(issuer, nil, options...) return newOAuthTokenExchange(ctx, issuer, nil, options...)
} }
func NewTokenExchangerClientCredentials(issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { func NewTokenExchangerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
authorizer := func() (any, error) { authorizer := func() (any, error) {
return httphelper.AuthorizeBasic(clientID, clientSecret), nil return httphelper.AuthorizeBasic(clientID, clientSecret), nil
} }
return newOAuthTokenExchange(issuer, authorizer, options...) return newOAuthTokenExchange(ctx, issuer, authorizer, options...)
} }
func newOAuthTokenExchange(issuer string, authorizer func() (any, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) { func newOAuthTokenExchange(ctx context.Context, issuer string, authorizer func() (any, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) {
te := &OAuthTokenExchange{ te := &OAuthTokenExchange{
httpClient: httphelper.DefaultHTTPClient, httpClient: httphelper.DefaultHTTPClient,
} }
@ -41,7 +42,7 @@ func newOAuthTokenExchange(issuer string, authorizer func() (any, error), option
} }
if te.tokenEndpoint == "" { if te.tokenEndpoint == "" {
config, err := client.Discover(issuer, te.httpClient) config, err := client.Discover(ctx, issuer, te.httpClient)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -89,6 +90,7 @@ func (te *OAuthTokenExchange) AuthFn() (any, error) {
// ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint. // ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint.
// SubjectToken and SubjectTokenType are required parameters. // SubjectToken and SubjectTokenType are required parameters.
func ExchangeToken( func ExchangeToken(
ctx context.Context,
te TokenExchanger, te TokenExchanger,
SubjectToken string, SubjectToken string,
SubjectTokenType oidc.TokenType, SubjectTokenType oidc.TokenType,
@ -123,5 +125,5 @@ func ExchangeToken(
RequestedTokenType: RequestedTokenType, RequestedTokenType: RequestedTokenType,
} }
return client.CallTokenExchangeEndpoint(request, authFn, te) return client.CallTokenExchangeEndpoint(ctx, request, authFn, te)
} }

View file

@ -8,7 +8,7 @@ import (
"fmt" "fmt"
"hash" "hash"
"gopkg.in/square/go-jose.v2" jose "github.com/go-jose/go-jose/v3"
) )
var ErrUnsupportedAlgorithm = errors.New("unsupported signing algorithm") var ErrUnsupportedAlgorithm = errors.New("unsupported signing algorithm")

View file

@ -4,7 +4,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"gopkg.in/square/go-jose.v2" jose "github.com/go-jose/go-jose/v3"
) )
func Sign(object any, signer jose.Signer) (string, error) { func Sign(object any, signer jose.Signer) (string, error) {

View file

@ -33,7 +33,7 @@ func AuthorizeBasic(user, password string) RequestAuthorization {
} }
} }
func FormRequest(endpoint string, request any, encoder Encoder, authFn any) (*http.Request, error) { func FormRequest(ctx context.Context, endpoint string, request any, encoder Encoder, authFn any) (*http.Request, error) {
form := url.Values{} form := url.Values{}
if err := encoder.Encode(request, form); err != nil { if err := encoder.Encode(request, form); err != nil {
return nil, err return nil, err
@ -42,7 +42,7 @@ func FormRequest(endpoint string, request any, encoder Encoder, authFn any) (*ht
fn(form) fn(form)
} }
body := strings.NewReader(form.Encode()) body := strings.NewReader(form.Encode())
req, err := http.NewRequest("POST", endpoint, body) req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,5 +1,9 @@
package oidc package oidc
import (
"golang.org/x/exp/slog"
)
const ( const (
// ScopeOpenID defines the scope `openid` // ScopeOpenID defines the scope `openid`
// OpenID Connect requests MUST contain the `openid` scope value // OpenID Connect requests MUST contain the `openid` scope value
@ -77,7 +81,7 @@ type AuthRequest struct {
UILocales Locales `json:"ui_locales" schema:"ui_locales"` UILocales Locales `json:"ui_locales" schema:"ui_locales"`
IDTokenHint string `json:"id_token_hint" schema:"id_token_hint"` IDTokenHint string `json:"id_token_hint" schema:"id_token_hint"`
LoginHint string `json:"login_hint" schema:"login_hint"` LoginHint string `json:"login_hint" schema:"login_hint"`
ACRValues []string `json:"acr_values" schema:"acr_values"` ACRValues SpaceDelimitedArray `json:"acr_values" schema:"acr_values"`
CodeChallenge string `json:"code_challenge" schema:"code_challenge"` CodeChallenge string `json:"code_challenge" schema:"code_challenge"`
CodeChallengeMethod CodeChallengeMethod `json:"code_challenge_method" schema:"code_challenge_method"` CodeChallengeMethod CodeChallengeMethod `json:"code_challenge_method" schema:"code_challenge_method"`
@ -86,6 +90,15 @@ type AuthRequest struct {
RequestParam string `schema:"request"` RequestParam string `schema:"request"`
} }
func (a *AuthRequest) LogValue() slog.Value {
return slog.GroupValue(
slog.Any("scopes", a.Scopes),
slog.String("response_type", string(a.ResponseType)),
slog.String("client_id", a.ClientID),
slog.String("redirect_uri", a.RedirectURI),
)
}
// GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface // GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface
func (a *AuthRequest) GetRedirectURI() string { func (a *AuthRequest) GetRedirectURI() string {
return a.RedirectURI return a.RedirectURI

View file

@ -0,0 +1,27 @@
//go:build go1.20
package oidc
import (
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/slog"
)
func TestAuthRequest_LogValue(t *testing.T) {
a := &AuthRequest{
Scopes: SpaceDelimitedArray{"a", "b"},
ResponseType: "respType",
ClientID: "123",
RedirectURI: "http://example.com/callback",
}
want := slog.GroupValue(
slog.Any("scopes", SpaceDelimitedArray{"a", "b"}),
slog.String("response_type", "respType"),
slog.String("client_id", "123"),
slog.String("redirect_uri", "http://example.com/callback"),
)
got := a.LogValue()
assert.Equal(t, want, got)
}

View file

@ -3,7 +3,7 @@ package oidc
import ( import (
"crypto/sha256" "crypto/sha256"
"github.com/zitadel/oidc/v2/pkg/crypto" "github.com/zitadel/oidc/v3/pkg/crypto"
) )
const ( const (

View file

@ -3,6 +3,8 @@ package oidc
import ( import (
"errors" "errors"
"fmt" "fmt"
"golang.org/x/exp/slog"
) )
type errorType string type errorType string
@ -171,3 +173,34 @@ func DefaultToServerError(err error, description string) *Error {
} }
return oauth return oauth
} }
func (e *Error) LogLevel() slog.Level {
level := slog.LevelWarn
if e.ErrorType == ServerError {
level = slog.LevelError
}
if e.ErrorType == AuthorizationPending {
level = slog.LevelInfo
}
return level
}
func (e *Error) LogValue() slog.Value {
attrs := make([]slog.Attr, 0, 5)
if e.Parent != nil {
attrs = append(attrs, slog.Any("parent", e.Parent))
}
if e.Description != "" {
attrs = append(attrs, slog.String("description", e.Description))
}
if e.ErrorType != "" {
attrs = append(attrs, slog.String("type", string(e.ErrorType)))
}
if e.State != "" {
attrs = append(attrs, slog.String("state", e.State))
}
if e.redirectDisabled {
attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled))
}
return slog.GroupValue(attrs...)
}

View file

@ -0,0 +1,83 @@
//go:build go1.20
package oidc
import (
"io"
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/slog"
)
func TestError_LogValue(t *testing.T) {
type fields struct {
Parent error
ErrorType errorType
Description string
State string
redirectDisabled bool
}
tests := []struct {
name string
fields fields
want slog.Value
}{
{
name: "parent",
fields: fields{
Parent: io.EOF,
},
want: slog.GroupValue(slog.Any("parent", io.EOF)),
},
{
name: "description",
fields: fields{
Description: "oops",
},
want: slog.GroupValue(slog.String("description", "oops")),
},
{
name: "errorType",
fields: fields{
ErrorType: ExpiredToken,
},
want: slog.GroupValue(slog.String("type", string(ExpiredToken))),
},
{
name: "state",
fields: fields{
State: "123",
},
want: slog.GroupValue(slog.String("state", "123")),
},
{
name: "all fields",
fields: fields{
Parent: io.EOF,
Description: "oops",
ErrorType: ExpiredToken,
State: "123",
},
want: slog.GroupValue(
slog.Any("parent", io.EOF),
slog.String("description", "oops"),
slog.String("type", string(ExpiredToken)),
slog.String("state", "123"),
),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &Error{
Parent: tt.fields.Parent,
ErrorType: tt.fields.ErrorType,
Description: tt.fields.Description,
State: tt.fields.State,
redirectDisabled: tt.fields.redirectDisabled,
}
got := e.LogValue()
assert.Equal(t, tt.want, got)
})
}
}

81
pkg/oidc/error_test.go Normal file
View file

@ -0,0 +1,81 @@
package oidc
import (
"io"
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/slog"
)
func TestDefaultToServerError(t *testing.T) {
type args struct {
err error
description string
}
tests := []struct {
name string
args args
want *Error
}{
{
name: "default",
args: args{
err: io.ErrClosedPipe,
description: "oops",
},
want: &Error{
ErrorType: ServerError,
Description: "oops",
Parent: io.ErrClosedPipe,
},
},
{
name: "our Error",
args: args{
err: ErrAccessDenied(),
description: "oops",
},
want: &Error{
ErrorType: AccessDenied,
Description: "The authorization request was denied.",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := DefaultToServerError(tt.args.err, tt.args.description)
assert.ErrorIs(t, got, tt.want)
})
}
}
func TestError_LogLevel(t *testing.T) {
tests := []struct {
name string
err *Error
want slog.Level
}{
{
name: "server error",
err: ErrServerError(),
want: slog.LevelError,
},
{
name: "authorization pending",
err: ErrAuthorizationPending(),
want: slog.LevelInfo,
},
{
name: "some other error",
err: ErrAccessDenied(),
want: slog.LevelWarn,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.err.LogLevel()
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -7,7 +7,7 @@ import (
"crypto/rsa" "crypto/rsa"
"errors" "errors"
"gopkg.in/square/go-jose.v2" jose "github.com/go-jose/go-jose/v3"
) )
const ( const (

View file

@ -7,7 +7,7 @@ import (
"reflect" "reflect"
"testing" "testing"
"gopkg.in/square/go-jose.v2" jose "github.com/go-jose/go-jose/v3"
) )
func TestFindKey(t *testing.T) { func TestFindKey(t *testing.T) {

View file

@ -5,11 +5,11 @@ import (
"os" "os"
"time" "time"
jose "github.com/go-jose/go-jose/v3"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"github.com/muhlemmer/gu" "github.com/muhlemmer/gu"
"github.com/zitadel/oidc/v2/pkg/crypto" "github.com/zitadel/oidc/v3/pkg/crypto"
) )
const ( const (

View file

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"time" "time"
"gopkg.in/square/go-jose.v2" jose "github.com/go-jose/go-jose/v3"
) )
const ( const (

View file

@ -4,9 +4,9 @@ import (
"testing" "testing"
"time" "time"
jose "github.com/go-jose/go-jose/v3"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/text/language" "golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
) )
var ( var (

View file

@ -8,10 +8,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/gorilla/schema" jose "github.com/go-jose/go-jose/v3"
"github.com/muhlemmer/gu" "github.com/muhlemmer/gu"
"github.com/zitadel/schema"
"golang.org/x/text/language" "golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
) )
type Audience []string type Audience []string
@ -151,7 +151,7 @@ type ResponseType string
type ResponseMode string type ResponseMode string
func (s SpaceDelimitedArray) Encode() string { func (s SpaceDelimitedArray) String() string {
return strings.Join(s, " ") return strings.Join(s, " ")
} }
@ -161,11 +161,11 @@ func (s *SpaceDelimitedArray) UnmarshalText(text []byte) error {
} }
func (s SpaceDelimitedArray) MarshalText() ([]byte, error) { func (s SpaceDelimitedArray) MarshalText() ([]byte, error) {
return []byte(s.Encode()), nil return []byte(s.String()), nil
} }
func (s SpaceDelimitedArray) MarshalJSON() ([]byte, error) { func (s SpaceDelimitedArray) MarshalJSON() ([]byte, error) {
return json.Marshal((s).Encode()) return json.Marshal((s).String())
} }
func (s *SpaceDelimitedArray) UnmarshalJSON(data []byte) error { func (s *SpaceDelimitedArray) UnmarshalJSON(data []byte) error {
@ -210,7 +210,7 @@ func (s SpaceDelimitedArray) Value() (driver.Value, error) {
func NewEncoder() *schema.Encoder { func NewEncoder() *schema.Encoder {
e := schema.NewEncoder() e := schema.NewEncoder()
e.RegisterEncoder(SpaceDelimitedArray{}, func(value reflect.Value) string { e.RegisterEncoder(SpaceDelimitedArray{}, func(value reflect.Value) string {
return value.Interface().(SpaceDelimitedArray).Encode() return value.Interface().(SpaceDelimitedArray).String()
}) })
return e return e
} }

View file

@ -9,9 +9,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/gorilla/schema"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zitadel/schema"
"golang.org/x/text/language" "golang.org/x/text/language"
) )

View file

@ -29,6 +29,11 @@ func (u *UserInfo) GetAddress() *UserInfoAddress {
return u.Address return u.Address
} }
// GetSubject implements [rp.SubjectGetter]
func (u *UserInfo) GetSubject() string {
return u.Subject
}
type uiAlias UserInfo type uiAlias UserInfo
func (u *UserInfo) MarshalJSON() ([]byte, error) { func (u *UserInfo) MarshalJSON() ([]byte, error) {

View file

@ -10,9 +10,9 @@ import (
"strings" "strings"
"time" "time"
"gopkg.in/square/go-jose.v2" jose "github.com/go-jose/go-jose/v3"
str "github.com/zitadel/oidc/v2/pkg/strings" str "github.com/zitadel/oidc/v3/pkg/strings"
) )
type Claims interface { type Claims interface {
@ -61,10 +61,19 @@ var (
ErrAtHash = errors.New("at_hash does not correspond to access token") ErrAtHash = errors.New("at_hash does not correspond to access token")
) )
type Verifier interface { // Verifier caries configuration for the various token verification
Issuer() string // functions. Use package specific constructor functions to know
MaxAgeIAT() time.Duration // which values need to be set.
Offset() time.Duration type Verifier struct {
Issuer string
MaxAgeIAT time.Duration
Offset time.Duration
ClientID string
SupportedSignAlgs []string
MaxAge time.Duration
ACR ACRVerifier
KeySet KeySet
Nonce func(ctx context.Context) string
} }
// ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim // ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
@ -121,6 +130,11 @@ func CheckAudience(claims Claims, clientID string) error {
return nil return nil
} }
// CheckAuthorizedParty checks azp (authorized party) claim requirements.
//
// If the ID Token contains multiple audiences, the Client SHOULD verify that an azp Claim is present.
// If an azp Claim is present, the Client SHOULD verify that its client_id is the Claim Value.
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func CheckAuthorizedParty(claims Claims, clientID string) error { func CheckAuthorizedParty(claims Claims, clientID string) error {
if len(claims.GetAudience()) > 1 { if len(claims.GetAudience()) > 1 {
if claims.GetAuthorizedParty() == "" { if claims.GetAuthorizedParty() == "" {
@ -167,26 +181,26 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl
} }
func CheckExpiration(claims Claims, offset time.Duration) error { func CheckExpiration(claims Claims, offset time.Duration) error {
expiration := claims.GetExpiration().Round(time.Second) expiration := claims.GetExpiration()
if !time.Now().UTC().Add(offset).Before(expiration) { if !time.Now().Add(offset).Before(expiration) {
return ErrExpired return ErrExpired
} }
return nil return nil
} }
func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error { func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error {
issuedAt := claims.GetIssuedAt().Round(time.Second) issuedAt := claims.GetIssuedAt()
if issuedAt.IsZero() { if issuedAt.IsZero() {
return ErrIatMissing return ErrIatMissing
} }
nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second) nowWithOffset := time.Now().Add(offset).Round(time.Second)
if issuedAt.After(nowWithOffset) { if issuedAt.After(nowWithOffset) {
return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset) return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset)
} }
if maxAgeIAT == 0 { if maxAgeIAT == 0 {
return nil return nil
} }
maxAge := time.Now().UTC().Add(-maxAgeIAT).Round(time.Second) maxAge := time.Now().Add(-maxAgeIAT).Round(time.Second)
if issuedAt.Before(maxAge) { if issuedAt.Before(maxAge) {
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt)) return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt))
} }
@ -216,8 +230,8 @@ func CheckAuthTime(claims Claims, maxAge time.Duration) error {
if claims.GetAuthTime().IsZero() { if claims.GetAuthTime().IsZero() {
return ErrAuthTimeNotPresent return ErrAuthTimeNotPresent
} }
authTime := claims.GetAuthTime().Round(time.Second) authTime := claims.GetAuthTime()
maxAuthTime := time.Now().UTC().Add(-maxAge).Round(time.Second) maxAuthTime := time.Now().Add(-maxAge).Round(time.Second)
if authTime.Before(maxAuthTime) { if authTime.Before(maxAuthTime) {
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime)) return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime))
} }

View file

@ -0,0 +1,128 @@
package oidc_test
import (
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
func TestParseToken(t *testing.T) {
token, wantClaims := tu.ValidIDToken()
wantClaims.SignatureAlg = "" // unset, because is not part of the JSON payload
wantPayload, err := json.Marshal(wantClaims)
require.NoError(t, err)
tests := []struct {
name string
tokenString string
wantErr bool
}{
{
name: "split error",
tokenString: "nope",
wantErr: true,
},
{
name: "base64 error",
tokenString: "foo.~.bar",
wantErr: true,
},
{
name: "success",
tokenString: token,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotClaims := new(oidc.IDTokenClaims)
gotPayload, err := oidc.ParseToken(tt.tokenString, gotClaims)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, wantClaims, gotClaims)
assert.JSONEq(t, string(wantPayload), string(gotPayload))
})
}
}
func TestCheckSignature(t *testing.T) {
errCtx, cancel := context.WithCancel(context.Background())
cancel()
token, _ := tu.ValidIDToken()
payload, err := oidc.ParseToken(token, &oidc.IDTokenClaims{})
require.NoError(t, err)
type args struct {
ctx context.Context
token string
payload []byte
supportedSigAlgs []string
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "parse error",
args: args{
ctx: context.Background(),
token: "~",
payload: payload,
},
wantErr: oidc.ErrParse,
},
{
name: "default sigAlg",
args: args{
ctx: context.Background(),
token: token,
payload: payload,
},
},
{
name: "unsupported sigAlg",
args: args{
ctx: context.Background(),
token: token,
payload: payload,
supportedSigAlgs: []string{"foo", "bar"},
},
wantErr: oidc.ErrSignatureUnsupportedAlg,
},
{
name: "verify error",
args: args{
ctx: errCtx,
token: token,
payload: payload,
},
wantErr: oidc.ErrSignatureInvalid,
},
{
name: "inequal payloads",
args: args{
ctx: context.Background(),
token: token,
payload: []byte{0, 1, 2},
},
wantErr: oidc.ErrSignatureInvalidPayload,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
claims := new(oidc.TokenClaims)
err := oidc.CheckSignature(tt.args.ctx, tt.args.token, tt.args.payload, claims, tt.args.supportedSigAlgs, tu.KeySet{})
assert.ErrorIs(t, err, tt.wantErr)
})
}
}

374
pkg/oidc/verifier_test.go Normal file
View file

@ -0,0 +1,374 @@
package oidc
import (
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDecryptToken(t *testing.T) {
const tokenString = "ABC"
got, err := DecryptToken(tokenString)
require.NoError(t, err)
assert.Equal(t, tokenString, got)
}
func TestDefaultACRVerifier(t *testing.T) {
acrVerfier := DefaultACRVerifier([]string{"foo", "bar"})
tests := []struct {
name string
acr string
wantErr string
}{
{
name: "ok",
acr: "bar",
},
{
name: "error",
acr: "hello",
wantErr: "expected one of: [foo bar], got: \"hello\"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := acrVerfier(tt.acr)
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
return
}
require.NoError(t, err)
})
}
}
func TestCheckSubject(t *testing.T) {
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrSubjectMissing,
},
{
name: "ok",
claims: &TokenClaims{
Subject: "foo",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckSubject(tt.claims)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckIssuer(t *testing.T) {
const issuer = "foo.bar"
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrIssuerInvalid,
},
{
name: "wrong",
claims: &TokenClaims{
Issuer: "wrong",
},
wantErr: ErrIssuerInvalid,
},
{
name: "ok",
claims: &TokenClaims{
Issuer: issuer,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckIssuer(tt.claims, issuer)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckAudience(t *testing.T) {
const clientID = "foo.bar"
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrAudience,
},
{
name: "wrong",
claims: &TokenClaims{
Audience: []string{"wrong"},
},
wantErr: ErrAudience,
},
{
name: "ok",
claims: &TokenClaims{
Audience: []string{clientID},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckAudience(tt.claims, clientID)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckAuthorizedParty(t *testing.T) {
const clientID = "foo.bar"
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "single audience, no azp",
claims: &TokenClaims{
Audience: []string{clientID},
},
},
{
name: "multiple audience, no azp",
claims: &TokenClaims{
Audience: []string{clientID, "other"},
},
wantErr: ErrAzpMissing,
},
{
name: "single audience, with azp",
claims: &TokenClaims{
Audience: []string{clientID},
AuthorizedParty: clientID,
},
},
{
name: "multiple audience, with azp",
claims: &TokenClaims{
Audience: []string{clientID, "other"},
AuthorizedParty: clientID,
},
},
{
name: "wrong azp",
claims: &TokenClaims{
AuthorizedParty: "wrong",
},
wantErr: ErrAzpInvalid,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckAuthorizedParty(tt.claims, clientID)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckExpiration(t *testing.T) {
const offset = time.Minute
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrExpired,
},
{
name: "expired",
claims: &TokenClaims{
Expiration: FromTime(time.Now().Add(-2 * offset)),
},
wantErr: ErrExpired,
},
{
name: "valid",
claims: &TokenClaims{
Expiration: FromTime(time.Now().Add(2 * offset)),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckExpiration(tt.claims, offset)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckIssuedAt(t *testing.T) {
const offset = time.Minute
tests := []struct {
name string
maxAgeIAT time.Duration
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrIatMissing,
},
{
name: "future",
claims: &TokenClaims{
IssuedAt: FromTime(time.Now().Add(time.Hour)),
},
wantErr: ErrIatInFuture,
},
{
name: "no max",
claims: &TokenClaims{
IssuedAt: FromTime(time.Now()),
},
},
{
name: "past max",
maxAgeIAT: time.Minute,
claims: &TokenClaims{
IssuedAt: FromTime(time.Now().Add(-time.Hour)),
},
wantErr: ErrIatToOld,
},
{
name: "within max",
maxAgeIAT: time.Hour,
claims: &TokenClaims{
IssuedAt: FromTime(time.Now()),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckIssuedAt(tt.claims, tt.maxAgeIAT, offset)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckNonce(t *testing.T) {
const nonce = "123"
tests := []struct {
name string
claims Claims
wantErr error
}{
{
name: "missing",
claims: &TokenClaims{},
wantErr: ErrNonceInvalid,
},
{
name: "wrong",
claims: &TokenClaims{
Nonce: "wrong",
},
wantErr: ErrNonceInvalid,
},
{
name: "ok",
claims: &TokenClaims{
Nonce: nonce,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckNonce(tt.claims, nonce)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckAuthorizationContextClassReference(t *testing.T) {
tests := []struct {
name string
acr ACRVerifier
wantErr error
}{
{
name: "error",
acr: func(s string) error { return errors.New("oops") },
wantErr: ErrAcrInvalid,
},
{
name: "ok",
acr: func(s string) error { return nil },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckAuthorizationContextClassReference(&IDTokenClaims{}, tt.acr)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestCheckAuthTime(t *testing.T) {
tests := []struct {
name string
claims Claims
maxAge time.Duration
wantErr error
}{
{
name: "no max age",
claims: &TokenClaims{},
},
{
name: "missing",
claims: &TokenClaims{},
maxAge: time.Minute,
wantErr: ErrAuthTimeNotPresent,
},
{
name: "expired",
maxAge: time.Minute,
claims: &TokenClaims{
AuthTime: FromTime(time.Now().Add(-time.Hour)),
},
wantErr: ErrAuthTimeToOld,
},
{
name: "ok",
maxAge: time.Minute,
claims: &TokenClaims{
AuthTime: NowTime(),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckAuthTime(tt.claims, tt.maxAge)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}

View file

@ -2,6 +2,7 @@ package op
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -10,11 +11,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/gorilla/mux" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
httphelper "github.com/zitadel/oidc/v2/pkg/http" str "github.com/zitadel/oidc/v3/pkg/strings"
"github.com/zitadel/oidc/v2/pkg/oidc" "golang.org/x/exp/slog"
str "github.com/zitadel/oidc/v2/pkg/strings"
) )
type AuthRequest interface { type AuthRequest interface {
@ -39,16 +39,17 @@ type Authorizer interface {
Storage() Storage Storage() Storage
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Encoder() httphelper.Encoder Encoder() httphelper.Encoder
IDTokenHintVerifier(context.Context) IDTokenHintVerifier IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
Crypto() Crypto Crypto() Crypto
RequestObjectSupported() bool RequestObjectSupported() bool
Logger() *slog.Logger
} }
// AuthorizeValidator is an extension of Authorizer interface // AuthorizeValidator is an extension of Authorizer interface
// implementing its own validation mechanism for the auth request // implementing its own validation mechanism for the auth request
type AuthorizeValidator interface { type AuthorizeValidator interface {
Authorizer Authorizer
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error) ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, *IDTokenHintVerifier) (string, error)
} }
func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) { func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) {
@ -68,23 +69,23 @@ func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, *
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder()) authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder())
if err != nil { if err != nil {
AuthRequestError(w, r, nil, err, authorizer.Encoder()) AuthRequestError(w, r, nil, err, authorizer)
return return
} }
ctx := r.Context() ctx := r.Context()
if authReq.RequestParam != "" && authorizer.RequestObjectSupported() { if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx)) err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer)
return return
} }
} }
if authReq.ClientID == "" { if authReq.ClientID == "" {
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer.Encoder()) AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer)
return return
} }
if authReq.RedirectURI == "" { if authReq.RedirectURI == "" {
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer.Encoder()) AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer)
return return
} }
validation := ValidateAuthRequest validation := ValidateAuthRequest
@ -93,21 +94,21 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
} }
userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx)) userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx))
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer)
return return
} }
if authReq.RequestParam != "" { if authReq.RequestParam != "" {
AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder()) AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer)
return return
} }
req, err := authorizer.Storage().CreateAuthRequest(ctx, authReq, userID) req, err := authorizer.Storage().CreateAuthRequest(ctx, authReq, userID)
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder()) AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer)
return return
} }
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID()) client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
if err != nil { if err != nil {
AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder()) AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer)
return return
} }
RedirectToLogin(req.GetID(), client, w, r) RedirectToLogin(req.GetID(), client, w, r)
@ -129,31 +130,31 @@ func ParseAuthorizeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.A
// ParseRequestObject parse the `request` parameter, validates the token including the signature // ParseRequestObject parse the `request` parameter, validates the token including the signature
// and copies the token claims into the auth request // and copies the token claims into the auth request
func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) (*oidc.AuthRequest, error) { func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) error {
requestObject := new(oidc.RequestObject) requestObject := new(oidc.RequestObject)
payload, err := oidc.ParseToken(authReq.RequestParam, requestObject) payload, err := oidc.ParseToken(authReq.RequestParam, requestObject)
if err != nil { if err != nil {
return nil, err return err
} }
if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID { if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID {
return authReq, oidc.ErrInvalidRequest() return oidc.ErrInvalidRequest()
} }
if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType { if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType {
return authReq, oidc.ErrInvalidRequest() return oidc.ErrInvalidRequest()
} }
if requestObject.Issuer != requestObject.ClientID { if requestObject.Issuer != requestObject.ClientID {
return authReq, oidc.ErrInvalidRequest() return oidc.ErrInvalidRequest()
} }
if !str.Contains(requestObject.Audience, issuer) { if !str.Contains(requestObject.Audience, issuer) {
return authReq, oidc.ErrInvalidRequest() return oidc.ErrInvalidRequest()
} }
keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer} keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer}
if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil { if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil {
return authReq, err return err
} }
CopyRequestObjectToAuthRequest(authReq, requestObject) CopyRequestObjectToAuthRequest(authReq, requestObject)
return authReq, nil return nil
} }
// CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request // CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request
@ -205,7 +206,7 @@ func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oi
} }
// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed // ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) { func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) {
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge) authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
if err != nil { if err != nil {
return "", err return "", err
@ -385,7 +386,7 @@ func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType)
// ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request) // ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request)
// and returns the `sub` claim // and returns the `sub` claim
func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier IDTokenHintVerifier) (string, error) { func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier *IDTokenHintVerifier) (string, error) {
if idTokenHint == "" { if idTokenHint == "" {
return "", nil return "", nil
} }
@ -405,32 +406,41 @@ func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *
// AuthorizeCallback handles the callback after authentication in the Login UI // AuthorizeCallback handles the callback after authentication in the Login UI
func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
params := mux.Vars(r) id, err := ParseAuthorizeCallbackRequest(r)
id := params["id"] if err != nil {
if id == "" { AuthRequestError(w, r, nil, err, authorizer)
AuthRequestError(w, r, nil, fmt.Errorf("auth request callback is missing id"), authorizer.Encoder())
return return
} }
authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id) authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
if err != nil { if err != nil {
AuthRequestError(w, r, nil, err, authorizer.Encoder()) AuthRequestError(w, r, nil, err, authorizer)
return return
} }
if !authReq.Done() { if !authReq.Done() {
AuthRequestError(w, r, authReq, AuthRequestError(w, r, authReq,
oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."), oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."),
authorizer.Encoder()) authorizer)
return return
} }
AuthResponse(authReq, authorizer, w, r) AuthResponse(authReq, authorizer, w, r)
} }
func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) {
if err = r.ParseForm(); err != nil {
return "", fmt.Errorf("cannot parse form: %w", err)
}
id = r.Form.Get("id")
if id == "" {
return "", errors.New("auth request callback is missing id")
}
return id, nil
}
// AuthResponse creates the successful authentication response (either code or tokens) // AuthResponse creates the successful authentication response (either code or tokens)
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID()) client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer)
return return
} }
if authReq.GetResponseType() == oidc.ResponseTypeCode { if authReq.GetResponseType() == oidc.ResponseTypeCode {
@ -444,7 +454,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) { func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto()) code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto())
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer)
return return
} }
codeResponse := struct { codeResponse := struct {
@ -456,7 +466,7 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques
} }
callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder()) callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer)
return return
} }
http.Redirect(w, r, callback, http.StatusFound) http.Redirect(w, r, callback, http.StatusFound)
@ -471,12 +481,12 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque
createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly
resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "", "") resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "", "")
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer)
return return
} }
callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder()) callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder())
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer)
return return
} }
http.Redirect(w, r, callback, http.StatusFound) http.Redirect(w, r, callback, http.StatusFound)

View file

@ -11,14 +11,16 @@ import (
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/gorilla/schema"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/example/server/storage" "github.com/zitadel/oidc/v3/example/server/storage"
httphelper "github.com/zitadel/oidc/v2/pkg/http" tu "github.com/zitadel/oidc/v3/internal/testutil"
"github.com/zitadel/oidc/v2/pkg/oidc" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op/mock" "github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/oidc/v3/pkg/op/mock"
"github.com/zitadel/schema"
"golang.org/x/exp/slog"
) )
func TestAuthorize(t *testing.T) { func TestAuthorize(t *testing.T) {
@ -39,7 +41,7 @@ func TestAuthorize(t *testing.T) {
expect := authorizer.EXPECT() expect := authorizer.EXPECT()
expect.Decoder().Return(schema.NewDecoder()) expect.Decoder().Return(schema.NewDecoder())
expect.Encoder().Return(schema.NewEncoder()) expect.Logger().Return(slog.Default())
if tt.expect != nil { if tt.expect != nil {
tt.expect(expect) tt.expect(expect)
@ -123,7 +125,7 @@ func TestValidateAuthRequest(t *testing.T) {
type args struct { type args struct {
authRequest *oidc.AuthRequest authRequest *oidc.AuthRequest
storage op.Storage storage op.Storage
verifier op.IDTokenHintVerifier verifier *op.IDTokenHintVerifier
} }
tests := []struct { tests := []struct {
name string name string
@ -996,7 +998,7 @@ func TestAuthResponseCode(t *testing.T) {
authorizer.EXPECT().Crypto().Return(&mockCrypto{ authorizer.EXPECT().Crypto().Return(&mockCrypto{
returnErr: io.ErrClosedPipe, returnErr: io.ErrClosedPipe,
}) })
authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) authorizer.EXPECT().Logger().Return(slog.Default())
return authorizer return authorizer
}, },
}, },
@ -1071,3 +1073,71 @@ func TestAuthResponseCode(t *testing.T) {
}) })
} }
} }
func Test_parseAuthorizeCallbackRequest(t *testing.T) {
tests := []struct {
name string
url string
wantId string
wantErr bool
}{
{
name: "parse error",
url: "/?id;=99",
wantErr: true,
},
{
name: "missing id",
url: "/",
wantErr: true,
},
{
name: "ok",
url: "/?id=99",
wantId: "99",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, tt.url, nil)
gotId, err := op.ParseAuthorizeCallbackRequest(r)
if tt.wantErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.wantId, gotId)
})
}
}
func TestValidateAuthReqIDTokenHint(t *testing.T) {
token, _ := tu.ValidIDToken()
tests := []struct {
name string
idTokenHint string
want string
wantErr error
}{
{
name: "empty",
},
{
name: "verify err",
idTokenHint: "foo",
wantErr: oidc.ErrLoginRequired(),
},
{
name: "ok",
idTokenHint: token,
want: tu.ValidSubject,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := op.ValidateAuthReqIDTokenHint(context.Background(), tt.idTokenHint, op.NewIDTokenHintVerifier(tu.ValidIssuer, tu.KeySet{}))
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -7,8 +7,8 @@ import (
"net/url" "net/url"
"time" "time"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
//go:generate go get github.com/dmarkham/enumer //go:generate go get github.com/dmarkham/enumer
@ -87,7 +87,7 @@ var (
) )
type ClientJWTProfile interface { type ClientJWTProfile interface {
JWTProfileVerifier(context.Context) JWTProfileVerifier JWTProfileVerifier(context.Context) *JWTProfileVerifier
} }
func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) { func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) {
@ -180,3 +180,10 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au
} }
return data.ClientID, false, nil return data.ClientID, false, nil
} }
type ClientCredentials struct {
ClientID string `schema:"client_id"`
ClientSecret string `schema:"client_secret"` // Client secret from Basic auth or request body
ClientAssertion string `schema:"client_assertion"` // JWT
ClientAssertionType string `schema:"client_assertion_type"`
}

View file

@ -11,18 +11,18 @@ import (
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/gorilla/schema"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/oidc/v2/pkg/op/mock" "github.com/zitadel/oidc/v3/pkg/op/mock"
"github.com/zitadel/schema"
) )
type testClientJWTProfile struct{} type testClientJWTProfile struct{}
func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil } func (testClientJWTProfile) JWTProfileVerifier(context.Context) *op.JWTProfileVerifier { return nil }
func TestClientJWTAuth(t *testing.T) { func TestClientJWTAuth(t *testing.T) {
type args struct { type args struct {

View file

@ -22,14 +22,14 @@ var (
type Configuration interface { type Configuration interface {
IssuerFromRequest(r *http.Request) string IssuerFromRequest(r *http.Request) string
Insecure() bool Insecure() bool
AuthorizationEndpoint() Endpoint AuthorizationEndpoint() *Endpoint
TokenEndpoint() Endpoint TokenEndpoint() *Endpoint
IntrospectionEndpoint() Endpoint IntrospectionEndpoint() *Endpoint
UserinfoEndpoint() Endpoint UserinfoEndpoint() *Endpoint
RevocationEndpoint() Endpoint RevocationEndpoint() *Endpoint
EndSessionEndpoint() Endpoint EndSessionEndpoint() *Endpoint
KeysEndpoint() Endpoint KeysEndpoint() *Endpoint
DeviceAuthorizationEndpoint() Endpoint DeviceAuthorizationEndpoint() *Endpoint
AuthMethodPostSupported() bool AuthMethodPostSupported() bool
CodeMethodS256Supported() bool CodeMethodS256Supported() bool

View file

@ -1,7 +1,7 @@
package op package op
import ( import (
"github.com/zitadel/oidc/v2/pkg/crypto" "github.com/zitadel/oidc/v3/pkg/crypto"
) )
type Crypto interface { type Crypto interface {

View file

@ -12,8 +12,8 @@ import (
"strings" "strings"
"time" "time"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
type DeviceAuthorizationConfig struct { type DeviceAuthorizationConfig struct {
@ -57,47 +57,57 @@ var (
func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) { func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if err := DeviceAuthorization(w, r, o); err != nil { if err := DeviceAuthorization(w, r, o); err != nil {
RequestError(w, r, err) RequestError(w, r, err, o.Logger())
} }
} }
} }
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error { func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error {
storage, err := assertDeviceStorage(o.Storage())
if err != nil {
return err
}
req, err := ParseDeviceCodeRequest(r, o) req, err := ParseDeviceCodeRequest(r, o)
if err != nil { if err != nil {
return err return err
} }
response, err := createDeviceAuthorization(r.Context(), req, req.ClientID, o)
if err != nil {
return err
}
httphelper.MarshalJSON(w, response)
return nil
}
func createDeviceAuthorization(ctx context.Context, req *oidc.DeviceAuthorizationRequest, clientID string, o OpenIDProvider) (*oidc.DeviceAuthorizationResponse, error) {
storage, err := assertDeviceStorage(o.Storage())
if err != nil {
return nil, err
}
config := o.DeviceAuthorization() config := o.DeviceAuthorization()
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes) deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
if err != nil { if err != nil {
return err return nil, NewStatusError(err, http.StatusInternalServerError)
} }
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval) userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval)
if err != nil { if err != nil {
return err return nil, NewStatusError(err, http.StatusInternalServerError)
} }
expires := time.Now().Add(config.Lifetime) expires := time.Now().Add(config.Lifetime)
err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes) err = storage.StoreDeviceAuthorization(ctx, clientID, deviceCode, userCode, expires, req.Scopes)
if err != nil { if err != nil {
return err return nil, NewStatusError(err, http.StatusInternalServerError)
} }
var verification *url.URL var verification *url.URL
if config.UserFormURL != "" { if config.UserFormURL != "" {
if verification, err = url.Parse(config.UserFormURL); err != nil { if verification, err = url.Parse(config.UserFormURL); err != nil {
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form") err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
return nil, NewStatusError(err, http.StatusInternalServerError)
} }
} else { } else {
if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil { if verification, err = url.Parse(IssuerFromContext(ctx)); err != nil {
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer") err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
return nil, NewStatusError(err, http.StatusInternalServerError)
} }
verification.Path = config.UserFormPath verification.Path = config.UserFormPath
} }
@ -112,9 +122,7 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
verification.RawQuery = "user_code=" + userCode verification.RawQuery = "user_code=" + userCode
response.VerificationURIComplete = verification.String() response.VerificationURIComplete = verification.String()
return response, nil
httphelper.MarshalJSON(w, response)
return nil
} }
func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) { func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) {
@ -201,7 +209,7 @@ func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
r = r.WithContext(ctx) r = r.WithContext(ctx)
if err := deviceAccessToken(w, r, exchanger); err != nil { if err := deviceAccessToken(w, r, exchanger); err != nil {
RequestError(w, r, err) RequestError(w, r, err, exchanger.Logger())
} }
} }

View file

@ -16,8 +16,9 @@ import (
"github.com/muhlemmer/gu" "github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/example/server/storage"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
) )
func Test_deviceAuthorizationHandler(t *testing.T) { func Test_deviceAuthorizationHandler(t *testing.T) {
@ -319,7 +320,7 @@ func BenchmarkNewUserCode(b *testing.B) {
} }
func TestDeviceAccessToken(t *testing.T) { func TestDeviceAccessToken(t *testing.T) {
storage := testProvider.Storage().(op.DeviceAuthorizationStorage) storage := testProvider.Storage().(*storage.Storage)
storage.StoreDeviceAuthorization(context.Background(), "native", "qwerty", "yuiop", time.Now().Add(time.Minute), []string{"foo"}) storage.StoreDeviceAuthorization(context.Background(), "native", "qwerty", "yuiop", time.Now().Add(time.Minute), []string{"foo"})
storage.CompleteDeviceAuthorization(context.Background(), "yuiop", "tim") storage.CompleteDeviceAuthorization(context.Background(), "yuiop", "tim")
@ -344,7 +345,7 @@ func TestDeviceAccessToken(t *testing.T) {
func TestCheckDeviceAuthorizationState(t *testing.T) { func TestCheckDeviceAuthorizationState(t *testing.T) {
now := time.Now() now := time.Now()
storage := testProvider.Storage().(op.DeviceAuthorizationStorage) storage := testProvider.Storage().(*storage.Storage)
storage.StoreDeviceAuthorization(context.Background(), "native", "pending", "pending", now.Add(time.Minute), []string{"foo"}) storage.StoreDeviceAuthorization(context.Background(), "native", "pending", "pending", now.Add(time.Minute), []string{"foo"})
storage.StoreDeviceAuthorization(context.Background(), "native", "denied", "denied", now.Add(time.Minute), []string{"foo"}) storage.StoreDeviceAuthorization(context.Background(), "native", "denied", "denied", now.Add(time.Minute), []string{"foo"})
storage.StoreDeviceAuthorization(context.Background(), "native", "completed", "completed", now.Add(time.Minute), []string{"foo"}) storage.StoreDeviceAuthorization(context.Background(), "native", "completed", "completed", now.Add(time.Minute), []string{"foo"})

View file

@ -4,10 +4,10 @@ import (
"context" "context"
"net/http" "net/http"
"gopkg.in/square/go-jose.v2" jose "github.com/go-jose/go-jose/v3"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
type DiscoverStorage interface { type DiscoverStorage interface {
@ -25,7 +25,7 @@ var DefaultSupportedScopes = []string{
func discoveryHandler(c Configuration, s DiscoverStorage) func(http.ResponseWriter, *http.Request) { func discoveryHandler(c Configuration, s DiscoverStorage) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
Discover(w, CreateDiscoveryConfig(r, c, s)) Discover(w, CreateDiscoveryConfig(r.Context(), c, s))
} }
} }
@ -33,8 +33,8 @@ func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
httphelper.MarshalJSON(w, config) httphelper.MarshalJSON(w, config)
} }
func CreateDiscoveryConfig(r *http.Request, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration { func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration {
issuer := config.IssuerFromRequest(r) issuer := IssuerFromContext(ctx)
return &oidc.DiscoveryConfiguration{ return &oidc.DiscoveryConfiguration{
Issuer: issuer, Issuer: issuer,
AuthorizationEndpoint: config.AuthorizationEndpoint().Absolute(issuer), AuthorizationEndpoint: config.AuthorizationEndpoint().Absolute(issuer),
@ -49,7 +49,38 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov
ResponseTypesSupported: ResponseTypes(config), ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config), GrantTypesSupported: GrantTypes(config),
SubjectTypesSupported: SubjectTypes(config), SubjectTypesSupported: SubjectTypes(config),
IDTokenSigningAlgValuesSupported: SigAlgorithms(r.Context(), storage), IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config),
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config),
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config),
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config),
ClaimsSupported: SupportedClaims(config),
CodeChallengeMethodsSupported: CodeChallengeMethods(config),
UILocalesSupported: config.SupportedUILocales(),
RequestParameterSupported: config.RequestObjectSupported(),
}
}
func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage DiscoverStorage, endpoints *Endpoints) *oidc.DiscoveryConfiguration {
issuer := IssuerFromContext(ctx)
return &oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: endpoints.Authorization.Absolute(issuer),
TokenEndpoint: endpoints.Token.Absolute(issuer),
IntrospectionEndpoint: endpoints.Introspection.Absolute(issuer),
UserinfoEndpoint: endpoints.Userinfo.Absolute(issuer),
RevocationEndpoint: endpoints.Revocation.Absolute(issuer),
EndSessionEndpoint: endpoints.EndSession.Absolute(issuer),
JwksURI: endpoints.JwksURI.Absolute(issuer),
DeviceAuthorizationEndpoint: endpoints.DeviceAuthorization.Absolute(issuer),
ScopesSupported: Scopes(config),
ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config),
SubjectTypesSupported: SubjectTypes(config),
IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config), RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config), TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config), TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),

View file

@ -6,14 +6,14 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
jose "github.com/go-jose/go-jose/v3"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/oidc/v2/pkg/op/mock" "github.com/zitadel/oidc/v3/pkg/op/mock"
) )
func TestDiscover(t *testing.T) { func TestDiscover(t *testing.T) {
@ -48,7 +48,7 @@ func TestDiscover(t *testing.T) {
func TestCreateDiscoveryConfig(t *testing.T) { func TestCreateDiscoveryConfig(t *testing.T) {
type args struct { type args struct {
request *http.Request ctx context.Context
c op.Configuration c op.Configuration
s op.DiscoverStorage s op.DiscoverStorage
} }
@ -61,7 +61,7 @@ func TestCreateDiscoveryConfig(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := op.CreateDiscoveryConfig(tt.args.request, tt.args.c, tt.args.s) got := op.CreateDiscoveryConfig(tt.args.ctx, tt.args.c, tt.args.s)
assert.Equal(t, tt.want, got) assert.Equal(t, tt.want, got)
}) })
} }

View file

@ -1,32 +1,46 @@
package op package op
import "strings" import (
"errors"
"strings"
)
type Endpoint struct { type Endpoint struct {
path string path string
url string url string
} }
func NewEndpoint(path string) Endpoint { func NewEndpoint(path string) *Endpoint {
return Endpoint{path: path} return &Endpoint{path: path}
} }
func NewEndpointWithURL(path, url string) Endpoint { func NewEndpointWithURL(path, url string) *Endpoint {
return Endpoint{path: path, url: url} return &Endpoint{path: path, url: url}
} }
func (e Endpoint) Relative() string { func (e *Endpoint) Relative() string {
if e == nil {
return ""
}
return relativeEndpoint(e.path) return relativeEndpoint(e.path)
} }
func (e Endpoint) Absolute(host string) string { func (e *Endpoint) Absolute(host string) string {
if e == nil {
return ""
}
if e.url != "" { if e.url != "" {
return e.url return e.url
} }
return absoluteEndpoint(host, e.path) return absoluteEndpoint(host, e.path)
} }
func (e Endpoint) Validate() error { var ErrNilEndpoint = errors.New("nil endpoint")
func (e *Endpoint) Validate() error {
if e == nil {
return ErrNilEndpoint
}
return nil // TODO: return nil // TODO:
} }

View file

@ -3,13 +3,14 @@ package op_test
import ( import (
"testing" "testing"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/op"
) )
func TestEndpoint_Path(t *testing.T) { func TestEndpoint_Path(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
e op.Endpoint e *op.Endpoint
want string want string
}{ }{
{ {
@ -27,6 +28,11 @@ func TestEndpoint_Path(t *testing.T) {
op.NewEndpointWithURL("/test", "http://test.com/test"), op.NewEndpointWithURL("/test", "http://test.com/test"),
"/test", "/test",
}, },
{
"nil",
nil,
"",
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -43,7 +49,7 @@ func TestEndpoint_Absolute(t *testing.T) {
} }
tests := []struct { tests := []struct {
name string name string
e op.Endpoint e *op.Endpoint
args args args args
want string want string
}{ }{
@ -77,6 +83,12 @@ func TestEndpoint_Absolute(t *testing.T) {
args{"https://host"}, args{"https://host"},
"https://test.com/test", "https://test.com/test",
}, },
{
"nil",
nil,
args{"https://host"},
"",
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -91,16 +103,19 @@ func TestEndpoint_Absolute(t *testing.T) {
func TestEndpoint_Validate(t *testing.T) { func TestEndpoint_Validate(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
e op.Endpoint e *op.Endpoint
wantErr bool wantErr error
}{ }{
// TODO: Add test cases. {
"nil",
nil,
op.ErrNilEndpoint,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.e.Validate(); (err != nil) != tt.wantErr { err := tt.e.Validate()
t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr) require.ErrorIs(t, err, tt.wantErr)
}
}) })
} }
} }

View file

@ -1,10 +1,14 @@
package op package op
import ( import (
"context"
"errors"
"fmt"
"net/http" "net/http"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/exp/slog"
) )
type ErrAuthRequest interface { type ErrAuthRequest interface {
@ -13,13 +17,31 @@ type ErrAuthRequest interface {
GetState() string GetState() string
} }
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder httphelper.Encoder) { // LogAuthRequest is an optional interface,
// that allows logging AuthRequest fields.
// If the AuthRequest does not implement this interface,
// no details shall be printed to the logs.
type LogAuthRequest interface {
ErrAuthRequest
slog.LogValuer
}
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, authorizer Authorizer) {
e := oidc.DefaultToServerError(err, err.Error())
logger := authorizer.Logger().With("oidc_error", e)
if authReq == nil { if authReq == nil {
logger.Log(r.Context(), e.LogLevel(), "auth request")
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
e := oidc.DefaultToServerError(err, err.Error())
if logAuthReq, ok := authReq.(LogAuthRequest); ok {
logger = logger.With("auth_request", logAuthReq)
}
if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() { if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
logger.Log(r.Context(), e.LogLevel(), "auth request: not redirecting")
http.Error(w, e.Description, http.StatusBadRequest) http.Error(w, e.Description, http.StatusBadRequest)
return return
} }
@ -28,19 +50,120 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
responseMode = rm.GetResponseMode() responseMode = rm.GetResponseMode()
} }
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder) url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, authorizer.Encoder())
if err != nil { if err != nil {
logger.ErrorContext(r.Context(), "auth response URL", "error", err)
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
logger.Log(r.Context(), e.LogLevel(), "auth request")
http.Redirect(w, r, url, http.StatusFound) http.Redirect(w, r, url, http.StatusFound)
} }
func RequestError(w http.ResponseWriter, r *http.Request, err error) { func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
e := oidc.DefaultToServerError(err, err.Error()) e := oidc.DefaultToServerError(err, err.Error())
status := http.StatusBadRequest status := http.StatusBadRequest
if e.ErrorType == oidc.InvalidClient { if e.ErrorType == oidc.InvalidClient {
status = 401 status = http.StatusUnauthorized
} }
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
httphelper.MarshalJSONWithStatus(w, e, status) httphelper.MarshalJSONWithStatus(w, e, status)
} }
// TryErrorRedirect tries to handle an error by redirecting a client.
// If this attempt fails, an error is returned that must be returned
// to the client instead.
func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) {
e := oidc.DefaultToServerError(parent, parent.Error())
logger = logger.With("oidc_error", e)
if authReq == nil {
logger.Log(ctx, e.LogLevel(), "auth request")
return nil, AsStatusError(e, http.StatusBadRequest)
}
if logAuthReq, ok := authReq.(LogAuthRequest); ok {
logger = logger.With("auth_request", logAuthReq)
}
if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
logger.Log(ctx, e.LogLevel(), "auth request: not redirecting")
return nil, AsStatusError(e, http.StatusBadRequest)
}
e.State = authReq.GetState()
var responseMode oidc.ResponseMode
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
responseMode = rm.GetResponseMode()
}
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
if err != nil {
logger.ErrorContext(ctx, "auth response URL", "error", err)
return nil, AsStatusError(err, http.StatusBadRequest)
}
logger.Log(ctx, e.LogLevel(), "auth request redirect", "url", url)
return NewRedirect(url), nil
}
// StatusError wraps an error with a HTTP status code.
// The status code is passed to the handler's writer.
type StatusError struct {
parent error
statusCode int
}
// NewStatusError sets the parent and statusCode to a new StatusError.
// It is recommended for parent to be an [oidc.Error].
//
// Typically implementations should only use this to signal something
// very specific, like an internal server error.
// If a returned error is not a StatusError, the framework
// will set a statusCode based on what the standard specifies,
// which is [http.StatusBadRequest] for most of the time.
// If the error encountered can described clearly with a [oidc.Error],
// do not use this function, as it might break standard rules!
func NewStatusError(parent error, statusCode int) StatusError {
return StatusError{
parent: parent,
statusCode: statusCode,
}
}
// AsStatusError unwraps a StatusError from err
// and returns it unmodified if found.
// If no StatuError was found, a new one is returned
// with statusCode set to it as a default.
func AsStatusError(err error, statusCode int) (target StatusError) {
if errors.As(err, &target) {
return target
}
return NewStatusError(err, statusCode)
}
func (e StatusError) Error() string {
return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error())
}
func (e StatusError) Unwrap() error {
return e.parent
}
func (e StatusError) Is(err error) bool {
var target StatusError
if !errors.As(err, &target) {
return false
}
return errors.Is(e.parent, target.parent) &&
e.statusCode == target.statusCode
}
// WriteError asserts for a StatusError containing an [oidc.Error].
// If no StatusError is found, the status code will default to [http.StatusBadRequest].
// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError].
func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
statusError := AsStatusError(err, http.StatusBadRequest)
e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error())
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
httphelper.MarshalJSONWithStatus(w, e, statusError.statusCode)
}

677
pkg/op/error_test.go Normal file
View file

@ -0,0 +1,677 @@
package op
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/schema"
"golang.org/x/exp/slog"
)
func TestAuthRequestError(t *testing.T) {
type args struct {
authReq ErrAuthRequest
err error
}
tests := []struct {
name string
args args
wantCode int
wantHeaders map[string]string
wantBody string
wantLog string
}{
{
name: "nil auth request",
args: args{
authReq: nil,
err: io.ErrClosedPipe,
},
wantCode: http.StatusBadRequest,
wantBody: "io: read/write on closed pipe\n",
wantLog: `{
"level":"ERROR",
"msg":"auth request",
"time":"not",
"oidc_error":{
"description":"io: read/write on closed pipe",
"parent":"io: read/write on closed pipe",
"type":"server_error"
}
}`,
},
{
name: "auth request, no redirect URI",
args: args{
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
err: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
wantCode: http.StatusBadRequest,
wantBody: "sign in\n",
wantLog: `{
"level":"WARN",
"msg":"auth request: not redirecting",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"",
"response_type":"responseType",
"scopes":"a b"
},
"oidc_error":{
"description":"sign in",
"type":"interaction_required"
}
}`,
},
{
name: "auth request, redirect disabled",
args: args{
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
RedirectURI: "http://example.com/callback",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
err: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"),
},
wantCode: http.StatusBadRequest,
wantBody: "oops\n",
wantLog: `{
"level":"WARN",
"msg":"auth request: not redirecting",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"http://example.com/callback",
"response_type":"responseType",
"scopes":"a b"
},
"oidc_error":{
"description":"oops",
"type":"invalid_request",
"redirect_disabled":true
}
}`,
},
{
name: "auth request, url parse error",
args: args{
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
RedirectURI: "can't parse this!\n",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
err: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
wantCode: http.StatusBadRequest,
wantBody: "ErrorType=server_error Parent=parse \"can't parse this!\\n\": net/url: invalid control character in URL\n",
wantLog: `{
"level":"ERROR",
"msg":"auth response URL",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"can't parse this!\n",
"response_type":"responseType",
"scopes":"a b"
},
"error":{
"type":"server_error",
"parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL"
},
"oidc_error":{
"description":"sign in",
"type":"interaction_required"
}
}`,
},
{
name: "auth request redirect",
args: args{
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
RedirectURI: "http://example.com/callback",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
err: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
wantCode: http.StatusFound,
wantHeaders: map[string]string{"Location": "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"},
wantLog: `{
"level":"WARN",
"msg":"auth request",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"http://example.com/callback",
"response_type":"responseType",
"scopes":"a b"
},
"oidc_error":{
"description":"sign in",
"type":"interaction_required"
}
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logOut := new(strings.Builder)
authorizer := &Provider{
encoder: schema.NewEncoder(),
logger: slog.New(
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
Level: slog.LevelInfo,
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
),
}
w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/path", nil)
AuthRequestError(w, r, tt.args.authReq, tt.args.err, authorizer)
res := w.Result()
defer res.Body.Close()
assert.Equal(t, tt.wantCode, res.StatusCode)
for key, wantHeader := range tt.wantHeaders {
gotHeader := res.Header.Get(key)
assert.Equalf(t, wantHeader, gotHeader, "header %q", key)
}
gotBody, err := io.ReadAll(res.Body)
require.NoError(t, err, "read result body")
assert.Equal(t, tt.wantBody, string(gotBody), "result body")
gotLog := logOut.String()
t.Log(gotLog)
assert.JSONEq(t, tt.wantLog, gotLog, "log output")
})
}
}
func TestRequestError(t *testing.T) {
tests := []struct {
name string
err error
wantCode int
wantBody string
wantLog string
}{
{
name: "server error",
err: io.ErrClosedPipe,
wantCode: http.StatusBadRequest,
wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`,
wantLog: `{
"level":"ERROR",
"msg":"request error",
"time":"not",
"oidc_error":{
"parent":"io: read/write on closed pipe",
"description":"io: read/write on closed pipe",
"type":"server_error"}
}`,
},
{
name: "invalid client",
err: oidc.ErrInvalidClient().WithDescription("not good"),
wantCode: http.StatusUnauthorized,
wantBody: `{"error":"invalid_client", "error_description":"not good"}`,
wantLog: `{
"level":"WARN",
"msg":"request error",
"time":"not",
"oidc_error":{
"description":"not good",
"type":"invalid_client"}
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logOut := new(strings.Builder)
logger := slog.New(
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
Level: slog.LevelInfo,
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
)
w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/path", nil)
RequestError(w, r, tt.err, logger)
res := w.Result()
defer res.Body.Close()
assert.Equal(t, tt.wantCode, res.StatusCode, "status code")
gotBody, err := io.ReadAll(res.Body)
require.NoError(t, err, "read result body")
assert.JSONEq(t, tt.wantBody, string(gotBody), "result body")
gotLog := logOut.String()
t.Log(gotLog)
assert.JSONEq(t, tt.wantLog, gotLog, "log output")
})
}
}
func TestTryErrorRedirect(t *testing.T) {
type args struct {
ctx context.Context
authReq ErrAuthRequest
parent error
}
tests := []struct {
name string
args args
want *Redirect
wantErr error
wantLog string
}{
{
name: "nil auth request",
args: args{
ctx: context.Background(),
authReq: nil,
parent: io.ErrClosedPipe,
},
wantErr: NewStatusError(io.ErrClosedPipe, http.StatusBadRequest),
wantLog: `{
"level":"ERROR",
"msg":"auth request",
"time":"not",
"oidc_error":{
"description":"io: read/write on closed pipe",
"parent":"io: read/write on closed pipe",
"type":"server_error"
}
}`,
},
{
name: "auth request, no redirect URI",
args: args{
ctx: context.Background(),
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
wantErr: NewStatusError(oidc.ErrInteractionRequired().WithDescription("sign in"), http.StatusBadRequest),
wantLog: `{
"level":"WARN",
"msg":"auth request: not redirecting",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"",
"response_type":"responseType",
"scopes":"a b"
},
"oidc_error":{
"description":"sign in",
"type":"interaction_required"
}
}`,
},
{
name: "auth request, redirect disabled",
args: args{
ctx: context.Background(),
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
RedirectURI: "http://example.com/callback",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
parent: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"),
},
wantErr: NewStatusError(oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), http.StatusBadRequest),
wantLog: `{
"level":"WARN",
"msg":"auth request: not redirecting",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"http://example.com/callback",
"response_type":"responseType",
"scopes":"a b"
},
"oidc_error":{
"description":"oops",
"type":"invalid_request",
"redirect_disabled":true
}
}`,
},
{
name: "auth request, url parse error",
args: args{
ctx: context.Background(),
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
RedirectURI: "can't parse this!\n",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
wantErr: func() error {
//lint:ignore SA1007 just recreating the error for testing
_, err := url.Parse("can't parse this!\n")
err = oidc.ErrServerError().WithParent(err)
return NewStatusError(err, http.StatusBadRequest)
}(),
wantLog: `{
"level":"ERROR",
"msg":"auth response URL",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"can't parse this!\n",
"response_type":"responseType",
"scopes":"a b"
},
"error":{
"type":"server_error",
"parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL"
},
"oidc_error":{
"description":"sign in",
"type":"interaction_required"
}
}`,
},
{
name: "auth request redirect",
args: args{
ctx: context.Background(),
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
RedirectURI: "http://example.com/callback",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
want: &Redirect{
URL: "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1",
},
wantLog: `{
"level":"WARN",
"msg":"auth request redirect",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"http://example.com/callback",
"response_type":"responseType",
"scopes":"a b"
},
"oidc_error":{
"description":"sign in",
"type":"interaction_required"
},
"url":"http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logOut := new(strings.Builder)
logger := slog.New(
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
Level: slog.LevelInfo,
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
)
encoder := schema.NewEncoder()
got, err := TryErrorRedirect(tt.args.ctx, tt.args.authReq, tt.args.parent, encoder, logger)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
gotLog := logOut.String()
t.Log(gotLog)
assert.JSONEq(t, tt.wantLog, gotLog, "log output")
})
}
}
func TestNewStatusError(t *testing.T) {
err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
want := "Internal Server Error: io: read/write on closed pipe"
got := fmt.Sprint(err)
assert.Equal(t, want, got)
}
func TestAsStatusError(t *testing.T) {
type args struct {
err error
statusCode int
}
tests := []struct {
name string
args args
want string
}{
{
name: "already status error",
args: args{
err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError),
statusCode: http.StatusBadRequest,
},
want: "Internal Server Error: io: read/write on closed pipe",
},
{
name: "oidc error",
args: args{
err: oidc.ErrAcrInvalid,
statusCode: http.StatusBadRequest,
},
want: "Bad Request: acr is invalid",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := AsStatusError(tt.args.err, tt.args.statusCode)
got := fmt.Sprint(err)
assert.Equal(t, tt.want, got)
})
}
}
func TestStatusError_Unwrap(t *testing.T) {
err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
require.ErrorIs(t, err, io.ErrClosedPipe)
}
func TestStatusError_Is(t *testing.T) {
type args struct {
err error
}
tests := []struct {
name string
args args
want bool
}{
{
name: "nil error",
args: args{err: nil},
want: false,
},
{
name: "other error",
args: args{err: io.EOF},
want: false,
},
{
name: "other parent",
args: args{err: NewStatusError(io.EOF, http.StatusInternalServerError)},
want: false,
},
{
name: "other status",
args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInsufficientStorage)},
want: false,
},
{
name: "same",
args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)},
want: true,
},
{
name: "wrapped",
args: args{err: fmt.Errorf("wrap: %w", NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError))},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
if got := e.Is(tt.args.err); got != tt.want {
t.Errorf("StatusError.Is() = %v, want %v", got, tt.want)
}
})
}
}
func TestWriteError(t *testing.T) {
tests := []struct {
name string
err error
wantStatus int
wantBody string
wantLog string
}{
{
name: "not a status or oidc error",
err: io.ErrClosedPipe,
wantStatus: http.StatusBadRequest,
wantBody: `{
"error":"server_error",
"error_description":"io: read/write on closed pipe"
}`,
wantLog: `{
"level":"ERROR",
"msg":"request error",
"oidc_error":{
"description":"io: read/write on closed pipe",
"parent":"io: read/write on closed pipe",
"type":"server_error"
},
"time":"not"
}`,
},
{
name: "status error w/o oidc",
err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError),
wantStatus: http.StatusInternalServerError,
wantBody: `{
"error":"server_error",
"error_description":"io: read/write on closed pipe"
}`,
wantLog: `{
"level":"ERROR",
"msg":"request error",
"oidc_error":{
"description":"io: read/write on closed pipe",
"parent":"io: read/write on closed pipe",
"type":"server_error"
},
"time":"not"
}`,
},
{
name: "oidc error w/o status",
err: oidc.ErrInvalidRequest().WithDescription("oops"),
wantStatus: http.StatusBadRequest,
wantBody: `{
"error":"invalid_request",
"error_description":"oops"
}`,
wantLog: `{
"level":"WARN",
"msg":"request error",
"oidc_error":{
"description":"oops",
"type":"invalid_request"
},
"time":"not"
}`,
},
{
name: "status with oidc error",
err: NewStatusError(
oidc.ErrUnauthorizedClient().WithDescription("oops"),
http.StatusUnauthorized,
),
wantStatus: http.StatusUnauthorized,
wantBody: `{
"error":"unauthorized_client",
"error_description":"oops"
}`,
wantLog: `{
"level":"WARN",
"msg":"request error",
"oidc_error":{
"description":"oops",
"type":"unauthorized_client"
},
"time":"not"
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logOut := new(strings.Builder)
logger := slog.New(
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
Level: slog.LevelInfo,
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
)
r := httptest.NewRequest("GET", "/target", nil)
w := httptest.NewRecorder()
WriteError(w, r, tt.err, logger)
res := w.Result()
assert.Equal(t, tt.wantStatus, res.StatusCode, "status code")
gotBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.JSONEq(t, tt.wantBody, string(gotBody), "body")
assert.JSONEq(t, tt.wantLog, logOut.String())
})
}
}

View file

@ -4,9 +4,9 @@ import (
"context" "context"
"net/http" "net/http"
"gopkg.in/square/go-jose.v2" jose "github.com/go-jose/go-jose/v3"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
) )
type KeyProvider interface { type KeyProvider interface {

View file

@ -7,13 +7,13 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
jose "github.com/go-jose/go-jose/v3"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/oidc/v2/pkg/op/mock" "github.com/zitadel/oidc/v3/pkg/op/mock"
) )
func TestKeys(t *testing.T) { func TestKeys(t *testing.T) {

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Authorizer) // Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Authorizer)
// Package mock is a generated GoMock package. // Package mock is a generated GoMock package.
package mock package mock
@ -9,8 +9,9 @@ import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
http "github.com/zitadel/oidc/v2/pkg/http" http "github.com/zitadel/oidc/v3/pkg/http"
op "github.com/zitadel/oidc/v2/pkg/op" op "github.com/zitadel/oidc/v3/pkg/op"
slog "golang.org/x/exp/slog"
) )
// MockAuthorizer is a mock of Authorizer interface. // MockAuthorizer is a mock of Authorizer interface.
@ -79,10 +80,10 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
} }
// IDTokenHintVerifier mocks base method. // IDTokenHintVerifier mocks base method.
func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) op.IDTokenHintVerifier { func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) *op.IDTokenHintVerifier {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0) ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0)
ret0, _ := ret[0].(op.IDTokenHintVerifier) ret0, _ := ret[0].(*op.IDTokenHintVerifier)
return ret0 return ret0
} }
@ -92,6 +93,20 @@ func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier(arg0 interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier), arg0)
} }
// Logger mocks base method.
func (m *MockAuthorizer) Logger() *slog.Logger {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Logger")
ret0, _ := ret[0].(*slog.Logger)
return ret0
}
// Logger indicates an expected call of Logger.
func (mr *MockAuthorizerMockRecorder) Logger() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockAuthorizer)(nil).Logger))
}
// RequestObjectSupported mocks base method. // RequestObjectSupported mocks base method.
func (m *MockAuthorizer) RequestObjectSupported() bool { func (m *MockAuthorizer) RequestObjectSupported() bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -4,12 +4,12 @@ import (
"context" "context"
"testing" "testing"
jose "github.com/go-jose/go-jose/v3"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/gorilla/schema" "github.com/zitadel/schema"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
) )
func NewAuthorizer(t *testing.T) op.Authorizer { func NewAuthorizer(t *testing.T) op.Authorizer {
@ -49,7 +49,7 @@ func ExpectEncoder(a op.Authorizer) {
func ExpectVerifier(a op.Authorizer, t *testing.T) { func ExpectVerifier(a op.Authorizer, t *testing.T) {
mockA := a.(*MockAuthorizer) mockA := a.(*MockAuthorizer)
mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn( mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn(
func() op.IDTokenHintVerifier { func() *op.IDTokenHintVerifier {
return op.NewIDTokenHintVerifier("", nil) return op.NewIDTokenHintVerifier("", nil)
}) })
} }

View file

@ -5,8 +5,8 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
) )
func NewClient(t *testing.T) op.Client { func NewClient(t *testing.T) op.Client {

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Client) // Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Client)
// Package mock is a generated GoMock package. // Package mock is a generated GoMock package.
package mock package mock
@ -9,8 +9,8 @@ import (
time "time" time "time"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
oidc "github.com/zitadel/oidc/v2/pkg/oidc" oidc "github.com/zitadel/oidc/v3/pkg/oidc"
op "github.com/zitadel/oidc/v2/pkg/op" op "github.com/zitadel/oidc/v3/pkg/op"
) )
// MockClient is a mock of Client interface. // MockClient is a mock of Client interface.

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Configuration) // Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Configuration)
// Package mock is a generated GoMock package. // Package mock is a generated GoMock package.
package mock package mock
@ -9,7 +9,7 @@ import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
op "github.com/zitadel/oidc/v2/pkg/op" op "github.com/zitadel/oidc/v3/pkg/op"
language "golang.org/x/text/language" language "golang.org/x/text/language"
) )
@ -65,10 +65,10 @@ func (mr *MockConfigurationMockRecorder) AuthMethodPrivateKeyJWTSupported() *gom
} }
// AuthorizationEndpoint mocks base method. // AuthorizationEndpoint mocks base method.
func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint { func (m *MockConfiguration) AuthorizationEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthorizationEndpoint") ret := m.ctrl.Call(m, "AuthorizationEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }
@ -107,10 +107,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
} }
// DeviceAuthorizationEndpoint mocks base method. // DeviceAuthorizationEndpoint mocks base method.
func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint { func (m *MockConfiguration) DeviceAuthorizationEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint") ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }
@ -121,10 +121,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.C
} }
// EndSessionEndpoint mocks base method. // EndSessionEndpoint mocks base method.
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint { func (m *MockConfiguration) EndSessionEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "EndSessionEndpoint") ret := m.ctrl.Call(m, "EndSessionEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }
@ -233,10 +233,10 @@ func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSup
} }
// IntrospectionEndpoint mocks base method. // IntrospectionEndpoint mocks base method.
func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint { func (m *MockConfiguration) IntrospectionEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IntrospectionEndpoint") ret := m.ctrl.Call(m, "IntrospectionEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }
@ -275,10 +275,10 @@ func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *go
} }
// KeysEndpoint mocks base method. // KeysEndpoint mocks base method.
func (m *MockConfiguration) KeysEndpoint() op.Endpoint { func (m *MockConfiguration) KeysEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeysEndpoint") ret := m.ctrl.Call(m, "KeysEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }
@ -331,10 +331,10 @@ func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSuppor
} }
// RevocationEndpoint mocks base method. // RevocationEndpoint mocks base method.
func (m *MockConfiguration) RevocationEndpoint() op.Endpoint { func (m *MockConfiguration) RevocationEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevocationEndpoint") ret := m.ctrl.Call(m, "RevocationEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }
@ -373,10 +373,10 @@ func (mr *MockConfigurationMockRecorder) SupportedUILocales() *gomock.Call {
} }
// TokenEndpoint mocks base method. // TokenEndpoint mocks base method.
func (m *MockConfiguration) TokenEndpoint() op.Endpoint { func (m *MockConfiguration) TokenEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TokenEndpoint") ret := m.ctrl.Call(m, "TokenEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }
@ -401,10 +401,10 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
} }
// UserinfoEndpoint mocks base method. // UserinfoEndpoint mocks base method.
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint { func (m *MockConfiguration) UserinfoEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserinfoEndpoint") ret := m.ctrl.Call(m, "UserinfoEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: DiscoverStorage) // Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: DiscoverStorage)
// Package mock is a generated GoMock package. // Package mock is a generated GoMock package.
package mock package mock
@ -8,8 +8,8 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
jose "github.com/go-jose/go-jose/v3"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
jose "gopkg.in/square/go-jose.v2"
) )
// MockDiscoverStorage is a mock of DiscoverStorage interface. // MockDiscoverStorage is a mock of DiscoverStorage interface.

View file

@ -1,10 +1,10 @@
package mock package mock
//go:generate go install github.com/golang/mock/mockgen@v1.6.0 //go:generate go install github.com/golang/mock/mockgen@v1.6.0
//go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/v2/pkg/op Storage //go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/v3/pkg/op Storage
//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/v2/pkg/op Authorizer //go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/v3/pkg/op Authorizer
//go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/v2/pkg/op Client //go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/v3/pkg/op Client
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/v2/pkg/op Configuration //go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/v3/pkg/op Configuration
//go:generate mockgen -package mock -destination ./discovery.mock.go github.com/zitadel/oidc/v2/pkg/op DiscoverStorage //go:generate mockgen -package mock -destination ./discovery.mock.go github.com/zitadel/oidc/v3/pkg/op DiscoverStorage
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/v2/pkg/op SigningKey,Key //go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/v3/pkg/op SigningKey,Key
//go:generate mockgen -package mock -destination ./key.mock.go github.com/zitadel/oidc/v2/pkg/op KeyProvider //go:generate mockgen -package mock -destination ./key.mock.go github.com/zitadel/oidc/v3/pkg/op KeyProvider

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: KeyProvider) // Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: KeyProvider)
// Package mock is a generated GoMock package. // Package mock is a generated GoMock package.
package mock package mock
@ -9,7 +9,7 @@ import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
op "github.com/zitadel/oidc/v2/pkg/op" op "github.com/zitadel/oidc/v3/pkg/op"
) )
// MockKeyProvider is a mock of KeyProvider interface. // MockKeyProvider is a mock of KeyProvider interface.

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: SigningKey,Key) // Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: SigningKey,Key)
// Package mock is a generated GoMock package. // Package mock is a generated GoMock package.
package mock package mock
@ -7,8 +7,8 @@ package mock
import ( import (
reflect "reflect" reflect "reflect"
jose "github.com/go-jose/go-jose/v3"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
jose "gopkg.in/square/go-jose.v2"
) )
// MockSigningKey is a mock of SigningKey interface. // MockSigningKey is a mock of SigningKey interface.

View file

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Storage) // Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Storage)
// Package mock is a generated GoMock package. // Package mock is a generated GoMock package.
package mock package mock
@ -9,10 +9,10 @@ import (
reflect "reflect" reflect "reflect"
time "time" time "time"
jose "github.com/go-jose/go-jose/v3"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
oidc "github.com/zitadel/oidc/v2/pkg/oidc" oidc "github.com/zitadel/oidc/v3/pkg/oidc"
op "github.com/zitadel/oidc/v2/pkg/op" op "github.com/zitadel/oidc/v3/pkg/op"
jose "gopkg.in/square/go-jose.v2"
) )
// MockStorage is a mock of Storage interface. // MockStorage is a mock of Storage interface.

View file

@ -8,8 +8,8 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
) )
func NewStorage(t *testing.T) op.Storage { func NewStorage(t *testing.T) op.Storage {

View file

@ -6,16 +6,17 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/gorilla/mux" "github.com/go-chi/chi"
"github.com/gorilla/schema" jose "github.com/go-jose/go-jose/v3"
"github.com/rs/cors" "github.com/rs/cors"
"github.com/zitadel/schema"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace"
"golang.org/x/exp/slog"
"golang.org/x/text/language" "golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
const ( const (
@ -33,7 +34,7 @@ const (
) )
var ( var (
DefaultEndpoints = &endpoints{ DefaultEndpoints = &Endpoints{
Authorization: NewEndpoint(defaultAuthorizationEndpoint), Authorization: NewEndpoint(defaultAuthorizationEndpoint),
Token: NewEndpoint(defaultTokenEndpoint), Token: NewEndpoint(defaultTokenEndpoint),
Introspection: NewEndpoint(defaultIntrospectEndpoint), Introspection: NewEndpoint(defaultIntrospectEndpoint),
@ -76,29 +77,35 @@ func init() {
} }
type OpenIDProvider interface { type OpenIDProvider interface {
http.Handler
Configuration Configuration
Storage() Storage Storage() Storage
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Encoder() httphelper.Encoder Encoder() httphelper.Encoder
IDTokenHintVerifier(context.Context) IDTokenHintVerifier IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
AccessTokenVerifier(context.Context) AccessTokenVerifier AccessTokenVerifier(context.Context) *AccessTokenVerifier
Crypto() Crypto Crypto() Crypto
DefaultLogoutRedirectURI() string DefaultLogoutRedirectURI() string
Probes() []ProbesFn Probes() []ProbesFn
// EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20
Logger() *slog.Logger
// Deprecated: Provider now implements http.Handler directly.
HttpHandler() http.Handler HttpHandler() http.Handler
} }
type HttpInterceptor func(http.Handler) http.Handler type HttpInterceptor func(http.Handler) http.Handler
func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router { func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router {
router := mux.NewRouter() router := chi.NewRouter()
router.Use(cors.New(defaultCORSOptions).Handler) router.Use(cors.New(defaultCORSOptions).Handler)
router.Use(intercept(o.IssuerFromRequest, interceptors...)) router.Use(intercept(o.IssuerFromRequest, interceptors...))
router.HandleFunc(healthEndpoint, healthHandler) router.HandleFunc(healthEndpoint, healthHandler)
router.HandleFunc(readinessEndpoint, readyHandler(o.Probes())) router.HandleFunc(readinessEndpoint, readyHandler(o.Probes()))
router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage())) router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage()))
router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o)) router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o))
router.NewRoute().Path(authCallbackPath(o)).Queries("id", "{id}").HandlerFunc(authorizeCallbackHandler(o)) router.HandleFunc(authCallbackPath(o), authorizeCallbackHandler(o))
router.HandleFunc(o.TokenEndpoint().Relative(), tokenHandler(o)) router.HandleFunc(o.TokenEndpoint().Relative(), tokenHandler(o))
router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o)) router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o))
router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o)) router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o))
@ -132,16 +139,17 @@ type Config struct {
DeviceAuthorization DeviceAuthorizationConfig DeviceAuthorization DeviceAuthorizationConfig
} }
type endpoints struct { // Endpoints defines endpoint routes.
Authorization Endpoint type Endpoints struct {
Token Endpoint Authorization *Endpoint
Introspection Endpoint Token *Endpoint
Userinfo Endpoint Introspection *Endpoint
Revocation Endpoint Userinfo *Endpoint
EndSession Endpoint Revocation *Endpoint
CheckSessionIframe Endpoint EndSession *Endpoint
JwksURI Endpoint CheckSessionIframe *Endpoint
DeviceAuthorization Endpoint JwksURI *Endpoint
DeviceAuthorization *Endpoint
} }
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler()) // NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
@ -186,6 +194,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
storage: storage, storage: storage,
endpoints: DefaultEndpoints, endpoints: DefaultEndpoints,
timer: make(<-chan time.Time), timer: make(<-chan time.Time),
logger: slog.Default(),
} }
for _, optFunc := range opOpts { for _, optFunc := range opOpts {
@ -199,7 +208,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
return nil, err return nil, err
} }
o.httpHandler = CreateRouter(o, o.interceptors...) o.Handler = CreateRouter(o, o.interceptors...)
o.decoder = schema.NewDecoder() o.decoder = schema.NewDecoder()
o.decoder.IgnoreUnknownKeys(true) o.decoder.IgnoreUnknownKeys(true)
@ -215,20 +224,21 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
} }
type Provider struct { type Provider struct {
http.Handler
config *Config config *Config
issuer IssuerFromRequest issuer IssuerFromRequest
insecure bool insecure bool
endpoints *endpoints endpoints *Endpoints
storage Storage storage Storage
keySet *openIDKeySet keySet *openIDKeySet
crypto Crypto crypto Crypto
httpHandler http.Handler
decoder *schema.Decoder decoder *schema.Decoder
encoder *schema.Encoder encoder *schema.Encoder
interceptors []HttpInterceptor interceptors []HttpInterceptor
timer <-chan time.Time timer <-chan time.Time
accessTokenVerifierOpts []AccessTokenVerifierOpt accessTokenVerifierOpts []AccessTokenVerifierOpt
idTokenHintVerifierOpts []IDTokenHintVerifierOpt idTokenHintVerifierOpts []IDTokenHintVerifierOpt
logger *slog.Logger
} }
func (o *Provider) IssuerFromRequest(r *http.Request) string { func (o *Provider) IssuerFromRequest(r *http.Request) string {
@ -239,35 +249,35 @@ func (o *Provider) Insecure() bool {
return o.insecure return o.insecure
} }
func (o *Provider) AuthorizationEndpoint() Endpoint { func (o *Provider) AuthorizationEndpoint() *Endpoint {
return o.endpoints.Authorization return o.endpoints.Authorization
} }
func (o *Provider) TokenEndpoint() Endpoint { func (o *Provider) TokenEndpoint() *Endpoint {
return o.endpoints.Token return o.endpoints.Token
} }
func (o *Provider) IntrospectionEndpoint() Endpoint { func (o *Provider) IntrospectionEndpoint() *Endpoint {
return o.endpoints.Introspection return o.endpoints.Introspection
} }
func (o *Provider) UserinfoEndpoint() Endpoint { func (o *Provider) UserinfoEndpoint() *Endpoint {
return o.endpoints.Userinfo return o.endpoints.Userinfo
} }
func (o *Provider) RevocationEndpoint() Endpoint { func (o *Provider) RevocationEndpoint() *Endpoint {
return o.endpoints.Revocation return o.endpoints.Revocation
} }
func (o *Provider) EndSessionEndpoint() Endpoint { func (o *Provider) EndSessionEndpoint() *Endpoint {
return o.endpoints.EndSession return o.endpoints.EndSession
} }
func (o *Provider) DeviceAuthorizationEndpoint() Endpoint { func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint {
return o.endpoints.DeviceAuthorization return o.endpoints.DeviceAuthorization
} }
func (o *Provider) KeysEndpoint() Endpoint { func (o *Provider) KeysEndpoint() *Endpoint {
return o.endpoints.JwksURI return o.endpoints.JwksURI
} }
@ -354,15 +364,15 @@ func (o *Provider) Encoder() httphelper.Encoder {
return o.encoder return o.encoder
} }
func (o *Provider) IDTokenHintVerifier(ctx context.Context) IDTokenHintVerifier { func (o *Provider) IDTokenHintVerifier(ctx context.Context) *IDTokenHintVerifier {
return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...) return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...)
} }
func (o *Provider) JWTProfileVerifier(ctx context.Context) JWTProfileVerifier { func (o *Provider) JWTProfileVerifier(ctx context.Context) *JWTProfileVerifier {
return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second) return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second)
} }
func (o *Provider) AccessTokenVerifier(ctx context.Context) AccessTokenVerifier { func (o *Provider) AccessTokenVerifier(ctx context.Context) *AccessTokenVerifier {
return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...) return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...)
} }
@ -387,8 +397,13 @@ func (o *Provider) Probes() []ProbesFn {
} }
} }
func (o *Provider) Logger() *slog.Logger {
return o.logger
}
// Deprecated: Provider now implements http.Handler directly.
func (o *Provider) HttpHandler() http.Handler { func (o *Provider) HttpHandler() http.Handler {
return o.httpHandler return o
} }
type openIDKeySet struct { type openIDKeySet struct {
@ -421,7 +436,7 @@ func WithAllowInsecure() Option {
} }
} }
func WithCustomAuthEndpoint(endpoint Endpoint) Option { func WithCustomAuthEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -431,7 +446,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomTokenEndpoint(endpoint Endpoint) Option { func WithCustomTokenEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -441,7 +456,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option { func WithCustomIntrospectionEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -451,7 +466,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomUserinfoEndpoint(endpoint Endpoint) Option { func WithCustomUserinfoEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -461,7 +476,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomRevocationEndpoint(endpoint Endpoint) Option { func WithCustomRevocationEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -471,7 +486,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { func WithCustomEndSessionEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -481,7 +496,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomKeysEndpoint(endpoint Endpoint) Option { func WithCustomKeysEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -491,7 +506,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option { func WithCustomDeviceAuthorizationEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -501,8 +516,16 @@ func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option { // WithCustomEndpoints sets multiple endpoints at once.
// Non of the endpoints may be nil, or an error will
// be returned when the Option used by the Provider.
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
for _, e := range []*Endpoint{auth, token, userInfo, revocation, endSession, keys} {
if err := e.Validate(); err != nil {
return err
}
}
o.endpoints.Authorization = auth o.endpoints.Authorization = auth
o.endpoints.Token = token o.endpoints.Token = token
o.endpoints.Userinfo = userInfo o.endpoints.Userinfo = userInfo
@ -534,6 +557,16 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option {
} }
} }
// WithLogger lets a logger other than slog.Default().
//
// EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20
func WithLogger(logger *slog.Logger) Option {
return func(o *Provider) error {
o.logger = logger
return nil
}
}
func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler { func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler {
issuerInterceptor := NewIssuerInterceptor(i) issuerInterceptor := NewIssuerInterceptor(i)
return func(handler http.Handler) http.Handler { return func(handler http.Handler) http.Handler {

View file

@ -14,9 +14,9 @@ import (
"github.com/muhlemmer/gu" "github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/example/server/storage" "github.com/zitadel/oidc/v3/example/server/storage"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
"golang.org/x/text/language" "golang.org/x/text/language"
) )
@ -157,7 +157,7 @@ func TestRoutes(t *testing.T) {
values: map[string]string{ values: map[string]string{
"client_id": client.GetID(), "client_id": client.GetID(),
"redirect_uri": "https://example.com", "redirect_uri": "https://example.com",
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"response_type": string(oidc.ResponseTypeCode), "response_type": string(oidc.ResponseTypeCode),
}, },
wantCode: http.StatusFound, wantCode: http.StatusFound,
@ -194,7 +194,7 @@ func TestRoutes(t *testing.T) {
path: testProvider.TokenEndpoint().Relative(), path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{ values: map[string]string{
"grant_type": string(oidc.GrantTypeBearer), "grant_type": string(oidc.GrantTypeBearer),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"assertion": jwtToken, "assertion": jwtToken,
}, },
wantCode: http.StatusBadRequest, wantCode: http.StatusBadRequest,
@ -207,7 +207,7 @@ func TestRoutes(t *testing.T) {
basicAuth: &basicAuth{"web", "secret"}, basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{ values: map[string]string{
"grant_type": string(oidc.GrantTypeTokenExchange), "grant_type": string(oidc.GrantTypeTokenExchange),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"subject_token": jwtToken, "subject_token": jwtToken,
"subject_token_type": string(oidc.AccessTokenType), "subject_token_type": string(oidc.AccessTokenType),
}, },
@ -224,7 +224,7 @@ func TestRoutes(t *testing.T) {
basicAuth: &basicAuth{"sid1", "verysecret"}, basicAuth: &basicAuth{"sid1", "verysecret"},
values: map[string]string{ values: map[string]string{
"grant_type": string(oidc.GrantTypeClientCredentials), "grant_type": string(oidc.GrantTypeClientCredentials),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
}, },
wantCode: http.StatusOK, wantCode: http.StatusOK,
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`}, contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`},
@ -339,7 +339,7 @@ func TestRoutes(t *testing.T) {
path: testProvider.DeviceAuthorizationEndpoint().Relative(), path: testProvider.DeviceAuthorizationEndpoint().Relative(),
basicAuth: &basicAuth{"device", "secret"}, basicAuth: &basicAuth{"device", "secret"},
values: map[string]string{ values: map[string]string{
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
}, },
wantCode: http.StatusOK, wantCode: http.StatusOK,
contains: []string{ contains: []string{
@ -371,7 +371,7 @@ func TestRoutes(t *testing.T) {
} }
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
testProvider.HttpHandler().ServeHTTP(rec, req) testProvider.ServeHTTP(rec, req)
resp := rec.Result() resp := rec.Result()
require.NoError(t, err) require.NoError(t, err)
@ -396,3 +396,54 @@ func TestRoutes(t *testing.T) {
}) })
} }
} }
func TestWithCustomEndpoints(t *testing.T) {
type args struct {
auth *op.Endpoint
token *op.Endpoint
userInfo *op.Endpoint
revocation *op.Endpoint
endSession *op.Endpoint
keys *op.Endpoint
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "all nil",
args: args{},
wantErr: op.ErrNilEndpoint,
},
{
name: "all set",
args: args{
auth: op.NewEndpoint("/authorize"),
token: op.NewEndpoint("/oauth/token"),
userInfo: op.NewEndpoint("/userinfo"),
revocation: op.NewEndpoint("/revoke"),
endSession: op.NewEndpoint("/end_session"),
keys: op.NewEndpoint("/keys"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider, err := op.NewOpenIDProvider(testIssuer, testConfig,
storage.NewStorage(storage.NewUserStore(testIssuer)),
op.WithCustomEndpoints(tt.args.auth, tt.args.token, tt.args.userInfo, tt.args.revocation, tt.args.endSession, tt.args.keys),
)
require.ErrorIs(t, err, tt.wantErr)
if tt.wantErr != nil {
return
}
assert.Equal(t, tt.args.auth, provider.AuthorizationEndpoint())
assert.Equal(t, tt.args.token, provider.TokenEndpoint())
assert.Equal(t, tt.args.userInfo, provider.UserinfoEndpoint())
assert.Equal(t, tt.args.revocation, provider.RevocationEndpoint())
assert.Equal(t, tt.args.endSession, provider.EndSessionEndpoint())
assert.Equal(t, tt.args.keys, provider.KeysEndpoint())
})
}
}

View file

@ -5,7 +5,7 @@ import (
"errors" "errors"
"net/http" "net/http"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
) )
type ProbesFn func(context.Context) error type ProbesFn func(context.Context) error
@ -41,9 +41,9 @@ func ReadyStorage(s Storage) ProbesFn {
} }
func ok(w http.ResponseWriter) { func ok(w http.ResponseWriter) {
httphelper.MarshalJSON(w, status{"ok"}) httphelper.MarshalJSON(w, Status{"ok"})
} }
type status struct { type Status struct {
Status string `json:"status,omitempty"` Status string `json:"status,omitempty"`
} }

346
pkg/op/server.go Normal file
View file

@ -0,0 +1,346 @@
package op
import (
"context"
"net/http"
"net/url"
"github.com/muhlemmer/gu"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
// Server describes the interface that needs to be implemented to serve
// OpenID Connect and Oauth2 standard requests.
//
// Methods are called after the HTTP route is resolved and
// the request body is parsed into the Request's Data field.
// When a method is called, it can be assumed that required fields,
// as described in their relevant standard, are validated already.
// The Response Data field may be of any type to allow flexibility
// to extend responses with custom fields. There are however requirements
// in the standards regarding the response models. Where applicable
// the method documentation gives a recommended type which can be used
// directly or extended upon.
//
// The addition of new methods is not considered a breaking change
// as defined by semver rules.
// Implementations MUST embed [UnimplementedServer] to maintain
// forward compatibility.
//
// EXPERIMENTAL: may change until v4
type Server interface {
// Health returns a status of "ok" once the Server is listening.
// The recommended Response Data type is [Status].
Health(context.Context, *Request[struct{}]) (*Response, error)
// Ready returns a status of "ok" once all dependencies,
// such as database storage, are ready.
// An error can be returned to explain what is not ready.
// The recommended Response Data type is [Status].
Ready(context.Context, *Request[struct{}]) (*Response, error)
// Discovery returns the OpenID Provider Configuration Information for this server.
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig
// The recommended Response Data type is [oidc.DiscoveryConfiguration].
Discovery(context.Context, *Request[struct{}]) (*Response, error)
// Keys serves the JWK set which the client can use verify signatures from the op.
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata `jwks_uri` key.
// The recommended Response Data type is [jose.JSONWebKeySet].
Keys(context.Context, *Request[struct{}]) (*Response, error)
// VerifyAuthRequest verifies the Auth Request and
// adds the Client to the request.
//
// When the `request` field is populated with a
// "Request Object" JWT, it needs to be Validated
// and its claims overwrite any fields in the AuthRequest.
// If the implementation does not support "Request Object",
// it MUST return an [oidc.ErrRequestNotSupported].
// https://openid.net/specs/openid-connect-core-1_0.html#RequestObject
VerifyAuthRequest(context.Context, *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error)
// Authorize initiates the authorization flow and redirects to a login page.
// See the various https://openid.net/specs/openid-connect-core-1_0.html
// authorize endpoint sections (one for each type of flow).
Authorize(context.Context, *ClientRequest[oidc.AuthRequest]) (*Redirect, error)
// DeviceAuthorization initiates the device authorization flow.
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
// The recommended Response Data type is [oidc.DeviceAuthorizationResponse].
DeviceAuthorization(context.Context, *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error)
// VerifyClient is called on most oauth/token handlers to authenticate,
// using either a secret (POST, Basic) or assertion (JWT).
// If no secrets are provided, the client must be public.
// This method is called before each method that takes a
// [ClientRequest] argument.
VerifyClient(context.Context, *Request[ClientCredentials]) (Client, error)
// CodeExchange returns Tokens after an authorization code
// is obtained in a successful Authorize flow.
// It is called by the Token endpoint handler when
// grant_type has the value authorization_code
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
// The recommended Response Data type is [oidc.AccessTokenResponse].
CodeExchange(context.Context, *ClientRequest[oidc.AccessTokenRequest]) (*Response, error)
// RefreshToken returns new Tokens after verifying a Refresh token.
// It is called by the Token endpoint handler when
// grant_type has the value refresh_token
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
// The recommended Response Data type is [oidc.AccessTokenResponse].
RefreshToken(context.Context, *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error)
// JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant
// It is called by the Token endpoint handler when
// grant_type has the value urn:ietf:params:oauth:grant-type:jwt-bearer
// https://datatracker.ietf.org/doc/html/rfc7523#section-2.1
// The recommended Response Data type is [oidc.AccessTokenResponse].
JWTProfile(context.Context, *Request[oidc.JWTProfileGrantRequest]) (*Response, error)
// TokenExchange handles the OAuth 2.0 token exchange grant
// It is called by the Token endpoint handler when
// grant_type has the value urn:ietf:params:oauth:grant-type:token-exchange
// https://datatracker.ietf.org/doc/html/rfc8693
// The recommended Response Data type is [oidc.AccessTokenResponse].
TokenExchange(context.Context, *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error)
// ClientCredentialsExchange handles the OAuth 2.0 client credentials grant
// It is called by the Token endpoint handler when
// grant_type has the value client_credentials
// https://datatracker.ietf.org/doc/html/rfc6749#section-4.4
// The recommended Response Data type is [oidc.AccessTokenResponse].
ClientCredentialsExchange(context.Context, *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error)
// DeviceToken handles the OAuth 2.0 Device Authorization Grant
// It is called by the Token endpoint handler when
// grant_type has the value urn:ietf:params:oauth:grant-type:device_code.
// It is typically called in a polling fashion and appropriate errors
// should be returned to signal authorization_pending or access_denied etc.
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4,
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5.
// The recommended Response Data type is [oidc.AccessTokenResponse].
DeviceToken(context.Context, *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error)
// Introspect handles the OAuth 2.0 Token Introspection endpoint.
// https://datatracker.ietf.org/doc/html/rfc7662
// The recommended Response Data type is [oidc.IntrospectionResponse].
Introspect(context.Context, *ClientRequest[oidc.IntrospectionRequest]) (*Response, error)
// UserInfo handles the UserInfo endpoint and returns Claims about the authenticated End-User.
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
// The recommended Response Data type is [oidc.UserInfo].
UserInfo(context.Context, *Request[oidc.UserInfoRequest]) (*Response, error)
// Revocation handles token revocation using an access or refresh token.
// https://datatracker.ietf.org/doc/html/rfc7009
// There are no response requirements. Data may remain empty.
Revocation(context.Context, *ClientRequest[oidc.RevocationRequest]) (*Response, error)
// EndSession handles the OpenID Connect RP-Initiated Logout.
// https://openid.net/specs/openid-connect-rpinitiated-1_0.html
// There are no response requirements. Data may remain empty.
EndSession(context.Context, *Request[oidc.EndSessionRequest]) (*Redirect, error)
// mustImpl forces implementations to embed the UnimplementedServer for forward
// compatibility with the interface.
mustImpl()
}
// Request contains the [http.Request] informational fields
// and parsed Data from the request body (POST) or URL parameters (GET).
// Data can be assumed to be validated according to the applicable
// standard for the specific endpoints.
//
// EXPERIMENTAL: may change until v4
type Request[T any] struct {
Method string
URL *url.URL
Header http.Header
Form url.Values
PostForm url.Values
Data *T
}
func (r *Request[_]) path() string {
return r.URL.Path
}
func newRequest[T any](r *http.Request, data *T) *Request[T] {
return &Request[T]{
Method: r.Method,
URL: r.URL,
Header: r.Header,
Form: r.Form,
PostForm: r.PostForm,
Data: data,
}
}
// ClientRequest is a Request with a verified client attached to it.
// Methods that receive this argument may assume the client was authenticated,
// or verified to be a public client.
//
// EXPERIMENTAL: may change until v4
type ClientRequest[T any] struct {
*Request[T]
Client Client
}
func newClientRequest[T any](r *http.Request, data *T, client Client) *ClientRequest[T] {
return &ClientRequest[T]{
Request: newRequest[T](r, data),
Client: client,
}
}
// Response object for most [Server] methods.
//
// EXPERIMENTAL: may change until v4
type Response struct {
// Header map will be merged with the
// header on the [http.ResponseWriter].
Header http.Header
// Data will be JSON marshaled to
// the response body.
// We allow any type, so that implementations
// can extend the standard types as they wish.
// However, each method will recommend which
// (base) type to use as model, in order to
// be compliant with the standards.
Data any
}
// NewResponse creates a new response for data,
// without custom headers.
func NewResponse(data any) *Response {
return &Response{
Data: data,
}
}
func (resp *Response) writeOut(w http.ResponseWriter) {
gu.MapMerge(resp.Header, w.Header())
httphelper.MarshalJSON(w, resp.Data)
}
// Redirect is a special response type which will
// initiate a [http.StatusFound] redirect.
// The Params field will be encoded and set to the
// URL's RawQuery field before building the URL.
//
// EXPERIMENTAL: may change until v4
type Redirect struct {
// Header map will be merged with the
// header on the [http.ResponseWriter].
Header http.Header
URL string
}
func NewRedirect(url string) *Redirect {
return &Redirect{URL: url}
}
func (red *Redirect) writeOut(w http.ResponseWriter, r *http.Request) {
gu.MapMerge(r.Header, w.Header())
http.Redirect(w, r, red.URL, http.StatusFound)
}
type UnimplementedServer struct{}
// UnimplementedStatusCode is the status code returned for methods
// that are not yet implemented.
// Note that this means methods in the sense of the Go interface,
// and not http methods covered by "501 Not Implemented".
var UnimplementedStatusCode = http.StatusNotFound
func unimplementedError(r interface{ path() string }) StatusError {
err := oidc.ErrServerError().WithDescription("%s not implemented on this server", r.path())
return NewStatusError(err, UnimplementedStatusCode)
}
func unimplementedGrantError(gt oidc.GrantType) StatusError {
err := oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", gt)
return NewStatusError(err, http.StatusBadRequest) // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
}
func (UnimplementedServer) mustImpl() {}
func (UnimplementedServer) Health(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
if r.Data.RequestParam != "" {
return nil, oidc.ErrRequestNotSupported()
}
return nil, unimplementedError(r)
}
func (UnimplementedServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (*Redirect, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeCode)
}
func (UnimplementedServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken)
}
func (UnimplementedServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeBearer)
}
func (UnimplementedServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
}
func (UnimplementedServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials)
}
func (UnimplementedServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
}
func (UnimplementedServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) {
return nil, unimplementedError(r)
}

480
pkg/op/server_http.go Normal file
View file

@ -0,0 +1,480 @@
package op
import (
"context"
"net/http"
"net/url"
"github.com/go-chi/chi"
"github.com/rs/cors"
"github.com/zitadel/logging"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/schema"
"golang.org/x/exp/slog"
)
// RegisterServer registers an implementation of Server.
// The resulting handler takes care of routing and request parsing,
// with some basic validation of required fields.
// The routes can be customized with [WithEndpoints].
//
// EXPERIMENTAL: may change until v4
func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler {
decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(true)
ws := &webServer{
server: server,
endpoints: endpoints,
decoder: decoder,
logger: slog.Default(),
}
for _, option := range options {
option(ws)
}
ws.createRouter()
return ws
}
type ServerOption func(s *webServer)
// WithHTTPMiddleware sets the passed middleware chain to the root of
// the Server's router.
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
return func(s *webServer) {
s.middleware = m
}
}
// WithDecoder overrides the default decoder,
// which is a [schema.Decoder] with IgnoreUnknownKeys set to true.
func WithDecoder(decoder httphelper.Decoder) ServerOption {
return func(s *webServer) {
s.decoder = decoder
}
}
// WithFallbackLogger overrides the fallback logger, which
// is used when no logger was found in the context.
// Defaults to [slog.Default].
func WithFallbackLogger(logger *slog.Logger) ServerOption {
return func(s *webServer) {
s.logger = logger
}
}
type webServer struct {
http.Handler
server Server
middleware []func(http.Handler) http.Handler
endpoints Endpoints
decoder httphelper.Decoder
logger *slog.Logger
}
func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
if logger, ok := logging.FromContext(ctx); ok {
return logger
}
return s.logger
}
func (s *webServer) createRouter() {
router := chi.NewRouter()
router.Use(cors.New(defaultCORSOptions).Handler)
router.Use(s.middleware...)
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler)
s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
s.endpointRoute(router, s.endpoints.Token, s.tokensHandler)
s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler))
s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler)
s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler))
s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler)
s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
s.Handler = router
}
func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) {
if e != nil {
router.HandleFunc(e.Relative(), hf)
s.logger.Info("registered route", "endpoint", e.Relative())
}
}
type clientHandler func(w http.ResponseWriter, r *http.Request, client Client)
func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
client, err := s.verifyRequestClient(r)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" {
if !ValidateGrantType(client, grantType) {
WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.getLogger(r.Context()))
return
}
}
handler(w, r, client)
}
}
func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
if err = r.ParseForm(); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
}
cc := new(ClientCredentials)
if err = s.decoder.Decode(cc, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
// Basic auth takes precedence, so if set it overwrites the form data.
if clientID, clientSecret, ok := r.BasicAuth(); ok {
cc.ClientID, err = url.QueryUnescape(clientID)
if err != nil {
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
cc.ClientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
}
if cc.ClientID == "" && cc.ClientAssertion == "" {
return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided")
}
if cc.ClientAssertion != "" && cc.ClientAssertionType != oidc.ClientAssertionTypeJWTAssertion {
return nil, oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type %s", cc.ClientAssertionType)
}
return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{
Method: r.Method,
URL: r.URL,
Header: r.Header,
Form: r.Form,
Data: cc,
})
}
func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
redirect, err := s.authorize(r.Context(), newRequest(r, request))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
redirect.writeOut(w, r)
}
func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) {
cr, err := s.server.VerifyAuthRequest(ctx, r)
if err != nil {
return nil, err
}
authReq := cr.Data
if authReq.RedirectURI == "" {
return nil, ErrAuthReqMissingRedirectURI
}
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
if err != nil {
return nil, err
}
authReq.Scopes, err = ValidateAuthReqScopes(cr.Client, authReq.Scopes)
if err != nil {
return nil, err
}
if err := ValidateAuthReqRedirectURI(cr.Client, authReq.RedirectURI, authReq.ResponseType); err != nil {
return nil, err
}
if err := ValidateAuthReqResponseType(cr.Client, authReq.ResponseType); err != nil {
return nil, err
}
return s.server.Authorize(ctx, cr)
}
func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
return
}
switch grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType {
case oidc.GrantTypeCode:
s.withClient(s.codeExchangeHandler)(w, r)
case oidc.GrantTypeRefreshToken:
s.withClient(s.refreshTokenHandler)(w, r)
case oidc.GrantTypeClientCredentials:
s.withClient(s.clientCredentialsHandler)(w, r)
case oidc.GrantTypeBearer:
s.jwtProfileHandler(w, r)
case oidc.GrantTypeTokenExchange:
s.withClient(s.tokenExchangeHandler)(w, r)
case oidc.GrantTypeDeviceCode:
s.withClient(s.deviceTokenHandler)(w, r)
case "":
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.getLogger(r.Context()))
default:
WriteError(w, r, unimplementedGrantError(grantType), s.getLogger(r.Context()))
}
}
func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.Assertion == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.Code == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.getLogger(r.Context()))
return
}
if request.RedirectURI == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.RefreshToken == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.SubjectToken == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.getLogger(r.Context()))
return
}
if request.SubjectTokenType == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.getLogger(r.Context()))
return
}
if !request.SubjectTokenType.IsSupported() {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.getLogger(r.Context()))
return
}
if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.getLogger(r.Context()))
return
}
if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.getLogger(r.Context()))
return
}
resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) {
if client.AuthMethod() == oidc.AuthMethodNone {
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
return
}
request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.DeviceCode == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) {
if client.AuthMethod() == oidc.AuthMethodNone {
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
return
}
request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.Token == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if token, err := getAccessToken(r); err == nil {
request.AccessToken = token
}
if request.AccessToken == "" {
err = NewStatusError(
oidc.ErrInvalidRequest().WithDescription("access token missing"),
http.StatusUnauthorized,
)
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp, err := s.server.UserInfo(r.Context(), newRequest(r, request))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.Token == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp, err := s.server.EndSession(r.Context(), newRequest(r, request))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w, r)
}
func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
return
}
resp, err := method(r.Context(), newRequest(r, &struct{}{}))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
}
func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) {
dst := new(R)
if err := r.ParseForm(); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
}
form := r.Form
if postOnly {
form = r.PostForm
}
if err := decoder.Decode(dst, form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
return dst, nil
}

View file

@ -0,0 +1,345 @@
package op_test
import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/client"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
)
func jwtProfile() (string, error) {
keyData, err := client.ConfigFromKeyFile("../../example/server/service-key1.json")
if err != nil {
return "", err
}
signer, err := client.NewSignerFromPrivateKeyByte([]byte(keyData.Key), keyData.KeyID)
if err != nil {
return "", err
}
return client.SignedJWTProfileAssertion(keyData.UserID, []string{testIssuer}, time.Hour, signer)
}
func TestServerRoutes(t *testing.T) {
server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints)
storage := testProvider.Storage().(routesTestStorage)
ctx := op.ContextWithIssuer(context.Background(), testIssuer)
client, err := storage.GetClientByClientID(ctx, "web")
require.NoError(t, err)
oidcAuthReq := &oidc.AuthRequest{
ClientID: client.GetID(),
RedirectURI: "https://example.com",
MaxAge: gu.Ptr[uint](300),
Scopes: oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, oidc.ScopeEmail, oidc.ScopeProfile, oidc.ScopePhone},
ResponseType: oidc.ResponseTypeCode,
}
authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1")
require.NoError(t, err)
storage.AuthRequestDone(authReq.GetID())
accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
require.NoError(t, err)
accessTokenRevoke, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
require.NoError(t, err)
idToken, err := op.CreateIDToken(ctx, testIssuer, authReq, time.Hour, accessToken, "123", storage, client)
require.NoError(t, err)
jwtToken, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeJWT, testProvider, client, "")
require.NoError(t, err)
jwtProfileToken, err := jwtProfile()
require.NoError(t, err)
oidcAuthReq.IDTokenHint = idToken
serverURL, err := url.Parse(testIssuer)
require.NoError(t, err)
type basicAuth struct {
username, password string
}
tests := []struct {
name string
method string
path string
basicAuth *basicAuth
header map[string]string
values map[string]string
body map[string]string
wantCode int
headerContains map[string]string
json string // test for exact json output
contains []string // when the body output is not constant, we just check for snippets to be present in the response
}{
{
name: "health",
method: http.MethodGet,
path: "/healthz",
wantCode: http.StatusOK,
json: `{"status":"ok"}`,
},
{
name: "ready",
method: http.MethodGet,
path: "/ready",
wantCode: http.StatusOK,
json: `{"status":"ok"}`,
},
{
name: "discovery",
method: http.MethodGet,
path: oidc.DiscoveryEndpoint,
wantCode: http.StatusOK,
json: `{"issuer":"https://localhost:9998/","authorization_endpoint":"https://localhost:9998/authorize","token_endpoint":"https://localhost:9998/oauth/token","introspection_endpoint":"https://localhost:9998/oauth/introspect","userinfo_endpoint":"https://localhost:9998/userinfo","revocation_endpoint":"https://localhost:9998/revoke","end_session_endpoint":"https://localhost:9998/end_session","device_authorization_endpoint":"https://localhost:9998/device_authorization","jwks_uri":"https://localhost:9998/keys","scopes_supported":["openid","profile","email","phone","address","offline_access"],"response_types_supported":["code","id_token","id_token token"],"grant_types_supported":["authorization_code","implicit","refresh_token","client_credentials","urn:ietf:params:oauth:grant-type:token-exchange","urn:ietf:params:oauth:grant-type:jwt-bearer","urn:ietf:params:oauth:grant-type:device_code"],"subject_types_supported":["public"],"id_token_signing_alg_values_supported":["RS256"],"request_object_signing_alg_values_supported":["RS256"],"token_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"token_endpoint_auth_signing_alg_values_supported":["RS256"],"revocation_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["RS256"],"introspection_endpoint_auth_methods_supported":["client_secret_basic","private_key_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["RS256"],"claims_supported":["sub","aud","exp","iat","iss","auth_time","nonce","acr","amr","c_hash","at_hash","act","scopes","client_id","azp","preferred_username","name","family_name","given_name","locale","email","email_verified","phone_number","phone_number_verified"],"code_challenge_methods_supported":["S256"],"ui_locales_supported":["en"],"request_parameter_supported":true,"request_uri_parameter_supported":false}`,
},
{
name: "authorization",
method: http.MethodGet,
path: testProvider.AuthorizationEndpoint().Relative(),
values: map[string]string{
"client_id": client.GetID(),
"redirect_uri": "https://example.com",
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"response_type": string(oidc.ResponseTypeCode),
},
wantCode: http.StatusFound,
headerContains: map[string]string{"Location": "/login/username?authRequestID="},
},
{
// This call will fail. A successfull test is already
// part of client/integration_test.go
name: "code exchange",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeCode),
"client_id": client.GetID(),
"client_secret": "secret",
"redirect_uri": "https://example.com",
"code": "123",
},
wantCode: http.StatusBadRequest,
json: `{"error":"invalid_grant", "error_description":"invalid code"}`,
},
{
name: "JWT authorization",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeBearer),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"assertion": jwtProfileToken,
},
wantCode: http.StatusOK,
contains: []string{`{"access_token":`, `"token_type":"Bearer","expires_in":299}`},
},
{
name: "Token exchange",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"grant_type": string(oidc.GrantTypeTokenExchange),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"subject_token": jwtToken,
"subject_token_type": string(oidc.AccessTokenType),
},
wantCode: http.StatusOK,
contains: []string{
`{"access_token":"`,
`","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`,
},
},
{
name: "Client credentials exchange",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
basicAuth: &basicAuth{"sid1", "verysecret"},
values: map[string]string{
"grant_type": string(oidc.GrantTypeClientCredentials),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
},
wantCode: http.StatusOK,
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`},
},
{
// This call will fail. A successfull test is already
// part of device_test.go
name: "device token",
method: http.MethodPost,
path: testProvider.TokenEndpoint().Relative(),
basicAuth: &basicAuth{"device", "secret"},
header: map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
},
body: map[string]string{
"grant_type": string(oidc.GrantTypeDeviceCode),
"device_code": "123",
},
wantCode: http.StatusBadRequest,
json: `{"error":"access_denied","error_description":"The authorization request was denied."}`,
},
{
name: "missing grant type",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
wantCode: http.StatusBadRequest,
json: `{"error":"invalid_request","error_description":"grant_type missing"}`,
},
{
name: "unsupported grant type",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": "foo",
},
wantCode: http.StatusBadRequest,
json: `{"error":"unsupported_grant_type","error_description":"foo not supported"}`,
},
{
name: "introspection",
method: http.MethodGet,
path: testProvider.IntrospectionEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"token": accessToken,
},
wantCode: http.StatusOK,
json: `{"active":true,"scope":"openid offline_access email profile phone","client_id":"web","sub":"id1","username":"test-user@localhost","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
},
{
name: "user info",
method: http.MethodGet,
path: testProvider.UserinfoEndpoint().Relative(),
header: map[string]string{
"authorization": "Bearer " + accessToken,
},
wantCode: http.StatusOK,
json: `{"sub":"id1","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
},
{
name: "refresh token",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeRefreshToken),
"refresh_token": refreshToken,
"client_id": client.GetID(),
"client_secret": "secret",
},
wantCode: http.StatusOK,
contains: []string{
`{"access_token":"`,
`","token_type":"Bearer","refresh_token":"`,
`","expires_in":299,"id_token":"`,
},
},
{
name: "revoke",
method: http.MethodGet,
path: testProvider.RevocationEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"token": accessTokenRevoke,
},
wantCode: http.StatusOK,
},
{
name: "end session",
method: http.MethodGet,
path: testProvider.EndSessionEndpoint().Relative(),
values: map[string]string{
"id_token_hint": idToken,
"client_id": "web",
},
wantCode: http.StatusFound,
headerContains: map[string]string{"Location": "/logged-out"},
contains: []string{`<a href="/logged-out">Found</a>.`},
},
{
name: "keys",
method: http.MethodGet,
path: testProvider.KeysEndpoint().Relative(),
wantCode: http.StatusOK,
contains: []string{
`{"keys":[{"use":"sig","kty":"RSA","kid":"`,
`","alg":"RS256","n":"`, `","e":"AQAB"}]}`,
},
},
{
name: "device authorization",
method: http.MethodGet,
path: testProvider.DeviceAuthorizationEndpoint().Relative(),
basicAuth: &basicAuth{"device", "secret"},
values: map[string]string{
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
},
wantCode: http.StatusOK,
contains: []string{
`{"device_code":"`, `","user_code":"`,
`","verification_uri":"https://localhost:9998/device"`,
`"verification_uri_complete":"https://localhost:9998/device?user_code=`,
`","expires_in":300,"interval":5}`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u := gu.PtrCopy(serverURL)
u.Path = tt.path
if tt.values != nil {
u.RawQuery = mapAsValues(tt.values)
}
var body io.Reader
if tt.body != nil {
body = strings.NewReader(mapAsValues(tt.body))
}
req := httptest.NewRequest(tt.method, u.String(), body)
for k, v := range tt.header {
req.Header.Set(k, v)
}
if tt.basicAuth != nil {
req.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
}
rec := httptest.NewRecorder()
server.ServeHTTP(rec, req)
resp := rec.Result()
require.NoError(t, err)
assert.Equal(t, tt.wantCode, resp.StatusCode)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
respBodyString := string(respBody)
t.Log(respBodyString)
t.Log(resp.Header)
if tt.json != "" {
assert.JSONEq(t, tt.json, respBodyString)
}
for _, c := range tt.contains {
assert.Contains(t, respBodyString, c)
}
for k, v := range tt.headerContains {
assert.Contains(t, resp.Header.Get(k), v)
}
})
}
}

1333
pkg/op/server_http_test.go Normal file

File diff suppressed because it is too large Load diff

344
pkg/op/server_legacy.go Normal file
View file

@ -0,0 +1,344 @@
package op
import (
"context"
"errors"
"net/http"
"time"
"github.com/go-chi/chi"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
// LegacyServer is an implementation of [Server[] that
// simply wraps a [OpenIDProvider].
// It can be used to transition from the former Provider/Storage
// interfaces to the new Server interface.
type LegacyServer struct {
UnimplementedServer
provider OpenIDProvider
endpoints Endpoints
}
// NewLegacyServer wraps provider in a `Server` and returns a handler which is
// the Server's router.
//
// Only non-nil endpoints will be registered on the router.
// Nil endpoints are disabled.
//
// The passed endpoints is also set to the provider,
// to be consistent with the discovery config.
// Any `With*Endpoint()` option used on the provider is
// therefore ineffective.
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler {
server := RegisterServer(&LegacyServer{
provider: provider,
endpoints: endpoints,
}, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
router := chi.NewRouter()
router.Mount("/", server)
router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider))
return router
}
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {
return NewResponse(Status{Status: "ok"}), nil
}
func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) {
for _, probe := range s.provider.Probes() {
// shouldn't we run probes in Go routines?
if err := probe(ctx); err != nil {
return nil, NewStatusError(err, http.StatusInternalServerError)
}
}
return NewResponse(Status{Status: "ok"}), nil
}
func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return NewResponse(
createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints),
), nil
}
func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
keys, err := s.provider.Storage().KeySet(ctx)
if err != nil {
return nil, NewStatusError(err, http.StatusInternalServerError)
}
return NewResponse(jsonWebKeySet(keys)), nil
}
var (
ErrAuthReqMissingClientID = errors.New("auth request is missing client_id")
ErrAuthReqMissingRedirectURI = errors.New("auth request is missing redirect_uri")
)
func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
if r.Data.RequestParam != "" {
if !s.provider.RequestObjectSupported() {
return nil, oidc.ErrRequestNotSupported()
}
err := ParseRequestObject(ctx, r.Data, s.provider.Storage(), IssuerFromContext(ctx))
if err != nil {
return nil, err
}
}
if r.Data.ClientID == "" {
return nil, ErrAuthReqMissingClientID
}
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
if err != nil {
return nil, oidc.DefaultToServerError(err, "unable to retrieve client by id")
}
return &ClientRequest[oidc.AuthRequest]{
Request: r,
Client: client,
}, nil
}
func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (_ *Redirect, err error) {
userID, err := ValidateAuthReqIDTokenHint(ctx, r.Data.IDTokenHint, s.provider.IDTokenHintVerifier(ctx))
if err != nil {
return nil, err
}
req, err := s.provider.Storage().CreateAuthRequest(ctx, r.Data, userID)
if err != nil {
return TryErrorRedirect(ctx, r.Data, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger())
}
return NewRedirect(r.Client.LoginURL(req.GetID())), nil
}
func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider)
if err != nil {
return nil, NewStatusError(err, http.StatusInternalServerError)
}
return NewResponse(response), nil
}
func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials {
storage, ok := s.provider.Storage().(ClientCredentialsStorage)
if !ok {
return nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported")
}
return storage.ClientCredentials(ctx, r.Data.ClientID, r.Data.ClientSecret)
}
if r.Data.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
jwtExchanger, ok := s.provider.(JWTAuthorizationGrantExchanger)
if !ok || !s.provider.AuthMethodPrivateKeyJWTSupported() {
return nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
}
return AuthorizePrivateJWTKey(ctx, r.Data.ClientAssertion, jwtExchanger)
}
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
if err != nil {
return nil, oidc.ErrInvalidClient().WithParent(err)
}
switch client.AuthMethod() {
case oidc.AuthMethodNone:
return client, nil
case oidc.AuthMethodPrivateKeyJWT:
return nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client")
case oidc.AuthMethodPost:
if !s.provider.AuthMethodPostSupported() {
return nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
}
}
err = AuthorizeClientIDSecret(ctx, r.Data.ClientID, r.Data.ClientSecret, s.provider.Storage())
if err != nil {
return nil, err
}
return client, nil
}
func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) {
authReq, err := AuthRequestByCode(ctx, s.provider.Storage(), r.Data.Code)
if err != nil {
return nil, err
}
if r.Client.AuthMethod() == oidc.AuthMethodNone {
if err = AuthorizeCodeChallenge(r.Data.CodeVerifier, authReq.GetCodeChallenge()); err != nil {
return nil, err
}
}
resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.provider, true, r.Data.Code, "")
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) {
if !s.provider.GrantTypeRefreshTokenSupported() {
return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken)
}
request, err := RefreshTokenRequestByRefreshToken(ctx, s.provider.Storage(), r.Data.RefreshToken)
if err != nil {
return nil, err
}
if r.Client.GetID() != request.GetClientID() {
return nil, oidc.ErrInvalidGrant()
}
if err = ValidateRefreshTokenScopes(r.Data.Scopes, request); err != nil {
return nil, err
}
resp, err := CreateTokenResponse(ctx, request, r.Client, s.provider, true, "", r.Data.RefreshToken)
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) {
exchanger, ok := s.provider.(JWTAuthorizationGrantExchanger)
if !ok {
return nil, unimplementedGrantError(oidc.GrantTypeBearer)
}
tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx))
if err != nil {
return nil, err
}
tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope)
if err != nil {
return nil, err
}
resp, err := CreateJWTTokenResponse(ctx, tokenRequest, exchanger)
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) {
if !s.provider.GrantTypeTokenExchangeSupported() {
return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
}
tokenExchangeRequest, err := CreateTokenExchangeRequest(ctx, r.Data, r.Client, s.provider)
if err != nil {
return nil, err
}
resp, err := CreateTokenExchangeResponse(ctx, tokenExchangeRequest, r.Client, s.provider)
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) {
storage, ok := s.provider.Storage().(ClientCredentialsStorage)
if !ok {
return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials)
}
tokenRequest, err := storage.ClientCredentialsTokenRequest(ctx, r.Client.GetID(), r.Data.Scope)
if err != nil {
return nil, err
}
resp, err := CreateClientCredentialsTokenResponse(ctx, tokenRequest, s.provider, r.Client)
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
if !s.provider.GrantTypeClientCredentialsSupported() {
return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
}
// use a limited context timeout shorter as the default
// poll interval of 5 seconds.
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()
state, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider)
if err != nil {
return nil, err
}
tokenRequest := &deviceAccessTokenRequest{
subject: state.Subject,
audience: []string{r.Client.GetID()},
scopes: state.Scopes,
}
resp, err := CreateDeviceTokenResponse(ctx, tokenRequest, s.provider, r.Client)
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) {
response := new(oidc.IntrospectionResponse)
tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.Token)
if !ok {
return NewResponse(response), nil
}
err := s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, r.Client.GetID())
if err != nil {
return NewResponse(response), nil
}
response.Active = true
return NewResponse(response), nil
}
func (s *LegacyServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) {
tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.AccessToken)
if !ok {
return nil, NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid"), http.StatusUnauthorized)
}
info := new(oidc.UserInfo)
err := s.provider.Storage().SetUserinfoFromToken(ctx, info, tokenID, subject, r.Header.Get("origin"))
if err != nil {
return nil, NewStatusError(err, http.StatusForbidden)
}
return NewResponse(info), nil
}
func (s *LegacyServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) {
var subject string
doDecrypt := true
if r.Data.TokenTypeHint != "access_token" {
userID, tokenID, err := s.provider.Storage().GetRefreshTokenInfo(ctx, r.Client.GetID(), r.Data.Token)
if err != nil {
// An invalid refresh token means that we'll try other things (leaving doDecrypt==true)
if !errors.Is(err, ErrInvalidRefreshToken) {
return nil, RevocationError(oidc.ErrServerError().WithParent(err))
}
} else {
r.Data.Token = tokenID
subject = userID
doDecrypt = false
}
}
if doDecrypt {
tokenID, userID, ok := getTokenIDAndSubjectForRevocation(ctx, s.provider, r.Data.Token)
if ok {
r.Data.Token = tokenID
subject = userID
}
}
if err := s.provider.Storage().RevokeToken(ctx, r.Data.Token, subject, r.Client.GetID()); err != nil {
return nil, RevocationError(err)
}
return NewResponse(nil), nil
}
func (s *LegacyServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) {
session, err := ValidateEndSessionRequest(ctx, r.Data, s.provider)
if err != nil {
return nil, err
}
err = s.provider.Storage().TerminateSession(ctx, session.UserID, session.ClientID)
if err != nil {
return nil, err
}
return NewRedirect(session.RedirectURI), nil
}

5
pkg/op/server_test.go Normal file
View file

@ -0,0 +1,5 @@
package op
// implementation check
var _ Server = &UnimplementedServer{}
var _ Server = &LegacyServer{}

View file

@ -6,15 +6,17 @@ import (
"net/url" "net/url"
"path" "path"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/exp/slog"
) )
type SessionEnder interface { type SessionEnder interface {
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Storage() Storage Storage() Storage
IDTokenHintVerifier(context.Context) IDTokenHintVerifier IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
DefaultLogoutRedirectURI() string DefaultLogoutRedirectURI() string
Logger() *slog.Logger
} }
func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) { func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) {
@ -31,7 +33,7 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
} }
session, err := ValidateEndSessionRequest(r.Context(), req, ender) session, err := ValidateEndSessionRequest(r.Context(), req, ender)
if err != nil { if err != nil {
RequestError(w, r, err) RequestError(w, r, err, ender.Logger())
return return
} }
redirect := session.RedirectURI redirect := session.RedirectURI
@ -41,7 +43,7 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID) err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID)
} }
if err != nil { if err != nil {
RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session")) RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"), ender.Logger())
return return
} }
http.Redirect(w, r, redirect, http.StatusFound) http.Redirect(w, r, redirect, http.StatusFound)

View file

@ -3,7 +3,7 @@ package op
import ( import (
"errors" "errors"
"gopkg.in/square/go-jose.v2" jose "github.com/go-jose/go-jose/v3"
) )
var ErrSignerCreationFailed = errors.New("signer creation failed") var ErrSignerCreationFailed = errors.New("signer creation failed")

Some files were not shown because too many files have changed in this diff Show more