fix: rp verification process (#95)

* fix: rp verification process

* types

* comments

* fix cli client
This commit is contained in:
Livio Amstutz 2021-06-23 11:08:54 +02:00 committed by GitHub
parent 400f5c4de4
commit 850faa159d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 175 additions and 55 deletions

View file

@ -59,11 +59,9 @@ func main() {
//including state handling with secure cookie and the possibility to use PKCE
http.Handle("/login", rp.AuthURLHandler(state, provider))
//for demonstration purposes the returned tokens (access token, id_token an its parsed claims)
//are written as JSON objects onto response
marshal := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string) {
_ = state
data, err := json.Marshal(tokens)
//for demonstration purposes the returned userinfo response is written as JSON object onto response
marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {
data, err := json.Marshal(info)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -71,10 +69,27 @@ func main() {
w.Write(data)
}
//you could also just take the access_token and id_token without calling the userinfo endpoint:
//
//marshalToken := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
// data, err := json.Marshal(tokens)
// if err != nil {
// http.Error(w, err.Error(), http.StatusInternalServerError)
// return
// }
// w.Write(data)
//}
//register the CodeExchangeHandler at the callbackPath
//the CodeExchangeHandler handles the auth response, creates the token request and calls the callback function
//with the returned tokens from the token endpoint
http.Handle(callbackPath, rp.CodeExchangeHandler(marshal, provider))
//in this example the callback function itself is wrapped by the UserinfoCallback which
//will call the Userinfo endpoint, check the sub and pass the info into the callback function
http.Handle(callbackPath, rp.CodeExchangeHandler(rp.UserinfoCallback(marshalUserinfo), provider))
//if you would use the callback without calling the userinfo endpoint, simply switch the callback handler for:
//
//http.Handle(callbackPath, rp.CodeExchangeHandler(marshalToken, provider))
lis := fmt.Sprintf("127.0.0.1:%s", port)
logrus.Infof("listening on http://%s/", lis)

View file

@ -36,6 +36,9 @@ func Discover(issuer string, httpClient *http.Client) (*oidc.DiscoveryConfigurat
if err != nil {
return nil, err
}
if discoveryConfig.Issuer != issuer {
return nil, oidc.ErrIssuerInvalid
}
return discoveryConfig, nil
}

View file

@ -19,7 +19,7 @@ func CodeFlow(ctx context.Context, relyingParty rp.RelyingParty, callbackPath, p
tokenChan := make(chan *oidc.Tokens, 1)
callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string) {
callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
tokenChan <- tokens
msg := "<p><strong>Success!</strong></p>"
msg = msg + "<p>You are authenticated and can now return to the CLI.</p>"

View file

@ -2,6 +2,7 @@ package rp
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
@ -21,6 +22,7 @@ func NewRemoteKeySet(client *http.Client, jwksURL string) oidc.KeySet {
type remoteKeySet struct {
jwksURL string
httpClient *http.Client
defaultAlg string
// guard all other fields
mu sync.Mutex
@ -66,29 +68,27 @@ func (i *inflight) result() ([]jose.JSONWebKey, error) {
}
func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
// We don't support JWTs signed with multiple signatures.
keyID := ""
for _, sig := range jws.Signatures {
keyID = sig.Header.KeyID
break
keyID, alg := oidc.GetKeyIDAndAlg(jws)
if alg == "" {
alg = r.defaultAlg
}
keys := r.keysFromCache()
payload, err, ok := oidc.CheckKey(keyID, jws, keys...)
if ok {
key, ok := oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keys...)
if ok && keyID != "" {
payload, err := jws.Verify(&key)
return payload, err
}
keys, err = r.keysFromRemote(ctx)
keys, err := r.keysFromRemote(ctx)
if err != nil {
return nil, fmt.Errorf("fetching keys %v", err)
}
payload, err, ok = oidc.CheckKey(keyID, jws, keys...)
if !ok {
return nil, errors.New("invalid kid")
key, ok = oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keys...)
if ok {
payload, err := jws.Verify(&key)
return payload, err
}
return payload, err
return nil, errors.New("invalid key")
}
func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {
@ -147,10 +147,34 @@ func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey,
return nil, fmt.Errorf("oidc: can't create request: %v", err)
}
keySet := new(jose.JSONWebKeySet)
keySet := new(jsonWebKeySet)
if err = utils.HttpRequest(r.httpClient, req, keySet); err != nil {
return nil, fmt.Errorf("oidc: failed to get keys: %v", err)
}
return keySet.Keys, nil
}
//jsonWebKeySet is an alias for jose.JSONWebKeySet which ignores unknown key types (kty)
type jsonWebKeySet jose.JSONWebKeySet
//UnmarshalJSON overrides the default jose.JSONWebKeySet method to ignore any error
//which might occur because of unknown key types (kty)
func (k *jsonWebKeySet) UnmarshalJSON(data []byte) (err error) {
var raw rawJSONWebKeySet
err = json.Unmarshal(data, &raw)
if err != nil {
return err
}
for _, key := range raw.Keys {
webKey := new(jose.JSONWebKey)
err = webKey.UnmarshalJSON(key)
if err == nil {
k.Keys = append(k.Keys, *webKey)
}
}
return nil
}
type rawJSONWebKeySet struct {
Keys []json.RawMessage `json:"keys"`
}

View file

@ -22,6 +22,10 @@ const (
pkceCode = "pkce"
)
var (
ErrUserInfoSubNotMatching = errors.New("sub from userinfo does not match the sub from the id_token")
)
//RelyingParty declares the minimal interface for oidc clients
type RelyingParty interface {
//OAuthConfig returns the oauth2 Config
@ -244,6 +248,9 @@ func Discover(issuer string, httpClient *http.Client) (Endpoints, error) {
if err != nil {
return Endpoints{}, err
}
if discoveryConfig.Issuer != issuer {
return Endpoints{}, oidc.ErrIssuerInvalid
}
return GetEndpoints(discoveryConfig), nil
}
@ -319,10 +326,12 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod
return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
}
type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty)
//CodeExchangeHandler extends the `CodeExchange` method with a http handler
//including cookie handling for secure `state` transfer
//and optional PKCE code verifier checking
func CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc.Tokens, string), rp RelyingParty) http.HandlerFunc {
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
state, err := tryReadStateCookie(w, r, rp)
if err != nil {
@ -356,21 +365,40 @@ func CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
return
}
callback(w, r, tokens, state)
callback(w, r, tokens, state, rp)
}
}
type CodeExchangeUserinfoCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, provider RelyingParty, info oidc.UserInfo)
//UserinfoCallback wraps the callback function of the CodeExchangeHandler
//and calls the userinfo endpoint with the access token
//on success it will pass the userinfo into its callback function as well
func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty) {
info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
if err != nil {
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
return
}
f(w, r, tokens, state, rp, info)
}
}
//Userinfo will call the OIDC Userinfo Endpoint with the provided token
func Userinfo(token string, rp RelyingParty) (oidc.UserInfo, error) {
func Userinfo(token, tokenType, subject string, rp RelyingParty) (oidc.UserInfo, error) {
req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil)
if err != nil {
return nil, err
}
req.Header.Set("authorization", token)
req.Header.Set("authorization", tokenType+" "+token)
userinfo := oidc.NewUserInfo()
if err := utils.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil {
return nil, err
}
if userinfo.GetSubject() != subject {
return nil, ErrUserInfoSubNotMatching
}
return userinfo, nil
}

View file

@ -46,7 +46,11 @@ func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.I
return nil, err
}
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
if err := oidc.CheckSubject(claims); err != nil {
return nil, err
}
if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil {
return nil, err
}

View file

@ -2,10 +2,17 @@ package oidc
import (
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"gopkg.in/square/go-jose.v2"
)
const (
KeyUseSignature = "sig"
)
//KeySet represents a set of JSON Web Keys
// - remotely fetch via discovery and jwks_uri -> `remoteKeySet`
// - held by the OP itself in storage -> `openIDKeySet`
@ -15,16 +22,51 @@ type KeySet interface {
VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error)
}
//CheckKey searches the given JSON Web Keys for the requested key ID
//and verifies the JSON Web Signature with the found key
//GetKeyIDAndAlg returns the `kid` and `alg` claim from the JWS header
func GetKeyIDAndAlg(jws *jose.JSONWebSignature) (string, string) {
keyID := ""
alg := ""
for _, sig := range jws.Signatures {
keyID = sig.Header.KeyID
alg = sig.Header.Algorithm
break
}
return keyID, alg
}
//FindKey searches the given JSON Web Keys for the requested key ID, usage and key type
//
//will return false but no error if key ID is not found
func CheckKey(keyID string, jws *jose.JSONWebSignature, keys ...jose.JSONWebKey) ([]byte, error, bool) {
//will return the key immediately if matches exact (id, usage, type)
//
//will return false none or multiple match
func FindKey(keyID, use, expectedAlg string, keys ...jose.JSONWebKey) (jose.JSONWebKey, bool) {
var validKeys []jose.JSONWebKey
for _, key := range keys {
if keyID == "" || key.KeyID == keyID {
payload, err := jws.Verify(&key)
return payload, err, true
if key.KeyID == keyID && key.Use == use && algToKeyType(key.Key, expectedAlg) {
if keyID != "" {
return key, true
}
validKeys = append(validKeys, key)
}
}
return nil, nil, false
if len(validKeys) == 1 {
return validKeys[0], true
}
return jose.JSONWebKey{}, false
}
func algToKeyType(key interface{}, alg string) bool {
switch alg[0] {
case 'R', 'P':
_, ok := key.(*rsa.PublicKey)
return ok
case 'E':
_, ok := key.(*ecdsa.PublicKey)
return ok
case 'O':
_, ok := key.(*ed25519.PublicKey)
return ok
default:
return false
}
}

View file

@ -17,6 +17,7 @@ import (
type Claims interface {
GetIssuer() string
GetSubject() string
GetAudience() []string
GetExpiration() time.Time
GetIssuedAt() time.Time
@ -30,6 +31,7 @@ type Claims interface {
var (
ErrParse = errors.New("parsing of request failed")
ErrIssuerInvalid = errors.New("issuer does not match")
ErrSubjectMissing = errors.New("subject missing")
ErrAudience = errors.New("audience is not valid")
ErrAzpMissing = errors.New("authorized party is not set. If Token is valid for multiple audiences, azp must not be empty")
ErrAzpInvalid = errors.New("authorized party is not valid")
@ -38,6 +40,7 @@ var (
ErrSignatureUnsupportedAlg = errors.New("signature algorithm not supported")
ErrSignatureInvalidPayload = errors.New("signature does not match Payload")
ErrExpired = errors.New("token has expired")
ErrIatMissing = errors.New("issuedAt of token is missing")
ErrIatInFuture = errors.New("issuedAt of token is in the future")
ErrIatToOld = errors.New("issuedAt of token is to old")
ErrNonceInvalid = errors.New("nonce does not match")
@ -84,6 +87,13 @@ func ParseToken(tokenString string, claims interface{}) ([]byte, error) {
return payload, err
}
func CheckSubject(claims Claims) error {
if claims.GetSubject() == "" {
return ErrSubjectMissing
}
return nil
}
func CheckIssuer(claims Claims, issuer string) error {
if claims.GetIssuer() != issuer {
return fmt.Errorf("%w: Expected: %s, got: %s", ErrIssuerInvalid, issuer, claims.GetIssuer())
@ -155,6 +165,9 @@ func CheckExpiration(claims Claims, offset time.Duration) error {
func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error {
issuedAt := claims.GetIssuedAt().Round(time.Second)
if issuedAt.IsZero() {
return ErrIatMissing
}
nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second)
if issuedAt.After(nowWithOffset) {
return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset)
@ -170,9 +183,6 @@ func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error {
}
func CheckNonce(claims Claims, nonce string) error {
if nonce == "" {
return nil
}
if claims.GetNonce() != nonce {
return fmt.Errorf("%w: expected %q but was %q", ErrNonceInvalid, nonce, claims.GetNonce())
}

View file

@ -272,20 +272,16 @@ type openIDKeySet struct {
//VerifySignature implements the oidc.KeySet interface
//providing an implementation for the keys stored in the OP Storage interface
func (o *openIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
keyID := ""
for _, sig := range jws.Signatures {
keyID = sig.Header.KeyID
break
}
keySet, err := o.Storage.GetKeySet(ctx)
if err != nil {
return nil, errors.New("error fetching keys")
}
payload, err, ok := oidc.CheckKey(keyID, jws, keySet.Keys...)
keyID, alg := oidc.GetKeyIDAndAlg(jws)
key, ok := oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keySet.Keys...)
if !ok {
return nil, errors.New("invalid kid")
}
return payload, err
return jws.Verify(key)
}
type Option func(o *openidProvider) error

View file

@ -91,18 +91,13 @@ type jwtProfileKeySet struct {
//VerifySignature implements oidc.KeySet by getting the public key from Storage implementation
func (k *jwtProfileKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
keyID := ""
for _, sig := range jws.Signatures {
keyID = sig.Header.KeyID
break
}
keyID, alg := oidc.GetKeyIDAndAlg(jws)
key, err := k.Storage.GetKeyByIDAndUserID(ctx, keyID, k.userID)
if err != nil {
return nil, fmt.Errorf("error fetching keys: %w", err)
}
payload, err, ok := oidc.CheckKey(keyID, jws, *key)
if !ok {
return nil, errors.New("invalid kid")
if key.Algorithm != alg {
}
return payload, err
return jws.Verify(&key)
}

View file

@ -24,6 +24,9 @@ func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
}
func HashString(hash hash.Hash, s string, firstHalf bool) string {
if hash == nil {
return s
}
//nolint:errcheck
hash.Write([]byte(s))
size := hash.Size()