diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index 6ad1016..a78d9dc 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -1,6 +1,7 @@ package oidc import ( + "encoding/json" "errors" "strings" "time" @@ -64,7 +65,7 @@ const ( PromptSelectAccount Prompt = "select_account" //GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow - GrantTypeCode GrantType = "authorization_code" + GrantTypeCode GrantType = "authorization_code" //GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" @@ -148,10 +149,67 @@ type AccessTokenResponse struct { } type JWTTokenRequest struct { - Scopes Scopes `schema:"scope"` - Audience []string `schema:"aud"` - IssuedAt time.Time `schema:"iat"` - ExpiresAt time.Time `schema:"exp"` + Issuer string `json:"iss"` + Subject string `json:"sub"` + Scopes Scopes `json:"scope"` + Audience string `json:"aud"` + IssuedAt Time `json:"iat"` + ExpiresAt Time `json:"exp"` +} + +func (j *JWTTokenRequest) GetClientID() string { + return j.Subject +} + +func (j *JWTTokenRequest) GetSubject() string { + return j.Subject +} + +func (j *JWTTokenRequest) GetScopes() []string { + return j.Scopes +} + +type Time time.Time + +func (t *Time) UnmarshalJSON(data []byte) error { + var i int64 + if err := json.Unmarshal(data, &i); err != nil { + return err + } + *t = Time(time.Unix(i, 0).UTC()) + return nil +} + +func (j *JWTTokenRequest) GetIssuer() string { + return j.Issuer +} + +func (j *JWTTokenRequest) GetAudience() []string { + return []string{j.Audience} +} + +func (j *JWTTokenRequest) GetExpiration() time.Time { + return time.Time(j.ExpiresAt) +} + +func (j *JWTTokenRequest) GetIssuedAt() time.Time { + return time.Time(j.IssuedAt) +} + +func (j *JWTTokenRequest) GetNonce() string { + return "" +} + +func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string { + return "" +} + +func (j *JWTTokenRequest) GetAuthTime() time.Time { + return time.Time{} +} + +func (j *JWTTokenRequest) GetAuthorizedParty() string { + return "" } type TokenExchangeRequest struct { diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index c468670..b06bc79 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -177,6 +177,42 @@ func (t *IDTokenClaims) UnmarshalJSON(b []byte) error { return nil } +func (t *IDTokenClaims) GetIssuer() string { + return t.Issuer +} + +func (t *IDTokenClaims) GetAudience() []string { + return t.Audiences +} + +func (t *IDTokenClaims) GetExpiration() time.Time { + return t.Expiration +} + +func (t *IDTokenClaims) GetIssuedAt() time.Time { + return t.IssuedAt +} + +func (t *IDTokenClaims) GetNonce() string { + return t.Nonce +} + +func (t *IDTokenClaims) GetAuthenticationContextClassReference() string { + return t.AuthenticationContextClassReference +} + +func (t *IDTokenClaims) GetAuthTime() time.Time { + return t.AuthTime +} + +func (t *IDTokenClaims) GetAuthorizedParty() string { + return t.AuthorizedParty +} + +func (t *IDTokenClaims) SetSignature(alg jose.SignatureAlgorithm) { + t.Signature = alg +} + func (j *jsonToken) UnmarshalUserinfoProfile() UserinfoProfile { locale, _ := language.Parse(j.Locale) return UserinfoProfile{ diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go new file mode 100644 index 0000000..4b0f684 --- /dev/null +++ b/pkg/oidc/verifier.go @@ -0,0 +1,210 @@ +package oidc + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "gopkg.in/square/go-jose.v2" + + "github.com/caos/oidc/pkg/utils" +) + +type Claims interface { + GetIssuer() string + GetAudience() []string + GetExpiration() time.Time + GetIssuedAt() time.Time + GetNonce() string + GetAuthenticationContextClassReference() string + GetAuthTime() time.Time + GetAuthorizedParty() string + SetSignature(algorithm jose.SignatureAlgorithm) +} + +var ( + ErrParse = errors.New("") + ErrIssuerInvalid = errors.New("issuer does not match") + + ErrAudience = errors.New("audience is not valid") + + ErrAzpMissing = errors.New("authorized party is not set. If Token is valid for multiple audiences, azp must not be empty") + ErrAzpInvalid = errors.New("authorized party is not valid") + + ErrSignatureMissing = errors.New("id_token does not contain a signature") + ErrSignatureMultiple = errors.New("id_token contains multiple signatures") + ErrSignatureUnsupportedAlg = errors.New("signature algorithm not supported") + ErrSignatureInvalidPayload = errors.New("signature does not match Payload") + + ErrExpired = errors.New("token has expired") + + ErrIatInFuture = errors.New("issuedAt of token is in the future") + + ErrIatToOld = errors.New("issuedAt of token is to old") + // + //ErrNonceInvalid = func(expected, actual string) *validationError { + // return ValidationError("nonce does not match. Expected: %s, got: %s", expected, actual) + //} + ErrAcrInvalid = errors.New("acr is invalid") + ErrAuthTimeNotPresent = errors.New("claim `auth_time` of token is missing") + ErrAuthTimeToOld = errors.New("auth time of token is to old") + + ErrAtHash = errors.New("at_hash does not correspond to access token") +) + +//ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim +type ACRVerifier func(string) error + +func DecryptToken(tokenString string) (string, error) { + return tokenString, nil //TODO: impl +} + +func ParseToken(tokenString string, claims interface{}) ([]byte, error) { + parts := strings.Split(tokenString, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("%w: token contains an invalid number of segments", ErrParse) + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("%w: malformed jwt payload: %v", ErrParse, err) + } + err = json.Unmarshal(payload, claims) + return payload, err +} + +type Verifier interface { + Issuer() string + ClientID() string + SupportedSignAlgs() []string + KeySet() KeySet + ACR() ACRVerifier + MaxAge() time.Duration + MaxAgeIAT() time.Duration + Offset() time.Duration +} + +func CheckIssuer(issuer string, i Verifier) error { + if i.Issuer() != issuer { + return fmt.Errorf("%w: Expected: %s, got: %s", ErrIssuerInvalid, i.Issuer(), issuer) + } + return nil +} + +func CheckAudience(audiences []string, i Verifier) error { + if !utils.Contains(audiences, i.ClientID()) { + return fmt.Errorf("%w: Audience must contain client_id (%s)", ErrAudience, i.ClientID()) + } + + //TODO: check aud trusted + return nil +} + +//4. if multiple aud strings --> check if azp +//5. if azp --> check azp == client_id +func CheckAuthorizedParty(audiences []string, authorizedParty string, v Verifier) error { + if len(audiences) > 1 { + if authorizedParty == "" { + return ErrAzpMissing + } + } + if authorizedParty != "" && authorizedParty != v.ClientID() { + return fmt.Errorf("%w: azp %q must be equal to client_id %q", ErrAzpInvalid, authorizedParty, v.ClientID()) + } + return nil +} + +func CheckSignature(ctx context.Context, idTokenString string, payload []byte, claims Claims, v Verifier) error { + jws, err := jose.ParseSigned(idTokenString) + if err != nil { + return err + } + if len(jws.Signatures) == 0 { + return ErrSignatureMissing + } + if len(jws.Signatures) > 1 { + return ErrSignatureMultiple + } + sig := jws.Signatures[0] + supportedSigAlgs := v.SupportedSignAlgs() + if len(supportedSigAlgs) == 0 { + supportedSigAlgs = []string{"RS256"} + } + if !utils.Contains(supportedSigAlgs, sig.Header.Algorithm) { + return fmt.Errorf("%w: id token signed with unsupported algorithm, expected %q got %q", ErrSignatureUnsupportedAlg, supportedSigAlgs, sig.Header.Algorithm) + } + + signedPayload, err := v.KeySet().VerifySignature(ctx, jws) + if err != nil { + return err + } + + if !bytes.Equal(signedPayload, payload) { + return ErrSignatureInvalidPayload + } + + claims.SetSignature(jose.SignatureAlgorithm(sig.Header.Algorithm)) + + return nil +} + +func CheckExpiration(expiration time.Time, v Verifier) error { + expiration = expiration.Round(time.Second) + if !time.Now().UTC().Add(v.Offset()).Before(expiration) { + return ErrExpired + } + return nil +} + +func CheckIssuedAt(issuedAt time.Time, v Verifier) error { + issuedAt = issuedAt.Round(time.Second) + offset := time.Now().UTC().Add(v.Offset()).Round(time.Second) + if issuedAt.After(offset) { + return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, offset) + } + if v.MaxAgeIAT() == 0 { + return nil + } + maxAge := time.Now().UTC().Add(-v.MaxAgeIAT()).Round(time.Second) + if issuedAt.Before(maxAge) { + return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt)) + } + return nil +} + +/* +func (v *DefaultVerifier) CheckNonce(nonce string) error { + if v.config.nonce == "" { + return nil + } + if v.config.nonce != nonce { + return ErrNonceInvalid(v.config.nonce, nonce) + } + return nil +}*/ +func CheckAuthorizationContextClassReference(acr string, v Verifier) error { + if v.ACR() != nil { + if err := v.ACR()(acr); err != nil { + return fmt.Errorf("%w: %v", ErrAcrInvalid, err) + } + } + return nil +} +func CheckAuthTime(authTime time.Time, v Verifier) error { + if v.MaxAge() == 0 { + return nil + } + if authTime.IsZero() { + return ErrAuthTimeNotPresent + } + authTime = authTime.Round(time.Second) + maxAge := time.Now().UTC().Add(-v.MaxAge()).Round(time.Second) + if authTime.Before(maxAge) { + return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAge.Sub(authTime)) + } + return nil +} diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index 55d00fa..fbf34bc 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -275,7 +275,7 @@ func (p *DefaultOP) Crypto() Crypto { return p.crypto } -func (p *DefaultOP) Verifier() rp.Verifier { +func (p *DefaultOP) ClientJWTVerifier() rp.Verifier { return p.verifier } diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 54c473b..5461cfa 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -7,6 +7,12 @@ import ( "github.com/caos/oidc/pkg/utils" ) +func DiscoveryHandler(c Configuration, s Signer) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + Discover(w, CreateDiscoveryConfig(c, s)) + } +} + func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) { utils.MarshalJSON(w, config) } diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 9432616..83ef5ab 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -97,7 +97,7 @@ func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0, arg1, arg2 interface{ } // CreateToken mocks base method -func (m *MockStorage) CreateToken(arg0 context.Context, arg1 op.AuthRequest) (string, time.Time, error) { +func (m *MockStorage) CreateToken(arg0 context.Context, arg1 op.TokenRequest) (string, time.Time, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateToken", arg0, arg1) ret0, _ := ret[0].(string) diff --git a/pkg/op/op.go b/pkg/op/op.go index 624a8a1..eaba685 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -22,9 +22,12 @@ type OpenIDProvider interface { HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request) HandleExchange(w http.ResponseWriter, r *http.Request) HandleUserinfo(w http.ResponseWriter, r *http.Request) - HandleEndSession(w http.ResponseWriter, r *http.Request) + //HandleEndSession(w http.ResponseWriter, r *http.Request) HandleKeys(w http.ResponseWriter, r *http.Request) HttpHandler() http.Handler + SessionEnder + Signer() Signer + Probes() []ProbesFn } type HttpInterceptor func(http.Handler) http.Handler @@ -42,13 +45,13 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router handlers.AllowedOriginValidator(allowAllOrigins), )) router.HandleFunc(healthzEndpoint, Healthz) - router.HandleFunc(readinessEndpoint, o.HandleReady) - router.HandleFunc(oidc.DiscoveryEndpoint, o.HandleDiscovery) + router.HandleFunc(readinessEndpoint, Ready(o.Probes())) + router.HandleFunc(oidc.DiscoveryEndpoint, DiscoveryHandler(o, o.Signer())) router.Handle(o.AuthorizationEndpoint().Relative(), intercept(o.HandleAuthorize)) router.Handle(o.AuthorizationEndpoint().Relative()+"/{id}", intercept(o.HandleAuthorizeCallback)) router.Handle(o.TokenEndpoint().Relative(), intercept(o.HandleExchange)) router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo) - router.Handle(o.EndSessionEndpoint().Relative(), intercept(o.HandleEndSession)) + router.Handle(o.EndSessionEndpoint().Relative(), intercept(EndSessionHandler(o))) router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys) return router } diff --git a/pkg/op/probes.go b/pkg/op/probes.go index 50e8a0f..ab12851 100644 --- a/pkg/op/probes.go +++ b/pkg/op/probes.go @@ -14,6 +14,12 @@ func Healthz(w http.ResponseWriter, r *http.Request) { ok(w) } +func Ready(probes []ProbesFn) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + Readiness(w, r, probes...) + } +} + func Readiness(w http.ResponseWriter, r *http.Request, probes ...ProbesFn) { ctx := r.Context() for _, probe := range probes { diff --git a/pkg/op/session.go b/pkg/op/session.go index e60f71b..30dd5ea 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -16,6 +16,12 @@ type SessionEnder interface { DefaultLogoutRedirectURI() string } +func EndSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + EndSession(w, r, ender) + } +} + func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) { req, err := ParseEndSessionRequest(r, ender.Decoder()) if err != nil { diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 17023c1..0eee936 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -16,7 +16,7 @@ type AuthStorage interface { SaveAuthCode(context.Context, string, string) error DeleteAuthRequest(context.Context, string) error - CreateToken(context.Context, AuthRequest) (string, time.Time, error) + CreateToken(context.Context, TokenRequest) (string, time.Time, error) TerminateSession(context.Context, string, string) error diff --git a/pkg/op/token.go b/pkg/op/token.go index 0fbcf60..10c8519 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -14,12 +14,19 @@ type TokenCreator interface { Crypto() Crypto } +type TokenRequest interface { + GetClientID() string + GetSubject() string + GetAudience() []string + GetScopes() []string +} + func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client, creator TokenCreator, createAccessToken bool, code string) (*oidc.AccessTokenResponse, error) { var accessToken string var validity time.Duration if createAccessToken { var err error - accessToken, validity, err = CreateAccessToken(ctx, authReq, client, creator) + accessToken, validity, err = CreateAccessToken(ctx, authReq, client.AccessTokenType(), creator) if err != nil { return nil, err } @@ -43,8 +50,8 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client }, nil } -func CreateJWTTokenResponse(ctx context.Context, authReq AuthRequest, client Client, creator TokenCreator) (*oidc.AccessTokenResponse, error) { - accessToken, validity, err := CreateAccessToken(ctx, authReq, client, creator) +func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator) (*oidc.AccessTokenResponse, error) { + accessToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator) if err != nil { return nil, err } @@ -57,13 +64,13 @@ func CreateJWTTokenResponse(ctx context.Context, authReq AuthRequest, client Cli }, nil } -func CreateAccessToken(ctx context.Context, authReq AuthRequest, client Client, creator TokenCreator) (token string, validity time.Duration, err error) { +func CreateAccessToken(ctx context.Context, authReq TokenRequest, accessTokenType AccessTokenType, creator TokenCreator) (token string, validity time.Duration, err error) { id, exp, err := creator.Storage().CreateToken(ctx, authReq) if err != nil { return "", 0, err } validity = exp.Sub(time.Now().UTC()) - if client.AccessTokenType() == AccessTokenTypeJWT { + if accessTokenType == AccessTokenTypeJWT { token, err = CreateJWT(creator.Issuer(), authReq, exp, id, creator.Signer()) return } @@ -75,7 +82,7 @@ func CreateBearerToken(id string, crypto Crypto) (string, error) { return crypto.Encrypt(id) } -func CreateJWT(issuer string, authReq AuthRequest, exp time.Time, id string, signer Signer) (string, error) { +func CreateJWT(issuer string, authReq TokenRequest, exp time.Time, id string, signer Signer) (string, error) { now := time.Now().UTC() nbf := now claims := &oidc.AccessTokenClaims{ diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index 7de2a70..f208fcc 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -3,7 +3,6 @@ package op import ( "context" "errors" - "fmt" "net/http" "github.com/caos/oidc/pkg/oidc" @@ -22,7 +21,7 @@ type Exchanger interface { type VerifyExchanger interface { Exchanger - Verifier() rp.Verifier + ClientJWTVerifier() rp.Verifier } func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { @@ -121,17 +120,31 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque return authReq, nil } +type ClientJWTVerifier struct { + claims *oidc.JWTTokenRequest + Storage +} + +func (c *ClientJWTVerifier) Issuer() string { + client, err := Storage.GetClientByClientID(context.TODO(), c.claims.Issuer) + return client.GetID() +} + func JWTExchange(w http.ResponseWriter, r *http.Request, exchanger VerifyExchanger) { assertion, err := ParseJWTTokenRequest(r, exchanger.Decoder()) if err != nil { RequestError(w, r, err) } - claims, err := exchanger.Verifier().Verify(r.Context(), "", assertion) + claims := new(oidc.JWTTokenRequest) + //var keyset oidc.KeySet + verifier := new(ClientJWTVerifier) + verifier.claims = claims + err = verifier.VerifyToken(r.Context(), assertion, claims) + if err != nil { + RequestError(w, r, err) + } - fmt.Println(claims, err) - var authReq AuthRequest - var client Client - resp, err := CreateJWTTokenResponse(r.Context(), authReq, client, exchanger) + resp, err := CreateJWTTokenResponse(r.Context(), claims, exchanger) if err != nil { RequestError(w, r, err) return @@ -139,7 +152,7 @@ func JWTExchange(w http.ResponseWriter, r *http.Request, exchanger VerifyExchang utils.MarshalJSON(w, resp) } -func ParseJWTTokenRequest(r *http.Request, decoder *schema.Decoder) (string, error) { +func ParseJWTTokenRequest(r *http.Request, decoder utils.Decoder) (string, error) { err := r.ParseForm() if err != nil { return "", ErrInvalidRequest("error parsing form") diff --git a/pkg/rp/default_verifier.go b/pkg/rp/default_verifier.go index a75eddc..2820e39 100644 --- a/pkg/rp/default_verifier.go +++ b/pkg/rp/default_verifier.go @@ -1,16 +1,10 @@ package rp import ( - "bytes" "context" - "encoding/base64" - "encoding/json" "fmt" - "strings" "time" - "gopkg.in/square/go-jose.v2" - "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/utils" ) @@ -24,9 +18,6 @@ type DefaultVerifier struct { //ConfFunc is the type for providing dynamic options to the DefaultVerfifier type ConfFunc func(*verifierConfig) -//ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim -type ACRVerifier func(string) error - //NewDefaultVerifier creates `DefaultVerifier` with the given //issuer, clientID, keyset and possible configOptions func NewDefaultVerifier(issuer, clientID string, keySet oidc.KeySet, confOpts ...ConfFunc) Verifier { @@ -90,7 +81,7 @@ func WithNonce(nonce string) func(*verifierConfig) { } //WithACRVerifier sets the verifier for the acr claim -func WithACRVerifier(verifier ACRVerifier) func(*verifierConfig) { +func WithACRVerifier(verifier oidc.ACRVerifier) func(*verifierConfig) { return func(conf *verifierConfig) { conf.acr = verifier } @@ -117,7 +108,7 @@ type verifierConfig struct { ignoreAudience bool ignoreExpiration bool iat *iatConfig - acr ACRVerifier + acr oidc.ACRVerifier maxAge time.Duration supportedSignAlgs []string @@ -134,10 +125,10 @@ type iatConfig struct { //DefaultACRVerifier implements `ACRVerifier` returning an error //if non of the provided values matches the acr claim -func DefaultACRVerifier(possibleValues []string) ACRVerifier { +func DefaultACRVerifier(possibleValues []string) oidc.ACRVerifier { return func(acr string) error { if !utils.Contains(possibleValues, acr) { - return ErrAcrInvalid(possibleValues, acr) + return fmt.Errorf("expected one of: %v, got: %q", possibleValues, acr) } return nil } @@ -148,88 +139,13 @@ func DefaultACRVerifier(possibleValues []string) ACRVerifier { //and https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation func (v *DefaultVerifier) Verify(ctx context.Context, accessToken, idTokenString string) (*oidc.IDTokenClaims, error) { v.config.now = time.Now().UTC() - // idToken, err := v.VerifyIDToken(ctx, idTokenString) - // if err != nil { - // return nil, err - // } - // if err := v.verifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil { //TODO: sig from token - // return nil, err - // } - // return idToken, nil - - // TODO: verifiy - decrypted, err := v.decryptToken(idTokenString) - if err != nil { - return nil, err - } - claims, _, err := v.parseToken(decrypted) - if err != nil { - return nil, err - } - return claims, nil + return VerifyTokens(ctx, accessToken, idTokenString, v) } //Verify implements the `VerifyIDToken` method of the `Verifier` interface //according to https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation func (v *DefaultVerifier) VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) { - //1. if encrypted --> decrypt - decrypted, err := v.decryptToken(idTokenString) - if err != nil { - return nil, err - } - claims, payload, err := v.parseToken(decrypted) - if err != nil { - return nil, err - } - // token, err := jwt.ParseWithClaims(decrypted, claims, func(token *jwt.Token) (interface{}, error) { - //2, check issuer (exact match) - if err := v.checkIssuer(claims.Issuer); err != nil { - return nil, err - } - - //3. check aud (aud must contain client_id, all aud strings must be allowed) - if err = v.checkAudience(claims.Audiences); err != nil { - return nil, err - } - - if err = v.checkAuthorizedParty(claims.Audiences, claims.AuthorizedParty); err != nil { - return nil, err - } - - //6. check signature by keys - //7. check alg default is rs256 - //8. check if alg is mac based (hs...) -> audience contains client_id. for validation use utf-8 representation of your client_secret - claims.Signature, err = v.checkSignature(ctx, decrypted, payload) - if err != nil { - return nil, err - } - - //9. check exp before now - if err = v.checkExpiration(claims.Expiration); err != nil { - return nil, err - } - - //10. check iat duration is optional (can be checked) - if err = v.checkIssuedAt(claims.IssuedAt); err != nil { - return nil, err - } - - //11. check nonce (check if optional possible) id_token.nonce == sentNonce - if err = v.checkNonce(claims.Nonce); err != nil { - return nil, err - } - - //12. if acr requested check acr - if err = v.checkAuthorizationContextClassReference(claims.AuthenticationContextClassReference); err != nil { - return nil, err - } - - //13. if auth_time requested check if auth_time is less than max age - if err = v.checkAuthTime(claims.AuthTime); err != nil { - return nil, err - } - - return claims, nil + return VerifyIDToken(ctx, idTokenString, v) } func (v *DefaultVerifier) now() time.Time { @@ -239,161 +155,34 @@ func (v *DefaultVerifier) now() time.Time { return v.config.now } -func (v *DefaultVerifier) parseToken(tokenString string) (*oidc.IDTokenClaims, []byte, error) { - parts := strings.Split(tokenString, ".") - if len(parts) != 3 { - return nil, nil, ValidationError("token contains an invalid number of segments") //TODO: err NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed) - } - payload, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - return nil, nil, fmt.Errorf("oidc: malformed jwt payload: %v", err) - } - idToken := new(oidc.IDTokenClaims) - err = json.Unmarshal(payload, idToken) - return idToken, payload, err +func (v *DefaultVerifier) Issuer() string { + return v.config.issuer } -func (v *DefaultVerifier) checkIssuer(issuer string) error { - if v.config.issuer != issuer { - return ErrIssuerInvalid(v.config.issuer, issuer) - } - return nil +func (v *DefaultVerifier) ClientID() string { + return v.config.clientID } -func (v *DefaultVerifier) checkAudience(audiences []string) error { - if v.config.ignoreAudience { - return nil - } - if !utils.Contains(audiences, v.config.clientID) { - return ErrAudienceMissingClientID(v.config.clientID) - } - - //TODO: check aud trusted - return nil +func (v *DefaultVerifier) SupportedSignAlgs() []string { + return v.config.supportedSignAlgs } -//4. if multiple aud strings --> check if azp -//5. if azp --> check azp == client_id -func (v *DefaultVerifier) checkAuthorizedParty(audiences []string, authorizedParty string) error { - if v.config.ignoreAudience { - return nil - } - if len(audiences) > 1 { - if authorizedParty == "" { - return ErrAzpMissing() - } - } - if authorizedParty != "" && authorizedParty != v.config.clientID { - return ErrAzpInvalid(authorizedParty, v.config.clientID) - } - return nil +func (v *DefaultVerifier) KeySet() oidc.KeySet { + return v.keySet } -func (v *DefaultVerifier) checkSignature(ctx context.Context, idTokenString string, payload []byte) (jose.SignatureAlgorithm, error) { - jws, err := jose.ParseSigned(idTokenString) - if err != nil { - return "", err - } - if len(jws.Signatures) == 0 { - return "", ErrSignatureMissing() - } - if len(jws.Signatures) > 1 { - return "", ErrSignatureMultiple() - } - sig := jws.Signatures[0] - supportedSigAlgs := v.config.supportedSignAlgs - if len(supportedSigAlgs) == 0 { - supportedSigAlgs = []string{"RS256"} - } - if !utils.Contains(supportedSigAlgs, sig.Header.Algorithm) { - return "", fmt.Errorf("oidc: id token signed with unsupported algorithm, expected %q got %q", supportedSigAlgs, sig.Header.Algorithm) - } - - signedPayload, err := v.keySet.VerifySignature(ctx, jws) - if err != nil { - return "", err - } - - if !bytes.Equal(signedPayload, payload) { - return "", ErrSignatureInvalidPayload() - } - return jose.SignatureAlgorithm(sig.Header.Algorithm), nil +func (v *DefaultVerifier) ACR() oidc.ACRVerifier { + return v.config.acr } -func (v *DefaultVerifier) checkExpiration(expiration time.Time) error { - if v.config.ignoreExpiration { - return nil - } - expiration = expiration.Round(time.Second) - if !v.now().Before(expiration) { - return ErrExpInvalid(expiration) - } - return nil +func (v *DefaultVerifier) MaxAge() time.Duration { + return v.config.maxAge } -func (v *DefaultVerifier) checkIssuedAt(issuedAt time.Time) error { - if v.config.iat.ignore { - return nil - } - issuedAt = issuedAt.Round(time.Second) - offset := v.now().Add(v.config.iat.offset).Round(time.Second) - if issuedAt.After(offset) { - return ErrIatInFuture(issuedAt, offset) - } - if v.config.iat.maxAge == 0 { - return nil - } - maxAge := v.now().Add(-v.config.iat.maxAge).Round(time.Second) - if issuedAt.Before(maxAge) { - return ErrIatToOld(maxAge, issuedAt) - } - return nil -} -func (v *DefaultVerifier) checkNonce(nonce string) error { - if v.config.nonce == "" { - return nil - } - if v.config.nonce != nonce { - return ErrNonceInvalid(v.config.nonce, nonce) - } - return nil -} -func (v *DefaultVerifier) checkAuthorizationContextClassReference(acr string) error { - if v.config.acr != nil { - return v.config.acr(acr) - } - return nil -} -func (v *DefaultVerifier) checkAuthTime(authTime time.Time) error { - if v.config.maxAge == 0 { - return nil - } - if authTime.IsZero() { - return ErrAuthTimeNotPresent() - } - authTime = authTime.Round(time.Second) - maxAge := v.now().Add(-v.config.maxAge).Round(time.Second) - if authTime.Before(maxAge) { - return ErrAuthTimeToOld(maxAge, authTime) - } - return nil +func (v *DefaultVerifier) MaxAgeIAT() time.Duration { + return v.config.iat.maxAge } -func (v *DefaultVerifier) decryptToken(tokenString string) (string, error) { - return tokenString, nil //TODO: impl -} - -func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error { - if atHash == "" { - return nil - } - - actual, err := oidc.ClaimHash(accessToken, sigAlgorithm) - if err != nil { - return err - } - if actual != atHash { - return ErrAtHash() - } - return nil +func (v *DefaultVerifier) Offset() time.Duration { + return v.config.iat.offset } diff --git a/pkg/rp/error.go b/pkg/rp/error.go deleted file mode 100644 index fa0ece9..0000000 --- a/pkg/rp/error.go +++ /dev/null @@ -1,67 +0,0 @@ -package rp - -import ( - "fmt" - "time" -) - -var ( - ErrIssuerInvalid = func(expected, actual string) *validationError { - return ValidationError("Issuer does not match. Expected: %s, got: %s", expected, actual) - } - ErrAudienceMissingClientID = func(clientID string) *validationError { - return ValidationError("Audience is not valid. Audience must contain client_id (%s)", clientID) - } - ErrAzpMissing = func() *validationError { - return ValidationError("Authorized Party is not set. If Token is valid for multiple audiences, azp must not be empty") - } - ErrAzpInvalid = func(azp, clientID string) *validationError { - return ValidationError("Authorized Party is not valid. azp (%s) must be equal to client_id (%s)", azp, clientID) - } - ErrExpInvalid = func(exp time.Time) *validationError { - return ValidationError("Token has expired %v", exp) - } - ErrIatInFuture = func(exp, now time.Time) *validationError { - return ValidationError("IssuedAt of token is in the future (%v, now with offset: %v)", exp, now) - } - ErrIatToOld = func(maxAge, iat time.Time) *validationError { - return ValidationError("IssuedAt of token must not be older than %v, but was %v (%v to old)", maxAge, iat, maxAge.Sub(iat)) - } - ErrNonceInvalid = func(expected, actual string) *validationError { - return ValidationError("nonce does not match. Expected: %s, got: %s", expected, actual) - } - ErrAcrInvalid = func(expected []string, actual string) *validationError { - return ValidationError("acr is invalid. Expected one of: %v, got: %s", expected, actual) - } - - ErrAuthTimeNotPresent = func() *validationError { - return ValidationError("claim `auth_time` of token is missing") - } - ErrAuthTimeToOld = func(maxAge, authTime time.Time) *validationError { - return ValidationError("Auth Time of token must not be older than %v, but was %v (%v to old)", maxAge, authTime, maxAge.Sub(authTime)) - } - ErrSignatureMissing = func() *validationError { - return ValidationError("id_token does not contain a signature") - } - ErrSignatureMultiple = func() *validationError { - return ValidationError("id_token contains multiple signatures") - } - ErrSignatureInvalidPayload = func() *validationError { - return ValidationError("Signature does not match Payload") - } - ErrAtHash = func() *validationError { - return ValidationError("at_hash does not correspond to access token") - } -) - -func ValidationError(message string, args ...interface{}) *validationError { - return &validationError{fmt.Sprintf(message, args...)} //TODO: impl -} - -type validationError struct { - message string -} - -func (v *validationError) Error() string { - return v.message -} diff --git a/pkg/rp/verifier.go b/pkg/rp/verifier.go index 5add60f..61caed1 100644 --- a/pkg/rp/verifier.go +++ b/pkg/rp/verifier.go @@ -3,11 +3,12 @@ package rp import ( "context" + "gopkg.in/square/go-jose.v2" + "github.com/caos/oidc/pkg/oidc" ) -//Verifier implement the Token Response Validation as defined in OIDC specification -//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation +//deprecated: Use IDTokenVerifier or oidc.Verifier type Verifier interface { //Verify checks the access_token and id_token and returns the `id token claims` @@ -16,3 +17,100 @@ type Verifier interface { //VerifyIDToken checks the id_token only and returns its `id token claims` VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) } + +type IDTokenVerifier interface { + oidc.Verifier +} + +//VerifyTokens implement the Token Response Validation as defined in OIDC specification +//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation +func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) { + idToken, err := VerifyIDToken(ctx, idTokenString, v) + if err != nil { + return nil, err + } + if err := VerifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil { + return nil, err + } + return idToken, nil +} + +//VerifyIDToken validates the id token according to +//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation +func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) { + claims := new(oidc.IDTokenClaims) + + decrypted, err := oidc.DecryptToken(token) + if err != nil { + return nil, err + } + payload, err := oidc.ParseToken(decrypted, claims) + if err != nil { + return nil, err + } + //2, check issuer (exact match) + if err := oidc.CheckIssuer(claims.GetIssuer(), v); err != nil { + return nil, err + } + + //3. check aud (aud must contain client_id, all aud strings must be allowed) + if err = oidc.CheckAudience(claims.GetAudience(), v); err != nil { + return nil, err + } + + if err = oidc.CheckAuthorizedParty(claims.GetAudience(), claims.GetAuthorizedParty(), v); err != nil { + return nil, err + } + + //6. check signature by keys + //7. check alg default is rs256 + //8. check if alg is mac based (hs...) -> audience contains client_id. for validation use utf-8 representation of your client_secret + if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v); err != nil { + return nil, err + } + + //9. check exp before now + if err = oidc.CheckExpiration(claims.GetExpiration(), v); err != nil { + return nil, err + } + + //10. check iat duration is optional (can be checked) + if err = oidc.CheckIssuedAt(claims.GetIssuedAt(), v); err != nil { + return nil, err + } + + /* + //11. check nonce (check if optional possible) id_token.nonce == sentNonce + if err = oidc.CheckNonce(claims.GetNonce()); err != nil { + return nil, err + } + */ + + //12. if acr requested check acr + if err = oidc.CheckAuthorizationContextClassReference(claims.GetAuthenticationContextClassReference(), v); err != nil { + return nil, err + } + + //13. if auth_time requested check if auth_time is less than max age + if err = oidc.CheckAuthTime(claims.GetAuthTime(), v); err != nil { + return nil, err + } + return claims, nil +} + +//VerifyAccessToken validates the access token according to +//https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation +func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error { + if atHash == "" { + return nil + } + + actual, err := oidc.ClaimHash(accessToken, sigAlgorithm) + if err != nil { + return err + } + if actual != atHash { + return oidc.ErrAtHash + } + return nil +} diff --git a/pkg/rp/verity.go b/pkg/rp/verity.go new file mode 100644 index 0000000..189ce90 --- /dev/null +++ b/pkg/rp/verity.go @@ -0,0 +1,9 @@ +package rp + +import ( + "context" + + "gopkg.in/square/go-jose.v2" + + "github.com/caos/oidc/pkg/oidc" +)