diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 9a5aecf..7465ae0 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -140,7 +140,7 @@ func (s *AuthStorage) AuthRequestByID(_ context.Context, id string) (op.AuthRequ } return a, nil } -func (s *AuthStorage) GetSigningKey(_ context.Context, keyCh chan<- jose.SigningKey, _ chan<- error, _ <-chan bool) { +func (s *AuthStorage) GetSigningKey(_ context.Context, keyCh chan<- jose.SigningKey, _ chan<- error, _ <-chan time.Time) { keyCh <- jose.SigningKey{Algorithm: jose.RS256, Key: s.key} } func (s *AuthStorage) GetKey(_ context.Context) (*rsa.PrivateKey, error) { diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index bcc9c31..63d9cd3 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -45,6 +45,7 @@ type DefaultOP struct { encoder *schema.Encoder interceptor HttpInterceptor retry func(int) (bool, int) + timer <-chan time.Time } type Config struct { @@ -123,6 +124,13 @@ func WithRetry(max int, sleep time.Duration) DefaultOPOpts { } } +func WithTimer(timer <-chan time.Time) DefaultOPOpts { + return func(o *DefaultOP) error { + o.timer = timer + return nil + } +} + func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) { err := ValidateIssuer(config.Issuer) if err != nil { @@ -133,18 +141,19 @@ func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts . config: config, storage: storage, endpoints: DefaultEndpoints, + timer: make(<-chan time.Time), } - keyCh := make(chan jose.SigningKey) - p.signer = NewDefaultSigner(ctx, storage, keyCh) - go p.ensureKey(ctx, storage, keyCh) - for _, optFunc := range opOpts { if err := optFunc(p); err != nil { return nil, err } } + keyCh := make(chan jose.SigningKey) + p.signer = NewDefaultSigner(ctx, storage, keyCh) + go p.ensureKey(ctx, storage, keyCh, p.timer) + router := CreateRouter(p, p.interceptor) p.http = &http.Server{ Addr: ":" + config.Port, @@ -252,12 +261,11 @@ func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) { Userinfo(w, r, p) } -func (p *DefaultOP) ensureKey(ctx context.Context, storage Storage, keyCh chan<- jose.SigningKey) { +func (p *DefaultOP) ensureKey(ctx context.Context, storage Storage, keyCh chan<- jose.SigningKey, timer <-chan time.Time) { count := 0 - explicit := make(chan bool) + timer = time.After(0) errCh := make(chan error) - go storage.GetSigningKey(ctx, keyCh, errCh, explicit) - explicit <- true + go storage.GetSigningKey(ctx, keyCh, errCh, timer) for { select { case <-ctx.Done(): @@ -275,7 +283,7 @@ func (p *DefaultOP) ensureKey(ctx context.Context, storage Storage, keyCh chan<- } ok, count = p.retry(count) if ok { - explicit <- true + timer = time.After(0) continue } logging.Log("OP-n6ynVE").WithError(err).Panic("error in key signer") diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 181ce3f..11a8ab5 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -11,6 +11,7 @@ import ( gomock "github.com/golang/mock/gomock" go_jose_v2 "gopkg.in/square/go-jose.v2" reflect "reflect" + time "time" ) // MockStorage is a mock of Storage interface @@ -125,7 +126,7 @@ func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call { } // GetSigningKey mocks base method -func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- go_jose_v2.SigningKey, arg2 chan<- error, arg3 <-chan bool) { +func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- go_jose_v2.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) { m.ctrl.T.Helper() m.ctrl.Call(m, "GetSigningKey", arg0, arg1, arg2, arg3) } diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 803aa58..fd9a582 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -14,7 +14,7 @@ type AuthStorage interface { AuthRequestByID(context.Context, string) (AuthRequest, error) DeleteAuthRequest(context.Context, string) error - GetSigningKey(context.Context, chan<- jose.SigningKey, chan<- error, <-chan bool) + GetSigningKey(context.Context, chan<- jose.SigningKey, chan<- error, <-chan time.Time) GetKeySet(context.Context) (*jose.JSONWebKeySet, error) SaveNewKeyPair(context.Context) error }