diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index febb28c..fc21e4c 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -151,8 +151,8 @@ func (s *AuthStorage) AuthRequestByID(_ context.Context, id string) (op.AuthRequ } return a, nil } -func (s *AuthStorage) CreateToken(_ context.Context, authReq op.AuthRequest) (string, time.Time, error) { - return authReq.GetID(), time.Now().UTC().Add(5 * time.Minute), nil +func (s *AuthStorage) CreateToken(_ context.Context, authReq op.TokenRequest) (string, time.Time, error) { + return "authReq.GetID()", time.Now().UTC().Add(5 * time.Minute), nil } func (s *AuthStorage) TerminateSession(_ context.Context, userID, clientID string) error { return nil @@ -174,6 +174,22 @@ func (s *AuthStorage) GetKeySet(_ context.Context) (*jose.JSONWebKeySet, error) }, }, nil } +func (s *AuthStorage) GetKeyByID(_ context.Context, _ string) (*jose.JSONWebKeySet, error) { + pubkey := s.key.Public() + return &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + {Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"}, + }, + }, nil +} +func (s *AuthStorage) GetKeysByServiceAccount(_ context.Context, _ string) (*jose.JSONWebKeySet, error) { + pubkey := s.key.Public() + return &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + {Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"}, + }, + }, nil +} func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Client, error) { if id == "none" { @@ -182,20 +198,24 @@ func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Clie var appType op.ApplicationType var authMethod op.AuthMethod var accessTokenType op.AccessTokenType + var responseTypes []oidc.ResponseType if id == "web" { appType = op.ApplicationTypeWeb authMethod = op.AuthMethodBasic accessTokenType = op.AccessTokenTypeBearer + responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode} } else if id == "native" { appType = op.ApplicationTypeNative authMethod = op.AuthMethodNone accessTokenType = op.AccessTokenTypeBearer + responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode} } else { appType = op.ApplicationTypeUserAgent authMethod = op.AuthMethodNone accessTokenType = op.AccessTokenTypeJWT + responseTypes = []oidc.ResponseType{oidc.ResponseTypeIDToken, oidc.ResponseTypeIDTokenOnly} } - return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod, accessTokenType: accessTokenType, devMode: false}, nil + return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod, accessTokenType: accessTokenType, responseTypes: responseTypes, devMode: false}, nil } func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error { diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go index 6e3e699..47012ad 100644 --- a/pkg/oidc/verifier.go +++ b/pkg/oidc/verifier.go @@ -28,35 +28,47 @@ type Claims interface { } var ( - ErrParse = errors.New("") - ErrIssuerInvalid = errors.New("issuer does not match") - - ErrAudience = errors.New("audience is not valid") - - ErrAzpMissing = errors.New("authorized party is not set. If Token is valid for multiple audiences, azp must not be empty") - ErrAzpInvalid = errors.New("authorized party is not valid") - + ErrParse = errors.New("") + ErrIssuerInvalid = errors.New("issuer does not match") + ErrAudience = errors.New("audience is not valid") + ErrAzpMissing = errors.New("authorized party is not set. If Token is valid for multiple audiences, azp must not be empty") + ErrAzpInvalid = errors.New("authorized party is not valid") ErrSignatureMissing = errors.New("id_token does not contain a signature") ErrSignatureMultiple = errors.New("id_token contains multiple signatures") ErrSignatureUnsupportedAlg = errors.New("signature algorithm not supported") ErrSignatureInvalidPayload = errors.New("signature does not match Payload") - - ErrExpired = errors.New("token has expired") - - ErrIatInFuture = errors.New("issuedAt of token is in the future") - - ErrIatToOld = errors.New("issuedAt of token is to old") - // - //ErrNonceInvalid = func(expected, actual string) *validationError { - // return ValidationError("nonce does not match. Expected: %s, got: %s", expected, actual) - //} - ErrAcrInvalid = errors.New("acr is invalid") - ErrAuthTimeNotPresent = errors.New("claim `auth_time` of token is missing") - ErrAuthTimeToOld = errors.New("auth time of token is to old") - - ErrAtHash = errors.New("at_hash does not correspond to access token") + ErrExpired = errors.New("token has expired") + ErrIatInFuture = errors.New("issuedAt of token is in the future") + ErrIatToOld = errors.New("issuedAt of token is to old") + ErrNonceInvalid = errors.New("nonce does not match") + ErrAcrInvalid = errors.New("acr is invalid") + ErrAuthTimeNotPresent = errors.New("claim `auth_time` of token is missing") + ErrAuthTimeToOld = errors.New("auth time of token is to old") + ErrAtHash = errors.New("at_hash does not correspond to access token") ) +type Verifier interface { + Issuer() string + MaxAgeIAT() time.Duration + Offset() time.Duration +} + +type verifierConfig struct { + issuer string + clientID string + nonce string + ignoreAudience bool + ignoreExpiration bool + //iat *iatConfig + acr ACRVerifier + maxAge time.Duration + supportedSignAlgs []string + + // httpClient *http.Client + + now time.Time +} + //ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim type ACRVerifier func(string) error @@ -77,43 +89,30 @@ func ParseToken(tokenString string, claims interface{}) ([]byte, error) { return payload, err } -type Verifier interface { - Issuer() string - ClientID() string - SupportedSignAlgs() []string - KeySet() KeySet - ACR() ACRVerifier - MaxAge() time.Duration - MaxAgeIAT() time.Duration - Offset() time.Duration -} - -func CheckIssuer(issuer string, i Verifier) error { - if i.Issuer() != issuer { - return fmt.Errorf("%w: Expected: %s, got: %s", ErrIssuerInvalid, i.Issuer(), issuer) +func CheckIssuer(claims Claims, issuer string) error { + if claims.GetIssuer() != issuer { + return fmt.Errorf("%w: Expected: %s, got: %s", ErrIssuerInvalid, issuer, claims.GetIssuer()) } return nil } -func CheckAudience(audiences []string, i Verifier) error { - if !utils.Contains(audiences, i.ClientID()) { - return fmt.Errorf("%w: Audience must contain client_id %q", ErrAudience, i.ClientID()) +func CheckAudience(claims Claims, clientID string) error { + if !utils.Contains(claims.GetAudience(), clientID) { + return fmt.Errorf("%w: Audience must contain client_id %q", ErrAudience, clientID) } //TODO: check aud trusted return nil } -//4. if multiple aud strings --> check if azp -//5. if azp --> check azp == client_id -func CheckAuthorizedParty(audiences []string, authorizedParty string, v Verifier) error { - if len(audiences) > 1 { - if authorizedParty == "" { +func CheckAuthorizedParty(claims Claims, clientID string) error { + if len(claims.GetAudience()) > 1 { + if claims.GetAuthorizedParty() == "" { return ErrAzpMissing } } - if authorizedParty != "" && authorizedParty != v.ClientID() { - return fmt.Errorf("%w: azp %q must be equal to client_id %q", ErrAzpInvalid, authorizedParty, v.ClientID()) + if claims.GetAuthorizedParty() != "" && claims.GetAuthorizedParty() != clientID { + return fmt.Errorf("%w: azp %q must be equal to client_id %q", ErrAzpInvalid, claims.GetAuthorizedParty(), clientID) } return nil } @@ -151,59 +150,59 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl return nil } -func CheckExpiration(expiration time.Time, v Verifier) error { - expiration = expiration.Round(time.Second) - if !time.Now().UTC().Add(v.Offset()).Before(expiration) { +func CheckExpiration(claims Claims, offset time.Duration) error { + expiration := claims.GetExpiration().Round(time.Second) + if !time.Now().UTC().Add(offset).Before(expiration) { return ErrExpired } return nil } -func CheckIssuedAt(issuedAt time.Time, v Verifier) error { - issuedAt = issuedAt.Round(time.Second) - offset := time.Now().UTC().Add(v.Offset()).Round(time.Second) - if issuedAt.After(offset) { - return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, offset) +func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error { + issuedAt := claims.GetIssuedAt().Round(time.Second) + nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second) + if issuedAt.After(nowWithOffset) { + return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset) } - if v.MaxAgeIAT() == 0 { + if maxAgeIAT == 0 { return nil } - maxAge := time.Now().UTC().Add(-v.MaxAgeIAT()).Round(time.Second) + maxAge := time.Now().UTC().Add(-maxAgeIAT).Round(time.Second) if issuedAt.Before(maxAge) { return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt)) } return nil } -/* -func (v *DefaultVerifier) CheckNonce(nonce string) error { - if v.config.nonce == "" { +func CheckNonce(claims Claims, nonce string) error { + if nonce == "" { return nil } - if v.config.nonce != nonce { - return ErrNonceInvalid(v.config.nonce, nonce) + if claims.GetNonce() != nonce { + return fmt.Errorf("%w: expected %q but was %q", ErrNonceInvalid, nonce, claims.GetNonce()) } return nil -}*/ -func CheckAuthorizationContextClassReference(acr string, v Verifier) error { - if v.ACR() != nil { - if err := v.ACR()(acr); err != nil { +} + +func CheckAuthorizationContextClassReference(claims Claims, acr ACRVerifier) error { + if acr != nil { + if err := acr(claims.GetAuthenticationContextClassReference()); err != nil { return fmt.Errorf("%w: %v", ErrAcrInvalid, err) } } return nil } -func CheckAuthTime(authTime time.Time, v Verifier) error { - if v.MaxAge() == 0 { +func CheckAuthTime(claims Claims, maxAge time.Duration) error { + if maxAge == 0 { return nil } - if authTime.IsZero() { + if claims.GetAuthTime().IsZero() { return ErrAuthTimeNotPresent } - authTime = authTime.Round(time.Second) - maxAge := time.Now().UTC().Add(-v.MaxAge()).Round(time.Second) - if authTime.Before(maxAge) { - return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAge.Sub(authTime)) + authTime := claims.GetAuthTime().Round(time.Second) + maxAuthTime := time.Now().UTC().Add(-maxAge).Round(time.Second) + if authTime.Before(maxAuthTime) { + return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime)) } return nil } diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index ba93c5a..aa8a36c 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -9,7 +9,6 @@ import ( "github.com/gorilla/mux" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/rp" "github.com/caos/oidc/pkg/utils" ) @@ -18,7 +17,7 @@ type Authorizer interface { Decoder() utils.Decoder Encoder() utils.Encoder Signer() Signer - IDTokenVerifier() rp.Verifier + IDTokenVerifier() IDTokenHintVerifier Crypto() Crypto Issuer() string } @@ -27,10 +26,10 @@ type Authorizer interface { //implementing it's own validation mechanism for the auth request type AuthorizeValidator interface { Authorizer - ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, rp.Verifier) (string, error) + ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error) } -//ValidationAuthorizer is an extension of Authorizer interface +//ValidationAuthorizer is an extension of Authorizer interface //implementing it's own validation mechanism for the auth request // //Deprecated: ValidationAuthorizer exists for historical compatibility. Use ValidationAuthorizer itself @@ -78,6 +77,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { RedirectToLogin(req.GetID(), client, w, r) } +//ParseAuthorizeRequest parsed the http request into a AuthRequest func ParseAuthorizeRequest(r *http.Request, decoder utils.Decoder) (*oidc.AuthRequest, error) { err := r.ParseForm() if err != nil { @@ -91,7 +91,8 @@ func ParseAuthorizeRequest(r *http.Request, decoder utils.Decoder) (*oidc.AuthRe return authReq, nil } -func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier rp.Verifier) (string, error) { +//ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed +func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (string, error) { client, err := storage.GetClientByClientID(ctx, authReq.ClientID) if err != nil { return "", ErrServerError(err.Error()) @@ -108,6 +109,7 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage return ValidateAuthReqIDTokenHint(ctx, authReq.IDTokenHint, verifier) } +//ValidateAuthReqScopes validates the passed scopes func ValidateAuthReqScopes(scopes []string) error { if len(scopes) == 0 { return ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.") @@ -118,6 +120,7 @@ func ValidateAuthReqScopes(scopes []string) error { return nil } +//ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.ResponseType) error { if uri == "" { return ErrInvalidRequestRedirectURI("The redirect_uri is missing in the request. Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.") @@ -150,6 +153,7 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res return nil } +//ValidateAuthReqResponseType validates the passed response_type to the registered response types func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType) error { if responseType == "" { return ErrInvalidRequest("The response type is missing in your request. If you have any questions, you may contact the administrator of the application.") @@ -160,7 +164,9 @@ func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType) return nil } -func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier rp.Verifier) (string, error) { +//ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request) +//and returns the `sub` claim +func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier IDTokenHintVerifier) (string, error) { if idTokenHint == "" { return "", nil } @@ -171,11 +177,13 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie return claims.Subject, nil } +//RedirectToLogin redirects the end user to the Login UI for authentication func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *http.Request) { login := client.LoginURL(authReqID) http.Redirect(w, r, login, http.StatusFound) } +//AuthorizeCallback handles the callback after authentication in the Login UI func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { params := mux.Vars(r) id := params["id"] @@ -192,6 +200,7 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author AuthResponse(authReq, authorizer, w, r) } +//AuthResponse creates the successful authentication response (either code or tokens) func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID()) if err != nil { @@ -205,6 +214,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri return } +//AuthResponseCode creates the successful code authentication response func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) { code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto()) if err != nil { @@ -218,6 +228,7 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques http.Redirect(w, r, callback, http.StatusFound) } +//AuthResponseToken creates the successful token(s) authentication response func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer, client Client) { createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "") @@ -234,6 +245,7 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque http.Redirect(w, r, callback, http.StatusFound) } +//CreateAuthRequestCode creates and stores a code for the auth code response func CreateAuthRequestCode(ctx context.Context, authReq AuthRequest, storage Storage, crypto Crypto) (string, error) { code, err := BuildAuthRequestCode(authReq, crypto) if err != nil { diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index cd21b03..374cfb7 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -47,7 +47,7 @@ type DefaultOP struct { endpoints *endpoints storage Storage signer Signer - verifier rp.Verifier + verifier IDTokenHintVerifier crypto Crypto http http.Handler decoder *schema.Decoder @@ -184,7 +184,7 @@ func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts . p.signer = NewDefaultSigner(ctx, storage, keyCh) go p.ensureKey(ctx, storage, keyCh, p.timer) - p.verifier = rp.NewDefaultVerifier(config.Issuer, "", p, rp.WithIgnoreAudience(), rp.WithIgnoreExpiration()) + p.verifier = NewIDTokenHintVerifier(config.Issuer, p) p.http = CreateRouter(p, p.interceptors...) @@ -238,10 +238,6 @@ func (p *DefaultOP) HandleDiscovery(w http.ResponseWriter, r *http.Request) { Discover(w, CreateDiscoveryConfig(p, p.Signer())) } -func (p *DefaultOP) Probes() []ProbesFn { - return nil -} - func (p *DefaultOP) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { keyID := "" for _, sig := range jws.Signatures { @@ -279,7 +275,7 @@ func (p *DefaultOP) Crypto() Crypto { return p.crypto } -func (p *DefaultOP) ClientJWTVerifier() rp.Verifier { +func (p *DefaultOP) ClientJWTVerifier() oidc.Verifier { return p.verifier } @@ -330,7 +326,7 @@ func (p *DefaultOP) HandleEndSession(w http.ResponseWriter, r *http.Request) { func (p *DefaultOP) DefaultLogoutRedirectURI() string { return p.config.DefaultLogoutRedirectURI } -func (p *DefaultOP) IDTokenVerifier() rp.Verifier { +func (p *DefaultOP) IDTokenVerifier() IDTokenHintVerifier { return p.verifier } diff --git a/pkg/op/op.go b/pkg/op/op.go index 2a65320..153e015 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -16,12 +16,11 @@ const ( type OpenIDProvider interface { Configuration - HandleKeys(w http.ResponseWriter, r *http.Request) - HttpHandler() http.Handler Authorizer SessionEnder Signer() Signer Probes() []ProbesFn + HttpHandler() http.Handler } type HttpInterceptor func(http.Handler) http.Handler diff --git a/pkg/op/session.go b/pkg/op/session.go index abbc114..5a3936a 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -5,14 +5,13 @@ import ( "net/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/rp" "github.com/caos/oidc/pkg/utils" ) type SessionEnder interface { Decoder() utils.Decoder Storage() Storage - IDTokenVerifier() rp.Verifier + IDTokenVerifier() IDTokenHintVerifier DefaultLogoutRedirectURI() string } @@ -63,7 +62,7 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, if req.IdTokenHint == "" { return session, nil } - claims, err := ender.IDTokenVerifier().VerifyIDToken(ctx, req.IdTokenHint) + claims, err := VerifyIDTokenHint(ctx, req.IdTokenHint, ender.IDTokenVerifier()) if err != nil { return nil, ErrInvalidRequest("id_token_hint invalid") } diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index 0ab61f6..72a688c 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -24,7 +24,7 @@ type Exchanger interface { type VerifyExchanger interface { Exchanger - ClientJWTVerifier() rp.Verifier + ClientJWTVerifier() oidc.Verifier } func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) { @@ -34,7 +34,8 @@ func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Reque CodeExchange(w, r, exchanger) return case string(oidc.GrantTypeBearer): - JWTExchange(w, r, exchanger) + ex, _ := exchanger.(VerifyExchanger) + JWTExchange(w, r, ex) return case "excahnge": TokenExchange(w, r, exchanger) @@ -161,23 +162,6 @@ func (c ClientJWTVerifier) ClientID() string { return c.issuer } -func (c ClientJWTVerifier) SupportedSignAlgs() []string { - panic("implement me") -} - -func (c ClientJWTVerifier) KeySet() oidc.KeySet { - // return c.claims - return nil -} - -func (c ClientJWTVerifier) ACR() oidc.ACRVerifier { - panic("implement me") -} - -func (c ClientJWTVerifier) MaxAge() time.Duration { - panic("implement me") -} - func (c ClientJWTVerifier) MaxAgeIAT() time.Duration { //TODO: define in conf/opts return 1 * time.Hour @@ -224,15 +208,15 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, exchanger Exchang return nil, err } - if err = oidc.CheckAudience(verifier.claims.GetAudience(), verifier); err != nil { + if err = oidc.CheckAudience(verifier.claims, verifier.issuer); err != nil { return nil, err } - if err = oidc.CheckExpiration(verifier.claims.GetExpiration(), verifier); err != nil { + if err = oidc.CheckExpiration(verifier.claims, verifier.Offset()); err != nil { return nil, err } - if err = oidc.CheckIssuedAt(verifier.claims.GetIssuedAt(), verifier); err != nil { + if err = oidc.CheckIssuedAt(verifier.claims, verifier.MaxAgeIAT(), verifier.Offset()); err != nil { return nil, err } diff --git a/pkg/op/verifier.go b/pkg/op/verifier.go index 863428c..3268a5e 100644 --- a/pkg/op/verifier.go +++ b/pkg/op/verifier.go @@ -2,14 +2,66 @@ package op import ( "context" + "time" "github.com/caos/oidc/pkg/oidc" ) type IDTokenHintVerifier interface { + oidc.Verifier + SupportedSignAlgs() []string + KeySet() oidc.KeySet + ACR() oidc.ACRVerifier + MaxAge() time.Duration } -//VerifyIDToken validates the id token according to +type idTokenHintVerifier struct { + issuer string + maxAgeIAT time.Duration + offset time.Duration + supportedSignAlgs []string + maxAge time.Duration + acr oidc.ACRVerifier + keySet oidc.KeySet +} + +func (i *idTokenHintVerifier) Issuer() string { + return i.issuer +} + +func (i *idTokenHintVerifier) MaxAgeIAT() time.Duration { + return i.maxAgeIAT +} + +func (i *idTokenHintVerifier) Offset() time.Duration { + return i.offset +} + +func (i *idTokenHintVerifier) SupportedSignAlgs() []string { + return i.supportedSignAlgs +} + +func (i *idTokenHintVerifier) KeySet() oidc.KeySet { + return i.keySet +} + +func (i *idTokenHintVerifier) ACR() oidc.ACRVerifier { + return i.acr +} + +func (i *idTokenHintVerifier) MaxAge() time.Duration { + return i.maxAge +} + +func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet) IDTokenHintVerifier { + verifier := &idTokenHintVerifier{ + issuer: issuer, + keySet: keySet, + } + return verifier +} + +//VerifyIDTokenHint validates the id token according to //https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (*oidc.IDTokenClaims, error) { claims := new(oidc.IDTokenClaims) @@ -22,51 +74,28 @@ func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) if err != nil { return nil, err } - //2, check issuer (exact match) - if err := oidc.CheckIssuer(claims.GetIssuer(), v); err != nil { + + if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { return nil, err } - //3. check aud (aud must contain client_id, all aud strings must be allowed) - if err = oidc.CheckAudience(claims.GetAudience(), v); err != nil { - return nil, err - } - - if err = oidc.CheckAuthorizedParty(claims.GetAudience(), claims.GetAuthorizedParty(), v); err != nil { - return nil, err - } - - //6. check signature by keys - //7. check alg default is rs256 - //8. check if alg is mac based (hs...) -> audience contains client_id. for validation use utf-8 representation of your client_secret if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { return nil, err } - //9. check exp before now - if err = oidc.CheckExpiration(claims.GetExpiration(), v); err != nil { + if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { return nil, err } - //10. check iat duration is optional (can be checked) - if err = oidc.CheckIssuedAt(claims.GetIssuedAt(), v); err != nil { + if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil { return nil, err } - /* - //11. check nonce (check if optional possible) id_token.nonce == sentNonce - if err = oidc.CheckNonce(claims.GetNonce()); err != nil { - return nil, err - } - */ - - //12. if acr requested check acr - if err = oidc.CheckAuthorizationContextClassReference(claims.GetAuthenticationContextClassReference(), v); err != nil { + if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil { return nil, err } - //13. if auth_time requested check if auth_time is less than max age - if err = oidc.CheckAuthTime(claims.GetAuthTime(), v); err != nil { + if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil { return nil, err } return claims, nil diff --git a/pkg/rp/default_rp.go b/pkg/rp/default_rp.go index 3a830bb..f701ec1 100644 --- a/pkg/rp/default_rp.go +++ b/pkg/rp/default_rp.go @@ -27,7 +27,7 @@ var ( } ) -//DefaultRP impements the `DelegationTokenExchangeRP` interface extending the `RelayingParty` interface +//DefaultRP implements the `DelegationTokenExchangeRP` interface extending the `RelayingParty` interface type DefaultRP struct { endpoints Endpoints @@ -40,9 +40,9 @@ type DefaultRP struct { errorHandler func(http.ResponseWriter, *http.Request, string, string, string) - verifier Verifier - verifierOpts []ConfFunc - onlyOAuth2 bool + idTokenVerifier IDTokenVerifier + verifierOpts []ConfFunc + onlyOAuth2 bool } //NewDefaultRP creates `DefaultRP` with the given @@ -79,8 +79,8 @@ func NewDefaultRP(rpConfig *Config, rpOpts ...DefaultRPOpts) (DelegationTokenExc p.errorHandler = DefaultErrorHandler } - if p.verifier == nil { - p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL), p.verifierOpts...) + if p.idTokenVerifier == nil { + p.idTokenVerifier = NewIDTokenVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL)) } return p, nil @@ -181,7 +181,7 @@ func (p *DefaultRP) CodeExchange(ctx context.Context, code string, opts ...CodeE idToken := new(oidc.IDTokenClaims) if !p.onlyOAuth2 { - idToken, err = p.verifier.Verify(ctx, token.AccessToken, idTokenString) + idToken, err = VerifyTokens(ctx, token.AccessToken, idTokenString, p.idTokenVerifier) if err != nil { return nil, err //TODO: err } diff --git a/pkg/rp/default_verifier.go b/pkg/rp/default_verifier.go index 45d8373..c13c135 100644 --- a/pkg/rp/default_verifier.go +++ b/pkg/rp/default_verifier.go @@ -15,7 +15,7 @@ type DefaultVerifier struct { keySet oidc.KeySet } -//ConfFunc is the type for providing dynamic options to the DefaultVerfifier +//ConfFunc is the type for providing dynamic options to the DefaultVerifier type ConfFunc func(*verifierConfig) //NewDefaultVerifier creates `DefaultVerifier` with the given @@ -145,7 +145,7 @@ func (v *DefaultVerifier) Verify(ctx context.Context, accessToken, idTokenString //Verify implements the `VerifyIDToken` method of the `Verifier` interface //according to https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation func (v *DefaultVerifier) VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) { - return VerifywIDToken(ctx, idTokenString, v) + return VerifyIDToken(ctx, idTokenString, v) } func (v *DefaultVerifier) now() time.Time { @@ -186,3 +186,7 @@ func (v *DefaultVerifier) MaxAgeIAT() time.Duration { func (v *DefaultVerifier) Offset() time.Duration { return v.config.iat.offset } + +func (v *DefaultVerifier) Nonce(ctx context.Context) string { + return "" +} diff --git a/pkg/rp/verifier.go b/pkg/rp/verifier.go index 27d30cb..c98134e 100644 --- a/pkg/rp/verifier.go +++ b/pkg/rp/verifier.go @@ -2,6 +2,7 @@ package rp import ( "context" + "time" "gopkg.in/square/go-jose.v2" @@ -20,6 +21,69 @@ type Verifier interface { type IDTokenVerifier interface { oidc.Verifier + ClientID() string + SupportedSignAlgs() []string + KeySet() oidc.KeySet + Nonce(context.Context) string + ACR() oidc.ACRVerifier + MaxAge() time.Duration +} + +type idTokenVerifier struct { + issuer string + maxAgeIAT time.Duration + offset time.Duration + clientID string + supportedSignAlgs []string + keySet oidc.KeySet + acr oidc.ACRVerifier + maxAge time.Duration + nonce func(ctx context.Context) string +} + +func (i *idTokenVerifier) Issuer() string { + return i.issuer +} + +func (i *idTokenVerifier) MaxAgeIAT() time.Duration { + return i.maxAgeIAT +} + +func (i *idTokenVerifier) Offset() time.Duration { + return i.offset +} + +func (i *idTokenVerifier) ClientID() string { + return i.clientID +} + +func (i *idTokenVerifier) SupportedSignAlgs() []string { + return i.supportedSignAlgs +} + +func (i *idTokenVerifier) KeySet() oidc.KeySet { + return i.keySet +} + +func (i *idTokenVerifier) Nonce(ctx context.Context) string { + return i.nonce(ctx) +} + +func (i *idTokenVerifier) ACR() oidc.ACRVerifier { + return i.acr +} + +func (i *idTokenVerifier) MaxAge() time.Duration { + return i.maxAge +} + +func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet) IDTokenVerifier { + return &idTokenVerifier{ + issuer: issuer, + clientID: clientID, + keySet: keySet, + offset: 5 * time.Second, + } } //VerifyTokens implement the Token Response Validation as defined in OIDC specification @@ -48,51 +112,40 @@ func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (*oidc. if err != nil { return nil, err } - //2, check issuer (exact match) - if err := oidc.CheckIssuer(claims.GetIssuer(), v); err != nil { + + if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { return nil, err } - //3. check aud (aud must contain client_id, all aud strings must be allowed) - if err = oidc.CheckAudience(claims.GetAudience(), v); err != nil { + if err = oidc.CheckAudience(claims, v.ClientID()); err != nil { return nil, err } - if err = oidc.CheckAuthorizedParty(claims.GetAudience(), claims.GetAuthorizedParty(), v); err != nil { + if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil { return nil, err } - //6. check signature by keys - //7. check alg default is rs256 - //8. check if alg is mac based (hs...) -> audience contains client_id. for validation use utf-8 representation of your client_secret if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { return nil, err } - //9. check exp before now - if err = oidc.CheckExpiration(claims.GetExpiration(), v); err != nil { + if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { return nil, err } - //10. check iat duration is optional (can be checked) - if err = oidc.CheckIssuedAt(claims.GetIssuedAt(), v); err != nil { + if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil { return nil, err } - /* - //11. check nonce (check if optional possible) id_token.nonce == sentNonce - if err = oidc.CheckNonce(claims.GetNonce()); err != nil { - return nil, err - } - */ - - //12. if acr requested check acr - if err = oidc.CheckAuthorizationContextClassReference(claims.GetAuthenticationContextClassReference(), v); err != nil { + if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil { return nil, err } - //13. if auth_time requested check if auth_time is less than max age - if err = oidc.CheckAuthTime(claims.GetAuthTime(), v); err != nil { + if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil { + return nil, err + } + + if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil { return nil, err } return claims, nil