diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 068e8e6..f753120 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" + "fmt" "io/ioutil" "time" @@ -399,7 +400,19 @@ type AccessTokenResponse struct { IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"` } -type JWTProfileAssertion struct { +type JWTProfileAssertionClaims interface { + GetKeyID() string + GetPrivateKey() []byte + GetIssuer() string + GetSubject() string + GetAudience() []string + GetExpiration() time.Time + GetIssuedAt() time.Time + SetCustomClaim(key string, value interface{}) + GetCustomClaim(key string) interface{} +} + +type jwtProfileAssertion struct { PrivateKeyID string `json:"-"` PrivateKey []byte `json:"-"` Issuer string `json:"iss"` @@ -407,17 +420,99 @@ type JWTProfileAssertion struct { Audience Audience `json:"aud"` Expiration Time `json:"exp"` IssuedAt Time `json:"iat"` + + customClaims map[string]interface{} } -func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWTProfileAssertion, error) { +func (j *jwtProfileAssertion) MarshalJSON() ([]byte, error) { + type Alias jwtProfileAssertion + a := (*Alias)(j) + + b, err := json.Marshal(a) + if err != nil { + return nil, err + } + + if len(j.customClaims) == 0 { + return b, nil + } + + err = json.Unmarshal(b, &j.customClaims) + if err != nil { + return nil, fmt.Errorf("jws: invalid map of custom claims %v", j.customClaims) + } + + return json.Marshal(j.customClaims) +} + +func (j *jwtProfileAssertion) UnmarshalJSON(data []byte) error { + type Alias jwtProfileAssertion + a := (*Alias)(j) + + err := json.Unmarshal(data, a) + if err != nil { + return err + } + + err = json.Unmarshal(data, &j.customClaims) + if err != nil { + return err + } + + return nil +} + +func (j *jwtProfileAssertion) GetKeyID() string { + return j.PrivateKeyID +} + +func (j *jwtProfileAssertion) GetPrivateKey() []byte { + return j.PrivateKey +} + +func (j *jwtProfileAssertion) SetCustomClaim(key string, value interface{}) { + if j.customClaims == nil { + j.customClaims = make(map[string]interface{}) + } + j.customClaims[key] = value +} + +func (j *jwtProfileAssertion) GetCustomClaim(key string) interface{} { + if j.customClaims == nil { + return nil + } + return j.customClaims[key] +} + +func (j *jwtProfileAssertion) GetIssuer() string { + return j.Issuer +} + +func (j *jwtProfileAssertion) GetSubject() string { + return j.Subject +} + +func (j *jwtProfileAssertion) GetAudience() []string { + return j.Audience +} + +func (j *jwtProfileAssertion) GetExpiration() time.Time { + return time.Time(j.Expiration) +} + +func (j *jwtProfileAssertion) GetIssuedAt() time.Time { + return time.Time(j.IssuedAt) +} + +func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string, opts ...AssertionOption) (JWTProfileAssertionClaims, error) { data, err := ioutil.ReadFile(filename) if err != nil { return nil, err } - return NewJWTProfileAssertionFromFileData(data, audience) + return NewJWTProfileAssertionFromFileData(data, audience, opts...) } -func NewJWTProfileAssertionStringFromFileData(data []byte, audience []string) (string, error) { +func NewJWTProfileAssertionStringFromFileData(data []byte, audience []string, opts ...AssertionOption) (string, error) { keyData := new(struct { KeyID string `json:"keyId"` Key string `json:"key"` @@ -427,10 +522,22 @@ func NewJWTProfileAssertionStringFromFileData(data []byte, audience []string) (s if err != nil { return "", err } - return GenerateJWTProfileToken(NewJWTProfileAssertion(keyData.UserID, keyData.KeyID, audience, []byte(keyData.Key))) + return GenerateJWTProfileToken(NewJWTProfileAssertion(keyData.UserID, keyData.KeyID, audience, []byte(keyData.Key), opts...)) } -func NewJWTProfileAssertionFromFileData(data []byte, audience []string) (*JWTProfileAssertion, error) { +func JWTProfileDelegatedSubject(sub string) func(*jwtProfileAssertion) { + return func(j *jwtProfileAssertion) { + j.Subject = sub + } +} + +func JWTProfileCustomClaim(key string, value interface{}) func(*jwtProfileAssertion) { + return func(j *jwtProfileAssertion) { + j.customClaims[key] = value + } +} + +func NewJWTProfileAssertionFromFileData(data []byte, audience []string, opts ...AssertionOption) (JWTProfileAssertionClaims, error) { keyData := new(struct { KeyID string `json:"keyId"` Key string `json:"key"` @@ -440,11 +547,13 @@ func NewJWTProfileAssertionFromFileData(data []byte, audience []string) (*JWTPro if err != nil { return nil, err } - return NewJWTProfileAssertion(keyData.UserID, keyData.KeyID, audience, []byte(keyData.Key)), nil + return NewJWTProfileAssertion(keyData.UserID, keyData.KeyID, audience, []byte(keyData.Key), opts...), nil } -func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte) *JWTProfileAssertion { - return &JWTProfileAssertion{ +type AssertionOption func(*jwtProfileAssertion) + +func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte, opts ...AssertionOption) JWTProfileAssertionClaims { + j := &jwtProfileAssertion{ PrivateKey: key, PrivateKeyID: keyID, Issuer: userID, @@ -452,7 +561,14 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte) IssuedAt: Time(time.Now().UTC()), Expiration: Time(time.Now().Add(1 * time.Hour).UTC()), Audience: audience, + customClaims: make(map[string]interface{}), } + + for _, opt := range opts { + opt(j) + } + + return j } func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { @@ -473,14 +589,14 @@ func AppendClientIDToAudience(clientID string, audience []string) []string { return append(audience, clientID) } -func GenerateJWTProfileToken(assertion *JWTProfileAssertion) (string, error) { - privateKey, err := bytesToPrivateKey(assertion.PrivateKey) +func GenerateJWTProfileToken(assertion JWTProfileAssertionClaims) (string, error) { + privateKey, err := bytesToPrivateKey(assertion.GetPrivateKey()) if err != nil { return "", err } key := jose.SigningKey{ Algorithm: jose.RS256, - Key: &jose.JSONWebKey{Key: privateKey, KeyID: assertion.PrivateKeyID}, + Key: &jose.JSONWebKey{Key: privateKey, KeyID: assertion.GetKeyID()}, } signer, err := jose.NewSigner(key, &jose.SignerOptions{}) if err != nil { diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go index 4899c3a..6f9f1af 100644 --- a/pkg/oidc/token_request.go +++ b/pkg/oidc/token_request.go @@ -1,6 +1,8 @@ package oidc import ( + "encoding/json" + "fmt" "time" "gopkg.in/square/go-jose.v2" @@ -87,6 +89,50 @@ type JWTTokenRequest struct { Audience Audience `json:"aud"` IssuedAt Time `json:"iat"` ExpiresAt Time `json:"exp"` + + private map[string]interface{} +} + +func (j *JWTTokenRequest) MarshalJSON() ([]byte, error) { + type Alias JWTTokenRequest + a := (*Alias)(j) + + b, err := json.Marshal(a) + if err != nil { + return nil, err + } + + if len(j.private) == 0 { + return b, nil + } + + err = json.Unmarshal(b, &j.private) + if err != nil { + return nil, fmt.Errorf("jws: invalid map of custom claims %v", j.private) + } + + return json.Marshal(j.private) +} + +func (j *JWTTokenRequest) UnmarshalJSON(data []byte) error { + type Alias JWTTokenRequest + a := (*Alias)(j) + + err := json.Unmarshal(data, a) + if err != nil { + return err + } + + err = json.Unmarshal(data, &j.private) + if err != nil { + return err + } + + return nil +} + +func (j *JWTTokenRequest) GetCustomClaim(key string) interface{} { + return j.private[key] } //GetIssuer implements the Claims interface diff --git a/pkg/op/keys.go b/pkg/op/keys.go index 4b8d607..c4b11d4 100644 --- a/pkg/op/keys.go +++ b/pkg/op/keys.go @@ -1,13 +1,16 @@ package op import ( + "context" "net/http" + "gopkg.in/square/go-jose.v2" + "github.com/caos/oidc/pkg/utils" ) type KeyProvider interface { - Storage() Storage + GetKeySet(context.Context) (*jose.JSONWebKeySet, error) } func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) { @@ -17,7 +20,7 @@ func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) { } func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) { - keySet, err := k.Storage().GetKeySet(r.Context()) + keySet, err := k.GetKeySet(r.Context()) if err != nil { w.WriteHeader(http.StatusInternalServerError) utils.MarshalJSON(w, err) diff --git a/pkg/op/op.go b/pkg/op/op.go index 518ffdf..241f0ce 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -74,7 +74,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o)) router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o)) router.Handle(o.EndSessionEndpoint().Relative(), intercept(endSessionHandler(o))) - router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o)) + router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage())) return router } @@ -281,7 +281,7 @@ func (o *openIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSig if !ok { return nil, errors.New("invalid kid") } - return jws.Verify(key) + return jws.Verify(&key) } type Option func(o *openidProvider) error diff --git a/pkg/op/verifier_jwt_profile.go b/pkg/op/verifier_jwt_profile.go index b425f80..e7784b5 100644 --- a/pkg/op/verifier_jwt_profile.go +++ b/pkg/op/verifier_jwt_profile.go @@ -13,23 +13,40 @@ import ( type JWTProfileVerifier interface { oidc.Verifier - Storage() Storage + Storage() jwtProfileKeyStorage + CheckSubject(request *oidc.JWTTokenRequest) error } type jwtProfileVerifier struct { - storage Storage - issuer string - maxAgeIAT time.Duration - offset time.Duration + storage jwtProfileKeyStorage + subjectCheck func(request *oidc.JWTTokenRequest) error + issuer string + maxAgeIAT time.Duration + offset time.Duration } //NewJWTProfileVerifier creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication) -func NewJWTProfileVerifier(storage Storage, issuer string, maxAgeIAT, offset time.Duration) JWTProfileVerifier { - return &jwtProfileVerifier{ - storage: storage, - issuer: issuer, - maxAgeIAT: maxAgeIAT, - offset: offset, +func NewJWTProfileVerifier(storage jwtProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) JWTProfileVerifier { + j := &jwtProfileVerifier{ + storage: storage, + subjectCheck: SubjectIsIssuer, + issuer: issuer, + maxAgeIAT: maxAgeIAT, + offset: offset, + } + + for _, opt := range opts { + opt(j) + } + + return j +} + +type JWTProfileVerifierOption func(*jwtProfileVerifier) + +func SubjectCheck(check func(request *oidc.JWTTokenRequest) error) JWTProfileVerifierOption { + return func(verifier *jwtProfileVerifier) { + verifier.subjectCheck = check } } @@ -37,7 +54,7 @@ func (v *jwtProfileVerifier) Issuer() string { return v.issuer } -func (v *jwtProfileVerifier) Storage() Storage { +func (v *jwtProfileVerifier) Storage() jwtProfileKeyStorage { return v.storage } @@ -49,6 +66,10 @@ func (v *jwtProfileVerifier) Offset() time.Duration { return v.offset } +func (v *jwtProfileVerifier) CheckSubject(request *oidc.JWTTokenRequest) error { + return v.subjectCheck(request) +} + //VerifyJWTAssertion verifies the assertion string from JWT Profile (authorization grant and client authentication) // //checks audience, exp, iat, signature and that issuer and sub are the same @@ -71,9 +92,8 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif return nil, err } - if request.Issuer != request.Subject { - //TODO: implement delegation (openid core / oauth rfc) - return nil, errors.New("delegation not yet implemented, issuer and sub must be identical") + if err = v.CheckSubject(request); err != nil { + return nil, err } keySet := &jwtProfileKeySet{v.Storage(), request.Issuer} @@ -84,20 +104,28 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif return request, nil } +type jwtProfileKeyStorage interface { + GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) +} + +func SubjectIsIssuer(request *oidc.JWTTokenRequest) error { + if request.Issuer != request.Subject { + return errors.New("delegation not allowed, issuer and sub must be identical") + } + return nil +} + type jwtProfileKeySet struct { - Storage - userID string + storage jwtProfileKeyStorage + userID string } //VerifySignature implements oidc.KeySet by getting the public key from Storage implementation func (k *jwtProfileKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) { - keyID, alg := oidc.GetKeyIDAndAlg(jws) - key, err := k.Storage.GetKeyByIDAndUserID(ctx, keyID, k.userID) + keyID, _ := oidc.GetKeyIDAndAlg(jws) + key, err := k.storage.GetKeyByIDAndUserID(ctx, keyID, k.userID) if err != nil { return nil, fmt.Errorf("error fetching keys: %w", err) } - if key.Algorithm != alg { - - } - return jws.Verify(&key) + return jws.Verify(key) }