make endpoints pointers to enable/disable them
This commit is contained in:
parent
f6cb47fbbb
commit
af22c1a4d8
12 changed files with 229 additions and 93 deletions
|
@ -79,7 +79,7 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer
|
||||||
|
|
||||||
handler := http.Handler(provider)
|
handler := http.Handler(provider)
|
||||||
if wrapServer {
|
if wrapServer {
|
||||||
handler = op.NewLegacyServer(provider)
|
handler = op.NewLegacyServer(provider, *op.DefaultEndpoints)
|
||||||
}
|
}
|
||||||
|
|
||||||
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
|
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
|
||||||
|
|
|
@ -20,14 +20,14 @@ var (
|
||||||
type Configuration interface {
|
type Configuration interface {
|
||||||
IssuerFromRequest(r *http.Request) string
|
IssuerFromRequest(r *http.Request) string
|
||||||
Insecure() bool
|
Insecure() bool
|
||||||
AuthorizationEndpoint() Endpoint
|
AuthorizationEndpoint() *Endpoint
|
||||||
TokenEndpoint() Endpoint
|
TokenEndpoint() *Endpoint
|
||||||
IntrospectionEndpoint() Endpoint
|
IntrospectionEndpoint() *Endpoint
|
||||||
UserinfoEndpoint() Endpoint
|
UserinfoEndpoint() *Endpoint
|
||||||
RevocationEndpoint() Endpoint
|
RevocationEndpoint() *Endpoint
|
||||||
EndSessionEndpoint() Endpoint
|
EndSessionEndpoint() *Endpoint
|
||||||
KeysEndpoint() Endpoint
|
KeysEndpoint() *Endpoint
|
||||||
DeviceAuthorizationEndpoint() Endpoint
|
DeviceAuthorizationEndpoint() *Endpoint
|
||||||
|
|
||||||
AuthMethodPostSupported() bool
|
AuthMethodPostSupported() bool
|
||||||
CodeMethodS256Supported() bool
|
CodeMethodS256Supported() bool
|
||||||
|
|
|
@ -64,6 +64,37 @@ func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage Di
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage DiscoverStorage, endpoints *Endpoints) *oidc.DiscoveryConfiguration {
|
||||||
|
issuer := IssuerFromContext(ctx)
|
||||||
|
return &oidc.DiscoveryConfiguration{
|
||||||
|
Issuer: issuer,
|
||||||
|
AuthorizationEndpoint: endpoints.Authorization.Absolute(issuer),
|
||||||
|
TokenEndpoint: endpoints.Token.Absolute(issuer),
|
||||||
|
IntrospectionEndpoint: endpoints.Introspection.Absolute(issuer),
|
||||||
|
UserinfoEndpoint: endpoints.Userinfo.Absolute(issuer),
|
||||||
|
RevocationEndpoint: endpoints.Revocation.Absolute(issuer),
|
||||||
|
EndSessionEndpoint: endpoints.EndSession.Absolute(issuer),
|
||||||
|
JwksURI: endpoints.JwksURI.Absolute(issuer),
|
||||||
|
DeviceAuthorizationEndpoint: endpoints.DeviceAuthorization.Absolute(issuer),
|
||||||
|
ScopesSupported: Scopes(config),
|
||||||
|
ResponseTypesSupported: ResponseTypes(config),
|
||||||
|
GrantTypesSupported: GrantTypes(config),
|
||||||
|
SubjectTypesSupported: SubjectTypes(config),
|
||||||
|
IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
|
||||||
|
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
|
||||||
|
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
|
||||||
|
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
|
||||||
|
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config),
|
||||||
|
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config),
|
||||||
|
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config),
|
||||||
|
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config),
|
||||||
|
ClaimsSupported: SupportedClaims(config),
|
||||||
|
CodeChallengeMethodsSupported: CodeChallengeMethods(config),
|
||||||
|
UILocalesSupported: config.SupportedUILocales(),
|
||||||
|
RequestParameterSupported: config.RequestObjectSupported(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Scopes(c Configuration) []string {
|
func Scopes(c Configuration) []string {
|
||||||
return DefaultSupportedScopes // TODO: config
|
return DefaultSupportedScopes // TODO: config
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,32 +1,46 @@
|
||||||
package op
|
package op
|
||||||
|
|
||||||
import "strings"
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
type Endpoint struct {
|
type Endpoint struct {
|
||||||
path string
|
path string
|
||||||
url string
|
url string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewEndpoint(path string) Endpoint {
|
func NewEndpoint(path string) *Endpoint {
|
||||||
return Endpoint{path: path}
|
return &Endpoint{path: path}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewEndpointWithURL(path, url string) Endpoint {
|
func NewEndpointWithURL(path, url string) *Endpoint {
|
||||||
return Endpoint{path: path, url: url}
|
return &Endpoint{path: path, url: url}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e Endpoint) Relative() string {
|
func (e *Endpoint) Relative() string {
|
||||||
|
if e == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return relativeEndpoint(e.path)
|
return relativeEndpoint(e.path)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e Endpoint) Absolute(host string) string {
|
func (e *Endpoint) Absolute(host string) string {
|
||||||
|
if e == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
if e.url != "" {
|
if e.url != "" {
|
||||||
return e.url
|
return e.url
|
||||||
}
|
}
|
||||||
return absoluteEndpoint(host, e.path)
|
return absoluteEndpoint(host, e.path)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e Endpoint) Validate() error {
|
var ErrNilEndpoint = errors.New("nil endpoint")
|
||||||
|
|
||||||
|
func (e *Endpoint) Validate() error {
|
||||||
|
if e == nil {
|
||||||
|
return ErrNilEndpoint
|
||||||
|
}
|
||||||
return nil // TODO:
|
return nil // TODO:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,13 +3,14 @@ package op_test
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/zitadel/oidc/v3/pkg/op"
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestEndpoint_Path(t *testing.T) {
|
func TestEndpoint_Path(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
e op.Endpoint
|
e *op.Endpoint
|
||||||
want string
|
want string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
|
@ -27,6 +28,11 @@ func TestEndpoint_Path(t *testing.T) {
|
||||||
op.NewEndpointWithURL("/test", "http://test.com/test"),
|
op.NewEndpointWithURL("/test", "http://test.com/test"),
|
||||||
"/test",
|
"/test",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"nil",
|
||||||
|
nil,
|
||||||
|
"",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -43,7 +49,7 @@ func TestEndpoint_Absolute(t *testing.T) {
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
e op.Endpoint
|
e *op.Endpoint
|
||||||
args args
|
args args
|
||||||
want string
|
want string
|
||||||
}{
|
}{
|
||||||
|
@ -77,6 +83,12 @@ func TestEndpoint_Absolute(t *testing.T) {
|
||||||
args{"https://host"},
|
args{"https://host"},
|
||||||
"https://test.com/test",
|
"https://test.com/test",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"nil",
|
||||||
|
nil,
|
||||||
|
args{"https://host"},
|
||||||
|
"",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -91,16 +103,19 @@ func TestEndpoint_Absolute(t *testing.T) {
|
||||||
func TestEndpoint_Validate(t *testing.T) {
|
func TestEndpoint_Validate(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
e op.Endpoint
|
e *op.Endpoint
|
||||||
wantErr bool
|
wantErr error
|
||||||
}{
|
}{
|
||||||
// TODO: Add test cases.
|
{
|
||||||
|
"nil",
|
||||||
|
nil,
|
||||||
|
op.ErrNilEndpoint,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := tt.e.Validate(); (err != nil) != tt.wantErr {
|
err := tt.e.Validate()
|
||||||
t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
require.ErrorIs(t, err, tt.wantErr)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,10 +65,10 @@ func (mr *MockConfigurationMockRecorder) AuthMethodPrivateKeyJWTSupported() *gom
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizationEndpoint mocks base method.
|
// AuthorizationEndpoint mocks base method.
|
||||||
func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint {
|
func (m *MockConfiguration) AuthorizationEndpoint() *op.Endpoint {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AuthorizationEndpoint")
|
ret := m.ctrl.Call(m, "AuthorizationEndpoint")
|
||||||
ret0, _ := ret[0].(op.Endpoint)
|
ret0, _ := ret[0].(*op.Endpoint)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,10 +107,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceAuthorizationEndpoint mocks base method.
|
// DeviceAuthorizationEndpoint mocks base method.
|
||||||
func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint {
|
func (m *MockConfiguration) DeviceAuthorizationEndpoint() *op.Endpoint {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
|
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
|
||||||
ret0, _ := ret[0].(op.Endpoint)
|
ret0, _ := ret[0].(*op.Endpoint)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,10 +121,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.C
|
||||||
}
|
}
|
||||||
|
|
||||||
// EndSessionEndpoint mocks base method.
|
// EndSessionEndpoint mocks base method.
|
||||||
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
|
func (m *MockConfiguration) EndSessionEndpoint() *op.Endpoint {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "EndSessionEndpoint")
|
ret := m.ctrl.Call(m, "EndSessionEndpoint")
|
||||||
ret0, _ := ret[0].(op.Endpoint)
|
ret0, _ := ret[0].(*op.Endpoint)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -233,10 +233,10 @@ func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSup
|
||||||
}
|
}
|
||||||
|
|
||||||
// IntrospectionEndpoint mocks base method.
|
// IntrospectionEndpoint mocks base method.
|
||||||
func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint {
|
func (m *MockConfiguration) IntrospectionEndpoint() *op.Endpoint {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "IntrospectionEndpoint")
|
ret := m.ctrl.Call(m, "IntrospectionEndpoint")
|
||||||
ret0, _ := ret[0].(op.Endpoint)
|
ret0, _ := ret[0].(*op.Endpoint)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -275,10 +275,10 @@ func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *go
|
||||||
}
|
}
|
||||||
|
|
||||||
// KeysEndpoint mocks base method.
|
// KeysEndpoint mocks base method.
|
||||||
func (m *MockConfiguration) KeysEndpoint() op.Endpoint {
|
func (m *MockConfiguration) KeysEndpoint() *op.Endpoint {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "KeysEndpoint")
|
ret := m.ctrl.Call(m, "KeysEndpoint")
|
||||||
ret0, _ := ret[0].(op.Endpoint)
|
ret0, _ := ret[0].(*op.Endpoint)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -331,10 +331,10 @@ func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSuppor
|
||||||
}
|
}
|
||||||
|
|
||||||
// RevocationEndpoint mocks base method.
|
// RevocationEndpoint mocks base method.
|
||||||
func (m *MockConfiguration) RevocationEndpoint() op.Endpoint {
|
func (m *MockConfiguration) RevocationEndpoint() *op.Endpoint {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RevocationEndpoint")
|
ret := m.ctrl.Call(m, "RevocationEndpoint")
|
||||||
ret0, _ := ret[0].(op.Endpoint)
|
ret0, _ := ret[0].(*op.Endpoint)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -373,10 +373,10 @@ func (mr *MockConfigurationMockRecorder) SupportedUILocales() *gomock.Call {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenEndpoint mocks base method.
|
// TokenEndpoint mocks base method.
|
||||||
func (m *MockConfiguration) TokenEndpoint() op.Endpoint {
|
func (m *MockConfiguration) TokenEndpoint() *op.Endpoint {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "TokenEndpoint")
|
ret := m.ctrl.Call(m, "TokenEndpoint")
|
||||||
ret0, _ := ret[0].(op.Endpoint)
|
ret0, _ := ret[0].(*op.Endpoint)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -401,10 +401,10 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserinfoEndpoint mocks base method.
|
// UserinfoEndpoint mocks base method.
|
||||||
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
|
func (m *MockConfiguration) UserinfoEndpoint() *op.Endpoint {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "UserinfoEndpoint")
|
ret := m.ctrl.Call(m, "UserinfoEndpoint")
|
||||||
ret0, _ := ret[0].(op.Endpoint)
|
ret0, _ := ret[0].(*op.Endpoint)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
61
pkg/op/op.go
61
pkg/op/op.go
|
@ -131,16 +131,17 @@ type Config struct {
|
||||||
DeviceAuthorization DeviceAuthorizationConfig
|
DeviceAuthorization DeviceAuthorizationConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Endpoints defines endpoint routes.
|
||||||
type Endpoints struct {
|
type Endpoints struct {
|
||||||
Authorization Endpoint
|
Authorization *Endpoint
|
||||||
Token Endpoint
|
Token *Endpoint
|
||||||
Introspection Endpoint
|
Introspection *Endpoint
|
||||||
Userinfo Endpoint
|
Userinfo *Endpoint
|
||||||
Revocation Endpoint
|
Revocation *Endpoint
|
||||||
EndSession Endpoint
|
EndSession *Endpoint
|
||||||
CheckSessionIframe Endpoint
|
CheckSessionIframe *Endpoint
|
||||||
JwksURI Endpoint
|
JwksURI *Endpoint
|
||||||
DeviceAuthorization Endpoint
|
DeviceAuthorization *Endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
|
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
|
||||||
|
@ -233,35 +234,35 @@ func (o *Provider) Insecure() bool {
|
||||||
return o.insecure
|
return o.insecure
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Provider) AuthorizationEndpoint() Endpoint {
|
func (o *Provider) AuthorizationEndpoint() *Endpoint {
|
||||||
return o.endpoints.Authorization
|
return o.endpoints.Authorization
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Provider) TokenEndpoint() Endpoint {
|
func (o *Provider) TokenEndpoint() *Endpoint {
|
||||||
return o.endpoints.Token
|
return o.endpoints.Token
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Provider) IntrospectionEndpoint() Endpoint {
|
func (o *Provider) IntrospectionEndpoint() *Endpoint {
|
||||||
return o.endpoints.Introspection
|
return o.endpoints.Introspection
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Provider) UserinfoEndpoint() Endpoint {
|
func (o *Provider) UserinfoEndpoint() *Endpoint {
|
||||||
return o.endpoints.Userinfo
|
return o.endpoints.Userinfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Provider) RevocationEndpoint() Endpoint {
|
func (o *Provider) RevocationEndpoint() *Endpoint {
|
||||||
return o.endpoints.Revocation
|
return o.endpoints.Revocation
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Provider) EndSessionEndpoint() Endpoint {
|
func (o *Provider) EndSessionEndpoint() *Endpoint {
|
||||||
return o.endpoints.EndSession
|
return o.endpoints.EndSession
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
|
func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint {
|
||||||
return o.endpoints.DeviceAuthorization
|
return o.endpoints.DeviceAuthorization
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Provider) KeysEndpoint() Endpoint {
|
func (o *Provider) KeysEndpoint() *Endpoint {
|
||||||
return o.endpoints.JwksURI
|
return o.endpoints.JwksURI
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -420,7 +421,7 @@ func WithAllowInsecure() Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithCustomAuthEndpoint(endpoint Endpoint) Option {
|
func WithCustomAuthEndpoint(endpoint *Endpoint) Option {
|
||||||
return func(o *Provider) error {
|
return func(o *Provider) error {
|
||||||
if err := endpoint.Validate(); err != nil {
|
if err := endpoint.Validate(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -430,7 +431,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithCustomTokenEndpoint(endpoint Endpoint) Option {
|
func WithCustomTokenEndpoint(endpoint *Endpoint) Option {
|
||||||
return func(o *Provider) error {
|
return func(o *Provider) error {
|
||||||
if err := endpoint.Validate(); err != nil {
|
if err := endpoint.Validate(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -440,7 +441,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
|
func WithCustomIntrospectionEndpoint(endpoint *Endpoint) Option {
|
||||||
return func(o *Provider) error {
|
return func(o *Provider) error {
|
||||||
if err := endpoint.Validate(); err != nil {
|
if err := endpoint.Validate(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -450,7 +451,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
|
func WithCustomUserinfoEndpoint(endpoint *Endpoint) Option {
|
||||||
return func(o *Provider) error {
|
return func(o *Provider) error {
|
||||||
if err := endpoint.Validate(); err != nil {
|
if err := endpoint.Validate(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -460,7 +461,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
|
func WithCustomRevocationEndpoint(endpoint *Endpoint) Option {
|
||||||
return func(o *Provider) error {
|
return func(o *Provider) error {
|
||||||
if err := endpoint.Validate(); err != nil {
|
if err := endpoint.Validate(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -470,7 +471,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
|
func WithCustomEndSessionEndpoint(endpoint *Endpoint) Option {
|
||||||
return func(o *Provider) error {
|
return func(o *Provider) error {
|
||||||
if err := endpoint.Validate(); err != nil {
|
if err := endpoint.Validate(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -480,7 +481,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithCustomKeysEndpoint(endpoint Endpoint) Option {
|
func WithCustomKeysEndpoint(endpoint *Endpoint) Option {
|
||||||
return func(o *Provider) error {
|
return func(o *Provider) error {
|
||||||
if err := endpoint.Validate(); err != nil {
|
if err := endpoint.Validate(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -490,7 +491,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
|
func WithCustomDeviceAuthorizationEndpoint(endpoint *Endpoint) Option {
|
||||||
return func(o *Provider) error {
|
return func(o *Provider) error {
|
||||||
if err := endpoint.Validate(); err != nil {
|
if err := endpoint.Validate(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -500,8 +501,16 @@ func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option {
|
// WithCustomEndpoints sets multiple endpoints at once.
|
||||||
|
// Non of the endpoints may be nil, or an error will
|
||||||
|
// be returned when the Option used by the Provider.
|
||||||
|
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys *Endpoint) Option {
|
||||||
return func(o *Provider) error {
|
return func(o *Provider) error {
|
||||||
|
for _, e := range []*Endpoint{auth, token, userInfo, revocation, endSession, keys} {
|
||||||
|
if err := e.Validate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
o.endpoints.Authorization = auth
|
o.endpoints.Authorization = auth
|
||||||
o.endpoints.Token = token
|
o.endpoints.Token = token
|
||||||
o.endpoints.Userinfo = userInfo
|
o.endpoints.Userinfo = userInfo
|
||||||
|
|
|
@ -395,3 +395,54 @@ func TestRoutes(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWithCustomEndpoints(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
auth *op.Endpoint
|
||||||
|
token *op.Endpoint
|
||||||
|
userInfo *op.Endpoint
|
||||||
|
revocation *op.Endpoint
|
||||||
|
endSession *op.Endpoint
|
||||||
|
keys *op.Endpoint
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "all nil",
|
||||||
|
args: args{},
|
||||||
|
wantErr: op.ErrNilEndpoint,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all set",
|
||||||
|
args: args{
|
||||||
|
auth: op.NewEndpoint("/authorize"),
|
||||||
|
token: op.NewEndpoint("/oauth/token"),
|
||||||
|
userInfo: op.NewEndpoint("/userinfo"),
|
||||||
|
revocation: op.NewEndpoint("/revoke"),
|
||||||
|
endSession: op.NewEndpoint("/end_session"),
|
||||||
|
keys: op.NewEndpoint("/keys"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
provider, err := op.NewOpenIDProvider(testIssuer, testConfig,
|
||||||
|
storage.NewStorage(storage.NewUserStore(testIssuer)),
|
||||||
|
op.WithCustomEndpoints(tt.args.auth, tt.args.token, tt.args.userInfo, tt.args.revocation, tt.args.endSession, tt.args.keys),
|
||||||
|
)
|
||||||
|
require.ErrorIs(t, err, tt.wantErr)
|
||||||
|
if tt.wantErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.args.auth, provider.AuthorizationEndpoint())
|
||||||
|
assert.Equal(t, tt.args.token, provider.TokenEndpoint())
|
||||||
|
assert.Equal(t, tt.args.userInfo, provider.UserinfoEndpoint())
|
||||||
|
assert.Equal(t, tt.args.revocation, provider.RevocationEndpoint())
|
||||||
|
assert.Equal(t, tt.args.endSession, provider.EndSessionEndpoint())
|
||||||
|
assert.Equal(t, tt.args.keys, provider.KeysEndpoint())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -20,13 +20,13 @@ import (
|
||||||
// The routes can be customized with [WithEndpoints].
|
// The routes can be customized with [WithEndpoints].
|
||||||
//
|
//
|
||||||
// EXPERIMENTAL: may change until v4
|
// EXPERIMENTAL: may change until v4
|
||||||
func RegisterServer(server Server, options ...ServerOption) http.Handler {
|
func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler {
|
||||||
decoder := schema.NewDecoder()
|
decoder := schema.NewDecoder()
|
||||||
decoder.IgnoreUnknownKeys(true)
|
decoder.IgnoreUnknownKeys(true)
|
||||||
|
|
||||||
ws := &webServer{
|
ws := &webServer{
|
||||||
server: server,
|
server: server,
|
||||||
endpoints: *DefaultEndpoints,
|
endpoints: endpoints,
|
||||||
decoder: decoder,
|
decoder: decoder,
|
||||||
logger: slog.Default(),
|
logger: slog.Default(),
|
||||||
}
|
}
|
||||||
|
@ -49,13 +49,6 @@ func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithEndpoints overrides the [DefaultEndpoints]
|
|
||||||
func WithEndpoints(endpoints Endpoints) ServerOption {
|
|
||||||
return func(s *webServer) {
|
|
||||||
s.endpoints = endpoints
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithDecoder overrides the default decoder,
|
// WithDecoder overrides the default decoder,
|
||||||
// which is a [schema.Decoder] with IgnoreUnknownKeys set to true.
|
// which is a [schema.Decoder] with IgnoreUnknownKeys set to true.
|
||||||
func WithDecoder(decoder httphelper.Decoder) ServerOption {
|
func WithDecoder(decoder httphelper.Decoder) ServerOption {
|
||||||
|
@ -96,17 +89,25 @@ func (s *webServer) createRouter() {
|
||||||
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
|
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
|
||||||
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
|
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
|
||||||
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
|
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
|
||||||
router.HandleFunc(s.endpoints.Authorization.Relative(), s.authorizeHandler)
|
|
||||||
router.HandleFunc(s.endpoints.DeviceAuthorization.Relative(), s.withClient(s.deviceAuthorizationHandler))
|
s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler)
|
||||||
router.HandleFunc(s.endpoints.Token.Relative(), s.tokensHandler)
|
s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
|
||||||
router.HandleFunc(s.endpoints.Introspection.Relative(), s.withClient(s.introspectionHandler))
|
s.endpointRoute(router, s.endpoints.Token, s.tokensHandler)
|
||||||
router.HandleFunc(s.endpoints.Userinfo.Relative(), s.userInfoHandler)
|
s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler))
|
||||||
router.HandleFunc(s.endpoints.Revocation.Relative(), s.withClient(s.revocationHandler))
|
s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler)
|
||||||
router.HandleFunc(s.endpoints.EndSession.Relative(), s.endSessionHandler)
|
s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler))
|
||||||
router.HandleFunc(s.endpoints.JwksURI.Relative(), simpleHandler(s, s.server.Keys))
|
s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler)
|
||||||
|
s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
|
||||||
s.Handler = router
|
s.Handler = router
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) {
|
||||||
|
if e != nil {
|
||||||
|
router.HandleFunc(e.Relative(), hf)
|
||||||
|
s.logger.Info("registered route", "endpoint", e.Relative())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type clientHandler func(w http.ResponseWriter, r *http.Request, client Client)
|
type clientHandler func(w http.ResponseWriter, r *http.Request, client Client)
|
||||||
|
|
||||||
func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {
|
func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {
|
||||||
|
|
|
@ -18,7 +18,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestServerRoutes(t *testing.T) {
|
func TestServerRoutes(t *testing.T) {
|
||||||
server := op.NewLegacyServer(testProvider)
|
server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints)
|
||||||
|
|
||||||
storage := testProvider.Storage().(routesTestStorage)
|
storage := testProvider.Storage().(routesTestStorage)
|
||||||
ctx := op.ContextWithIssuer(context.Background(), testIssuer)
|
ctx := op.ContextWithIssuer(context.Background(), testIssuer)
|
||||||
|
|
|
@ -25,15 +25,14 @@ import (
|
||||||
func TestRegisterServer(t *testing.T) {
|
func TestRegisterServer(t *testing.T) {
|
||||||
server := UnimplementedServer{}
|
server := UnimplementedServer{}
|
||||||
endpoints := Endpoints{
|
endpoints := Endpoints{
|
||||||
Authorization: Endpoint{
|
Authorization: &Endpoint{
|
||||||
path: "/auth",
|
path: "/auth",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
decoder := schema.NewDecoder()
|
decoder := schema.NewDecoder()
|
||||||
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
|
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
|
||||||
|
|
||||||
h := RegisterServer(server,
|
h := RegisterServer(server, endpoints,
|
||||||
WithEndpoints(endpoints),
|
|
||||||
WithDecoder(decoder),
|
WithDecoder(decoder),
|
||||||
WithFallbackLogger(logger),
|
WithFallbackLogger(logger),
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,15 +10,31 @@ import (
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// LegacyServer is an implementation of [Server[] that
|
||||||
|
// simply wraps a [OpenIDProvider].
|
||||||
|
// It can be used to transition from the former Provider/Storage
|
||||||
|
// interfaces to the new Server interface.
|
||||||
type LegacyServer struct {
|
type LegacyServer struct {
|
||||||
UnimplementedServer
|
UnimplementedServer
|
||||||
provider OpenIDProvider
|
provider OpenIDProvider
|
||||||
|
endpoints Endpoints
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLegacyServer(provider OpenIDProvider) http.Handler {
|
// NewLegacyServer wraps provider in a `Server` and returns a handler which is
|
||||||
|
// the Server's router.
|
||||||
|
//
|
||||||
|
// Only non-nil endpoints will be registered on the router.
|
||||||
|
// Nil endpoints are disabled.
|
||||||
|
//
|
||||||
|
// The passed endpoints is also set to the provider,
|
||||||
|
// to be consistent with the discovery config.
|
||||||
|
// Any `With*Endpoint()` option used on the provider is
|
||||||
|
// therefore ineffective.
|
||||||
|
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler {
|
||||||
server := RegisterServer(&LegacyServer{
|
server := RegisterServer(&LegacyServer{
|
||||||
provider: provider,
|
provider: provider,
|
||||||
}, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
|
endpoints: endpoints,
|
||||||
|
}, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
|
||||||
|
|
||||||
router := chi.NewRouter()
|
router := chi.NewRouter()
|
||||||
router.Mount("/", server)
|
router.Mount("/", server)
|
||||||
|
@ -43,7 +59,7 @@ func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Respon
|
||||||
|
|
||||||
func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||||
return NewResponse(
|
return NewResponse(
|
||||||
CreateDiscoveryConfig(ctx, s.provider, s.provider.Storage()),
|
createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints),
|
||||||
), nil
|
), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue