fix: improve JWS and key verification
This commit is contained in:
parent
fcad98f4bd
commit
7bb6443cd0
5 changed files with 403 additions and 29 deletions
|
@ -5,4 +5,4 @@ module.exports = {
|
|||
"@semantic-release/release-notes-generator",
|
||||
"@semantic-release/github"
|
||||
]
|
||||
};
|
||||
};
|
||||
|
|
|
@ -3,7 +3,6 @@ package rp
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
@ -15,14 +14,31 @@ import (
|
|||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
func NewRemoteKeySet(client *http.Client, jwksURL string) oidc.KeySet {
|
||||
return &remoteKeySet{httpClient: client, jwksURL: jwksURL}
|
||||
func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet {
|
||||
keyset := &remoteKeySet{httpClient: client, jwksURL: jwksURL}
|
||||
for _, opt := range opts {
|
||||
opt(keyset)
|
||||
}
|
||||
return keyset
|
||||
}
|
||||
|
||||
//SkipRemoteCheck will suppress checking for new remote keys if signature validation fails with cached keys
|
||||
//and no kid header is set in the JWT
|
||||
//
|
||||
//this might be handy to save some unnecessary round trips in cases where the JWT does not contain a kid header and
|
||||
//there is only a single remote key
|
||||
//please notice that remote keys will then only be fetched if cached keys are empty
|
||||
func SkipRemoteCheck() func(set *remoteKeySet) {
|
||||
return func(set *remoteKeySet) {
|
||||
set.skipRemoteCheck = true
|
||||
}
|
||||
}
|
||||
|
||||
type remoteKeySet struct {
|
||||
jwksURL string
|
||||
httpClient *http.Client
|
||||
defaultAlg string
|
||||
skipRemoteCheck bool
|
||||
|
||||
// guard all other fields
|
||||
mu sync.Mutex
|
||||
|
@ -73,22 +89,37 @@ func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSig
|
|||
alg = r.defaultAlg
|
||||
}
|
||||
keys := r.keysFromCache()
|
||||
key, ok := oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keys...)
|
||||
if ok && keyID != "" {
|
||||
payload, err := jws.Verify(&key)
|
||||
return payload, err
|
||||
if len(keys) == 0 {
|
||||
return r.verifySignatureRemote(ctx, jws, keyID, alg)
|
||||
}
|
||||
payload, err := r.verifySignatureCached(jws, keys, keyID, alg)
|
||||
if payload != nil {
|
||||
return payload, nil
|
||||
}
|
||||
if err != nil && keyID != "" || r.skipRemoteCheck {
|
||||
return nil, fmt.Errorf("invalid signature: %w", err)
|
||||
}
|
||||
return r.verifySignatureRemote(ctx, jws, keyID, alg)
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) verifySignatureCached(jws *jose.JSONWebSignature, keys []jose.JSONWebKey, keyID, alg string) ([]byte, error) {
|
||||
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keys...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jws.Verify(&key)
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) verifySignatureRemote(ctx context.Context, jws *jose.JSONWebSignature, keyID, alg string) ([]byte, error) {
|
||||
keys, err := r.keysFromRemote(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching keys %v", err)
|
||||
return nil, fmt.Errorf("unable to fetch key for signature validation: %w", err)
|
||||
}
|
||||
key, ok = oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keys...)
|
||||
if ok {
|
||||
payload, err := jws.Verify(&key)
|
||||
return payload, err
|
||||
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keys...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to validate signature: %w", err)
|
||||
}
|
||||
return nil, errors.New("invalid key")
|
||||
return jws.Verify(&key)
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
@ -13,6 +14,11 @@ const (
|
|||
KeyUseSignature = "sig"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKeyMultiple = errors.New("multiple possible keys match")
|
||||
ErrKeyNone = errors.New("no possible keys matches")
|
||||
)
|
||||
|
||||
//KeySet represents a set of JSON Web Keys
|
||||
// - remotely fetch via discovery and jwks_uri -> `remoteKeySet`
|
||||
// - held by the OP itself in storage -> `openIDKeySet`
|
||||
|
@ -39,20 +45,38 @@ func GetKeyIDAndAlg(jws *jose.JSONWebSignature) (string, string) {
|
|||
//will return the key immediately if matches exact (id, usage, type)
|
||||
//
|
||||
//will return false none or multiple match
|
||||
//
|
||||
//deprecated: use FindMatchingKey which will return an error (more specific) instead of just a bool
|
||||
//moved implementation already to FindMatchingKey
|
||||
func FindKey(keyID, use, expectedAlg string, keys ...jose.JSONWebKey) (jose.JSONWebKey, bool) {
|
||||
key, err := FindMatchingKey(keyID, use, expectedAlg, keys...)
|
||||
return key, err == nil
|
||||
}
|
||||
|
||||
//FindMatchingKey searches the given JSON Web Keys for the requested key ID, usage and key type
|
||||
//
|
||||
//will return the key immediately if matches exact (id, usage, type)
|
||||
//
|
||||
//will return a specific error if none (ErrKeyNone) or multiple (ErrKeyMultiple) match
|
||||
func FindMatchingKey(keyID, use, expectedAlg string, keys ...jose.JSONWebKey) (key jose.JSONWebKey, err error) {
|
||||
var validKeys []jose.JSONWebKey
|
||||
for _, key := range keys {
|
||||
if key.KeyID == keyID && key.Use == use && algToKeyType(key.Key, expectedAlg) {
|
||||
if keyID != "" {
|
||||
return key, true
|
||||
for _, k := range keys {
|
||||
if k.Use == use && algToKeyType(k.Key, expectedAlg) {
|
||||
if k.KeyID == keyID && keyID != "" {
|
||||
return k, nil
|
||||
}
|
||||
if k.KeyID == "" || keyID == "" {
|
||||
validKeys = append(validKeys, k)
|
||||
}
|
||||
validKeys = append(validKeys, key)
|
||||
}
|
||||
}
|
||||
if len(validKeys) == 1 {
|
||||
return validKeys[0], true
|
||||
return validKeys[0], nil
|
||||
}
|
||||
return jose.JSONWebKey{}, false
|
||||
if len(validKeys) > 1 {
|
||||
return key, ErrKeyMultiple
|
||||
}
|
||||
return key, ErrKeyNone
|
||||
}
|
||||
|
||||
func algToKeyType(key interface{}, alg string) bool {
|
||||
|
|
319
pkg/oidc/keyset_test.go
Normal file
319
pkg/oidc/keyset_test.go
Normal file
|
@ -0,0 +1,319 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
func TestFindKey(t *testing.T) {
|
||||
type args struct {
|
||||
keyID string
|
||||
use string
|
||||
expectedAlg string
|
||||
keys []jose.JSONWebKey
|
||||
}
|
||||
type res struct {
|
||||
key jose.JSONWebKey
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"no keys, ErrKeyNone",
|
||||
args{
|
||||
keyID: "",
|
||||
use: KeyUseSignature,
|
||||
expectedAlg: "RS256",
|
||||
keys: nil,
|
||||
},
|
||||
res{
|
||||
key: jose.JSONWebKey{},
|
||||
err: ErrKeyNone,
|
||||
},
|
||||
},
|
||||
{
|
||||
"single key enc, ErrKeyNone",
|
||||
args{
|
||||
keyID: "",
|
||||
use: KeyUseSignature,
|
||||
expectedAlg: "RS256",
|
||||
keys: []jose.JSONWebKey{
|
||||
{
|
||||
Use: "enc",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
key: jose.JSONWebKey{},
|
||||
err: ErrKeyNone,
|
||||
},
|
||||
},
|
||||
{
|
||||
"single key wrong algorithm, ErrKeyNone",
|
||||
args{
|
||||
keyID: "",
|
||||
use: KeyUseSignature,
|
||||
expectedAlg: "RS256",
|
||||
keys: []jose.JSONWebKey{
|
||||
{
|
||||
Use: "sig",
|
||||
Key: &rsa.PrivateKey{},
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
key: jose.JSONWebKey{},
|
||||
err: ErrKeyNone,
|
||||
},
|
||||
},
|
||||
{
|
||||
"single key no kid, no jwt kid, match",
|
||||
args{
|
||||
keyID: "",
|
||||
use: KeyUseSignature,
|
||||
expectedAlg: "RS256",
|
||||
keys: []jose.JSONWebKey{
|
||||
{
|
||||
Use: "sig",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
key: jose.JSONWebKey{
|
||||
Use: "sig",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"single key kid, jwt no kid, match",
|
||||
args{
|
||||
keyID: "",
|
||||
use: KeyUseSignature,
|
||||
expectedAlg: "RS256",
|
||||
keys: []jose.JSONWebKey{
|
||||
{
|
||||
Use: "sig",
|
||||
KeyID: "id",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
key: jose.JSONWebKey{
|
||||
Use: "sig",
|
||||
KeyID: "id",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"single key no kid, jwt with kid, match",
|
||||
args{
|
||||
keyID: "id",
|
||||
use: KeyUseSignature,
|
||||
expectedAlg: "RS256",
|
||||
keys: []jose.JSONWebKey{
|
||||
{
|
||||
Use: "sig",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
key: jose.JSONWebKey{
|
||||
Use: "sig",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"single key wrong kid, ErrKeyNone",
|
||||
args{
|
||||
keyID: "id",
|
||||
use: KeyUseSignature,
|
||||
expectedAlg: "RS256",
|
||||
keys: []jose.JSONWebKey{
|
||||
{
|
||||
Use: "sig",
|
||||
KeyID: "id2",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
key: jose.JSONWebKey{},
|
||||
err: ErrKeyNone,
|
||||
},
|
||||
},
|
||||
{
|
||||
"multiple keys no kid, jwt no kid, ErrKeyMultiple",
|
||||
args{
|
||||
keyID: "",
|
||||
use: KeyUseSignature,
|
||||
expectedAlg: "RS256",
|
||||
keys: []jose.JSONWebKey{
|
||||
{
|
||||
Use: "sig",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
{
|
||||
Use: "sig",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
key: jose.JSONWebKey{},
|
||||
err: ErrKeyMultiple,
|
||||
},
|
||||
},
|
||||
{
|
||||
"multiple keys with kid, jwt no kid, ErrKeyMultiple",
|
||||
args{
|
||||
keyID: "",
|
||||
use: KeyUseSignature,
|
||||
expectedAlg: "RS256",
|
||||
keys: []jose.JSONWebKey{
|
||||
{
|
||||
Use: "sig",
|
||||
KeyID: "id1",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
{
|
||||
Use: "sig",
|
||||
KeyID: "id2",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
key: jose.JSONWebKey{},
|
||||
err: ErrKeyMultiple,
|
||||
},
|
||||
},
|
||||
{
|
||||
"multiple keys, single sig key, jwt no kid, match",
|
||||
args{
|
||||
keyID: "",
|
||||
use: KeyUseSignature,
|
||||
expectedAlg: "RS256",
|
||||
keys: []jose.JSONWebKey{
|
||||
{
|
||||
Use: "sig",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
{
|
||||
Use: "enc",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
key: jose.JSONWebKey{
|
||||
Use: "sig",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"multiple keys no kid, jwt with kid, ErrKeyMultiple",
|
||||
args{
|
||||
keyID: "id",
|
||||
use: KeyUseSignature,
|
||||
expectedAlg: "RS256",
|
||||
keys: []jose.JSONWebKey{
|
||||
{
|
||||
Use: "sig",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
{
|
||||
Use: "sig",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
key: jose.JSONWebKey{},
|
||||
err: ErrKeyMultiple,
|
||||
},
|
||||
},
|
||||
{
|
||||
"multiple keys with kid, jwt with kid, match",
|
||||
args{
|
||||
keyID: "id1",
|
||||
use: KeyUseSignature,
|
||||
expectedAlg: "RS256",
|
||||
keys: []jose.JSONWebKey{
|
||||
{
|
||||
Use: "sig",
|
||||
KeyID: "id1",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
{
|
||||
Use: "sig",
|
||||
KeyID: "id2",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
key: jose.JSONWebKey{
|
||||
Use: "sig",
|
||||
KeyID: "id1",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"multiple keys, single sig key, jwt with kid, match",
|
||||
args{
|
||||
keyID: "id1",
|
||||
use: KeyUseSignature,
|
||||
expectedAlg: "RS256",
|
||||
keys: []jose.JSONWebKey{
|
||||
{
|
||||
Use: "sig",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
{
|
||||
Use: "enc",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
key: jose.JSONWebKey{
|
||||
Use: "sig",
|
||||
Key: &rsa.PublicKey{},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := FindMatchingKey(tt.args.keyID, tt.args.use, tt.args.expectedAlg, tt.args.keys...)
|
||||
if (tt.res.err != nil && !errors.Is(err, tt.res.err)) || (tt.res.err == nil && err != nil) {
|
||||
t.Errorf("FindKey() error, got = %v, want = %v", err, tt.res.err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.res.key) {
|
||||
t.Errorf("FindKey() got = %v, want %v", got, tt.res.key)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
10
pkg/op/op.go
10
pkg/op/op.go
|
@ -2,7 +2,7 @@ package op
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
@ -280,12 +280,12 @@ type openIDKeySet struct {
|
|||
func (o *openIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||
keySet, err := o.Storage.GetKeySet(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.New("error fetching keys")
|
||||
return nil, fmt.Errorf("error fetching keys: %w", err)
|
||||
}
|
||||
keyID, alg := oidc.GetKeyIDAndAlg(jws)
|
||||
key, ok := oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keySet.Keys...)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid kid")
|
||||
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keySet.Keys...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid signature: %w", err)
|
||||
}
|
||||
return jws.Verify(&key)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue