diff --git a/.releaserc.js b/.releaserc.js index 2847184..6500ace 100644 --- a/.releaserc.js +++ b/.releaserc.js @@ -5,4 +5,4 @@ module.exports = { "@semantic-release/release-notes-generator", "@semantic-release/github" ] - }; +}; diff --git a/pkg/client/rp/jwks.go b/pkg/client/rp/jwks.go index 98ed501..3fb26c5 100644 --- a/pkg/client/rp/jwks.go +++ b/pkg/client/rp/jwks.go @@ -3,7 +3,6 @@ package rp import ( "context" "encoding/json" - "errors" "fmt" "net/http" "sync" @@ -15,14 +14,31 @@ import ( "github.com/caos/oidc/pkg/oidc" ) -func NewRemoteKeySet(client *http.Client, jwksURL string) oidc.KeySet { - return &remoteKeySet{httpClient: client, jwksURL: jwksURL} +func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet { + keyset := &remoteKeySet{httpClient: client, jwksURL: jwksURL} + for _, opt := range opts { + opt(keyset) + } + return keyset +} + +//SkipRemoteCheck will suppress checking for new remote keys if signature validation fails with cached keys +//and no kid header is set in the JWT +// +//this might be handy to save some unnecessary round trips in cases where the JWT does not contain a kid header and +//there is only a single remote key +//please notice that remote keys will then only be fetched if cached keys are empty +func SkipRemoteCheck() func(set *remoteKeySet) { + return func(set *remoteKeySet) { + set.skipRemoteCheck = true + } } type remoteKeySet struct { - jwksURL string - httpClient *http.Client - defaultAlg string + jwksURL string + httpClient *http.Client + defaultAlg string + skipRemoteCheck bool // guard all other fields mu sync.Mutex @@ -73,22 +89,37 @@ func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSig alg = r.defaultAlg } keys := r.keysFromCache() - key, ok := oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keys...) - if ok && keyID != "" { - payload, err := jws.Verify(&key) - return payload, err + if len(keys) == 0 { + return r.verifySignatureRemote(ctx, jws, keyID, alg) } + payload, err := r.verifySignatureCached(jws, keys, keyID, alg) + if payload != nil { + return payload, nil + } + if err != nil && keyID != "" || r.skipRemoteCheck { + return nil, fmt.Errorf("invalid signature: %w", err) + } + return r.verifySignatureRemote(ctx, jws, keyID, alg) +} +func (r *remoteKeySet) verifySignatureCached(jws *jose.JSONWebSignature, keys []jose.JSONWebKey, keyID, alg string) ([]byte, error) { + key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keys...) + if err != nil { + return nil, err + } + return jws.Verify(&key) +} + +func (r *remoteKeySet) verifySignatureRemote(ctx context.Context, jws *jose.JSONWebSignature, keyID, alg string) ([]byte, error) { keys, err := r.keysFromRemote(ctx) if err != nil { - return nil, fmt.Errorf("fetching keys %v", err) + return nil, fmt.Errorf("unable to fetch key for signature validation: %w", err) } - key, ok = oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keys...) - if ok { - payload, err := jws.Verify(&key) - return payload, err + key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keys...) + if err != nil { + return nil, fmt.Errorf("unable to validate signature: %w", err) } - return nil, errors.New("invalid key") + return jws.Verify(&key) } func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) { diff --git a/pkg/oidc/keyset.go b/pkg/oidc/keyset.go index adfffcf..3eca654 100644 --- a/pkg/oidc/keyset.go +++ b/pkg/oidc/keyset.go @@ -5,6 +5,7 @@ import ( "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" + "errors" "gopkg.in/square/go-jose.v2" ) @@ -13,6 +14,11 @@ const ( KeyUseSignature = "sig" ) +var ( + ErrKeyMultiple = errors.New("multiple possible keys match") + ErrKeyNone = errors.New("no possible keys matches") +) + //KeySet represents a set of JSON Web Keys // - remotely fetch via discovery and jwks_uri -> `remoteKeySet` // - held by the OP itself in storage -> `openIDKeySet` @@ -39,20 +45,38 @@ func GetKeyIDAndAlg(jws *jose.JSONWebSignature) (string, string) { //will return the key immediately if matches exact (id, usage, type) // //will return false none or multiple match +// +//deprecated: use FindMatchingKey which will return an error (more specific) instead of just a bool +//moved implementation already to FindMatchingKey func FindKey(keyID, use, expectedAlg string, keys ...jose.JSONWebKey) (jose.JSONWebKey, bool) { + key, err := FindMatchingKey(keyID, use, expectedAlg, keys...) + return key, err == nil +} + +//FindMatchingKey searches the given JSON Web Keys for the requested key ID, usage and key type +// +//will return the key immediately if matches exact (id, usage, type) +// +//will return a specific error if none (ErrKeyNone) or multiple (ErrKeyMultiple) match +func FindMatchingKey(keyID, use, expectedAlg string, keys ...jose.JSONWebKey) (key jose.JSONWebKey, err error) { var validKeys []jose.JSONWebKey - for _, key := range keys { - if key.KeyID == keyID && key.Use == use && algToKeyType(key.Key, expectedAlg) { - if keyID != "" { - return key, true + for _, k := range keys { + if k.Use == use && algToKeyType(k.Key, expectedAlg) { + if k.KeyID == keyID && keyID != "" { + return k, nil + } + if k.KeyID == "" || keyID == "" { + validKeys = append(validKeys, k) } - validKeys = append(validKeys, key) } } if len(validKeys) == 1 { - return validKeys[0], true + return validKeys[0], nil } - return jose.JSONWebKey{}, false + if len(validKeys) > 1 { + return key, ErrKeyMultiple + } + return key, ErrKeyNone } func algToKeyType(key interface{}, alg string) bool { diff --git a/pkg/oidc/keyset_test.go b/pkg/oidc/keyset_test.go new file mode 100644 index 0000000..802edec --- /dev/null +++ b/pkg/oidc/keyset_test.go @@ -0,0 +1,319 @@ +package oidc + +import ( + "crypto/rsa" + "errors" + "reflect" + "testing" + + "gopkg.in/square/go-jose.v2" +) + +func TestFindKey(t *testing.T) { + type args struct { + keyID string + use string + expectedAlg string + keys []jose.JSONWebKey + } + type res struct { + key jose.JSONWebKey + err error + } + tests := []struct { + name string + args args + res res + }{ + { + "no keys, ErrKeyNone", + args{ + keyID: "", + use: KeyUseSignature, + expectedAlg: "RS256", + keys: nil, + }, + res{ + key: jose.JSONWebKey{}, + err: ErrKeyNone, + }, + }, + { + "single key enc, ErrKeyNone", + args{ + keyID: "", + use: KeyUseSignature, + expectedAlg: "RS256", + keys: []jose.JSONWebKey{ + { + Use: "enc", + Key: &rsa.PublicKey{}, + }, + }, + }, + res{ + key: jose.JSONWebKey{}, + err: ErrKeyNone, + }, + }, + { + "single key wrong algorithm, ErrKeyNone", + args{ + keyID: "", + use: KeyUseSignature, + expectedAlg: "RS256", + keys: []jose.JSONWebKey{ + { + Use: "sig", + Key: &rsa.PrivateKey{}, + }, + }, + }, + res{ + key: jose.JSONWebKey{}, + err: ErrKeyNone, + }, + }, + { + "single key no kid, no jwt kid, match", + args{ + keyID: "", + use: KeyUseSignature, + expectedAlg: "RS256", + keys: []jose.JSONWebKey{ + { + Use: "sig", + Key: &rsa.PublicKey{}, + }, + }, + }, + res{ + key: jose.JSONWebKey{ + Use: "sig", + Key: &rsa.PublicKey{}, + }, + err: nil, + }, + }, + { + "single key kid, jwt no kid, match", + args{ + keyID: "", + use: KeyUseSignature, + expectedAlg: "RS256", + keys: []jose.JSONWebKey{ + { + Use: "sig", + KeyID: "id", + Key: &rsa.PublicKey{}, + }, + }, + }, + res{ + key: jose.JSONWebKey{ + Use: "sig", + KeyID: "id", + Key: &rsa.PublicKey{}, + }, + err: nil, + }, + }, + { + "single key no kid, jwt with kid, match", + args{ + keyID: "id", + use: KeyUseSignature, + expectedAlg: "RS256", + keys: []jose.JSONWebKey{ + { + Use: "sig", + Key: &rsa.PublicKey{}, + }, + }, + }, + res{ + key: jose.JSONWebKey{ + Use: "sig", + Key: &rsa.PublicKey{}, + }, + err: nil, + }, + }, + { + "single key wrong kid, ErrKeyNone", + args{ + keyID: "id", + use: KeyUseSignature, + expectedAlg: "RS256", + keys: []jose.JSONWebKey{ + { + Use: "sig", + KeyID: "id2", + Key: &rsa.PublicKey{}, + }, + }, + }, + res{ + key: jose.JSONWebKey{}, + err: ErrKeyNone, + }, + }, + { + "multiple keys no kid, jwt no kid, ErrKeyMultiple", + args{ + keyID: "", + use: KeyUseSignature, + expectedAlg: "RS256", + keys: []jose.JSONWebKey{ + { + Use: "sig", + Key: &rsa.PublicKey{}, + }, + { + Use: "sig", + Key: &rsa.PublicKey{}, + }, + }, + }, + res{ + key: jose.JSONWebKey{}, + err: ErrKeyMultiple, + }, + }, + { + "multiple keys with kid, jwt no kid, ErrKeyMultiple", + args{ + keyID: "", + use: KeyUseSignature, + expectedAlg: "RS256", + keys: []jose.JSONWebKey{ + { + Use: "sig", + KeyID: "id1", + Key: &rsa.PublicKey{}, + }, + { + Use: "sig", + KeyID: "id2", + Key: &rsa.PublicKey{}, + }, + }, + }, + res{ + key: jose.JSONWebKey{}, + err: ErrKeyMultiple, + }, + }, + { + "multiple keys, single sig key, jwt no kid, match", + args{ + keyID: "", + use: KeyUseSignature, + expectedAlg: "RS256", + keys: []jose.JSONWebKey{ + { + Use: "sig", + Key: &rsa.PublicKey{}, + }, + { + Use: "enc", + Key: &rsa.PublicKey{}, + }, + }, + }, + res{ + key: jose.JSONWebKey{ + Use: "sig", + Key: &rsa.PublicKey{}, + }, + err: nil, + }, + }, + { + "multiple keys no kid, jwt with kid, ErrKeyMultiple", + args{ + keyID: "id", + use: KeyUseSignature, + expectedAlg: "RS256", + keys: []jose.JSONWebKey{ + { + Use: "sig", + Key: &rsa.PublicKey{}, + }, + { + Use: "sig", + Key: &rsa.PublicKey{}, + }, + }, + }, + res{ + key: jose.JSONWebKey{}, + err: ErrKeyMultiple, + }, + }, + { + "multiple keys with kid, jwt with kid, match", + args{ + keyID: "id1", + use: KeyUseSignature, + expectedAlg: "RS256", + keys: []jose.JSONWebKey{ + { + Use: "sig", + KeyID: "id1", + Key: &rsa.PublicKey{}, + }, + { + Use: "sig", + KeyID: "id2", + Key: &rsa.PublicKey{}, + }, + }, + }, + res{ + key: jose.JSONWebKey{ + Use: "sig", + KeyID: "id1", + Key: &rsa.PublicKey{}, + }, + err: nil, + }, + }, + { + "multiple keys, single sig key, jwt with kid, match", + args{ + keyID: "id1", + use: KeyUseSignature, + expectedAlg: "RS256", + keys: []jose.JSONWebKey{ + { + Use: "sig", + Key: &rsa.PublicKey{}, + }, + { + Use: "enc", + Key: &rsa.PublicKey{}, + }, + }, + }, + res{ + key: jose.JSONWebKey{ + Use: "sig", + Key: &rsa.PublicKey{}, + }, + err: nil, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := FindMatchingKey(tt.args.keyID, tt.args.use, tt.args.expectedAlg, tt.args.keys...) + if (tt.res.err != nil && !errors.Is(err, tt.res.err)) || (tt.res.err == nil && err != nil) { + t.Errorf("FindKey() error, got = %v, want = %v", err, tt.res.err) + } + if !reflect.DeepEqual(got, tt.res.key) { + t.Errorf("FindKey() got = %v, want %v", got, tt.res.key) + } + }) + } +} diff --git a/pkg/op/op.go b/pkg/op/op.go index 772b5f7..3841227 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -2,7 +2,7 @@ package op import ( "context" - "errors" + "fmt" "net/http" "time" @@ -280,12 +280,12 @@ type openIDKeySet struct { func (o *openIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { keySet, err := o.Storage.GetKeySet(ctx) if err != nil { - return nil, errors.New("error fetching keys") + return nil, fmt.Errorf("error fetching keys: %w", err) } keyID, alg := oidc.GetKeyIDAndAlg(jws) - key, ok := oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keySet.Keys...) - if !ok { - return nil, errors.New("invalid kid") + key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keySet.Keys...) + if err != nil { + return nil, fmt.Errorf("invalid signature: %w", err) } return jws.Verify(&key) }