zitadel-oidc/pkg/rp/jwks.go
2020-01-31 15:22:16 +01:00

166 lines
4 KiB
Go

package rp
import (
"context"
"errors"
"fmt"
"net/http"
"sync"
"github.com/caos/oidc/pkg/utils"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
)
func NewRemoteKeySet(client *http.Client, jwksURL string) oidc.KeySet {
return &remoteKeySet{httpClient: client, jwksURL: jwksURL}
}
type remoteKeySet struct {
jwksURL string
httpClient *http.Client
// guard all other fields
mu sync.Mutex
// inflight suppresses parallel execution of updateKeys and allows
// multiple goroutines to wait for its result.
inflight *inflight
// A set of cached keys and their expiry.
cachedKeys []jose.JSONWebKey
}
// inflight is used to wait on some in-flight request from multiple goroutines.
type inflight struct {
doneCh chan struct{}
keys []jose.JSONWebKey
err error
}
func newInflight() *inflight {
return &inflight{doneCh: make(chan struct{})}
}
// wait returns a channel that multiple goroutines can receive on. Once it returns
// a value, the inflight request is done and result() can be inspected.
func (i *inflight) wait() <-chan struct{} {
return i.doneCh
}
// done can only be called by a single goroutine. It records the result of the
// inflight request and signals other goroutines that the result is safe to
// inspect.
func (i *inflight) done(keys []jose.JSONWebKey, err error) {
i.keys = keys
i.err = err
close(i.doneCh)
}
// result cannot be called until the wait() channel has returned a value.
func (i *inflight) result() ([]jose.JSONWebKey, error) {
return i.keys, i.err
}
func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
// We don't support JWTs signed with multiple signatures.
keyID := ""
for _, sig := range jws.Signatures {
keyID = sig.Header.KeyID
break
}
keys := r.keysFromCache()
payload, err, ok := checkKey(keyID, keys, jws)
if ok {
return payload, err
}
keys, err = r.keysFromRemote(ctx)
if err != nil {
return nil, fmt.Errorf("fetching keys %v", err)
}
payload, err, ok = checkKey(keyID, keys, jws)
if !ok {
return nil, errors.New("invalid kid")
}
return payload, err
}
func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {
r.mu.Lock()
defer r.mu.Unlock()
return r.cachedKeys
}
// keysFromRemote syncs the key set from the remote set, records the values in the
// cache, and returns the key set.
func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
// Need to lock to inspect the inflight request field.
r.mu.Lock()
// If there's not a current inflight request, create one.
if r.inflight == nil {
r.inflight = newInflight()
// This goroutine has exclusive ownership over the current inflight
// request. It releases the resource by nil'ing the inflight field
// once the goroutine is done.
go r.updateKeys(ctx)
}
inflight := r.inflight
r.mu.Unlock()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-inflight.wait():
return inflight.result()
}
}
func (r *remoteKeySet) updateKeys(ctx context.Context) {
// Sync keys and finish inflight when that's done.
keys, err := r.fetchRemoteKeys(ctx)
r.inflight.done(keys, err)
// Lock to update the keys and indicate that there is no longer an
// inflight request.
r.mu.Lock()
defer r.mu.Unlock()
if err == nil {
r.cachedKeys = keys
}
// Free inflight so a different request can run.
r.inflight = nil
}
func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey, error) {
req, err := http.NewRequest("GET", r.jwksURL, nil)
if err != nil {
return nil, fmt.Errorf("oidc: can't create request: %v", err)
}
keySet := new(jose.JSONWebKeySet)
if err = utils.HttpRequest(r.httpClient, req, keySet); err != nil {
return nil, fmt.Errorf("oidc: failed to get keys: %v", err)
}
return keySet.Keys, nil
}
func checkKey(keyID string, keys []jose.JSONWebKey, jws *jose.JSONWebSignature) ([]byte, error, bool) {
for _, key := range keys {
if keyID == "" || key.KeyID == keyID {
payload, err := jws.Verify(&key)
return payload, err, true
}
}
return nil, nil, false
}