fix: improve JWS and key verification (#128)

* fix: improve JWS and key verification

* fix: get remote keys if no cached key matches

* fix: get remote keys if no cached key matches

* fix exactMatch

* fix exactMatch

* chore: change default branch name in .releaserc.js
This commit is contained in:
Livio Amstutz 2021-09-14 15:13:44 +02:00 committed by GitHub
parent 2b5b436c41
commit a63fbee93d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 453 additions and 32 deletions

View file

@ -3,26 +3,41 @@ package rp
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"sync"
"github.com/caos/oidc/pkg/utils"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
)
func NewRemoteKeySet(client *http.Client, jwksURL string) oidc.KeySet {
return &remoteKeySet{httpClient: client, jwksURL: jwksURL}
func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet {
keyset := &remoteKeySet{httpClient: client, jwksURL: jwksURL}
for _, opt := range opts {
opt(keyset)
}
return keyset
}
//SkipRemoteCheck will suppress checking for new remote keys if signature validation fails with cached keys
//and no kid header is set in the JWT
//
//this might be handy to save some unnecessary round trips in cases where the JWT does not contain a kid header and
//there is only a single remote key
//please notice that remote keys will then only be fetched if cached keys are empty
func SkipRemoteCheck() func(set *remoteKeySet) {
return func(set *remoteKeySet) {
set.skipRemoteCheck = true
}
}
type remoteKeySet struct {
jwksURL string
httpClient *http.Client
defaultAlg string
jwksURL string
httpClient *http.Client
defaultAlg string
skipRemoteCheck bool
// guard all other fields
mu sync.Mutex
@ -72,23 +87,67 @@ func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSig
if alg == "" {
alg = r.defaultAlg
}
keys := r.keysFromCache()
key, ok := oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keys...)
if ok && keyID != "" {
payload, err := jws.Verify(&key)
return payload, err
payload, err := r.verifySignatureCached(jws, keyID, alg)
if payload != nil {
return payload, nil
}
if err != nil {
return nil, err
}
return r.verifySignatureRemote(ctx, jws, keyID, alg)
}
//verifySignatureCached checks for a matching key in the cached key list
//
//if there is only one possible, it tries to verify the signature and will return the payload if successful
//
//it only returns an error if signature validation fails and keys exactMatch which is if either:
// - both kid are empty and skipRemoteCheck is set to true
// - or both (JWT and JWK) kid are equal
//
//otherwise it will return no error (so remote keys will be loaded)
func (r *remoteKeySet) verifySignatureCached(jws *jose.JSONWebSignature, keyID, alg string) ([]byte, error) {
keys := r.keysFromCache()
if len(keys) == 0 {
return nil, nil
}
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keys...)
if err != nil {
//no key / multiple found, try with remote keys
return nil, nil //nolint:nilerr
}
payload, err := jws.Verify(&key)
if payload != nil {
return payload, nil
}
if !r.exactMatch(key.KeyID, keyID) {
//no exact key match, try getting better match with remote keys
return nil, nil
}
return nil, fmt.Errorf("signature verification failed: %w", err)
}
func (r *remoteKeySet) exactMatch(jwkID, jwsID string) bool {
if jwkID == "" && jwsID == "" {
return r.skipRemoteCheck
}
return jwkID == jwsID
}
func (r *remoteKeySet) verifySignatureRemote(ctx context.Context, jws *jose.JSONWebSignature, keyID, alg string) ([]byte, error) {
keys, err := r.keysFromRemote(ctx)
if err != nil {
return nil, fmt.Errorf("fetching keys %v", err)
return nil, fmt.Errorf("unable to fetch key for signature validation: %w", err)
}
key, ok = oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keys...)
if ok {
payload, err := jws.Verify(&key)
return payload, err
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keys...)
if err != nil {
return nil, fmt.Errorf("unable to validate signature: %w", err)
}
return nil, errors.New("invalid key")
payload, err := jws.Verify(&key)
if err != nil {
return nil, fmt.Errorf("signature verification failed: %w", err)
}
return payload, nil
}
func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {

View file

@ -5,6 +5,7 @@ import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"errors"
"gopkg.in/square/go-jose.v2"
)
@ -13,6 +14,11 @@ const (
KeyUseSignature = "sig"
)
var (
ErrKeyMultiple = errors.New("multiple possible keys match")
ErrKeyNone = errors.New("no possible keys matches")
)
//KeySet represents a set of JSON Web Keys
// - remotely fetch via discovery and jwks_uri -> `remoteKeySet`
// - held by the OP itself in storage -> `openIDKeySet`
@ -39,20 +45,38 @@ func GetKeyIDAndAlg(jws *jose.JSONWebSignature) (string, string) {
//will return the key immediately if matches exact (id, usage, type)
//
//will return false none or multiple match
//
//deprecated: use FindMatchingKey which will return an error (more specific) instead of just a bool
//moved implementation already to FindMatchingKey
func FindKey(keyID, use, expectedAlg string, keys ...jose.JSONWebKey) (jose.JSONWebKey, bool) {
key, err := FindMatchingKey(keyID, use, expectedAlg, keys...)
return key, err == nil
}
//FindMatchingKey searches the given JSON Web Keys for the requested key ID, usage and key type
//
//will return the key immediately if matches exact (id, usage, type)
//
//will return a specific error if none (ErrKeyNone) or multiple (ErrKeyMultiple) match
func FindMatchingKey(keyID, use, expectedAlg string, keys ...jose.JSONWebKey) (key jose.JSONWebKey, err error) {
var validKeys []jose.JSONWebKey
for _, key := range keys {
if key.KeyID == keyID && key.Use == use && algToKeyType(key.Key, expectedAlg) {
if keyID != "" {
return key, true
for _, k := range keys {
if k.Use == use && algToKeyType(k.Key, expectedAlg) {
if k.KeyID == keyID && keyID != "" {
return k, nil
}
if k.KeyID == "" || keyID == "" {
validKeys = append(validKeys, k)
}
validKeys = append(validKeys, key)
}
}
if len(validKeys) == 1 {
return validKeys[0], true
return validKeys[0], nil
}
return jose.JSONWebKey{}, false
if len(validKeys) > 1 {
return key, ErrKeyMultiple
}
return key, ErrKeyNone
}
func algToKeyType(key interface{}, alg string) bool {

319
pkg/oidc/keyset_test.go Normal file
View file

@ -0,0 +1,319 @@
package oidc
import (
"crypto/rsa"
"errors"
"reflect"
"testing"
"gopkg.in/square/go-jose.v2"
)
func TestFindKey(t *testing.T) {
type args struct {
keyID string
use string
expectedAlg string
keys []jose.JSONWebKey
}
type res struct {
key jose.JSONWebKey
err error
}
tests := []struct {
name string
args args
res res
}{
{
"no keys, ErrKeyNone",
args{
keyID: "",
use: KeyUseSignature,
expectedAlg: "RS256",
keys: nil,
},
res{
key: jose.JSONWebKey{},
err: ErrKeyNone,
},
},
{
"single key enc, ErrKeyNone",
args{
keyID: "",
use: KeyUseSignature,
expectedAlg: "RS256",
keys: []jose.JSONWebKey{
{
Use: "enc",
Key: &rsa.PublicKey{},
},
},
},
res{
key: jose.JSONWebKey{},
err: ErrKeyNone,
},
},
{
"single key wrong algorithm, ErrKeyNone",
args{
keyID: "",
use: KeyUseSignature,
expectedAlg: "RS256",
keys: []jose.JSONWebKey{
{
Use: "sig",
Key: &rsa.PrivateKey{},
},
},
},
res{
key: jose.JSONWebKey{},
err: ErrKeyNone,
},
},
{
"single key no kid, no jwt kid, match",
args{
keyID: "",
use: KeyUseSignature,
expectedAlg: "RS256",
keys: []jose.JSONWebKey{
{
Use: "sig",
Key: &rsa.PublicKey{},
},
},
},
res{
key: jose.JSONWebKey{
Use: "sig",
Key: &rsa.PublicKey{},
},
err: nil,
},
},
{
"single key kid, jwt no kid, match",
args{
keyID: "",
use: KeyUseSignature,
expectedAlg: "RS256",
keys: []jose.JSONWebKey{
{
Use: "sig",
KeyID: "id",
Key: &rsa.PublicKey{},
},
},
},
res{
key: jose.JSONWebKey{
Use: "sig",
KeyID: "id",
Key: &rsa.PublicKey{},
},
err: nil,
},
},
{
"single key no kid, jwt with kid, match",
args{
keyID: "id",
use: KeyUseSignature,
expectedAlg: "RS256",
keys: []jose.JSONWebKey{
{
Use: "sig",
Key: &rsa.PublicKey{},
},
},
},
res{
key: jose.JSONWebKey{
Use: "sig",
Key: &rsa.PublicKey{},
},
err: nil,
},
},
{
"single key wrong kid, ErrKeyNone",
args{
keyID: "id",
use: KeyUseSignature,
expectedAlg: "RS256",
keys: []jose.JSONWebKey{
{
Use: "sig",
KeyID: "id2",
Key: &rsa.PublicKey{},
},
},
},
res{
key: jose.JSONWebKey{},
err: ErrKeyNone,
},
},
{
"multiple keys no kid, jwt no kid, ErrKeyMultiple",
args{
keyID: "",
use: KeyUseSignature,
expectedAlg: "RS256",
keys: []jose.JSONWebKey{
{
Use: "sig",
Key: &rsa.PublicKey{},
},
{
Use: "sig",
Key: &rsa.PublicKey{},
},
},
},
res{
key: jose.JSONWebKey{},
err: ErrKeyMultiple,
},
},
{
"multiple keys with kid, jwt no kid, ErrKeyMultiple",
args{
keyID: "",
use: KeyUseSignature,
expectedAlg: "RS256",
keys: []jose.JSONWebKey{
{
Use: "sig",
KeyID: "id1",
Key: &rsa.PublicKey{},
},
{
Use: "sig",
KeyID: "id2",
Key: &rsa.PublicKey{},
},
},
},
res{
key: jose.JSONWebKey{},
err: ErrKeyMultiple,
},
},
{
"multiple keys, single sig key, jwt no kid, match",
args{
keyID: "",
use: KeyUseSignature,
expectedAlg: "RS256",
keys: []jose.JSONWebKey{
{
Use: "sig",
Key: &rsa.PublicKey{},
},
{
Use: "enc",
Key: &rsa.PublicKey{},
},
},
},
res{
key: jose.JSONWebKey{
Use: "sig",
Key: &rsa.PublicKey{},
},
err: nil,
},
},
{
"multiple keys no kid, jwt with kid, ErrKeyMultiple",
args{
keyID: "id",
use: KeyUseSignature,
expectedAlg: "RS256",
keys: []jose.JSONWebKey{
{
Use: "sig",
Key: &rsa.PublicKey{},
},
{
Use: "sig",
Key: &rsa.PublicKey{},
},
},
},
res{
key: jose.JSONWebKey{},
err: ErrKeyMultiple,
},
},
{
"multiple keys with kid, jwt with kid, match",
args{
keyID: "id1",
use: KeyUseSignature,
expectedAlg: "RS256",
keys: []jose.JSONWebKey{
{
Use: "sig",
KeyID: "id1",
Key: &rsa.PublicKey{},
},
{
Use: "sig",
KeyID: "id2",
Key: &rsa.PublicKey{},
},
},
},
res{
key: jose.JSONWebKey{
Use: "sig",
KeyID: "id1",
Key: &rsa.PublicKey{},
},
err: nil,
},
},
{
"multiple keys, single sig key, jwt with kid, match",
args{
keyID: "id1",
use: KeyUseSignature,
expectedAlg: "RS256",
keys: []jose.JSONWebKey{
{
Use: "sig",
Key: &rsa.PublicKey{},
},
{
Use: "enc",
Key: &rsa.PublicKey{},
},
},
},
res{
key: jose.JSONWebKey{
Use: "sig",
Key: &rsa.PublicKey{},
},
err: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := FindMatchingKey(tt.args.keyID, tt.args.use, tt.args.expectedAlg, tt.args.keys...)
if (tt.res.err != nil && !errors.Is(err, tt.res.err)) || (tt.res.err == nil && err != nil) {
t.Errorf("FindKey() error, got = %v, want = %v", err, tt.res.err)
}
if !reflect.DeepEqual(got, tt.res.key) {
t.Errorf("FindKey() got = %v, want %v", got, tt.res.key)
}
})
}
}

View file

@ -39,6 +39,7 @@ var (
ErrSignatureMultiple = errors.New("id_token contains multiple signatures")
ErrSignatureUnsupportedAlg = errors.New("signature algorithm not supported")
ErrSignatureInvalidPayload = errors.New("signature does not match Payload")
ErrSignatureInvalid = errors.New("invalid signature")
ErrExpired = errors.New("token has expired")
ErrIatMissing = errors.New("issuedAt of token is missing")
ErrIatInFuture = errors.New("issuedAt of token is in the future")
@ -143,7 +144,7 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl
signedPayload, err := set.VerifySignature(ctx, jws)
if err != nil {
return err
return fmt.Errorf("%w (%v)", ErrSignatureInvalid, err)
}
if !bytes.Equal(signedPayload, payload) {

View file

@ -2,7 +2,7 @@ package op
import (
"context"
"errors"
"fmt"
"net/http"
"time"
@ -280,12 +280,12 @@ type openIDKeySet struct {
func (o *openIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
keySet, err := o.Storage.GetKeySet(ctx)
if err != nil {
return nil, errors.New("error fetching keys")
return nil, fmt.Errorf("error fetching keys: %w", err)
}
keyID, alg := oidc.GetKeyIDAndAlg(jws)
key, ok := oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keySet.Keys...)
if !ok {
return nil, errors.New("invalid kid")
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keySet.Keys...)
if err != nil {
return nil, fmt.Errorf("invalid signature: %w", err)
}
return jws.Verify(&key)
}