zitadel-oidc/pkg/oidc/verifier.go
Livio Amstutz a63fbee93d
fix: improve JWS and key verification (#128)
* fix: improve JWS and key verification

* fix: get remote keys if no cached key matches

* fix: get remote keys if no cached key matches

* fix exactMatch

* fix exactMatch

* chore: change default branch name in .releaserc.js
2021-09-14 15:13:44 +02:00

214 lines
6.7 KiB
Go

package oidc
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/utils"
)
type Claims interface {
GetIssuer() string
GetSubject() string
GetAudience() []string
GetExpiration() time.Time
GetIssuedAt() time.Time
GetNonce() string
GetAuthenticationContextClassReference() string
GetAuthTime() time.Time
GetAuthorizedParty() string
SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm)
}
var (
ErrParse = errors.New("parsing of request failed")
ErrIssuerInvalid = errors.New("issuer does not match")
ErrSubjectMissing = errors.New("subject missing")
ErrAudience = errors.New("audience is not valid")
ErrAzpMissing = errors.New("authorized party is not set. If Token is valid for multiple audiences, azp must not be empty")
ErrAzpInvalid = errors.New("authorized party is not valid")
ErrSignatureMissing = errors.New("id_token does not contain a signature")
ErrSignatureMultiple = errors.New("id_token contains multiple signatures")
ErrSignatureUnsupportedAlg = errors.New("signature algorithm not supported")
ErrSignatureInvalidPayload = errors.New("signature does not match Payload")
ErrSignatureInvalid = errors.New("invalid signature")
ErrExpired = errors.New("token has expired")
ErrIatMissing = errors.New("issuedAt of token is missing")
ErrIatInFuture = errors.New("issuedAt of token is in the future")
ErrIatToOld = errors.New("issuedAt of token is to old")
ErrNonceInvalid = errors.New("nonce does not match")
ErrAcrInvalid = errors.New("acr is invalid")
ErrAuthTimeNotPresent = errors.New("claim `auth_time` of token is missing")
ErrAuthTimeToOld = errors.New("auth time of token is to old")
ErrAtHash = errors.New("at_hash does not correspond to access token")
)
type Verifier interface {
Issuer() string
MaxAgeIAT() time.Duration
Offset() time.Duration
}
//ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
type ACRVerifier func(string) error
//DefaultACRVerifier implements `ACRVerifier` returning an error
//if non of the provided values matches the acr claim
func DefaultACRVerifier(possibleValues []string) ACRVerifier {
return func(acr string) error {
if !utils.Contains(possibleValues, acr) {
return fmt.Errorf("expected one of: %v, got: %q", possibleValues, acr)
}
return nil
}
}
func DecryptToken(tokenString string) (string, error) {
return tokenString, nil //TODO: impl
}
func ParseToken(tokenString string, claims interface{}) ([]byte, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("%w: token contains an invalid number of segments", ErrParse)
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("%w: malformed jwt payload: %v", ErrParse, err)
}
err = json.Unmarshal(payload, claims)
return payload, err
}
func CheckSubject(claims Claims) error {
if claims.GetSubject() == "" {
return ErrSubjectMissing
}
return nil
}
func CheckIssuer(claims Claims, issuer string) error {
if claims.GetIssuer() != issuer {
return fmt.Errorf("%w: Expected: %s, got: %s", ErrIssuerInvalid, issuer, claims.GetIssuer())
}
return nil
}
func CheckAudience(claims Claims, clientID string) error {
if !utils.Contains(claims.GetAudience(), clientID) {
return fmt.Errorf("%w: Audience must contain client_id %q", ErrAudience, clientID)
}
//TODO: check aud trusted
return nil
}
func CheckAuthorizedParty(claims Claims, clientID string) error {
if len(claims.GetAudience()) > 1 {
if claims.GetAuthorizedParty() == "" {
return ErrAzpMissing
}
}
if claims.GetAuthorizedParty() != "" && claims.GetAuthorizedParty() != clientID {
return fmt.Errorf("%w: azp %q must be equal to client_id %q", ErrAzpInvalid, claims.GetAuthorizedParty(), clientID)
}
return nil
}
func CheckSignature(ctx context.Context, token string, payload []byte, claims Claims, supportedSigAlgs []string, set KeySet) error {
jws, err := jose.ParseSigned(token)
if err != nil {
return ErrParse
}
if len(jws.Signatures) == 0 {
return ErrSignatureMissing
}
if len(jws.Signatures) > 1 {
return ErrSignatureMultiple
}
sig := jws.Signatures[0]
if len(supportedSigAlgs) == 0 {
supportedSigAlgs = []string{"RS256"}
}
if !utils.Contains(supportedSigAlgs, sig.Header.Algorithm) {
return fmt.Errorf("%w: id token signed with unsupported algorithm, expected %q got %q", ErrSignatureUnsupportedAlg, supportedSigAlgs, sig.Header.Algorithm)
}
signedPayload, err := set.VerifySignature(ctx, jws)
if err != nil {
return fmt.Errorf("%w (%v)", ErrSignatureInvalid, err)
}
if !bytes.Equal(signedPayload, payload) {
return ErrSignatureInvalidPayload
}
claims.SetSignatureAlgorithm(jose.SignatureAlgorithm(sig.Header.Algorithm))
return nil
}
func CheckExpiration(claims Claims, offset time.Duration) error {
expiration := claims.GetExpiration().Round(time.Second)
if !time.Now().UTC().Add(offset).Before(expiration) {
return ErrExpired
}
return nil
}
func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error {
issuedAt := claims.GetIssuedAt().Round(time.Second)
if issuedAt.IsZero() {
return ErrIatMissing
}
nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second)
if issuedAt.After(nowWithOffset) {
return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset)
}
if maxAgeIAT == 0 {
return nil
}
maxAge := time.Now().UTC().Add(-maxAgeIAT).Round(time.Second)
if issuedAt.Before(maxAge) {
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt))
}
return nil
}
func CheckNonce(claims Claims, nonce string) error {
if claims.GetNonce() != nonce {
return fmt.Errorf("%w: expected %q but was %q", ErrNonceInvalid, nonce, claims.GetNonce())
}
return nil
}
func CheckAuthorizationContextClassReference(claims Claims, acr ACRVerifier) error {
if acr != nil {
if err := acr(claims.GetAuthenticationContextClassReference()); err != nil {
return fmt.Errorf("%w: %v", ErrAcrInvalid, err)
}
}
return nil
}
func CheckAuthTime(claims Claims, maxAge time.Duration) error {
if maxAge == 0 {
return nil
}
if claims.GetAuthTime().IsZero() {
return ErrAuthTimeNotPresent
}
authTime := claims.GetAuthTime().Round(time.Second)
maxAuthTime := time.Now().UTC().Add(-maxAge).Round(time.Second)
if authTime.Before(maxAuthTime) {
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime))
}
return nil
}