lot of unfinished changes

This commit is contained in:
Livio Amstutz 2020-09-08 16:07:49 +02:00
parent 9cb0fff23f
commit a37a8461a5
16 changed files with 502 additions and 328 deletions

View file

@ -1,16 +1,10 @@
package rp
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"time"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
@ -24,9 +18,6 @@ type DefaultVerifier struct {
//ConfFunc is the type for providing dynamic options to the DefaultVerfifier
type ConfFunc func(*verifierConfig)
//ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
type ACRVerifier func(string) error
//NewDefaultVerifier creates `DefaultVerifier` with the given
//issuer, clientID, keyset and possible configOptions
func NewDefaultVerifier(issuer, clientID string, keySet oidc.KeySet, confOpts ...ConfFunc) Verifier {
@ -90,7 +81,7 @@ func WithNonce(nonce string) func(*verifierConfig) {
}
//WithACRVerifier sets the verifier for the acr claim
func WithACRVerifier(verifier ACRVerifier) func(*verifierConfig) {
func WithACRVerifier(verifier oidc.ACRVerifier) func(*verifierConfig) {
return func(conf *verifierConfig) {
conf.acr = verifier
}
@ -117,7 +108,7 @@ type verifierConfig struct {
ignoreAudience bool
ignoreExpiration bool
iat *iatConfig
acr ACRVerifier
acr oidc.ACRVerifier
maxAge time.Duration
supportedSignAlgs []string
@ -134,10 +125,10 @@ type iatConfig struct {
//DefaultACRVerifier implements `ACRVerifier` returning an error
//if non of the provided values matches the acr claim
func DefaultACRVerifier(possibleValues []string) ACRVerifier {
func DefaultACRVerifier(possibleValues []string) oidc.ACRVerifier {
return func(acr string) error {
if !utils.Contains(possibleValues, acr) {
return ErrAcrInvalid(possibleValues, acr)
return fmt.Errorf("expected one of: %v, got: %q", possibleValues, acr)
}
return nil
}
@ -148,88 +139,13 @@ func DefaultACRVerifier(possibleValues []string) ACRVerifier {
//and https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
func (v *DefaultVerifier) Verify(ctx context.Context, accessToken, idTokenString string) (*oidc.IDTokenClaims, error) {
v.config.now = time.Now().UTC()
// idToken, err := v.VerifyIDToken(ctx, idTokenString)
// if err != nil {
// return nil, err
// }
// if err := v.verifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil { //TODO: sig from token
// return nil, err
// }
// return idToken, nil
// TODO: verifiy
decrypted, err := v.decryptToken(idTokenString)
if err != nil {
return nil, err
}
claims, _, err := v.parseToken(decrypted)
if err != nil {
return nil, err
}
return claims, nil
return VerifyTokens(ctx, accessToken, idTokenString, v)
}
//Verify implements the `VerifyIDToken` method of the `Verifier` interface
//according to https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func (v *DefaultVerifier) VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) {
//1. if encrypted --> decrypt
decrypted, err := v.decryptToken(idTokenString)
if err != nil {
return nil, err
}
claims, payload, err := v.parseToken(decrypted)
if err != nil {
return nil, err
}
// token, err := jwt.ParseWithClaims(decrypted, claims, func(token *jwt.Token) (interface{}, error) {
//2, check issuer (exact match)
if err := v.checkIssuer(claims.Issuer); err != nil {
return nil, err
}
//3. check aud (aud must contain client_id, all aud strings must be allowed)
if err = v.checkAudience(claims.Audiences); err != nil {
return nil, err
}
if err = v.checkAuthorizedParty(claims.Audiences, claims.AuthorizedParty); err != nil {
return nil, err
}
//6. check signature by keys
//7. check alg default is rs256
//8. check if alg is mac based (hs...) -> audience contains client_id. for validation use utf-8 representation of your client_secret
claims.Signature, err = v.checkSignature(ctx, decrypted, payload)
if err != nil {
return nil, err
}
//9. check exp before now
if err = v.checkExpiration(claims.Expiration); err != nil {
return nil, err
}
//10. check iat duration is optional (can be checked)
if err = v.checkIssuedAt(claims.IssuedAt); err != nil {
return nil, err
}
//11. check nonce (check if optional possible) id_token.nonce == sentNonce
if err = v.checkNonce(claims.Nonce); err != nil {
return nil, err
}
//12. if acr requested check acr
if err = v.checkAuthorizationContextClassReference(claims.AuthenticationContextClassReference); err != nil {
return nil, err
}
//13. if auth_time requested check if auth_time is less than max age
if err = v.checkAuthTime(claims.AuthTime); err != nil {
return nil, err
}
return claims, nil
return VerifyIDToken(ctx, idTokenString, v)
}
func (v *DefaultVerifier) now() time.Time {
@ -239,161 +155,34 @@ func (v *DefaultVerifier) now() time.Time {
return v.config.now
}
func (v *DefaultVerifier) parseToken(tokenString string) (*oidc.IDTokenClaims, []byte, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, nil, ValidationError("token contains an invalid number of segments") //TODO: err NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
}
idToken := new(oidc.IDTokenClaims)
err = json.Unmarshal(payload, idToken)
return idToken, payload, err
func (v *DefaultVerifier) Issuer() string {
return v.config.issuer
}
func (v *DefaultVerifier) checkIssuer(issuer string) error {
if v.config.issuer != issuer {
return ErrIssuerInvalid(v.config.issuer, issuer)
}
return nil
func (v *DefaultVerifier) ClientID() string {
return v.config.clientID
}
func (v *DefaultVerifier) checkAudience(audiences []string) error {
if v.config.ignoreAudience {
return nil
}
if !utils.Contains(audiences, v.config.clientID) {
return ErrAudienceMissingClientID(v.config.clientID)
}
//TODO: check aud trusted
return nil
func (v *DefaultVerifier) SupportedSignAlgs() []string {
return v.config.supportedSignAlgs
}
//4. if multiple aud strings --> check if azp
//5. if azp --> check azp == client_id
func (v *DefaultVerifier) checkAuthorizedParty(audiences []string, authorizedParty string) error {
if v.config.ignoreAudience {
return nil
}
if len(audiences) > 1 {
if authorizedParty == "" {
return ErrAzpMissing()
}
}
if authorizedParty != "" && authorizedParty != v.config.clientID {
return ErrAzpInvalid(authorizedParty, v.config.clientID)
}
return nil
func (v *DefaultVerifier) KeySet() oidc.KeySet {
return v.keySet
}
func (v *DefaultVerifier) checkSignature(ctx context.Context, idTokenString string, payload []byte) (jose.SignatureAlgorithm, error) {
jws, err := jose.ParseSigned(idTokenString)
if err != nil {
return "", err
}
if len(jws.Signatures) == 0 {
return "", ErrSignatureMissing()
}
if len(jws.Signatures) > 1 {
return "", ErrSignatureMultiple()
}
sig := jws.Signatures[0]
supportedSigAlgs := v.config.supportedSignAlgs
if len(supportedSigAlgs) == 0 {
supportedSigAlgs = []string{"RS256"}
}
if !utils.Contains(supportedSigAlgs, sig.Header.Algorithm) {
return "", fmt.Errorf("oidc: id token signed with unsupported algorithm, expected %q got %q", supportedSigAlgs, sig.Header.Algorithm)
}
signedPayload, err := v.keySet.VerifySignature(ctx, jws)
if err != nil {
return "", err
}
if !bytes.Equal(signedPayload, payload) {
return "", ErrSignatureInvalidPayload()
}
return jose.SignatureAlgorithm(sig.Header.Algorithm), nil
func (v *DefaultVerifier) ACR() oidc.ACRVerifier {
return v.config.acr
}
func (v *DefaultVerifier) checkExpiration(expiration time.Time) error {
if v.config.ignoreExpiration {
return nil
}
expiration = expiration.Round(time.Second)
if !v.now().Before(expiration) {
return ErrExpInvalid(expiration)
}
return nil
func (v *DefaultVerifier) MaxAge() time.Duration {
return v.config.maxAge
}
func (v *DefaultVerifier) checkIssuedAt(issuedAt time.Time) error {
if v.config.iat.ignore {
return nil
}
issuedAt = issuedAt.Round(time.Second)
offset := v.now().Add(v.config.iat.offset).Round(time.Second)
if issuedAt.After(offset) {
return ErrIatInFuture(issuedAt, offset)
}
if v.config.iat.maxAge == 0 {
return nil
}
maxAge := v.now().Add(-v.config.iat.maxAge).Round(time.Second)
if issuedAt.Before(maxAge) {
return ErrIatToOld(maxAge, issuedAt)
}
return nil
}
func (v *DefaultVerifier) checkNonce(nonce string) error {
if v.config.nonce == "" {
return nil
}
if v.config.nonce != nonce {
return ErrNonceInvalid(v.config.nonce, nonce)
}
return nil
}
func (v *DefaultVerifier) checkAuthorizationContextClassReference(acr string) error {
if v.config.acr != nil {
return v.config.acr(acr)
}
return nil
}
func (v *DefaultVerifier) checkAuthTime(authTime time.Time) error {
if v.config.maxAge == 0 {
return nil
}
if authTime.IsZero() {
return ErrAuthTimeNotPresent()
}
authTime = authTime.Round(time.Second)
maxAge := v.now().Add(-v.config.maxAge).Round(time.Second)
if authTime.Before(maxAge) {
return ErrAuthTimeToOld(maxAge, authTime)
}
return nil
func (v *DefaultVerifier) MaxAgeIAT() time.Duration {
return v.config.iat.maxAge
}
func (v *DefaultVerifier) decryptToken(tokenString string) (string, error) {
return tokenString, nil //TODO: impl
}
func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
if atHash == "" {
return nil
}
actual, err := oidc.ClaimHash(accessToken, sigAlgorithm)
if err != nil {
return err
}
if actual != atHash {
return ErrAtHash()
}
return nil
func (v *DefaultVerifier) Offset() time.Duration {
return v.config.iat.offset
}

View file

@ -1,67 +0,0 @@
package rp
import (
"fmt"
"time"
)
var (
ErrIssuerInvalid = func(expected, actual string) *validationError {
return ValidationError("Issuer does not match. Expected: %s, got: %s", expected, actual)
}
ErrAudienceMissingClientID = func(clientID string) *validationError {
return ValidationError("Audience is not valid. Audience must contain client_id (%s)", clientID)
}
ErrAzpMissing = func() *validationError {
return ValidationError("Authorized Party is not set. If Token is valid for multiple audiences, azp must not be empty")
}
ErrAzpInvalid = func(azp, clientID string) *validationError {
return ValidationError("Authorized Party is not valid. azp (%s) must be equal to client_id (%s)", azp, clientID)
}
ErrExpInvalid = func(exp time.Time) *validationError {
return ValidationError("Token has expired %v", exp)
}
ErrIatInFuture = func(exp, now time.Time) *validationError {
return ValidationError("IssuedAt of token is in the future (%v, now with offset: %v)", exp, now)
}
ErrIatToOld = func(maxAge, iat time.Time) *validationError {
return ValidationError("IssuedAt of token must not be older than %v, but was %v (%v to old)", maxAge, iat, maxAge.Sub(iat))
}
ErrNonceInvalid = func(expected, actual string) *validationError {
return ValidationError("nonce does not match. Expected: %s, got: %s", expected, actual)
}
ErrAcrInvalid = func(expected []string, actual string) *validationError {
return ValidationError("acr is invalid. Expected one of: %v, got: %s", expected, actual)
}
ErrAuthTimeNotPresent = func() *validationError {
return ValidationError("claim `auth_time` of token is missing")
}
ErrAuthTimeToOld = func(maxAge, authTime time.Time) *validationError {
return ValidationError("Auth Time of token must not be older than %v, but was %v (%v to old)", maxAge, authTime, maxAge.Sub(authTime))
}
ErrSignatureMissing = func() *validationError {
return ValidationError("id_token does not contain a signature")
}
ErrSignatureMultiple = func() *validationError {
return ValidationError("id_token contains multiple signatures")
}
ErrSignatureInvalidPayload = func() *validationError {
return ValidationError("Signature does not match Payload")
}
ErrAtHash = func() *validationError {
return ValidationError("at_hash does not correspond to access token")
}
)
func ValidationError(message string, args ...interface{}) *validationError {
return &validationError{fmt.Sprintf(message, args...)} //TODO: impl
}
type validationError struct {
message string
}
func (v *validationError) Error() string {
return v.message
}

View file

@ -3,11 +3,12 @@ package rp
import (
"context"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
)
//Verifier implement the Token Response Validation as defined in OIDC specification
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
//deprecated: Use IDTokenVerifier or oidc.Verifier
type Verifier interface {
//Verify checks the access_token and id_token and returns the `id token claims`
@ -16,3 +17,100 @@ type Verifier interface {
//VerifyIDToken checks the id_token only and returns its `id token claims`
VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error)
}
type IDTokenVerifier interface {
oidc.Verifier
}
//VerifyTokens implement the Token Response Validation as defined in OIDC specification
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) {
idToken, err := VerifyIDToken(ctx, idTokenString, v)
if err != nil {
return nil, err
}
if err := VerifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil {
return nil, err
}
return idToken, nil
}
//VerifyIDToken validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) {
claims := new(oidc.IDTokenClaims)
decrypted, err := oidc.DecryptToken(token)
if err != nil {
return nil, err
}
payload, err := oidc.ParseToken(decrypted, claims)
if err != nil {
return nil, err
}
//2, check issuer (exact match)
if err := oidc.CheckIssuer(claims.GetIssuer(), v); err != nil {
return nil, err
}
//3. check aud (aud must contain client_id, all aud strings must be allowed)
if err = oidc.CheckAudience(claims.GetAudience(), v); err != nil {
return nil, err
}
if err = oidc.CheckAuthorizedParty(claims.GetAudience(), claims.GetAuthorizedParty(), v); err != nil {
return nil, err
}
//6. check signature by keys
//7. check alg default is rs256
//8. check if alg is mac based (hs...) -> audience contains client_id. for validation use utf-8 representation of your client_secret
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v); err != nil {
return nil, err
}
//9. check exp before now
if err = oidc.CheckExpiration(claims.GetExpiration(), v); err != nil {
return nil, err
}
//10. check iat duration is optional (can be checked)
if err = oidc.CheckIssuedAt(claims.GetIssuedAt(), v); err != nil {
return nil, err
}
/*
//11. check nonce (check if optional possible) id_token.nonce == sentNonce
if err = oidc.CheckNonce(claims.GetNonce()); err != nil {
return nil, err
}
*/
//12. if acr requested check acr
if err = oidc.CheckAuthorizationContextClassReference(claims.GetAuthenticationContextClassReference(), v); err != nil {
return nil, err
}
//13. if auth_time requested check if auth_time is less than max age
if err = oidc.CheckAuthTime(claims.GetAuthTime(), v); err != nil {
return nil, err
}
return claims, nil
}
//VerifyAccessToken validates the access token according to
//https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
if atHash == "" {
return nil
}
actual, err := oidc.ClaimHash(accessToken, sigAlgorithm)
if err != nil {
return err
}
if actual != atHash {
return oidc.ErrAtHash
}
return nil
}

9
pkg/rp/verity.go Normal file
View file

@ -0,0 +1,9 @@
package rp
import (
"context"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
)