make endpoints pointers to enable/disable them

This commit is contained in:
Tim Möhlmann 2023-09-27 18:09:00 +03:00
parent f6cb47fbbb
commit af22c1a4d8
12 changed files with 229 additions and 93 deletions

View file

@ -79,7 +79,7 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer
handler := http.Handler(provider)
if wrapServer {
handler = op.NewLegacyServer(provider)
handler = op.NewLegacyServer(provider, *op.DefaultEndpoints)
}
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)

View file

@ -20,14 +20,14 @@ var (
type Configuration interface {
IssuerFromRequest(r *http.Request) string
Insecure() bool
AuthorizationEndpoint() Endpoint
TokenEndpoint() Endpoint
IntrospectionEndpoint() Endpoint
UserinfoEndpoint() Endpoint
RevocationEndpoint() Endpoint
EndSessionEndpoint() Endpoint
KeysEndpoint() Endpoint
DeviceAuthorizationEndpoint() Endpoint
AuthorizationEndpoint() *Endpoint
TokenEndpoint() *Endpoint
IntrospectionEndpoint() *Endpoint
UserinfoEndpoint() *Endpoint
RevocationEndpoint() *Endpoint
EndSessionEndpoint() *Endpoint
KeysEndpoint() *Endpoint
DeviceAuthorizationEndpoint() *Endpoint
AuthMethodPostSupported() bool
CodeMethodS256Supported() bool

View file

@ -64,6 +64,37 @@ func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage Di
}
}
func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage DiscoverStorage, endpoints *Endpoints) *oidc.DiscoveryConfiguration {
issuer := IssuerFromContext(ctx)
return &oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: endpoints.Authorization.Absolute(issuer),
TokenEndpoint: endpoints.Token.Absolute(issuer),
IntrospectionEndpoint: endpoints.Introspection.Absolute(issuer),
UserinfoEndpoint: endpoints.Userinfo.Absolute(issuer),
RevocationEndpoint: endpoints.Revocation.Absolute(issuer),
EndSessionEndpoint: endpoints.EndSession.Absolute(issuer),
JwksURI: endpoints.JwksURI.Absolute(issuer),
DeviceAuthorizationEndpoint: endpoints.DeviceAuthorization.Absolute(issuer),
ScopesSupported: Scopes(config),
ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config),
SubjectTypesSupported: SubjectTypes(config),
IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config),
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config),
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config),
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config),
ClaimsSupported: SupportedClaims(config),
CodeChallengeMethodsSupported: CodeChallengeMethods(config),
UILocalesSupported: config.SupportedUILocales(),
RequestParameterSupported: config.RequestObjectSupported(),
}
}
func Scopes(c Configuration) []string {
return DefaultSupportedScopes // TODO: config
}

View file

