try different packags structures

This commit is contained in:
Livio Amstutz 2019-11-19 13:10:40 +01:00
parent f6ba7ab75e
commit f19df22e8e
36 changed files with 1821 additions and 23 deletions

View file

@ -0,0 +1,202 @@
package defaults
import (
"context"
"encoding/base64"
"net/http"
"strings"
"github.com/caos/oidc/pkg/oidc/grants"
"golang.org/x/oauth2"
"github.com/caos/oidc/pkg/oidc"
grants_tx "github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
"github.com/caos/oidc/pkg/rp"
"github.com/caos/oidc/pkg/rp/tokenexchange"
"github.com/caos/oidc/pkg/utils"
)
const (
idTokenKey = "id_token"
stateParam = "state"
)
type DefaultRP struct {
endpoints rp.Endpoints
oauthConfig oauth2.Config
config *rp.Config
httpClient *http.Client
cookieHandler *utils.CookieHandler
verifier rp.Verifier
}
func NewDefaultRelayingParty(rpConfig *rp.Config, rpOpts ...DefaultReplayingPartyOpts) (tokenexchange.DelegationTokenExchangeRP, error) {
p := &DefaultRP{
config: rpConfig,
httpClient: utils.DefaultHTTPClient,
}
for _, optFunc := range rpOpts {
optFunc(p)
}
if err := p.discover(); err != nil {
return nil, err
}
if p.verifier == nil {
// p.verifier = NewVerifier(rpConfig.Issuer, rpConfig.ClientID, utils.NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL)) //TODO: keys endpoint
}
return p, nil
}
type DefaultReplayingPartyOpts func(p *DefaultRP)
func WithCookieHandler(cookieHandler *utils.CookieHandler) DefaultReplayingPartyOpts {
return func(p *DefaultRP) {
p.cookieHandler = cookieHandler
}
}
func WithHTTPClient(client *http.Client) DefaultReplayingPartyOpts {
return func(p *DefaultRP) {
p.httpClient = client
}
}
func (p *DefaultRP) AuthURL(state string) string {
return p.oauthConfig.AuthCodeURL(state)
}
func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := p.trySetStateCookie(w, state); err != nil {
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
return
}
http.Redirect(w, r, p.AuthURL(state), http.StatusFound)
}
}
func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc.Tokens, err error) {
token, err := p.oauthConfig.Exchange(ctx, code)
if err != nil {
return nil, err //TODO: our error
}
idTokenString, ok := token.Extra(idTokenKey).(string)
if !ok {
//TODO: implement
}
idToken, err := p.verifier.Verify(ctx, token.AccessToken, idTokenString)
if err != nil {
return nil, err //TODO: err
}
return &oidc.Tokens{Token: token, IDTokenClaims: idToken}, nil
}
func (p *DefaultRP) CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc.Tokens, string)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
state, err := p.tryReadStateCookie(w, r)
if err != nil {
http.Error(w, "failed to get state: "+err.Error(), http.StatusUnauthorized)
return
}
tokens, err := p.CodeExchange(r.Context(), r.URL.Query().Get("code"))
if err != nil {
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
return
}
callback(w, r, tokens, state)
}
}
// func (p *DefaultRP) Introspect(ctx context.Context, accessToken string) (oidc.TokenIntrospectResponse, error) {
// // req := &http.Request{}
// // resp, err := p.httpClient.Do(req)
// // if err != nil {
// // }
// // p.endpoints.IntrospectURL
// return nil, nil
// }
func (p *DefaultRP) Userinfo() {}
func (p *DefaultRP) TokenExchange(ctx context.Context, request *grants_tx.TokenExchangeRequest) (newToken *oauth2.Token, err error) {
return p.callTokenEndpoint(request)
}
func (p *DefaultRP) callTokenEndpoint(request interface{}) (newToken *oauth2.Token, err error) {
req, err := utils.FormRequest(p.endpoints.TokenURL, request)
if err != nil {
return nil, err
}
auth := base64.StdEncoding.EncodeToString([]byte(p.config.ClientID + ":" + p.config.ClientSecret))
req.Header.Set("Authorization", "Basic "+auth)
token := new(oauth2.Token)
if err := utils.HttpRequest(p.httpClient, req, token); err != nil {
return nil, err
}
return token, nil
}
func (p *DefaultRP) ClientCredentials(ctx context.Context, scopes ...string) (newToken *oauth2.Token, err error) {
return p.callTokenEndpoint(grants.ClientCredentialsGrantBasic(scopes...))
}
func (p *DefaultRP) DelegationTokenExchange(ctx context.Context, subjectToken string, reqOpts ...grants_tx.TokenExchangeOption) (newToken *oauth2.Token, err error) {
return p.TokenExchange(ctx, DelegationTokenRequest(subjectToken, reqOpts...))
}
func (p *DefaultRP) discover() error {
wellKnown := strings.TrimSuffix(p.config.Issuer, "/") + oidc.DiscoveryEndpoint
req, err := http.NewRequest("GET", wellKnown, nil)
if err != nil {
return err
}
discoveryConfig := new(oidc.DiscoveryConfiguration)
err = utils.HttpRequest(p.httpClient, req, &discoveryConfig)
if err != nil {
return err
}
p.endpoints = rp.GetEndpoints(discoveryConfig)
p.oauthConfig = oauth2.Config{
ClientID: p.config.ClientID,
ClientSecret: p.config.ClientSecret,
Endpoint: p.endpoints.Endpoint,
RedirectURL: p.config.CallbackURL,
Scopes: p.config.Scopes,
}
return nil
}
func (p *DefaultRP) trySetStateCookie(w http.ResponseWriter, state string) error {
if p.cookieHandler != nil {
if err := p.cookieHandler.SetQueryCookie(w, stateParam, state); err != nil {
return err
}
}
return nil
}
func (p *DefaultRP) tryReadStateCookie(w http.ResponseWriter, r *http.Request) (state string, err error) {
if p.cookieHandler == nil {
return r.FormValue(stateParam), nil
}
state, err = p.cookieHandler.CheckQueryCookie(r, stateParam)
if err != nil {
return "", err
}
p.cookieHandler.DeleteCookie(w, stateParam)
return state, nil
}

View file

@ -0,0 +1,13 @@
package defaults
import (
"github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
)
//DelegationTokenRequest is an implementation of TokenExchangeRequest
//it exchanges a "urn:ietf:params:oauth:token-type:access_token" with an optional
//"urn:ietf:params:oauth:token-type:access_token" actor token for a
//"urn:ietf:params:oauth:token-type:access_token" delegation token
func DelegationTokenRequest(subjectToken string, opts ...tokenexchange.TokenExchangeOption) *tokenexchange.TokenExchangeRequest {
return tokenexchange.NewTokenExchangeRequest(subjectToken, tokenexchange.AccessTokenType, opts...)
}

58
pkg/rp/defaults/error.go Normal file
View file

@ -0,0 +1,58 @@
package defaults
import (
"fmt"
"time"
)
var (
ErrIssuerInvalid = func(expected, actual string) *validationError {
return ValidationError("Issuer does not match. Expected: %s, got: %s", expected, actual)
}
ErrAudienceMissingClientID = func(clientID string) *validationError {
return ValidationError("Audience is not valid. Audience must contain client_id (%s)", clientID)
}
ErrAzpMissing = func() *validationError {
return ValidationError("Authorized Party is not set. If Token is valid for multiple audiences, azp must not be empty")
}
ErrAzpInvalid = func(azp, clientID string) *validationError {
return ValidationError("Authorized Party is not valid. azp (%s) must be equal to client_id (%s)", azp, clientID)
}
ErrExpInvalid = func(exp time.Time) *validationError {
return ValidationError("Token has expired %v", exp)
}
ErrIatInFuture = func(exp, now time.Time) *validationError {
return ValidationError("IssuedAt of token is in the future (%v, now with offset: %v)", exp, now)
}
ErrIatToOld = func(maxAge, iat time.Time) *validationError {
return ValidationError("IssuedAt of token must not be older than %v, but was %v (%v to old)", maxAge, iat, maxAge.Sub(iat))
}
ErrNonceInvalid = func(expected, actual string) *validationError {
return ValidationError("nonce does not match. Expected: %s, got: %s", expected, actual)
}
ErrAcrInvalid = func(expected []string, actual string) *validationError {
return ValidationError("acr is invalid. Expected one of: %v, got: %s", expected, actual)
}
ErrAuthTimeNotPresent = func() *validationError {
return ValidationError("claim `auth_time` of token is missing")
}
ErrAuthTimeToOld = func(maxAge, authTime time.Time) *validationError {
return ValidationError("Auth Time of token must not be older than %v, but was %v (%v to old)", maxAge, authTime, maxAge.Sub(authTime))
}
ErrSignatureInvalidPayload = func() *validationError {
return ValidationError("Signature does not match Payload")
}
)
func ValidationError(message string, args ...interface{}) *validationError {
return &validationError{fmt.Sprintf(message, args...)} //TODO: impl
}
type validationError struct {
message string
}
func (v *validationError) Error() string {
return v.message
}

456
pkg/rp/defaults/verifier.go Normal file
View file

@ -0,0 +1,456 @@
package defaults
import (
"bytes"
"context"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"encoding/json"
"fmt"
"hash"
"strings"
"time"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/rp"
str_utils "github.com/caos/utils/strings"
)
func NewVerifier(issuer, clientID string, keySet oidc.KeySet, confOpts ...ConfFunc) rp.Verifier {
conf := &verifierConfig{
issuer: issuer,
clientID: clientID,
iat: &iatConfig{
// offset: time.Duration(500 * time.Millisecond),
},
}
for _, opt := range confOpts {
if opt != nil {
opt(conf)
}
}
return &Verifier{config: conf, keySet: keySet}
}
type Verifier struct {
config *verifierConfig
keySet oidc.KeySet
}
type ConfFunc func(*verifierConfig)
func WithIgnoreIssuedAt() func(*verifierConfig) {
return func(conf *verifierConfig) {
conf.iat.ignore = true
}
}
func WithIssuedAtOffset(offset time.Duration) func(*verifierConfig) {
return func(conf *verifierConfig) {
conf.iat.offset = offset
}
}
func WithIssuedAtMaxAge(maxAge time.Duration) func(*verifierConfig) {
return func(conf *verifierConfig) {
conf.iat.maxAge = maxAge
}
}
func WithNonce(nonce string) func(*verifierConfig) {
return func(conf *verifierConfig) {
conf.nonce = nonce
}
}
func WithACRVerifier(verifier ACRVerifier) func(*verifierConfig) {
return func(conf *verifierConfig) {
conf.acr = verifier
}
}
func WithAuthTimeMaxAge(maxAge time.Duration) func(*verifierConfig) {
return func(conf *verifierConfig) {
conf.maxAge = maxAge
}
}
func WithSupportedSigningAlgorithms(algs ...string) func(*verifierConfig) {
return func(conf *verifierConfig) {
conf.supportedSignAlgs = algs
}
}
// func WithVerifierHTTPClient(client *http.Client) func(*verifierConfig) {
// return func(conf *verifierConfig) {
// conf.httpClient = client
// }
// }
type verifierConfig struct {
issuer string
clientID string
nonce string
iat *iatConfig
acr ACRVerifier
maxAge time.Duration
supportedSignAlgs []string
// httpClient *http.Client
now time.Time
}
type iatConfig struct {
ignore bool
offset time.Duration
maxAge time.Duration
}
type ACRVerifier func(string) error
func DefaultACRVerifier(possibleValues []string) func(string) error {
return func(acr string) error {
if !str_utils.Contains(possibleValues, acr) {
return ErrAcrInvalid(possibleValues, acr)
}
return nil
}
}
func (v *Verifier) Verify(ctx context.Context, accessToken, idTokenString string) (*oidc.IDTokenClaims, error) {
v.config.now = time.Now().UTC()
idToken, err := v.VerifyIDToken(ctx, idTokenString)
if err != nil {
return nil, err
}
if err := v.verifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil { //TODO: sig from token
return nil, err
}
return idToken, nil
}
func (v *Verifier) now() time.Time {
if v.config.now.IsZero() {
v.config.now = time.Now().UTC().Round(time.Second)
}
return v.config.now
}
func (v *Verifier) VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) {
//1. if encrypted --> decrypt
decrypted, err := v.decryptToken(idTokenString)
if err != nil {
return nil, err
}
claims, payload, err := v.parseToken(decrypted)
if err != nil {
return nil, err
}
// token, err := jwt.ParseWithClaims(decrypted, claims, func(token *jwt.Token) (interface{}, error) {
//2, check issuer (exact match)
if err := v.checkIssuer(claims.Issuer); err != nil {
return nil, err
}
//3. check aud (aud must contain client_id, all aud strings must be allowed)
if err = v.checkAudience(claims.Audiences); err != nil {
return nil, err
}
if err = v.checkAuthorizedParty(claims.Audiences, claims.AuthorizedParty); err != nil {
return nil, err
}
//6. check signature by keys
//7. check alg default is rs256
//8. check if alg is mac based (hs...) -> audience contains client_id. for validation use utf-8 representation of your client_secret
claims.Signature, err = v.checkSignature(ctx, decrypted, payload)
if err != nil {
return nil, err
}
//9. check exp before now
if err = v.checkExpiration(claims.Expiration); err != nil {
return nil, err
}
//10. check iat duration is optional (can be checked)
if err = v.checkIssuedAt(claims.IssuedAt); err != nil {
return nil, err
}
//11. check nonce (check if optional possible) id_token.nonce == sentNonce
if err = v.checkNonce(claims.Nonce); err != nil {
return nil, err
}
//12. if acr requested check acr
if err = v.checkAuthorizationContextClassReference(claims.AuthenticationContextClassReference); err != nil {
return nil, err
}
//13. if auth_time requested check if auth_time is less than max age
if err = v.checkAuthTime(claims.AuthTime); err != nil {
return nil, err
}
//return idtoken struct, err
return claims, nil
// })
// _ = token
// return err
}
func (v *Verifier) parseToken(tokenString string) (*oidc.IDTokenClaims, []byte, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, nil, ValidationError("token contains an invalid number of segments") //TODO: err NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
}
idToken := new(oidc.IDTokenClaims)
err = json.Unmarshal(payload, idToken)
return idToken, payload, err
}
func (v *Verifier) checkIssuer(issuer string) error {
if v.config.issuer != issuer {
return ErrIssuerInvalid(v.config.issuer, issuer)
}
return nil
}
func (v *Verifier) checkAudience(audiences []string) error {
if !str_utils.Contains(audiences, v.config.clientID) {
return ErrAudienceMissingClientID(v.config.clientID)
}
//TODO: check aud trusted
return nil
}
//4. if multiple aud strings --> check if azp
//5. if azp --> check azp == client_id
func (v *Verifier) checkAuthorizedParty(audiences []string, authorizedParty string) error {
if len(audiences) > 1 {
if authorizedParty == "" {
return ErrAzpMissing()
}
}
if authorizedParty != "" && authorizedParty != v.config.clientID {
return ErrAzpInvalid(authorizedParty, v.config.clientID)
}
return nil
}
func (v *Verifier) checkSignature(ctx context.Context, idTokenString string, payload []byte) (jose.SignatureAlgorithm, error) {
jws, err := jose.ParseSigned(idTokenString)
if err != nil {
return "", err
}
if len(jws.Signatures) == 0 {
return "", nil //TODO: error
}
if len(jws.Signatures) > 1 {
return "", nil //TODO: error
}
sig := jws.Signatures[0]
supportedSigAlgs := v.config.supportedSignAlgs
if len(supportedSigAlgs) == 0 {
supportedSigAlgs = []string{"RS256"}
}
if !str_utils.Contains(supportedSigAlgs, sig.Header.Algorithm) {
return "", fmt.Errorf("oidc: id token signed with unsupported algorithm, expected %q got %q", supportedSigAlgs, sig.Header.Algorithm)
}
signedPayload, err := v.keySet.VerifySignature(ctx, jws)
if err != nil {
return "", err
//TODO:
}
if !bytes.Equal(signedPayload, payload) {
return "", ErrSignatureInvalidPayload() //TODO: err
}
return jose.SignatureAlgorithm(sig.Header.Algorithm), nil
}
// type KeySet struct {
// remoteURL url.URL
// httpClient *http.Client
// keys []jose.JSONWebKey
// m sync.Mutex
// inflight *inflight
// }
// func (k *KeySet) GetKey(ctx context.Context, keyID string) (*jose.JSONWebKey, error) {
// key, err := k.getKey(keyID)
// if err != nil {
// //lock
// k.updateKeys(ctx)
// //unlock
// return k.getKey(keyID)
// }
// return key, nil
// }
// func (k *KeySet) getKey(keyID string) (*jose.JSONWebKey, error) {
// k.m.Lock()
// keys := k.keys
// k.m.Unlock()
// for _, key := range keys {
// if key.KeyID == keyID {
// return &key, nil
// }
// }
// return nil, nil //TODO: err
// }
// func (k *KeySet) retrieveNewKeys(ctx context.Context) ([]jose.JSONWebKey, error) {
// resp, err := k.httpClient.Get(k.remoteURL.String())
// if err != nil {
// return nil, err
// }
// if resp.StatusCode != http.StatusOK {
// return nil, nil //TODO: errs
// }
// defer resp.Body.Close()
// body, err := ioutil.ReadAll(resp.Body)
// if err != nil {
// return nil, err
// }
// var keySet jose.JSONWebKeySet
// err = json.Unmarshal(body, keySet)
// if err != nil {
// return nil, err
// }
// return keySet.Keys, nil
// }
// func (k *KeySet) updateKeys(ctx context.Context) error {
// k.inflight
// k.m.Lock()
// k.keys = keySet.Keys
// return nil
// }
func (v *Verifier) checkExpiration(expiration time.Time) error {
expiration = expiration.Round(time.Second)
if !v.now().Before(expiration) {
return ErrExpInvalid(expiration)
}
return nil
}
func (v *Verifier) checkIssuedAt(issuedAt time.Time) error {
if v.config.iat.ignore {
return nil
}
issuedAt = issuedAt.Round(time.Second)
offset := v.now().Add(v.config.iat.offset).Round(time.Second)
if issuedAt.After(offset) {
return ErrIatInFuture(issuedAt, offset)
}
if v.config.iat.maxAge == 0 {
return nil
}
maxAge := v.now().Add(-v.config.iat.maxAge).Round(time.Second)
if issuedAt.Before(maxAge) {
return ErrIatToOld(maxAge, issuedAt)
}
return nil
}
func (v *Verifier) checkNonce(nonce string) error {
if v.config.nonce == "" {
return nil
}
if v.config.nonce != nonce {
return ErrNonceInvalid(v.config.nonce, nonce)
}
return nil
}
func (v *Verifier) checkAuthorizationContextClassReference(acr string) error {
if v.config.acr != nil {
return v.config.acr(acr)
}
return nil
}
func (v *Verifier) checkAuthTime(authTime time.Time) error {
if v.config.maxAge == 0 {
return nil
}
if authTime.IsZero() {
return ErrAuthTimeNotPresent()
}
authTime = authTime.Round(time.Second)
maxAge := v.now().Add(-v.config.maxAge).Round(time.Second)
if authTime.Before(maxAge) {
return ErrAuthTimeToOld(maxAge, authTime)
}
return nil
}
func (v *Verifier) decryptToken(tokenString string) (string, error) {
return tokenString, nil //TODO: impl
}
// func (v *Verifier) parseIDToken(tokenString string) (IDToken, error) {
// var claims jwt.StandardClaims
// token, err := jwt.ParseWithClaims(tokenString, &claims, func(token *jwt.Token) (interface{}, error) {
// claims.VerifyIssuer(v.config.Issuer, true)
// // return token.Header["alg"]
// })
// payload, err := parseJWT(rawIDToken)
// if err != nil {
// return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
// }
// var token IDToken
// if err := json.Unmarshal(payload, &token); err != nil {
// return nil, fmt.Errorf("oidc: failed to unmarshal claims: %v", err)
// }
// return token, nil //TODO: impl
// }
func (v *Verifier) verifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
if atHash == "" {
return nil //TODO: return error
}
tokenHash, err := getHashAlgorithm(sigAlgorithm)
if err != nil {
return err
}
tokenHash.Write([]byte(accessToken)) // hash documents that Write will never return an error
sum := tokenHash.Sum(nil)[:tokenHash.Size()/2]
actual := base64.RawURLEncoding.EncodeToString(sum)
if actual != atHash {
return nil //TODO: error
}
return nil
}
func getHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
switch sigAlgorithm {
case jose.RS256, jose.ES256, jose.PS256:
return sha256.New(), nil
case jose.RS384, jose.ES384, jose.PS384:
return sha512.New384(), nil
case jose.RS512, jose.ES512, jose.PS512:
return sha512.New(), nil
default:
return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm)
}
}