From 93709a18b68d97803999bbae1ad6ee4f46dfd14e Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Tue, 11 Feb 2020 17:17:09 +0100 Subject: [PATCH] add readiness and partial key rotation --- example/internal/mock/storage.go | 12 ++-- go.mod | 4 +- go.sum | 6 ++ pkg/op/default_op.go | 87 +++++++++++++++++++++++------ pkg/op/default_op_test.go | 49 ---------------- pkg/op/error.go | 2 +- pkg/op/mock/authorizer.mock.impl.go | 5 ++ pkg/op/mock/signer.mock.go | 15 +++++ pkg/op/mock/storage.mock.go | 42 ++++++++------ pkg/op/mock/storage.mock.impl.go | 18 +++++- pkg/op/op.go | 20 ++++--- pkg/op/probes.go | 51 +++++++++++++++++ pkg/op/signer.go | 64 +++++++++++---------- pkg/op/signer_test.go | 2 +- pkg/op/storage.go | 9 ++- 15 files changed, 254 insertions(+), 132 deletions(-) delete mode 100644 pkg/op/default_op_test.go create mode 100644 pkg/op/probes.go diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 37f47b1..9a5aecf 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -112,6 +112,10 @@ var ( t bool ) +func (s *AuthStorage) Health(ctx context.Context) error { + return nil +} + func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthRequest) (op.AuthRequest, error) { a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI} if authReq.CodeChallenge != "" { @@ -136,14 +140,14 @@ func (s *AuthStorage) AuthRequestByID(_ context.Context, id string) (op.AuthRequ } return a, nil } -func (s *AuthStorage) GetSigningKey(_ context.Context) (*jose.SigningKey, error) { - return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil +func (s *AuthStorage) GetSigningKey(_ context.Context, keyCh chan<- jose.SigningKey, _ chan<- error, _ <-chan bool) { + keyCh <- jose.SigningKey{Algorithm: jose.RS256, Key: s.key} } func (s *AuthStorage) GetKey(_ context.Context) (*rsa.PrivateKey, error) { return s.key, nil } -func (s *AuthStorage) SaveKeyPair(ctx context.Context) (*jose.SigningKey, error) { - return s.GetSigningKey(ctx) +func (s *AuthStorage) SaveNewKeyPair(ctx context.Context) error { + return nil } func (s *AuthStorage) GetKeySet(_ context.Context) (*jose.JSONWebKeySet, error) { pubkey := s.key.Public() diff --git a/go.mod b/go.mod index da7059a..70bbb7f 100644 --- a/go.mod +++ b/go.mod @@ -3,23 +3,21 @@ module github.com/caos/oidc go 1.13 require ( + github.com/caos/logging v0.0.0-20191210002624-b3260f690a6a github.com/golang/mock v1.3.1 github.com/golang/protobuf v1.3.2 // indirect github.com/google/uuid v1.1.1 github.com/gorilla/mux v1.7.3 github.com/gorilla/schema v1.1.0 github.com/gorilla/securecookie v1.1.1 - github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect github.com/kr/pretty v0.1.0 // indirect github.com/sirupsen/logrus v1.4.2 github.com/stretchr/testify v1.4.0 golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1 // indirect golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933 golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c - golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9 // indirect golang.org/x/text v0.3.2 google.golang.org/appengine v1.6.5 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/square/go-jose.v2 v2.4.0 - gopkg.in/yaml.v2 v2.2.3 // indirect ) diff --git a/go.sum b/go.sum index 54a2ca8..d586cb1 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/caos/logging v0.0.0-20191210002624-b3260f690a6a h1:HOU/3xL/afsZ+2aCstfJlrzRkwYMTFR1TIEgps5ny8s= +github.com/caos/logging v0.0.0-20191210002624-b3260f690a6a/go.mod h1:9LKiDE2ChuGv6CHYif/kiugrfEXu9AwDiFWSreX7Wp0= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -54,6 +56,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9 h1:ZBzSG/7F4eNKz2L3GE9o300RX0Az1Bw5HF7PDraD+qU= golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191206220618-eeba5f6aabab h1:FvshnhkKW+LO3HWHodML8kuVX8rnJTxKm9dFPuI68UM= +golang.org/x/sys v0.0.0-20191206220618-eeba5f6aabab/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -72,3 +76,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.3 h1:fvjTMHxHEw/mxHbtzPi3JCcKXQRAnQTBRo6YCJSVHKI= gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= +gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index 7ce7925..07486c6 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -3,9 +3,12 @@ package op import ( "context" "net/http" + "time" "github.com/gorilla/schema" + "gopkg.in/square/go-jose.v2" + "github.com/caos/logging" "github.com/caos/oidc/pkg/oidc" ) @@ -32,16 +35,16 @@ var ( ) type DefaultOP struct { - config *Config - endpoints *endpoints - discoveryConfig *oidc.DiscoveryConfiguration - storage Storage - signer Signer - crypto Crypto - http *http.Server - decoder *schema.Decoder - encoder *schema.Encoder - interceptor HttpInterceptor + config *Config + endpoints *endpoints + storage Storage + signer Signer + crypto Crypto + http *http.Server + decoder *schema.Decoder + encoder *schema.Encoder + interceptor HttpInterceptor + retry func(int) (bool, int) } type Config struct { @@ -106,6 +109,20 @@ func WithHttpInterceptor(h HttpInterceptor) DefaultOPOpts { } } +func WithRetry(max int, sleep time.Duration) DefaultOPOpts { + return func(o *DefaultOP) error { + o.retry = func(count int) (bool, int) { + count++ + if count == max { + return false, count + } + time.Sleep(sleep) + return true, count + } + return nil + } +} + func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) { err := ValidateIssuer(config.Issuer) if err != nil { @@ -118,10 +135,10 @@ func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts . endpoints: DefaultEndpoints, } - p.signer, err = NewDefaultSigner(ctx, storage) - if err != nil { - return nil, err - } + keyCh := make(chan jose.SigningKey) + // ctx, cancel := context.WithCancel(ctx) + p.signer = NewDefaultSigner(ctx, storage, keyCh) + go p.ensureKey(ctx, storage, keyCh) for _, optFunc := range opOpts { if err := optFunc(p); err != nil { @@ -129,8 +146,6 @@ func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts . } } - p.discoveryConfig = CreateDiscoveryConfig(p, p.signer) - router := CreateRouter(p, p.interceptor) p.http = &http.Server{ Addr: ":" + config.Port, @@ -179,7 +194,7 @@ func (p *DefaultOP) HttpHandler() *http.Server { } func (p *DefaultOP) HandleDiscovery(w http.ResponseWriter, r *http.Request) { - Discover(w, p.discoveryConfig) + Discover(w, CreateDiscoveryConfig(p, p.Signer())) } func (p *DefaultOP) Decoder() *schema.Decoder { @@ -201,6 +216,13 @@ func (p *DefaultOP) Signer() Signer { func (p *DefaultOP) Crypto() Crypto { return p.crypto } +func (p *DefaultOP) HandleReady(w http.ResponseWriter, r *http.Request) { + probes := []ProbesFn{ + ReadySigner(p.Signer()), + ReadyStorage(p.Storage()), + } + Readiness(w, r, probes...) +} func (p *DefaultOP) HandleKeys(w http.ResponseWriter, r *http.Request) { Keys(w, r, p) @@ -230,3 +252,34 @@ func (p *DefaultOP) HandleExchange(w http.ResponseWriter, r *http.Request) { func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) { Userinfo(w, r, p) } + +func (p *DefaultOP) ensureKey(ctx context.Context, storage Storage, keyCh chan<- jose.SigningKey) { + count := 0 + explicit := make(chan bool) + errCh := make(chan error) + go storage.GetSigningKey(ctx, keyCh, errCh, explicit) + explicit <- true + for { + select { + case <-ctx.Done(): + return + case err := <-errCh: + if err == nil { + continue + } + _, ok := err.(StorageNotFoundError) + if ok { + err := storage.SaveNewKeyPair(ctx) + if err == nil { + continue + } + } + ok, count = p.retry(count) + if ok { + explicit <- true + continue + } + logging.Log("OP-n6ynVE").WithError(err).Panic("error in key signer") + } + } +} diff --git a/pkg/op/default_op_test.go b/pkg/op/default_op_test.go deleted file mode 100644 index ed359a5..0000000 --- a/pkg/op/default_op_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package op - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/caos/oidc/pkg/oidc" -) - -func TestDefaultOP_HandleDiscovery(t *testing.T) { - type fields struct { - config *Config - endpoints *endpoints - discoveryConfig *oidc.DiscoveryConfiguration - storage Storage - http *http.Server - } - type args struct { - w http.ResponseWriter - r *http.Request - } - tests := []struct { - name string - fields fields - args args - want string - wantCode int - }{ - {"OK", fields{config: nil, endpoints: nil, discoveryConfig: &oidc.DiscoveryConfiguration{Issuer: "https://issuer.com"}}, args{httptest.NewRecorder(), nil}, `{"issuer":"https://issuer.com"}`, 200}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p := &DefaultOP{ - config: tt.fields.config, - endpoints: tt.fields.endpoints, - discoveryConfig: tt.fields.discoveryConfig, - storage: tt.fields.storage, - http: tt.fields.http, - } - p.HandleDiscovery(tt.args.w, tt.args.r) - rec := tt.args.w.(*httptest.ResponseRecorder) - require.Equal(t, tt.want, rec.Body.String()) - require.Equal(t, tt.wantCode, rec.Code) - }) - } -} diff --git a/pkg/op/error.go b/pkg/op/error.go index 1e84c1a..c6e702e 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -89,7 +89,7 @@ func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) { type OAuthError struct { ErrorType errorType `json:"error" schema:"error"` - Description string `json:"description" schema:"description"` + Description string `json:"error_description" schema:"error_description"` state string `json:"state" schema:"state"` redirectDisabled bool } diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go index 0091877..93994ad 100644 --- a/pkg/op/mock/authorizer.mock.impl.go +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -1,6 +1,7 @@ package mock import ( + "context" "testing" "github.com/golang/mock/gomock" @@ -67,6 +68,10 @@ func ExpectSigner(a op.Authorizer, t *testing.T) { type Sig struct{} +func (s *Sig) Health(ctx context.Context) error { + return nil +} + func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) { return "", nil } diff --git a/pkg/op/mock/signer.mock.go b/pkg/op/mock/signer.mock.go index 5c7b669..c780752 100644 --- a/pkg/op/mock/signer.mock.go +++ b/pkg/op/mock/signer.mock.go @@ -5,6 +5,7 @@ package mock import ( + context "context" oidc "github.com/caos/oidc/pkg/oidc" gomock "github.com/golang/mock/gomock" go_jose_v2 "gopkg.in/square/go-jose.v2" @@ -34,6 +35,20 @@ func (m *MockSigner) EXPECT() *MockSignerMockRecorder { return m.recorder } +// Health mocks base method +func (m *MockSigner) Health(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Health", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Health indicates an expected call of Health +func (mr *MockSignerMockRecorder) Health(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockSigner)(nil).Health), arg0) +} + // SignAccessToken mocks base method func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) { m.ctrl.T.Helper() diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 3a36417..181ce3f 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -125,18 +125,15 @@ func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call { } // GetSigningKey mocks base method -func (m *MockStorage) GetSigningKey(arg0 context.Context) (*go_jose_v2.SigningKey, error) { +func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- go_jose_v2.SigningKey, arg2 chan<- error, arg3 <-chan bool) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSigningKey", arg0) - ret0, _ := ret[0].(*go_jose_v2.SigningKey) - ret1, _ := ret[1].(error) - return ret0, ret1 + m.ctrl.Call(m, "GetSigningKey", arg0, arg1, arg2, arg3) } // GetSigningKey indicates an expected call of GetSigningKey -func (mr *MockStorageMockRecorder) GetSigningKey(arg0 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey), arg0, arg1, arg2, arg3) } // GetUserinfoFromScopes mocks base method @@ -154,17 +151,30 @@ func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1) } -// SaveKeyPair mocks base method -func (m *MockStorage) SaveKeyPair(arg0 context.Context) (*go_jose_v2.SigningKey, error) { +// Health mocks base method +func (m *MockStorage) Health(arg0 context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SaveKeyPair", arg0) - ret0, _ := ret[0].(*go_jose_v2.SigningKey) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "Health", arg0) + ret0, _ := ret[0].(error) + return ret0 } -// SaveKeyPair indicates an expected call of SaveKeyPair -func (mr *MockStorageMockRecorder) SaveKeyPair(arg0 interface{}) *gomock.Call { +// Health indicates an expected call of Health +func (mr *MockStorageMockRecorder) Health(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveKeyPair", reflect.TypeOf((*MockStorage)(nil).SaveKeyPair), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0) +} + +// SaveNewKeyPair mocks base method +func (m *MockStorage) SaveNewKeyPair(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveNewKeyPair", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveNewKeyPair indicates an expected call of SaveNewKeyPair +func (mr *MockStorageMockRecorder) SaveNewKeyPair(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveNewKeyPair", reflect.TypeOf((*MockStorage)(nil).SaveNewKeyPair), arg0) } diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go index 7cd62b9..e4328c7 100644 --- a/pkg/op/mock/storage.mock.impl.go +++ b/pkg/op/mock/storage.mock.impl.go @@ -86,17 +86,29 @@ func ExpectValidClientID(s op.Storage) { func ExpectSigningKeyError(s op.Storage) { mockS := s.(*MockStorage) - mockS.EXPECT().GetSigningKey(gomock.Any()).Return(nil, errors.New("error")) + mockS.EXPECT().GetSigningKey(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, keyCh chan<- jose.SigningKey, errCh chan<- error, _ <-chan bool) { + errCh <- errors.New("error") + }, + ) } func ExpectSigningKeyInvalid(s op.Storage) { mockS := s.(*MockStorage) - mockS.EXPECT().GetSigningKey(gomock.Any()).Return(&jose.SigningKey{}, nil) + mockS.EXPECT().GetSigningKey(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, keyCh chan<- jose.SigningKey, errCh chan<- error, _ <-chan bool) { + keyCh <- jose.SigningKey{} + }, + ) } func ExpectSigningKey(s op.Storage) { mockS := s.(*MockStorage) - mockS.EXPECT().GetSigningKey(gomock.Any()).Return(&jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, nil) + mockS.EXPECT().GetSigningKey(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, keyCh chan<- jose.SigningKey, errCh chan<- error, _ <-chan bool) { + keyCh <- jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")} + }, + ) } type ConfClient struct { diff --git a/pkg/op/op.go b/pkg/op/op.go index 4d64e04..7bdd08e 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -10,8 +10,14 @@ import ( "github.com/caos/oidc/pkg/oidc" ) +const ( + healthzEndpoint = "/healthz" + readinessEndpoint = "/ready" +) + type OpenIDProvider interface { Configuration + HandleReady(w http.ResponseWriter, r *http.Request) HandleDiscovery(w http.ResponseWriter, r *http.Request) HandleAuthorize(w http.ResponseWriter, r *http.Request) HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request) @@ -23,19 +29,19 @@ type OpenIDProvider interface { type HttpInterceptor func(http.HandlerFunc) http.HandlerFunc -var ( - DefaultInterceptor = func(h http.HandlerFunc) http.HandlerFunc { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - h(w, r) - }) - } -) +var DefaultInterceptor = func(h http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h(w, r) + }) +} func CreateRouter(o OpenIDProvider, h HttpInterceptor) *mux.Router { if h == nil { h = DefaultInterceptor } router := mux.NewRouter() + router.HandleFunc(healthzEndpoint, Healthz) + router.HandleFunc(readinessEndpoint, o.HandleReady) router.HandleFunc(oidc.DiscoveryEndpoint, o.HandleDiscovery) router.HandleFunc(o.AuthorizationEndpoint().Relative(), h(o.HandleAuthorize)) router.HandleFunc(o.AuthorizationEndpoint().Relative()+"/{id}", h(o.HandleAuthorizeCallback)) diff --git a/pkg/op/probes.go b/pkg/op/probes.go new file mode 100644 index 0000000..50e8a0f --- /dev/null +++ b/pkg/op/probes.go @@ -0,0 +1,51 @@ +package op + +import ( + "context" + "errors" + "net/http" + + "github.com/caos/oidc/pkg/utils" +) + +type ProbesFn func(context.Context) error + +func Healthz(w http.ResponseWriter, r *http.Request) { + ok(w) +} + +func Readiness(w http.ResponseWriter, r *http.Request, probes ...ProbesFn) { + ctx := r.Context() + for _, probe := range probes { + if err := probe(ctx); err != nil { + http.Error(w, "not ready", http.StatusInternalServerError) + return + } + } + ok(w) +} + +func ReadySigner(s Signer) ProbesFn { + return func(ctx context.Context) error { + if s == nil { + return errors.New("no signer") + } + return s.Health(ctx) + } +} +func ReadyStorage(s Storage) ProbesFn { + return func(ctx context.Context) error { + if s == nil { + return errors.New("no storage") + } + return s.Health(ctx) + } +} + +func ok(w http.ResponseWriter) { + utils.MarshalJSON(w, status{"ok"}) +} + +type status struct { + Status string `json:"status,omitempty"` +} diff --git a/pkg/op/signer.go b/pkg/op/signer.go index 6235931..b4f770e 100644 --- a/pkg/op/signer.go +++ b/pkg/op/signer.go @@ -2,54 +2,60 @@ package op import ( "encoding/json" + "errors" "golang.org/x/net/context" "gopkg.in/square/go-jose.v2" + "github.com/caos/logging" "github.com/caos/oidc/pkg/oidc" ) type Signer interface { + Health(ctx context.Context) error SignIDToken(claims *oidc.IDTokenClaims) (string, error) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) SignatureAlgorithm() jose.SignatureAlgorithm } -type idTokenSigner struct { - signer jose.Signer - storage AuthStorage - algorithm jose.SignatureAlgorithm +type tokenSigner struct { + signer jose.Signer + storage AuthStorage + alg jose.SignatureAlgorithm } -func NewDefaultSigner(ctx context.Context, storage AuthStorage) (Signer, error) { - s := &idTokenSigner{ +func NewDefaultSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer { + s := &tokenSigner{ storage: storage, } - if err := s.initialize(ctx); err != nil { - return nil, err - } - return s, nil + + go s.refreshSigningKey(ctx, keyCh) + + return s } -func (s *idTokenSigner) initialize(ctx context.Context) error { - var key *jose.SigningKey - var err error - key, err = s.storage.GetSigningKey(ctx) - if err != nil { - key, err = s.storage.SaveKeyPair(ctx) - if err != nil { - return err - } +func (s *tokenSigner) Health(_ context.Context) error { + if s.signer == nil { + return errors.New("no signer") } - s.signer, err = jose.NewSigner(*key, &jose.SignerOptions{}) - if err != nil { - return err - } - s.algorithm = key.Algorithm return nil } -func (s *idTokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) { +func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) { + for { + select { + case <-ctx.Done(): + return + case key := <-keyCh: + s.alg = key.Algorithm + var err error + s.signer, err = jose.NewSigner(key, &jose.SignerOptions{}) + logging.Log("OP-pf32aw").OnError(err).Error("error creating signer") + } + } +} + +func (s *tokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) { payload, err := json.Marshal(claims) if err != nil { return "", err @@ -57,7 +63,7 @@ func (s *idTokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) return s.Sign(payload) } -func (s *idTokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) { +func (s *tokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) { payload, err := json.Marshal(claims) if err != nil { return "", err @@ -65,7 +71,7 @@ func (s *idTokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, return s.Sign(payload) } -func (s *idTokenSigner) Sign(payload []byte) (string, error) { +func (s *tokenSigner) Sign(payload []byte) (string, error) { result, err := s.signer.Sign(payload) if err != nil { return "", err @@ -73,6 +79,6 @@ func (s *idTokenSigner) Sign(payload []byte) (string, error) { return result.CompactSerialize() } -func (s *idTokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm { - return s.algorithm +func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm { + return s.alg } diff --git a/pkg/op/signer_test.go b/pkg/op/signer_test.go index 21aab0d..75e184b 100644 --- a/pkg/op/signer_test.go +++ b/pkg/op/signer_test.go @@ -78,7 +78,7 @@ func Test_idTokenSigner_Sign(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := &idTokenSigner{ + s := &tokenSigner{ signer: tt.fields.signer, storage: tt.fields.storage, } diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 83b9f3e..803aa58 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -14,9 +14,9 @@ type AuthStorage interface { AuthRequestByID(context.Context, string) (AuthRequest, error) DeleteAuthRequest(context.Context, string) error - GetSigningKey(context.Context) (*jose.SigningKey, error) + GetSigningKey(context.Context, chan<- jose.SigningKey, chan<- error, <-chan bool) GetKeySet(context.Context) (*jose.JSONWebKeySet, error) - SaveKeyPair(context.Context) (*jose.SigningKey, error) + SaveNewKeyPair(context.Context) error } type OPStorage interface { @@ -28,6 +28,11 @@ type OPStorage interface { type Storage interface { AuthStorage OPStorage + Health(context.Context) error +} + +type StorageNotFoundError interface { + IsNotFound() } type AuthRequest interface {