@ -1,32 +1,46 @@
package op
import "strings"
import (
"errors"
"strings"
)
type Endpoint struct {
path string
url string
}
func NewEndpoint(path string) Endpoint {
return Endpoint{path: path}
func NewEndpoint(path string) *Endpoint {
return &Endpoint{path: path}
}
func NewEndpointWithURL(path, url string) Endpoint {
return Endpoint{path: path, url: url}
func NewEndpointWithURL(path, url string) *Endpoint {
return &Endpoint{path: path, url: url}
}
func (e Endpoint) Relative() string {
func (e *Endpoint) Relative() string {
if e == nil {
return ""
}
return relativeEndpoint(e.path)
}
func (e Endpoint) Absolute(host string) string {
func (e *Endpoint) Absolute(host string) string {
if e == nil {
return ""
}
if e.url != "" {
return e.url
}
return absoluteEndpoint(host, e.path)
}
func (e Endpoint) Validate() error {
var ErrNilEndpoint = errors.New("nil endpoint")
func (e *Endpoint) Validate() error {
if e == nil {
return ErrNilEndpoint
}
return nil // TODO:
}

View file

@ -3,13 +3,14 @@ package op_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/op"
)
func TestEndpoint_Path(t *testing.T) {
tests := []struct {
name string
e op.Endpoint
e *op.Endpoint
want string
}{
{
@ -27,6 +28,11 @@ func TestEndpoint_Path(t *testing.T) {
op.NewEndpointWithURL("/test", "http://test.com/test"),
"/test",
},
{
"nil",
nil,
"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -43,7 +49,7 @@ func TestEndpoint_Absolute(t *testing.T) {
}
tests := []struct {
name string
e op.Endpoint
e *op.Endpoint
args args
want string
}{
@ -77,6 +83,12 @@ func TestEndpoint_Absolute(t *testing.T) {
args{"https://host"},
"https://test.com/test",
},
{
"nil",
nil,
args{"https://host"},
"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -91,16 +103,19 @@ func TestEndpoint_Absolute(t *testing.T) {
func TestEndpoint_Validate(t *testing.T) {
tests := []struct {
name string
e op.Endpoint
wantErr bool
e *op.Endpoint
wantErr error
}{
// TODO: Add test cases.
{
"nil",
nil,
op.ErrNilEndpoint,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.e.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
err := tt.e.Validate()
require.ErrorIs(t, err, tt.wantErr)
})
}
}

View file

@ -65,10 +65,10 @@ func (mr *MockConfigurationMockRecorder) AuthMethodPrivateKeyJWTSupported() *gom
}
// AuthorizationEndpoint mocks base method.
func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint {
func (m *MockConfiguration) AuthorizationEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthorizationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -107,10 +107,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
}
// DeviceAuthorizationEndpoint mocks base method.
func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint {
func (m *MockConfiguration) DeviceAuthorizationEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -121,10 +121,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.C
}
// EndSessionEndpoint mocks base method.
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
func (m *MockConfiguration) EndSessionEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "EndSessionEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -233,10 +233,10 @@ func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSup
}
// IntrospectionEndpoint mocks base method.
func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint {
func (m *MockConfiguration) IntrospectionEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IntrospectionEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -275,10 +275,10 @@ func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *go
}
// KeysEndpoint mocks base method.
func (m *MockConfiguration) KeysEndpoint() op.Endpoint {
func (m *MockConfiguration) KeysEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeysEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -331,10 +331,10 @@ func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSuppor
}
// RevocationEndpoint mocks base method.
func (m *MockConfiguration) RevocationEndpoint() op.Endpoint {
func (m *MockConfiguration) RevocationEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevocationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -373,10 +373,10 @@ func (mr *MockConfigurationMockRecorder) SupportedUILocales() *gomock.Call {
}
// TokenEndpoint mocks base method.
func (m *MockConfiguration) TokenEndpoint() op.Endpoint {
func (m *MockConfiguration) TokenEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TokenEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -401,10 +401,10 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
}
// UserinfoEndpoint mocks base method.
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
func (m *MockConfiguration) UserinfoEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserinfoEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}

View file

@ -131,16 +131,17 @@ type Config struct {
DeviceAuthorization DeviceAuthorizationConfig
}
// Endpoints defines endpoint routes.
type Endpoints struct {
Authorization Endpoint
Token Endpoint
Introspection Endpoint
Userinfo Endpoint
Revocation Endpoint
EndSession Endpoint
CheckSessionIframe Endpoint
JwksURI Endpoint
DeviceAuthorization Endpoint
Authorization *Endpoint
Token *Endpoint
Introspection *Endpoint
Userinfo *Endpoint
Revocation *Endpoint
EndSession *Endpoint
CheckSessionIframe *Endpoint
JwksURI *Endpoint
DeviceAuthorization *Endpoint
}
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
@ -233,35 +234,35 @@ func (o *Provider) Insecure() bool {
return o.insecure
}
func (o *Provider) AuthorizationEndpoint() Endpoint {
func (o *Provider) AuthorizationEndpoint() *Endpoint {
return o.endpoints.Authorization
}
func (o *Provider) TokenEndpoint() Endpoint {
func (o *Provider) TokenEndpoint() *Endpoint {
return o.endpoints.Token
}
func (o *Provider) IntrospectionEndpoint() Endpoint {
func (o *Provider) IntrospectionEndpoint() *Endpoint {
return o.endpoints.Introspection
}
func (o *Provider) UserinfoEndpoint() Endpoint {
func (o *Provider) UserinfoEndpoint() *Endpoint {
return o.endpoints.Userinfo
}
func (o *Provider) RevocationEndpoint() Endpoint {
func (o *Provider) RevocationEndpoint() *Endpoint {
return o.endpoints.Revocation
}
func (o *Provider) EndSessionEndpoint() Endpoint {
func (o *Provider) EndSessionEndpoint() *Endpoint {
return o.endpoints.EndSession
}
func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint {
return o.endpoints.DeviceAuthorization
}
func (o *Provider) KeysEndpoint() Endpoint {
func (o *Provider) KeysEndpoint() *Endpoint {
return o.endpoints.JwksURI
}
@ -420,7 +421,7 @@ func WithAllowInsecure() Option {
}
}
func WithCustomAuthEndpoint(endpoint Endpoint) Option {
func WithCustomAuthEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -430,7 +431,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomTokenEndpoint(endpoint Endpoint) Option {
func WithCustomTokenEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -440,7 +441,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
func WithCustomIntrospectionEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -450,7 +451,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
func WithCustomUserinfoEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -460,7 +461,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
func WithCustomRevocationEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -470,7 +471,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
func WithCustomEndSessionEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -480,7 +481,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomKeysEndpoint(endpoint Endpoint) Option {
func WithCustomKeysEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -490,7 +491,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
func WithCustomDeviceAuthorizationEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -500,8 +501,16 @@ func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option {
// WithCustomEndpoints sets multiple endpoints at once.
// Non of the endpoints may be nil, or an error will
// be returned when the Option used by the Provider.
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys *Endpoint) Option {
return func(o *Provider) error {
for _, e := range []*Endpoint{auth, token, userInfo, revocation, endSession, keys} {
if err := e.Validate(); err != nil {
return err
}
}
o.endpoints.Authorization = auth
o.endpoints.Token = token
o.endpoints.Userinfo = userInfo

View file

@ -395,3 +395,54 @@ func TestRoutes(t *testing.T) {
})
}
}
func TestWithCustomEndpoints(t *testing.T) {
type args struct {
auth *op.Endpoint
token *op.Endpoint
userInfo *op.Endpoint
revocation *op.Endpoint
endSession *op.Endpoint
keys *op.Endpoint
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "all nil",
args: args{},
wantErr: op.ErrNilEndpoint,
},
{
name: "all set",
args: args{
auth: op.NewEndpoint("/authorize"),
token: op.NewEndpoint("/oauth/token"),
userInfo: op.NewEndpoint("/userinfo"),
revocation: op.NewEndpoint("/revoke"),
endSession: op.NewEndpoint("/end_session"),
keys: op.NewEndpoint("/keys"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider, err := op.NewOpenIDProvider(testIssuer, testConfig,
storage.NewStorage(storage.NewUserStore(testIssuer)),
op.WithCustomEndpoints(tt.args.auth, tt.args.token, tt.args.userInfo, tt.args.revocation, tt.args.endSession, tt.args.keys),
)
require.ErrorIs(t, err, tt.wantErr)
if tt.wantErr != nil {
return
}
assert.Equal(t, tt.args.auth, provider.AuthorizationEndpoint())
assert.Equal(t, tt.args.token, provider.TokenEndpoint())
assert.Equal(t, tt.args.userInfo, provider.UserinfoEndpoint())
assert.Equal(t, tt.args.revocation, provider.RevocationEndpoint())
assert.Equal(t, tt.args.endSession, provider.EndSessionEndpoint())
assert.Equal(t, tt.args.keys, provider.KeysEndpoint())
})
}
}

View file

@ -20,13 +20,13 @@ import (
// The routes can be customized with [WithEndpoints].
//
// EXPERIMENTAL: may change until v4
func RegisterServer(server Server, options ...ServerOption) http.Handler {
func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler {
decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(true)
ws := &webServer{
server: server,
endpoints: *DefaultEndpoints,
endpoints: endpoints,
decoder: decoder,
logger: slog.Default(),
}
@ -49,13 +49,6 @@ func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
}
}
// WithEndpoints overrides the [DefaultEndpoints]
func WithEndpoints(endpoints Endpoints) ServerOption {
return func(s *webServer) {
s.endpoints = endpoints
}
}
// WithDecoder overrides the default decoder,
// which is a [schema.Decoder] with IgnoreUnknownKeys set to true.
func WithDecoder(decoder httphelper.Decoder) ServerOption {
@ -96,17 +89,25 @@ func (s *webServer) createRouter() {
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
router.HandleFunc(s.endpoints.Authorization.Relative(), s.authorizeHandler)
router.HandleFunc(s.endpoints.DeviceAuthorization.Relative(), s.withClient(s.deviceAuthorizationHandler))
router.HandleFunc(s.endpoints.Token.Relative(), s.tokensHandler)
router.HandleFunc(s.endpoints.Introspection.Relative(), s.withClient(s.introspectionHandler))
router.HandleFunc(s.endpoints.Userinfo.Relative(), s.userInfoHandler)
router.HandleFunc(s.endpoints.Revocation.Relative(), s.withClient(s.revocationHandler))
router.HandleFunc(s.endpoints.EndSession.Relative(), s.endSessionHandler)
router.HandleFunc(s.endpoints.JwksURI.Relative(), simpleHandler(s, s.server.Keys))
s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler)
s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
s.endpointRoute(router, s.endpoints.Token, s.tokensHandler)
s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler))
s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler)
s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler))
s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler)
s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
s.Handler = router
}
func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) {
if e != nil {
router.HandleFunc(e.Relative(), hf)
s.logger.Info("registered route", "endpoint", e.Relative())
}
}
type clientHandler func(w http.ResponseWriter, r *http.Request, client Client)
func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {

View file

@ -18,7 +18,7 @@ import (
)
func TestServerRoutes(t *testing.T) {
server := op.NewLegacyServer(testProvider)
server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints)
storage := testProvider.Storage().(routesTestStorage)
ctx := op.ContextWithIssuer(context.Background(), testIssuer)

View file

@ -25,15 +25,14 @@ import (
func TestRegisterServer(t *testing.T) {
server := UnimplementedServer{}
endpoints := Endpoints{
Authorization: Endpoint{
Authorization: &Endpoint{
path: "/auth",
},
}
decoder := schema.NewDecoder()
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
h := RegisterServer(server,
WithEndpoints(endpoints),
h := RegisterServer(server, endpoints,
WithDecoder(decoder),
WithFallbackLogger(logger),
)

View file

@ -10,15 +10,31 @@ import (
"github.com/zitadel/oidc/v3/pkg/oidc"
)
// LegacyServer is an implementation of [Server[] that
// simply wraps a [OpenIDProvider].
// It can be used to transition from the former Provider/Storage
// interfaces to the new Server interface.
type LegacyServer struct {
UnimplementedServer
provider OpenIDProvider
provider OpenIDProvider
endpoints Endpoints
}
func NewLegacyServer(provider OpenIDProvider) http.Handler {
// NewLegacyServer wraps provider in a `Server` and returns a handler which is
// the Server's router.
//
// Only non-nil endpoints will be registered on the router.
// Nil endpoints are disabled.
//
// The passed endpoints is also set to the provider,
// to be consistent with the discovery config.
// Any `With*Endpoint()` option used on the provider is
// therefore ineffective.
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler {
server := RegisterServer(&LegacyServer{
provider: provider,
}, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
provider: provider,
endpoints: endpoints,
}, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
router := chi.NewRouter()
router.Mount("/", server)
@ -43,7 +59,7 @@ func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Respon
func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return NewResponse(
CreateDiscoveryConfig(ctx, s.provider, s.provider.Storage()),
createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints),
), nil
}