diff --git a/pkg/rp/default_rp.go b/pkg/rp/default_rp.go index 4087a97..7c956bb 100644 --- a/pkg/rp/default_rp.go +++ b/pkg/rp/default_rp.go @@ -52,7 +52,7 @@ func NewDefaultRP(rpConfig *Config, rpOpts ...DefaultRPOpts) (DelegationTokenExc } if p.verifier == nil { - p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, nil) //utils.NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL)) //TODO: keys endpoint + p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL)) //TODO: keys endpoint } return p, nil diff --git a/pkg/rp/jwks.go b/pkg/rp/jwks.go new file mode 100644 index 0000000..3851374 --- /dev/null +++ b/pkg/rp/jwks.go @@ -0,0 +1,178 @@ +package rp + +import ( + "context" + "errors" + "fmt" + "net/http" + "sync" + "time" + + "github.com/caos/oidc/pkg/utils" + + "gopkg.in/square/go-jose.v2" + + "github.com/caos/oidc/pkg/oidc" +) + +func NewRemoteKeySet(client *http.Client, jwksURL string) oidc.KeySet { + return &remoteKeySet{httpClient: client, jwksURL: jwksURL} +} + +type remoteKeySet struct { + jwksURL string + httpClient *http.Client + // now func() time.Time + + // guard all other fields + mu sync.Mutex + + // inflight suppresses parallel execution of updateKeys and allows + // multiple goroutines to wait for its result. + inflight *inflight + + // A set of cached keys and their expiry. + cachedKeys []jose.JSONWebKey + expiry time.Time +} + +// inflight is used to wait on some in-flight request from multiple goroutines. +type inflight struct { + doneCh chan struct{} + + keys []jose.JSONWebKey + err error +} + +func newInflight() *inflight { + return &inflight{doneCh: make(chan struct{})} +} + +// wait returns a channel that multiple goroutines can receive on. Once it returns +// a value, the inflight request is done and result() can be inspected. +func (i *inflight) wait() <-chan struct{} { + return i.doneCh +} + +// done can only be called by a single goroutine. It records the result of the +// inflight request and signals other goroutines that the result is safe to +// inspect. +func (i *inflight) done(keys []jose.JSONWebKey, err error) { + i.keys = keys + i.err = err + close(i.doneCh) +} + +// result cannot be called until the wait() channel has returned a value. +func (i *inflight) result() ([]jose.JSONWebKey, error) { + return i.keys, i.err +} + +func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { + // We don't support JWTs signed with multiple signatures. + keyID := "" + for _, sig := range jws.Signatures { + keyID = sig.Header.KeyID + break + } + + keys, _ := r.keysFromCache() + payload, err, ok := checkKey(keyID, keys, jws) + if ok { + return payload, err + } + + keys, err = r.keysFromRemote(ctx) + if err != nil { + return nil, fmt.Errorf("fetching keys %v", err) + } + + payload, err, ok = checkKey(keyID, keys, jws) + if !ok { + return nil, errors.New("invalid kid") + } + return payload, err +} + +func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey, expiry time.Time) { + r.mu.Lock() + defer r.mu.Unlock() + return r.cachedKeys, r.expiry +} + +// keysFromRemote syncs the key set from the remote set, records the values in the +// cache, and returns the key set. +func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) { + // Need to lock to inspect the inflight request field. + r.mu.Lock() + // If there's not a current inflight request, create one. + if r.inflight == nil { + r.inflight = newInflight() + + // This goroutine has exclusive ownership over the current inflight + // request. It releases the resource by nil'ing the inflight field + // once the goroutine is done. + go r.updateKeys(ctx) + } + inflight := r.inflight + r.mu.Unlock() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-inflight.wait(): + return inflight.result() + } +} + +func (r *remoteKeySet) updateKeys(ctx context.Context) { + // Sync keys and finish inflight when that's done. + keys, expiry, err := r.fetchRemoteKeys(ctx) + + r.inflight.done(keys, err) + + // Lock to update the keys and indicate that there is no longer an + // inflight request. + r.mu.Lock() + defer r.mu.Unlock() + + if err == nil { + r.cachedKeys = keys + r.expiry = expiry + } + + // Free inflight so a different request can run. + r.inflight = nil +} + +func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey, time.Time, error) { + req, err := http.NewRequest("GET", r.jwksURL, nil) + if err != nil { + return nil, time.Time{}, fmt.Errorf("oidc: can't create request: %v", err) + } + + keySet := new(jose.JSONWebKeySet) + if err = utils.HttpRequest(r.httpClient, req, keySet); err != nil { + return nil, time.Time{}, fmt.Errorf("oidc: failed to get keys: %v", err) + } + + // If the server doesn't provide cache control headers, assume the + // keys expire immediately. + // expiry := r.now() + + // _, e, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{}) + // if err == nil && e.After(expiry) { + // expiry = e + // } + return keySet.Keys, time.Now(), nil +} + +func checkKey(keyID string, keys []jose.JSONWebKey, jws *jose.JSONWebSignature) ([]byte, error, bool) { + for _, key := range keys { + if keyID == "" || key.KeyID == keyID { + payload, err := jws.Verify(&key) + return payload, err, true + } + } + return nil, nil, false +}