From 65a58039dcb7a1b4ee09b5ff31620db1967cd6c1 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Mon, 26 Apr 2021 11:00:42 +0200 Subject: [PATCH] fix: rp verification process --- pkg/client/client.go | 3 ++ pkg/client/rp/jwks.go | 56 ++++++++++++++++++++++--------- pkg/client/rp/relaying_party.go | 30 +++++++++++++++-- pkg/client/rp/verifier.go | 6 +++- pkg/oidc/keyset.go | 58 ++++++++++++++++++++++++++++----- pkg/oidc/verifier.go | 16 +++++++-- pkg/op/op.go | 10 ++---- pkg/op/verifier_jwt_profile.go | 13 +++----- pkg/utils/hash.go | 3 ++ 9 files changed, 148 insertions(+), 47 deletions(-) diff --git a/pkg/client/client.go b/pkg/client/client.go index b2b815e..2cc413a 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -36,6 +36,9 @@ func Discover(issuer string, httpClient *http.Client) (*oidc.DiscoveryConfigurat if err != nil { return nil, err } + if discoveryConfig.Issuer != issuer { + return nil, oidc.ErrIssuerInvalid + } return discoveryConfig, nil } diff --git a/pkg/client/rp/jwks.go b/pkg/client/rp/jwks.go index 339fc93..98ed501 100644 --- a/pkg/client/rp/jwks.go +++ b/pkg/client/rp/jwks.go @@ -2,6 +2,7 @@ package rp import ( "context" + "encoding/json" "errors" "fmt" "net/http" @@ -21,6 +22,7 @@ func NewRemoteKeySet(client *http.Client, jwksURL string) oidc.KeySet { type remoteKeySet struct { jwksURL string httpClient *http.Client + defaultAlg string // guard all other fields mu sync.Mutex @@ -66,29 +68,27 @@ func (i *inflight) result() ([]jose.JSONWebKey, error) { } func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { - // We don't support JWTs signed with multiple signatures. - keyID := "" - for _, sig := range jws.Signatures { - keyID = sig.Header.KeyID - break + keyID, alg := oidc.GetKeyIDAndAlg(jws) + if alg == "" { + alg = r.defaultAlg } - keys := r.keysFromCache() - payload, err, ok := oidc.CheckKey(keyID, jws, keys...) - if ok { + key, ok := oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keys...) + if ok && keyID != "" { + payload, err := jws.Verify(&key) return payload, err } - keys, err = r.keysFromRemote(ctx) + keys, err := r.keysFromRemote(ctx) if err != nil { return nil, fmt.Errorf("fetching keys %v", err) } - - payload, err, ok = oidc.CheckKey(keyID, jws, keys...) - if !ok { - return nil, errors.New("invalid kid") + key, ok = oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keys...) + if ok { + payload, err := jws.Verify(&key) + return payload, err } - return payload, err + return nil, errors.New("invalid key") } func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) { @@ -147,10 +147,34 @@ func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey, return nil, fmt.Errorf("oidc: can't create request: %v", err) } - keySet := new(jose.JSONWebKeySet) + keySet := new(jsonWebKeySet) if err = utils.HttpRequest(r.httpClient, req, keySet); err != nil { return nil, fmt.Errorf("oidc: failed to get keys: %v", err) } - return keySet.Keys, nil } + +//jsonWebKeySet is an alias for jose.JSONWebKeySet which ignores unknown key types (kty) +type jsonWebKeySet jose.JSONWebKeySet + +//UnmarshalJSON overrides the default jose.JSONWebKeySet method to ignore any error +//which might occur because of unknown key types (kty) +func (k *jsonWebKeySet) UnmarshalJSON(data []byte) (err error) { + var raw rawJSONWebKeySet + err = json.Unmarshal(data, &raw) + if err != nil { + return err + } + for _, key := range raw.Keys { + webKey := new(jose.JSONWebKey) + err = webKey.UnmarshalJSON(key) + if err == nil { + k.Keys = append(k.Keys, *webKey) + } + } + return nil +} + +type rawJSONWebKeySet struct { + Keys []json.RawMessage `json:"keys"` +} diff --git a/pkg/client/rp/relaying_party.go b/pkg/client/rp/relaying_party.go index 528f554..9e02e65 100644 --- a/pkg/client/rp/relaying_party.go +++ b/pkg/client/rp/relaying_party.go @@ -22,6 +22,10 @@ const ( pkceCode = "pkce" ) +var ( + ErrUserInfoSubNotMatching = errors.New("sub from userinfo does not match the sub from the id_token") +) + //RelyingParty declares the minimal interface for oidc clients type RelyingParty interface { //OAuthConfig returns the oauth2 Config @@ -245,6 +249,9 @@ func Discover(issuer string, httpClient *http.Client) (Endpoints, error) { if err != nil { return Endpoints{}, err } + if discoveryConfig.Issuer != issuer { + return Endpoints{}, oidc.ErrIssuerInvalid + } return GetEndpoints(discoveryConfig), nil } @@ -323,7 +330,7 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod //CodeExchangeHandler extends the `CodeExchange` method with a http handler //including cookie handling for secure `state` transfer //and optional PKCE code verifier checking -func CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc.Tokens, string), rp RelyingParty) http.HandlerFunc { +func CodeExchangeHandler(callback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string), rp RelyingParty) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { state, err := tryReadStateCookie(w, r, rp) if err != nil { @@ -361,17 +368,34 @@ func CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc } } +//UserinfoCallback wraps the callback function of the CodeExchangeHandler +//and calls the userinfo endpoint with the access token +//on success it will pass the userinfo into its callback function as well +func UserinfoCallback(provider RelyingParty, f func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, info oidc.UserInfo, state string)) func(http.ResponseWriter, *http.Request, *oidc.Tokens, string) { + return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string) { + info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), provider) + if err != nil { + http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized) + return + } + f(w, r, tokens, info, state) + } +} + //Userinfo will call the OIDC Userinfo Endpoint with the provided token -func Userinfo(token string, rp RelyingParty) (oidc.UserInfo, error) { +func Userinfo(token, tokenType, subject string, rp RelyingParty) (oidc.UserInfo, error) { req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil) if err != nil { return nil, err } - req.Header.Set("authorization", token) + req.Header.Set("authorization", tokenType+" "+token) userinfo := oidc.NewUserInfo() if err := utils.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil { return nil, err } + if userinfo.GetSubject() != subject { + return nil, ErrUserInfoSubNotMatching + } return userinfo, nil } diff --git a/pkg/client/rp/verifier.go b/pkg/client/rp/verifier.go index 1f45ca8..027ca79 100644 --- a/pkg/client/rp/verifier.go +++ b/pkg/client/rp/verifier.go @@ -46,7 +46,11 @@ func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.I return nil, err } - if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { + if err := oidc.CheckSubject(claims); err != nil { + return nil, err + } + + if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil { return nil, err } diff --git a/pkg/oidc/keyset.go b/pkg/oidc/keyset.go index 0d8e02c..adfffcf 100644 --- a/pkg/oidc/keyset.go +++ b/pkg/oidc/keyset.go @@ -2,10 +2,17 @@ package oidc import ( "context" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" "gopkg.in/square/go-jose.v2" ) +const ( + KeyUseSignature = "sig" +) + //KeySet represents a set of JSON Web Keys // - remotely fetch via discovery and jwks_uri -> `remoteKeySet` // - held by the OP itself in storage -> `openIDKeySet` @@ -15,16 +22,51 @@ type KeySet interface { VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) } -//CheckKey searches the given JSON Web Keys for the requested key ID -//and verifies the JSON Web Signature with the found key +//GetKeyIDAndAlg returns the `kid` and `alg` claim from the JWS header +func GetKeyIDAndAlg(jws *jose.JSONWebSignature) (string, string) { + keyID := "" + alg := "" + for _, sig := range jws.Signatures { + keyID = sig.Header.KeyID + alg = sig.Header.Algorithm + break + } + return keyID, alg +} + +//FindKey searches the given JSON Web Keys for the requested key ID, usage and key type // -//will return false but no error if key ID is not found -func CheckKey(keyID string, jws *jose.JSONWebSignature, keys ...jose.JSONWebKey) ([]byte, error, bool) { +//will return the key immediately if matches exact (id, usage, type) +// +//will return false none or multiple match +func FindKey(keyID, use, expectedAlg string, keys ...jose.JSONWebKey) (jose.JSONWebKey, bool) { + var validKeys []jose.JSONWebKey for _, key := range keys { - if keyID == "" || key.KeyID == keyID { - payload, err := jws.Verify(&key) - return payload, err, true + if key.KeyID == keyID && key.Use == use && algToKeyType(key.Key, expectedAlg) { + if keyID != "" { + return key, true + } + validKeys = append(validKeys, key) } } - return nil, nil, false + if len(validKeys) == 1 { + return validKeys[0], true + } + return jose.JSONWebKey{}, false +} + +func algToKeyType(key interface{}, alg string) bool { + switch alg[0] { + case 'R', 'P': + _, ok := key.(*rsa.PublicKey) + return ok + case 'E': + _, ok := key.(*ecdsa.PublicKey) + return ok + case 'O': + _, ok := key.(*ed25519.PublicKey) + return ok + default: + return false + } } diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go index 06470a0..f8470b5 100644 --- a/pkg/oidc/verifier.go +++ b/pkg/oidc/verifier.go @@ -17,6 +17,7 @@ import ( type Claims interface { GetIssuer() string + GetSubject() string GetAudience() []string GetExpiration() time.Time GetIssuedAt() time.Time @@ -30,6 +31,7 @@ type Claims interface { var ( ErrParse = errors.New("parsing of request failed") ErrIssuerInvalid = errors.New("issuer does not match") + ErrSubjectMissing = errors.New("subject missing") ErrAudience = errors.New("audience is not valid") ErrAzpMissing = errors.New("authorized party is not set. If Token is valid for multiple audiences, azp must not be empty") ErrAzpInvalid = errors.New("authorized party is not valid") @@ -38,6 +40,7 @@ var ( ErrSignatureUnsupportedAlg = errors.New("signature algorithm not supported") ErrSignatureInvalidPayload = errors.New("signature does not match Payload") ErrExpired = errors.New("token has expired") + ErrIatMissing = errors.New("issuedAt of token is missing") ErrIatInFuture = errors.New("issuedAt of token is in the future") ErrIatToOld = errors.New("issuedAt of token is to old") ErrNonceInvalid = errors.New("nonce does not match") @@ -84,6 +87,13 @@ func ParseToken(tokenString string, claims interface{}) ([]byte, error) { return payload, err } +func CheckSubject(claims Claims) error { + if claims.GetSubject() == "" { + return ErrSubjectMissing + } + return nil +} + func CheckIssuer(claims Claims, issuer string) error { if claims.GetIssuer() != issuer { return fmt.Errorf("%w: Expected: %s, got: %s", ErrIssuerInvalid, issuer, claims.GetIssuer()) @@ -155,6 +165,9 @@ func CheckExpiration(claims Claims, offset time.Duration) error { func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error { issuedAt := claims.GetIssuedAt().Round(time.Second) + if issuedAt.IsZero() { + return ErrIatMissing + } nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second) if issuedAt.After(nowWithOffset) { return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset) @@ -170,9 +183,6 @@ func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error { } func CheckNonce(claims Claims, nonce string) error { - if nonce == "" { - return nil - } if claims.GetNonce() != nonce { return fmt.Errorf("%w: expected %q but was %q", ErrNonceInvalid, nonce, claims.GetNonce()) } diff --git a/pkg/op/op.go b/pkg/op/op.go index c91865d..a897dca 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -267,20 +267,16 @@ type openIDKeySet struct { //VerifySignature implements the oidc.KeySet interface //providing an implementation for the keys stored in the OP Storage interface func (o *openIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { - keyID := "" - for _, sig := range jws.Signatures { - keyID = sig.Header.KeyID - break - } keySet, err := o.Storage.GetKeySet(ctx) if err != nil { return nil, errors.New("error fetching keys") } - payload, err, ok := oidc.CheckKey(keyID, jws, keySet.Keys...) + keyID, alg := oidc.GetKeyIDAndAlg(jws) + key, ok := oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keySet.Keys...) if !ok { return nil, errors.New("invalid kid") } - return payload, err + return jws.Verify(key) } type Option func(o *openidProvider) error diff --git a/pkg/op/verifier_jwt_profile.go b/pkg/op/verifier_jwt_profile.go index 338e39a..bb776b0 100644 --- a/pkg/op/verifier_jwt_profile.go +++ b/pkg/op/verifier_jwt_profile.go @@ -86,18 +86,13 @@ type jwtProfileKeySet struct { } func (k *jwtProfileKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) { - keyID := "" - for _, sig := range jws.Signatures { - keyID = sig.Header.KeyID - break - } + keyID, alg := oidc.GetKeyIDAndAlg(jws) key, err := k.Storage.GetKeyByIDAndUserID(ctx, keyID, k.userID) if err != nil { return nil, fmt.Errorf("error fetching keys: %w", err) } - payload, err, ok := oidc.CheckKey(keyID, jws, *key) - if !ok { - return nil, errors.New("invalid kid") + if key.Algorithm != alg { + } - return payload, err + return jws.Verify(&key) } diff --git a/pkg/utils/hash.go b/pkg/utils/hash.go index b7dfd9c..5dae03c 100644 --- a/pkg/utils/hash.go +++ b/pkg/utils/hash.go @@ -24,6 +24,9 @@ func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) { } func HashString(hash hash.Hash, s string, firstHalf bool) string { + if hash == nil { + return s + } //nolint:errcheck hash.Write([]byte(s)) size := hash.Size()