Merge branch 'next' into next-main
This commit is contained in:
commit
d9487ef77d
118 changed files with 6091 additions and 981 deletions
|
@ -44,9 +44,9 @@ Check the `/example` folder where example code for different scenarios is locate
|
|||
```bash
|
||||
# start oidc op server
|
||||
# oidc discovery http://localhost:9998/.well-known/openid-configuration
|
||||
go run github.com/zitadel/oidc/v2/example/server
|
||||
go run github.com/zitadel/oidc/v3/example/server
|
||||
# start oidc web client (in a new terminal)
|
||||
CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v2/example/client/app
|
||||
CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v3/example/client/app
|
||||
```
|
||||
|
||||
- open http://localhost:9999/login in your browser
|
||||
|
@ -56,11 +56,11 @@ CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid
|
|||
|
||||
for the dynamic issuer, just start it with:
|
||||
```bash
|
||||
go run github.com/zitadel/oidc/v2/example/server/dynamic
|
||||
go run github.com/zitadel/oidc/v3/example/server/dynamic
|
||||
```
|
||||
the oidc web client above will still work, but if you add `oidc.local` (pointing to 127.0.0.1) in your hosts file you can also start it with:
|
||||
```bash
|
||||
CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://oidc.local:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v2/example/client/app
|
||||
CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://oidc.local:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v3/example/client/app
|
||||
```
|
||||
|
||||
> Note: Usernames are suffixed with the hostname (`test-user@localhost` or `test-user@oidc.local`)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
|
@ -9,11 +10,11 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rs"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rs"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -27,12 +28,12 @@ func main() {
|
|||
port := os.Getenv("PORT")
|
||||
issuer := os.Getenv("ISSUER")
|
||||
|
||||
provider, err := rs.NewResourceServerFromKeyFile(issuer, keyPath)
|
||||
provider, err := rs.NewResourceServerFromKeyFile(context.TODO(), issuer, keyPath)
|
||||
if err != nil {
|
||||
logrus.Fatalf("error creating provider %s", err.Error())
|
||||
}
|
||||
|
||||
router := mux.NewRouter()
|
||||
router := chi.NewRouter()
|
||||
|
||||
// public url accessible without any authorization
|
||||
// will print `OK` and current timestamp
|
||||
|
@ -47,7 +48,7 @@ func main() {
|
|||
if !ok {
|
||||
return
|
||||
}
|
||||
resp, err := rs.Introspect(r.Context(), provider, token)
|
||||
resp, err := rs.Introspect[*oidc.IntrospectionResponse](r.Context(), provider, token)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
return
|
||||
|
@ -68,14 +69,14 @@ func main() {
|
|||
if !ok {
|
||||
return
|
||||
}
|
||||
resp, err := rs.Introspect(r.Context(), provider, token)
|
||||
resp, err := rs.Introspect[*oidc.IntrospectionResponse](r.Context(), provider, token)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
params := mux.Vars(r)
|
||||
requestedClaim := params["claim"]
|
||||
requestedValue := params["value"]
|
||||
requestedClaim := chi.URLParam(r, "claim")
|
||||
requestedValue := chi.URLParam(r, "value")
|
||||
|
||||
value, ok := resp.Claims[requestedClaim].(string)
|
||||
if !ok || value == "" || value != requestedValue {
|
||||
http.Error(w, "claim does not match", http.StatusForbidden)
|
||||
|
|
|
@ -1,19 +1,23 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/slog"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -32,9 +36,25 @@ func main() {
|
|||
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
|
||||
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
|
||||
|
||||
logger := slog.New(
|
||||
slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
AddSource: true,
|
||||
Level: slog.LevelDebug,
|
||||
}),
|
||||
)
|
||||
client := &http.Client{
|
||||
Timeout: time.Minute,
|
||||
}
|
||||
// enable outgoing request logging
|
||||
logging.EnableHTTPClient(client,
|
||||
logging.WithClientGroup("client"),
|
||||
)
|
||||
|
||||
options := []rp.Option{
|
||||
rp.WithCookieHandler(cookieHandler),
|
||||
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
|
||||
rp.WithHTTPClient(client),
|
||||
rp.WithLogger(logger),
|
||||
}
|
||||
if clientSecret == "" {
|
||||
options = append(options, rp.WithPKCE(cookieHandler))
|
||||
|
@ -43,7 +63,10 @@ func main() {
|
|||
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
|
||||
}
|
||||
|
||||
provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, options...)
|
||||
// One can add a logger to the context,
|
||||
// pre-defining log attributes as required.
|
||||
ctx := logging.ToContext(context.TODO(), logger)
|
||||
provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, redirectURI, scopes, options...)
|
||||
if err != nil {
|
||||
logrus.Fatalf("error creating provider %s", err.Error())
|
||||
}
|
||||
|
@ -118,8 +141,22 @@ func main() {
|
|||
//
|
||||
// http.Handle(callbackPath, rp.CodeExchangeHandler(marshalToken, provider))
|
||||
|
||||
// simple counter for request IDs
|
||||
var counter atomic.Int64
|
||||
// enable incomming request logging
|
||||
mw := logging.Middleware(
|
||||
logging.WithLogger(logger),
|
||||
logging.WithGroup("server"),
|
||||
logging.WithIDFunc(func() slog.Attr {
|
||||
return slog.Int64("id", counter.Add(1))
|
||||
}),
|
||||
)
|
||||
|
||||
lis := fmt.Sprintf("127.0.0.1:%s", port)
|
||||
logrus.Infof("listening on http://%s/", lis)
|
||||
logrus.Info("press ctrl+c to stop")
|
||||
logrus.Fatal(http.ListenAndServe(lis, nil))
|
||||
logger.Info("server listening, press ctrl+c to stop", "addr", lis)
|
||||
err = http.ListenAndServe(lis, mw(http.DefaultServeMux))
|
||||
if err != http.ErrServerClosed {
|
||||
logger.Error("server terminated", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,8 +11,8 @@ import (
|
|||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -39,13 +39,13 @@ func main() {
|
|||
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
|
||||
}
|
||||
|
||||
provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, "", scopes, options...)
|
||||
provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, "", scopes, options...)
|
||||
if err != nil {
|
||||
logrus.Fatalf("error creating provider %s", err.Error())
|
||||
}
|
||||
|
||||
logrus.Info("starting device authorization flow")
|
||||
resp, err := rp.DeviceAuthorization(scopes, provider)
|
||||
resp, err := rp.DeviceAuthorization(ctx, scopes, provider, nil)
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -10,10 +10,10 @@ import (
|
|||
"golang.org/x/oauth2"
|
||||
githubOAuth "golang.org/x/oauth2/github"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp/cli"
|
||||
"github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp/cli"
|
||||
"github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client/profile"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/profile"
|
||||
)
|
||||
|
||||
var client = http.DefaultClient
|
||||
|
@ -25,7 +25,7 @@ func main() {
|
|||
scopes := strings.Split(os.Getenv("SCOPES"), " ")
|
||||
|
||||
if keyPath != "" {
|
||||
ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath, scopes)
|
||||
ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(context.TODO(), issuer, keyPath, scopes)
|
||||
if err != nil {
|
||||
logrus.Fatalf("error creating token source %s", err.Error())
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ func main() {
|
|||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(issuer, key, scopes)
|
||||
ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(context.TODO(), issuer, key, scopes)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
|
|
|
@ -6,9 +6,9 @@ import (
|
|||
"html/template"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -43,7 +43,7 @@ var (
|
|||
|
||||
type login struct {
|
||||
authenticate authenticate
|
||||
router *mux.Router
|
||||
router chi.Router
|
||||
callback func(context.Context, string) string
|
||||
}
|
||||
|
||||
|
@ -57,9 +57,9 @@ func NewLogin(authenticate authenticate, callback func(context.Context, string)
|
|||
}
|
||||
|
||||
func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) {
|
||||
l.router = mux.NewRouter()
|
||||
l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler)
|
||||
l.router.Path("/username").Methods("POST").HandlerFunc(issuerInterceptor.HandlerFunc(l.checkLoginHandler))
|
||||
l.router = chi.NewRouter()
|
||||
l.router.Get("/username", l.loginHandler)
|
||||
l.router.With(issuerInterceptor.Handler).Post("/username", l.checkLoginHandler)
|
||||
}
|
||||
|
||||
type authenticate interface {
|
||||
|
|
|
@ -7,11 +7,11 @@ import (
|
|||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/go-chi/chi"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/oidc/v2/example/server/storage"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/example/server/storage"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -47,7 +47,7 @@ func main() {
|
|||
//be sure to create a proper crypto random key and manage it securely!
|
||||
key := sha256.Sum256([]byte("test"))
|
||||
|
||||
router := mux.NewRouter()
|
||||
router := chi.NewRouter()
|
||||
|
||||
//for simplicity, we provide a very small default page for users who have signed out
|
||||
router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) {
|
||||
|
@ -76,7 +76,7 @@ func main() {
|
|||
|
||||
//regardless of how many pages / steps there are in the process, the UI must be registered in the router,
|
||||
//so we will direct all calls to /login to the login UI
|
||||
router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router))
|
||||
router.Mount("/login/", http.StripPrefix("/login", l.router))
|
||||
|
||||
//we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
|
||||
//is served on the correct path
|
||||
|
@ -84,7 +84,7 @@ func main() {
|
|||
//if your issuer ends with a path (e.g. http://localhost:9998/custom/path/),
|
||||
//then you would have to set the path prefix (/custom/path/):
|
||||
//router.PathPrefix("/custom/path/").Handler(http.StripPrefix("/custom/path", provider.HttpHandler()))
|
||||
router.PathPrefix("/").Handler(provider.HttpHandler())
|
||||
router.Mount("/", provider)
|
||||
|
||||
server := &http.Server{
|
||||
Addr: ":" + port,
|
||||
|
|
|
@ -1,21 +1,34 @@
|
|||
package exampleop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
type deviceAuthenticate interface {
|
||||
CheckUsernamePasswordSimple(username, password string) error
|
||||
op.DeviceAuthorizationStorage
|
||||
|
||||
// GetDeviceAuthorizationByUserCode resturns the current state of the device authorization flow,
|
||||
// identified by the user code.
|
||||
GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error)
|
||||
|
||||
// CompleteDeviceAuthorization marks a device authorization entry as Completed,
|
||||
// identified by userCode. The Subject is added to the state, so that
|
||||
// GetDeviceAuthorizatonState can use it to create a new Access Token.
|
||||
CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error
|
||||
|
||||
// DenyDeviceAuthorization marks a device authorization entry as Denied.
|
||||
DenyDeviceAuthorization(ctx context.Context, userCode string) error
|
||||
}
|
||||
|
||||
type deviceLogin struct {
|
||||
|
@ -23,14 +36,14 @@ type deviceLogin struct {
|
|||
cookie *securecookie.SecureCookie
|
||||
}
|
||||
|
||||
func registerDeviceAuth(storage deviceAuthenticate, router *mux.Router) {
|
||||
func registerDeviceAuth(storage deviceAuthenticate, router chi.Router) {
|
||||
l := &deviceLogin{
|
||||
storage: storage,
|
||||
cookie: securecookie.New(securecookie.GenerateRandomKey(32), nil),
|
||||
}
|
||||
|
||||
router.HandleFunc("", l.userCodeHandler)
|
||||
router.Path("/login").Methods(http.MethodPost).HandlerFunc(l.loginHandler)
|
||||
router.HandleFunc("/", l.userCodeHandler)
|
||||
router.Post("/login", l.loginHandler)
|
||||
router.HandleFunc("/confirm", l.confirmHandler)
|
||||
}
|
||||
|
||||
|
|
|
@ -5,13 +5,13 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
type login struct {
|
||||
authenticate authenticate
|
||||
router *mux.Router
|
||||
router chi.Router
|
||||
callback func(context.Context, string) string
|
||||
}
|
||||
|
||||
|
@ -25,9 +25,9 @@ func NewLogin(authenticate authenticate, callback func(context.Context, string)
|
|||
}
|
||||
|
||||
func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) {
|
||||
l.router = mux.NewRouter()
|
||||
l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler)
|
||||
l.router.Path("/username").Methods("POST").HandlerFunc(issuerInterceptor.HandlerFunc(l.checkLoginHandler))
|
||||
l.router = chi.NewRouter()
|
||||
l.router.Get("/username", l.loginHandler)
|
||||
l.router.Post("/username", issuerInterceptor.HandlerFunc(l.checkLoginHandler))
|
||||
}
|
||||
|
||||
type authenticate interface {
|
||||
|
|
|
@ -4,13 +4,16 @@ import (
|
|||
"crypto/sha256"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/zitadel/logging"
|
||||
"golang.org/x/exp/slog"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/oidc/v2/example/server/storage"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/example/server/storage"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -31,26 +34,33 @@ type Storage interface {
|
|||
deviceAuthenticate
|
||||
}
|
||||
|
||||
// simple counter for request IDs
|
||||
var counter atomic.Int64
|
||||
|
||||
// SetupServer creates an OIDC server with Issuer=http://localhost:<port>
|
||||
//
|
||||
// Use one of the pre-made clients in storage/clients.go or register a new one.
|
||||
func SetupServer(issuer string, storage Storage, extraOptions ...op.Option) *mux.Router {
|
||||
func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer bool, extraOptions ...op.Option) chi.Router {
|
||||
// the OpenID Provider requires a 32-byte key for (token) encryption
|
||||
// be sure to create a proper crypto random key and manage it securely!
|
||||
key := sha256.Sum256([]byte("test"))
|
||||
|
||||
router := mux.NewRouter()
|
||||
router := chi.NewRouter()
|
||||
router.Use(logging.Middleware(
|
||||
logging.WithLogger(logger),
|
||||
logging.WithIDFunc(func() slog.Attr {
|
||||
return slog.Int64("id", counter.Add(1))
|
||||
}),
|
||||
))
|
||||
|
||||
// for simplicity, we provide a very small default page for users who have signed out
|
||||
router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) {
|
||||
_, err := w.Write([]byte("signed out successfully"))
|
||||
if err != nil {
|
||||
log.Printf("error serving logged out page: %v", err)
|
||||
}
|
||||
w.Write([]byte("signed out successfully"))
|
||||
// no need to check/log error, this will be handeled by the middleware.
|
||||
})
|
||||
|
||||
// creation of the OpenIDProvider with the just created in-memory Storage
|
||||
provider, err := newOP(storage, issuer, key, extraOptions...)
|
||||
provider, err := newOP(storage, issuer, key, logger, extraOptions...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -62,17 +72,23 @@ func SetupServer(issuer string, storage Storage, extraOptions ...op.Option) *mux
|
|||
|
||||
// regardless of how many pages / steps there are in the process, the UI must be registered in the router,
|
||||
// so we will direct all calls to /login to the login UI
|
||||
router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router))
|
||||
router.Mount("/login/", http.StripPrefix("/login", l.router))
|
||||
|
||||
router.PathPrefix("/device").Subrouter()
|
||||
registerDeviceAuth(storage, router.PathPrefix("/device").Subrouter())
|
||||
router.Route("/device", func(r chi.Router) {
|
||||
registerDeviceAuth(storage, r)
|
||||
})
|
||||
|
||||
handler := http.Handler(provider)
|
||||
if wrapServer {
|
||||
handler = op.NewLegacyServer(provider, *op.DefaultEndpoints)
|
||||
}
|
||||
|
||||
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
|
||||
// is served on the correct path
|
||||
//
|
||||
// if your issuer ends with a path (e.g. http://localhost:9998/custom/path/),
|
||||
// then you would have to set the path prefix (/custom/path/)
|
||||
router.PathPrefix("/").Handler(provider.HttpHandler())
|
||||
router.Mount("/", handler)
|
||||
|
||||
return router
|
||||
}
|
||||
|
@ -80,7 +96,7 @@ func SetupServer(issuer string, storage Storage, extraOptions ...op.Option) *mux
|
|||
// newOP will create an OpenID Provider for localhost on a specified port with a given encryption key
|
||||
// and a predefined default logout uri
|
||||
// it will enable all options (see descriptions)
|
||||
func newOP(storage op.Storage, issuer string, key [32]byte, extraOptions ...op.Option) (op.OpenIDProvider, error) {
|
||||
func newOP(storage op.Storage, issuer string, key [32]byte, logger *slog.Logger, extraOptions ...op.Option) (op.OpenIDProvider, error) {
|
||||
config := &op.Config{
|
||||
CryptoKey: key,
|
||||
|
||||
|
@ -114,10 +130,12 @@ func newOP(storage op.Storage, issuer string, key [32]byte, extraOptions ...op.O
|
|||
}
|
||||
handler, err := op.NewOpenIDProvider(issuer, config, storage,
|
||||
append([]op.Option{
|
||||
// we must explicitly allow the use of the http issuer
|
||||
//we must explicitly allow the use of the http issuer
|
||||
op.WithAllowInsecure(),
|
||||
// as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth
|
||||
op.WithCustomAuthEndpoint(op.NewEndpoint("auth")),
|
||||
// Pass our logger to the OP
|
||||
op.WithLogger(logger.WithGroup("op")),
|
||||
}, extraOptions...)...,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
@ -2,11 +2,12 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/zitadel/oidc/v2/example/server/exampleop"
|
||||
"github.com/zitadel/oidc/v2/example/server/storage"
|
||||
"github.com/zitadel/oidc/v3/example/server/exampleop"
|
||||
"github.com/zitadel/oidc/v3/example/server/storage"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
@ -20,16 +21,22 @@ func main() {
|
|||
// in this example it will be handled in-memory
|
||||
storage := storage.NewStorage(storage.NewUserStore(issuer))
|
||||
|
||||
router := exampleop.SetupServer(issuer, storage)
|
||||
logger := slog.New(
|
||||
slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
AddSource: true,
|
||||
Level: slog.LevelDebug,
|
||||
}),
|
||||
)
|
||||
router := exampleop.SetupServer(issuer, storage, logger, false)
|
||||
|
||||
server := &http.Server{
|
||||
Addr: ":" + port,
|
||||
Handler: router,
|
||||
}
|
||||
log.Printf("server listening on http://localhost:%s/", port)
|
||||
log.Println("press ctrl+c to stop")
|
||||
logger.Info("server listening, press ctrl+c to stop", "addr", fmt.Sprintf("http://localhost:%s/", port))
|
||||
err := server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
if err != http.ErrServerClosed {
|
||||
logger.Error("server terminated", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,8 +3,8 @@ package storage
|
|||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -185,7 +185,7 @@ func WebClient(id, secret string, redirectURIs ...string) *Client {
|
|||
authMethod: oidc.AuthMethodBasic,
|
||||
loginURL: defaultLoginURL,
|
||||
responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode},
|
||||
grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken},
|
||||
grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken, oidc.GrantTypeTokenExchange},
|
||||
accessTokenType: op.AccessTokenTypeBearer,
|
||||
devMode: false,
|
||||
idTokenUserinfoClaimsAssertion: false,
|
||||
|
|
|
@ -3,10 +3,11 @@ package storage
|
|||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/slog"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -41,6 +42,19 @@ type AuthRequest struct {
|
|||
authTime time.Time
|
||||
}
|
||||
|
||||
// LogValue allows you to define which fields will be logged.
|
||||
// Implements the [slog.LogValuer]
|
||||
func (a *AuthRequest) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.String("id", a.ID),
|
||||
slog.Time("creation_date", a.CreationDate),
|
||||
slog.Any("scopes", a.Scopes),
|
||||
slog.String("response_type", string(a.ResponseType)),
|
||||
slog.String("app_id", a.ApplicationID),
|
||||
slog.String("callback_uri", a.CallbackURI),
|
||||
)
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
|
|
|
@ -11,11 +11,11 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
"github.com/google/uuid"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
// serviceKey1 is a public key which will be used for the JWT Profile Authorization Grant
|
||||
|
|
|
@ -4,10 +4,10 @@ import (
|
|||
"context"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
type multiStorage struct {
|
||||
|
|
10
go.mod
10
go.mod
|
@ -1,13 +1,13 @@
|
|||
module github.com/zitadel/oidc/v2
|
||||
module github.com/zitadel/oidc/v3
|
||||
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/go-chi/chi v1.5.4
|
||||
github.com/go-jose/go-jose/v3 v3.0.0
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/go-github/v31 v31.0.0
|
||||
github.com/google/uuid v1.3.1
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/schema v1.2.0
|
||||
github.com/gorilla/securecookie v1.1.1
|
||||
github.com/jeremija/gosubmit v0.2.7
|
||||
github.com/muhlemmer/gu v0.3.1
|
||||
|
@ -15,11 +15,13 @@ require (
|
|||
github.com/rs/cors v1.10.1
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/zitadel/logging v0.4.0
|
||||
github.com/zitadel/schema v1.3.0
|
||||
go.opentelemetry.io/otel v1.19.0
|
||||
go.opentelemetry.io/otel/trace v1.19.0
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
|
||||
golang.org/x/oauth2 v0.13.0
|
||||
golang.org/x/text v0.13.0
|
||||
gopkg.in/square/go-jose.v2 v2.6.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
|
20
go.sum
20
go.sum
|
@ -1,6 +1,10 @@
|
|||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-chi/chi v1.5.4 h1:QHdzF2szwjqVV4wmByUnTcsbIg7UGaQ0tPF2t5GcAIs=
|
||||
github.com/go-chi/chi v1.5.4/go.mod h1:uaf8YgoFazUOkPBG7fxPftUylNumIev9awIWOENIuEg=
|
||||
github.com/go-jose/go-jose/v3 v3.0.0 h1:s6rrhirfEP/CGIoc6p+PZAeogN2SxKav6Wp7+dyMWVo=
|
||||
github.com/go-jose/go-jose/v3 v3.0.0/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
|
@ -13,6 +17,7 @@ github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
|
|||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
|
@ -23,10 +28,6 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD
|
|||
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
|
||||
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
|
||||
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
|
||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||
github.com/gorilla/schema v1.2.0 h1:YufUaxZYCKGFuAq3c96BOhjgd5nmXiOY9NGzF247Tsc=
|
||||
github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU=
|
||||
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||
github.com/jeremija/gosubmit v0.2.7 h1:At0OhGCFGPXyjPYAsCchoBUhE099pcBXmsb4iZqROIc=
|
||||
|
@ -44,10 +45,15 @@ github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU
|
|||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
github.com/zitadel/logging v0.4.0 h1:lRAIFgaRoJpLNbsL7jtIYHcMDoEJP9QZB4GqMfl4xaA=
|
||||
github.com/zitadel/logging v0.4.0/go.mod h1:6uALRJawpkkuUPCkgzfgcPR3c2N908wqnOnIrRelUFc=
|
||||
github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0=
|
||||
github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc=
|
||||
go.opentelemetry.io/otel v1.19.0 h1:MuS/TNf4/j4IXsZuJegVzI1cwut7Qc00344rgH7p8bs=
|
||||
go.opentelemetry.io/otel v1.19.0/go.mod h1:i0QyjOq3UPoTzff0PJB2N66fb4S0+rSbSB15/oyH9fY=
|
||||
go.opentelemetry.io/otel/metric v1.19.0 h1:aTzpGtV0ar9wlV4Sna9sdJyII5jTVJEvKETPiOKwvpE=
|
||||
|
@ -55,9 +61,12 @@ go.opentelemetry.io/otel/metric v1.19.0/go.mod h1:L5rUsV9kM1IxCj1MmSdS+JQAcVm319
|
|||
go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1Dzxpg=
|
||||
go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
|
@ -102,8 +111,7 @@ google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs
|
|||
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
|
||||
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
|
|
@ -8,8 +8,8 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
|
||||
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
tu "github.com/zitadel/oidc/v3/internal/testutil"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
var custom = map[string]any{
|
||||
|
|
|
@ -8,8 +8,9 @@ import (
|
|||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
// KeySet implements oidc.Keys
|
||||
|
@ -17,7 +18,7 @@ type KeySet struct{}
|
|||
|
||||
// VerifySignature implments op.KeySet.
|
||||
func (KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
|
||||
if ctx.Err() != nil {
|
||||
if err = ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -45,6 +46,16 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
type JWTProfileKeyStorage struct{}
|
||||
|
||||
func (JWTProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID string, clientID string) (*jose.JSONWebKey, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gu.Ptr(WebKey.Public()), nil
|
||||
}
|
||||
|
||||
func signEncodeTokenClaims(claims any) string {
|
||||
payload, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
|
@ -106,6 +117,25 @@ func NewAccessToken(issuer, subject string, audience []string, expiration time.T
|
|||
return NewAccessTokenCustom(issuer, subject, audience, expiration, jwtid, clientID, skew, nil)
|
||||
}
|
||||
|
||||
func NewJWTProfileAssertion(issuer, clientID string, audience []string, issuedAt, expiration time.Time) (string, *oidc.JWTTokenRequest) {
|
||||
req := &oidc.JWTTokenRequest{
|
||||
Issuer: issuer,
|
||||
Subject: clientID,
|
||||
Audience: audience,
|
||||
ExpiresAt: oidc.FromTime(expiration),
|
||||
IssuedAt: oidc.FromTime(issuedAt),
|
||||
}
|
||||
// make sure the private claim map is set correctly
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err = json.Unmarshal(data, req); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return signEncodeTokenClaims(req), req
|
||||
}
|
||||
|
||||
const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg`
|
||||
|
||||
// These variables always result in a valid token
|
||||
|
@ -137,6 +167,10 @@ func ValidAccessToken() (string, *oidc.AccessTokenClaims) {
|
|||
return NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew)
|
||||
}
|
||||
|
||||
func ValidJWTProfileAssertion() (string, *oidc.JWTTokenRequest) {
|
||||
return NewJWTProfileAssertion(ValidClientID, ValidClientID, []string{ValidIssuer}, time.Now(), ValidExpiration)
|
||||
}
|
||||
|
||||
// ACRVerify is a oidc.ACRVerifier func.
|
||||
func ACRVerify(acr string) error {
|
||||
if acr != ValidACR {
|
||||
|
|
|
@ -11,24 +11,25 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/crypto"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/oidc/v3/pkg/crypto"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
var Encoder = httphelper.Encoder(oidc.NewEncoder())
|
||||
|
||||
// Discover calls the discovery endpoint of the provided issuer and returns its configuration
|
||||
// It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url
|
||||
func Discover(issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) {
|
||||
func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) {
|
||||
wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
|
||||
if len(wellKnownUrl) == 1 && wellKnownUrl[0] != "" {
|
||||
wellKnown = wellKnownUrl[0]
|
||||
}
|
||||
req, err := http.NewRequest("GET", wellKnown, nil)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -37,6 +38,10 @@ func Discover(issuer string, httpClient *http.Client, wellKnownUrl ...string) (*
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if logger, ok := logging.FromContext(ctx); ok {
|
||||
logger.Debug("discover", "config", discoveryConfig)
|
||||
}
|
||||
|
||||
if discoveryConfig.Issuer != issuer {
|
||||
return nil, oidc.ErrIssuerInvalid
|
||||
}
|
||||
|
@ -48,12 +53,12 @@ type TokenEndpointCaller interface {
|
|||
HttpClient() *http.Client
|
||||
}
|
||||
|
||||
func CallTokenEndpoint(request any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
|
||||
return callTokenEndpoint(request, nil, caller)
|
||||
func CallTokenEndpoint(ctx context.Context, request any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
|
||||
return callTokenEndpoint(ctx, request, nil, caller)
|
||||
}
|
||||
|
||||
func callTokenEndpoint(request any, authFn any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
|
||||
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn)
|
||||
func callTokenEndpoint(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
|
||||
req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -80,8 +85,8 @@ type EndSessionCaller interface {
|
|||
HttpClient() *http.Client
|
||||
}
|
||||
|
||||
func CallEndSessionEndpoint(request any, authFn any, caller EndSessionCaller) (*url.URL, error) {
|
||||
req, err := httphelper.FormRequest(caller.GetEndSessionEndpoint(), request, Encoder, authFn)
|
||||
func CallEndSessionEndpoint(ctx context.Context, request any, authFn any, caller EndSessionCaller) (*url.URL, error) {
|
||||
req, err := httphelper.FormRequest(ctx, caller.GetEndSessionEndpoint(), request, Encoder, authFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -123,8 +128,8 @@ type RevokeRequest struct {
|
|||
ClientSecret string `schema:"client_secret"`
|
||||
}
|
||||
|
||||
func CallRevokeEndpoint(request any, authFn any, caller RevokeCaller) error {
|
||||
req, err := httphelper.FormRequest(caller.GetRevokeEndpoint(), request, Encoder, authFn)
|
||||
func CallRevokeEndpoint(ctx context.Context, request any, authFn any, caller RevokeCaller) error {
|
||||
req, err := httphelper.FormRequest(ctx, caller.GetRevokeEndpoint(), request, Encoder, authFn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -151,8 +156,8 @@ func CallRevokeEndpoint(request any, authFn any, caller RevokeCaller) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func CallTokenExchangeEndpoint(request any, authFn any, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) {
|
||||
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn)
|
||||
func CallTokenExchangeEndpoint(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) {
|
||||
req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -192,8 +197,8 @@ type DeviceAuthorizationCaller interface {
|
|||
HttpClient() *http.Client
|
||||
}
|
||||
|
||||
func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) {
|
||||
req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil)
|
||||
func CallDeviceAuthorizationEndpoint(ctx context.Context, request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller, authFn any) (*oidc.DeviceAuthorizationResponse, error) {
|
||||
req, err := httphelper.FormRequest(ctx, caller.GetDeviceAuthorizationEndpoint(), request, Encoder, authFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -214,7 +219,7 @@ type DeviceAccessTokenRequest struct {
|
|||
}
|
||||
|
||||
func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
|
||||
req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil)
|
||||
req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
|
@ -36,7 +37,7 @@ func TestDiscover(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Discover(tt.args.issuer, http.DefaultClient, tt.args.wellKnownUrl...)
|
||||
got, err := Discover(context.Background(), tt.args.issuer, http.DefaultClient, tt.args.wellKnownUrl...)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
|
|
|
@ -2,33 +2,64 @@ package client_test
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jeremija/gosubmit"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/slog"
|
||||
|
||||
"github.com/zitadel/oidc/v2/example/server/exampleop"
|
||||
"github.com/zitadel/oidc/v2/example/server/storage"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rs"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/tokenexchange"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/example/server/exampleop"
|
||||
"github.com/zitadel/oidc/v3/example/server/storage"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rs"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/tokenexchange"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
var Logger = slog.New(
|
||||
slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
AddSource: true,
|
||||
Level: slog.LevelDebug,
|
||||
}),
|
||||
)
|
||||
|
||||
var CTX context.Context
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(func() int {
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT)
|
||||
defer cancel()
|
||||
CTX, cancel = context.WithTimeout(ctx, time.Minute)
|
||||
defer cancel()
|
||||
return m.Run()
|
||||
}())
|
||||
}
|
||||
|
||||
func TestRelyingPartySession(t *testing.T) {
|
||||
for _, wrapServer := range []bool{false, true} {
|
||||
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
|
||||
testRelyingPartySession(t, wrapServer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testRelyingPartySession(t *testing.T, wrapServer bool) {
|
||||
t.Log("------- start example OP ------")
|
||||
targetURL := "http://local-site"
|
||||
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
|
||||
|
@ -36,17 +67,17 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
opServer := httptest.NewServer(&dh)
|
||||
defer opServer.Close()
|
||||
t.Logf("auth server at %s", opServer.URL)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
|
||||
|
||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
|
||||
t.Log("------- run authorization code flow ------")
|
||||
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, "secret")
|
||||
provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret")
|
||||
|
||||
t.Log("------- refresh tokens ------")
|
||||
|
||||
newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
newTokens, err := rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
|
||||
require.NoError(t, err, "refresh token")
|
||||
assert.NotNil(t, newTokens, "access token")
|
||||
t.Logf("new access token %s", newTokens.AccessToken)
|
||||
|
@ -54,11 +85,13 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
t.Logf("new token type %s", newTokens.TokenType)
|
||||
t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
|
||||
require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
|
||||
assert.NotEmpty(t, newTokens.Extra("id_token"), "new idToken")
|
||||
assert.NotEmpty(t, newTokens.IDToken, "new idToken")
|
||||
assert.NotNil(t, newTokens.IDTokenClaims)
|
||||
assert.Equal(t, newTokens.IDTokenClaims.Subject, tokens.IDTokenClaims.Subject)
|
||||
|
||||
t.Log("------ end session (logout) ------")
|
||||
|
||||
newLoc, err := rp.EndSession(provider, idToken, "", "")
|
||||
newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "")
|
||||
require.NoError(t, err, "logout")
|
||||
if newLoc != nil {
|
||||
t.Logf("redirect to %s", newLoc)
|
||||
|
@ -67,17 +100,25 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Log("------ attempt refresh again (should fail) ------")
|
||||
t.Log("trying original refresh token", refreshToken)
|
||||
_, err = rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
t.Log("trying original refresh token", tokens.RefreshToken)
|
||||
_, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
|
||||
assert.Errorf(t, err, "refresh with original")
|
||||
if newTokens.RefreshToken != "" {
|
||||
t.Log("trying replacement refresh token", newTokens.RefreshToken)
|
||||
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
|
||||
_, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, newTokens.RefreshToken, "", "")
|
||||
assert.Errorf(t, err, "refresh with replacement")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceServerTokenExchange(t *testing.T) {
|
||||
for _, wrapServer := range []bool{false, true} {
|
||||
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
|
||||
testResourceServerTokenExchange(t, wrapServer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testResourceServerTokenExchange(t *testing.T, wrapServer bool) {
|
||||
t.Log("------- start example OP ------")
|
||||
targetURL := "http://local-site"
|
||||
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
|
||||
|
@ -85,23 +126,24 @@ func TestResourceServerTokenExchange(t *testing.T) {
|
|||
opServer := httptest.NewServer(&dh)
|
||||
defer opServer.Close()
|
||||
t.Logf("auth server at %s", opServer.URL)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
|
||||
|
||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
clientSecret := "secret"
|
||||
|
||||
t.Log("------- run authorization code flow ------")
|
||||
provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)
|
||||
provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)
|
||||
|
||||
resourceServer, err := rs.NewResourceServerClientCredentials(opServer.URL, clientID, clientSecret)
|
||||
resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret)
|
||||
require.NoError(t, err, "new resource server")
|
||||
|
||||
t.Log("------- exchage refresh tokens (impersonation) ------")
|
||||
|
||||
tokenExchangeResponse, err := tokenexchange.ExchangeToken(
|
||||
CTX,
|
||||
resourceServer,
|
||||
refreshToken,
|
||||
tokens.RefreshToken,
|
||||
oidc.RefreshTokenType,
|
||||
"",
|
||||
"",
|
||||
|
@ -119,7 +161,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
|
|||
|
||||
t.Log("------ end session (logout) ------")
|
||||
|
||||
newLoc, err := rp.EndSession(provider, idToken, "", "")
|
||||
newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "")
|
||||
require.NoError(t, err, "logout")
|
||||
if newLoc != nil {
|
||||
t.Logf("redirect to %s", newLoc)
|
||||
|
@ -130,8 +172,9 @@ func TestResourceServerTokenExchange(t *testing.T) {
|
|||
t.Log("------- attempt exchage again (should fail) ------")
|
||||
|
||||
tokenExchangeResponse, err = tokenexchange.ExchangeToken(
|
||||
CTX,
|
||||
resourceServer,
|
||||
refreshToken,
|
||||
tokens.RefreshToken,
|
||||
oidc.RefreshTokenType,
|
||||
"",
|
||||
"",
|
||||
|
@ -145,7 +188,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
|
|||
require.Nil(t, tokenExchangeResponse, "token exchange response")
|
||||
}
|
||||
|
||||
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) {
|
||||
func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) {
|
||||
targetURL := "http://local-site"
|
||||
localURL, err := url.Parse(targetURL + "/login?requestID=1234")
|
||||
require.NoError(t, err, "local url")
|
||||
|
@ -167,6 +210,7 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
|
|||
key := []byte("test1234test1234")
|
||||
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
|
||||
provider, err = rp.NewRelyingPartyOIDC(
|
||||
CTX,
|
||||
opServer.URL,
|
||||
clientID,
|
||||
clientSecret,
|
||||
|
@ -241,7 +285,8 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
|
|||
}
|
||||
|
||||
var email string
|
||||
redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
|
||||
redirect := func(w http.ResponseWriter, r *http.Request, newTokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
|
||||
tokens = newTokens
|
||||
require.NotNil(t, tokens, "tokens")
|
||||
require.NotNil(t, info, "info")
|
||||
t.Log("access token", tokens.AccessToken)
|
||||
|
@ -249,9 +294,6 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
|
|||
t.Log("id token", tokens.IDToken)
|
||||
t.Log("email", info.Email)
|
||||
|
||||
accessToken = tokens.AccessToken
|
||||
refreshToken = tokens.RefreshToken
|
||||
idToken = tokens.IDToken
|
||||
email = info.Email
|
||||
http.Redirect(w, r, targetURL, 302)
|
||||
}
|
||||
|
@ -273,12 +315,12 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
|
|||
require.NoError(t, err, "get fully-authorizied redirect location")
|
||||
require.Equal(t, targetURL, authorizedURL.String(), "fully-authorizied redirect location")
|
||||
|
||||
require.NotEmpty(t, idToken, "id token")
|
||||
assert.NotEmpty(t, refreshToken, "refresh token")
|
||||
assert.NotEmpty(t, accessToken, "access token")
|
||||
require.NotEmpty(t, tokens.IDToken, "id token")
|
||||
assert.NotEmpty(t, tokens.RefreshToken, "refresh token")
|
||||
assert.NotEmpty(t, tokens.AccessToken, "access token")
|
||||
assert.NotEmpty(t, email, "email")
|
||||
|
||||
return provider, accessToken, refreshToken, idToken
|
||||
return provider, tokens
|
||||
}
|
||||
|
||||
func TestErrorFromPromptNone(t *testing.T) {
|
||||
|
@ -299,7 +341,7 @@ func TestErrorFromPromptNone(t *testing.T) {
|
|||
opServer := httptest.NewServer(&dh)
|
||||
defer opServer.Close()
|
||||
t.Logf("auth server at %s", opServer.URL)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, op.WithHttpInterceptors(
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, false, op.WithHttpInterceptors(
|
||||
func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Logf("request to %s", r.URL)
|
||||
|
@ -317,6 +359,7 @@ func TestErrorFromPromptNone(t *testing.T) {
|
|||
key := []byte("test1234test1234")
|
||||
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
|
||||
provider, err := rp.NewRelyingPartyOIDC(
|
||||
CTX,
|
||||
opServer.URL,
|
||||
clientID,
|
||||
clientSecret,
|
||||
|
@ -412,7 +455,7 @@ func getForm(t *testing.T, desc string, httpClient *http.Client, uri *url.URL) [
|
|||
|
||||
func fillForm(t *testing.T, desc string, httpClient *http.Client, body []byte, uri *url.URL, opts ...gosubmit.Option) *url.URL {
|
||||
// TODO: switch to io.NopCloser when go1.15 support is dropped
|
||||
req := gosubmit.ParseWithURL(ioutil.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest(
|
||||
req := gosubmit.ParseWithURL(io.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest(
|
||||
append([]gosubmit.Option{gosubmit.AutoFill()}, opts...)...,
|
||||
)
|
||||
if req.URL.Scheme == "" {
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
// JWTProfileExchange handles the oauth2 jwt profile exchange
|
||||
func JWTProfileExchange(jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) {
|
||||
return CallTokenEndpoint(jwtProfileGrantRequest, caller)
|
||||
func JWTProfileExchange(ctx context.Context, jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) {
|
||||
return CallTokenEndpoint(ctx, jwtProfileGrantRequest, caller)
|
||||
}
|
||||
|
||||
func ClientAssertionCodeOptions(assertion string) []oauth2.AuthCodeOption {
|
||||
|
|
|
@ -10,7 +10,7 @@ const (
|
|||
applicationKey = "application"
|
||||
)
|
||||
|
||||
type keyFile struct {
|
||||
type KeyFile struct {
|
||||
Type string `json:"type"` // serviceaccount or application
|
||||
KeyID string `json:"keyId"`
|
||||
Key string `json:"key"`
|
||||
|
@ -23,7 +23,7 @@ type keyFile struct {
|
|||
ClientID string `json:"clientId"`
|
||||
}
|
||||
|
||||
func ConfigFromKeyFile(path string) (*keyFile, error) {
|
||||
func ConfigFromKeyFile(path string) (*KeyFile, error) {
|
||||
data, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -31,8 +31,8 @@ func ConfigFromKeyFile(path string) (*keyFile, error) {
|
|||
return ConfigFromKeyFileData(data)
|
||||
}
|
||||
|
||||
func ConfigFromKeyFileData(data []byte) (*keyFile, error) {
|
||||
var f keyFile
|
||||
func ConfigFromKeyFileData(data []byte) (*KeyFile, error) {
|
||||
var f KeyFile
|
||||
if err := json.Unmarshal(data, &f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,16 +1,22 @@
|
|||
package profile
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/client"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
type TokenSource interface {
|
||||
oauth2.TokenSource
|
||||
TokenCtx(context.Context) (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
// jwtProfileTokenSource implement the oauth2.TokenSource
|
||||
// it will request a token using the OAuth2 JWT Profile Grant
|
||||
// therefore sending an `assertion` by signing a JWT with the provided private key
|
||||
|
@ -23,23 +29,38 @@ type jwtProfileTokenSource struct {
|
|||
tokenEndpoint string
|
||||
}
|
||||
|
||||
func NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath string, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) {
|
||||
keyData, err := client.ConfigFromKeyFile(keyPath)
|
||||
// NewJWTProfileTokenSourceFromKeyFile returns an implementation of TokenSource
|
||||
// It will request a token using the OAuth2 JWT Profile Grant,
|
||||
// therefore sending an `assertion` by singing a JWT with the provided private key from jsonFile.
|
||||
//
|
||||
// The passed context is only used for the call to the Discover endpoint.
|
||||
func NewJWTProfileTokenSourceFromKeyFile(ctx context.Context, issuer, jsonFile string, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) {
|
||||
keyData, err := client.ConfigFromKeyFile(jsonFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
|
||||
return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
|
||||
}
|
||||
|
||||
func NewJWTProfileTokenSourceFromKeyFileData(issuer string, data []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) {
|
||||
keyData, err := client.ConfigFromKeyFileData(data)
|
||||
// NewJWTProfileTokenSourceFromKeyFileData returns an implementation of oauth2.TokenSource
|
||||
// It will request a token using the OAuth2 JWT Profile Grant,
|
||||
// therefore sending an `assertion` by singing a JWT with the provided private key in jsonData.
|
||||
//
|
||||
// The passed context is only used for the call to the Discover endpoint.
|
||||
func NewJWTProfileTokenSourceFromKeyFileData(ctx context.Context, issuer string, jsonData []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) {
|
||||
keyData, err := client.ConfigFromKeyFileData(jsonData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
|
||||
return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
|
||||
}
|
||||
|
||||
func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) {
|
||||
// NewJWTProfileSource returns an implementation of oauth2.TokenSource
|
||||
// It will request a token using the OAuth2 JWT Profile Grant,
|
||||
// therefore sending an `assertion` by singing a JWT with the provided private key.
|
||||
//
|
||||
// The passed context is only used for the call to the Discover endpoint.
|
||||
func NewJWTProfileTokenSource(ctx context.Context, issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) {
|
||||
signer, err := client.NewSignerFromPrivateKeyByte(key, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -55,7 +76,7 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes
|
|||
opt(source)
|
||||
}
|
||||
if source.tokenEndpoint == "" {
|
||||
config, err := client.Discover(issuer, source.httpClient)
|
||||
config, err := client.Discover(ctx, issuer, source.httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -64,13 +85,13 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes
|
|||
return source, nil
|
||||
}
|
||||
|
||||
func WithHTTPClient(client *http.Client) func(*jwtProfileTokenSource) {
|
||||
func WithHTTPClient(client *http.Client) func(source *jwtProfileTokenSource) {
|
||||
return func(source *jwtProfileTokenSource) {
|
||||
source.httpClient = client
|
||||
}
|
||||
}
|
||||
|
||||
func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(*jwtProfileTokenSource) {
|
||||
func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(source *jwtProfileTokenSource) {
|
||||
return func(source *jwtProfileTokenSource) {
|
||||
source.tokenEndpoint = tokenEndpoint
|
||||
}
|
||||
|
@ -85,9 +106,13 @@ func (j *jwtProfileTokenSource) HttpClient() *http.Client {
|
|||
}
|
||||
|
||||
func (j *jwtProfileTokenSource) Token() (*oauth2.Token, error) {
|
||||
return j.TokenCtx(context.Background())
|
||||
}
|
||||
|
||||
func (j *jwtProfileTokenSource) TokenCtx(ctx context.Context) (*oauth2.Token, error) {
|
||||
assertion, err := client.SignedJWTProfileAssertion(j.clientID, j.audience, time.Hour, j.signer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return client.JWTProfileExchange(oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j)
|
||||
return client.JWTProfileExchange(ctx, oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j)
|
||||
}
|
||||
|
|
|
@ -4,9 +4,9 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc/grants/tokenexchange"
|
||||
)
|
||||
|
||||
// DelegationTokenRequest is an implementation of TokenExchangeRequest
|
||||
|
|
|
@ -5,8 +5,8 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/client"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) {
|
||||
|
@ -32,19 +32,21 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.
|
|||
// DeviceAuthorization starts a new Device Authorization flow as defined
|
||||
// in RFC 8628, section 3.1 and 3.2:
|
||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
|
||||
func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) {
|
||||
func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, authFn any) (*oidc.DeviceAuthorizationResponse, error) {
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAuthorization")
|
||||
req, err := newDeviceClientCredentialsRequest(scopes, rp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.CallDeviceAuthorizationEndpoint(req, rp)
|
||||
return client.CallDeviceAuthorizationEndpoint(ctx, req, rp, authFn)
|
||||
}
|
||||
|
||||
// DeviceAccessToken attempts to obtain tokens from a Device Authorization,
|
||||
// by means of polling as defined in RFC, section 3.3 and 3.4:
|
||||
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4
|
||||
func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAccessToken")
|
||||
req := &client.DeviceAccessTokenRequest{
|
||||
DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
|
||||
GrantType: oidc.GrantTypeDeviceCode,
|
||||
|
|
|
@ -7,10 +7,10 @@ import (
|
|||
"net/http"
|
||||
"sync"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet {
|
||||
|
|
17
pkg/client/rp/log.go
Normal file
17
pkg/client/rp/log.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func logCtxWithRPData(ctx context.Context, rp RelyingParty, attrs ...any) context.Context {
|
||||
logger, ok := rp.Logger(ctx)
|
||||
if !ok {
|
||||
return ctx
|
||||
}
|
||||
logger = logger.With(slog.Group("rp", attrs...))
|
||||
return logging.ToContext(ctx, logger)
|
||||
}
|
|
@ -7,16 +7,17 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
"github.com/google/uuid"
|
||||
"github.com/zitadel/logging"
|
||||
"golang.org/x/exp/slog"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -63,11 +64,14 @@ type RelyingParty interface {
|
|||
// be used to start a DeviceAuthorization flow.
|
||||
GetDeviceAuthorizationEndpoint() string
|
||||
|
||||
// IDTokenVerifier returns the verifier interface used for oidc id_token verification
|
||||
IDTokenVerifier() IDTokenVerifier
|
||||
// IDTokenVerifier returns the verifier used for oidc id_token verification
|
||||
IDTokenVerifier() *IDTokenVerifier
|
||||
// ErrorHandler returns the handler used for callback errors
|
||||
|
||||
ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)
|
||||
|
||||
// Logger from the context, or a fallback if set.
|
||||
Logger(context.Context) (logger *slog.Logger, ok bool)
|
||||
}
|
||||
|
||||
type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string)
|
||||
|
@ -88,9 +92,10 @@ type relyingParty struct {
|
|||
cookieHandler *httphelper.CookieHandler
|
||||
|
||||
errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
|
||||
idTokenVerifier IDTokenVerifier
|
||||
idTokenVerifier *IDTokenVerifier
|
||||
verifierOpts []VerifierOption
|
||||
signer jose.Signer
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (rp *relyingParty) OAuthConfig() *oauth2.Config {
|
||||
|
@ -137,7 +142,7 @@ func (rp *relyingParty) GetRevokeEndpoint() string {
|
|||
return rp.endpoints.RevokeURL
|
||||
}
|
||||
|
||||
func (rp *relyingParty) IDTokenVerifier() IDTokenVerifier {
|
||||
func (rp *relyingParty) IDTokenVerifier() *IDTokenVerifier {
|
||||
if rp.idTokenVerifier == nil {
|
||||
rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...)
|
||||
}
|
||||
|
@ -151,6 +156,14 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request,
|
|||
return rp.errorHandler
|
||||
}
|
||||
|
||||
func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) {
|
||||
logger, ok = logging.FromContext(ctx)
|
||||
if ok {
|
||||
return logger, ok
|
||||
}
|
||||
return rp.logger, rp.logger != nil
|
||||
}
|
||||
|
||||
// NewRelyingPartyOAuth creates an (OAuth2) RelyingParty with the given
|
||||
// OAuth2 Config and possible configOptions
|
||||
// it will use the AuthURL and TokenURL set in config
|
||||
|
@ -177,7 +190,7 @@ func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingPart
|
|||
// NewRelyingPartyOIDC creates an (OIDC) RelyingParty with the given
|
||||
// issuer, clientID, clientSecret, redirectURI, scopes and possible configOptions
|
||||
// it will run discovery on the provided issuer and use the found endpoints
|
||||
func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) {
|
||||
func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) {
|
||||
rp := &relyingParty{
|
||||
issuer: issuer,
|
||||
oauthConfig: &oauth2.Config{
|
||||
|
@ -195,7 +208,8 @@ func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, sco
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
discoveryConfiguration, err := client.Discover(rp.issuer, rp.httpClient, rp.DiscoveryEndpoint)
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "NewRelyingPartyOIDC")
|
||||
discoveryConfiguration, err := client.Discover(ctx, rp.issuer, rp.httpClient, rp.DiscoveryEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -282,6 +296,15 @@ func WithJWTProfile(signerFromKey SignerFromKey) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithLogger sets a logger that is used
|
||||
// in case the request context does not contain a logger.
|
||||
func WithLogger(logger *slog.Logger) Option {
|
||||
return func(rp *relyingParty) error {
|
||||
rp.logger = logger
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type SignerFromKey func() (jose.Signer, error)
|
||||
|
||||
func SignerFromKeyPath(path string) SignerFromKey {
|
||||
|
@ -310,26 +333,6 @@ func SignerFromKeyAndKeyID(key []byte, keyID string) SignerFromKey {
|
|||
}
|
||||
}
|
||||
|
||||
// Discover calls the discovery endpoint of the provided issuer and returns the found endpoints
|
||||
//
|
||||
// deprecated: use client.Discover
|
||||
func Discover(issuer string, httpClient *http.Client) (Endpoints, error) {
|
||||
wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
|
||||
req, err := http.NewRequest("GET", wellKnown, nil)
|
||||
if err != nil {
|
||||
return Endpoints{}, err
|
||||
}
|
||||
discoveryConfig := new(oidc.DiscoveryConfiguration)
|
||||
err = httphelper.HttpRequest(httpClient, req, &discoveryConfig)
|
||||
if err != nil {
|
||||
return Endpoints{}, err
|
||||
}
|
||||
if discoveryConfig.Issuer != issuer {
|
||||
return Endpoints{}, fmt.Errorf("%w: Expected: %s, got: %s", oidc.ErrIssuerInvalid, discoveryConfig.Issuer, issuer)
|
||||
}
|
||||
return GetEndpoints(discoveryConfig), nil
|
||||
}
|
||||
|
||||
// AuthURL returns the auth request url
|
||||
// (wrapping the oauth2 `AuthCodeURL`)
|
||||
func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string {
|
||||
|
@ -377,9 +380,29 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri
|
|||
return oidc.NewSHACodeChallenge(codeVerifier), nil
|
||||
}
|
||||
|
||||
// ErrMissingIDToken is returned when an id_token was expected,
|
||||
// but not received in the token response.
|
||||
var ErrMissingIDToken = errors.New("id_token missing")
|
||||
|
||||
func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Token, rp RelyingParty) (*oidc.Tokens[C], error) {
|
||||
if rp.IsOAuth2Only() {
|
||||
return &oidc.Tokens[C]{Token: token}, nil
|
||||
}
|
||||
idTokenString, ok := token.Extra(idTokenKey).(string)
|
||||
if !ok {
|
||||
return &oidc.Tokens[C]{Token: token}, ErrMissingIDToken
|
||||
}
|
||||
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
|
||||
}
|
||||
|
||||
// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
|
||||
// returning it parsed together with the oauth2 tokens (access, refresh)
|
||||
func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "CodeExchange")
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
|
||||
codeOpts := make([]oauth2.AuthCodeOption, 0)
|
||||
for _, opt := range opts {
|
||||
|
@ -390,22 +413,7 @@ func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingP
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rp.IsOAuth2Only() {
|
||||
return &oidc.Tokens[C]{Token: token}, nil
|
||||
}
|
||||
|
||||
idTokenString, ok := token.Extra(idTokenKey).(string)
|
||||
if !ok {
|
||||
return nil, errors.New("id_token missing")
|
||||
}
|
||||
|
||||
idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
|
||||
return verifyTokenResponse[C](ctx, token, rp)
|
||||
}
|
||||
|
||||
type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty)
|
||||
|
@ -457,14 +465,18 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
|
|||
}
|
||||
}
|
||||
|
||||
type CodeExchangeUserinfoCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info *oidc.UserInfo)
|
||||
type SubjectGetter interface {
|
||||
GetSubject() string
|
||||
}
|
||||
|
||||
type CodeExchangeUserinfoCallback[C oidc.IDClaims, U SubjectGetter] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info U)
|
||||
|
||||
// UserinfoCallback wraps the callback function of the CodeExchangeHandler
|
||||
// and calls the userinfo endpoint with the access token
|
||||
// on success it will pass the userinfo into its callback function as well
|
||||
func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] {
|
||||
func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCallback[C, U]) CodeExchangeCallback[C] {
|
||||
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
|
||||
info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
|
||||
info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
|
||||
if err != nil {
|
||||
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
|
@ -473,19 +485,26 @@ func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeEx
|
|||
}
|
||||
}
|
||||
|
||||
// Userinfo will call the OIDC Userinfo Endpoint with the provided token
|
||||
func Userinfo(token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) {
|
||||
req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil)
|
||||
// Userinfo will call the OIDC [UserInfo] Endpoint with the provided token and returns
|
||||
// the response in an instance of type U.
|
||||
// [*oidc.UserInfo] can be used as a good example, or use a custom type if type-safe
|
||||
// access to custom claims is needed.
|
||||
//
|
||||
// [UserInfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject string, rp RelyingParty) (userinfo U, err error) {
|
||||
var nilU U
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "Userinfo")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nilU, err
|
||||
}
|
||||
req.Header.Set("authorization", tokenType+" "+token)
|
||||
userinfo := new(oidc.UserInfo)
|
||||
if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil {
|
||||
return nil, err
|
||||
return nilU, err
|
||||
}
|
||||
if userinfo.Subject != subject {
|
||||
return nil, ErrUserInfoSubNotMatching
|
||||
if userinfo.GetSubject() != subject {
|
||||
return nilU, ErrUserInfoSubNotMatching
|
||||
}
|
||||
return userinfo, nil
|
||||
}
|
||||
|
@ -554,7 +573,7 @@ func withURLParam(key, value string) func() []oauth2.AuthCodeOption {
|
|||
// This is the generalized, unexported, function used by both
|
||||
// URLParamOpt and AuthURLOpt.
|
||||
func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption {
|
||||
return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode())
|
||||
return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).String())
|
||||
}
|
||||
|
||||
type URLParamOpt func() []oauth2.AuthCodeOption
|
||||
|
@ -626,11 +645,15 @@ type RefreshTokenRequest struct {
|
|||
GrantType oidc.GrantType `schema:"grant_type"`
|
||||
}
|
||||
|
||||
// RefreshAccessToken performs a token refresh. If it doesn't error, it will always
|
||||
// RefreshTokens performs a token refresh. If it doesn't error, it will always
|
||||
// provide a new AccessToken. It may provide a new RefreshToken, and if it does, then
|
||||
// the old one should be considered invalid. It may also provide a new IDToken. The
|
||||
// new IDToken can be retrieved with token.Extra("id_token").
|
||||
func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) {
|
||||
// the old one should be considered invalid.
|
||||
//
|
||||
// In case the RP is not OAuth2 only and an IDToken was part of the response,
|
||||
// the IDToken and AccessToken will be verfied
|
||||
// and the IDToken and IDTokenClaims fields will be populated in the returned object.
|
||||
func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) {
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "RefreshTokens")
|
||||
request := RefreshTokenRequest{
|
||||
RefreshToken: refreshToken,
|
||||
Scopes: rp.OAuthConfig().Scopes,
|
||||
|
@ -640,17 +663,28 @@ func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAs
|
|||
ClientAssertionType: clientAssertionType,
|
||||
GrantType: oidc.GrantTypeRefreshToken,
|
||||
}
|
||||
return client.CallTokenEndpoint(request, tokenEndpointCaller{RelyingParty: rp})
|
||||
newToken, err := client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokens, err := verifyTokenResponse[C](ctx, newToken, rp)
|
||||
if err == nil || errors.Is(err, ErrMissingIDToken) {
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse
|
||||
// ...except that it might not contain an id_token.
|
||||
return tokens, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) {
|
||||
func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) {
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "EndSession")
|
||||
request := oidc.EndSessionRequest{
|
||||
IdTokenHint: idToken,
|
||||
ClientID: rp.OAuthConfig().ClientID,
|
||||
PostLogoutRedirectURI: optionalRedirectURI,
|
||||
State: optionalState,
|
||||
}
|
||||
return client.CallEndSessionEndpoint(request, nil, rp)
|
||||
return client.CallEndSessionEndpoint(ctx, request, nil, rp)
|
||||
}
|
||||
|
||||
// RevokeToken requires a RelyingParty that is also a client.RevokeCaller. The RelyingParty
|
||||
|
@ -658,7 +692,8 @@ func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState str
|
|||
// NewRelyingPartyOAuth() does not.
|
||||
//
|
||||
// tokenTypeHint should be either "id_token" or "refresh_token".
|
||||
func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error {
|
||||
func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error {
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "RevokeToken")
|
||||
request := client.RevokeRequest{
|
||||
Token: token,
|
||||
TokenTypeHint: tokenTypeHint,
|
||||
|
@ -666,7 +701,7 @@ func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error {
|
|||
ClientSecret: rp.OAuthConfig().ClientSecret,
|
||||
}
|
||||
if rc, ok := rp.(client.RevokeCaller); ok && rc.GetRevokeEndpoint() != "" {
|
||||
return client.CallRevokeEndpoint(request, nil, rc)
|
||||
return client.CallRevokeEndpoint(ctx, request, nil, rc)
|
||||
}
|
||||
return fmt.Errorf("RelyingParty does not support RevokeCaller")
|
||||
}
|
||||
|
|
107
pkg/client/rp/relying_party_test.go
Normal file
107
pkg/client/rp/relying_party_test.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
tu "github.com/zitadel/oidc/v3/internal/testutil"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func Test_verifyTokenResponse(t *testing.T) {
|
||||
verifier := &IDTokenVerifier{
|
||||
Issuer: tu.ValidIssuer,
|
||||
MaxAgeIAT: 2 * time.Minute,
|
||||
ClientID: tu.ValidClientID,
|
||||
Offset: time.Second,
|
||||
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
KeySet: tu.KeySet{},
|
||||
MaxAge: 2 * time.Minute,
|
||||
ACR: tu.ACRVerify,
|
||||
Nonce: func(context.Context) string { return tu.ValidNonce },
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
oauth2Only bool
|
||||
tokens func() (token *oauth2.Token, want *oidc.Tokens[*oidc.IDTokenClaims])
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "succes, oauth2 only",
|
||||
oauth2Only: true,
|
||||
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
|
||||
accesToken, _ := tu.ValidAccessToken()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: accesToken,
|
||||
}
|
||||
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: token,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "id_token missing error",
|
||||
oauth2Only: false,
|
||||
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
|
||||
accesToken, _ := tu.ValidAccessToken()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: accesToken,
|
||||
}
|
||||
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: token,
|
||||
}
|
||||
},
|
||||
wantErr: ErrMissingIDToken,
|
||||
},
|
||||
{
|
||||
name: "verify tokens error",
|
||||
oauth2Only: false,
|
||||
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
|
||||
accesToken, _ := tu.ValidAccessToken()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: accesToken,
|
||||
}
|
||||
token = token.WithExtra(map[string]any{
|
||||
"id_token": "foobar",
|
||||
})
|
||||
return token, nil
|
||||
},
|
||||
wantErr: oidc.ErrParse,
|
||||
},
|
||||
{
|
||||
name: "success, with id_token",
|
||||
oauth2Only: false,
|
||||
tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
|
||||
accesToken, _ := tu.ValidAccessToken()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: accesToken,
|
||||
}
|
||||
idToken, claims := tu.ValidIDToken()
|
||||
token = token.WithExtra(map[string]any{
|
||||
"id_token": idToken,
|
||||
})
|
||||
return token, &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: token,
|
||||
IDTokenClaims: claims,
|
||||
IDToken: idToken,
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rp := &relyingParty{
|
||||
oauth2Only: tt.oauth2Only,
|
||||
idTokenVerifier: verifier,
|
||||
}
|
||||
token, want := tt.tokens()
|
||||
got, err := verifyTokenResponse[*oidc.IDTokenClaims](context.Background(), token, rp)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, want, got)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -5,7 +5,7 @@ import (
|
|||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc/grants/tokenexchange"
|
||||
)
|
||||
|
||||
// TokenExchangeRP extends the `RelyingParty` interface for the *draft* oauth2 `Token Exchange`
|
||||
|
|
45
pkg/client/rp/userinfo_example_test.go
Normal file
45
pkg/client/rp/userinfo_example_test.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package rp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
type UserInfo struct {
|
||||
Subject string `json:"sub,omitempty"`
|
||||
oidc.UserInfoProfile
|
||||
oidc.UserInfoEmail
|
||||
oidc.UserInfoPhone
|
||||
Address *oidc.UserInfoAddress `json:"address,omitempty"`
|
||||
|
||||
// Foo and Bar are custom claims
|
||||
Foo string `json:"foo,omitempty"`
|
||||
Bar struct {
|
||||
Val1 string `json:"val_1,omitempty"`
|
||||
Val2 string `json:"val_2,omitempty"`
|
||||
} `json:"bar,omitempty"`
|
||||
|
||||
// Claims are all the combined claims, including custom.
|
||||
Claims map[string]any `json:"-,omitempty"`
|
||||
}
|
||||
|
||||
func (u *UserInfo) GetSubject() string {
|
||||
return u.Subject
|
||||
}
|
||||
|
||||
func ExampleUserinfo_custom() {
|
||||
rpo, err := rp.NewRelyingPartyOIDC(context.TODO(), "http://localhost:8080", "clientid", "clientsecret", "http://example.com/redirect", []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
info, err := rp.Userinfo[*UserInfo](context.TODO(), "accesstokenstring", "Bearer", "userid", rpo)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Println(info)
|
||||
}
|
|
@ -4,24 +4,14 @@ import (
|
|||
"context"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
type IDTokenVerifier interface {
|
||||
oidc.Verifier
|
||||
ClientID() string
|
||||
SupportedSignAlgs() []string
|
||||
KeySet() oidc.KeySet
|
||||
Nonce(context.Context) string
|
||||
ACR() oidc.ACRVerifier
|
||||
MaxAge() time.Duration
|
||||
}
|
||||
|
||||
// VerifyTokens implement the Token Response Validation as defined in OIDC specification
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
||||
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) {
|
||||
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier) (claims C, err error) {
|
||||
var nilClaims C
|
||||
|
||||
claims, err = VerifyIDToken[C](ctx, idToken, v)
|
||||
|
@ -36,7 +26,7 @@ func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken str
|
|||
|
||||
// VerifyIDToken validates the id token according to
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) {
|
||||
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier) (claims C, err error) {
|
||||
var nilClaims C
|
||||
|
||||
decrypted, err := oidc.DecryptToken(token)
|
||||
|
@ -52,27 +42,27 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe
|
|||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
||||
if err = oidc.CheckIssuer(claims, v.Issuer); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAudience(claims, v.ClientID()); err != nil {
|
||||
if err = oidc.CheckAudience(claims, v.ClientID); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil {
|
||||
if err = oidc.CheckAuthorizedParty(claims, v.ClientID); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
|
||||
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
|
||||
if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
|
||||
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
|
@ -80,16 +70,18 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe
|
|||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
|
||||
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
|
||||
if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
type IDTokenVerifier oidc.Verifier
|
||||
|
||||
// VerifyAccessToken validates the access token according to
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
|
||||
func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
|
||||
|
@ -107,15 +99,14 @@ func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAl
|
|||
return nil
|
||||
}
|
||||
|
||||
// NewIDTokenVerifier returns an implementation of `IDTokenVerifier`
|
||||
// for `VerifyTokens` and `VerifyIDToken`
|
||||
func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) IDTokenVerifier {
|
||||
v := &idTokenVerifier{
|
||||
issuer: issuer,
|
||||
clientID: clientID,
|
||||
keySet: keySet,
|
||||
offset: time.Second,
|
||||
nonce: func(_ context.Context) string {
|
||||
// NewIDTokenVerifier returns a oidc.Verifier suitable for ID token verification.
|
||||
func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) *IDTokenVerifier {
|
||||
v := &IDTokenVerifier{
|
||||
Issuer: issuer,
|
||||
ClientID: clientID,
|
||||
KeySet: keySet,
|
||||
Offset: time.Second,
|
||||
Nonce: func(_ context.Context) string {
|
||||
return ""
|
||||
},
|
||||
}
|
||||
|
@ -128,95 +119,47 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...
|
|||
}
|
||||
|
||||
// VerifierOption is the type for providing dynamic options to the IDTokenVerifier
|
||||
type VerifierOption func(*idTokenVerifier)
|
||||
type VerifierOption func(*IDTokenVerifier)
|
||||
|
||||
// WithIssuedAtOffset mitigates the risk of iat to be in the future
|
||||
// because of clock skews with the ability to add an offset to the current time
|
||||
func WithIssuedAtOffset(offset time.Duration) func(*idTokenVerifier) {
|
||||
return func(v *idTokenVerifier) {
|
||||
v.offset = offset
|
||||
func WithIssuedAtOffset(offset time.Duration) VerifierOption {
|
||||
return func(v *IDTokenVerifier) {
|
||||
v.Offset = offset
|
||||
}
|
||||
}
|
||||
|
||||
// WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
|
||||
func WithIssuedAtMaxAge(maxAge time.Duration) func(*idTokenVerifier) {
|
||||
return func(v *idTokenVerifier) {
|
||||
v.maxAgeIAT = maxAge
|
||||
func WithIssuedAtMaxAge(maxAge time.Duration) VerifierOption {
|
||||
return func(v *IDTokenVerifier) {
|
||||
v.MaxAgeIAT = maxAge
|
||||
}
|
||||
}
|
||||
|
||||
// WithNonce sets the function to check the nonce
|
||||
func WithNonce(nonce func(context.Context) string) VerifierOption {
|
||||
return func(v *idTokenVerifier) {
|
||||
v.nonce = nonce
|
||||
return func(v *IDTokenVerifier) {
|
||||
v.Nonce = nonce
|
||||
}
|
||||
}
|
||||
|
||||
// WithACRVerifier sets the verifier for the acr claim
|
||||
func WithACRVerifier(verifier oidc.ACRVerifier) VerifierOption {
|
||||
return func(v *idTokenVerifier) {
|
||||
v.acr = verifier
|
||||
return func(v *IDTokenVerifier) {
|
||||
v.ACR = verifier
|
||||
}
|
||||
}
|
||||
|
||||
// WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now
|
||||
func WithAuthTimeMaxAge(maxAge time.Duration) VerifierOption {
|
||||
return func(v *idTokenVerifier) {
|
||||
v.maxAge = maxAge
|
||||
return func(v *IDTokenVerifier) {
|
||||
v.MaxAge = maxAge
|
||||
}
|
||||
}
|
||||
|
||||
// WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm
|
||||
func WithSupportedSigningAlgorithms(algs ...string) VerifierOption {
|
||||
return func(v *idTokenVerifier) {
|
||||
v.supportedSignAlgs = algs
|
||||
return func(v *IDTokenVerifier) {
|
||||
v.SupportedSignAlgs = algs
|
||||
}
|
||||
}
|
||||
|
||||
type idTokenVerifier struct {
|
||||
issuer string
|
||||
maxAgeIAT time.Duration
|
||||
offset time.Duration
|
||||
clientID string
|
||||
supportedSignAlgs []string
|
||||
keySet oidc.KeySet
|
||||
acr oidc.ACRVerifier
|
||||
maxAge time.Duration
|
||||
nonce func(ctx context.Context) string
|
||||
}
|
||||
|
||||
func (i *idTokenVerifier) Issuer() string {
|
||||
return i.issuer
|
||||
}
|
||||
|
||||
func (i *idTokenVerifier) MaxAgeIAT() time.Duration {
|
||||
return i.maxAgeIAT
|
||||
}
|
||||
|
||||
func (i *idTokenVerifier) Offset() time.Duration {
|
||||
return i.offset
|
||||
}
|
||||
|
||||
func (i *idTokenVerifier) ClientID() string {
|
||||
return i.clientID
|
||||
}
|
||||
|
||||
func (i *idTokenVerifier) SupportedSignAlgs() []string {
|
||||
return i.supportedSignAlgs
|
||||
}
|
||||
|
||||
func (i *idTokenVerifier) KeySet() oidc.KeySet {
|
||||
return i.keySet
|
||||
}
|
||||
|
||||
func (i *idTokenVerifier) Nonce(ctx context.Context) string {
|
||||
return i.nonce(ctx)
|
||||
}
|
||||
|
||||
func (i *idTokenVerifier) ACR() oidc.ACRVerifier {
|
||||
return i.acr
|
||||
}
|
||||
|
||||
func (i *idTokenVerifier) MaxAge() time.Duration {
|
||||
return i.maxAge
|
||||
}
|
||||
|
|
|
@ -5,24 +5,24 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
tu "github.com/zitadel/oidc/v3/internal/testutil"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
func TestVerifyTokens(t *testing.T) {
|
||||
verifier := &idTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
maxAgeIAT: 2 * time.Minute,
|
||||
offset: time.Second,
|
||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
keySet: tu.KeySet{},
|
||||
maxAge: 2 * time.Minute,
|
||||
acr: tu.ACRVerify,
|
||||
nonce: func(context.Context) string { return tu.ValidNonce },
|
||||
clientID: tu.ValidClientID,
|
||||
verifier := &IDTokenVerifier{
|
||||
Issuer: tu.ValidIssuer,
|
||||
MaxAgeIAT: 2 * time.Minute,
|
||||
Offset: time.Second,
|
||||
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
KeySet: tu.KeySet{},
|
||||
MaxAge: 2 * time.Minute,
|
||||
ACR: tu.ACRVerify,
|
||||
Nonce: func(context.Context) string { return tu.ValidNonce },
|
||||
ClientID: tu.ValidClientID,
|
||||
}
|
||||
accessToken, _ := tu.ValidAccessToken()
|
||||
atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm)
|
||||
|
@ -91,15 +91,15 @@ func TestVerifyTokens(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestVerifyIDToken(t *testing.T) {
|
||||
verifier := &idTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
maxAgeIAT: 2 * time.Minute,
|
||||
offset: time.Second,
|
||||
supportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
keySet: tu.KeySet{},
|
||||
maxAge: 2 * time.Minute,
|
||||
acr: tu.ACRVerify,
|
||||
nonce: func(context.Context) string { return tu.ValidNonce },
|
||||
verifier := &IDTokenVerifier{
|
||||
Issuer: tu.ValidIssuer,
|
||||
MaxAgeIAT: 2 * time.Minute,
|
||||
Offset: time.Second,
|
||||
SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
|
||||
KeySet: tu.KeySet{},
|
||||
MaxAge: 2 * time.Minute,
|
||||
ACR: tu.ACRVerify,
|
||||
Nonce: func(context.Context) string { return tu.ValidNonce },
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
@ -231,7 +231,7 @@ func TestVerifyIDToken(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, want := tt.tokenClaims()
|
||||
verifier.clientID = tt.clientID
|
||||
verifier.ClientID = tt.clientID
|
||||
got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
|
@ -312,7 +312,7 @@ func TestNewIDTokenVerifier(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want IDTokenVerifier
|
||||
want *IDTokenVerifier
|
||||
}{
|
||||
{
|
||||
name: "nil nonce", // otherwise assert.Equal will fail on the function
|
||||
|
@ -329,16 +329,16 @@ func TestNewIDTokenVerifier(t *testing.T) {
|
|||
WithSupportedSigningAlgorithms("ABC", "DEF"),
|
||||
},
|
||||
},
|
||||
want: &idTokenVerifier{
|
||||
issuer: tu.ValidIssuer,
|
||||
offset: time.Minute,
|
||||
maxAgeIAT: time.Hour,
|
||||
clientID: tu.ValidClientID,
|
||||
keySet: tu.KeySet{},
|
||||
nonce: nil,
|
||||
acr: nil,
|
||||
maxAge: 2 * time.Hour,
|
||||
supportedSignAlgs: []string{"ABC", "DEF"},
|
||||
want: &IDTokenVerifier{
|
||||
Issuer: tu.ValidIssuer,
|
||||
Offset: time.Minute,
|
||||
MaxAgeIAT: time.Hour,
|
||||
ClientID: tu.ValidClientID,
|
||||
KeySet: tu.KeySet{},
|
||||
Nonce: nil,
|
||||
ACR: nil,
|
||||
MaxAge: 2 * time.Hour,
|
||||
SupportedSignAlgs: []string{"ABC", "DEF"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -4,9 +4,9 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
|
||||
tu "github.com/zitadel/oidc/v2/internal/testutil"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
tu "github.com/zitadel/oidc/v3/internal/testutil"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
// MyCustomClaims extends the TokenClaims base,
|
||||
|
|
52
pkg/client/rs/introspect_example_test.go
Normal file
52
pkg/client/rs/introspect_example_test.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package rs_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rs"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
type IntrospectionResponse struct {
|
||||
Active bool `json:"active"`
|
||||
Scope oidc.SpaceDelimitedArray `json:"scope,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
Expiration oidc.Time `json:"exp,omitempty"`
|
||||
IssuedAt oidc.Time `json:"iat,omitempty"`
|
||||
NotBefore oidc.Time `json:"nbf,omitempty"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
Audience oidc.Audience `json:"aud,omitempty"`
|
||||
Issuer string `json:"iss,omitempty"`
|
||||
JWTID string `json:"jti,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
oidc.UserInfoProfile
|
||||
oidc.UserInfoEmail
|
||||
oidc.UserInfoPhone
|
||||
Address *oidc.UserInfoAddress `json:"address,omitempty"`
|
||||
|
||||
// Foo and Bar are custom claims
|
||||
Foo string `json:"foo,omitempty"`
|
||||
Bar struct {
|
||||
Val1 string `json:"val_1,omitempty"`
|
||||
Val2 string `json:"val_2,omitempty"`
|
||||
} `json:"bar,omitempty"`
|
||||
|
||||
// Claims are all the combined claims, including custom.
|
||||
Claims map[string]any `json:"-,omitempty"`
|
||||
}
|
||||
|
||||
func ExampleIntrospect_custom() {
|
||||
rss, err := rs.NewResourceServerClientCredentials(context.TODO(), "http://localhost:8080", "clientid", "clientsecret")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
resp, err := rs.Introspect[*IntrospectionResponse](context.TODO(), rss, "accesstokenstring")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Println(resp)
|
||||
}
|
|
@ -6,9 +6,9 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
type ResourceServer interface {
|
||||
|
@ -42,14 +42,14 @@ func (r *resourceServer) AuthFn() (any, error) {
|
|||
return r.authFn()
|
||||
}
|
||||
|
||||
func NewResourceServerClientCredentials(issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) {
|
||||
func NewResourceServerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) {
|
||||
authorizer := func() (any, error) {
|
||||
return httphelper.AuthorizeBasic(clientID, clientSecret), nil
|
||||
}
|
||||
return newResourceServer(issuer, authorizer, option...)
|
||||
return newResourceServer(ctx, issuer, authorizer, option...)
|
||||
}
|
||||
|
||||
func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) {
|
||||
func NewResourceServerJWTProfile(ctx context.Context, issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) {
|
||||
signer, err := client.NewSignerFromPrivateKeyByte(key, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -61,10 +61,10 @@ func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, opt
|
|||
}
|
||||
return client.ClientAssertionFormAuthorization(assertion), nil
|
||||
}
|
||||
return newResourceServer(issuer, authorizer, options...)
|
||||
return newResourceServer(ctx, issuer, authorizer, options...)
|
||||
}
|
||||
|
||||
func newResourceServer(issuer string, authorizer func() (any, error), options ...Option) (*resourceServer, error) {
|
||||
func newResourceServer(ctx context.Context, issuer string, authorizer func() (any, error), options ...Option) (*resourceServer, error) {
|
||||
rs := &resourceServer{
|
||||
issuer: issuer,
|
||||
httpClient: httphelper.DefaultHTTPClient,
|
||||
|
@ -73,7 +73,7 @@ func newResourceServer(issuer string, authorizer func() (any, error), options ..
|
|||
optFunc(rs)
|
||||
}
|
||||
if rs.introspectURL == "" || rs.tokenURL == "" {
|
||||
config, err := client.Discover(rs.issuer, rs.httpClient)
|
||||
config, err := client.Discover(ctx, rs.issuer, rs.httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -91,12 +91,12 @@ func newResourceServer(issuer string, authorizer func() (any, error), options ..
|
|||
return rs, nil
|
||||
}
|
||||
|
||||
func NewResourceServerFromKeyFile(issuer, path string, options ...Option) (ResourceServer, error) {
|
||||
func NewResourceServerFromKeyFile(ctx context.Context, issuer, path string, options ...Option) (ResourceServer, error) {
|
||||
c, err := client.ConfigFromKeyFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewResourceServerJWTProfile(issuer, c.ClientID, c.KeyID, []byte(c.Key), options...)
|
||||
return NewResourceServerJWTProfile(ctx, issuer, c.ClientID, c.KeyID, []byte(c.Key), options...)
|
||||
}
|
||||
|
||||
type Option func(*resourceServer)
|
||||
|
@ -116,21 +116,27 @@ func WithStaticEndpoints(tokenURL, introspectURL string) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func Introspect(ctx context.Context, rp ResourceServer, token string) (*oidc.IntrospectionResponse, error) {
|
||||
// Introspect calls the [RFC7662] Token Introspection
|
||||
// endpoint and returns the response in an instance of type R.
|
||||
// [*oidc.IntrospectionResponse] can be used as a good example, or use a custom type if type-safe
|
||||
// access to custom claims is needed.
|
||||
//
|
||||
// [RFC7662]: https://www.rfc-editor.org/rfc/rfc7662
|
||||
func Introspect[R any](ctx context.Context, rp ResourceServer, token string) (resp R, err error) {
|
||||
if rp.IntrospectionURL() == "" {
|
||||
return nil, errors.New("resource server: introspection URL is empty")
|
||||
return resp, errors.New("resource server: introspection URL is empty")
|
||||
}
|
||||
authFn, err := rp.AuthFn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return resp, err
|
||||
}
|
||||
req, err := httphelper.FormRequest(rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn)
|
||||
req, err := httphelper.FormRequest(ctx, rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return resp, err
|
||||
}
|
||||
resp := new(oidc.IntrospectionResponse)
|
||||
if err := httphelper.HttpRequest(rp.HttpClient(), req, resp); err != nil {
|
||||
return nil, err
|
||||
|
||||
if err := httphelper.HttpRequest(rp.HttpClient(), req, &resp); err != nil {
|
||||
return resp, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
func TestNewResourceServer(t *testing.T) {
|
||||
|
@ -164,7 +165,7 @@ func TestNewResourceServer(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := newResourceServer(tt.args.issuer, tt.args.authorizer, tt.args.options...)
|
||||
got, err := newResourceServer(context.Background(), tt.args.issuer, tt.args.authorizer, tt.args.options...)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
|
@ -187,6 +188,7 @@ func TestIntrospect(t *testing.T) {
|
|||
token string
|
||||
}
|
||||
rp, err := newResourceServer(
|
||||
context.Background(),
|
||||
"https://accounts.spotify.com",
|
||||
nil,
|
||||
)
|
||||
|
@ -208,7 +210,7 @@ func TestIntrospect(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := Introspect(tt.args.ctx, tt.args.rp, tt.args.token)
|
||||
_, err := Introspect[*oidc.IntrospectionResponse](tt.args.ctx, tt.args.rp, tt.args.token)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
package tokenexchange
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
type TokenExchanger interface {
|
||||
|
@ -21,18 +22,18 @@ type OAuthTokenExchange struct {
|
|||
authFn func() (any, error)
|
||||
}
|
||||
|
||||
func NewTokenExchanger(issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
|
||||
return newOAuthTokenExchange(issuer, nil, options...)
|
||||
func NewTokenExchanger(ctx context.Context, issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
|
||||
return newOAuthTokenExchange(ctx, issuer, nil, options...)
|
||||
}
|
||||
|
||||
func NewTokenExchangerClientCredentials(issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
|
||||
func NewTokenExchangerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
|
||||
authorizer := func() (any, error) {
|
||||
return httphelper.AuthorizeBasic(clientID, clientSecret), nil
|
||||
}
|
||||
return newOAuthTokenExchange(issuer, authorizer, options...)
|
||||
return newOAuthTokenExchange(ctx, issuer, authorizer, options...)
|
||||
}
|
||||
|
||||
func newOAuthTokenExchange(issuer string, authorizer func() (any, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) {
|
||||
func newOAuthTokenExchange(ctx context.Context, issuer string, authorizer func() (any, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) {
|
||||
te := &OAuthTokenExchange{
|
||||
httpClient: httphelper.DefaultHTTPClient,
|
||||
}
|
||||
|
@ -41,7 +42,7 @@ func newOAuthTokenExchange(issuer string, authorizer func() (any, error), option
|
|||
}
|
||||
|
||||
if te.tokenEndpoint == "" {
|
||||
config, err := client.Discover(issuer, te.httpClient)
|
||||
config, err := client.Discover(ctx, issuer, te.httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -89,6 +90,7 @@ func (te *OAuthTokenExchange) AuthFn() (any, error) {
|
|||
// ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint.
|
||||
// SubjectToken and SubjectTokenType are required parameters.
|
||||
func ExchangeToken(
|
||||
ctx context.Context,
|
||||
te TokenExchanger,
|
||||
SubjectToken string,
|
||||
SubjectTokenType oidc.TokenType,
|
||||
|
@ -123,5 +125,5 @@ func ExchangeToken(
|
|||
RequestedTokenType: RequestedTokenType,
|
||||
}
|
||||
|
||||
return client.CallTokenExchangeEndpoint(request, authFn, te)
|
||||
return client.CallTokenExchangeEndpoint(ctx, request, authFn, te)
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"fmt"
|
||||
"hash"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
)
|
||||
|
||||
var ErrUnsupportedAlgorithm = errors.New("unsupported signing algorithm")
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
)
|
||||
|
||||
func Sign(object any, signer jose.Signer) (string, error) {
|
||||
|
|
|
@ -33,7 +33,7 @@ func AuthorizeBasic(user, password string) RequestAuthorization {
|
|||
}
|
||||
}
|
||||
|
||||
func FormRequest(endpoint string, request any, encoder Encoder, authFn any) (*http.Request, error) {
|
||||
func FormRequest(ctx context.Context, endpoint string, request any, encoder Encoder, authFn any) (*http.Request, error) {
|
||||
form := url.Values{}
|
||||
if err := encoder.Encode(request, form); err != nil {
|
||||
return nil, err
|
||||
|
@ -42,7 +42,7 @@ func FormRequest(endpoint string, request any, encoder Encoder, authFn any) (*ht
|
|||
fn(form)
|
||||
}
|
||||
body := strings.NewReader(form.Encode())
|
||||
req, err := http.NewRequest("POST", endpoint, body)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
const (
|
||||
// ScopeOpenID defines the scope `openid`
|
||||
// OpenID Connect requests MUST contain the `openid` scope value
|
||||
|
@ -77,7 +81,7 @@ type AuthRequest struct {
|
|||
UILocales Locales `json:"ui_locales" schema:"ui_locales"`
|
||||
IDTokenHint string `json:"id_token_hint" schema:"id_token_hint"`
|
||||
LoginHint string `json:"login_hint" schema:"login_hint"`
|
||||
ACRValues []string `json:"acr_values" schema:"acr_values"`
|
||||
ACRValues SpaceDelimitedArray `json:"acr_values" schema:"acr_values"`
|
||||
|
||||
CodeChallenge string `json:"code_challenge" schema:"code_challenge"`
|
||||
CodeChallengeMethod CodeChallengeMethod `json:"code_challenge_method" schema:"code_challenge_method"`
|
||||
|
@ -86,6 +90,15 @@ type AuthRequest struct {
|
|||
RequestParam string `schema:"request"`
|
||||
}
|
||||
|
||||
func (a *AuthRequest) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.Any("scopes", a.Scopes),
|
||||
slog.String("response_type", string(a.ResponseType)),
|
||||
slog.String("client_id", a.ClientID),
|
||||
slog.String("redirect_uri", a.RedirectURI),
|
||||
)
|
||||
}
|
||||
|
||||
// GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface
|
||||
func (a *AuthRequest) GetRedirectURI() string {
|
||||
return a.RedirectURI
|
||||
|
|
27
pkg/oidc/authorization_test.go
Normal file
27
pkg/oidc/authorization_test.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
//go:build go1.20
|
||||
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func TestAuthRequest_LogValue(t *testing.T) {
|
||||
a := &AuthRequest{
|
||||
Scopes: SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "respType",
|
||||
ClientID: "123",
|
||||
RedirectURI: "http://example.com/callback",
|
||||
}
|
||||
want := slog.GroupValue(
|
||||
slog.Any("scopes", SpaceDelimitedArray{"a", "b"}),
|
||||
slog.String("response_type", "respType"),
|
||||
slog.String("client_id", "123"),
|
||||
slog.String("redirect_uri", "http://example.com/callback"),
|
||||
)
|
||||
got := a.LogValue()
|
||||
assert.Equal(t, want, got)
|
||||
}
|
|
@ -3,7 +3,7 @@ package oidc
|
|||
import (
|
||||
"crypto/sha256"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/crypto"
|
||||
"github.com/zitadel/oidc/v3/pkg/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -3,6 +3,8 @@ package oidc
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
type errorType string
|
||||
|
@ -171,3 +173,34 @@ func DefaultToServerError(err error, description string) *Error {
|
|||
}
|
||||
return oauth
|
||||
}
|
||||
|
||||
func (e *Error) LogLevel() slog.Level {
|
||||
level := slog.LevelWarn
|
||||
if e.ErrorType == ServerError {
|
||||
level = slog.LevelError
|
||||
}
|
||||
if e.ErrorType == AuthorizationPending {
|
||||
level = slog.LevelInfo
|
||||
}
|
||||
return level
|
||||
}
|
||||
|
||||
func (e *Error) LogValue() slog.Value {
|
||||
attrs := make([]slog.Attr, 0, 5)
|
||||
if e.Parent != nil {
|
||||
attrs = append(attrs, slog.Any("parent", e.Parent))
|
||||
}
|
||||
if e.Description != "" {
|
||||
attrs = append(attrs, slog.String("description", e.Description))
|
||||
}
|
||||
if e.ErrorType != "" {
|
||||
attrs = append(attrs, slog.String("type", string(e.ErrorType)))
|
||||
}
|
||||
if e.State != "" {
|
||||
attrs = append(attrs, slog.String("state", e.State))
|
||||
}
|
||||
if e.redirectDisabled {
|
||||
attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled))
|
||||
}
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
|
83
pkg/oidc/error_go120_test.go
Normal file
83
pkg/oidc/error_go120_test.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
//go:build go1.20
|
||||
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func TestError_LogValue(t *testing.T) {
|
||||
type fields struct {
|
||||
Parent error
|
||||
ErrorType errorType
|
||||
Description string
|
||||
State string
|
||||
redirectDisabled bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want slog.Value
|
||||
}{
|
||||
{
|
||||
name: "parent",
|
||||
fields: fields{
|
||||
Parent: io.EOF,
|
||||
},
|
||||
want: slog.GroupValue(slog.Any("parent", io.EOF)),
|
||||
},
|
||||
{
|
||||
name: "description",
|
||||
fields: fields{
|
||||
Description: "oops",
|
||||
},
|
||||
want: slog.GroupValue(slog.String("description", "oops")),
|
||||
},
|
||||
{
|
||||
name: "errorType",
|
||||
fields: fields{
|
||||
ErrorType: ExpiredToken,
|
||||
},
|
||||
want: slog.GroupValue(slog.String("type", string(ExpiredToken))),
|
||||
},
|
||||
{
|
||||
name: "state",
|
||||
fields: fields{
|
||||
State: "123",
|
||||
},
|
||||
want: slog.GroupValue(slog.String("state", "123")),
|
||||
},
|
||||
{
|
||||
name: "all fields",
|
||||
fields: fields{
|
||||
Parent: io.EOF,
|
||||
Description: "oops",
|
||||
ErrorType: ExpiredToken,
|
||||
State: "123",
|
||||
},
|
||||
want: slog.GroupValue(
|
||||
slog.Any("parent", io.EOF),
|
||||
slog.String("description", "oops"),
|
||||
slog.String("type", string(ExpiredToken)),
|
||||
slog.String("state", "123"),
|
||||
),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := &Error{
|
||||
Parent: tt.fields.Parent,
|
||||
ErrorType: tt.fields.ErrorType,
|
||||
Description: tt.fields.Description,
|
||||
State: tt.fields.State,
|
||||
redirectDisabled: tt.fields.redirectDisabled,
|
||||
}
|
||||
got := e.LogValue()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
81
pkg/oidc/error_test.go
Normal file
81
pkg/oidc/error_test.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func TestDefaultToServerError(t *testing.T) {
|
||||
type args struct {
|
||||
err error
|
||||
description string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *Error
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
args: args{
|
||||
err: io.ErrClosedPipe,
|
||||
description: "oops",
|
||||
},
|
||||
want: &Error{
|
||||
ErrorType: ServerError,
|
||||
Description: "oops",
|
||||
Parent: io.ErrClosedPipe,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "our Error",
|
||||
args: args{
|
||||
err: ErrAccessDenied(),
|
||||
description: "oops",
|
||||
},
|
||||
want: &Error{
|
||||
ErrorType: AccessDenied,
|
||||
Description: "The authorization request was denied.",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := DefaultToServerError(tt.args.err, tt.args.description)
|
||||
assert.ErrorIs(t, got, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestError_LogLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *Error
|
||||
want slog.Level
|
||||
}{
|
||||
{
|
||||
name: "server error",
|
||||
err: ErrServerError(),
|
||||
want: slog.LevelError,
|
||||
},
|
||||
{
|
||||
name: "authorization pending",
|
||||
err: ErrAuthorizationPending(),
|
||||
want: slog.LevelInfo,
|
||||
},
|
||||
{
|
||||
name: "some other error",
|
||||
err: ErrAccessDenied(),
|
||||
want: slog.LevelWarn,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.err.LogLevel()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -7,7 +7,7 @@ import (
|
|||
"crypto/rsa"
|
||||
"errors"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"reflect"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
)
|
||||
|
||||
func TestFindKey(t *testing.T) {
|
||||
|
|
|
@ -5,11 +5,11 @@ import (
|
|||
"os"
|
||||
"time"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/zitadel/oidc/v2/pkg/crypto"
|
||||
"github.com/zitadel/oidc/v3/pkg/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -4,9 +4,9 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/text/language"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
|
@ -8,10 +8,10 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/zitadel/schema"
|
||||
"golang.org/x/text/language"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
type Audience []string
|
||||
|
@ -151,7 +151,7 @@ type ResponseType string
|
|||
|
||||
type ResponseMode string
|
||||
|
||||
func (s SpaceDelimitedArray) Encode() string {
|
||||
func (s SpaceDelimitedArray) String() string {
|
||||
return strings.Join(s, " ")
|
||||
}
|
||||
|
||||
|
@ -161,11 +161,11 @@ func (s *SpaceDelimitedArray) UnmarshalText(text []byte) error {
|
|||
}
|
||||
|
||||
func (s SpaceDelimitedArray) MarshalText() ([]byte, error) {
|
||||
return []byte(s.Encode()), nil
|
||||
return []byte(s.String()), nil
|
||||
}
|
||||
|
||||
func (s SpaceDelimitedArray) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal((s).Encode())
|
||||
return json.Marshal((s).String())
|
||||
}
|
||||
|
||||
func (s *SpaceDelimitedArray) UnmarshalJSON(data []byte) error {
|
||||
|
@ -210,7 +210,7 @@ func (s SpaceDelimitedArray) Value() (driver.Value, error) {
|
|||
func NewEncoder() *schema.Encoder {
|
||||
e := schema.NewEncoder()
|
||||
e.RegisterEncoder(SpaceDelimitedArray{}, func(value reflect.Value) string {
|
||||
return value.Interface().(SpaceDelimitedArray).Encode()
|
||||
return value.Interface().(SpaceDelimitedArray).String()
|
||||
})
|
||||
return e
|
||||
}
|
||||
|
|
|
@ -9,9 +9,9 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/schema"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
|
|
|
@ -29,6 +29,11 @@ func (u *UserInfo) GetAddress() *UserInfoAddress {
|
|||
return u.Address
|
||||
}
|
||||
|
||||
// GetSubject implements [rp.SubjectGetter]
|
||||
func (u *UserInfo) GetSubject() string {
|
||||
return u.Subject
|
||||
}
|
||||
|
||||
type uiAlias UserInfo
|
||||
|
||||
func (u *UserInfo) MarshalJSON() ([]byte, error) {
|
||||
|
|
|
@ -10,9 +10,9 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
|
||||
str "github.com/zitadel/oidc/v2/pkg/strings"
|
||||
str "github.com/zitadel/oidc/v3/pkg/strings"
|
||||
)
|
||||
|
||||
type Claims interface {
|
||||
|
@ -61,10 +61,19 @@ var (
|
|||
ErrAtHash = errors.New("at_hash does not correspond to access token")
|
||||
)
|
||||
|
||||
type Verifier interface {
|
||||
Issuer() string
|
||||
MaxAgeIAT() time.Duration
|
||||
Offset() time.Duration
|
||||
// Verifier caries configuration for the various token verification
|
||||
// functions. Use package specific constructor functions to know
|
||||
// which values need to be set.
|
||||
type Verifier struct {
|
||||
Issuer string
|
||||
MaxAgeIAT time.Duration
|
||||
Offset time.Duration
|
||||
ClientID string
|
||||
SupportedSignAlgs []string
|
||||
MaxAge time.Duration
|
||||
ACR ACRVerifier
|
||||
KeySet KeySet
|
||||
Nonce func(ctx context.Context) string
|
||||
}
|
||||
|
||||
// ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
|
||||
|
@ -121,6 +130,11 @@ func CheckAudience(claims Claims, clientID string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// CheckAuthorizedParty checks azp (authorized party) claim requirements.
|
||||
//
|
||||
// If the ID Token contains multiple audiences, the Client SHOULD verify that an azp Claim is present.
|
||||
// If an azp Claim is present, the Client SHOULD verify that its client_id is the Claim Value.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func CheckAuthorizedParty(claims Claims, clientID string) error {
|
||||
if len(claims.GetAudience()) > 1 {
|
||||
if claims.GetAuthorizedParty() == "" {
|
||||
|
@ -167,26 +181,26 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl
|
|||
}
|
||||
|
||||
func CheckExpiration(claims Claims, offset time.Duration) error {
|
||||
expiration := claims.GetExpiration().Round(time.Second)
|
||||
if !time.Now().UTC().Add(offset).Before(expiration) {
|
||||
expiration := claims.GetExpiration()
|
||||
if !time.Now().Add(offset).Before(expiration) {
|
||||
return ErrExpired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error {
|
||||
issuedAt := claims.GetIssuedAt().Round(time.Second)
|
||||
issuedAt := claims.GetIssuedAt()
|
||||
if issuedAt.IsZero() {
|
||||
return ErrIatMissing
|
||||
}
|
||||
nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second)
|
||||
nowWithOffset := time.Now().Add(offset).Round(time.Second)
|
||||
if issuedAt.After(nowWithOffset) {
|
||||
return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset)
|
||||
}
|
||||
if maxAgeIAT == 0 {
|
||||
return nil
|
||||
}
|
||||
maxAge := time.Now().UTC().Add(-maxAgeIAT).Round(time.Second)
|
||||
maxAge := time.Now().Add(-maxAgeIAT).Round(time.Second)
|
||||
if issuedAt.Before(maxAge) {
|
||||
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt))
|
||||
}
|
||||
|
@ -216,8 +230,8 @@ func CheckAuthTime(claims Claims, maxAge time.Duration) error {
|
|||
if claims.GetAuthTime().IsZero() {
|
||||
return ErrAuthTimeNotPresent
|
||||
}
|
||||
authTime := claims.GetAuthTime().Round(time.Second)
|
||||
maxAuthTime := time.Now().UTC().Add(-maxAge).Round(time.Second)
|
||||
authTime := claims.GetAuthTime()
|
||||
maxAuthTime := time.Now().Add(-maxAge).Round(time.Second)
|
||||
if authTime.Before(maxAuthTime) {
|
||||
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime))
|
||||
}
|
||||
|
|
128
pkg/oidc/verifier_parse_test.go
Normal file
128
pkg/oidc/verifier_parse_test.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
package oidc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
tu "github.com/zitadel/oidc/v3/internal/testutil"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
func TestParseToken(t *testing.T) {
|
||||
token, wantClaims := tu.ValidIDToken()
|
||||
wantClaims.SignatureAlg = "" // unset, because is not part of the JSON payload
|
||||
|
||||
wantPayload, err := json.Marshal(wantClaims)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenString string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "split error",
|
||||
tokenString: "nope",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "base64 error",
|
||||
tokenString: "foo.~.bar",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
tokenString: token,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotClaims := new(oidc.IDTokenClaims)
|
||||
gotPayload, err := oidc.ParseToken(tt.tokenString, gotClaims)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, wantClaims, gotClaims)
|
||||
assert.JSONEq(t, string(wantPayload), string(gotPayload))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckSignature(t *testing.T) {
|
||||
errCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
token, _ := tu.ValidIDToken()
|
||||
payload, err := oidc.ParseToken(token, &oidc.IDTokenClaims{})
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
token string
|
||||
payload []byte
|
||||
supportedSigAlgs []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "parse error",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
token: "~",
|
||||
payload: payload,
|
||||
},
|
||||
wantErr: oidc.ErrParse,
|
||||
},
|
||||
{
|
||||
name: "default sigAlg",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
token: token,
|
||||
payload: payload,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unsupported sigAlg",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
token: token,
|
||||
payload: payload,
|
||||
supportedSigAlgs: []string{"foo", "bar"},
|
||||
},
|
||||
wantErr: oidc.ErrSignatureUnsupportedAlg,
|
||||
},
|
||||
{
|
||||
name: "verify error",
|
||||
args: args{
|
||||
ctx: errCtx,
|
||||
token: token,
|
||||
payload: payload,
|
||||
},
|
||||
wantErr: oidc.ErrSignatureInvalid,
|
||||
},
|
||||
{
|
||||
name: "inequal payloads",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
token: token,
|
||||
payload: []byte{0, 1, 2},
|
||||
},
|
||||
wantErr: oidc.ErrSignatureInvalidPayload,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
claims := new(oidc.TokenClaims)
|
||||
err := oidc.CheckSignature(tt.args.ctx, tt.args.token, tt.args.payload, claims, tt.args.supportedSigAlgs, tu.KeySet{})
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
374
pkg/oidc/verifier_test.go
Normal file
374
pkg/oidc/verifier_test.go
Normal file
|
@ -0,0 +1,374 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecryptToken(t *testing.T) {
|
||||
const tokenString = "ABC"
|
||||
got, err := DecryptToken(tokenString)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tokenString, got)
|
||||
}
|
||||
|
||||
func TestDefaultACRVerifier(t *testing.T) {
|
||||
acrVerfier := DefaultACRVerifier([]string{"foo", "bar"})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
acr string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
acr: "bar",
|
||||
},
|
||||
{
|
||||
name: "error",
|
||||
acr: "hello",
|
||||
wantErr: "expected one of: [foo bar], got: \"hello\"",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := acrVerfier(tt.acr)
|
||||
if tt.wantErr != "" {
|
||||
assert.EqualError(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckSubject(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
claims Claims
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "missing",
|
||||
claims: &TokenClaims{},
|
||||
wantErr: ErrSubjectMissing,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
claims: &TokenClaims{
|
||||
Subject: "foo",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := CheckSubject(tt.claims)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckIssuer(t *testing.T) {
|
||||
const issuer = "foo.bar"
|
||||
tests := []struct {
|
||||
name string
|
||||
claims Claims
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "missing",
|
||||
claims: &TokenClaims{},
|
||||
wantErr: ErrIssuerInvalid,
|
||||
},
|
||||
{
|
||||
name: "wrong",
|
||||
claims: &TokenClaims{
|
||||
Issuer: "wrong",
|
||||
},
|
||||
wantErr: ErrIssuerInvalid,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
claims: &TokenClaims{
|
||||
Issuer: issuer,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := CheckIssuer(tt.claims, issuer)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAudience(t *testing.T) {
|
||||
const clientID = "foo.bar"
|
||||
tests := []struct {
|
||||
name string
|
||||
claims Claims
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "missing",
|
||||
claims: &TokenClaims{},
|
||||
wantErr: ErrAudience,
|
||||
},
|
||||
{
|
||||
name: "wrong",
|
||||
claims: &TokenClaims{
|
||||
Audience: []string{"wrong"},
|
||||
},
|
||||
wantErr: ErrAudience,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
claims: &TokenClaims{
|
||||
Audience: []string{clientID},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := CheckAudience(tt.claims, clientID)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAuthorizedParty(t *testing.T) {
|
||||
const clientID = "foo.bar"
|
||||
tests := []struct {
|
||||
name string
|
||||
claims Claims
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "single audience, no azp",
|
||||
claims: &TokenClaims{
|
||||
Audience: []string{clientID},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple audience, no azp",
|
||||
claims: &TokenClaims{
|
||||
Audience: []string{clientID, "other"},
|
||||
},
|
||||
wantErr: ErrAzpMissing,
|
||||
},
|
||||
{
|
||||
name: "single audience, with azp",
|
||||
claims: &TokenClaims{
|
||||
Audience: []string{clientID},
|
||||
AuthorizedParty: clientID,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple audience, with azp",
|
||||
claims: &TokenClaims{
|
||||
Audience: []string{clientID, "other"},
|
||||
AuthorizedParty: clientID,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "wrong azp",
|
||||
claims: &TokenClaims{
|
||||
AuthorizedParty: "wrong",
|
||||
},
|
||||
wantErr: ErrAzpInvalid,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := CheckAuthorizedParty(tt.claims, clientID)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckExpiration(t *testing.T) {
|
||||
const offset = time.Minute
|
||||
tests := []struct {
|
||||
name string
|
||||
claims Claims
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "missing",
|
||||
claims: &TokenClaims{},
|
||||
wantErr: ErrExpired,
|
||||
},
|
||||
{
|
||||
name: "expired",
|
||||
claims: &TokenClaims{
|
||||
Expiration: FromTime(time.Now().Add(-2 * offset)),
|
||||
},
|
||||
wantErr: ErrExpired,
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
claims: &TokenClaims{
|
||||
Expiration: FromTime(time.Now().Add(2 * offset)),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := CheckExpiration(tt.claims, offset)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckIssuedAt(t *testing.T) {
|
||||
const offset = time.Minute
|
||||
tests := []struct {
|
||||
name string
|
||||
maxAgeIAT time.Duration
|
||||
claims Claims
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "missing",
|
||||
claims: &TokenClaims{},
|
||||
wantErr: ErrIatMissing,
|
||||
},
|
||||
{
|
||||
name: "future",
|
||||
claims: &TokenClaims{
|
||||
IssuedAt: FromTime(time.Now().Add(time.Hour)),
|
||||
},
|
||||
wantErr: ErrIatInFuture,
|
||||
},
|
||||
{
|
||||
name: "no max",
|
||||
claims: &TokenClaims{
|
||||
IssuedAt: FromTime(time.Now()),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "past max",
|
||||
maxAgeIAT: time.Minute,
|
||||
claims: &TokenClaims{
|
||||
IssuedAt: FromTime(time.Now().Add(-time.Hour)),
|
||||
},
|
||||
wantErr: ErrIatToOld,
|
||||
},
|
||||
{
|
||||
name: "within max",
|
||||
maxAgeIAT: time.Hour,
|
||||
claims: &TokenClaims{
|
||||
IssuedAt: FromTime(time.Now()),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := CheckIssuedAt(tt.claims, tt.maxAgeIAT, offset)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckNonce(t *testing.T) {
|
||||
const nonce = "123"
|
||||
tests := []struct {
|
||||
name string
|
||||
claims Claims
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "missing",
|
||||
claims: &TokenClaims{},
|
||||
wantErr: ErrNonceInvalid,
|
||||
},
|
||||
{
|
||||
name: "wrong",
|
||||
claims: &TokenClaims{
|
||||
Nonce: "wrong",
|
||||
},
|
||||
wantErr: ErrNonceInvalid,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
claims: &TokenClaims{
|
||||
Nonce: nonce,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := CheckNonce(tt.claims, nonce)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAuthorizationContextClassReference(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
acr ACRVerifier
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "error",
|
||||
acr: func(s string) error { return errors.New("oops") },
|
||||
wantErr: ErrAcrInvalid,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
acr: func(s string) error { return nil },
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := CheckAuthorizationContextClassReference(&IDTokenClaims{}, tt.acr)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAuthTime(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
claims Claims
|
||||
maxAge time.Duration
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "no max age",
|
||||
claims: &TokenClaims{},
|
||||
},
|
||||
{
|
||||
name: "missing",
|
||||
claims: &TokenClaims{},
|
||||
maxAge: time.Minute,
|
||||
wantErr: ErrAuthTimeNotPresent,
|
||||
},
|
||||
{
|
||||
name: "expired",
|
||||
maxAge: time.Minute,
|
||||
claims: &TokenClaims{
|
||||
AuthTime: FromTime(time.Now().Add(-time.Hour)),
|
||||
},
|
||||
wantErr: ErrAuthTimeToOld,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
maxAge: time.Minute,
|
||||
claims: &TokenClaims{
|
||||
AuthTime: NowTime(),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := CheckAuthTime(tt.claims, tt.maxAge)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -2,6 +2,7 @@ package op
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -10,11 +11,10 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
str "github.com/zitadel/oidc/v2/pkg/strings"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
str "github.com/zitadel/oidc/v3/pkg/strings"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
type AuthRequest interface {
|
||||
|
@ -39,16 +39,17 @@ type Authorizer interface {
|
|||
Storage() Storage
|
||||
Decoder() httphelper.Decoder
|
||||
Encoder() httphelper.Encoder
|
||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
||||
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
|
||||
Crypto() Crypto
|
||||
RequestObjectSupported() bool
|
||||
Logger() *slog.Logger
|
||||
}
|
||||
|
||||
// AuthorizeValidator is an extension of Authorizer interface
|
||||
// implementing its own validation mechanism for the auth request
|
||||
type AuthorizeValidator interface {
|
||||
Authorizer
|
||||
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error)
|
||||
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, *IDTokenHintVerifier) (string, error)
|
||||
}
|
||||
|
||||
func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) {
|
||||
|
@ -68,23 +69,23 @@ func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, *
|
|||
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||
authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, nil, err, authorizer)
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
|
||||
authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
|
||||
err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
}
|
||||
if authReq.ClientID == "" {
|
||||
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer)
|
||||
return
|
||||
}
|
||||
if authReq.RedirectURI == "" {
|
||||
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer)
|
||||
return
|
||||
}
|
||||
validation := ValidateAuthRequest
|
||||
|
@ -93,21 +94,21 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
|||
}
|
||||
userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx))
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
if authReq.RequestParam != "" {
|
||||
AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer)
|
||||
return
|
||||
}
|
||||
req, err := authorizer.Storage().CreateAuthRequest(ctx, authReq, userID)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer)
|
||||
return
|
||||
}
|
||||
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder())
|
||||
AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer)
|
||||
return
|
||||
}
|
||||
RedirectToLogin(req.GetID(), client, w, r)
|
||||
|
@ -129,31 +130,31 @@ func ParseAuthorizeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.A
|
|||
|
||||
// ParseRequestObject parse the `request` parameter, validates the token including the signature
|
||||
// and copies the token claims into the auth request
|
||||
func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) (*oidc.AuthRequest, error) {
|
||||
func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) error {
|
||||
requestObject := new(oidc.RequestObject)
|
||||
payload, err := oidc.ParseToken(authReq.RequestParam, requestObject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID {
|
||||
return authReq, oidc.ErrInvalidRequest()
|
||||
return oidc.ErrInvalidRequest()
|
||||
}
|
||||
if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType {
|
||||
return authReq, oidc.ErrInvalidRequest()
|
||||
return oidc.ErrInvalidRequest()
|
||||
}
|
||||
if requestObject.Issuer != requestObject.ClientID {
|
||||
return authReq, oidc.ErrInvalidRequest()
|
||||
return oidc.ErrInvalidRequest()
|
||||
}
|
||||
if !str.Contains(requestObject.Audience, issuer) {
|
||||
return authReq, oidc.ErrInvalidRequest()
|
||||
return oidc.ErrInvalidRequest()
|
||||
}
|
||||
keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer}
|
||||
if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil {
|
||||
return authReq, err
|
||||
return err
|
||||
}
|
||||
CopyRequestObjectToAuthRequest(authReq, requestObject)
|
||||
return authReq, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request
|
||||
|
@ -205,7 +206,7 @@ func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oi
|
|||
}
|
||||
|
||||
// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed
|
||||
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) {
|
||||
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) {
|
||||
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -385,7 +386,7 @@ func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType)
|
|||
|
||||
// ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request)
|
||||
// and returns the `sub` claim
|
||||
func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier IDTokenHintVerifier) (string, error) {
|
||||
func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier *IDTokenHintVerifier) (string, error) {
|
||||
if idTokenHint == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
@ -405,32 +406,41 @@ func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *
|
|||
|
||||
// AuthorizeCallback handles the callback after authentication in the Login UI
|
||||
func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||
params := mux.Vars(r)
|
||||
id := params["id"]
|
||||
if id == "" {
|
||||
AuthRequestError(w, r, nil, fmt.Errorf("auth request callback is missing id"), authorizer.Encoder())
|
||||
id, err := ParseAuthorizeCallbackRequest(r)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, err, authorizer)
|
||||
return
|
||||
}
|
||||
|
||||
authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, nil, err, authorizer)
|
||||
return
|
||||
}
|
||||
if !authReq.Done() {
|
||||
AuthRequestError(w, r, authReq,
|
||||
oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."),
|
||||
authorizer.Encoder())
|
||||
authorizer)
|
||||
return
|
||||
}
|
||||
AuthResponse(authReq, authorizer, w, r)
|
||||
}
|
||||
|
||||
func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) {
|
||||
if err = r.ParseForm(); err != nil {
|
||||
return "", fmt.Errorf("cannot parse form: %w", err)
|
||||
}
|
||||
id = r.Form.Get("id")
|
||||
if id == "" {
|
||||
return "", errors.New("auth request callback is missing id")
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// AuthResponse creates the successful authentication response (either code or tokens)
|
||||
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
|
||||
client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||
|
@ -444,7 +454,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
|
|||
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
|
||||
code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
codeResponse := struct {
|
||||
|
@ -456,7 +466,7 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques
|
|||
}
|
||||
callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
|
@ -471,12 +481,12 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque
|
|||
createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly
|
||||
resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "", "")
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
|
|
|
@ -11,14 +11,16 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v2/example/server/storage"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v2/pkg/op/mock"
|
||||
"github.com/zitadel/oidc/v3/example/server/storage"
|
||||
tu "github.com/zitadel/oidc/v3/internal/testutil"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/pkg/op/mock"
|
||||
"github.com/zitadel/schema"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func TestAuthorize(t *testing.T) {
|
||||
|
@ -39,7 +41,7 @@ func TestAuthorize(t *testing.T) {
|
|||
|
||||
expect := authorizer.EXPECT()
|
||||
expect.Decoder().Return(schema.NewDecoder())
|
||||
expect.Encoder().Return(schema.NewEncoder())
|
||||
expect.Logger().Return(slog.Default())
|
||||
|
||||
if tt.expect != nil {
|
||||
tt.expect(expect)
|
||||
|
@ -123,7 +125,7 @@ func TestValidateAuthRequest(t *testing.T) {
|
|||
type args struct {
|
||||
authRequest *oidc.AuthRequest
|
||||
storage op.Storage
|
||||
verifier op.IDTokenHintVerifier
|
||||
verifier *op.IDTokenHintVerifier
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -996,7 +998,7 @@ func TestAuthResponseCode(t *testing.T) {
|
|||
authorizer.EXPECT().Crypto().Return(&mockCrypto{
|
||||
returnErr: io.ErrClosedPipe,
|
||||
})
|
||||
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
|
||||
authorizer.EXPECT().Logger().Return(slog.Default())
|
||||
return authorizer
|
||||
},
|
||||
},
|
||||
|
@ -1071,3 +1073,71 @@ func TestAuthResponseCode(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseAuthorizeCallbackRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantId string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "parse error",
|
||||
url: "/?id;=99",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing id",
|
||||
url: "/",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
url: "/?id=99",
|
||||
wantId: "99",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, tt.url, nil)
|
||||
gotId, err := op.ParseAuthorizeCallbackRequest(r)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.wantId, gotId)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAuthReqIDTokenHint(t *testing.T) {
|
||||
token, _ := tu.ValidIDToken()
|
||||
tests := []struct {
|
||||
name string
|
||||
idTokenHint string
|
||||
want string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
},
|
||||
{
|
||||
name: "verify err",
|
||||
idTokenHint: "foo",
|
||||
wantErr: oidc.ErrLoginRequired(),
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
idTokenHint: token,
|
||||
want: tu.ValidSubject,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := op.ValidateAuthReqIDTokenHint(context.Background(), tt.idTokenHint, op.NewIDTokenHintVerifier(tu.ValidIssuer, tu.KeySet{}))
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,8 +7,8 @@ import (
|
|||
"net/url"
|
||||
"time"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
//go:generate go get github.com/dmarkham/enumer
|
||||
|
@ -87,7 +87,7 @@ var (
|
|||
)
|
||||
|
||||
type ClientJWTProfile interface {
|
||||
JWTProfileVerifier(context.Context) JWTProfileVerifier
|
||||
JWTProfileVerifier(context.Context) *JWTProfileVerifier
|
||||
}
|
||||
|
||||
func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) {
|
||||
|
@ -180,3 +180,10 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au
|
|||
}
|
||||
return data.ClientID, false, nil
|
||||
}
|
||||
|
||||
type ClientCredentials struct {
|
||||
ClientID string `schema:"client_id"`
|
||||
ClientSecret string `schema:"client_secret"` // Client secret from Basic auth or request body
|
||||
ClientAssertion string `schema:"client_assertion"` // JWT
|
||||
ClientAssertionType string `schema:"client_assertion_type"`
|
||||
}
|
||||
|
|
|
@ -11,18 +11,18 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v2/pkg/op/mock"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/pkg/op/mock"
|
||||
"github.com/zitadel/schema"
|
||||
)
|
||||
|
||||
type testClientJWTProfile struct{}
|
||||
|
||||
func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil }
|
||||
func (testClientJWTProfile) JWTProfileVerifier(context.Context) *op.JWTProfileVerifier { return nil }
|
||||
|
||||
func TestClientJWTAuth(t *testing.T) {
|
||||
type args struct {
|
||||
|
|
|
@ -22,14 +22,14 @@ var (
|
|||
type Configuration interface {
|
||||
IssuerFromRequest(r *http.Request) string
|
||||
Insecure() bool
|
||||
AuthorizationEndpoint() Endpoint
|
||||
TokenEndpoint() Endpoint
|
||||
IntrospectionEndpoint() Endpoint
|
||||
UserinfoEndpoint() Endpoint
|
||||
RevocationEndpoint() Endpoint
|
||||
EndSessionEndpoint() Endpoint
|
||||
KeysEndpoint() Endpoint
|
||||
DeviceAuthorizationEndpoint() Endpoint
|
||||
AuthorizationEndpoint() *Endpoint
|
||||
TokenEndpoint() *Endpoint
|
||||
IntrospectionEndpoint() *Endpoint
|
||||
UserinfoEndpoint() *Endpoint
|
||||
RevocationEndpoint() *Endpoint
|
||||
EndSessionEndpoint() *Endpoint
|
||||
KeysEndpoint() *Endpoint
|
||||
DeviceAuthorizationEndpoint() *Endpoint
|
||||
|
||||
AuthMethodPostSupported() bool
|
||||
CodeMethodS256Supported() bool
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"github.com/zitadel/oidc/v2/pkg/crypto"
|
||||
"github.com/zitadel/oidc/v3/pkg/crypto"
|
||||
)
|
||||
|
||||
type Crypto interface {
|
||||
|
|
|
@ -12,8 +12,8 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
type DeviceAuthorizationConfig struct {
|
||||
|
@ -57,47 +57,57 @@ var (
|
|||
func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := DeviceAuthorization(w, r, o); err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, o.Logger())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error {
|
||||
storage, err := assertDeviceStorage(o.Storage())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := ParseDeviceCodeRequest(r, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response, err := createDeviceAuthorization(r.Context(), req, req.ClientID, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httphelper.MarshalJSON(w, response)
|
||||
return nil
|
||||
}
|
||||
|
||||
func createDeviceAuthorization(ctx context.Context, req *oidc.DeviceAuthorizationRequest, clientID string, o OpenIDProvider) (*oidc.DeviceAuthorizationResponse, error) {
|
||||
storage, err := assertDeviceStorage(o.Storage())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config := o.DeviceAuthorization()
|
||||
|
||||
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
expires := time.Now().Add(config.Lifetime)
|
||||
err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes)
|
||||
err = storage.StoreDeviceAuthorization(ctx, clientID, deviceCode, userCode, expires, req.Scopes)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var verification *url.URL
|
||||
if config.UserFormURL != "" {
|
||||
if verification, err = url.Parse(config.UserFormURL); err != nil {
|
||||
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
|
||||
err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
} else {
|
||||
if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil {
|
||||
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
|
||||
if verification, err = url.Parse(IssuerFromContext(ctx)); err != nil {
|
||||
err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
verification.Path = config.UserFormPath
|
||||
}
|
||||
|
@ -112,9 +122,7 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
|
|||
|
||||
verification.RawQuery = "user_code=" + userCode
|
||||
response.VerificationURIComplete = verification.String()
|
||||
|
||||
httphelper.MarshalJSON(w, response)
|
||||
return nil
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) {
|
||||
|
@ -201,7 +209,7 @@ func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
|
|||
r = r.WithContext(ctx)
|
||||
|
||||
if err := deviceAccessToken(w, r, exchanger); err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, exchanger.Logger())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -16,8 +16,9 @@ import (
|
|||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/example/server/storage"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
func Test_deviceAuthorizationHandler(t *testing.T) {
|
||||
|
@ -319,7 +320,7 @@ func BenchmarkNewUserCode(b *testing.B) {
|
|||
}
|
||||
|
||||
func TestDeviceAccessToken(t *testing.T) {
|
||||
storage := testProvider.Storage().(op.DeviceAuthorizationStorage)
|
||||
storage := testProvider.Storage().(*storage.Storage)
|
||||
storage.StoreDeviceAuthorization(context.Background(), "native", "qwerty", "yuiop", time.Now().Add(time.Minute), []string{"foo"})
|
||||
storage.CompleteDeviceAuthorization(context.Background(), "yuiop", "tim")
|
||||
|
||||
|
@ -344,7 +345,7 @@ func TestDeviceAccessToken(t *testing.T) {
|
|||
func TestCheckDeviceAuthorizationState(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
storage := testProvider.Storage().(op.DeviceAuthorizationStorage)
|
||||
storage := testProvider.Storage().(*storage.Storage)
|
||||
storage.StoreDeviceAuthorization(context.Background(), "native", "pending", "pending", now.Add(time.Minute), []string{"foo"})
|
||||
storage.StoreDeviceAuthorization(context.Background(), "native", "denied", "denied", now.Add(time.Minute), []string{"foo"})
|
||||
storage.StoreDeviceAuthorization(context.Background(), "native", "completed", "completed", now.Add(time.Minute), []string{"foo"})
|
||||
|
|
|
@ -4,10 +4,10 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
type DiscoverStorage interface {
|
||||
|
@ -25,7 +25,7 @@ var DefaultSupportedScopes = []string{
|
|||
|
||||
func discoveryHandler(c Configuration, s DiscoverStorage) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
Discover(w, CreateDiscoveryConfig(r, c, s))
|
||||
Discover(w, CreateDiscoveryConfig(r.Context(), c, s))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -33,8 +33,8 @@ func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
|
|||
httphelper.MarshalJSON(w, config)
|
||||
}
|
||||
|
||||
func CreateDiscoveryConfig(r *http.Request, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration {
|
||||
issuer := config.IssuerFromRequest(r)
|
||||
func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration {
|
||||
issuer := IssuerFromContext(ctx)
|
||||
return &oidc.DiscoveryConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: config.AuthorizationEndpoint().Absolute(issuer),
|
||||
|
@ -49,7 +49,38 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov
|
|||
ResponseTypesSupported: ResponseTypes(config),
|
||||
GrantTypesSupported: GrantTypes(config),
|
||||
SubjectTypesSupported: SubjectTypes(config),
|
||||
IDTokenSigningAlgValuesSupported: SigAlgorithms(r.Context(), storage),
|
||||
IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
|
||||
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
|
||||
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
|
||||
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
|
||||
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config),
|
||||
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config),
|
||||
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config),
|
||||
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config),
|
||||
ClaimsSupported: SupportedClaims(config),
|
||||
CodeChallengeMethodsSupported: CodeChallengeMethods(config),
|
||||
UILocalesSupported: config.SupportedUILocales(),
|
||||
RequestParameterSupported: config.RequestObjectSupported(),
|
||||
}
|
||||
}
|
||||
|
||||
func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage DiscoverStorage, endpoints *Endpoints) *oidc.DiscoveryConfiguration {
|
||||
issuer := IssuerFromContext(ctx)
|
||||
return &oidc.DiscoveryConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: endpoints.Authorization.Absolute(issuer),
|
||||
TokenEndpoint: endpoints.Token.Absolute(issuer),
|
||||
IntrospectionEndpoint: endpoints.Introspection.Absolute(issuer),
|
||||
UserinfoEndpoint: endpoints.Userinfo.Absolute(issuer),
|
||||
RevocationEndpoint: endpoints.Revocation.Absolute(issuer),
|
||||
EndSessionEndpoint: endpoints.EndSession.Absolute(issuer),
|
||||
JwksURI: endpoints.JwksURI.Absolute(issuer),
|
||||
DeviceAuthorizationEndpoint: endpoints.DeviceAuthorization.Absolute(issuer),
|
||||
ScopesSupported: Scopes(config),
|
||||
ResponseTypesSupported: ResponseTypes(config),
|
||||
GrantTypesSupported: GrantTypes(config),
|
||||
SubjectTypesSupported: SubjectTypes(config),
|
||||
IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
|
||||
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
|
||||
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
|
||||
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
|
||||
|
|
|
@ -6,14 +6,14 @@ import (
|
|||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v2/pkg/op/mock"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/pkg/op/mock"
|
||||
)
|
||||
|
||||
func TestDiscover(t *testing.T) {
|
||||
|
@ -48,9 +48,9 @@ func TestDiscover(t *testing.T) {
|
|||
|
||||
func TestCreateDiscoveryConfig(t *testing.T) {
|
||||
type args struct {
|
||||
request *http.Request
|
||||
c op.Configuration
|
||||
s op.DiscoverStorage
|
||||
ctx context.Context
|
||||
c op.Configuration
|
||||
s op.DiscoverStorage
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -61,7 +61,7 @@ func TestCreateDiscoveryConfig(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := op.CreateDiscoveryConfig(tt.args.request, tt.args.c, tt.args.s)
|
||||
got := op.CreateDiscoveryConfig(tt.args.ctx, tt.args.c, tt.args.s)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,32 +1,46 @@
|
|||
package op
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Endpoint struct {
|
||||
path string
|
||||
url string
|
||||
}
|
||||
|
||||
func NewEndpoint(path string) Endpoint {
|
||||
return Endpoint{path: path}
|
||||
func NewEndpoint(path string) *Endpoint {
|
||||
return &Endpoint{path: path}
|
||||
}
|
||||
|
||||
func NewEndpointWithURL(path, url string) Endpoint {
|
||||
return Endpoint{path: path, url: url}
|
||||
func NewEndpointWithURL(path, url string) *Endpoint {
|
||||
return &Endpoint{path: path, url: url}
|
||||
}
|
||||
|
||||
func (e Endpoint) Relative() string {
|
||||
func (e *Endpoint) Relative() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return relativeEndpoint(e.path)
|
||||
}
|
||||
|
||||
func (e Endpoint) Absolute(host string) string {
|
||||
func (e *Endpoint) Absolute(host string) string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.url != "" {
|
||||
return e.url
|
||||
}
|
||||
return absoluteEndpoint(host, e.path)
|
||||
}
|
||||
|
||||
func (e Endpoint) Validate() error {
|
||||
var ErrNilEndpoint = errors.New("nil endpoint")
|
||||
|
||||
func (e *Endpoint) Validate() error {
|
||||
if e == nil {
|
||||
return ErrNilEndpoint
|
||||
}
|
||||
return nil // TODO:
|
||||
}
|
||||
|
||||
|
|
|
@ -3,13 +3,14 @@ package op_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
func TestEndpoint_Path(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
e op.Endpoint
|
||||
e *op.Endpoint
|
||||
want string
|
||||
}{
|
||||
{
|
||||
|
@ -27,6 +28,11 @@ func TestEndpoint_Path(t *testing.T) {
|
|||
op.NewEndpointWithURL("/test", "http://test.com/test"),
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
"",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -43,7 +49,7 @@ func TestEndpoint_Absolute(t *testing.T) {
|
|||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
e op.Endpoint
|
||||
e *op.Endpoint
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
|
@ -77,6 +83,12 @@ func TestEndpoint_Absolute(t *testing.T) {
|
|||
args{"https://host"},
|
||||
"https://test.com/test",
|
||||
},
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
args{"https://host"},
|
||||
"",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -91,16 +103,19 @@ func TestEndpoint_Absolute(t *testing.T) {
|
|||
func TestEndpoint_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
e op.Endpoint
|
||||
wantErr bool
|
||||
e *op.Endpoint
|
||||
wantErr error
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
op.ErrNilEndpoint,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.e.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
err := tt.e.Validate()
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
137
pkg/op/error.go
137
pkg/op/error.go
|
@ -1,10 +1,14 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
type ErrAuthRequest interface {
|
||||
|
@ -13,13 +17,31 @@ type ErrAuthRequest interface {
|
|||
GetState() string
|
||||
}
|
||||
|
||||
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder httphelper.Encoder) {
|
||||
// LogAuthRequest is an optional interface,
|
||||
// that allows logging AuthRequest fields.
|
||||
// If the AuthRequest does not implement this interface,
|
||||
// no details shall be printed to the logs.
|
||||
type LogAuthRequest interface {
|
||||
ErrAuthRequest
|
||||
slog.LogValuer
|
||||
}
|
||||
|
||||
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, authorizer Authorizer) {
|
||||
e := oidc.DefaultToServerError(err, err.Error())
|
||||
logger := authorizer.Logger().With("oidc_error", e)
|
||||
|
||||
if authReq == nil {
|
||||
logger.Log(r.Context(), e.LogLevel(), "auth request")
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
e := oidc.DefaultToServerError(err, err.Error())
|
||||
|
||||
if logAuthReq, ok := authReq.(LogAuthRequest); ok {
|
||||
logger = logger.With("auth_request", logAuthReq)
|
||||
}
|
||||
|
||||
if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
|
||||
logger.Log(r.Context(), e.LogLevel(), "auth request: not redirecting")
|
||||
http.Error(w, e.Description, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
@ -28,19 +50,120 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
|
|||
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
|
||||
responseMode = rm.GetResponseMode()
|
||||
}
|
||||
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
|
||||
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, authorizer.Encoder())
|
||||
if err != nil {
|
||||
logger.ErrorContext(r.Context(), "auth response URL", "error", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
logger.Log(r.Context(), e.LogLevel(), "auth request")
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
||||
func RequestError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
|
||||
e := oidc.DefaultToServerError(err, err.Error())
|
||||
status := http.StatusBadRequest
|
||||
if e.ErrorType == oidc.InvalidClient {
|
||||
status = 401
|
||||
status = http.StatusUnauthorized
|
||||
}
|
||||
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
|
||||
httphelper.MarshalJSONWithStatus(w, e, status)
|
||||
}
|
||||
|
||||
// TryErrorRedirect tries to handle an error by redirecting a client.
|
||||
// If this attempt fails, an error is returned that must be returned
|
||||
// to the client instead.
|
||||
func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) {
|
||||
e := oidc.DefaultToServerError(parent, parent.Error())
|
||||
logger = logger.With("oidc_error", e)
|
||||
|
||||
if authReq == nil {
|
||||
logger.Log(ctx, e.LogLevel(), "auth request")
|
||||
return nil, AsStatusError(e, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if logAuthReq, ok := authReq.(LogAuthRequest); ok {
|
||||
logger = logger.With("auth_request", logAuthReq)
|
||||
}
|
||||
|
||||
if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
|
||||
logger.Log(ctx, e.LogLevel(), "auth request: not redirecting")
|
||||
return nil, AsStatusError(e, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
e.State = authReq.GetState()
|
||||
var responseMode oidc.ResponseMode
|
||||
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
|
||||
responseMode = rm.GetResponseMode()
|
||||
}
|
||||
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, "auth response URL", "error", err)
|
||||
return nil, AsStatusError(err, http.StatusBadRequest)
|
||||
}
|
||||
logger.Log(ctx, e.LogLevel(), "auth request redirect", "url", url)
|
||||
return NewRedirect(url), nil
|
||||
}
|
||||
|
||||
// StatusError wraps an error with a HTTP status code.
|
||||
// The status code is passed to the handler's writer.
|
||||
type StatusError struct {
|
||||
parent error
|
||||
statusCode int
|
||||
}
|
||||
|
||||
// NewStatusError sets the parent and statusCode to a new StatusError.
|
||||
// It is recommended for parent to be an [oidc.Error].
|
||||
//
|
||||
// Typically implementations should only use this to signal something
|
||||
// very specific, like an internal server error.
|
||||
// If a returned error is not a StatusError, the framework
|
||||
// will set a statusCode based on what the standard specifies,
|
||||
// which is [http.StatusBadRequest] for most of the time.
|
||||
// If the error encountered can described clearly with a [oidc.Error],
|
||||
// do not use this function, as it might break standard rules!
|
||||
func NewStatusError(parent error, statusCode int) StatusError {
|
||||
return StatusError{
|
||||
parent: parent,
|
||||
statusCode: statusCode,
|
||||
}
|
||||
}
|
||||
|
||||
// AsStatusError unwraps a StatusError from err
|
||||
// and returns it unmodified if found.
|
||||
// If no StatuError was found, a new one is returned
|
||||
// with statusCode set to it as a default.
|
||||
func AsStatusError(err error, statusCode int) (target StatusError) {
|
||||
if errors.As(err, &target) {
|
||||
return target
|
||||
}
|
||||
return NewStatusError(err, statusCode)
|
||||
}
|
||||
|
||||
func (e StatusError) Error() string {
|
||||
return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error())
|
||||
}
|
||||
|
||||
func (e StatusError) Unwrap() error {
|
||||
return e.parent
|
||||
}
|
||||
|
||||
func (e StatusError) Is(err error) bool {
|
||||
var target StatusError
|
||||
if !errors.As(err, &target) {
|
||||
return false
|
||||
}
|
||||
return errors.Is(e.parent, target.parent) &&
|
||||
e.statusCode == target.statusCode
|
||||
}
|
||||
|
||||
// WriteError asserts for a StatusError containing an [oidc.Error].
|
||||
// If no StatusError is found, the status code will default to [http.StatusBadRequest].
|
||||
// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError].
|
||||
func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
|
||||
statusError := AsStatusError(err, http.StatusBadRequest)
|
||||
e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error())
|
||||
|
||||
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
|
||||
httphelper.MarshalJSONWithStatus(w, e, statusError.statusCode)
|
||||
}
|
||||
|
|
677
pkg/op/error_test.go
Normal file
677
pkg/op/error_test.go
Normal file
|
@ -0,0 +1,677 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/schema"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func TestAuthRequestError(t *testing.T) {
|
||||
type args struct {
|
||||
authReq ErrAuthRequest
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantCode int
|
||||
wantHeaders map[string]string
|
||||
wantBody string
|
||||
wantLog string
|
||||
}{
|
||||
{
|
||||
name: "nil auth request",
|
||||
args: args{
|
||||
authReq: nil,
|
||||
err: io.ErrClosedPipe,
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantBody: "io: read/write on closed pipe\n",
|
||||
wantLog: `{
|
||||
"level":"ERROR",
|
||||
"msg":"auth request",
|
||||
"time":"not",
|
||||
"oidc_error":{
|
||||
"description":"io: read/write on closed pipe",
|
||||
"parent":"io: read/write on closed pipe",
|
||||
"type":"server_error"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request, no redirect URI",
|
||||
args: args{
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
err: oidc.ErrInteractionRequired().WithDescription("sign in"),
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantBody: "sign in\n",
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"auth request: not redirecting",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"sign in",
|
||||
"type":"interaction_required"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request, redirect disabled",
|
||||
args: args{
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
RedirectURI: "http://example.com/callback",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
err: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"),
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantBody: "oops\n",
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"auth request: not redirecting",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"http://example.com/callback",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"oops",
|
||||
"type":"invalid_request",
|
||||
"redirect_disabled":true
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request, url parse error",
|
||||
args: args{
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
RedirectURI: "can't parse this!\n",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
err: oidc.ErrInteractionRequired().WithDescription("sign in"),
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantBody: "ErrorType=server_error Parent=parse \"can't parse this!\\n\": net/url: invalid control character in URL\n",
|
||||
wantLog: `{
|
||||
"level":"ERROR",
|
||||
"msg":"auth response URL",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"can't parse this!\n",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"error":{
|
||||
"type":"server_error",
|
||||
"parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"sign in",
|
||||
"type":"interaction_required"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request redirect",
|
||||
args: args{
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
RedirectURI: "http://example.com/callback",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
err: oidc.ErrInteractionRequired().WithDescription("sign in"),
|
||||
},
|
||||
wantCode: http.StatusFound,
|
||||
wantHeaders: map[string]string{"Location": "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"},
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"auth request",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"http://example.com/callback",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"sign in",
|
||||
"type":"interaction_required"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logOut := new(strings.Builder)
|
||||
authorizer := &Provider{
|
||||
encoder: schema.NewEncoder(),
|
||||
logger: slog.New(
|
||||
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
|
||||
),
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/path", nil)
|
||||
AuthRequestError(w, r, tt.args.authReq, tt.args.err, authorizer)
|
||||
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, tt.wantCode, res.StatusCode)
|
||||
for key, wantHeader := range tt.wantHeaders {
|
||||
gotHeader := res.Header.Get(key)
|
||||
assert.Equalf(t, wantHeader, gotHeader, "header %q", key)
|
||||
}
|
||||
gotBody, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err, "read result body")
|
||||
assert.Equal(t, tt.wantBody, string(gotBody), "result body")
|
||||
|
||||
gotLog := logOut.String()
|
||||
t.Log(gotLog)
|
||||
assert.JSONEq(t, tt.wantLog, gotLog, "log output")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantCode int
|
||||
wantBody string
|
||||
wantLog string
|
||||
}{
|
||||
{
|
||||
name: "server error",
|
||||
err: io.ErrClosedPipe,
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`,
|
||||
wantLog: `{
|
||||
"level":"ERROR",
|
||||
"msg":"request error",
|
||||
"time":"not",
|
||||
"oidc_error":{
|
||||
"parent":"io: read/write on closed pipe",
|
||||
"description":"io: read/write on closed pipe",
|
||||
"type":"server_error"}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "invalid client",
|
||||
err: oidc.ErrInvalidClient().WithDescription("not good"),
|
||||
wantCode: http.StatusUnauthorized,
|
||||
wantBody: `{"error":"invalid_client", "error_description":"not good"}`,
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"request error",
|
||||
"time":"not",
|
||||
"oidc_error":{
|
||||
"description":"not good",
|
||||
"type":"invalid_client"}
|
||||
}`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logOut := new(strings.Builder)
|
||||
logger := slog.New(
|
||||
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/path", nil)
|
||||
RequestError(w, r, tt.err, logger)
|
||||
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, tt.wantCode, res.StatusCode, "status code")
|
||||
|
||||
gotBody, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err, "read result body")
|
||||
assert.JSONEq(t, tt.wantBody, string(gotBody), "result body")
|
||||
|
||||
gotLog := logOut.String()
|
||||
t.Log(gotLog)
|
||||
assert.JSONEq(t, tt.wantLog, gotLog, "log output")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryErrorRedirect(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
authReq ErrAuthRequest
|
||||
parent error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *Redirect
|
||||
wantErr error
|
||||
wantLog string
|
||||
}{
|
||||
{
|
||||
name: "nil auth request",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
authReq: nil,
|
||||
parent: io.ErrClosedPipe,
|
||||
},
|
||||
wantErr: NewStatusError(io.ErrClosedPipe, http.StatusBadRequest),
|
||||
wantLog: `{
|
||||
"level":"ERROR",
|
||||
"msg":"auth request",
|
||||
"time":"not",
|
||||
"oidc_error":{
|
||||
"description":"io: read/write on closed pipe",
|
||||
"parent":"io: read/write on closed pipe",
|
||||
"type":"server_error"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request, no redirect URI",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
|
||||
},
|
||||
wantErr: NewStatusError(oidc.ErrInteractionRequired().WithDescription("sign in"), http.StatusBadRequest),
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"auth request: not redirecting",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"sign in",
|
||||
"type":"interaction_required"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request, redirect disabled",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
RedirectURI: "http://example.com/callback",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
parent: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"),
|
||||
},
|
||||
wantErr: NewStatusError(oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), http.StatusBadRequest),
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"auth request: not redirecting",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"http://example.com/callback",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"oops",
|
||||
"type":"invalid_request",
|
||||
"redirect_disabled":true
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request, url parse error",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
RedirectURI: "can't parse this!\n",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
|
||||
},
|
||||
wantErr: func() error {
|
||||
//lint:ignore SA1007 just recreating the error for testing
|
||||
_, err := url.Parse("can't parse this!\n")
|
||||
err = oidc.ErrServerError().WithParent(err)
|
||||
return NewStatusError(err, http.StatusBadRequest)
|
||||
}(),
|
||||
wantLog: `{
|
||||
"level":"ERROR",
|
||||
"msg":"auth response URL",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"can't parse this!\n",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"error":{
|
||||
"type":"server_error",
|
||||
"parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"sign in",
|
||||
"type":"interaction_required"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request redirect",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
RedirectURI: "http://example.com/callback",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
|
||||
},
|
||||
want: &Redirect{
|
||||
URL: "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1",
|
||||
},
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"auth request redirect",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"http://example.com/callback",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"sign in",
|
||||
"type":"interaction_required"
|
||||
},
|
||||
"url":"http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"
|
||||
}`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logOut := new(strings.Builder)
|
||||
logger := slog.New(
|
||||
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
|
||||
)
|
||||
encoder := schema.NewEncoder()
|
||||
|
||||
got, err := TryErrorRedirect(tt.args.ctx, tt.args.authReq, tt.args.parent, encoder, logger)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
|
||||
gotLog := logOut.String()
|
||||
t.Log(gotLog)
|
||||
assert.JSONEq(t, tt.wantLog, gotLog, "log output")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewStatusError(t *testing.T) {
|
||||
err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
|
||||
|
||||
want := "Internal Server Error: io: read/write on closed pipe"
|
||||
got := fmt.Sprint(err)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestAsStatusError(t *testing.T) {
|
||||
type args struct {
|
||||
err error
|
||||
statusCode int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "already status error",
|
||||
args: args{
|
||||
err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError),
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
want: "Internal Server Error: io: read/write on closed pipe",
|
||||
},
|
||||
{
|
||||
name: "oidc error",
|
||||
args: args{
|
||||
err: oidc.ErrAcrInvalid,
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
want: "Bad Request: acr is invalid",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := AsStatusError(tt.args.err, tt.args.statusCode)
|
||||
got := fmt.Sprint(err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusError_Unwrap(t *testing.T) {
|
||||
err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
|
||||
require.ErrorIs(t, err, io.ErrClosedPipe)
|
||||
}
|
||||
|
||||
func TestStatusError_Is(t *testing.T) {
|
||||
type args struct {
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
args: args{err: nil},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
args: args{err: io.EOF},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "other parent",
|
||||
args: args{err: NewStatusError(io.EOF, http.StatusInternalServerError)},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "other status",
|
||||
args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInsufficientStorage)},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "same",
|
||||
args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wrapped",
|
||||
args: args{err: fmt.Errorf("wrap: %w", NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError))},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
|
||||
if got := e.Is(tt.args.err); got != tt.want {
|
||||
t.Errorf("StatusError.Is() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantStatus int
|
||||
wantBody string
|
||||
wantLog string
|
||||
}{
|
||||
{
|
||||
name: "not a status or oidc error",
|
||||
err: io.ErrClosedPipe,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{
|
||||
"error":"server_error",
|
||||
"error_description":"io: read/write on closed pipe"
|
||||
}`,
|
||||
wantLog: `{
|
||||
"level":"ERROR",
|
||||
"msg":"request error",
|
||||
"oidc_error":{
|
||||
"description":"io: read/write on closed pipe",
|
||||
"parent":"io: read/write on closed pipe",
|
||||
"type":"server_error"
|
||||
},
|
||||
"time":"not"
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "status error w/o oidc",
|
||||
err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError),
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantBody: `{
|
||||
"error":"server_error",
|
||||
"error_description":"io: read/write on closed pipe"
|
||||
}`,
|
||||
wantLog: `{
|
||||
"level":"ERROR",
|
||||
"msg":"request error",
|
||||
"oidc_error":{
|
||||
"description":"io: read/write on closed pipe",
|
||||
"parent":"io: read/write on closed pipe",
|
||||
"type":"server_error"
|
||||
},
|
||||
"time":"not"
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "oidc error w/o status",
|
||||
err: oidc.ErrInvalidRequest().WithDescription("oops"),
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{
|
||||
"error":"invalid_request",
|
||||
"error_description":"oops"
|
||||
}`,
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"request error",
|
||||
"oidc_error":{
|
||||
"description":"oops",
|
||||
"type":"invalid_request"
|
||||
},
|
||||
"time":"not"
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "status with oidc error",
|
||||
err: NewStatusError(
|
||||
oidc.ErrUnauthorizedClient().WithDescription("oops"),
|
||||
http.StatusUnauthorized,
|
||||
),
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantBody: `{
|
||||
"error":"unauthorized_client",
|
||||
"error_description":"oops"
|
||||
}`,
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"request error",
|
||||
"oidc_error":{
|
||||
"description":"oops",
|
||||
"type":"unauthorized_client"
|
||||
},
|
||||
"time":"not"
|
||||
}`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logOut := new(strings.Builder)
|
||||
logger := slog.New(
|
||||
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
|
||||
)
|
||||
r := httptest.NewRequest("GET", "/target", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
WriteError(w, r, tt.err, logger)
|
||||
res := w.Result()
|
||||
assert.Equal(t, tt.wantStatus, res.StatusCode, "status code")
|
||||
gotBody, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(t, tt.wantBody, string(gotBody), "body")
|
||||
assert.JSONEq(t, tt.wantLog, logOut.String())
|
||||
})
|
||||
}
|
||||
}
|
|
@ -4,9 +4,9 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
)
|
||||
|
||||
type KeyProvider interface {
|
||||
|
|
|
@ -7,13 +7,13 @@ import (
|
|||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v2/pkg/op/mock"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/pkg/op/mock"
|
||||
)
|
||||
|
||||
func TestKeys(t *testing.T) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Authorizer)
|
||||
// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Authorizer)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
@ -9,8 +9,9 @@ import (
|
|||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
http "github.com/zitadel/oidc/v2/pkg/http"
|
||||
op "github.com/zitadel/oidc/v2/pkg/op"
|
||||
http "github.com/zitadel/oidc/v3/pkg/http"
|
||||
op "github.com/zitadel/oidc/v3/pkg/op"
|
||||
slog "golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
// MockAuthorizer is a mock of Authorizer interface.
|
||||
|
@ -79,10 +80,10 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
|
|||
}
|
||||
|
||||
// IDTokenHintVerifier mocks base method.
|
||||
func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) op.IDTokenHintVerifier {
|
||||
func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) *op.IDTokenHintVerifier {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0)
|
||||
ret0, _ := ret[0].(op.IDTokenHintVerifier)
|
||||
ret0, _ := ret[0].(*op.IDTokenHintVerifier)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -92,6 +93,20 @@ func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier(arg0 interface{}) *gom
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier), arg0)
|
||||
}
|
||||
|
||||
// Logger mocks base method.
|
||||
func (m *MockAuthorizer) Logger() *slog.Logger {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Logger")
|
||||
ret0, _ := ret[0].(*slog.Logger)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Logger indicates an expected call of Logger.
|
||||
func (mr *MockAuthorizerMockRecorder) Logger() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockAuthorizer)(nil).Logger))
|
||||
}
|
||||
|
||||
// RequestObjectSupported mocks base method.
|
||||
func (m *MockAuthorizer) RequestObjectSupported() bool {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -4,12 +4,12 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/schema"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
"github.com/zitadel/schema"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
func NewAuthorizer(t *testing.T) op.Authorizer {
|
||||
|
@ -49,7 +49,7 @@ func ExpectEncoder(a op.Authorizer) {
|
|||
func ExpectVerifier(a op.Authorizer, t *testing.T) {
|
||||
mockA := a.(*MockAuthorizer)
|
||||
mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn(
|
||||
func() op.IDTokenHintVerifier {
|
||||
func() *op.IDTokenHintVerifier {
|
||||
return op.NewIDTokenHintVerifier("", nil)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -5,8 +5,8 @@ import (
|
|||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
func NewClient(t *testing.T) op.Client {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Client)
|
||||
// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Client)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
@ -9,8 +9,8 @@ import (
|
|||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
oidc "github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
op "github.com/zitadel/oidc/v2/pkg/op"
|
||||
oidc "github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
op "github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
// MockClient is a mock of Client interface.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Configuration)
|
||||
// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Configuration)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
@ -9,7 +9,7 @@ import (
|
|||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
op "github.com/zitadel/oidc/v2/pkg/op"
|
||||
op "github.com/zitadel/oidc/v3/pkg/op"
|
||||
language "golang.org/x/text/language"
|
||||
)
|
||||
|
||||
|
@ -65,10 +65,10 @@ func (mr *MockConfigurationMockRecorder) AuthMethodPrivateKeyJWTSupported() *gom
|
|||
}
|
||||
|
||||
// AuthorizationEndpoint mocks base method.
|
||||
func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) AuthorizationEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthorizationEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -107,10 +107,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
|
|||
}
|
||||
|
||||
// DeviceAuthorizationEndpoint mocks base method.
|
||||
func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) DeviceAuthorizationEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -121,10 +121,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.C
|
|||
}
|
||||
|
||||
// EndSessionEndpoint mocks base method.
|
||||
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) EndSessionEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "EndSessionEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -233,10 +233,10 @@ func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSup
|
|||
}
|
||||
|
||||
// IntrospectionEndpoint mocks base method.
|
||||
func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) IntrospectionEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IntrospectionEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -275,10 +275,10 @@ func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *go
|
|||
}
|
||||
|
||||
// KeysEndpoint mocks base method.
|
||||
func (m *MockConfiguration) KeysEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) KeysEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "KeysEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -331,10 +331,10 @@ func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSuppor
|
|||
}
|
||||
|
||||
// RevocationEndpoint mocks base method.
|
||||
func (m *MockConfiguration) RevocationEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) RevocationEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RevocationEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -373,10 +373,10 @@ func (mr *MockConfigurationMockRecorder) SupportedUILocales() *gomock.Call {
|
|||
}
|
||||
|
||||
// TokenEndpoint mocks base method.
|
||||
func (m *MockConfiguration) TokenEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) TokenEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TokenEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -401,10 +401,10 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
|
|||
}
|
||||
|
||||
// UserinfoEndpoint mocks base method.
|
||||
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) UserinfoEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UserinfoEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: DiscoverStorage)
|
||||
// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: DiscoverStorage)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
@ -8,8 +8,8 @@ import (
|
|||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// MockDiscoverStorage is a mock of DiscoverStorage interface.
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
package mock
|
||||
|
||||
//go:generate go install github.com/golang/mock/mockgen@v1.6.0
|
||||
//go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/v2/pkg/op Storage
|
||||
//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/v2/pkg/op Authorizer
|
||||
//go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/v2/pkg/op Client
|
||||
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/v2/pkg/op Configuration
|
||||
//go:generate mockgen -package mock -destination ./discovery.mock.go github.com/zitadel/oidc/v2/pkg/op DiscoverStorage
|
||||
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/v2/pkg/op SigningKey,Key
|
||||
//go:generate mockgen -package mock -destination ./key.mock.go github.com/zitadel/oidc/v2/pkg/op KeyProvider
|
||||
//go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/v3/pkg/op Storage
|
||||
//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/v3/pkg/op Authorizer
|
||||
//go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/v3/pkg/op Client
|
||||
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/v3/pkg/op Configuration
|
||||
//go:generate mockgen -package mock -destination ./discovery.mock.go github.com/zitadel/oidc/v3/pkg/op DiscoverStorage
|
||||
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/v3/pkg/op SigningKey,Key
|
||||
//go:generate mockgen -package mock -destination ./key.mock.go github.com/zitadel/oidc/v3/pkg/op KeyProvider
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: KeyProvider)
|
||||
// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: KeyProvider)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
@ -9,7 +9,7 @@ import (
|
|||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
op "github.com/zitadel/oidc/v2/pkg/op"
|
||||
op "github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
// MockKeyProvider is a mock of KeyProvider interface.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: SigningKey,Key)
|
||||
// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: SigningKey,Key)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
@ -7,8 +7,8 @@ package mock
|
|||
import (
|
||||
reflect "reflect"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// MockSigningKey is a mock of SigningKey interface.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Storage)
|
||||
// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: Storage)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
@ -9,10 +9,10 @@ import (
|
|||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
oidc "github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
op "github.com/zitadel/oidc/v2/pkg/op"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
oidc "github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
op "github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
// MockStorage is a mock of Storage interface.
|
||||
|
|
|
@ -8,8 +8,8 @@ import (
|
|||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
func NewStorage(t *testing.T) op.Storage {
|
||||
|
|
123
pkg/op/op.go
123
pkg/op/op.go
|
@ -6,16 +6,17 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/schema"
|
||||
"github.com/go-chi/chi"
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
"github.com/rs/cors"
|
||||
"github.com/zitadel/schema"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/exp/slog"
|
||||
"golang.org/x/text/language"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -33,7 +34,7 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
DefaultEndpoints = &endpoints{
|
||||
DefaultEndpoints = &Endpoints{
|
||||
Authorization: NewEndpoint(defaultAuthorizationEndpoint),
|
||||
Token: NewEndpoint(defaultTokenEndpoint),
|
||||
Introspection: NewEndpoint(defaultIntrospectEndpoint),
|
||||
|
@ -76,29 +77,35 @@ func init() {
|
|||
}
|
||||
|
||||
type OpenIDProvider interface {
|
||||
http.Handler
|
||||
Configuration
|
||||
Storage() Storage
|
||||
Decoder() httphelper.Decoder
|
||||
Encoder() httphelper.Encoder
|
||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
||||
AccessTokenVerifier(context.Context) AccessTokenVerifier
|
||||
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
|
||||
AccessTokenVerifier(context.Context) *AccessTokenVerifier
|
||||
Crypto() Crypto
|
||||
DefaultLogoutRedirectURI() string
|
||||
Probes() []ProbesFn
|
||||
|
||||
// EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20
|
||||
Logger() *slog.Logger
|
||||
|
||||
// Deprecated: Provider now implements http.Handler directly.
|
||||
HttpHandler() http.Handler
|
||||
}
|
||||
|
||||
type HttpInterceptor func(http.Handler) http.Handler
|
||||
|
||||
func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router {
|
||||
router := mux.NewRouter()
|
||||
func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router {
|
||||
router := chi.NewRouter()
|
||||
router.Use(cors.New(defaultCORSOptions).Handler)
|
||||
router.Use(intercept(o.IssuerFromRequest, interceptors...))
|
||||
router.HandleFunc(healthEndpoint, healthHandler)
|
||||
router.HandleFunc(readinessEndpoint, readyHandler(o.Probes()))
|
||||
router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage()))
|
||||
router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o))
|
||||
router.NewRoute().Path(authCallbackPath(o)).Queries("id", "{id}").HandlerFunc(authorizeCallbackHandler(o))
|
||||
router.HandleFunc(authCallbackPath(o), authorizeCallbackHandler(o))
|
||||
router.HandleFunc(o.TokenEndpoint().Relative(), tokenHandler(o))
|
||||
router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o))
|
||||
router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o))
|
||||
|
@ -132,16 +139,17 @@ type Config struct {
|
|||
DeviceAuthorization DeviceAuthorizationConfig
|
||||
}
|
||||
|
||||
type endpoints struct {
|
||||
Authorization Endpoint
|
||||
Token Endpoint
|
||||
Introspection Endpoint
|
||||
Userinfo Endpoint
|
||||
Revocation Endpoint
|
||||
EndSession Endpoint
|
||||
CheckSessionIframe Endpoint
|
||||
JwksURI Endpoint
|
||||
DeviceAuthorization Endpoint
|
||||
// Endpoints defines endpoint routes.
|
||||
type Endpoints struct {
|
||||
Authorization *Endpoint
|
||||
Token *Endpoint
|
||||
Introspection *Endpoint
|
||||
Userinfo *Endpoint
|
||||
Revocation *Endpoint
|
||||
EndSession *Endpoint
|
||||
CheckSessionIframe *Endpoint
|
||||
JwksURI *Endpoint
|
||||
DeviceAuthorization *Endpoint
|
||||
}
|
||||
|
||||
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
|
||||
|
@ -186,6 +194,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
|
|||
storage: storage,
|
||||
endpoints: DefaultEndpoints,
|
||||
timer: make(<-chan time.Time),
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
for _, optFunc := range opOpts {
|
||||
|
@ -199,7 +208,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
|
|||
return nil, err
|
||||
}
|
||||
|
||||
o.httpHandler = CreateRouter(o, o.interceptors...)
|
||||
o.Handler = CreateRouter(o, o.interceptors...)
|
||||
|
||||
o.decoder = schema.NewDecoder()
|
||||
o.decoder.IgnoreUnknownKeys(true)
|
||||
|
@ -215,20 +224,21 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
|
|||
}
|
||||
|
||||
type Provider struct {
|
||||
http.Handler
|
||||
config *Config
|
||||
issuer IssuerFromRequest
|
||||
insecure bool
|
||||
endpoints *endpoints
|
||||
endpoints *Endpoints
|
||||
storage Storage
|
||||
keySet *openIDKeySet
|
||||
crypto Crypto
|
||||
httpHandler http.Handler
|
||||
decoder *schema.Decoder
|
||||
encoder *schema.Encoder
|
||||
interceptors []HttpInterceptor
|
||||
timer <-chan time.Time
|
||||
accessTokenVerifierOpts []AccessTokenVerifierOpt
|
||||
idTokenHintVerifierOpts []IDTokenHintVerifierOpt
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (o *Provider) IssuerFromRequest(r *http.Request) string {
|
||||
|
@ -239,35 +249,35 @@ func (o *Provider) Insecure() bool {
|
|||
return o.insecure
|
||||
}
|
||||
|
||||
func (o *Provider) AuthorizationEndpoint() Endpoint {
|
||||
func (o *Provider) AuthorizationEndpoint() *Endpoint {
|
||||
return o.endpoints.Authorization
|
||||
}
|
||||
|
||||
func (o *Provider) TokenEndpoint() Endpoint {
|
||||
func (o *Provider) TokenEndpoint() *Endpoint {
|
||||
return o.endpoints.Token
|
||||
}
|
||||
|
||||
func (o *Provider) IntrospectionEndpoint() Endpoint {
|
||||
func (o *Provider) IntrospectionEndpoint() *Endpoint {
|
||||
return o.endpoints.Introspection
|
||||
}
|
||||
|
||||
func (o *Provider) UserinfoEndpoint() Endpoint {
|
||||
func (o *Provider) UserinfoEndpoint() *Endpoint {
|
||||
return o.endpoints.Userinfo
|
||||
}
|
||||
|
||||
func (o *Provider) RevocationEndpoint() Endpoint {
|
||||
func (o *Provider) RevocationEndpoint() *Endpoint {
|
||||
return o.endpoints.Revocation
|
||||
}
|
||||
|
||||
func (o *Provider) EndSessionEndpoint() Endpoint {
|
||||
func (o *Provider) EndSessionEndpoint() *Endpoint {
|
||||
return o.endpoints.EndSession
|
||||
}
|
||||
|
||||
func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
|
||||
func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint {
|
||||
return o.endpoints.DeviceAuthorization
|
||||
}
|
||||
|
||||
func (o *Provider) KeysEndpoint() Endpoint {
|
||||
func (o *Provider) KeysEndpoint() *Endpoint {
|
||||
return o.endpoints.JwksURI
|
||||
}
|
||||
|
||||
|
@ -354,15 +364,15 @@ func (o *Provider) Encoder() httphelper.Encoder {
|
|||
return o.encoder
|
||||
}
|
||||
|
||||
func (o *Provider) IDTokenHintVerifier(ctx context.Context) IDTokenHintVerifier {
|
||||
func (o *Provider) IDTokenHintVerifier(ctx context.Context) *IDTokenHintVerifier {
|
||||
return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...)
|
||||
}
|
||||
|
||||
func (o *Provider) JWTProfileVerifier(ctx context.Context) JWTProfileVerifier {
|
||||
func (o *Provider) JWTProfileVerifier(ctx context.Context) *JWTProfileVerifier {
|
||||
return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second)
|
||||
}
|
||||
|
||||
func (o *Provider) AccessTokenVerifier(ctx context.Context) AccessTokenVerifier {
|
||||
func (o *Provider) AccessTokenVerifier(ctx context.Context) *AccessTokenVerifier {
|
||||
return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...)
|
||||
}
|
||||
|
||||
|
@ -387,8 +397,13 @@ func (o *Provider) Probes() []ProbesFn {
|
|||
}
|
||||
}
|
||||
|
||||
func (o *Provider) Logger() *slog.Logger {
|
||||
return o.logger
|
||||
}
|
||||
|
||||
// Deprecated: Provider now implements http.Handler directly.
|
||||
func (o *Provider) HttpHandler() http.Handler {
|
||||
return o.httpHandler
|
||||
return o
|
||||
}
|
||||
|
||||
type openIDKeySet struct {
|
||||
|
@ -421,7 +436,7 @@ func WithAllowInsecure() Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomAuthEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomAuthEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -431,7 +446,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomTokenEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomTokenEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -441,7 +456,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomIntrospectionEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -451,7 +466,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomUserinfoEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -461,7 +476,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomRevocationEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -471,7 +486,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomEndSessionEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -481,7 +496,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomKeysEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomKeysEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -491,7 +506,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomDeviceAuthorizationEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -501,8 +516,16 @@ func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option {
|
||||
// WithCustomEndpoints sets multiple endpoints at once.
|
||||
// Non of the endpoints may be nil, or an error will
|
||||
// be returned when the Option used by the Provider.
|
||||
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
for _, e := range []*Endpoint{auth, token, userInfo, revocation, endSession, keys} {
|
||||
if err := e.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
o.endpoints.Authorization = auth
|
||||
o.endpoints.Token = token
|
||||
o.endpoints.Userinfo = userInfo
|
||||
|
@ -534,6 +557,16 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithLogger lets a logger other than slog.Default().
|
||||
//
|
||||
// EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20
|
||||
func WithLogger(logger *slog.Logger) Option {
|
||||
return func(o *Provider) error {
|
||||
o.logger = logger
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler {
|
||||
issuerInterceptor := NewIssuerInterceptor(i)
|
||||
return func(handler http.Handler) http.Handler {
|
||||
|
|
|
@ -14,9 +14,9 @@ import (
|
|||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v2/example/server/storage"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"github.com/zitadel/oidc/v3/example/server/storage"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
|
@ -157,7 +157,7 @@ func TestRoutes(t *testing.T) {
|
|||
values: map[string]string{
|
||||
"client_id": client.GetID(),
|
||||
"redirect_uri": "https://example.com",
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
"response_type": string(oidc.ResponseTypeCode),
|
||||
},
|
||||
wantCode: http.StatusFound,
|
||||
|
@ -194,7 +194,7 @@ func TestRoutes(t *testing.T) {
|
|||
path: testProvider.TokenEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeBearer),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
"assertion": jwtToken,
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
|
@ -207,7 +207,7 @@ func TestRoutes(t *testing.T) {
|
|||
basicAuth: &basicAuth{"web", "secret"},
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeTokenExchange),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
"subject_token": jwtToken,
|
||||
"subject_token_type": string(oidc.AccessTokenType),
|
||||
},
|
||||
|
@ -224,7 +224,7 @@ func TestRoutes(t *testing.T) {
|
|||
basicAuth: &basicAuth{"sid1", "verysecret"},
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeClientCredentials),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`},
|
||||
|
@ -339,7 +339,7 @@ func TestRoutes(t *testing.T) {
|
|||
path: testProvider.DeviceAuthorizationEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"device", "secret"},
|
||||
values: map[string]string{
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{
|
||||
|
@ -371,7 +371,7 @@ func TestRoutes(t *testing.T) {
|
|||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
testProvider.HttpHandler().ServeHTTP(rec, req)
|
||||
testProvider.ServeHTTP(rec, req)
|
||||
|
||||
resp := rec.Result()
|
||||
require.NoError(t, err)
|
||||
|
@ -396,3 +396,54 @@ func TestRoutes(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithCustomEndpoints(t *testing.T) {
|
||||
type args struct {
|
||||
auth *op.Endpoint
|
||||
token *op.Endpoint
|
||||
userInfo *op.Endpoint
|
||||
revocation *op.Endpoint
|
||||
endSession *op.Endpoint
|
||||
keys *op.Endpoint
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "all nil",
|
||||
args: args{},
|
||||
wantErr: op.ErrNilEndpoint,
|
||||
},
|
||||
{
|
||||
name: "all set",
|
||||
args: args{
|
||||
auth: op.NewEndpoint("/authorize"),
|
||||
token: op.NewEndpoint("/oauth/token"),
|
||||
userInfo: op.NewEndpoint("/userinfo"),
|
||||
revocation: op.NewEndpoint("/revoke"),
|
||||
endSession: op.NewEndpoint("/end_session"),
|
||||
keys: op.NewEndpoint("/keys"),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, err := op.NewOpenIDProvider(testIssuer, testConfig,
|
||||
storage.NewStorage(storage.NewUserStore(testIssuer)),
|
||||
op.WithCustomEndpoints(tt.args.auth, tt.args.token, tt.args.userInfo, tt.args.revocation, tt.args.endSession, tt.args.keys),
|
||||
)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
if tt.wantErr != nil {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, tt.args.auth, provider.AuthorizationEndpoint())
|
||||
assert.Equal(t, tt.args.token, provider.TokenEndpoint())
|
||||
assert.Equal(t, tt.args.userInfo, provider.UserinfoEndpoint())
|
||||
assert.Equal(t, tt.args.revocation, provider.RevocationEndpoint())
|
||||
assert.Equal(t, tt.args.endSession, provider.EndSessionEndpoint())
|
||||
assert.Equal(t, tt.args.keys, provider.KeysEndpoint())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"errors"
|
||||
"net/http"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
)
|
||||
|
||||
type ProbesFn func(context.Context) error
|
||||
|
@ -41,9 +41,9 @@ func ReadyStorage(s Storage) ProbesFn {
|
|||
}
|
||||
|
||||
func ok(w http.ResponseWriter) {
|
||||
httphelper.MarshalJSON(w, status{"ok"})
|
||||
httphelper.MarshalJSON(w, Status{"ok"})
|
||||
}
|
||||
|
||||
type status struct {
|
||||
type Status struct {
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
|
346
pkg/op/server.go
Normal file
346
pkg/op/server.go
Normal file
|
@ -0,0 +1,346 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
// Server describes the interface that needs to be implemented to serve
|
||||
// OpenID Connect and Oauth2 standard requests.
|
||||
//
|
||||
// Methods are called after the HTTP route is resolved and
|
||||
// the request body is parsed into the Request's Data field.
|
||||
// When a method is called, it can be assumed that required fields,
|
||||
// as described in their relevant standard, are validated already.
|
||||
// The Response Data field may be of any type to allow flexibility
|
||||
// to extend responses with custom fields. There are however requirements
|
||||
// in the standards regarding the response models. Where applicable
|
||||
// the method documentation gives a recommended type which can be used
|
||||
// directly or extended upon.
|
||||
//
|
||||
// The addition of new methods is not considered a breaking change
|
||||
// as defined by semver rules.
|
||||
// Implementations MUST embed [UnimplementedServer] to maintain
|
||||
// forward compatibility.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type Server interface {
|
||||
// Health returns a status of "ok" once the Server is listening.
|
||||
// The recommended Response Data type is [Status].
|
||||
Health(context.Context, *Request[struct{}]) (*Response, error)
|
||||
|
||||
// Ready returns a status of "ok" once all dependencies,
|
||||
// such as database storage, are ready.
|
||||
// An error can be returned to explain what is not ready.
|
||||
// The recommended Response Data type is [Status].
|
||||
Ready(context.Context, *Request[struct{}]) (*Response, error)
|
||||
|
||||
// Discovery returns the OpenID Provider Configuration Information for this server.
|
||||
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig
|
||||
// The recommended Response Data type is [oidc.DiscoveryConfiguration].
|
||||
Discovery(context.Context, *Request[struct{}]) (*Response, error)
|
||||
|
||||
// Keys serves the JWK set which the client can use verify signatures from the op.
|
||||
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata `jwks_uri` key.
|
||||
// The recommended Response Data type is [jose.JSONWebKeySet].
|
||||
Keys(context.Context, *Request[struct{}]) (*Response, error)
|
||||
|
||||
// VerifyAuthRequest verifies the Auth Request and
|
||||
// adds the Client to the request.
|
||||
//
|
||||
// When the `request` field is populated with a
|
||||
// "Request Object" JWT, it needs to be Validated
|
||||
// and its claims overwrite any fields in the AuthRequest.
|
||||
// If the implementation does not support "Request Object",
|
||||
// it MUST return an [oidc.ErrRequestNotSupported].
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RequestObject
|
||||
VerifyAuthRequest(context.Context, *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error)
|
||||
|
||||
// Authorize initiates the authorization flow and redirects to a login page.
|
||||
// See the various https://openid.net/specs/openid-connect-core-1_0.html
|
||||
// authorize endpoint sections (one for each type of flow).
|
||||
Authorize(context.Context, *ClientRequest[oidc.AuthRequest]) (*Redirect, error)
|
||||
|
||||
// DeviceAuthorization initiates the device authorization flow.
|
||||
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
|
||||
// The recommended Response Data type is [oidc.DeviceAuthorizationResponse].
|
||||
DeviceAuthorization(context.Context, *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error)
|
||||
|
||||
// VerifyClient is called on most oauth/token handlers to authenticate,
|
||||
// using either a secret (POST, Basic) or assertion (JWT).
|
||||
// If no secrets are provided, the client must be public.
|
||||
// This method is called before each method that takes a
|
||||
// [ClientRequest] argument.
|
||||
VerifyClient(context.Context, *Request[ClientCredentials]) (Client, error)
|
||||
|
||||
// CodeExchange returns Tokens after an authorization code
|
||||
// is obtained in a successful Authorize flow.
|
||||
// It is called by the Token endpoint handler when
|
||||
// grant_type has the value authorization_code
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
|
||||
// The recommended Response Data type is [oidc.AccessTokenResponse].
|
||||
CodeExchange(context.Context, *ClientRequest[oidc.AccessTokenRequest]) (*Response, error)
|
||||
|
||||
// RefreshToken returns new Tokens after verifying a Refresh token.
|
||||
// It is called by the Token endpoint handler when
|
||||
// grant_type has the value refresh_token
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
|
||||
// The recommended Response Data type is [oidc.AccessTokenResponse].
|
||||
RefreshToken(context.Context, *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error)
|
||||
|
||||
// JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant
|
||||
// It is called by the Token endpoint handler when
|
||||
// grant_type has the value urn:ietf:params:oauth:grant-type:jwt-bearer
|
||||
// https://datatracker.ietf.org/doc/html/rfc7523#section-2.1
|
||||
// The recommended Response Data type is [oidc.AccessTokenResponse].
|
||||
JWTProfile(context.Context, *Request[oidc.JWTProfileGrantRequest]) (*Response, error)
|
||||
|
||||
// TokenExchange handles the OAuth 2.0 token exchange grant
|
||||
// It is called by the Token endpoint handler when
|
||||
// grant_type has the value urn:ietf:params:oauth:grant-type:token-exchange
|
||||
// https://datatracker.ietf.org/doc/html/rfc8693
|
||||
// The recommended Response Data type is [oidc.AccessTokenResponse].
|
||||
TokenExchange(context.Context, *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error)
|
||||
|
||||
// ClientCredentialsExchange handles the OAuth 2.0 client credentials grant
|
||||
// It is called by the Token endpoint handler when
|
||||
// grant_type has the value client_credentials
|
||||
// https://datatracker.ietf.org/doc/html/rfc6749#section-4.4
|
||||
// The recommended Response Data type is [oidc.AccessTokenResponse].
|
||||
ClientCredentialsExchange(context.Context, *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error)
|
||||
|
||||
// DeviceToken handles the OAuth 2.0 Device Authorization Grant
|
||||
// It is called by the Token endpoint handler when
|
||||
// grant_type has the value urn:ietf:params:oauth:grant-type:device_code.
|
||||
// It is typically called in a polling fashion and appropriate errors
|
||||
// should be returned to signal authorization_pending or access_denied etc.
|
||||
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4,
|
||||
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5.
|
||||
// The recommended Response Data type is [oidc.AccessTokenResponse].
|
||||
DeviceToken(context.Context, *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error)
|
||||
|
||||
// Introspect handles the OAuth 2.0 Token Introspection endpoint.
|
||||
// https://datatracker.ietf.org/doc/html/rfc7662
|
||||
// The recommended Response Data type is [oidc.IntrospectionResponse].
|
||||
Introspect(context.Context, *ClientRequest[oidc.IntrospectionRequest]) (*Response, error)
|
||||
|
||||
// UserInfo handles the UserInfo endpoint and returns Claims about the authenticated End-User.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
// The recommended Response Data type is [oidc.UserInfo].
|
||||
UserInfo(context.Context, *Request[oidc.UserInfoRequest]) (*Response, error)
|
||||
|
||||
// Revocation handles token revocation using an access or refresh token.
|
||||
// https://datatracker.ietf.org/doc/html/rfc7009
|
||||
// There are no response requirements. Data may remain empty.
|
||||
Revocation(context.Context, *ClientRequest[oidc.RevocationRequest]) (*Response, error)
|
||||
|
||||
// EndSession handles the OpenID Connect RP-Initiated Logout.
|
||||
// https://openid.net/specs/openid-connect-rpinitiated-1_0.html
|
||||
// There are no response requirements. Data may remain empty.
|
||||
EndSession(context.Context, *Request[oidc.EndSessionRequest]) (*Redirect, error)
|
||||
|
||||
// mustImpl forces implementations to embed the UnimplementedServer for forward
|
||||
// compatibility with the interface.
|
||||
mustImpl()
|
||||
}
|
||||
|
||||
// Request contains the [http.Request] informational fields
|
||||
// and parsed Data from the request body (POST) or URL parameters (GET).
|
||||
// Data can be assumed to be validated according to the applicable
|
||||
// standard for the specific endpoints.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type Request[T any] struct {
|
||||
Method string
|
||||
URL *url.URL
|
||||
Header http.Header
|
||||
Form url.Values
|
||||
PostForm url.Values
|
||||
Data *T
|
||||
}
|
||||
|
||||
func (r *Request[_]) path() string {
|
||||
return r.URL.Path
|
||||
}
|
||||
|
||||
func newRequest[T any](r *http.Request, data *T) *Request[T] {
|
||||
return &Request[T]{
|
||||
Method: r.Method,
|
||||
URL: r.URL,
|
||||
Header: r.Header,
|
||||
Form: r.Form,
|
||||
PostForm: r.PostForm,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// ClientRequest is a Request with a verified client attached to it.
|
||||
// Methods that receive this argument may assume the client was authenticated,
|
||||
// or verified to be a public client.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type ClientRequest[T any] struct {
|
||||
*Request[T]
|
||||
Client Client
|
||||
}
|
||||
|
||||
func newClientRequest[T any](r *http.Request, data *T, client Client) *ClientRequest[T] {
|
||||
return &ClientRequest[T]{
|
||||
Request: newRequest[T](r, data),
|
||||
Client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// Response object for most [Server] methods.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type Response struct {
|
||||
// Header map will be merged with the
|
||||
// header on the [http.ResponseWriter].
|
||||
Header http.Header
|
||||
|
||||
// Data will be JSON marshaled to
|
||||
// the response body.
|
||||
// We allow any type, so that implementations
|
||||
// can extend the standard types as they wish.
|
||||
// However, each method will recommend which
|
||||
// (base) type to use as model, in order to
|
||||
// be compliant with the standards.
|
||||
Data any
|
||||
}
|
||||
|
||||
// NewResponse creates a new response for data,
|
||||
// without custom headers.
|
||||
func NewResponse(data any) *Response {
|
||||
return &Response{
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func (resp *Response) writeOut(w http.ResponseWriter) {
|
||||
gu.MapMerge(resp.Header, w.Header())
|
||||
httphelper.MarshalJSON(w, resp.Data)
|
||||
}
|
||||
|
||||
// Redirect is a special response type which will
|
||||
// initiate a [http.StatusFound] redirect.
|
||||
// The Params field will be encoded and set to the
|
||||
// URL's RawQuery field before building the URL.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type Redirect struct {
|
||||
// Header map will be merged with the
|
||||
// header on the [http.ResponseWriter].
|
||||
Header http.Header
|
||||
|
||||
URL string
|
||||
}
|
||||
|
||||
func NewRedirect(url string) *Redirect {
|
||||
return &Redirect{URL: url}
|
||||
}
|
||||
|
||||
func (red *Redirect) writeOut(w http.ResponseWriter, r *http.Request) {
|
||||
gu.MapMerge(r.Header, w.Header())
|
||||
http.Redirect(w, r, red.URL, http.StatusFound)
|
||||
}
|
||||
|
||||
type UnimplementedServer struct{}
|
||||
|
||||
// UnimplementedStatusCode is the status code returned for methods
|
||||
// that are not yet implemented.
|
||||
// Note that this means methods in the sense of the Go interface,
|
||||
// and not http methods covered by "501 Not Implemented".
|
||||
var UnimplementedStatusCode = http.StatusNotFound
|
||||
|
||||
func unimplementedError(r interface{ path() string }) StatusError {
|
||||
err := oidc.ErrServerError().WithDescription("%s not implemented on this server", r.path())
|
||||
return NewStatusError(err, UnimplementedStatusCode)
|
||||
}
|
||||
|
||||
func unimplementedGrantError(gt oidc.GrantType) StatusError {
|
||||
err := oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", gt)
|
||||
return NewStatusError(err, http.StatusBadRequest) // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
||||
}
|
||||
|
||||
func (UnimplementedServer) mustImpl() {}
|
||||
|
||||
func (UnimplementedServer) Health(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
|
||||
if r.Data.RequestParam != "" {
|
||||
return nil, oidc.ErrRequestNotSupported()
|
||||
}
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (*Redirect, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeCode)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeBearer)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
480
pkg/op/server_http.go
Normal file
480
pkg/op/server_http.go
Normal file
|
@ -0,0 +1,480 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/rs/cors"
|
||||
"github.com/zitadel/logging"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/schema"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
// RegisterServer registers an implementation of Server.
|
||||
// The resulting handler takes care of routing and request parsing,
|
||||
// with some basic validation of required fields.
|
||||
// The routes can be customized with [WithEndpoints].
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler {
|
||||
decoder := schema.NewDecoder()
|
||||
decoder.IgnoreUnknownKeys(true)
|
||||
|
||||
ws := &webServer{
|
||||
server: server,
|
||||
endpoints: endpoints,
|
||||
decoder: decoder,
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
option(ws)
|
||||
}
|
||||
|
||||
ws.createRouter()
|
||||
return ws
|
||||
}
|
||||
|
||||
type ServerOption func(s *webServer)
|
||||
|
||||
// WithHTTPMiddleware sets the passed middleware chain to the root of
|
||||
// the Server's router.
|
||||
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
|
||||
return func(s *webServer) {
|
||||
s.middleware = m
|
||||
}
|
||||
}
|
||||
|
||||
// WithDecoder overrides the default decoder,
|
||||
// which is a [schema.Decoder] with IgnoreUnknownKeys set to true.
|
||||
func WithDecoder(decoder httphelper.Decoder) ServerOption {
|
||||
return func(s *webServer) {
|
||||
s.decoder = decoder
|
||||
}
|
||||
}
|
||||
|
||||
// WithFallbackLogger overrides the fallback logger, which
|
||||
// is used when no logger was found in the context.
|
||||
// Defaults to [slog.Default].
|
||||
func WithFallbackLogger(logger *slog.Logger) ServerOption {
|
||||
return func(s *webServer) {
|
||||
s.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
type webServer struct {
|
||||
http.Handler
|
||||
server Server
|
||||
middleware []func(http.Handler) http.Handler
|
||||
endpoints Endpoints
|
||||
decoder httphelper.Decoder
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
|
||||
if logger, ok := logging.FromContext(ctx); ok {
|
||||
return logger
|
||||
}
|
||||
return s.logger
|
||||
}
|
||||
|
||||
func (s *webServer) createRouter() {
|
||||
router := chi.NewRouter()
|
||||
router.Use(cors.New(defaultCORSOptions).Handler)
|
||||
router.Use(s.middleware...)
|
||||
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
|
||||
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
|
||||
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
|
||||
|
||||
s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler)
|
||||
s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
|
||||
s.endpointRoute(router, s.endpoints.Token, s.tokensHandler)
|
||||
s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler))
|
||||
s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler)
|
||||
s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler))
|
||||
s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler)
|
||||
s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
|
||||
s.Handler = router
|
||||
}
|
||||
|
||||
func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) {
|
||||
if e != nil {
|
||||
router.HandleFunc(e.Relative(), hf)
|
||||
s.logger.Info("registered route", "endpoint", e.Relative())
|
||||
}
|
||||
}
|
||||
|
||||
type clientHandler func(w http.ResponseWriter, r *http.Request, client Client)
|
||||
|
||||
func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
client, err := s.verifyRequestClient(r)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" {
|
||||
if !ValidateGrantType(client, grantType) {
|
||||
WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
}
|
||||
handler(w, r, client)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
|
||||
if err = r.ParseForm(); err != nil {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
|
||||
}
|
||||
cc := new(ClientCredentials)
|
||||
if err = s.decoder.Decode(cc, r.Form); err != nil {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
|
||||
}
|
||||
// Basic auth takes precedence, so if set it overwrites the form data.
|
||||
if clientID, clientSecret, ok := r.BasicAuth(); ok {
|
||||
cc.ClientID, err = url.QueryUnescape(clientID)
|
||||
if err != nil {
|
||||
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
|
||||
}
|
||||
cc.ClientSecret, err = url.QueryUnescape(clientSecret)
|
||||
if err != nil {
|
||||
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
|
||||
}
|
||||
}
|
||||
if cc.ClientID == "" && cc.ClientAssertion == "" {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided")
|
||||
}
|
||||
if cc.ClientAssertion != "" && cc.ClientAssertionType != oidc.ClientAssertionTypeJWTAssertion {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type %s", cc.ClientAssertionType)
|
||||
}
|
||||
return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{
|
||||
Method: r.Method,
|
||||
URL: r.URL,
|
||||
Header: r.Header,
|
||||
Form: r.Form,
|
||||
Data: cc,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
redirect, err := s.authorize(r.Context(), newRequest(r, request))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
redirect.writeOut(w, r)
|
||||
}
|
||||
|
||||
func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) {
|
||||
cr, err := s.server.VerifyAuthRequest(ctx, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq := cr.Data
|
||||
if authReq.RedirectURI == "" {
|
||||
return nil, ErrAuthReqMissingRedirectURI
|
||||
}
|
||||
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq.Scopes, err = ValidateAuthReqScopes(cr.Client, authReq.Scopes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ValidateAuthReqRedirectURI(cr.Client, authReq.RedirectURI, authReq.ResponseType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ValidateAuthReqResponseType(cr.Client, authReq.ResponseType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.server.Authorize(ctx, cr)
|
||||
}
|
||||
|
||||
func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
|
||||
switch grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType {
|
||||
case oidc.GrantTypeCode:
|
||||
s.withClient(s.codeExchangeHandler)(w, r)
|
||||
case oidc.GrantTypeRefreshToken:
|
||||
s.withClient(s.refreshTokenHandler)(w, r)
|
||||
case oidc.GrantTypeClientCredentials:
|
||||
s.withClient(s.clientCredentialsHandler)(w, r)
|
||||
case oidc.GrantTypeBearer:
|
||||
s.jwtProfileHandler(w, r)
|
||||
case oidc.GrantTypeTokenExchange:
|
||||
s.withClient(s.tokenExchangeHandler)(w, r)
|
||||
case oidc.GrantTypeDeviceCode:
|
||||
s.withClient(s.deviceTokenHandler)(w, r)
|
||||
case "":
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.getLogger(r.Context()))
|
||||
default:
|
||||
WriteError(w, r, unimplementedGrantError(grantType), s.getLogger(r.Context()))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) {
|
||||
request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.Assertion == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.Code == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.RedirectURI == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.RefreshToken == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.SubjectToken == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.SubjectTokenType == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if !request.SubjectTokenType.IsSupported() {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
if client.AuthMethod() == oidc.AuthMethodNone {
|
||||
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
|
||||
request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.DeviceCode == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
if client.AuthMethod() == oidc.AuthMethodNone {
|
||||
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.Token == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
|
||||
request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if token, err := getAccessToken(r); err == nil {
|
||||
request.AccessToken = token
|
||||
}
|
||||
if request.AccessToken == "" {
|
||||
err = NewStatusError(
|
||||
oidc.ErrInvalidRequest().WithDescription("access token missing"),
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.UserInfo(r.Context(), newRequest(r, request))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.Token == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.EndSession(r.Context(), newRequest(r, request))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w, r)
|
||||
}
|
||||
|
||||
func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := method(r.Context(), newRequest(r, &struct{}{}))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
}
|
||||
|
||||
func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) {
|
||||
dst := new(R)
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
|
||||
}
|
||||
form := r.Form
|
||||
if postOnly {
|
||||
form = r.PostForm
|
||||
}
|
||||
if err := decoder.Decode(dst, form); err != nil {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
|
||||
}
|
||||
return dst, nil
|
||||
}
|
345
pkg/op/server_http_routes_test.go
Normal file
345
pkg/op/server_http_routes_test.go
Normal file
|
@ -0,0 +1,345 @@
|
|||
package op_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/client"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
func jwtProfile() (string, error) {
|
||||
keyData, err := client.ConfigFromKeyFile("../../example/server/service-key1.json")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
signer, err := client.NewSignerFromPrivateKeyByte([]byte(keyData.Key), keyData.KeyID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return client.SignedJWTProfileAssertion(keyData.UserID, []string{testIssuer}, time.Hour, signer)
|
||||
}
|
||||
|
||||
func TestServerRoutes(t *testing.T) {
|
||||
server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints)
|
||||
|
||||
storage := testProvider.Storage().(routesTestStorage)
|
||||
ctx := op.ContextWithIssuer(context.Background(), testIssuer)
|
||||
|
||||
client, err := storage.GetClientByClientID(ctx, "web")
|
||||
require.NoError(t, err)
|
||||
|
||||
oidcAuthReq := &oidc.AuthRequest{
|
||||
ClientID: client.GetID(),
|
||||
RedirectURI: "https://example.com",
|
||||
MaxAge: gu.Ptr[uint](300),
|
||||
Scopes: oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, oidc.ScopeEmail, oidc.ScopeProfile, oidc.ScopePhone},
|
||||
ResponseType: oidc.ResponseTypeCode,
|
||||
}
|
||||
|
||||
authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1")
|
||||
require.NoError(t, err)
|
||||
storage.AuthRequestDone(authReq.GetID())
|
||||
|
||||
accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
|
||||
require.NoError(t, err)
|
||||
accessTokenRevoke, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
|
||||
require.NoError(t, err)
|
||||
idToken, err := op.CreateIDToken(ctx, testIssuer, authReq, time.Hour, accessToken, "123", storage, client)
|
||||
require.NoError(t, err)
|
||||
jwtToken, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeJWT, testProvider, client, "")
|
||||
require.NoError(t, err)
|
||||
jwtProfileToken, err := jwtProfile()
|
||||
require.NoError(t, err)
|
||||
|
||||
oidcAuthReq.IDTokenHint = idToken
|
||||
|
||||
serverURL, err := url.Parse(testIssuer)
|
||||
require.NoError(t, err)
|
||||
|
||||
type basicAuth struct {
|
||||
username, password string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
basicAuth *basicAuth
|
||||
header map[string]string
|
||||
values map[string]string
|
||||
body map[string]string
|
||||
wantCode int
|
||||
headerContains map[string]string
|
||||
json string // test for exact json output
|
||||
contains []string // when the body output is not constant, we just check for snippets to be present in the response
|
||||
}{
|
||||
{
|
||||
name: "health",
|
||||
method: http.MethodGet,
|
||||
path: "/healthz",
|
||||
wantCode: http.StatusOK,
|
||||
json: `{"status":"ok"}`,
|
||||
},
|
||||
{
|
||||
name: "ready",
|
||||
method: http.MethodGet,
|
||||
path: "/ready",
|
||||
wantCode: http.StatusOK,
|
||||
json: `{"status":"ok"}`,
|
||||
},
|
||||
{
|
||||
name: "discovery",
|
||||
method: http.MethodGet,
|
||||
path: oidc.DiscoveryEndpoint,
|
||||
wantCode: http.StatusOK,
|
||||
json: `{"issuer":"https://localhost:9998/","authorization_endpoint":"https://localhost:9998/authorize","token_endpoint":"https://localhost:9998/oauth/token","introspection_endpoint":"https://localhost:9998/oauth/introspect","userinfo_endpoint":"https://localhost:9998/userinfo","revocation_endpoint":"https://localhost:9998/revoke","end_session_endpoint":"https://localhost:9998/end_session","device_authorization_endpoint":"https://localhost:9998/device_authorization","jwks_uri":"https://localhost:9998/keys","scopes_supported":["openid","profile","email","phone","address","offline_access"],"response_types_supported":["code","id_token","id_token token"],"grant_types_supported":["authorization_code","implicit","refresh_token","client_credentials","urn:ietf:params:oauth:grant-type:token-exchange","urn:ietf:params:oauth:grant-type:jwt-bearer","urn:ietf:params:oauth:grant-type:device_code"],"subject_types_supported":["public"],"id_token_signing_alg_values_supported":["RS256"],"request_object_signing_alg_values_supported":["RS256"],"token_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"token_endpoint_auth_signing_alg_values_supported":["RS256"],"revocation_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["RS256"],"introspection_endpoint_auth_methods_supported":["client_secret_basic","private_key_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["RS256"],"claims_supported":["sub","aud","exp","iat","iss","auth_time","nonce","acr","amr","c_hash","at_hash","act","scopes","client_id","azp","preferred_username","name","family_name","given_name","locale","email","email_verified","phone_number","phone_number_verified"],"code_challenge_methods_supported":["S256"],"ui_locales_supported":["en"],"request_parameter_supported":true,"request_uri_parameter_supported":false}`,
|
||||
},
|
||||
{
|
||||
name: "authorization",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.AuthorizationEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"client_id": client.GetID(),
|
||||
"redirect_uri": "https://example.com",
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
"response_type": string(oidc.ResponseTypeCode),
|
||||
},
|
||||
wantCode: http.StatusFound,
|
||||
headerContains: map[string]string{"Location": "/login/username?authRequestID="},
|
||||
},
|
||||
{
|
||||
// This call will fail. A successfull test is already
|
||||
// part of client/integration_test.go
|
||||
name: "code exchange",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeCode),
|
||||
"client_id": client.GetID(),
|
||||
"client_secret": "secret",
|
||||
"redirect_uri": "https://example.com",
|
||||
"code": "123",
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
json: `{"error":"invalid_grant", "error_description":"invalid code"}`,
|
||||
},
|
||||
{
|
||||
name: "JWT authorization",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeBearer),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
"assertion": jwtProfileToken,
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{`{"access_token":`, `"token_type":"Bearer","expires_in":299}`},
|
||||
},
|
||||
{
|
||||
name: "Token exchange",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"web", "secret"},
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeTokenExchange),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
"subject_token": jwtToken,
|
||||
"subject_token_type": string(oidc.AccessTokenType),
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{
|
||||
`{"access_token":"`,
|
||||
`","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Client credentials exchange",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"sid1", "verysecret"},
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeClientCredentials),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`},
|
||||
},
|
||||
{
|
||||
// This call will fail. A successfull test is already
|
||||
// part of device_test.go
|
||||
name: "device token",
|
||||
method: http.MethodPost,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"device", "secret"},
|
||||
header: map[string]string{
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
body: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeDeviceCode),
|
||||
"device_code": "123",
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
json: `{"error":"access_denied","error_description":"The authorization request was denied."}`,
|
||||
},
|
||||
{
|
||||
name: "missing grant type",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
wantCode: http.StatusBadRequest,
|
||||
json: `{"error":"invalid_request","error_description":"grant_type missing"}`,
|
||||
},
|
||||
{
|
||||
name: "unsupported grant type",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"grant_type": "foo",
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
json: `{"error":"unsupported_grant_type","error_description":"foo not supported"}`,
|
||||
},
|
||||
{
|
||||
name: "introspection",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.IntrospectionEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"web", "secret"},
|
||||
values: map[string]string{
|
||||
"token": accessToken,
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
json: `{"active":true,"scope":"openid offline_access email profile phone","client_id":"web","sub":"id1","username":"test-user@localhost","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
|
||||
},
|
||||
{
|
||||
name: "user info",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.UserinfoEndpoint().Relative(),
|
||||
header: map[string]string{
|
||||
"authorization": "Bearer " + accessToken,
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
json: `{"sub":"id1","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
|
||||
},
|
||||
{
|
||||
name: "refresh token",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeRefreshToken),
|
||||
"refresh_token": refreshToken,
|
||||
"client_id": client.GetID(),
|
||||
"client_secret": "secret",
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{
|
||||
`{"access_token":"`,
|
||||
`","token_type":"Bearer","refresh_token":"`,
|
||||
`","expires_in":299,"id_token":"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "revoke",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.RevocationEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"web", "secret"},
|
||||
values: map[string]string{
|
||||
"token": accessTokenRevoke,
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "end session",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.EndSessionEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"id_token_hint": idToken,
|
||||
"client_id": "web",
|
||||
},
|
||||
wantCode: http.StatusFound,
|
||||
headerContains: map[string]string{"Location": "/logged-out"},
|
||||
contains: []string{`<a href="/logged-out">Found</a>.`},
|
||||
},
|
||||
{
|
||||
name: "keys",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.KeysEndpoint().Relative(),
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{
|
||||
`{"keys":[{"use":"sig","kty":"RSA","kid":"`,
|
||||
`","alg":"RS256","n":"`, `","e":"AQAB"}]}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "device authorization",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.DeviceAuthorizationEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"device", "secret"},
|
||||
values: map[string]string{
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{
|
||||
`{"device_code":"`, `","user_code":"`,
|
||||
`","verification_uri":"https://localhost:9998/device"`,
|
||||
`"verification_uri_complete":"https://localhost:9998/device?user_code=`,
|
||||
`","expires_in":300,"interval":5}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
u := gu.PtrCopy(serverURL)
|
||||
u.Path = tt.path
|
||||
if tt.values != nil {
|
||||
u.RawQuery = mapAsValues(tt.values)
|
||||
}
|
||||
var body io.Reader
|
||||
if tt.body != nil {
|
||||
body = strings.NewReader(mapAsValues(tt.body))
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(tt.method, u.String(), body)
|
||||
for k, v := range tt.header {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
if tt.basicAuth != nil {
|
||||
req.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
server.ServeHTTP(rec, req)
|
||||
|
||||
resp := rec.Result()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantCode, resp.StatusCode)
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
respBodyString := string(respBody)
|
||||
t.Log(respBodyString)
|
||||
t.Log(resp.Header)
|
||||
|
||||
if tt.json != "" {
|
||||
assert.JSONEq(t, tt.json, respBodyString)
|
||||
}
|
||||
for _, c := range tt.contains {
|
||||
assert.Contains(t, respBodyString, c)
|
||||
}
|
||||
for k, v := range tt.headerContains {
|
||||
assert.Contains(t, resp.Header.Get(k), v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
1333
pkg/op/server_http_test.go
Normal file
1333
pkg/op/server_http_test.go
Normal file
File diff suppressed because it is too large
Load diff
344
pkg/op/server_legacy.go
Normal file
344
pkg/op/server_legacy.go
Normal file
|
@ -0,0 +1,344 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
// LegacyServer is an implementation of [Server[] that
|
||||
// simply wraps a [OpenIDProvider].
|
||||
// It can be used to transition from the former Provider/Storage
|
||||
// interfaces to the new Server interface.
|
||||
type LegacyServer struct {
|
||||
UnimplementedServer
|
||||
provider OpenIDProvider
|
||||
endpoints Endpoints
|
||||
}
|
||||
|
||||
// NewLegacyServer wraps provider in a `Server` and returns a handler which is
|
||||
// the Server's router.
|
||||
//
|
||||
// Only non-nil endpoints will be registered on the router.
|
||||
// Nil endpoints are disabled.
|
||||
//
|
||||
// The passed endpoints is also set to the provider,
|
||||
// to be consistent with the discovery config.
|
||||
// Any `With*Endpoint()` option used on the provider is
|
||||
// therefore ineffective.
|
||||
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler {
|
||||
server := RegisterServer(&LegacyServer{
|
||||
provider: provider,
|
||||
endpoints: endpoints,
|
||||
}, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Mount("/", server)
|
||||
router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider))
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
return NewResponse(Status{Status: "ok"}), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
for _, probe := range s.provider.Probes() {
|
||||
// shouldn't we run probes in Go routines?
|
||||
if err := probe(ctx); err != nil {
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
return NewResponse(Status{Status: "ok"}), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
return NewResponse(
|
||||
createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints),
|
||||
), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
keys, err := s.provider.Storage().KeySet(ctx)
|
||||
if err != nil {
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
return NewResponse(jsonWebKeySet(keys)), nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrAuthReqMissingClientID = errors.New("auth request is missing client_id")
|
||||
ErrAuthReqMissingRedirectURI = errors.New("auth request is missing redirect_uri")
|
||||
)
|
||||
|
||||
func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
|
||||
if r.Data.RequestParam != "" {
|
||||
if !s.provider.RequestObjectSupported() {
|
||||
return nil, oidc.ErrRequestNotSupported()
|
||||
}
|
||||
err := ParseRequestObject(ctx, r.Data, s.provider.Storage(), IssuerFromContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if r.Data.ClientID == "" {
|
||||
return nil, ErrAuthReqMissingClientID
|
||||
}
|
||||
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
|
||||
if err != nil {
|
||||
return nil, oidc.DefaultToServerError(err, "unable to retrieve client by id")
|
||||
}
|
||||
|
||||
return &ClientRequest[oidc.AuthRequest]{
|
||||
Request: r,
|
||||
Client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (_ *Redirect, err error) {
|
||||
userID, err := ValidateAuthReqIDTokenHint(ctx, r.Data.IDTokenHint, s.provider.IDTokenHintVerifier(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := s.provider.Storage().CreateAuthRequest(ctx, r.Data, userID)
|
||||
if err != nil {
|
||||
return TryErrorRedirect(ctx, r.Data, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger())
|
||||
}
|
||||
return NewRedirect(r.Client.LoginURL(req.GetID())), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
|
||||
response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider)
|
||||
if err != nil {
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
return NewResponse(response), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
|
||||
if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials {
|
||||
storage, ok := s.provider.Storage().(ClientCredentialsStorage)
|
||||
if !ok {
|
||||
return nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported")
|
||||
}
|
||||
return storage.ClientCredentials(ctx, r.Data.ClientID, r.Data.ClientSecret)
|
||||
}
|
||||
|
||||
if r.Data.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
|
||||
jwtExchanger, ok := s.provider.(JWTAuthorizationGrantExchanger)
|
||||
if !ok || !s.provider.AuthMethodPrivateKeyJWTSupported() {
|
||||
return nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
|
||||
}
|
||||
return AuthorizePrivateJWTKey(ctx, r.Data.ClientAssertion, jwtExchanger)
|
||||
}
|
||||
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
|
||||
if err != nil {
|
||||
return nil, oidc.ErrInvalidClient().WithParent(err)
|
||||
}
|
||||
|
||||
switch client.AuthMethod() {
|
||||
case oidc.AuthMethodNone:
|
||||
return client, nil
|
||||
case oidc.AuthMethodPrivateKeyJWT:
|
||||
return nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client")
|
||||
case oidc.AuthMethodPost:
|
||||
if !s.provider.AuthMethodPostSupported() {
|
||||
return nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
|
||||
}
|
||||
}
|
||||
|
||||
err = AuthorizeClientIDSecret(ctx, r.Data.ClientID, r.Data.ClientSecret, s.provider.Storage())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) {
|
||||
authReq, err := AuthRequestByCode(ctx, s.provider.Storage(), r.Data.Code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.Client.AuthMethod() == oidc.AuthMethodNone {
|
||||
if err = AuthorizeCodeChallenge(r.Data.CodeVerifier, authReq.GetCodeChallenge()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.provider, true, r.Data.Code, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) {
|
||||
if !s.provider.GrantTypeRefreshTokenSupported() {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken)
|
||||
}
|
||||
request, err := RefreshTokenRequestByRefreshToken(ctx, s.provider.Storage(), r.Data.RefreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.Client.GetID() != request.GetClientID() {
|
||||
return nil, oidc.ErrInvalidGrant()
|
||||
}
|
||||
if err = ValidateRefreshTokenScopes(r.Data.Scopes, request); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := CreateTokenResponse(ctx, request, r.Client, s.provider, true, "", r.Data.RefreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) {
|
||||
exchanger, ok := s.provider.(JWTAuthorizationGrantExchanger)
|
||||
if !ok {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeBearer)
|
||||
}
|
||||
tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := CreateJWTTokenResponse(ctx, tokenRequest, exchanger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) {
|
||||
if !s.provider.GrantTypeTokenExchangeSupported() {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
|
||||
}
|
||||
tokenExchangeRequest, err := CreateTokenExchangeRequest(ctx, r.Data, r.Client, s.provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := CreateTokenExchangeResponse(ctx, tokenExchangeRequest, r.Client, s.provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) {
|
||||
storage, ok := s.provider.Storage().(ClientCredentialsStorage)
|
||||
if !ok {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials)
|
||||
}
|
||||
tokenRequest, err := storage.ClientCredentialsTokenRequest(ctx, r.Client.GetID(), r.Data.Scope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := CreateClientCredentialsTokenResponse(ctx, tokenRequest, s.provider, r.Client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
|
||||
if !s.provider.GrantTypeClientCredentialsSupported() {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
|
||||
}
|
||||
// use a limited context timeout shorter as the default
|
||||
// poll interval of 5 seconds.
|
||||
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
|
||||
defer cancel()
|
||||
|
||||
state, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenRequest := &deviceAccessTokenRequest{
|
||||
subject: state.Subject,
|
||||
audience: []string{r.Client.GetID()},
|
||||
scopes: state.Scopes,
|
||||
}
|
||||
resp, err := CreateDeviceTokenResponse(ctx, tokenRequest, s.provider, r.Client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) {
|
||||
response := new(oidc.IntrospectionResponse)
|
||||
tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.Token)
|
||||
if !ok {
|
||||
return NewResponse(response), nil
|
||||
}
|
||||
err := s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, r.Client.GetID())
|
||||
if err != nil {
|
||||
return NewResponse(response), nil
|
||||
}
|
||||
response.Active = true
|
||||
return NewResponse(response), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) {
|
||||
tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.AccessToken)
|
||||
if !ok {
|
||||
return nil, NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid"), http.StatusUnauthorized)
|
||||
}
|
||||
info := new(oidc.UserInfo)
|
||||
err := s.provider.Storage().SetUserinfoFromToken(ctx, info, tokenID, subject, r.Header.Get("origin"))
|
||||
if err != nil {
|
||||
return nil, NewStatusError(err, http.StatusForbidden)
|
||||
}
|
||||
return NewResponse(info), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) {
|
||||
var subject string
|
||||
doDecrypt := true
|
||||
if r.Data.TokenTypeHint != "access_token" {
|
||||
userID, tokenID, err := s.provider.Storage().GetRefreshTokenInfo(ctx, r.Client.GetID(), r.Data.Token)
|
||||
if err != nil {
|
||||
// An invalid refresh token means that we'll try other things (leaving doDecrypt==true)
|
||||
if !errors.Is(err, ErrInvalidRefreshToken) {
|
||||
return nil, RevocationError(oidc.ErrServerError().WithParent(err))
|
||||
}
|
||||
} else {
|
||||
r.Data.Token = tokenID
|
||||
subject = userID
|
||||
doDecrypt = false
|
||||
}
|
||||
}
|
||||
if doDecrypt {
|
||||
tokenID, userID, ok := getTokenIDAndSubjectForRevocation(ctx, s.provider, r.Data.Token)
|
||||
if ok {
|
||||
r.Data.Token = tokenID
|
||||
subject = userID
|
||||
}
|
||||
}
|
||||
if err := s.provider.Storage().RevokeToken(ctx, r.Data.Token, subject, r.Client.GetID()); err != nil {
|
||||
return nil, RevocationError(err)
|
||||
}
|
||||
return NewResponse(nil), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) {
|
||||
session, err := ValidateEndSessionRequest(ctx, r.Data, s.provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = s.provider.Storage().TerminateSession(ctx, session.UserID, session.ClientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewRedirect(session.RedirectURI), nil
|
||||
}
|
5
pkg/op/server_test.go
Normal file
5
pkg/op/server_test.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package op
|
||||
|
||||
// implementation check
|
||||
var _ Server = &UnimplementedServer{}
|
||||
var _ Server = &LegacyServer{}
|
|
@ -6,15 +6,17 @@ import (
|
|||
"net/url"
|
||||
"path"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
type SessionEnder interface {
|
||||
Decoder() httphelper.Decoder
|
||||
Storage() Storage
|
||||
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
|
||||
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
|
||||
DefaultLogoutRedirectURI() string
|
||||
Logger() *slog.Logger
|
||||
}
|
||||
|
||||
func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) {
|
||||
|
@ -31,7 +33,7 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
|
|||
}
|
||||
session, err := ValidateEndSessionRequest(r.Context(), req, ender)
|
||||
if err != nil {
|
||||
RequestError(w, r, err)
|
||||
RequestError(w, r, err, ender.Logger())
|
||||
return
|
||||
}
|
||||
redirect := session.RedirectURI
|
||||
|
@ -41,7 +43,7 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
|
|||
err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID)
|
||||
}
|
||||
if err != nil {
|
||||
RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"))
|
||||
RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"), ender.Logger())
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, redirect, http.StatusFound)
|
||||
|
|
|
@ -3,7 +3,7 @@ package op
|
|||
import (
|
||||
"errors"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
)
|
||||
|
||||
var ErrSignerCreationFailed = errors.New("signer creation failed")
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue