fix: rp verification process (#95)
* fix: rp verification process * types * comments * fix cli client
This commit is contained in:
parent
400f5c4de4
commit
850faa159d
11 changed files with 175 additions and 55 deletions
|
@ -19,7 +19,7 @@ func CodeFlow(ctx context.Context, relyingParty rp.RelyingParty, callbackPath, p
|
|||
|
||||
tokenChan := make(chan *oidc.Tokens, 1)
|
||||
|
||||
callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string) {
|
||||
callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
|
||||
tokenChan <- tokens
|
||||
msg := "<p><strong>Success!</strong></p>"
|
||||
msg = msg + "<p>You are authenticated and can now return to the CLI.</p>"
|
||||
|
|
|
@ -2,6 +2,7 @@ package rp
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -21,6 +22,7 @@ func NewRemoteKeySet(client *http.Client, jwksURL string) oidc.KeySet {
|
|||
type remoteKeySet struct {
|
||||
jwksURL string
|
||||
httpClient *http.Client
|
||||
defaultAlg string
|
||||
|
||||
// guard all other fields
|
||||
mu sync.Mutex
|
||||
|
@ -66,29 +68,27 @@ func (i *inflight) result() ([]jose.JSONWebKey, error) {
|
|||
}
|
||||
|
||||
func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||
// We don't support JWTs signed with multiple signatures.
|
||||
keyID := ""
|
||||
for _, sig := range jws.Signatures {
|
||||
keyID = sig.Header.KeyID
|
||||
break
|
||||
keyID, alg := oidc.GetKeyIDAndAlg(jws)
|
||||
if alg == "" {
|
||||
alg = r.defaultAlg
|
||||
}
|
||||
|
||||
keys := r.keysFromCache()
|
||||
payload, err, ok := oidc.CheckKey(keyID, jws, keys...)
|
||||
if ok {
|
||||
key, ok := oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keys...)
|
||||
if ok && keyID != "" {
|
||||
payload, err := jws.Verify(&key)
|
||||
return payload, err
|
||||
}
|
||||
|
||||
keys, err = r.keysFromRemote(ctx)
|
||||
keys, err := r.keysFromRemote(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching keys %v", err)
|
||||
}
|
||||
|
||||
payload, err, ok = oidc.CheckKey(keyID, jws, keys...)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid kid")
|
||||
key, ok = oidc.FindKey(keyID, oidc.KeyUseSignature, alg, keys...)
|
||||
if ok {
|
||||
payload, err := jws.Verify(&key)
|
||||
return payload, err
|
||||
}
|
||||
return payload, err
|
||||
return nil, errors.New("invalid key")
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {
|
||||
|
@ -147,10 +147,34 @@ func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey,
|
|||
return nil, fmt.Errorf("oidc: can't create request: %v", err)
|
||||
}
|
||||
|
||||
keySet := new(jose.JSONWebKeySet)
|
||||
keySet := new(jsonWebKeySet)
|
||||
if err = utils.HttpRequest(r.httpClient, req, keySet); err != nil {
|
||||
return nil, fmt.Errorf("oidc: failed to get keys: %v", err)
|
||||
}
|
||||
|
||||
return keySet.Keys, nil
|
||||
}
|
||||
|
||||
//jsonWebKeySet is an alias for jose.JSONWebKeySet which ignores unknown key types (kty)
|
||||
type jsonWebKeySet jose.JSONWebKeySet
|
||||
|
||||
//UnmarshalJSON overrides the default jose.JSONWebKeySet method to ignore any error
|
||||
//which might occur because of unknown key types (kty)
|
||||
func (k *jsonWebKeySet) UnmarshalJSON(data []byte) (err error) {
|
||||
var raw rawJSONWebKeySet
|
||||
err = json.Unmarshal(data, &raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range raw.Keys {
|
||||
webKey := new(jose.JSONWebKey)
|
||||
err = webKey.UnmarshalJSON(key)
|
||||
if err == nil {
|
||||
k.Keys = append(k.Keys, *webKey)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type rawJSONWebKeySet struct {
|
||||
Keys []json.RawMessage `json:"keys"`
|
||||
}
|
||||
|
|
|
@ -22,6 +22,10 @@ const (
|
|||
pkceCode = "pkce"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserInfoSubNotMatching = errors.New("sub from userinfo does not match the sub from the id_token")
|
||||
)
|
||||
|
||||
//RelyingParty declares the minimal interface for oidc clients
|
||||
type RelyingParty interface {
|
||||
//OAuthConfig returns the oauth2 Config
|
||||
|
@ -244,6 +248,9 @@ func Discover(issuer string, httpClient *http.Client) (Endpoints, error) {
|
|||
if err != nil {
|
||||
return Endpoints{}, err
|
||||
}
|
||||
if discoveryConfig.Issuer != issuer {
|
||||
return Endpoints{}, oidc.ErrIssuerInvalid
|
||||
}
|
||||
return GetEndpoints(discoveryConfig), nil
|
||||
}
|
||||
|
||||
|
@ -319,10 +326,12 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod
|
|||
return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
|
||||
}
|
||||
|
||||
type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty)
|
||||
|
||||
//CodeExchangeHandler extends the `CodeExchange` method with a http handler
|
||||
//including cookie handling for secure `state` transfer
|
||||
//and optional PKCE code verifier checking
|
||||
func CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc.Tokens, string), rp RelyingParty) http.HandlerFunc {
|
||||
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := tryReadStateCookie(w, r, rp)
|
||||
if err != nil {
|
||||
|
@ -356,21 +365,40 @@ func CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc
|
|||
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
callback(w, r, tokens, state)
|
||||
callback(w, r, tokens, state, rp)
|
||||
}
|
||||
}
|
||||
|
||||
type CodeExchangeUserinfoCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, provider RelyingParty, info oidc.UserInfo)
|
||||
|
||||
//UserinfoCallback wraps the callback function of the CodeExchangeHandler
|
||||
//and calls the userinfo endpoint with the access token
|
||||
//on success it will pass the userinfo into its callback function as well
|
||||
func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
|
||||
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty) {
|
||||
info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
|
||||
if err != nil {
|
||||
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
f(w, r, tokens, state, rp, info)
|
||||
}
|
||||
}
|
||||
|
||||
//Userinfo will call the OIDC Userinfo Endpoint with the provided token
|
||||
func Userinfo(token string, rp RelyingParty) (oidc.UserInfo, error) {
|
||||
func Userinfo(token, tokenType, subject string, rp RelyingParty) (oidc.UserInfo, error) {
|
||||
req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("authorization", token)
|
||||
req.Header.Set("authorization", tokenType+" "+token)
|
||||
userinfo := oidc.NewUserInfo()
|
||||
if err := utils.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userinfo.GetSubject() != subject {
|
||||
return nil, ErrUserInfoSubNotMatching
|
||||
}
|
||||
return userinfo, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -46,7 +46,11 @@ func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.I
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
||||
if err := oidc.CheckSubject(claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue