change verifier interfaces

This commit is contained in:
Livio Amstutz 2020-09-11 10:45:07 +02:00
parent 3777f1436d
commit 143ff3482c
11 changed files with 274 additions and 179 deletions

View file

@ -151,8 +151,8 @@ func (s *AuthStorage) AuthRequestByID(_ context.Context, id string) (op.AuthRequ
} }
return a, nil return a, nil
} }
func (s *AuthStorage) CreateToken(_ context.Context, authReq op.AuthRequest) (string, time.Time, error) { func (s *AuthStorage) CreateToken(_ context.Context, authReq op.TokenRequest) (string, time.Time, error) {
return authReq.GetID(), time.Now().UTC().Add(5 * time.Minute), nil return "authReq.GetID()", time.Now().UTC().Add(5 * time.Minute), nil
} }
func (s *AuthStorage) TerminateSession(_ context.Context, userID, clientID string) error { func (s *AuthStorage) TerminateSession(_ context.Context, userID, clientID string) error {
return nil return nil
@ -174,6 +174,22 @@ func (s *AuthStorage) GetKeySet(_ context.Context) (*jose.JSONWebKeySet, error)
}, },
}, nil }, nil
} }
func (s *AuthStorage) GetKeyByID(_ context.Context, _ string) (*jose.JSONWebKeySet, error) {
pubkey := s.key.Public()
return &jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"},
},
}, nil
}
func (s *AuthStorage) GetKeysByServiceAccount(_ context.Context, _ string) (*jose.JSONWebKeySet, error) {
pubkey := s.key.Public()
return &jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"},
},
}, nil
}
func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Client, error) { func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Client, error) {
if id == "none" { if id == "none" {
@ -182,20 +198,24 @@ func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Clie
var appType op.ApplicationType var appType op.ApplicationType
var authMethod op.AuthMethod var authMethod op.AuthMethod
var accessTokenType op.AccessTokenType var accessTokenType op.AccessTokenType
var responseTypes []oidc.ResponseType
if id == "web" { if id == "web" {
appType = op.ApplicationTypeWeb appType = op.ApplicationTypeWeb
authMethod = op.AuthMethodBasic authMethod = op.AuthMethodBasic
accessTokenType = op.AccessTokenTypeBearer accessTokenType = op.AccessTokenTypeBearer
responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode}
} else if id == "native" { } else if id == "native" {
appType = op.ApplicationTypeNative appType = op.ApplicationTypeNative
authMethod = op.AuthMethodNone authMethod = op.AuthMethodNone
accessTokenType = op.AccessTokenTypeBearer accessTokenType = op.AccessTokenTypeBearer
responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode}
} else { } else {
appType = op.ApplicationTypeUserAgent appType = op.ApplicationTypeUserAgent
authMethod = op.AuthMethodNone authMethod = op.AuthMethodNone
accessTokenType = op.AccessTokenTypeJWT accessTokenType = op.AccessTokenTypeJWT
responseTypes = []oidc.ResponseType{oidc.ResponseTypeIDToken, oidc.ResponseTypeIDTokenOnly}
} }
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod, accessTokenType: accessTokenType, devMode: false}, nil return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod, accessTokenType: accessTokenType, responseTypes: responseTypes, devMode: false}, nil
} }
func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error { func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error {

View file

@ -30,33 +30,45 @@ type Claims interface {
var ( var (
ErrParse = errors.New("") ErrParse = errors.New("")
ErrIssuerInvalid = errors.New("issuer does not match") ErrIssuerInvalid = errors.New("issuer does not match")
ErrAudience = errors.New("audience is not valid") ErrAudience = errors.New("audience is not valid")
ErrAzpMissing = errors.New("authorized party is not set. If Token is valid for multiple audiences, azp must not be empty") ErrAzpMissing = errors.New("authorized party is not set. If Token is valid for multiple audiences, azp must not be empty")
ErrAzpInvalid = errors.New("authorized party is not valid") ErrAzpInvalid = errors.New("authorized party is not valid")
ErrSignatureMissing = errors.New("id_token does not contain a signature") ErrSignatureMissing = errors.New("id_token does not contain a signature")
ErrSignatureMultiple = errors.New("id_token contains multiple signatures") ErrSignatureMultiple = errors.New("id_token contains multiple signatures")
ErrSignatureUnsupportedAlg = errors.New("signature algorithm not supported") ErrSignatureUnsupportedAlg = errors.New("signature algorithm not supported")
ErrSignatureInvalidPayload = errors.New("signature does not match Payload") ErrSignatureInvalidPayload = errors.New("signature does not match Payload")
ErrExpired = errors.New("token has expired") ErrExpired = errors.New("token has expired")
ErrIatInFuture = errors.New("issuedAt of token is in the future") ErrIatInFuture = errors.New("issuedAt of token is in the future")
ErrIatToOld = errors.New("issuedAt of token is to old") ErrIatToOld = errors.New("issuedAt of token is to old")
// ErrNonceInvalid = errors.New("nonce does not match")
//ErrNonceInvalid = func(expected, actual string) *validationError {
// return ValidationError("nonce does not match. Expected: %s, got: %s", expected, actual)
//}
ErrAcrInvalid = errors.New("acr is invalid") ErrAcrInvalid = errors.New("acr is invalid")
ErrAuthTimeNotPresent = errors.New("claim `auth_time` of token is missing") ErrAuthTimeNotPresent = errors.New("claim `auth_time` of token is missing")
ErrAuthTimeToOld = errors.New("auth time of token is to old") ErrAuthTimeToOld = errors.New("auth time of token is to old")
ErrAtHash = errors.New("at_hash does not correspond to access token") ErrAtHash = errors.New("at_hash does not correspond to access token")
) )
type Verifier interface {
Issuer() string
MaxAgeIAT() time.Duration
Offset() time.Duration
}
type verifierConfig struct {
issuer string
clientID string
nonce string
ignoreAudience bool
ignoreExpiration bool
//iat *iatConfig
acr ACRVerifier
maxAge time.Duration
supportedSignAlgs []string
// httpClient *http.Client
now time.Time
}
//ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim //ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
type ACRVerifier func(string) error type ACRVerifier func(string) error
@ -77,43 +89,30 @@ func ParseToken(tokenString string, claims interface{}) ([]byte, error) {
return payload, err return payload, err
} }
type Verifier interface { func CheckIssuer(claims Claims, issuer string) error {
Issuer() string if claims.GetIssuer() != issuer {
ClientID() string return fmt.Errorf("%w: Expected: %s, got: %s", ErrIssuerInvalid, issuer, claims.GetIssuer())
SupportedSignAlgs() []string
KeySet() KeySet
ACR() ACRVerifier
MaxAge() time.Duration
MaxAgeIAT() time.Duration
Offset() time.Duration
}
func CheckIssuer(issuer string, i Verifier) error {
if i.Issuer() != issuer {
return fmt.Errorf("%w: Expected: %s, got: %s", ErrIssuerInvalid, i.Issuer(), issuer)
} }
return nil return nil
} }
func CheckAudience(audiences []string, i Verifier) error { func CheckAudience(claims Claims, clientID string) error {
if !utils.Contains(audiences, i.ClientID()) { if !utils.Contains(claims.GetAudience(), clientID) {
return fmt.Errorf("%w: Audience must contain client_id %q", ErrAudience, i.ClientID()) return fmt.Errorf("%w: Audience must contain client_id %q", ErrAudience, clientID)
} }
//TODO: check aud trusted //TODO: check aud trusted
return nil return nil
} }
//4. if multiple aud strings --> check if azp func CheckAuthorizedParty(claims Claims, clientID string) error {
//5. if azp --> check azp == client_id if len(claims.GetAudience()) > 1 {
func CheckAuthorizedParty(audiences []string, authorizedParty string, v Verifier) error { if claims.GetAuthorizedParty() == "" {
if len(audiences) > 1 {
if authorizedParty == "" {
return ErrAzpMissing return ErrAzpMissing
} }
} }
if authorizedParty != "" && authorizedParty != v.ClientID() { if claims.GetAuthorizedParty() != "" && claims.GetAuthorizedParty() != clientID {
return fmt.Errorf("%w: azp %q must be equal to client_id %q", ErrAzpInvalid, authorizedParty, v.ClientID()) return fmt.Errorf("%w: azp %q must be equal to client_id %q", ErrAzpInvalid, claims.GetAuthorizedParty(), clientID)
} }
return nil return nil
} }
@ -151,59 +150,59 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl
return nil return nil
} }
func CheckExpiration(expiration time.Time, v Verifier) error { func CheckExpiration(claims Claims, offset time.Duration) error {
expiration = expiration.Round(time.Second) expiration := claims.GetExpiration().Round(time.Second)
if !time.Now().UTC().Add(v.Offset()).Before(expiration) { if !time.Now().UTC().Add(offset).Before(expiration) {
return ErrExpired return ErrExpired
} }
return nil return nil
} }
func CheckIssuedAt(issuedAt time.Time, v Verifier) error { func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error {
issuedAt = issuedAt.Round(time.Second) issuedAt := claims.GetIssuedAt().Round(time.Second)
offset := time.Now().UTC().Add(v.Offset()).Round(time.Second) nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second)
if issuedAt.After(offset) { if issuedAt.After(nowWithOffset) {
return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, offset) return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset)
} }
if v.MaxAgeIAT() == 0 { if maxAgeIAT == 0 {
return nil return nil
} }
maxAge := time.Now().UTC().Add(-v.MaxAgeIAT()).Round(time.Second) maxAge := time.Now().UTC().Add(-maxAgeIAT).Round(time.Second)
if issuedAt.Before(maxAge) { if issuedAt.Before(maxAge) {
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt)) return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt))
} }
return nil return nil
} }
/* func CheckNonce(claims Claims, nonce string) error {
func (v *DefaultVerifier) CheckNonce(nonce string) error { if nonce == "" {
if v.config.nonce == "" {
return nil return nil
} }
if v.config.nonce != nonce { if claims.GetNonce() != nonce {
return ErrNonceInvalid(v.config.nonce, nonce) return fmt.Errorf("%w: expected %q but was %q", ErrNonceInvalid, nonce, claims.GetNonce())
} }
return nil return nil
}*/ }
func CheckAuthorizationContextClassReference(acr string, v Verifier) error {
if v.ACR() != nil { func CheckAuthorizationContextClassReference(claims Claims, acr ACRVerifier) error {
if err := v.ACR()(acr); err != nil { if acr != nil {
if err := acr(claims.GetAuthenticationContextClassReference()); err != nil {
return fmt.Errorf("%w: %v", ErrAcrInvalid, err) return fmt.Errorf("%w: %v", ErrAcrInvalid, err)
} }
} }
return nil return nil
} }
func CheckAuthTime(authTime time.Time, v Verifier) error { func CheckAuthTime(claims Claims, maxAge time.Duration) error {
if v.MaxAge() == 0 { if maxAge == 0 {
return nil return nil
} }
if authTime.IsZero() { if claims.GetAuthTime().IsZero() {
return ErrAuthTimeNotPresent return ErrAuthTimeNotPresent
} }
authTime = authTime.Round(time.Second) authTime := claims.GetAuthTime().Round(time.Second)
maxAge := time.Now().UTC().Add(-v.MaxAge()).Round(time.Second) maxAuthTime := time.Now().UTC().Add(-maxAge).Round(time.Second)
if authTime.Before(maxAge) { if authTime.Before(maxAuthTime) {
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAge.Sub(authTime)) return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime))
} }
return nil return nil
} }

View file

@ -9,7 +9,6 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/rp"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/utils"
) )
@ -18,7 +17,7 @@ type Authorizer interface {
Decoder() utils.Decoder Decoder() utils.Decoder
Encoder() utils.Encoder Encoder() utils.Encoder
Signer() Signer Signer() Signer
IDTokenVerifier() rp.Verifier IDTokenVerifier() IDTokenHintVerifier
Crypto() Crypto Crypto() Crypto
Issuer() string Issuer() string
} }
@ -27,7 +26,7 @@ type Authorizer interface {
//implementing it's own validation mechanism for the auth request //implementing it's own validation mechanism for the auth request
type AuthorizeValidator interface { type AuthorizeValidator interface {
Authorizer Authorizer
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, rp.Verifier) (string, error) ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error)
} }
//ValidationAuthorizer is an extension of Authorizer interface //ValidationAuthorizer is an extension of Authorizer interface
@ -78,6 +77,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
RedirectToLogin(req.GetID(), client, w, r) RedirectToLogin(req.GetID(), client, w, r)
} }
//ParseAuthorizeRequest parsed the http request into a AuthRequest
func ParseAuthorizeRequest(r *http.Request, decoder utils.Decoder) (*oidc.AuthRequest, error) { func ParseAuthorizeRequest(r *http.Request, decoder utils.Decoder) (*oidc.AuthRequest, error) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
@ -91,7 +91,8 @@ func ParseAuthorizeRequest(r *http.Request, decoder utils.Decoder) (*oidc.AuthRe
return authReq, nil return authReq, nil
} }
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier rp.Verifier) (string, error) { //ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (string, error) {
client, err := storage.GetClientByClientID(ctx, authReq.ClientID) client, err := storage.GetClientByClientID(ctx, authReq.ClientID)
if err != nil { if err != nil {
return "", ErrServerError(err.Error()) return "", ErrServerError(err.Error())
@ -108,6 +109,7 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
return ValidateAuthReqIDTokenHint(ctx, authReq.IDTokenHint, verifier) return ValidateAuthReqIDTokenHint(ctx, authReq.IDTokenHint, verifier)
} }
//ValidateAuthReqScopes validates the passed scopes
func ValidateAuthReqScopes(scopes []string) error { func ValidateAuthReqScopes(scopes []string) error {
if len(scopes) == 0 { if len(scopes) == 0 {
return ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.") return ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.")
@ -118,6 +120,7 @@ func ValidateAuthReqScopes(scopes []string) error {
return nil return nil
} }
//ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type
func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.ResponseType) error { func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.ResponseType) error {
if uri == "" { if uri == "" {
return ErrInvalidRequestRedirectURI("The redirect_uri is missing in the request. Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.") return ErrInvalidRequestRedirectURI("The redirect_uri is missing in the request. Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.")
@ -150,6 +153,7 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res
return nil return nil
} }
//ValidateAuthReqResponseType validates the passed response_type to the registered response types
func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType) error { func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType) error {
if responseType == "" { if responseType == "" {
return ErrInvalidRequest("The response type is missing in your request. If you have any questions, you may contact the administrator of the application.") return ErrInvalidRequest("The response type is missing in your request. If you have any questions, you may contact the administrator of the application.")
@ -160,7 +164,9 @@ func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType)
return nil return nil
} }
func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier rp.Verifier) (string, error) { //ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request)
//and returns the `sub` claim
func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier IDTokenHintVerifier) (string, error) {
if idTokenHint == "" { if idTokenHint == "" {
return "", nil return "", nil
} }
@ -171,11 +177,13 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
return claims.Subject, nil return claims.Subject, nil
} }
//RedirectToLogin redirects the end user to the Login UI for authentication
func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *http.Request) { func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *http.Request) {
login := client.LoginURL(authReqID) login := client.LoginURL(authReqID)
http.Redirect(w, r, login, http.StatusFound) http.Redirect(w, r, login, http.StatusFound)
} }
//AuthorizeCallback handles the callback after authentication in the Login UI
func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
params := mux.Vars(r) params := mux.Vars(r)
id := params["id"] id := params["id"]
@ -192,6 +200,7 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
AuthResponse(authReq, authorizer, w, r) AuthResponse(authReq, authorizer, w, r)
} }
//AuthResponse creates the successful authentication response (either code or tokens)
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID()) client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())
if err != nil { if err != nil {
@ -205,6 +214,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
return return
} }
//AuthResponseCode creates the successful code authentication response
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) { func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto()) code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto())
if err != nil { if err != nil {
@ -218,6 +228,7 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques
http.Redirect(w, r, callback, http.StatusFound) http.Redirect(w, r, callback, http.StatusFound)
} }
//AuthResponseToken creates the successful token(s) authentication response
func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer, client Client) { func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer, client Client) {
createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly
resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "") resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "")
@ -234,6 +245,7 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque
http.Redirect(w, r, callback, http.StatusFound) http.Redirect(w, r, callback, http.StatusFound)
} }
//CreateAuthRequestCode creates and stores a code for the auth code response
func CreateAuthRequestCode(ctx context.Context, authReq AuthRequest, storage Storage, crypto Crypto) (string, error) { func CreateAuthRequestCode(ctx context.Context, authReq AuthRequest, storage Storage, crypto Crypto) (string, error) {
code, err := BuildAuthRequestCode(authReq, crypto) code, err := BuildAuthRequestCode(authReq, crypto)
if err != nil { if err != nil {

View file

@ -47,7 +47,7 @@ type DefaultOP struct {
endpoints *endpoints endpoints *endpoints
storage Storage storage Storage
signer Signer signer Signer
verifier rp.Verifier verifier IDTokenHintVerifier
crypto Crypto crypto Crypto
http http.Handler http http.Handler
decoder *schema.Decoder decoder *schema.Decoder
@ -184,7 +184,7 @@ func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts .
p.signer = NewDefaultSigner(ctx, storage, keyCh) p.signer = NewDefaultSigner(ctx, storage, keyCh)
go p.ensureKey(ctx, storage, keyCh, p.timer) go p.ensureKey(ctx, storage, keyCh, p.timer)
p.verifier = rp.NewDefaultVerifier(config.Issuer, "", p, rp.WithIgnoreAudience(), rp.WithIgnoreExpiration()) p.verifier = NewIDTokenHintVerifier(config.Issuer, p)
p.http = CreateRouter(p, p.interceptors...) p.http = CreateRouter(p, p.interceptors...)
@ -238,10 +238,6 @@ func (p *DefaultOP) HandleDiscovery(w http.ResponseWriter, r *http.Request) {
Discover(w, CreateDiscoveryConfig(p, p.Signer())) Discover(w, CreateDiscoveryConfig(p, p.Signer()))
} }
func (p *DefaultOP) Probes() []ProbesFn {
return nil
}
func (p *DefaultOP) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { func (p *DefaultOP) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
keyID := "" keyID := ""
for _, sig := range jws.Signatures { for _, sig := range jws.Signatures {
@ -279,7 +275,7 @@ func (p *DefaultOP) Crypto() Crypto {
return p.crypto return p.crypto
} }
func (p *DefaultOP) ClientJWTVerifier() rp.Verifier { func (p *DefaultOP) ClientJWTVerifier() oidc.Verifier {
return p.verifier return p.verifier
} }
@ -330,7 +326,7 @@ func (p *DefaultOP) HandleEndSession(w http.ResponseWriter, r *http.Request) {
func (p *DefaultOP) DefaultLogoutRedirectURI() string { func (p *DefaultOP) DefaultLogoutRedirectURI() string {
return p.config.DefaultLogoutRedirectURI return p.config.DefaultLogoutRedirectURI
} }
func (p *DefaultOP) IDTokenVerifier() rp.Verifier { func (p *DefaultOP) IDTokenVerifier() IDTokenHintVerifier {
return p.verifier return p.verifier
} }

View file

@ -16,12 +16,11 @@ const (
type OpenIDProvider interface { type OpenIDProvider interface {
Configuration Configuration
HandleKeys(w http.ResponseWriter, r *http.Request)
HttpHandler() http.Handler
Authorizer Authorizer
SessionEnder SessionEnder
Signer() Signer Signer() Signer
Probes() []ProbesFn Probes() []ProbesFn
HttpHandler() http.Handler
} }
type HttpInterceptor func(http.Handler) http.Handler type HttpInterceptor func(http.Handler) http.Handler

View file

@ -5,14 +5,13 @@ import (
"net/http" "net/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/rp"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/utils"
) )
type SessionEnder interface { type SessionEnder interface {
Decoder() utils.Decoder Decoder() utils.Decoder
Storage() Storage Storage() Storage
IDTokenVerifier() rp.Verifier IDTokenVerifier() IDTokenHintVerifier
DefaultLogoutRedirectURI() string DefaultLogoutRedirectURI() string
} }
@ -63,7 +62,7 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
if req.IdTokenHint == "" { if req.IdTokenHint == "" {
return session, nil return session, nil
} }
claims, err := ender.IDTokenVerifier().VerifyIDToken(ctx, req.IdTokenHint) claims, err := VerifyIDTokenHint(ctx, req.IdTokenHint, ender.IDTokenVerifier())
if err != nil { if err != nil {
return nil, ErrInvalidRequest("id_token_hint invalid") return nil, ErrInvalidRequest("id_token_hint invalid")
} }

View file

@ -24,7 +24,7 @@ type Exchanger interface {
type VerifyExchanger interface { type VerifyExchanger interface {
Exchanger Exchanger
ClientJWTVerifier() rp.Verifier ClientJWTVerifier() oidc.Verifier
} }
func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) { func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) {
@ -34,7 +34,8 @@ func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Reque
CodeExchange(w, r, exchanger) CodeExchange(w, r, exchanger)
return return
case string(oidc.GrantTypeBearer): case string(oidc.GrantTypeBearer):
JWTExchange(w, r, exchanger) ex, _ := exchanger.(VerifyExchanger)
JWTExchange(w, r, ex)
return return
case "excahnge": case "excahnge":
TokenExchange(w, r, exchanger) TokenExchange(w, r, exchanger)
@ -161,23 +162,6 @@ func (c ClientJWTVerifier) ClientID() string {
return c.issuer return c.issuer
} }
func (c ClientJWTVerifier) SupportedSignAlgs() []string {
panic("implement me")
}
func (c ClientJWTVerifier) KeySet() oidc.KeySet {
// return c.claims
return nil
}
func (c ClientJWTVerifier) ACR() oidc.ACRVerifier {
panic("implement me")
}
func (c ClientJWTVerifier) MaxAge() time.Duration {
panic("implement me")
}
func (c ClientJWTVerifier) MaxAgeIAT() time.Duration { func (c ClientJWTVerifier) MaxAgeIAT() time.Duration {
//TODO: define in conf/opts //TODO: define in conf/opts
return 1 * time.Hour return 1 * time.Hour
@ -224,15 +208,15 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, exchanger Exchang
return nil, err return nil, err
} }
if err = oidc.CheckAudience(verifier.claims.GetAudience(), verifier); err != nil { if err = oidc.CheckAudience(verifier.claims, verifier.issuer); err != nil {
return nil, err return nil, err
} }
if err = oidc.CheckExpiration(verifier.claims.GetExpiration(), verifier); err != nil { if err = oidc.CheckExpiration(verifier.claims, verifier.Offset()); err != nil {
return nil, err return nil, err
} }
if err = oidc.CheckIssuedAt(verifier.claims.GetIssuedAt(), verifier); err != nil { if err = oidc.CheckIssuedAt(verifier.claims, verifier.MaxAgeIAT(), verifier.Offset()); err != nil {
return nil, err return nil, err
} }

View file

@ -2,14 +2,66 @@ package op
import ( import (
"context" "context"
"time"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
) )
type IDTokenHintVerifier interface { type IDTokenHintVerifier interface {
oidc.Verifier
SupportedSignAlgs() []string
KeySet() oidc.KeySet
ACR() oidc.ACRVerifier
MaxAge() time.Duration
} }
//VerifyIDToken validates the id token according to type idTokenHintVerifier struct {
issuer string
maxAgeIAT time.Duration
offset time.Duration
supportedSignAlgs []string
maxAge time.Duration
acr oidc.ACRVerifier
keySet oidc.KeySet
}
func (i *idTokenHintVerifier) Issuer() string {
return i.issuer
}
func (i *idTokenHintVerifier) MaxAgeIAT() time.Duration {
return i.maxAgeIAT
}
func (i *idTokenHintVerifier) Offset() time.Duration {
return i.offset
}
func (i *idTokenHintVerifier) SupportedSignAlgs() []string {
return i.supportedSignAlgs
}
func (i *idTokenHintVerifier) KeySet() oidc.KeySet {
return i.keySet
}
func (i *idTokenHintVerifier) ACR() oidc.ACRVerifier {
return i.acr
}
func (i *idTokenHintVerifier) MaxAge() time.Duration {
return i.maxAge
}
func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet) IDTokenHintVerifier {
verifier := &idTokenHintVerifier{
issuer: issuer,
keySet: keySet,
}
return verifier
}
//VerifyIDTokenHint validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation //https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (*oidc.IDTokenClaims, error) { func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (*oidc.IDTokenClaims, error) {
claims := new(oidc.IDTokenClaims) claims := new(oidc.IDTokenClaims)
@ -22,51 +74,28 @@ func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier)
if err != nil { if err != nil {
return nil, err return nil, err
} }
//2, check issuer (exact match)
if err := oidc.CheckIssuer(claims.GetIssuer(), v); err != nil { if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
return nil, err return nil, err
} }
//3. check aud (aud must contain client_id, all aud strings must be allowed)
if err = oidc.CheckAudience(claims.GetAudience(), v); err != nil {
return nil, err
}
if err = oidc.CheckAuthorizedParty(claims.GetAudience(), claims.GetAuthorizedParty(), v); err != nil {
return nil, err
}
//6. check signature by keys
//7. check alg default is rs256
//8. check if alg is mac based (hs...) -> audience contains client_id. for validation use utf-8 representation of your client_secret
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
return nil, err return nil, err
} }
//9. check exp before now if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
if err = oidc.CheckExpiration(claims.GetExpiration(), v); err != nil {
return nil, err return nil, err
} }
//10. check iat duration is optional (can be checked) if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
if err = oidc.CheckIssuedAt(claims.GetIssuedAt(), v); err != nil {
return nil, err return nil, err
} }
/* if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
//11. check nonce (check if optional possible) id_token.nonce == sentNonce
if err = oidc.CheckNonce(claims.GetNonce()); err != nil {
return nil, err
}
*/
//12. if acr requested check acr
if err = oidc.CheckAuthorizationContextClassReference(claims.GetAuthenticationContextClassReference(), v); err != nil {
return nil, err return nil, err
} }
//13. if auth_time requested check if auth_time is less than max age if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
if err = oidc.CheckAuthTime(claims.GetAuthTime(), v); err != nil {
return nil, err return nil, err
} }
return claims, nil return claims, nil

View file

@ -27,7 +27,7 @@ var (
} }
) )
//DefaultRP impements the `DelegationTokenExchangeRP` interface extending the `RelayingParty` interface //DefaultRP implements the `DelegationTokenExchangeRP` interface extending the `RelayingParty` interface
type DefaultRP struct { type DefaultRP struct {
endpoints Endpoints endpoints Endpoints
@ -40,7 +40,7 @@ type DefaultRP struct {
errorHandler func(http.ResponseWriter, *http.Request, string, string, string) errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
verifier Verifier idTokenVerifier IDTokenVerifier
verifierOpts []ConfFunc verifierOpts []ConfFunc
onlyOAuth2 bool onlyOAuth2 bool
} }
@ -79,8 +79,8 @@ func NewDefaultRP(rpConfig *Config, rpOpts ...DefaultRPOpts) (DelegationTokenExc
p.errorHandler = DefaultErrorHandler p.errorHandler = DefaultErrorHandler
} }
if p.verifier == nil { if p.idTokenVerifier == nil {
p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL), p.verifierOpts...) p.idTokenVerifier = NewIDTokenVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL))
} }
return p, nil return p, nil
@ -181,7 +181,7 @@ func (p *DefaultRP) CodeExchange(ctx context.Context, code string, opts ...CodeE
idToken := new(oidc.IDTokenClaims) idToken := new(oidc.IDTokenClaims)
if !p.onlyOAuth2 { if !p.onlyOAuth2 {
idToken, err = p.verifier.Verify(ctx, token.AccessToken, idTokenString) idToken, err = VerifyTokens(ctx, token.AccessToken, idTokenString, p.idTokenVerifier)
if err != nil { if err != nil {
return nil, err //TODO: err return nil, err //TODO: err
} }

View file

@ -15,7 +15,7 @@ type DefaultVerifier struct {
keySet oidc.KeySet keySet oidc.KeySet
} }
//ConfFunc is the type for providing dynamic options to the DefaultVerfifier //ConfFunc is the type for providing dynamic options to the DefaultVerifier
type ConfFunc func(*verifierConfig) type ConfFunc func(*verifierConfig)
//NewDefaultVerifier creates `DefaultVerifier` with the given //NewDefaultVerifier creates `DefaultVerifier` with the given
@ -145,7 +145,7 @@ func (v *DefaultVerifier) Verify(ctx context.Context, accessToken, idTokenString
//Verify implements the `VerifyIDToken` method of the `Verifier` interface //Verify implements the `VerifyIDToken` method of the `Verifier` interface
//according to https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation //according to https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func (v *DefaultVerifier) VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) { func (v *DefaultVerifier) VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) {
return VerifywIDToken(ctx, idTokenString, v) return VerifyIDToken(ctx, idTokenString, v)
} }
func (v *DefaultVerifier) now() time.Time { func (v *DefaultVerifier) now() time.Time {
@ -186,3 +186,7 @@ func (v *DefaultVerifier) MaxAgeIAT() time.Duration {
func (v *DefaultVerifier) Offset() time.Duration { func (v *DefaultVerifier) Offset() time.Duration {
return v.config.iat.offset return v.config.iat.offset
} }
func (v *DefaultVerifier) Nonce(ctx context.Context) string {
return ""
}

View file

@ -2,6 +2,7 @@ package rp
import ( import (
"context" "context"
"time"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
@ -20,6 +21,69 @@ type Verifier interface {
type IDTokenVerifier interface { type IDTokenVerifier interface {
oidc.Verifier oidc.Verifier
ClientID() string
SupportedSignAlgs() []string
KeySet() oidc.KeySet
Nonce(context.Context) string
ACR() oidc.ACRVerifier
MaxAge() time.Duration
}
type idTokenVerifier struct {
issuer string
maxAgeIAT time.Duration
offset time.Duration
clientID string
supportedSignAlgs []string
keySet oidc.KeySet
acr oidc.ACRVerifier
maxAge time.Duration
nonce func(ctx context.Context) string
}
func (i *idTokenVerifier) Issuer() string {
return i.issuer
}
func (i *idTokenVerifier) MaxAgeIAT() time.Duration {
return i.maxAgeIAT
}
func (i *idTokenVerifier) Offset() time.Duration {
return i.offset
}
func (i *idTokenVerifier) ClientID() string {
return i.clientID
}
func (i *idTokenVerifier) SupportedSignAlgs() []string {
return i.supportedSignAlgs
}
func (i *idTokenVerifier) KeySet() oidc.KeySet {
return i.keySet
}
func (i *idTokenVerifier) Nonce(ctx context.Context) string {
return i.nonce(ctx)
}
func (i *idTokenVerifier) ACR() oidc.ACRVerifier {
return i.acr
}
func (i *idTokenVerifier) MaxAge() time.Duration {
return i.maxAge
}
func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet) IDTokenVerifier {
return &idTokenVerifier{
issuer: issuer,
clientID: clientID,
keySet: keySet,
offset: 5 * time.Second,
}
} }
//VerifyTokens implement the Token Response Validation as defined in OIDC specification //VerifyTokens implement the Token Response Validation as defined in OIDC specification
@ -48,51 +112,40 @@ func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (*oidc.
if err != nil { if err != nil {
return nil, err return nil, err
} }
//2, check issuer (exact match)
if err := oidc.CheckIssuer(claims.GetIssuer(), v); err != nil { if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
return nil, err return nil, err
} }
//3. check aud (aud must contain client_id, all aud strings must be allowed) if err = oidc.CheckAudience(claims, v.ClientID()); err != nil {
if err = oidc.CheckAudience(claims.GetAudience(), v); err != nil {
return nil, err return nil, err
} }
if err = oidc.CheckAuthorizedParty(claims.GetAudience(), claims.GetAuthorizedParty(), v); err != nil { if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil {
return nil, err return nil, err
} }
//6. check signature by keys
//7. check alg default is rs256
//8. check if alg is mac based (hs...) -> audience contains client_id. for validation use utf-8 representation of your client_secret
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
return nil, err return nil, err
} }
//9. check exp before now if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
if err = oidc.CheckExpiration(claims.GetExpiration(), v); err != nil {
return nil, err return nil, err
} }
//10. check iat duration is optional (can be checked) if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
if err = oidc.CheckIssuedAt(claims.GetIssuedAt(), v); err != nil {
return nil, err return nil, err
} }
/* if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil {
//11. check nonce (check if optional possible) id_token.nonce == sentNonce
if err = oidc.CheckNonce(claims.GetNonce()); err != nil {
return nil, err
}
*/
//12. if acr requested check acr
if err = oidc.CheckAuthorizationContextClassReference(claims.GetAuthenticationContextClassReference(), v); err != nil {
return nil, err return nil, err
} }
//13. if auth_time requested check if auth_time is less than max age if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
if err = oidc.CheckAuthTime(claims.GetAuthTime(), v); err != nil { return nil, err
}
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
return nil, err return nil, err
} }
return claims, nil return claims, nil