This commit is contained in:
Livio Amstutz 2019-11-21 14:38:23 +01:00
parent 8b0f4438fb
commit 720fe28f70
4 changed files with 244 additions and 192 deletions

View file

@ -16,7 +16,7 @@ func main() {
Port: "9998", Port: "9998",
} }
storage := &mock.Storage{} storage := &mock.Storage{}
handler, err := server.NewDefaultHandler(config, storage) handler, err := server.NewDefaultOP(config, storage, server.WithCustomTokenEndpoint("test"))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View file

@ -1,47 +1,47 @@
package server package server
import ( // import (
"net/http" // "net/http"
"net/http/httptest" // "net/http/httptest"
"testing" // "testing"
"github.com/stretchr/testify/require" // "github.com/stretchr/testify/require"
"github.com/caos/oidc/pkg/oidc" // "github.com/caos/oidc/pkg/oidc"
) // )
func TestDefaultHandler_HandleDiscovery(t *testing.T) { // func TestDefaultHandler_HandleDiscovery(t *testing.T) {
type fields struct { // type fields struct {
config *Config // config *Config
discoveryConfig *oidc.DiscoveryConfiguration // discoveryConfig *oidc.DiscoveryConfiguration
storage Storage // storage Storage
http *http.Server // http *http.Server
} // }
type args struct { // type args struct {
w http.ResponseWriter // w http.ResponseWriter
r *http.Request // r *http.Request
} // }
tests := []struct { // tests := []struct {
name string // name string
fields fields // fields fields
args args // args args
want string // want string
wantCode int // wantCode int
}{ // }{
{"OK", fields{config: nil, discoveryConfig: &oidc.DiscoveryConfiguration{Issuer: "test"}}, args{httptest.NewRecorder(), nil}, `{"issuer":"test"}`, 200}, // {"OK", fields{config: nil, discoveryConfig: &oidc.DiscoveryConfiguration{Issuer: "test"}}, args{httptest.NewRecorder(), nil}, `{"issuer":"test"}`, 200},
} // }
for _, tt := range tests { // for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { // t.Run(tt.name, func(t *testing.T) {
h := &DefaultHandler{ // h := &DefaultHandler{
config: tt.fields.config, // config: tt.fields.config,
discoveryConfig: tt.fields.discoveryConfig, // discoveryConfig: tt.fields.discoveryConfig,
storage: tt.fields.storage, // storage: tt.fields.storage,
http: tt.fields.http, // http: tt.fields.http,
} // }
h.HandleDiscovery(tt.args.w, tt.args.r) // h.HandleDiscovery(tt.args.w, tt.args.r)
rec := tt.args.w.(*httptest.ResponseRecorder) // rec := tt.args.w.(*httptest.ResponseRecorder)
require.Equal(t, tt.want, rec.Body.String()) // require.Equal(t, tt.want, rec.Body.String())
require.Equal(t, tt.wantCode, rec.Code) // require.Equal(t, tt.wantCode, rec.Code)
}) // })
} // }
} // }

View file

@ -11,8 +11,9 @@ import (
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
) )
type DefaultHandler struct { type DefaultOP struct {
config *Config config *Config
endpoints endpoints
discoveryConfig *oidc.DiscoveryConfiguration discoveryConfig *oidc.DiscoveryConfiguration
storage Storage storage Storage
http *http.Server http *http.Server
@ -20,13 +21,6 @@ type DefaultHandler struct {
type Config struct { type Config struct {
Issuer string Issuer string
AuthorizationEndpoint Endpoint
TokenEndpoint Endpoint
IntrospectionEndpoint Endpoint
UserinfoEndpoint Endpoint
EndSessionEndpoint Endpoint
CheckSessionIframe Endpoint
JwksURI Endpoint
// ScopesSupported: oidc.SupportedScopes, // ScopesSupported: oidc.SupportedScopes,
// ResponseTypesSupported: responseTypes, // ResponseTypesSupported: responseTypes,
// GrantTypesSupported: oidc.SupportedGrantTypes, // GrantTypesSupported: oidc.SupportedGrantTypes,
@ -37,32 +31,202 @@ type Config struct {
Port string Port string
} }
const ( type endpoints struct {
defaultAuthorizationEndpoint = "authorize" Authorization Endpoint
defaulTokenEndpoint = "token" Token Endpoint
defaultIntrospectEndpoint = "introspect" IntrospectionEndpoint Endpoint
defaultUserinfoEndpoint = "me" Userinfo Endpoint
) EndSessionEndpoint Endpoint
CheckSessionIframe Endpoint
JwksURI Endpoint
}
func (c *Config) DefaultAndValidate() error { type DefaultOPOpts func(o *DefaultOP) error
if err := ValidateIssuer(c.Issuer); err != nil {
func WithCustomAuthEndpoint(endpoint Endpoint) DefaultOPOpts {
return func(o *DefaultOP) error {
if err := endpoint.Validate(); err != nil {
return err return err
} }
if c.AuthorizationEndpoint == "" { o.endpoints.Authorization = endpoint
c.AuthorizationEndpoint = defaultAuthorizationEndpoint
}
if c.TokenEndpoint == "" {
c.TokenEndpoint = defaulTokenEndpoint
}
if c.IntrospectionEndpoint == "" {
c.IntrospectionEndpoint = defaultIntrospectEndpoint
}
if c.UserinfoEndpoint == "" {
c.UserinfoEndpoint = defaultUserinfoEndpoint
}
return nil return nil
}
} }
func WithCustomTokenEndpoint(endpoint Endpoint) DefaultOPOpts {
return func(o *DefaultOP) error {
if err := endpoint.Validate(); err != nil {
return err
}
o.endpoints.Token = endpoint
return nil
}
}
func WithCustomUserinfoEndpoint(endpoint Endpoint) DefaultOPOpts {
return func(o *DefaultOP) error {
if err := endpoint.Validate(); err != nil {
return err
}
o.endpoints.Userinfo = endpoint
return nil
}
}
const (
defaultAuthorizationEndpoint = "authorize"
defaulTokenEndpoint = "oauth/token"
defaultIntrospectEndpoint = "introspect"
defaultUserinfoEndpoint = "userinfo"
)
func CreateDiscoveryConfig(c Configuration) *oidc.DiscoveryConfiguration {
return &oidc.DiscoveryConfiguration{
Issuer: c.Issuer(),
AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()),
TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()),
// IntrospectionEndpoint: c.absoluteEndpoint(c.IntrospectionEndpoint),
// UserinfoEndpoint: c.absoluteEndpoint(c.UserinfoEndpoint),
// EndSessionEndpoint: c.absoluteEndpoint(c.EndSessionEndpoint),
// CheckSessionIframe: c.absoluteEndpoint(c.CheckSessionIframe),
// JwksURI: c.absoluteEndpoint(c.JwksURI),
// ScopesSupported: oidc.SupportedScopes,
// ResponseTypesSupported: responseTypes,
// GrantTypesSupported: oidc.SupportedGrantTypes,
// ClaimsSupported: oidc.SupportedClaims,
// IdTokenSigningAlgValuesSupported: []string{keys.SigningAlgorithm},
// SubjectTypesSupported: []string{"public"},
// TokenEndpointAuthMethodsSupported:
}
}
var DefaultEndpoints = endpoints{
Authorization: defaultAuthorizationEndpoint,
Token: defaulTokenEndpoint,
IntrospectionEndpoint: defaultIntrospectEndpoint,
Userinfo: defaultUserinfoEndpoint,
}
func NewDefaultOP(config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) {
if err := ValidateIssuer(config.Issuer); err != nil {
return nil, err
}
p := &DefaultOP{
config: config,
storage: storage,
endpoints: DefaultEndpoints,
}
for _, optFunc := range opOpts {
if err := optFunc(p); err != nil {
return nil, err
}
}
p.discoveryConfig = CreateDiscoveryConfig(p)
router := CreateRouter(p)
p.http = &http.Server{
Addr: ":" + config.Port,
Handler: router,
}
return p, nil
}
func (p *DefaultOP) Issuer() string {
return p.config.Issuer
}
type Endpoint string
func (e Endpoint) Relative() string {
return relativeEndpoint(string(e))
}
func (e Endpoint) Absolute(host string) string {
return absoluteEndpoint(host, string(e))
}
func (e Endpoint) Validate() error {
return nil //TODO:
}
func (p *DefaultOP) AuthorizationEndpoint() Endpoint {
return p.endpoints.Authorization
}
func (p *DefaultOP) TokenEndpoint() Endpoint {
return Endpoint(p.endpoints.Token)
}
func (p *DefaultOP) UserinfoEndpoint() Endpoint {
return Endpoint(p.endpoints.Userinfo)
}
func (p *DefaultOP) Port() string {
return p.config.Port
}
func (p *DefaultOP) HttpHandler() *http.Server {
return p.http
}
func (p *DefaultOP) HandleDiscovery(w http.ResponseWriter, r *http.Request) {
utils.MarshalJSON(w, p.discoveryConfig)
}
func (p *DefaultOP) HandleAuthorize(w http.ResponseWriter, r *http.Request) {
authRequest, err := ParseAuthRequest(w, r)
if err != nil {
//TODO: return err
}
err = ValidateAuthRequest(authRequest)
if err != nil {
//TODO: return err
}
if NeedsExistingSession(authRequest) {
// session, err := p.storage.CheckSession(authRequest)
// if err != nil {
// //TODO: return err
// }
}
err = p.storage.CreateAuthRequest(authRequest)
if err != nil {
//TODO: return err
}
//TODO: redirect?
}
func (p *DefaultOP) HandleExchange(w http.ResponseWriter, r *http.Request) {
}
func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) {
}
// func (c *Config) DefaultAndValidate() error {
// if err := ValidateIssuer(c.Issuer); err != nil {
// return err
// }
// if c.AuthorizationEndpoint == "" {
// c.AuthorizationEndpoint = defaultAuthorizationEndpoint
// }
// if c.TokenEndpoint == "" {
// c.TokenEndpoint = defaulTokenEndpoint
// }
// if c.IntrospectionEndpoint == "" {
// c.IntrospectionEndpoint = defaultIntrospectEndpoint
// }
// if c.UserinfoEndpoint == "" {
// c.UserinfoEndpoint = defaultUserinfoEndpoint
// }
// return nil
// }
func ValidateIssuer(issuer string) error { func ValidateIssuer(issuer string) error {
if issuer == "" { if issuer == "" {
return errors.New("missing issuer") return errors.New("missing issuer")
@ -85,27 +249,6 @@ func ValidateIssuer(issuer string) error {
return nil return nil
} }
func OIDC(c Configuration) *oidc.DiscoveryConfiguration {
return &oidc.DiscoveryConfiguration{
Issuer: c.Issuer(),
AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()),
// TokenEndpoint: c.absoluteEndpoint(c.TokenEndpoint),
// IntrospectionEndpoint: c.absoluteEndpoint(c.IntrospectionEndpoint),
// UserinfoEndpoint: c.absoluteEndpoint(c.UserinfoEndpoint),
// EndSessionEndpoint: c.absoluteEndpoint(c.EndSessionEndpoint),
// CheckSessionIframe: c.absoluteEndpoint(c.CheckSessionIframe),
// JwksURI: c.absoluteEndpoint(c.JwksURI),
// ScopesSupported: oidc.SupportedScopes,
// ResponseTypesSupported: responseTypes,
// GrantTypesSupported: oidc.SupportedGrantTypes,
// ClaimsSupported: oidc.SupportedClaims,
// IdTokenSigningAlgValuesSupported: []string{keys.SigningAlgorithm},
// SubjectTypesSupported: []string{"public"},
// TokenEndpointAuthMethodsSupported:
}
}
func (c *Config) absoluteEndpoint(endpoint string) string { func (c *Config) absoluteEndpoint(endpoint string) string {
return strings.TrimSuffix(c.Issuer, "/") + relativeEndpoint(endpoint) return strings.TrimSuffix(c.Issuer, "/") + relativeEndpoint(endpoint)
} }
@ -117,94 +260,3 @@ func absoluteEndpoint(host, endpoint string) string {
func relativeEndpoint(endpoint string) string { func relativeEndpoint(endpoint string) string {
return "/" + strings.TrimPrefix(endpoint, "/") return "/" + strings.TrimPrefix(endpoint, "/")
} }
func NewDefaultHandler(config *Config, storage Storage) (Handler, error) {
err := config.DefaultAndValidate()
if err != nil {
return nil, err
}
h := &DefaultHandler{
config: config,
storage: storage,
}
h.discoveryConfig = OIDC(h)
router := CreateRouter(h)
h.http = &http.Server{
Addr: ":" + config.Port,
Handler: router,
}
return h, nil
}
func (h *DefaultHandler) Issuer() string {
return h.config.Issuer
}
type Endpoint string
func (e Endpoint) Relative() string {
return relativeEndpoint(string(e))
}
func (e Endpoint) Absolute(host string) string {
return absoluteEndpoint(host, string(e))
}
func (e Endpoint) Validate() error {
return nil //TODO:
}
func (h *DefaultHandler) AuthorizationEndpoint() Endpoint {
return Endpoint(h.config.AuthorizationEndpoint)
}
func (h *DefaultHandler) TokenEndpoint() Endpoint {
return Endpoint(h.config.TokenEndpoint)
}
func (h *DefaultHandler) UserinfoEndpoint() Endpoint {
return Endpoint(h.config.UserinfoEndpoint)
}
func (h *DefaultHandler) Port() string {
return h.config.Port
}
func (h *DefaultHandler) HttpHandler() *http.Server {
return h.http
}
func (h *DefaultHandler) HandleDiscovery(w http.ResponseWriter, r *http.Request) {
utils.MarshalJSON(w, h.discoveryConfig)
}
func (h *DefaultHandler) HandleAuthorize(w http.ResponseWriter, r *http.Request) {
authRequest, err := ParseAuthRequest(w, r)
if err != nil {
//TODO: return err
}
err = ValidateAuthRequest(authRequest)
if err != nil {
//TODO: return err
}
if NeedsExistingSession(authRequest) {
// session, err := h.storage.CheckSession(authRequest)
// if err != nil {
// //TODO: return err
// }
}
err = h.storage.CreateAuthRequest(authRequest)
if err != nil {
//TODO: return err
}
//TODO: redirect?
}
func (h *DefaultHandler) HandleExchange(w http.ResponseWriter, r *http.Request) {
}
func (h *DefaultHandler) HandleUserinfo(w http.ResponseWriter, r *http.Request) {
}

View file

@ -10,7 +10,7 @@ import (
"github.com/caos/utils/logging" "github.com/caos/utils/logging"
) )
type Handler interface { type OpenIDProvider interface {
Configuration Configuration
// Storage() Storage // Storage() Storage
HandleDiscovery(w http.ResponseWriter, r *http.Request) HandleDiscovery(w http.ResponseWriter, r *http.Request)
@ -20,25 +20,25 @@ type Handler interface {
HttpHandler() *http.Server HttpHandler() *http.Server
} }
func CreateRouter(h Handler) *mux.Router { func CreateRouter(o OpenIDProvider) *mux.Router {
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc(oidc.DiscoveryEndpoint, h.HandleDiscovery) router.HandleFunc(oidc.DiscoveryEndpoint, o.HandleDiscovery)
router.HandleFunc(h.AuthorizationEndpoint().Relative(), h.HandleAuthorize) router.HandleFunc(o.AuthorizationEndpoint().Relative(), o.HandleAuthorize)
router.HandleFunc(h.TokenEndpoint().Relative(), h.HandleExchange) router.HandleFunc(o.TokenEndpoint().Relative(), o.HandleExchange)
router.HandleFunc(h.UserinfoEndpoint().Relative(), h.HandleUserinfo) router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo)
return router return router
} }
func Start(ctx context.Context, h Handler) { func Start(ctx context.Context, o OpenIDProvider) {
go func() { go func() {
<-ctx.Done() <-ctx.Done()
err := h.HttpHandler().Shutdown(ctx) err := o.HttpHandler().Shutdown(ctx)
logging.Log("SERVE-REqwpM").OnError(err).Error("graceful shutdown of oidc server failed") logging.Log("SERVE-REqwpM").OnError(err).Error("graceful shutdown of oidc server failed")
}() }()
go func() { go func() {
err := h.HttpHandler().ListenAndServe() err := o.HttpHandler().ListenAndServe()
logging.Log("SERVE-4YNIwG").OnError(err).Panic("oidc server serve failed") logging.Log("SERVE-4YNIwG").OnError(err).Panic("oidc server serve failed")
}() }()
logging.LogWithFields("SERVE-koAFMs", "port", h.Port()).Info("oidc server is listening") logging.LogWithFields("SERVE-koAFMs", "port", o.Port()).Info("oidc server is listening")
} }