fix: improve JWS and key verification

This commit is contained in:
Livio Amstutz 2021-09-14 08:20:41 +02:00
parent fcad98f4bd
commit 7bb6443cd0
5 changed files with 403 additions and 29 deletions

View file

@ -3,7 +3,6 @@ package rp
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"sync"
@ -15,14 +14,31 @@ import (
"github.com/caos/oidc/pkg/oidc"
)
func NewRemoteKeySet(client *http.Client, jwksURL string) oidc.KeySet {
return &remoteKeySet{httpClient: client, jwksURL: jwksURL}
func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet {
keyset := &remoteKeySet{httpClient: client, jwksURL: jwksURL}
for _, opt := range opts {
opt(keyset)
}
return keyset
}
//SkipRemoteCheck will suppress checking for new remote keys if signature validation fails with cached keys
//and no kid header is set in the JWT
//
//this might be handy to save some unnecessary round trips in cases where the JWT does not contain a kid header and
//there is only a single remote key
//please notice that remote keys will then only be fetched if cached keys are empty
func SkipRemoteCheck() func(set *remoteKeySet) {
return func(set *remoteKeySet) {
set.skipRemoteCheck = true
}
}
type remoteKeySet struct {
jwksURL string
httpClient *http.Client
defaultAlg string
jwksURL string
httpClient *http.Client
defaultAlg string
skipRemoteCheck bool
// guard all other fields
mu sync.Mutex
@ -73,22 +89,37 @@ func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSig
alg = r.defaultAlg
}
keys := r.keysFromCache()
key, ok := oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keys...)
if ok && keyID != "" {
payload, err := jws.Verify(&key)
return payload, err
if len(keys) == 0 {
return r.verifySignatureRemote(ctx, jws, keyID, alg)
}
payload, err := r.verifySignatureCached(jws, keys, keyID, alg)
if payload != nil {
return payload, nil
}
if err != nil && keyID != "" || r.skipRemoteCheck {
return nil, fmt.Errorf("invalid signature: %w", err)
}
return r.verifySignatureRemote(ctx, jws, keyID, alg)
}
func (r *remoteKeySet) verifySignatureCached(jws *jose.JSONWebSignature, keys []jose.JSONWebKey, keyID, alg string) ([]byte, error) {
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keys...)
if err != nil {
return nil, err
}
return jws.Verify(&key)
}
func (r *remoteKeySet) verifySignatureRemote(ctx context.Context, jws *jose.JSONWebSignature, keyID, alg string) ([]byte, error) {
keys, err := r.keysFromRemote(ctx)
if err != nil {
return nil, fmt.Errorf("fetching keys %v", err)
return nil, fmt.Errorf("unable to fetch key for signature validation: %w", err)
}
key, ok = oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keys...)
if ok {
payload, err := jws.Verify(&key)
return payload, err
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keys...)
if err != nil {
return nil, fmt.Errorf("unable to validate signature: %w", err)
}
return nil, errors.New("invalid key")
return jws.Verify(&key)
}
func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {