* chore: move to new org * chore: change import * fix: update logging lib Co-authored-by: Fabienne <fabienne.gerschwiler@gmail.com> Co-authored-by: adlerhurst <silvan.reusser@gmail.com>
239 lines
6.6 KiB
Go
239 lines
6.6 KiB
Go
package rp
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"sync"
|
|
|
|
"gopkg.in/square/go-jose.v2"
|
|
|
|
httphelper "github.com/zitadel/oidc/pkg/http"
|
|
"github.com/zitadel/oidc/pkg/oidc"
|
|
)
|
|
|
|
func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet {
|
|
keyset := &remoteKeySet{httpClient: client, jwksURL: jwksURL}
|
|
for _, opt := range opts {
|
|
opt(keyset)
|
|
}
|
|
return keyset
|
|
}
|
|
|
|
//SkipRemoteCheck will suppress checking for new remote keys if signature validation fails with cached keys
|
|
//and no kid header is set in the JWT
|
|
//
|
|
//this might be handy to save some unnecessary round trips in cases where the JWT does not contain a kid header and
|
|
//there is only a single remote key
|
|
//please notice that remote keys will then only be fetched if cached keys are empty
|
|
func SkipRemoteCheck() func(set *remoteKeySet) {
|
|
return func(set *remoteKeySet) {
|
|
set.skipRemoteCheck = true
|
|
}
|
|
}
|
|
|
|
type remoteKeySet struct {
|
|
jwksURL string
|
|
httpClient *http.Client
|
|
defaultAlg string
|
|
skipRemoteCheck bool
|
|
|
|
// guard all other fields
|
|
mu sync.Mutex
|
|
|
|
// inflight suppresses parallel execution of updateKeys and allows
|
|
// multiple goroutines to wait for its result.
|
|
inflight *inflight
|
|
|
|
// A set of cached keys and their expiry.
|
|
cachedKeys []jose.JSONWebKey
|
|
}
|
|
|
|
// inflight is used to wait on some in-flight request from multiple goroutines.
|
|
type inflight struct {
|
|
doneCh chan struct{}
|
|
|
|
keys []jose.JSONWebKey
|
|
err error
|
|
}
|
|
|
|
func newInflight() *inflight {
|
|
return &inflight{doneCh: make(chan struct{})}
|
|
}
|
|
|
|
// wait returns a channel that multiple goroutines can receive on. Once it returns
|
|
// a value, the inflight request is done and result() can be inspected.
|
|
func (i *inflight) wait() <-chan struct{} {
|
|
return i.doneCh
|
|
}
|
|
|
|
// done can only be called by a single goroutine. It records the result of the
|
|
// inflight request and signals other goroutines that the result is safe to
|
|
// inspect.
|
|
func (i *inflight) done(keys []jose.JSONWebKey, err error) {
|
|
i.keys = keys
|
|
i.err = err
|
|
close(i.doneCh)
|
|
}
|
|
|
|
// result cannot be called until the wait() channel has returned a value.
|
|
func (i *inflight) result() ([]jose.JSONWebKey, error) {
|
|
return i.keys, i.err
|
|
}
|
|
|
|
func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
|
keyID, alg := oidc.GetKeyIDAndAlg(jws)
|
|
if alg == "" {
|
|
alg = r.defaultAlg
|
|
}
|
|
payload, err := r.verifySignatureCached(jws, keyID, alg)
|
|
if payload != nil {
|
|
return payload, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return r.verifySignatureRemote(ctx, jws, keyID, alg)
|
|
}
|
|
|
|
//verifySignatureCached checks for a matching key in the cached key list
|
|
//
|
|
//if there is only one possible, it tries to verify the signature and will return the payload if successful
|
|
//
|
|
//it only returns an error if signature validation fails and keys exactMatch which is if either:
|
|
// - both kid are empty and skipRemoteCheck is set to true
|
|
// - or both (JWT and JWK) kid are equal
|
|
//
|
|
//otherwise it will return no error (so remote keys will be loaded)
|
|
func (r *remoteKeySet) verifySignatureCached(jws *jose.JSONWebSignature, keyID, alg string) ([]byte, error) {
|
|
keys := r.keysFromCache()
|
|
if len(keys) == 0 {
|
|
return nil, nil
|
|
}
|
|
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keys...)
|
|
if err != nil {
|
|
//no key / multiple found, try with remote keys
|
|
return nil, nil //nolint:nilerr
|
|
}
|
|
payload, err := jws.Verify(&key)
|
|
if payload != nil {
|
|
return payload, nil
|
|
}
|
|
if !r.exactMatch(key.KeyID, keyID) {
|
|
//no exact key match, try getting better match with remote keys
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("signature verification failed: %w", err)
|
|
}
|
|
|
|
func (r *remoteKeySet) exactMatch(jwkID, jwsID string) bool {
|
|
if jwkID == "" && jwsID == "" {
|
|
return r.skipRemoteCheck
|
|
}
|
|
return jwkID == jwsID
|
|
}
|
|
|
|
func (r *remoteKeySet) verifySignatureRemote(ctx context.Context, jws *jose.JSONWebSignature, keyID, alg string) ([]byte, error) {
|
|
keys, err := r.keysFromRemote(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to fetch key for signature validation: %w", err)
|
|
}
|
|
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keys...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to validate signature: %w", err)
|
|
}
|
|
payload, err := jws.Verify(&key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("signature verification failed: %w", err)
|
|
}
|
|
return payload, nil
|
|
}
|
|
|
|
func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
return r.cachedKeys
|
|
}
|
|
|
|
// keysFromRemote syncs the key set from the remote set, records the values in the
|
|
// cache, and returns the key set.
|
|
func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
|
|
// Need to lock to inspect the inflight request field.
|
|
r.mu.Lock()
|
|
// If there's not a current inflight request, create one.
|
|
if r.inflight == nil {
|
|
r.inflight = newInflight()
|
|
|
|
// This goroutine has exclusive ownership over the current inflight
|
|
// request. It releases the resource by nil'ing the inflight field
|
|
// once the goroutine is done.
|
|
go r.updateKeys(ctx)
|
|
}
|
|
inflight := r.inflight
|
|
r.mu.Unlock()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-inflight.wait():
|
|
return inflight.result()
|
|
}
|
|
}
|
|
|
|
func (r *remoteKeySet) updateKeys(ctx context.Context) {
|
|
// Sync keys and finish inflight when that's done.
|
|
keys, err := r.fetchRemoteKeys(ctx)
|
|
|
|
r.inflight.done(keys, err)
|
|
|
|
// Lock to update the keys and indicate that there is no longer an
|
|
// inflight request.
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
if err == nil {
|
|
r.cachedKeys = keys
|
|
}
|
|
|
|
// Free inflight so a different request can run.
|
|
r.inflight = nil
|
|
}
|
|
|
|
func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey, error) {
|
|
req, err := http.NewRequest("GET", r.jwksURL, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("oidc: can't create request: %v", err)
|
|
}
|
|
|
|
keySet := new(jsonWebKeySet)
|
|
if err = httphelper.HttpRequest(r.httpClient, req, keySet); err != nil {
|
|
return nil, fmt.Errorf("oidc: failed to get keys: %v", err)
|
|
}
|
|
return keySet.Keys, nil
|
|
}
|
|
|
|
//jsonWebKeySet is an alias for jose.JSONWebKeySet which ignores unknown key types (kty)
|
|
type jsonWebKeySet jose.JSONWebKeySet
|
|
|
|
//UnmarshalJSON overrides the default jose.JSONWebKeySet method to ignore any error
|
|
//which might occur because of unknown key types (kty)
|
|
func (k *jsonWebKeySet) UnmarshalJSON(data []byte) (err error) {
|
|
var raw rawJSONWebKeySet
|
|
err = json.Unmarshal(data, &raw)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, key := range raw.Keys {
|
|
webKey := new(jose.JSONWebKey)
|
|
err = webKey.UnmarshalJSON(key)
|
|
if err == nil {
|
|
k.Keys = append(k.Keys, *webKey)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type rawJSONWebKeySet struct {
|
|
Keys []json.RawMessage `json:"keys"`
|
|
}
|