lot of unfinished changes
This commit is contained in:
parent
9cb0fff23f
commit
a37a8461a5
16 changed files with 502 additions and 328 deletions
|
@ -1,6 +1,7 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -64,7 +65,7 @@ const (
|
|||
PromptSelectAccount Prompt = "select_account"
|
||||
|
||||
//GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
|
||||
GrantTypeCode GrantType = "authorization_code"
|
||||
GrantTypeCode GrantType = "authorization_code"
|
||||
//GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
|
||||
GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
|
||||
|
||||
|
@ -148,10 +149,67 @@ type AccessTokenResponse struct {
|
|||
}
|
||||
|
||||
type JWTTokenRequest struct {
|
||||
Scopes Scopes `schema:"scope"`
|
||||
Audience []string `schema:"aud"`
|
||||
IssuedAt time.Time `schema:"iat"`
|
||||
ExpiresAt time.Time `schema:"exp"`
|
||||
Issuer string `json:"iss"`
|
||||
Subject string `json:"sub"`
|
||||
Scopes Scopes `json:"scope"`
|
||||
Audience string `json:"aud"`
|
||||
IssuedAt Time `json:"iat"`
|
||||
ExpiresAt Time `json:"exp"`
|
||||
}
|
||||
|
||||
func (j *JWTTokenRequest) GetClientID() string {
|
||||
return j.Subject
|
||||
}
|
||||
|
||||
func (j *JWTTokenRequest) GetSubject() string {
|
||||
return j.Subject
|
||||
}
|
||||
|
||||
func (j *JWTTokenRequest) GetScopes() []string {
|
||||
return j.Scopes
|
||||
}
|
||||
|
||||
type Time time.Time
|
||||
|
||||
func (t *Time) UnmarshalJSON(data []byte) error {
|
||||
var i int64
|
||||
if err := json.Unmarshal(data, &i); err != nil {
|
||||
return err
|
||||
}
|
||||
*t = Time(time.Unix(i, 0).UTC())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *JWTTokenRequest) GetIssuer() string {
|
||||
return j.Issuer
|
||||
}
|
||||
|
||||
func (j *JWTTokenRequest) GetAudience() []string {
|
||||
return []string{j.Audience}
|
||||
}
|
||||
|
||||
func (j *JWTTokenRequest) GetExpiration() time.Time {
|
||||
return time.Time(j.ExpiresAt)
|
||||
}
|
||||
|
||||
func (j *JWTTokenRequest) GetIssuedAt() time.Time {
|
||||
return time.Time(j.IssuedAt)
|
||||
}
|
||||
|
||||
func (j *JWTTokenRequest) GetNonce() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (j *JWTTokenRequest) GetAuthTime() time.Time {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func (j *JWTTokenRequest) GetAuthorizedParty() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type TokenExchangeRequest struct {
|
||||
|
|
|
@ -177,6 +177,42 @@ func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *IDTokenClaims) GetIssuer() string {
|
||||
return t.Issuer
|
||||
}
|
||||
|
||||
func (t *IDTokenClaims) GetAudience() []string {
|
||||
return t.Audiences
|
||||
}
|
||||
|
||||
func (t *IDTokenClaims) GetExpiration() time.Time {
|
||||
return t.Expiration
|
||||
}
|
||||
|
||||
func (t *IDTokenClaims) GetIssuedAt() time.Time {
|
||||
return t.IssuedAt
|
||||
}
|
||||
|
||||
func (t *IDTokenClaims) GetNonce() string {
|
||||
return t.Nonce
|
||||
}
|
||||
|
||||
func (t *IDTokenClaims) GetAuthenticationContextClassReference() string {
|
||||
return t.AuthenticationContextClassReference
|
||||
}
|
||||
|
||||
func (t *IDTokenClaims) GetAuthTime() time.Time {
|
||||
return t.AuthTime
|
||||
}
|
||||
|
||||
func (t *IDTokenClaims) GetAuthorizedParty() string {
|
||||
return t.AuthorizedParty
|
||||
}
|
||||
|
||||
func (t *IDTokenClaims) SetSignature(alg jose.SignatureAlgorithm) {
|
||||
t.Signature = alg
|
||||
}
|
||||
|
||||
func (j *jsonToken) UnmarshalUserinfoProfile() UserinfoProfile {
|
||||
locale, _ := language.Parse(j.Locale)
|
||||
return UserinfoProfile{
|
||||
|
|
210
pkg/oidc/verifier.go
Normal file
210
pkg/oidc/verifier.go
Normal file
|
@ -0,0 +1,210 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
type Claims interface {
|
||||
GetIssuer() string
|
||||
GetAudience() []string
|
||||
GetExpiration() time.Time
|
||||
GetIssuedAt() time.Time
|
||||
GetNonce() string
|
||||
GetAuthenticationContextClassReference() string
|
||||
GetAuthTime() time.Time
|
||||
GetAuthorizedParty() string
|
||||
SetSignature(algorithm jose.SignatureAlgorithm)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrParse = errors.New("")
|
||||
ErrIssuerInvalid = errors.New("issuer does not match")
|
||||
|
||||
ErrAudience = errors.New("audience is not valid")
|
||||
|
||||
ErrAzpMissing = errors.New("authorized party is not set. If Token is valid for multiple audiences, azp must not be empty")
|
||||
ErrAzpInvalid = errors.New("authorized party is not valid")
|
||||
|
||||
ErrSignatureMissing = errors.New("id_token does not contain a signature")
|
||||
ErrSignatureMultiple = errors.New("id_token contains multiple signatures")
|
||||
ErrSignatureUnsupportedAlg = errors.New("signature algorithm not supported")
|
||||
ErrSignatureInvalidPayload = errors.New("signature does not match Payload")
|
||||
|
||||
ErrExpired = errors.New("token has expired")
|
||||
|
||||
ErrIatInFuture = errors.New("issuedAt of token is in the future")
|
||||
|
||||
ErrIatToOld = errors.New("issuedAt of token is to old")
|
||||
//
|
||||
//ErrNonceInvalid = func(expected, actual string) *validationError {
|
||||
// return ValidationError("nonce does not match. Expected: %s, got: %s", expected, actual)
|
||||
//}
|
||||
ErrAcrInvalid = errors.New("acr is invalid")
|
||||
ErrAuthTimeNotPresent = errors.New("claim `auth_time` of token is missing")
|
||||
ErrAuthTimeToOld = errors.New("auth time of token is to old")
|
||||
|
||||
ErrAtHash = errors.New("at_hash does not correspond to access token")
|
||||
)
|
||||
|
||||
//ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
|
||||
type ACRVerifier func(string) error
|
||||
|
||||
func DecryptToken(tokenString string) (string, error) {
|
||||
return tokenString, nil //TODO: impl
|
||||
}
|
||||
|
||||
func ParseToken(tokenString string, claims interface{}) ([]byte, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("%w: token contains an invalid number of segments", ErrParse)
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: malformed jwt payload: %v", ErrParse, err)
|
||||
}
|
||||
err = json.Unmarshal(payload, claims)
|
||||
return payload, err
|
||||
}
|
||||
|
||||
type Verifier interface {
|
||||
Issuer() string
|
||||
ClientID() string
|
||||
SupportedSignAlgs() []string
|
||||
KeySet() KeySet
|
||||
ACR() ACRVerifier
|
||||
MaxAge() time.Duration
|
||||
MaxAgeIAT() time.Duration
|
||||
Offset() time.Duration
|
||||
}
|
||||
|
||||
func CheckIssuer(issuer string, i Verifier) error {
|
||||
if i.Issuer() != issuer {
|
||||
return fmt.Errorf("%w: Expected: %s, got: %s", ErrIssuerInvalid, i.Issuer(), issuer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CheckAudience(audiences []string, i Verifier) error {
|
||||
if !utils.Contains(audiences, i.ClientID()) {
|
||||
return fmt.Errorf("%w: Audience must contain client_id (%s)", ErrAudience, i.ClientID())
|
||||
}
|
||||
|
||||
//TODO: check aud trusted
|
||||
return nil
|
||||
}
|
||||
|
||||
//4. if multiple aud strings --> check if azp
|
||||
//5. if azp --> check azp == client_id
|
||||
func CheckAuthorizedParty(audiences []string, authorizedParty string, v Verifier) error {
|
||||
if len(audiences) > 1 {
|
||||
if authorizedParty == "" {
|
||||
return ErrAzpMissing
|
||||
}
|
||||
}
|
||||
if authorizedParty != "" && authorizedParty != v.ClientID() {
|
||||
return fmt.Errorf("%w: azp %q must be equal to client_id %q", ErrAzpInvalid, authorizedParty, v.ClientID())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CheckSignature(ctx context.Context, idTokenString string, payload []byte, claims Claims, v Verifier) error {
|
||||
jws, err := jose.ParseSigned(idTokenString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(jws.Signatures) == 0 {
|
||||
return ErrSignatureMissing
|
||||
}
|
||||
if len(jws.Signatures) > 1 {
|
||||
return ErrSignatureMultiple
|
||||
}
|
||||
sig := jws.Signatures[0]
|
||||
supportedSigAlgs := v.SupportedSignAlgs()
|
||||
if len(supportedSigAlgs) == 0 {
|
||||
supportedSigAlgs = []string{"RS256"}
|
||||
}
|
||||
if !utils.Contains(supportedSigAlgs, sig.Header.Algorithm) {
|
||||
return fmt.Errorf("%w: id token signed with unsupported algorithm, expected %q got %q", ErrSignatureUnsupportedAlg, supportedSigAlgs, sig.Header.Algorithm)
|
||||
}
|
||||
|
||||
signedPayload, err := v.KeySet().VerifySignature(ctx, jws)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !bytes.Equal(signedPayload, payload) {
|
||||
return ErrSignatureInvalidPayload
|
||||
}
|
||||
|
||||
claims.SetSignature(jose.SignatureAlgorithm(sig.Header.Algorithm))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func CheckExpiration(expiration time.Time, v Verifier) error {
|
||||
expiration = expiration.Round(time.Second)
|
||||
if !time.Now().UTC().Add(v.Offset()).Before(expiration) {
|
||||
return ErrExpired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CheckIssuedAt(issuedAt time.Time, v Verifier) error {
|
||||
issuedAt = issuedAt.Round(time.Second)
|
||||
offset := time.Now().UTC().Add(v.Offset()).Round(time.Second)
|
||||
if issuedAt.After(offset) {
|
||||
return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, offset)
|
||||
}
|
||||
if v.MaxAgeIAT() == 0 {
|
||||
return nil
|
||||
}
|
||||
maxAge := time.Now().UTC().Add(-v.MaxAgeIAT()).Round(time.Second)
|
||||
if issuedAt.Before(maxAge) {
|
||||
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
/*
|
||||
func (v *DefaultVerifier) CheckNonce(nonce string) error {
|
||||
if v.config.nonce == "" {
|
||||
return nil
|
||||
}
|
||||
if v.config.nonce != nonce {
|
||||
return ErrNonceInvalid(v.config.nonce, nonce)
|
||||
}
|
||||
return nil
|
||||
}*/
|
||||
func CheckAuthorizationContextClassReference(acr string, v Verifier) error {
|
||||
if v.ACR() != nil {
|
||||
if err := v.ACR()(acr); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrAcrInvalid, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func CheckAuthTime(authTime time.Time, v Verifier) error {
|
||||
if v.MaxAge() == 0 {
|
||||
return nil
|
||||
}
|
||||
if authTime.IsZero() {
|
||||
return ErrAuthTimeNotPresent
|
||||
}
|
||||
authTime = authTime.Round(time.Second)
|
||||
maxAge := time.Now().UTC().Add(-v.MaxAge()).Round(time.Second)
|
||||
if authTime.Before(maxAge) {
|
||||
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAge.Sub(authTime))
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue