make endpoints pointers to enable/disable them

This commit is contained in:
Tim Möhlmann 2023-09-27 18:09:00 +03:00
parent f6cb47fbbb
commit af22c1a4d8
12 changed files with 229 additions and 93 deletions

View file

@ -79,7 +79,7 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer
handler := http.Handler(provider) handler := http.Handler(provider)
if wrapServer { if wrapServer {
handler = op.NewLegacyServer(provider) handler = op.NewLegacyServer(provider, *op.DefaultEndpoints)
} }
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)

View file

@ -20,14 +20,14 @@ var (
type Configuration interface { type Configuration interface {
IssuerFromRequest(r *http.Request) string IssuerFromRequest(r *http.Request) string
Insecure() bool Insecure() bool
AuthorizationEndpoint() Endpoint AuthorizationEndpoint() *Endpoint
TokenEndpoint() Endpoint TokenEndpoint() *Endpoint
IntrospectionEndpoint() Endpoint IntrospectionEndpoint() *Endpoint
UserinfoEndpoint() Endpoint UserinfoEndpoint() *Endpoint
RevocationEndpoint() Endpoint RevocationEndpoint() *Endpoint
EndSessionEndpoint() Endpoint EndSessionEndpoint() *Endpoint
KeysEndpoint() Endpoint KeysEndpoint() *Endpoint
DeviceAuthorizationEndpoint() Endpoint DeviceAuthorizationEndpoint() *Endpoint
AuthMethodPostSupported() bool AuthMethodPostSupported() bool
CodeMethodS256Supported() bool CodeMethodS256Supported() bool

View file

@ -64,6 +64,37 @@ func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage Di
} }
} }
func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage DiscoverStorage, endpoints *Endpoints) *oidc.DiscoveryConfiguration {
issuer := IssuerFromContext(ctx)
return &oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: endpoints.Authorization.Absolute(issuer),
TokenEndpoint: endpoints.Token.Absolute(issuer),
IntrospectionEndpoint: endpoints.Introspection.Absolute(issuer),
UserinfoEndpoint: endpoints.Userinfo.Absolute(issuer),
RevocationEndpoint: endpoints.Revocation.Absolute(issuer),
EndSessionEndpoint: endpoints.EndSession.Absolute(issuer),
JwksURI: endpoints.JwksURI.Absolute(issuer),
DeviceAuthorizationEndpoint: endpoints.DeviceAuthorization.Absolute(issuer),
ScopesSupported: Scopes(config),
ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config),
SubjectTypesSupported: SubjectTypes(config),
IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config),
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config),
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config),
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config),
ClaimsSupported: SupportedClaims(config),
CodeChallengeMethodsSupported: CodeChallengeMethods(config),
UILocalesSupported: config.SupportedUILocales(),
RequestParameterSupported: config.RequestObjectSupported(),
}
}
func Scopes(c Configuration) []string { func Scopes(c Configuration) []string {
return DefaultSupportedScopes // TODO: config return DefaultSupportedScopes // TODO: config
} }

View file

@ -1,32 +1,46 @@
package op package op
import "strings" import (
"errors"
"strings"
)
type Endpoint struct { type Endpoint struct {
path string path string
url string url string
} }
func NewEndpoint(path string) Endpoint { func NewEndpoint(path string) *Endpoint {
return Endpoint{path: path} return &Endpoint{path: path}
} }
func NewEndpointWithURL(path, url string) Endpoint { func NewEndpointWithURL(path, url string) *Endpoint {
return Endpoint{path: path, url: url} return &Endpoint{path: path, url: url}
} }
func (e Endpoint) Relative() string { func (e *Endpoint) Relative() string {
if e == nil {
return ""
}
return relativeEndpoint(e.path) return relativeEndpoint(e.path)
} }
func (e Endpoint) Absolute(host string) string { func (e *Endpoint) Absolute(host string) string {
if e == nil {
return ""
}
if e.url != "" { if e.url != "" {
return e.url return e.url
} }
return absoluteEndpoint(host, e.path) return absoluteEndpoint(host, e.path)
} }
func (e Endpoint) Validate() error { var ErrNilEndpoint = errors.New("nil endpoint")
func (e *Endpoint) Validate() error {
if e == nil {
return ErrNilEndpoint
}
return nil // TODO: return nil // TODO:
} }

View file

@ -3,13 +3,14 @@ package op_test
import ( import (
"testing" "testing"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
) )
func TestEndpoint_Path(t *testing.T) { func TestEndpoint_Path(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
e op.Endpoint e *op.Endpoint
want string want string
}{ }{
{ {
@ -27,6 +28,11 @@ func TestEndpoint_Path(t *testing.T) {
op.NewEndpointWithURL("/test", "http://test.com/test"), op.NewEndpointWithURL("/test", "http://test.com/test"),
"/test", "/test",
}, },
{
"nil",
nil,
"",
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -43,7 +49,7 @@ func TestEndpoint_Absolute(t *testing.T) {
} }
tests := []struct { tests := []struct {
name string name string
e op.Endpoint e *op.Endpoint
args args args args
want string want string
}{ }{
@ -77,6 +83,12 @@ func TestEndpoint_Absolute(t *testing.T) {
args{"https://host"}, args{"https://host"},
"https://test.com/test", "https://test.com/test",
}, },
{
"nil",
nil,
args{"https://host"},
"",
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -91,16 +103,19 @@ func TestEndpoint_Absolute(t *testing.T) {
func TestEndpoint_Validate(t *testing.T) { func TestEndpoint_Validate(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
e op.Endpoint e *op.Endpoint
wantErr bool wantErr error
}{ }{
// TODO: Add test cases. {
"nil",
nil,
op.ErrNilEndpoint,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.e.Validate(); (err != nil) != tt.wantErr { err := tt.e.Validate()
t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr) require.ErrorIs(t, err, tt.wantErr)
}
}) })
} }
} }

View file

@ -65,10 +65,10 @@ func (mr *MockConfigurationMockRecorder) AuthMethodPrivateKeyJWTSupported() *gom
} }
// AuthorizationEndpoint mocks base method. // AuthorizationEndpoint mocks base method.
func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint { func (m *MockConfiguration) AuthorizationEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthorizationEndpoint") ret := m.ctrl.Call(m, "AuthorizationEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }
@ -107,10 +107,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
} }
// DeviceAuthorizationEndpoint mocks base method. // DeviceAuthorizationEndpoint mocks base method.
func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint { func (m *MockConfiguration) DeviceAuthorizationEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint") ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }
@ -121,10 +121,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.C
} }
// EndSessionEndpoint mocks base method. // EndSessionEndpoint mocks base method.
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint { func (m *MockConfiguration) EndSessionEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "EndSessionEndpoint") ret := m.ctrl.Call(m, "EndSessionEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }
@ -233,10 +233,10 @@ func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSup
} }
// IntrospectionEndpoint mocks base method. // IntrospectionEndpoint mocks base method.
func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint { func (m *MockConfiguration) IntrospectionEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IntrospectionEndpoint") ret := m.ctrl.Call(m, "IntrospectionEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }
@ -275,10 +275,10 @@ func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *go
} }
// KeysEndpoint mocks base method. // KeysEndpoint mocks base method.
func (m *MockConfiguration) KeysEndpoint() op.Endpoint { func (m *MockConfiguration) KeysEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeysEndpoint") ret := m.ctrl.Call(m, "KeysEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }
@ -331,10 +331,10 @@ func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSuppor
} }
// RevocationEndpoint mocks base method. // RevocationEndpoint mocks base method.
func (m *MockConfiguration) RevocationEndpoint() op.Endpoint { func (m *MockConfiguration) RevocationEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevocationEndpoint") ret := m.ctrl.Call(m, "RevocationEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }
@ -373,10 +373,10 @@ func (mr *MockConfigurationMockRecorder) SupportedUILocales() *gomock.Call {
} }
// TokenEndpoint mocks base method. // TokenEndpoint mocks base method.
func (m *MockConfiguration) TokenEndpoint() op.Endpoint { func (m *MockConfiguration) TokenEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TokenEndpoint") ret := m.ctrl.Call(m, "TokenEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }
@ -401,10 +401,10 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
} }
// UserinfoEndpoint mocks base method. // UserinfoEndpoint mocks base method.
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint { func (m *MockConfiguration) UserinfoEndpoint() *op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserinfoEndpoint") ret := m.ctrl.Call(m, "UserinfoEndpoint")
ret0, _ := ret[0].(op.Endpoint) ret0, _ := ret[0].(*op.Endpoint)
return ret0 return ret0
} }

View file

@ -131,16 +131,17 @@ type Config struct {
DeviceAuthorization DeviceAuthorizationConfig DeviceAuthorization DeviceAuthorizationConfig
} }
// Endpoints defines endpoint routes.
type Endpoints struct { type Endpoints struct {
Authorization Endpoint Authorization *Endpoint
Token Endpoint Token *Endpoint
Introspection Endpoint Introspection *Endpoint
Userinfo Endpoint Userinfo *Endpoint
Revocation Endpoint Revocation *Endpoint
EndSession Endpoint EndSession *Endpoint
CheckSessionIframe Endpoint CheckSessionIframe *Endpoint
JwksURI Endpoint JwksURI *Endpoint
DeviceAuthorization Endpoint DeviceAuthorization *Endpoint
} }
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler()) // NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
@ -233,35 +234,35 @@ func (o *Provider) Insecure() bool {
return o.insecure return o.insecure
} }
func (o *Provider) AuthorizationEndpoint() Endpoint { func (o *Provider) AuthorizationEndpoint() *Endpoint {
return o.endpoints.Authorization return o.endpoints.Authorization
} }
func (o *Provider) TokenEndpoint() Endpoint { func (o *Provider) TokenEndpoint() *Endpoint {
return o.endpoints.Token return o.endpoints.Token
} }
func (o *Provider) IntrospectionEndpoint() Endpoint { func (o *Provider) IntrospectionEndpoint() *Endpoint {
return o.endpoints.Introspection return o.endpoints.Introspection
} }
func (o *Provider) UserinfoEndpoint() Endpoint { func (o *Provider) UserinfoEndpoint() *Endpoint {
return o.endpoints.Userinfo return o.endpoints.Userinfo
} }
func (o *Provider) RevocationEndpoint() Endpoint { func (o *Provider) RevocationEndpoint() *Endpoint {
return o.endpoints.Revocation return o.endpoints.Revocation
} }
func (o *Provider) EndSessionEndpoint() Endpoint { func (o *Provider) EndSessionEndpoint() *Endpoint {
return o.endpoints.EndSession return o.endpoints.EndSession
} }
func (o *Provider) DeviceAuthorizationEndpoint() Endpoint { func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint {
return o.endpoints.DeviceAuthorization return o.endpoints.DeviceAuthorization
} }
func (o *Provider) KeysEndpoint() Endpoint { func (o *Provider) KeysEndpoint() *Endpoint {
return o.endpoints.JwksURI return o.endpoints.JwksURI
} }
@ -420,7 +421,7 @@ func WithAllowInsecure() Option {
} }
} }
func WithCustomAuthEndpoint(endpoint Endpoint) Option { func WithCustomAuthEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -430,7 +431,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomTokenEndpoint(endpoint Endpoint) Option { func WithCustomTokenEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -440,7 +441,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option { func WithCustomIntrospectionEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -450,7 +451,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomUserinfoEndpoint(endpoint Endpoint) Option { func WithCustomUserinfoEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -460,7 +461,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomRevocationEndpoint(endpoint Endpoint) Option { func WithCustomRevocationEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -470,7 +471,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { func WithCustomEndSessionEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -480,7 +481,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomKeysEndpoint(endpoint Endpoint) Option { func WithCustomKeysEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -490,7 +491,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option { func WithCustomDeviceAuthorizationEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
@ -500,8 +501,16 @@ func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option { // WithCustomEndpoints sets multiple endpoints at once.
// Non of the endpoints may be nil, or an error will
// be returned when the Option used by the Provider.
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys *Endpoint) Option {
return func(o *Provider) error { return func(o *Provider) error {
for _, e := range []*Endpoint{auth, token, userInfo, revocation, endSession, keys} {
if err := e.Validate(); err != nil {
return err
}
}
o.endpoints.Authorization = auth o.endpoints.Authorization = auth
o.endpoints.Token = token o.endpoints.Token = token
o.endpoints.Userinfo = userInfo o.endpoints.Userinfo = userInfo

View file

@ -395,3 +395,54 @@ func TestRoutes(t *testing.T) {
}) })
} }
} }
func TestWithCustomEndpoints(t *testing.T) {
type args struct {
auth *op.Endpoint
token *op.Endpoint
userInfo *op.Endpoint
revocation *op.Endpoint
endSession *op.Endpoint
keys *op.Endpoint
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "all nil",
args: args{},
wantErr: op.ErrNilEndpoint,
},
{
name: "all set",
args: args{
auth: op.NewEndpoint("/authorize"),
token: op.NewEndpoint("/oauth/token"),
userInfo: op.NewEndpoint("/userinfo"),
revocation: op.NewEndpoint("/revoke"),
endSession: op.NewEndpoint("/end_session"),
keys: op.NewEndpoint("/keys"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider, err := op.NewOpenIDProvider(testIssuer, testConfig,
storage.NewStorage(storage.NewUserStore(testIssuer)),
op.WithCustomEndpoints(tt.args.auth, tt.args.token, tt.args.userInfo, tt.args.revocation, tt.args.endSession, tt.args.keys),
)
require.ErrorIs(t, err, tt.wantErr)
if tt.wantErr != nil {
return
}
assert.Equal(t, tt.args.auth, provider.AuthorizationEndpoint())
assert.Equal(t, tt.args.token, provider.TokenEndpoint())
assert.Equal(t, tt.args.userInfo, provider.UserinfoEndpoint())
assert.Equal(t, tt.args.revocation, provider.RevocationEndpoint())
assert.Equal(t, tt.args.endSession, provider.EndSessionEndpoint())
assert.Equal(t, tt.args.keys, provider.KeysEndpoint())
})
}
}

View file

@ -20,13 +20,13 @@ import (
// The routes can be customized with [WithEndpoints]. // The routes can be customized with [WithEndpoints].
// //
// EXPERIMENTAL: may change until v4 // EXPERIMENTAL: may change until v4
func RegisterServer(server Server, options ...ServerOption) http.Handler { func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler {
decoder := schema.NewDecoder() decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(true) decoder.IgnoreUnknownKeys(true)
ws := &webServer{ ws := &webServer{
server: server, server: server,
endpoints: *DefaultEndpoints, endpoints: endpoints,
decoder: decoder, decoder: decoder,
logger: slog.Default(), logger: slog.Default(),
} }
@ -49,13 +49,6 @@ func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
} }
} }
// WithEndpoints overrides the [DefaultEndpoints]
func WithEndpoints(endpoints Endpoints) ServerOption {
return func(s *webServer) {
s.endpoints = endpoints
}
}
// WithDecoder overrides the default decoder, // WithDecoder overrides the default decoder,
// which is a [schema.Decoder] with IgnoreUnknownKeys set to true. // which is a [schema.Decoder] with IgnoreUnknownKeys set to true.
func WithDecoder(decoder httphelper.Decoder) ServerOption { func WithDecoder(decoder httphelper.Decoder) ServerOption {
@ -96,17 +89,25 @@ func (s *webServer) createRouter() {
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health)) router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready)) router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery)) router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
router.HandleFunc(s.endpoints.Authorization.Relative(), s.authorizeHandler)
router.HandleFunc(s.endpoints.DeviceAuthorization.Relative(), s.withClient(s.deviceAuthorizationHandler)) s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler)
router.HandleFunc(s.endpoints.Token.Relative(), s.tokensHandler) s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
router.HandleFunc(s.endpoints.Introspection.Relative(), s.withClient(s.introspectionHandler)) s.endpointRoute(router, s.endpoints.Token, s.tokensHandler)
router.HandleFunc(s.endpoints.Userinfo.Relative(), s.userInfoHandler) s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler))
router.HandleFunc(s.endpoints.Revocation.Relative(), s.withClient(s.revocationHandler)) s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler)
router.HandleFunc(s.endpoints.EndSession.Relative(), s.endSessionHandler) s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler))
router.HandleFunc(s.endpoints.JwksURI.Relative(), simpleHandler(s, s.server.Keys)) s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler)
s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
s.Handler = router s.Handler = router
} }
func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) {
if e != nil {
router.HandleFunc(e.Relative(), hf)
s.logger.Info("registered route", "endpoint", e.Relative())
}
}
type clientHandler func(w http.ResponseWriter, r *http.Request, client Client) type clientHandler func(w http.ResponseWriter, r *http.Request, client Client)
func (s *webServer) withClient(handler clientHandler) http.HandlerFunc { func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {

View file

@ -18,7 +18,7 @@ import (
) )
func TestServerRoutes(t *testing.T) { func TestServerRoutes(t *testing.T) {
server := op.NewLegacyServer(testProvider) server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints)
storage := testProvider.Storage().(routesTestStorage) storage := testProvider.Storage().(routesTestStorage)
ctx := op.ContextWithIssuer(context.Background(), testIssuer) ctx := op.ContextWithIssuer(context.Background(), testIssuer)

View file

@ -25,15 +25,14 @@ import (
func TestRegisterServer(t *testing.T) { func TestRegisterServer(t *testing.T) {
server := UnimplementedServer{} server := UnimplementedServer{}
endpoints := Endpoints{ endpoints := Endpoints{
Authorization: Endpoint{ Authorization: &Endpoint{
path: "/auth", path: "/auth",
}, },
} }
decoder := schema.NewDecoder() decoder := schema.NewDecoder()
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
h := RegisterServer(server, h := RegisterServer(server, endpoints,
WithEndpoints(endpoints),
WithDecoder(decoder), WithDecoder(decoder),
WithFallbackLogger(logger), WithFallbackLogger(logger),
) )

View file

@ -10,15 +10,31 @@ import (
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
) )
// LegacyServer is an implementation of [Server[] that
// simply wraps a [OpenIDProvider].
// It can be used to transition from the former Provider/Storage
// interfaces to the new Server interface.
type LegacyServer struct { type LegacyServer struct {
UnimplementedServer UnimplementedServer
provider OpenIDProvider provider OpenIDProvider
endpoints Endpoints
} }
func NewLegacyServer(provider OpenIDProvider) http.Handler { // NewLegacyServer wraps provider in a `Server` and returns a handler which is
// the Server's router.
//
// Only non-nil endpoints will be registered on the router.
// Nil endpoints are disabled.
//
// The passed endpoints is also set to the provider,
// to be consistent with the discovery config.
// Any `With*Endpoint()` option used on the provider is
// therefore ineffective.
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler {
server := RegisterServer(&LegacyServer{ server := RegisterServer(&LegacyServer{
provider: provider, provider: provider,
}, WithHTTPMiddleware(intercept(provider.IssuerFromRequest))) endpoints: endpoints,
}, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
router := chi.NewRouter() router := chi.NewRouter()
router.Mount("/", server) router.Mount("/", server)
@ -43,7 +59,7 @@ func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Respon
func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) { func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return NewResponse( return NewResponse(
CreateDiscoveryConfig(ctx, s.provider, s.provider.Storage()), createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints),
), nil ), nil
} }