refactoring

This commit is contained in:
Livio Amstutz 2020-09-25 16:41:25 +02:00
parent 6cfd02e4c9
commit 542ec6ed7b
26 changed files with 1412 additions and 625 deletions

View file

@ -168,7 +168,7 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
if err != nil {
return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.")
}
return claims.Subject, nil
return claims.GetSubject(), nil
}
//RedirectToLogin redirects the end user to the Login UI for authentication

View file

@ -81,7 +81,7 @@ func (s *Sig) Health(ctx context.Context) error {
func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) {
return "", nil
}
func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) {
func (s *Sig) SignAccessToken(*oidc.accessTokenClaims) (string, error) {
return "", nil
}
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {

View file

@ -50,7 +50,7 @@ func (mr *MockSignerMockRecorder) Health(arg0 interface{}) *gomock.Call {
}
// SignAccessToken mocks base method
func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) {
func (m *MockSigner) SignAccessToken(arg0 *oidc.accessTokenClaims) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignAccessToken", arg0)
ret0, _ := ret[0].(string)

View file

@ -184,10 +184,10 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac
}
// GetUserinfoFromScopes mocks base method
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.Userinfo, error) {
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.userinfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2)
ret0, _ := ret[0].(*oidc.Userinfo)
ret0, _ := ret[0].(*oidc.userinfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@ -199,10 +199,10 @@ func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interf
}
// GetUserinfoFromToken mocks base method
func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (*oidc.Userinfo, error) {
func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (*oidc.userinfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2)
ret0, _ := ret[0].(*oidc.Userinfo)
ret0, _ := ret[0].(*oidc.userinfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}

View file

@ -130,7 +130,7 @@ func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opO
}
keyCh := make(chan jose.SigningKey)
o.signer = NewDefaultSigner(ctx, storage, keyCh)
o.signer = NewSigner(ctx, storage, keyCh)
go EnsureKey(ctx, storage, keyCh, o.timer, o.retry)
o.httpHandler = CreateRouter(o, o.interceptors...)

View file

@ -66,8 +66,8 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
if err != nil {
return nil, ErrInvalidRequest("id_token_hint invalid")
}
session.UserID = claims.Subject
session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.AuthorizedParty)
session.UserID = claims.GetSubject()
session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.GetAuthorizedParty())
if err != nil {
return nil, ErrServerError("")
}

View file

@ -2,19 +2,17 @@ package op
import (
"context"
"encoding/json"
"errors"
"github.com/caos/logging"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
)
type Signer interface {
Health(ctx context.Context) error
SignIDToken(claims *oidc.IDTokenClaims) (string, error)
SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
//SignIDToken(claims *oidc.IDTokenClaims) (string, error)
//SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
Signer() jose.Signer
SignatureAlgorithm() jose.SignatureAlgorithm
}
@ -24,7 +22,7 @@ type tokenSigner struct {
alg jose.SignatureAlgorithm
}
func NewDefaultSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer {
func NewSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer {
s := &tokenSigner{
storage: storage,
}
@ -41,6 +39,15 @@ func (s *tokenSigner) Health(_ context.Context) error {
return nil
}
func (s *tokenSigner) Signer() jose.Signer {
return s.signer
}
//
//func (s *tokenSigner) Sign(payload []byte) (*jose.JSONWebSignature, error) {
// return s.signer.Sign(payload)
//}
func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) {
for {
select {
@ -55,30 +62,6 @@ func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.S
}
}
func (s *tokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", err
}
return s.Sign(payload)
}
func (s *tokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", err
}
return s.Sign(payload)
}
func (s *tokenSigner) Sign(payload []byte) (string, error) {
result, err := s.signer.Sign(payload)
if err != nil {
return "", err
}
return result.CompactSerialize()
}
func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
return s.alg
}

View file

@ -38,13 +38,13 @@ import (
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// got, err := op.NewDefaultSigner(tt.args.storage)
// got, err := op.NewSigner(tt.args.storage)
// if (err != nil) != tt.wantErr {
// t.Errorf("NewDefaultSigner() error = %v, wantErr %v", err, tt.wantErr)
// t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
// return
// }
// if !reflect.DeepEqual(got, tt.want) {
// t.Errorf("NewDefaultSigner() = %v, want %v", got, tt.want)
// t.Errorf("NewSigner() = %v, want %v", got, tt.want)
// }
// })
// }

View file

@ -28,8 +28,8 @@ type AuthStorage interface {
type OPStorage interface {
GetClientByClientID(context.Context, string) (Client, error)
AuthorizeClientIDSecret(context.Context, string, string) error
GetUserinfoFromScopes(context.Context, string, []string) (*oidc.Userinfo, error)
GetUserinfoFromToken(context.Context, string, string) (*oidc.Userinfo, error)
GetUserinfoFromScopes(context.Context, string, []string) (oidc.UserInfoSetter, error)
GetUserinfoFromToken(context.Context, string, string) (oidc.UserInfoSetter, error)
GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
}

View file

@ -5,6 +5,7 @@ import (
"time"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
type TokenCreator interface {
@ -82,51 +83,34 @@ func CreateBearerToken(id string, crypto Crypto) (string, error) {
}
func CreateJWT(issuer string, authReq TokenRequest, exp time.Time, id string, signer Signer) (string, error) {
now := time.Now().UTC()
nbf := now
claims := &oidc.AccessTokenClaims{
Issuer: issuer,
Subject: authReq.GetSubject(),
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: now,
NotBefore: nbf,
JWTID: id,
}
return signer.SignAccessToken(claims)
claims := oidc.NewAccessTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, id)
return utils.Sign(claims, signer.Signer())
}
func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer) (string, error) {
var err error
exp := time.Now().UTC().Add(validity)
userinfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes())
if err != nil {
return "", err
}
claims := &oidc.IDTokenClaims{
Issuer: issuer,
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: time.Now().UTC(),
AuthTime: authReq.GetAuthTime(),
Nonce: authReq.GetNonce(),
AuthenticationContextClassReference: authReq.GetACR(),
AuthenticationMethodsReferences: authReq.GetAMR(),
AuthorizedParty: authReq.GetClientID(),
Userinfo: *userinfo,
}
claims := oidc.NewIDTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, authReq.GetAuthTime(), authReq.GetNonce(), authReq.GetACR(), authReq.GetAMR(), authReq.GetClientID())
if accessToken != "" {
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
claims.SetAccessTokenHash(atHash)
} else {
userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes())
if err != nil {
return "", err
}
claims.SetUserinfo(userInfo)
}
if code != "" {
claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm())
codeHash, err := oidc.ClaimHash(code, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
claims.SetCodeHash(codeHash)
}
return signer.SignIDToken(claims)
return utils.Sign(claims, signer.Signer())
}

View file

@ -63,8 +63,8 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet) IDTokenHintVerifi
//VerifyIDTokenHint validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (*oidc.IDTokenClaims, error) {
claims := new(oidc.IDTokenClaims)
func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (oidc.IDTokenClaims, error) {
claims := oidc.EmptyIDTokenClaims()
decrypted, err := oidc.DecryptToken(token)
if err != nil {