From b8d892443ce332472fc3cb11c4a1817e64490206 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Wed, 14 Oct 2020 16:41:04 +0200 Subject: [PATCH] claims assertion --- pkg/oidc/token.go | 433 +++++++++++++++++++------------- pkg/op/client.go | 4 + pkg/op/mock/client.mock.go | 28 +++ pkg/op/mock/storage.mock.go | 23 +- pkg/op/op.go | 19 +- pkg/op/storage.go | 5 +- pkg/op/token.go | 64 ++++- pkg/op/userinfo.go | 19 +- pkg/op/verifier_access_token.go | 85 +++++++ 9 files changed, 491 insertions(+), 189 deletions(-) create mode 100644 pkg/op/verifier_access_token.go diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index e445e7e..2a8c0ad 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -24,6 +24,8 @@ type Tokens struct { type AccessTokenClaims interface { Claims + GetTokenID() string + SetPrivateClaims(map[string]interface{}) } type IDTokenClaims interface { @@ -36,67 +38,13 @@ type IDTokenClaims interface { GetClientID() string GetSignatureAlgorithm() jose.SignatureAlgorithm SetAccessTokenHash(hash string) - SetUserinfo(userinfo UserInfoSetter) + SetUserinfo(userinfo UserInfo) SetCodeHash(hash string) UserInfo } -type accessTokenClaims struct { - Issuer string - Subject string - Audience Audience - Expiration Time - IssuedAt Time - NotBefore Time - JWTID string - AuthorizedParty string - Nonce string - AuthTime Time - CodeHash string - AuthenticationContextClassReference string - AuthenticationMethodsReferences []string - SessionID string - Scopes []string - ClientID string - AccessTokenUseNumber int - - signatureAlg jose.SignatureAlgorithm -} - -func (a accessTokenClaims) GetIssuer() string { - return a.Issuer -} - -func (a accessTokenClaims) GetAudience() []string { - return a.Audience -} - -func (a accessTokenClaims) GetExpiration() time.Time { - return time.Time(a.Expiration) -} - -func (a accessTokenClaims) GetIssuedAt() time.Time { - return time.Time(a.IssuedAt) -} - -func (a accessTokenClaims) GetNonce() string { - return a.Nonce -} - -func (a accessTokenClaims) GetAuthenticationContextClassReference() string { - return a.AuthenticationContextClassReference -} - -func (a accessTokenClaims) GetAuthTime() time.Time { - return time.Time(a.AuthTime) -} - -func (a accessTokenClaims) GetAuthorizedParty() string { - return a.AuthorizedParty -} - -func (a accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) { - a.signatureAlg = algorithm +func EmptyAccessTokenClaims() AccessTokenClaims { + return new(accessTokenClaims) } func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, id string) AccessTokenClaims { @@ -112,6 +60,155 @@ func NewAccessTokenClaims(issuer, subject string, audience []string, expiration } } +type accessTokenClaims struct { + Issuer string `json:"iss,omitempty"` + Subject string `json:"sub,omitempty"` + Audience Audience `json:"aud,omitempty"` + Expiration Time `json:"exp,omitempty"` + IssuedAt Time `json:"iat,omitempty"` + NotBefore Time `json:"nbf,omitempty"` + JWTID string `json:"jti,omitempty"` + AuthorizedParty string `json:"azp,omitempty"` + Nonce string `json:"nonce,omitempty"` + AuthTime Time `json:"auth_time,omitempty"` + CodeHash string `json:"c_hash,omitempty"` + AuthenticationContextClassReference string `json:"acr,omitempty"` + AuthenticationMethodsReferences []string `json:"amr,omitempty"` + SessionID string `json:"sid,omitempty"` + Scopes []string `json:"scope,omitempty"` + ClientID string `json:"client_id,omitempty"` + AccessTokenUseNumber int `json:"at_use_nbr,omitempty"` + + claims map[string]interface{} `json:"-"` + signatureAlg jose.SignatureAlgorithm `json:"-"` +} + +//GetIssuer implements the Claims interface +func (a *accessTokenClaims) GetIssuer() string { + return a.Issuer +} + +//GetAudience implements the Claims interface +func (a *accessTokenClaims) GetAudience() []string { + return a.Audience +} + +//GetExpiration implements the Claims interface +func (a *accessTokenClaims) GetExpiration() time.Time { + return time.Time(a.Expiration) +} + +//GetIssuedAt implements the Claims interface +func (a *accessTokenClaims) GetIssuedAt() time.Time { + return time.Time(a.IssuedAt) +} + +//GetNonce implements the Claims interface +func (a *accessTokenClaims) GetNonce() string { + return a.Nonce +} + +//GetAuthenticationContextClassReference implements the Claims interface +func (a *accessTokenClaims) GetAuthenticationContextClassReference() string { + return a.AuthenticationContextClassReference +} + +//GetAuthTime implements the Claims interface +func (a *accessTokenClaims) GetAuthTime() time.Time { + return time.Time(a.AuthTime) +} + +//GetAuthorizedParty implements the Claims interface +func (a *accessTokenClaims) GetAuthorizedParty() string { + return a.AuthorizedParty +} + +//SetSignatureAlgorithm implements the Claims interface +func (a *accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) { + a.signatureAlg = algorithm +} + +//GetTokenID implements the AccessTokenClaims interface +func (a *accessTokenClaims) GetTokenID() string { + return a.JWTID +} + +//SetPrivateClaims implements the AccessTokenClaims interface +func (a *accessTokenClaims) SetPrivateClaims(claims map[string]interface{}) { + a.claims = claims +} + +func (a *accessTokenClaims) MarshalJSON() ([]byte, error) { + type Alias accessTokenClaims + s := &struct { + *Alias + Expiration int64 `json:"exp,omitempty"` + IssuedAt int64 `json:"iat,omitempty"` + NotBefore int64 `json:"nbf,omitempty"` + AuthTime int64 `json:"auth_time,omitempty"` + }{ + Alias: (*Alias)(a), + } + if !time.Time(a.Expiration).IsZero() { + s.Expiration = time.Time(a.Expiration).Unix() + } + if !time.Time(a.IssuedAt).IsZero() { + s.IssuedAt = time.Time(a.IssuedAt).Unix() + } + if !time.Time(a.NotBefore).IsZero() { + s.NotBefore = time.Time(a.NotBefore).Unix() + } + if !time.Time(a.AuthTime).IsZero() { + s.AuthTime = time.Time(a.AuthTime).Unix() + } + b, err := json.Marshal(s) + if err != nil { + return nil, err + } + + if a.claims == nil { + return b, nil + } + info, err := json.Marshal(a.claims) + if err != nil { + return nil, err + } + return utils.ConcatenateJSON(b, info) +} + +func (a *accessTokenClaims) UnmarshalJSON(data []byte) error { + type Alias accessTokenClaims + if err := json.Unmarshal(data, (*Alias)(a)); err != nil { + return err + } + claims := make(map[string]interface{}) + if err := json.Unmarshal(data, &claims); err != nil { + return err + } + a.claims = claims + + return nil +} + +func EmptyIDTokenClaims() IDTokenClaims { + return new(idTokenClaims) +} + +func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string) IDTokenClaims { + return &idTokenClaims{ + Issuer: issuer, + Audience: audience, + Expiration: Time(expiration), + IssuedAt: Time(time.Now().UTC()), + AuthTime: Time(authTime), + Nonce: nonce, + AuthenticationContextClassReference: acr, + AuthenticationMethodsReferences: amr, + AuthorizedParty: clientID, + UserInfo: &userinfo{Subject: subject}, + } +} + type idTokenClaims struct { Issuer string `json:"iss,omitempty"` Audience Audience `json:"aud,omitempty"` @@ -132,65 +229,153 @@ type idTokenClaims struct { signatureAlg jose.SignatureAlgorithm } -func (t *idTokenClaims) SetAccessTokenHash(hash string) { - t.AccessTokenHash = hash +//GetIssuer implements the Claims interface +func (t *idTokenClaims) GetIssuer() string { + return t.Issuer } -func (t *idTokenClaims) SetUserinfo(info UserInfoSetter) { - t.UserInfo = info +//GetAudience implements the Claims interface +func (t *idTokenClaims) GetAudience() []string { + return t.Audience } -func (t *idTokenClaims) SetCodeHash(hash string) { - t.CodeHash = hash +//GetExpiration implements the Claims interface +func (t *idTokenClaims) GetExpiration() time.Time { + return time.Time(t.Expiration) } -func EmptyIDTokenClaims() IDTokenClaims { - return new(idTokenClaims) +//GetIssuedAt implements the Claims interface +func (t *idTokenClaims) GetIssuedAt() time.Time { + return time.Time(t.IssuedAt) } -func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string) IDTokenClaims { - return &idTokenClaims{ - Issuer: issuer, - Audience: audience, - Expiration: Time(expiration), - IssuedAt: Time(time.Now().UTC()), - AuthTime: Time(authTime), - Nonce: nonce, - AuthenticationContextClassReference: acr, - AuthenticationMethodsReferences: amr, - AuthorizedParty: clientID, - UserInfo: &userinfo{Subject: subject}, - } +//GetNonce implements the Claims interface +func (t *idTokenClaims) GetNonce() string { + return t.Nonce } -func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm { - return t.signatureAlg +//GetAuthenticationContextClassReference implements the Claims interface +func (t *idTokenClaims) GetAuthenticationContextClassReference() string { + return t.AuthenticationContextClassReference } +//GetAuthTime implements the Claims interface +func (t *idTokenClaims) GetAuthTime() time.Time { + return time.Time(t.AuthTime) +} + +//GetAuthorizedParty implements the Claims interface +func (t *idTokenClaims) GetAuthorizedParty() string { + return t.AuthorizedParty +} + +//SetSignatureAlgorithm implements the Claims interface +func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) { + t.signatureAlg = alg +} + +//GetNotBefore implements the IDTokenClaims interface func (t *idTokenClaims) GetNotBefore() time.Time { return time.Time(t.NotBefore) } +//GetJWTID implements the IDTokenClaims interface func (t *idTokenClaims) GetJWTID() string { return t.JWTID } +//GetAccessTokenHash implements the IDTokenClaims interface func (t *idTokenClaims) GetAccessTokenHash() string { return t.AccessTokenHash } +//GetCodeHash implements the IDTokenClaims interface func (t *idTokenClaims) GetCodeHash() string { return t.CodeHash } +//GetAuthenticationMethodsReferences implements the IDTokenClaims interface func (t *idTokenClaims) GetAuthenticationMethodsReferences() []string { return t.AuthenticationMethodsReferences } +//GetClientID implements the IDTokenClaims interface func (t *idTokenClaims) GetClientID() string { return t.ClientID } +//GetSignatureAlgorithm implements the IDTokenClaims interface +func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm { + return t.signatureAlg +} + +//SetSignatureAlgorithm implements the IDTokenClaims interface +func (t *idTokenClaims) SetAccessTokenHash(hash string) { + t.AccessTokenHash = hash +} + +//SetUserinfo implements the IDTokenClaims interface +func (t *idTokenClaims) SetUserinfo(info UserInfo) { + t.UserInfo = info +} + +//SetCodeHash implements the IDTokenClaims interface +func (t *idTokenClaims) SetCodeHash(hash string) { + t.CodeHash = hash +} + +func (t *idTokenClaims) MarshalJSON() ([]byte, error) { + type Alias idTokenClaims + a := &struct { + *Alias + Expiration int64 `json:"exp,omitempty"` + IssuedAt int64 `json:"iat,omitempty"` + NotBefore int64 `json:"nbf,omitempty"` + AuthTime int64 `json:"auth_time,omitempty"` + }{ + Alias: (*Alias)(t), + } + if !time.Time(t.Expiration).IsZero() { + a.Expiration = time.Time(t.Expiration).Unix() + } + if !time.Time(t.IssuedAt).IsZero() { + a.IssuedAt = time.Time(t.IssuedAt).Unix() + } + if !time.Time(t.NotBefore).IsZero() { + a.NotBefore = time.Time(t.NotBefore).Unix() + } + if !time.Time(t.AuthTime).IsZero() { + a.AuthTime = time.Time(t.AuthTime).Unix() + } + b, err := json.Marshal(a) + if err != nil { + return nil, err + } + + if t.UserInfo == nil { + return b, nil + } + info, err := json.Marshal(t.UserInfo) + if err != nil { + return nil, err + } + return utils.ConcatenateJSON(b, info) +} + +func (t *idTokenClaims) UnmarshalJSON(data []byte) error { + type Alias idTokenClaims + if err := json.Unmarshal(data, (*Alias)(t)); err != nil { + return err + } + userinfo := new(userinfo) + if err := json.Unmarshal(data, userinfo); err != nil { + return err + } + t.UserInfo = userinfo + + return nil +} + type AccessTokenResponse struct { AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"` TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"` @@ -242,94 +427,6 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte) } } -func (t *idTokenClaims) MarshalJSON() ([]byte, error) { - type Alias idTokenClaims - a := &struct { - *Alias - Expiration int64 `json:"nbf,omitempty"` - IssuedAt int64 `json:"nbf,omitempty"` - NotBefore int64 `json:"nbf,omitempty"` - AuthTime int64 `json:"nbf,omitempty"` - }{ - Alias: (*Alias)(t), - } - if !time.Time(t.Expiration).IsZero() { - a.Expiration = time.Time(t.Expiration).Unix() - } - if !time.Time(t.IssuedAt).IsZero() { - a.IssuedAt = time.Time(t.IssuedAt).Unix() - } - if !time.Time(t.NotBefore).IsZero() { - a.NotBefore = time.Time(t.NotBefore).Unix() - } - if !time.Time(t.AuthTime).IsZero() { - a.AuthTime = time.Time(t.AuthTime).Unix() - } - b, err := json.Marshal(a) - if err != nil { - return nil, err - } - - if t.UserInfo == nil { - return b, nil - } - info, err := json.Marshal(t.UserInfo) - if err != nil { - return nil, err - } - return utils.ConcatenateJSON(b, info) -} - -func (t *idTokenClaims) UnmarshalJSON(data []byte) error { - type Alias idTokenClaims - if err := json.Unmarshal(data, (*Alias)(t)); err != nil { - return err - } - userinfo := new(userinfo) - if err := json.Unmarshal(data, userinfo); err != nil { - return err - } - t.UserInfo = userinfo - - return nil -} - -func (t *idTokenClaims) GetIssuer() string { - return t.Issuer -} - -func (t *idTokenClaims) GetAudience() []string { - return t.Audience -} - -func (t *idTokenClaims) GetExpiration() time.Time { - return time.Time(t.Expiration) -} - -func (t *idTokenClaims) GetIssuedAt() time.Time { - return time.Time(t.IssuedAt) -} - -func (t *idTokenClaims) GetNonce() string { - return t.Nonce -} - -func (t *idTokenClaims) GetAuthenticationContextClassReference() string { - return t.AuthenticationContextClassReference -} - -func (t *idTokenClaims) GetAuthTime() time.Time { - return time.Time(t.AuthTime) -} - -func (t *idTokenClaims) GetAuthorizedParty() string { - return t.AuthorizedParty -} - -func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) { - t.signatureAlg = alg -} - func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { hash, err := utils.GetHashAlgorithm(sigAlgorithm) if err != nil { diff --git a/pkg/op/client.go b/pkg/op/client.go index 258ce6e..790933e 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -10,7 +10,9 @@ const ( ApplicationTypeWeb ApplicationType = iota ApplicationTypeUserAgent ApplicationTypeNative +) +const ( AccessTokenTypeBearer AccessTokenType = iota AccessTokenTypeJWT ) @@ -33,6 +35,8 @@ type Client interface { IDTokenLifetime() time.Duration DevMode() bool AllowedScopes() []string + AssertAdditionalIdTokenScopes() bool + AssertAdditionalAccessTokenScopes() bool } func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool { diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go index 8e18d56..0780623 100644 --- a/pkg/op/mock/client.mock.go +++ b/pkg/op/mock/client.mock.go @@ -77,6 +77,34 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType)) } +// AssertAdditionalAccessTokenScopes mocks base method +func (m *MockClient) AssertAdditionalAccessTokenScopes() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AssertAdditionalAccessTokenScopes") + ret0, _ := ret[0].(bool) + return ret0 +} + +// AssertAdditionalAccessTokenScopes indicates an expected call of AssertAdditionalAccessTokenScopes +func (mr *MockClientMockRecorder) AssertAdditionalAccessTokenScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalAccessTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalAccessTokenScopes)) +} + +// AssertAdditionalIdTokenScopes mocks base method +func (m *MockClient) AssertAdditionalIdTokenScopes() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AssertAdditionalIdTokenScopes") + ret0, _ := ret[0].(bool) + return ret0 +} + +// AssertAdditionalIdTokenScopes indicates an expected call of AssertAdditionalIdTokenScopes +func (mr *MockClientMockRecorder) AssertAdditionalIdTokenScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalIdTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalIdTokenScopes)) +} + // AuthMethod mocks base method func (m *MockClient) AuthMethod() op.AuthMethod { m.ctrl.T.Helper() diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 973f58b..a184597 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -171,6 +171,21 @@ func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0) } +// GetPrivateClaimsFromScopes mocks base method +func (m *MockStorage) GetPrivateClaimsFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (map[string]interface{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPrivateClaimsFromScopes", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(map[string]interface{}) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPrivateClaimsFromScopes indicates an expected call of GetPrivateClaimsFromScopes +func (mr *MockStorageMockRecorder) GetPrivateClaimsFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivateClaimsFromScopes", reflect.TypeOf((*MockStorage)(nil).GetPrivateClaimsFromScopes), arg0, arg1, arg2, arg3) +} + // GetSigningKey mocks base method func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) { m.ctrl.T.Helper() @@ -184,10 +199,10 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac } // GetUserinfoFromScopes mocks base method -func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (oidc.UserInfoSetter, error) { +func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (oidc.UserInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(oidc.UserInfoSetter) + ret0, _ := ret[0].(oidc.UserInfo) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -199,10 +214,10 @@ func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2, arg3 } // GetUserinfoFromToken mocks base method -func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (oidc.UserInfoSetter, error) { +func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (oidc.UserInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2) - ret0, _ := ret[0].(oidc.UserInfoSetter) + ret0, _ := ret[0].(oidc.UserInfo) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/pkg/op/op.go b/pkg/op/op.go index 7e8279a..bba7a14 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -51,6 +51,7 @@ type OpenIDProvider interface { Encoder() utils.Encoder IDTokenHintVerifier() IDTokenHintVerifier JWTProfileVerifier() JWTProfileVerifier + AccessTokenVerifier() AccessTokenVerifier Crypto() Crypto DefaultLogoutRedirectURI() string Signer() Signer @@ -152,6 +153,8 @@ type openidProvider struct { signer Signer idTokenHintVerifier IDTokenHintVerifier jwtProfileVerifier JWTProfileVerifier + accessTokenVerifier AccessTokenVerifier + keySet *openIDKeySet crypto Crypto httpHandler http.Handler decoder *schema.Decoder @@ -207,7 +210,7 @@ func (o *openidProvider) Encoder() utils.Encoder { func (o *openidProvider) IDTokenHintVerifier() IDTokenHintVerifier { if o.idTokenHintVerifier == nil { - o.idTokenHintVerifier = NewIDTokenHintVerifier(o.Issuer(), &openIDKeySet{o.Storage()}) + o.idTokenHintVerifier = NewIDTokenHintVerifier(o.Issuer(), o.openIDKeySet()) } return o.idTokenHintVerifier } @@ -219,6 +222,20 @@ func (o *openidProvider) JWTProfileVerifier() JWTProfileVerifier { return o.jwtProfileVerifier } +func (o *openidProvider) AccessTokenVerifier() AccessTokenVerifier { + if o.accessTokenVerifier == nil { + o.accessTokenVerifier = NewAccessTokenVerifier(o.Issuer(), o.openIDKeySet()) + } + return o.accessTokenVerifier +} + +func (o *openidProvider) openIDKeySet() oidc.KeySet { + if o.keySet == nil { + o.keySet = &openIDKeySet{o.Storage()} + } + return o.keySet +} + func (o *openidProvider) Crypto() Crypto { return o.crypto } diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 1c266d7..10e7779 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -28,8 +28,9 @@ type AuthStorage interface { type OPStorage interface { GetClientByClientID(context.Context, string) (Client, error) AuthorizeClientIDSecret(context.Context, string, string) error - GetUserinfoFromScopes(context.Context, string, string, []string) (oidc.UserInfoSetter, error) - GetUserinfoFromToken(context.Context, string, string) (oidc.UserInfoSetter, error) + GetUserinfoFromScopes(context.Context, string, string, []string) (oidc.UserInfo, error) + GetUserinfoFromToken(context.Context, string, string) (oidc.UserInfo, error) + GetPrivateClaimsFromScopes(context.Context, string, string, []string) (map[string]interface{}, error) GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) } diff --git a/pkg/op/token.go b/pkg/op/token.go index 670fca7..f542588 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -26,12 +26,12 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client var validity time.Duration if createAccessToken { var err error - accessToken, validity, err = CreateAccessToken(ctx, authReq, client.AccessTokenType(), creator) + accessToken, validity, err = CreateAccessToken(ctx, authReq, client.AccessTokenType(), creator, client) if err != nil { return nil, err } } - idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer()) + idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer(), client.AssertAdditionalIdTokenScopes()) if err != nil { return nil, err } @@ -51,7 +51,7 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client } func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator) (*oidc.AccessTokenResponse, error) { - accessToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator) + accessToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator, nil) if err != nil { return nil, err } @@ -64,14 +64,14 @@ func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, crea }, nil } -func CreateAccessToken(ctx context.Context, authReq TokenRequest, accessTokenType AccessTokenType, creator TokenCreator) (token string, validity time.Duration, err error) { - id, exp, err := creator.Storage().CreateToken(ctx, authReq) +func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client Client) (token string, validity time.Duration, err error) { + id, exp, err := creator.Storage().CreateToken(ctx, tokenRequest) if err != nil { return "", 0, err } validity = exp.Sub(time.Now().UTC()) if accessTokenType == AccessTokenTypeJWT { - token, err = CreateJWT(creator.Issuer(), authReq, exp, id, creator.Signer()) + token, err = CreateJWT(ctx, creator.Issuer(), tokenRequest, exp, id, creator.Signer(), client, creator.Storage()) return } token, err = CreateBearerToken(id, creator.Crypto()) @@ -82,14 +82,22 @@ func CreateBearerToken(id string, crypto Crypto) (string, error) { return crypto.Encrypt(id) } -func CreateJWT(issuer string, tokenRequest TokenRequest, exp time.Time, id string, signer Signer) (string, error) { +func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, signer Signer, client Client, storage Storage) (string, error) { claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id) + if client != nil && client.AssertAdditionalAccessTokenScopes() { + privateClaims, err := storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(tokenRequest.GetScopes())) + if err != nil { + return "", err + } + claims.SetPrivateClaims(privateClaims) + } return utils.Sign(claims, signer.Signer()) } -func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer) (string, error) { +func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer, additonalScopes bool) (string, error) { exp := time.Now().UTC().Add(validity) claims := oidc.NewIDTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, authReq.GetAuthTime(), authReq.GetNonce(), authReq.GetACR(), authReq.GetAMR(), authReq.GetClientID()) + scopes := authReq.GetScopes() if accessToken != "" { atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm()) @@ -97,8 +105,13 @@ func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, vali return "", err } claims.SetAccessTokenHash(atHash) - } else { - userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetClientID(), authReq.GetScopes()) + scopes = removeUserinfoScopes(scopes) + } + if !additonalScopes { + scopes = removeAdditionalScopes(scopes) + } + if len(scopes) > 0 { + userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetClientID(), scopes) if err != nil { return "", err } @@ -114,3 +127,34 @@ func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, vali return utils.Sign(claims, signer.Signer()) } + +func removeUserinfoScopes(scopes []string) []string { + for i := len(scopes) - 1; i >= 0; i-- { + if scopes[i] == oidc.ScopeProfile || + scopes[i] == oidc.ScopeEmail || + scopes[i] == oidc.ScopeAddress || + scopes[i] == oidc.ScopePhone { + + scopes[i] = scopes[len(scopes)-1] + scopes[len(scopes)-1] = "" + scopes = scopes[:len(scopes)-1] + } + } + return scopes +} + +func removeAdditionalScopes(scopes []string) []string { + for i := len(scopes) - 1; i >= 0; i-- { + if !(scopes[i] == oidc.ScopeOpenID || + scopes[i] == oidc.ScopeProfile || + scopes[i] == oidc.ScopeEmail || + scopes[i] == oidc.ScopeAddress || + scopes[i] == oidc.ScopePhone) { + + scopes[i] = scopes[len(scopes)-1] + scopes[len(scopes)-1] = "" + scopes = scopes[:len(scopes)-1] + } + } + return scopes +} diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index 36ecd4a..0b27a5e 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -13,6 +13,7 @@ type UserinfoProvider interface { Decoder() utils.Decoder Crypto() Crypto Storage() Storage + AccessTokenVerifier() AccessTokenVerifier } func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) { @@ -27,10 +28,20 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP http.Error(w, "access token missing", http.StatusUnauthorized) return } - tokenID, err := userinfoProvider.Crypto().Decrypt(accessToken) - if err != nil { - http.Error(w, "access token missing", http.StatusUnauthorized) - return + var tokenID string + if strings.HasPrefix(accessToken, "eyJhbGci") { //TODO: improve + accessTokenClaims, err := VerifyAccessToken(r.Context(), accessToken, userinfoProvider.AccessTokenVerifier()) + if err != nil { + http.Error(w, "access token invalid", http.StatusUnauthorized) + return + } + tokenID = accessTokenClaims.GetTokenID() + } else { + tokenID, err = userinfoProvider.Crypto().Decrypt(accessToken) + if err != nil { + http.Error(w, "access token invalid", http.StatusUnauthorized) + return + } } info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), tokenID, r.Header.Get("origin")) if err != nil { diff --git a/pkg/op/verifier_access_token.go b/pkg/op/verifier_access_token.go new file mode 100644 index 0000000..05168a6 --- /dev/null +++ b/pkg/op/verifier_access_token.go @@ -0,0 +1,85 @@ +package op + +import ( + "context" + "time" + + "github.com/caos/oidc/pkg/oidc" +) + +type AccessTokenVerifier interface { + oidc.Verifier + SupportedSignAlgs() []string + KeySet() oidc.KeySet +} + +type accessTokenVerifier struct { + issuer string + maxAgeIAT time.Duration + offset time.Duration + supportedSignAlgs []string + maxAge time.Duration + acr oidc.ACRVerifier + keySet oidc.KeySet +} + +//Issuer implements oidc.Verifier interface +func (i *accessTokenVerifier) Issuer() string { + return i.issuer +} + +//MaxAgeIAT implements oidc.Verifier interface +func (i *accessTokenVerifier) MaxAgeIAT() time.Duration { + return i.maxAgeIAT +} + +//Offset implements oidc.Verifier interface +func (i *accessTokenVerifier) Offset() time.Duration { + return i.offset +} + +//SupportedSignAlgs implements AccessTokenVerifier interface +func (i *accessTokenVerifier) SupportedSignAlgs() []string { + return i.supportedSignAlgs +} + +//KeySet implements AccessTokenVerifier interface +func (i *accessTokenVerifier) KeySet() oidc.KeySet { + return i.keySet +} + +func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet) AccessTokenVerifier { + verifier := &idTokenHintVerifier{ + issuer: issuer, + keySet: keySet, + } + return verifier +} + +//VerifyAccessToken validates the access token (issuer, signature and expiration) +func VerifyAccessToken(ctx context.Context, token string, v AccessTokenVerifier) (oidc.AccessTokenClaims, error) { + claims := oidc.EmptyAccessTokenClaims() + + decrypted, err := oidc.DecryptToken(token) + if err != nil { + return nil, err + } + payload, err := oidc.ParseToken(decrypted, claims) + if err != nil { + return nil, err + } + + if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { + return nil, err + } + + if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { + return nil, err + } + + if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { + return nil, err + } + + return claims, nil +}