fix: rp verification process (#95)
* fix: rp verification process * types * comments * fix cli client
This commit is contained in:
parent
400f5c4de4
commit
850faa159d
11 changed files with 175 additions and 55 deletions
|
@ -2,10 +2,17 @@ package oidc
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rsa"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
KeyUseSignature = "sig"
|
||||
)
|
||||
|
||||
//KeySet represents a set of JSON Web Keys
|
||||
// - remotely fetch via discovery and jwks_uri -> `remoteKeySet`
|
||||
// - held by the OP itself in storage -> `openIDKeySet`
|
||||
|
@ -15,16 +22,51 @@ type KeySet interface {
|
|||
VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error)
|
||||
}
|
||||
|
||||
//CheckKey searches the given JSON Web Keys for the requested key ID
|
||||
//and verifies the JSON Web Signature with the found key
|
||||
//GetKeyIDAndAlg returns the `kid` and `alg` claim from the JWS header
|
||||
func GetKeyIDAndAlg(jws *jose.JSONWebSignature) (string, string) {
|
||||
keyID := ""
|
||||
alg := ""
|
||||
for _, sig := range jws.Signatures {
|
||||
keyID = sig.Header.KeyID
|
||||
alg = sig.Header.Algorithm
|
||||
break
|
||||
}
|
||||
return keyID, alg
|
||||
}
|
||||
|
||||
//FindKey searches the given JSON Web Keys for the requested key ID, usage and key type
|
||||
//
|
||||
//will return false but no error if key ID is not found
|
||||
func CheckKey(keyID string, jws *jose.JSONWebSignature, keys ...jose.JSONWebKey) ([]byte, error, bool) {
|
||||
//will return the key immediately if matches exact (id, usage, type)
|
||||
//
|
||||
//will return false none or multiple match
|
||||
func FindKey(keyID, use, expectedAlg string, keys ...jose.JSONWebKey) (jose.JSONWebKey, bool) {
|
||||
var validKeys []jose.JSONWebKey
|
||||
for _, key := range keys {
|
||||
if keyID == "" || key.KeyID == keyID {
|
||||
payload, err := jws.Verify(&key)
|
||||
return payload, err, true
|
||||
if key.KeyID == keyID && key.Use == use && algToKeyType(key.Key, expectedAlg) {
|
||||
if keyID != "" {
|
||||
return key, true
|
||||
}
|
||||
validKeys = append(validKeys, key)
|
||||
}
|
||||
}
|
||||
return nil, nil, false
|
||||
if len(validKeys) == 1 {
|
||||
return validKeys[0], true
|
||||
}
|
||||
return jose.JSONWebKey{}, false
|
||||
}
|
||||
|
||||
func algToKeyType(key interface{}, alg string) bool {
|
||||
switch alg[0] {
|
||||
case 'R', 'P':
|
||||
_, ok := key.(*rsa.PublicKey)
|
||||
return ok
|
||||
case 'E':
|
||||
_, ok := key.(*ecdsa.PublicKey)
|
||||
return ok
|
||||
case 'O':
|
||||
_, ok := key.(*ed25519.PublicKey)
|
||||
return ok
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
|
||||
type Claims interface {
|
||||
GetIssuer() string
|
||||
GetSubject() string
|
||||
GetAudience() []string
|
||||
GetExpiration() time.Time
|
||||
GetIssuedAt() time.Time
|
||||
|
@ -30,6 +31,7 @@ type Claims interface {
|
|||
var (
|
||||
ErrParse = errors.New("parsing of request failed")
|
||||
ErrIssuerInvalid = errors.New("issuer does not match")
|
||||
ErrSubjectMissing = errors.New("subject missing")
|
||||
ErrAudience = errors.New("audience is not valid")
|
||||
ErrAzpMissing = errors.New("authorized party is not set. If Token is valid for multiple audiences, azp must not be empty")
|
||||
ErrAzpInvalid = errors.New("authorized party is not valid")
|
||||
|
@ -38,6 +40,7 @@ var (
|
|||
ErrSignatureUnsupportedAlg = errors.New("signature algorithm not supported")
|
||||
ErrSignatureInvalidPayload = errors.New("signature does not match Payload")
|
||||
ErrExpired = errors.New("token has expired")
|
||||
ErrIatMissing = errors.New("issuedAt of token is missing")
|
||||
ErrIatInFuture = errors.New("issuedAt of token is in the future")
|
||||
ErrIatToOld = errors.New("issuedAt of token is to old")
|
||||
ErrNonceInvalid = errors.New("nonce does not match")
|
||||
|
@ -84,6 +87,13 @@ func ParseToken(tokenString string, claims interface{}) ([]byte, error) {
|
|||
return payload, err
|
||||
}
|
||||
|
||||
func CheckSubject(claims Claims) error {
|
||||
if claims.GetSubject() == "" {
|
||||
return ErrSubjectMissing
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CheckIssuer(claims Claims, issuer string) error {
|
||||
if claims.GetIssuer() != issuer {
|
||||
return fmt.Errorf("%w: Expected: %s, got: %s", ErrIssuerInvalid, issuer, claims.GetIssuer())
|
||||
|
@ -155,6 +165,9 @@ func CheckExpiration(claims Claims, offset time.Duration) error {
|
|||
|
||||
func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error {
|
||||
issuedAt := claims.GetIssuedAt().Round(time.Second)
|
||||
if issuedAt.IsZero() {
|
||||
return ErrIatMissing
|
||||
}
|
||||
nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second)
|
||||
if issuedAt.After(nowWithOffset) {
|
||||
return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset)
|
||||
|
@ -170,9 +183,6 @@ func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error {
|
|||
}
|
||||
|
||||
func CheckNonce(claims Claims, nonce string) error {
|
||||
if nonce == "" {
|
||||
return nil
|
||||
}
|
||||
if claims.GetNonce() != nonce {
|
||||
return fmt.Errorf("%w: expected %q but was %q", ErrNonceInvalid, nonce, claims.GetNonce())
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue