lot of unfinished changes

This commit is contained in:
Livio Amstutz 2020-09-08 16:07:49 +02:00
parent 9cb0fff23f
commit a37a8461a5
16 changed files with 502 additions and 328 deletions

View file

@ -1,6 +1,7 @@
package oidc package oidc
import ( import (
"encoding/json"
"errors" "errors"
"strings" "strings"
"time" "time"
@ -64,7 +65,7 @@ const (
PromptSelectAccount Prompt = "select_account" PromptSelectAccount Prompt = "select_account"
//GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow //GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
GrantTypeCode GrantType = "authorization_code" GrantTypeCode GrantType = "authorization_code"
//GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant //GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
@ -148,10 +149,67 @@ type AccessTokenResponse struct {
} }
type JWTTokenRequest struct { type JWTTokenRequest struct {
Scopes Scopes `schema:"scope"` Issuer string `json:"iss"`
Audience []string `schema:"aud"` Subject string `json:"sub"`
IssuedAt time.Time `schema:"iat"` Scopes Scopes `json:"scope"`
ExpiresAt time.Time `schema:"exp"` Audience string `json:"aud"`
IssuedAt Time `json:"iat"`
ExpiresAt Time `json:"exp"`
}
func (j *JWTTokenRequest) GetClientID() string {
return j.Subject
}
func (j *JWTTokenRequest) GetSubject() string {
return j.Subject
}
func (j *JWTTokenRequest) GetScopes() []string {
return j.Scopes
}
type Time time.Time
func (t *Time) UnmarshalJSON(data []byte) error {
var i int64
if err := json.Unmarshal(data, &i); err != nil {
return err
}
*t = Time(time.Unix(i, 0).UTC())
return nil
}
func (j *JWTTokenRequest) GetIssuer() string {
return j.Issuer
}
func (j *JWTTokenRequest) GetAudience() []string {
return []string{j.Audience}
}
func (j *JWTTokenRequest) GetExpiration() time.Time {
return time.Time(j.ExpiresAt)
}
func (j *JWTTokenRequest) GetIssuedAt() time.Time {
return time.Time(j.IssuedAt)
}
func (j *JWTTokenRequest) GetNonce() string {
return ""
}
func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string {
return ""
}
func (j *JWTTokenRequest) GetAuthTime() time.Time {
return time.Time{}
}
func (j *JWTTokenRequest) GetAuthorizedParty() string {
return ""
} }
type TokenExchangeRequest struct { type TokenExchangeRequest struct {

View file

@ -177,6 +177,42 @@ func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (t *IDTokenClaims) GetIssuer() string {
return t.Issuer
}
func (t *IDTokenClaims) GetAudience() []string {
return t.Audiences
}
func (t *IDTokenClaims) GetExpiration() time.Time {
return t.Expiration
}
func (t *IDTokenClaims) GetIssuedAt() time.Time {
return t.IssuedAt
}
func (t *IDTokenClaims) GetNonce() string {
return t.Nonce
}
func (t *IDTokenClaims) GetAuthenticationContextClassReference() string {
return t.AuthenticationContextClassReference
}
func (t *IDTokenClaims) GetAuthTime() time.Time {
return t.AuthTime
}
func (t *IDTokenClaims) GetAuthorizedParty() string {
return t.AuthorizedParty
}
func (t *IDTokenClaims) SetSignature(alg jose.SignatureAlgorithm) {
t.Signature = alg
}
func (j *jsonToken) UnmarshalUserinfoProfile() UserinfoProfile { func (j *jsonToken) UnmarshalUserinfoProfile() UserinfoProfile {
locale, _ := language.Parse(j.Locale) locale, _ := language.Parse(j.Locale)
return UserinfoProfile{ return UserinfoProfile{

210
pkg/oidc/verifier.go Normal file
View file

@ -0,0 +1,210 @@
package oidc
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/utils"
)
type Claims interface {
GetIssuer() string
GetAudience() []string
GetExpiration() time.Time
GetIssuedAt() time.Time
GetNonce() string
GetAuthenticationContextClassReference() string
GetAuthTime() time.Time
GetAuthorizedParty() string
SetSignature(algorithm jose.SignatureAlgorithm)
}
var (
ErrParse = errors.New("")
ErrIssuerInvalid = errors.New("issuer does not match")
ErrAudience = errors.New("audience is not valid")
ErrAzpMissing = errors.New("authorized party is not set. If Token is valid for multiple audiences, azp must not be empty")
ErrAzpInvalid = errors.New("authorized party is not valid")
ErrSignatureMissing = errors.New("id_token does not contain a signature")
ErrSignatureMultiple = errors.New("id_token contains multiple signatures")
ErrSignatureUnsupportedAlg = errors.New("signature algorithm not supported")
ErrSignatureInvalidPayload = errors.New("signature does not match Payload")
ErrExpired = errors.New("token has expired")
ErrIatInFuture = errors.New("issuedAt of token is in the future")
ErrIatToOld = errors.New("issuedAt of token is to old")
//
//ErrNonceInvalid = func(expected, actual string) *validationError {
// return ValidationError("nonce does not match. Expected: %s, got: %s", expected, actual)
//}
ErrAcrInvalid = errors.New("acr is invalid")
ErrAuthTimeNotPresent = errors.New("claim `auth_time` of token is missing")
ErrAuthTimeToOld = errors.New("auth time of token is to old")
ErrAtHash = errors.New("at_hash does not correspond to access token")
)
//ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
type ACRVerifier func(string) error
func DecryptToken(tokenString string) (string, error) {
return tokenString, nil //TODO: impl
}
func ParseToken(tokenString string, claims interface{}) ([]byte, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("%w: token contains an invalid number of segments", ErrParse)
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("%w: malformed jwt payload: %v", ErrParse, err)
}
err = json.Unmarshal(payload, claims)
return payload, err
}
type Verifier interface {
Issuer() string
ClientID() string
SupportedSignAlgs() []string
KeySet() KeySet
ACR() ACRVerifier
MaxAge() time.Duration
MaxAgeIAT() time.Duration
Offset() time.Duration
}
func CheckIssuer(issuer string, i Verifier) error {
if i.Issuer() != issuer {
return fmt.Errorf("%w: Expected: %s, got: %s", ErrIssuerInvalid, i.Issuer(), issuer)
}
return nil
}
func CheckAudience(audiences []string, i Verifier) error {
if !utils.Contains(audiences, i.ClientID()) {
return fmt.Errorf("%w: Audience must contain client_id (%s)", ErrAudience, i.ClientID())
}
//TODO: check aud trusted
return nil
}
//4. if multiple aud strings --> check if azp
//5. if azp --> check azp == client_id
func CheckAuthorizedParty(audiences []string, authorizedParty string, v Verifier) error {
if len(audiences) > 1 {
if authorizedParty == "" {
return ErrAzpMissing
}
}
if authorizedParty != "" && authorizedParty != v.ClientID() {
return fmt.Errorf("%w: azp %q must be equal to client_id %q", ErrAzpInvalid, authorizedParty, v.ClientID())
}
return nil
}
func CheckSignature(ctx context.Context, idTokenString string, payload []byte, claims Claims, v Verifier) error {
jws, err := jose.ParseSigned(idTokenString)
if err != nil {
return err
}
if len(jws.Signatures) == 0 {
return ErrSignatureMissing
}
if len(jws.Signatures) > 1 {
return ErrSignatureMultiple
}
sig := jws.Signatures[0]
supportedSigAlgs := v.SupportedSignAlgs()
if len(supportedSigAlgs) == 0 {
supportedSigAlgs = []string{"RS256"}
}
if !utils.Contains(supportedSigAlgs, sig.Header.Algorithm) {
return fmt.Errorf("%w: id token signed with unsupported algorithm, expected %q got %q", ErrSignatureUnsupportedAlg, supportedSigAlgs, sig.Header.Algorithm)
}
signedPayload, err := v.KeySet().VerifySignature(ctx, jws)
if err != nil {
return err
}
if !bytes.Equal(signedPayload, payload) {
return ErrSignatureInvalidPayload
}
claims.SetSignature(jose.SignatureAlgorithm(sig.Header.Algorithm))
return nil
}
func CheckExpiration(expiration time.Time, v Verifier) error {
expiration = expiration.Round(time.Second)
if !time.Now().UTC().Add(v.Offset()).Before(expiration) {
return ErrExpired
}
return nil
}
func CheckIssuedAt(issuedAt time.Time, v Verifier) error {
issuedAt = issuedAt.Round(time.Second)
offset := time.Now().UTC().Add(v.Offset()).Round(time.Second)
if issuedAt.After(offset) {
return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, offset)
}
if v.MaxAgeIAT() == 0 {
return nil
}
maxAge := time.Now().UTC().Add(-v.MaxAgeIAT()).Round(time.Second)
if issuedAt.Before(maxAge) {
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt))
}
return nil
}
/*
func (v *DefaultVerifier) CheckNonce(nonce string) error {
if v.config.nonce == "" {
return nil
}
if v.config.nonce != nonce {
return ErrNonceInvalid(v.config.nonce, nonce)
}
return nil
}*/
func CheckAuthorizationContextClassReference(acr string, v Verifier) error {
if v.ACR() != nil {
if err := v.ACR()(acr); err != nil {
return fmt.Errorf("%w: %v", ErrAcrInvalid, err)
}
}
return nil
}
func CheckAuthTime(authTime time.Time, v Verifier) error {
if v.MaxAge() == 0 {
return nil
}
if authTime.IsZero() {
return ErrAuthTimeNotPresent
}
authTime = authTime.Round(time.Second)
maxAge := time.Now().UTC().Add(-v.MaxAge()).Round(time.Second)
if authTime.Before(maxAge) {
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAge.Sub(authTime))
}
return nil
}

View file

@ -275,7 +275,7 @@ func (p *DefaultOP) Crypto() Crypto {
return p.crypto return p.crypto
} }
func (p *DefaultOP) Verifier() rp.Verifier { func (p *DefaultOP) ClientJWTVerifier() rp.Verifier {
return p.verifier return p.verifier
} }

View file

@ -7,6 +7,12 @@ import (
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/utils"
) )
func DiscoveryHandler(c Configuration, s Signer) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
Discover(w, CreateDiscoveryConfig(c, s))
}
}
func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) { func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
utils.MarshalJSON(w, config) utils.MarshalJSON(w, config)
} }

View file

@ -97,7 +97,7 @@ func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0, arg1, arg2 interface{
} }
// CreateToken mocks base method // CreateToken mocks base method
func (m *MockStorage) CreateToken(arg0 context.Context, arg1 op.AuthRequest) (string, time.Time, error) { func (m *MockStorage) CreateToken(arg0 context.Context, arg1 op.TokenRequest) (string, time.Time, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateToken", arg0, arg1) ret := m.ctrl.Call(m, "CreateToken", arg0, arg1)
ret0, _ := ret[0].(string) ret0, _ := ret[0].(string)

View file

@ -22,9 +22,12 @@ type OpenIDProvider interface {
HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request) HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request)
HandleExchange(w http.ResponseWriter, r *http.Request) HandleExchange(w http.ResponseWriter, r *http.Request)
HandleUserinfo(w http.ResponseWriter, r *http.Request) HandleUserinfo(w http.ResponseWriter, r *http.Request)
HandleEndSession(w http.ResponseWriter, r *http.Request) //HandleEndSession(w http.ResponseWriter, r *http.Request)
HandleKeys(w http.ResponseWriter, r *http.Request) HandleKeys(w http.ResponseWriter, r *http.Request)
HttpHandler() http.Handler HttpHandler() http.Handler
SessionEnder
Signer() Signer
Probes() []ProbesFn
} }
type HttpInterceptor func(http.Handler) http.Handler type HttpInterceptor func(http.Handler) http.Handler
@ -42,13 +45,13 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router
handlers.AllowedOriginValidator(allowAllOrigins), handlers.AllowedOriginValidator(allowAllOrigins),
)) ))
router.HandleFunc(healthzEndpoint, Healthz) router.HandleFunc(healthzEndpoint, Healthz)
router.HandleFunc(readinessEndpoint, o.HandleReady) router.HandleFunc(readinessEndpoint, Ready(o.Probes()))
router.HandleFunc(oidc.DiscoveryEndpoint, o.HandleDiscovery) router.HandleFunc(oidc.DiscoveryEndpoint, DiscoveryHandler(o, o.Signer()))
router.Handle(o.AuthorizationEndpoint().Relative(), intercept(o.HandleAuthorize)) router.Handle(o.AuthorizationEndpoint().Relative(), intercept(o.HandleAuthorize))
router.Handle(o.AuthorizationEndpoint().Relative()+"/{id}", intercept(o.HandleAuthorizeCallback)) router.Handle(o.AuthorizationEndpoint().Relative()+"/{id}", intercept(o.HandleAuthorizeCallback))
router.Handle(o.TokenEndpoint().Relative(), intercept(o.HandleExchange)) router.Handle(o.TokenEndpoint().Relative(), intercept(o.HandleExchange))
router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo) router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo)
router.Handle(o.EndSessionEndpoint().Relative(), intercept(o.HandleEndSession)) router.Handle(o.EndSessionEndpoint().Relative(), intercept(EndSessionHandler(o)))
router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys) router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys)
return router return router
} }

View file

@ -14,6 +14,12 @@ func Healthz(w http.ResponseWriter, r *http.Request) {
ok(w) ok(w)
} }
func Ready(probes []ProbesFn) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
Readiness(w, r, probes...)
}
}
func Readiness(w http.ResponseWriter, r *http.Request, probes ...ProbesFn) { func Readiness(w http.ResponseWriter, r *http.Request, probes ...ProbesFn) {
ctx := r.Context() ctx := r.Context()
for _, probe := range probes { for _, probe := range probes {

View file

@ -16,6 +16,12 @@ type SessionEnder interface {
DefaultLogoutRedirectURI() string DefaultLogoutRedirectURI() string
} }
func EndSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
EndSession(w, r, ender)
}
}
func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) { func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
req, err := ParseEndSessionRequest(r, ender.Decoder()) req, err := ParseEndSessionRequest(r, ender.Decoder())
if err != nil { if err != nil {

View file

@ -16,7 +16,7 @@ type AuthStorage interface {
SaveAuthCode(context.Context, string, string) error SaveAuthCode(context.Context, string, string) error
DeleteAuthRequest(context.Context, string) error DeleteAuthRequest(context.Context, string) error
CreateToken(context.Context, AuthRequest) (string, time.Time, error) CreateToken(context.Context, TokenRequest) (string, time.Time, error)
TerminateSession(context.Context, string, string) error TerminateSession(context.Context, string, string) error

View file

@ -14,12 +14,19 @@ type TokenCreator interface {
Crypto() Crypto Crypto() Crypto
} }
type TokenRequest interface {
GetClientID() string
GetSubject() string
GetAudience() []string
GetScopes() []string
}
func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client, creator TokenCreator, createAccessToken bool, code string) (*oidc.AccessTokenResponse, error) { func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client, creator TokenCreator, createAccessToken bool, code string) (*oidc.AccessTokenResponse, error) {
var accessToken string var accessToken string
var validity time.Duration var validity time.Duration
if createAccessToken { if createAccessToken {
var err error var err error
accessToken, validity, err = CreateAccessToken(ctx, authReq, client, creator) accessToken, validity, err = CreateAccessToken(ctx, authReq, client.AccessTokenType(), creator)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -43,8 +50,8 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client
}, nil }, nil
} }
func CreateJWTTokenResponse(ctx context.Context, authReq AuthRequest, client Client, creator TokenCreator) (*oidc.AccessTokenResponse, error) { func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator) (*oidc.AccessTokenResponse, error) {
accessToken, validity, err := CreateAccessToken(ctx, authReq, client, creator) accessToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -57,13 +64,13 @@ func CreateJWTTokenResponse(ctx context.Context, authReq AuthRequest, client Cli
}, nil }, nil
} }
func CreateAccessToken(ctx context.Context, authReq AuthRequest, client Client, creator TokenCreator) (token string, validity time.Duration, err error) { func CreateAccessToken(ctx context.Context, authReq TokenRequest, accessTokenType AccessTokenType, creator TokenCreator) (token string, validity time.Duration, err error) {
id, exp, err := creator.Storage().CreateToken(ctx, authReq) id, exp, err := creator.Storage().CreateToken(ctx, authReq)
if err != nil { if err != nil {
return "", 0, err return "", 0, err
} }
validity = exp.Sub(time.Now().UTC()) validity = exp.Sub(time.Now().UTC())
if client.AccessTokenType() == AccessTokenTypeJWT { if accessTokenType == AccessTokenTypeJWT {
token, err = CreateJWT(creator.Issuer(), authReq, exp, id, creator.Signer()) token, err = CreateJWT(creator.Issuer(), authReq, exp, id, creator.Signer())
return return
} }
@ -75,7 +82,7 @@ func CreateBearerToken(id string, crypto Crypto) (string, error) {
return crypto.Encrypt(id) return crypto.Encrypt(id)
} }
func CreateJWT(issuer string, authReq AuthRequest, exp time.Time, id string, signer Signer) (string, error) { func CreateJWT(issuer string, authReq TokenRequest, exp time.Time, id string, signer Signer) (string, error) {
now := time.Now().UTC() now := time.Now().UTC()
nbf := now nbf := now
claims := &oidc.AccessTokenClaims{ claims := &oidc.AccessTokenClaims{

View file

@ -3,7 +3,6 @@ package op
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net/http" "net/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
@ -22,7 +21,7 @@ type Exchanger interface {
type VerifyExchanger interface { type VerifyExchanger interface {
Exchanger Exchanger
Verifier() rp.Verifier ClientJWTVerifier() rp.Verifier
} }
func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
@ -121,17 +120,31 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque
return authReq, nil return authReq, nil
} }
type ClientJWTVerifier struct {
claims *oidc.JWTTokenRequest
Storage
}
func (c *ClientJWTVerifier) Issuer() string {
client, err := Storage.GetClientByClientID(context.TODO(), c.claims.Issuer)
return client.GetID()
}
func JWTExchange(w http.ResponseWriter, r *http.Request, exchanger VerifyExchanger) { func JWTExchange(w http.ResponseWriter, r *http.Request, exchanger VerifyExchanger) {
assertion, err := ParseJWTTokenRequest(r, exchanger.Decoder()) assertion, err := ParseJWTTokenRequest(r, exchanger.Decoder())
if err != nil { if err != nil {
RequestError(w, r, err) RequestError(w, r, err)
} }
claims, err := exchanger.Verifier().Verify(r.Context(), "", assertion) claims := new(oidc.JWTTokenRequest)
//var keyset oidc.KeySet
verifier := new(ClientJWTVerifier)
verifier.claims = claims
err = verifier.VerifyToken(r.Context(), assertion, claims)
if err != nil {
RequestError(w, r, err)
}
fmt.Println(claims, err) resp, err := CreateJWTTokenResponse(r.Context(), claims, exchanger)
var authReq AuthRequest
var client Client
resp, err := CreateJWTTokenResponse(r.Context(), authReq, client, exchanger)
if err != nil { if err != nil {
RequestError(w, r, err) RequestError(w, r, err)
return return
@ -139,7 +152,7 @@ func JWTExchange(w http.ResponseWriter, r *http.Request, exchanger VerifyExchang
utils.MarshalJSON(w, resp) utils.MarshalJSON(w, resp)
} }
func ParseJWTTokenRequest(r *http.Request, decoder *schema.Decoder) (string, error) { func ParseJWTTokenRequest(r *http.Request, decoder utils.Decoder) (string, error) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
return "", ErrInvalidRequest("error parsing form") return "", ErrInvalidRequest("error parsing form")

View file

@ -1,16 +1,10 @@
package rp package rp
import ( import (
"bytes"
"context" "context"
"encoding/base64"
"encoding/json"
"fmt" "fmt"
"strings"
"time" "time"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/utils"
) )
@ -24,9 +18,6 @@ type DefaultVerifier struct {
//ConfFunc is the type for providing dynamic options to the DefaultVerfifier //ConfFunc is the type for providing dynamic options to the DefaultVerfifier
type ConfFunc func(*verifierConfig) type ConfFunc func(*verifierConfig)
//ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
type ACRVerifier func(string) error
//NewDefaultVerifier creates `DefaultVerifier` with the given //NewDefaultVerifier creates `DefaultVerifier` with the given
//issuer, clientID, keyset and possible configOptions //issuer, clientID, keyset and possible configOptions
func NewDefaultVerifier(issuer, clientID string, keySet oidc.KeySet, confOpts ...ConfFunc) Verifier { func NewDefaultVerifier(issuer, clientID string, keySet oidc.KeySet, confOpts ...ConfFunc) Verifier {
@ -90,7 +81,7 @@ func WithNonce(nonce string) func(*verifierConfig) {
} }
//WithACRVerifier sets the verifier for the acr claim //WithACRVerifier sets the verifier for the acr claim
func WithACRVerifier(verifier ACRVerifier) func(*verifierConfig) { func WithACRVerifier(verifier oidc.ACRVerifier) func(*verifierConfig) {
return func(conf *verifierConfig) { return func(conf *verifierConfig) {
conf.acr = verifier conf.acr = verifier
} }
@ -117,7 +108,7 @@ type verifierConfig struct {
ignoreAudience bool ignoreAudience bool
ignoreExpiration bool ignoreExpiration bool
iat *iatConfig iat *iatConfig
acr ACRVerifier acr oidc.ACRVerifier
maxAge time.Duration maxAge time.Duration
supportedSignAlgs []string supportedSignAlgs []string
@ -134,10 +125,10 @@ type iatConfig struct {
//DefaultACRVerifier implements `ACRVerifier` returning an error //DefaultACRVerifier implements `ACRVerifier` returning an error
//if non of the provided values matches the acr claim //if non of the provided values matches the acr claim
func DefaultACRVerifier(possibleValues []string) ACRVerifier { func DefaultACRVerifier(possibleValues []string) oidc.ACRVerifier {
return func(acr string) error { return func(acr string) error {
if !utils.Contains(possibleValues, acr) { if !utils.Contains(possibleValues, acr) {
return ErrAcrInvalid(possibleValues, acr) return fmt.Errorf("expected one of: %v, got: %q", possibleValues, acr)
} }
return nil return nil
} }
@ -148,88 +139,13 @@ func DefaultACRVerifier(possibleValues []string) ACRVerifier {
//and https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation //and https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
func (v *DefaultVerifier) Verify(ctx context.Context, accessToken, idTokenString string) (*oidc.IDTokenClaims, error) { func (v *DefaultVerifier) Verify(ctx context.Context, accessToken, idTokenString string) (*oidc.IDTokenClaims, error) {
v.config.now = time.Now().UTC() v.config.now = time.Now().UTC()
// idToken, err := v.VerifyIDToken(ctx, idTokenString) return VerifyTokens(ctx, accessToken, idTokenString, v)
// if err != nil {
// return nil, err
// }
// if err := v.verifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil { //TODO: sig from token
// return nil, err
// }
// return idToken, nil
// TODO: verifiy
decrypted, err := v.decryptToken(idTokenString)
if err != nil {
return nil, err
}
claims, _, err := v.parseToken(decrypted)
if err != nil {
return nil, err
}
return claims, nil
} }
//Verify implements the `VerifyIDToken` method of the `Verifier` interface //Verify implements the `VerifyIDToken` method of the `Verifier` interface
//according to https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation //according to https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func (v *DefaultVerifier) VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) { func (v *DefaultVerifier) VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) {
//1. if encrypted --> decrypt return VerifyIDToken(ctx, idTokenString, v)
decrypted, err := v.decryptToken(idTokenString)
if err != nil {
return nil, err
}
claims, payload, err := v.parseToken(decrypted)
if err != nil {
return nil, err
}
// token, err := jwt.ParseWithClaims(decrypted, claims, func(token *jwt.Token) (interface{}, error) {
//2, check issuer (exact match)
if err := v.checkIssuer(claims.Issuer); err != nil {
return nil, err
}
//3. check aud (aud must contain client_id, all aud strings must be allowed)
if err = v.checkAudience(claims.Audiences); err != nil {
return nil, err
}
if err = v.checkAuthorizedParty(claims.Audiences, claims.AuthorizedParty); err != nil {
return nil, err
}
//6. check signature by keys
//7. check alg default is rs256
//8. check if alg is mac based (hs...) -> audience contains client_id. for validation use utf-8 representation of your client_secret
claims.Signature, err = v.checkSignature(ctx, decrypted, payload)
if err != nil {
return nil, err
}
//9. check exp before now
if err = v.checkExpiration(claims.Expiration); err != nil {
return nil, err
}
//10. check iat duration is optional (can be checked)
if err = v.checkIssuedAt(claims.IssuedAt); err != nil {
return nil, err
}
//11. check nonce (check if optional possible) id_token.nonce == sentNonce
if err = v.checkNonce(claims.Nonce); err != nil {
return nil, err
}
//12. if acr requested check acr
if err = v.checkAuthorizationContextClassReference(claims.AuthenticationContextClassReference); err != nil {
return nil, err
}
//13. if auth_time requested check if auth_time is less than max age
if err = v.checkAuthTime(claims.AuthTime); err != nil {
return nil, err
}
return claims, nil
} }
func (v *DefaultVerifier) now() time.Time { func (v *DefaultVerifier) now() time.Time {
@ -239,161 +155,34 @@ func (v *DefaultVerifier) now() time.Time {
return v.config.now return v.config.now
} }
func (v *DefaultVerifier) parseToken(tokenString string) (*oidc.IDTokenClaims, []byte, error) { func (v *DefaultVerifier) Issuer() string {
parts := strings.Split(tokenString, ".") return v.config.issuer
if len(parts) != 3 {
return nil, nil, ValidationError("token contains an invalid number of segments") //TODO: err NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
}
idToken := new(oidc.IDTokenClaims)
err = json.Unmarshal(payload, idToken)
return idToken, payload, err
} }
func (v *DefaultVerifier) checkIssuer(issuer string) error { func (v *DefaultVerifier) ClientID() string {
if v.config.issuer != issuer { return v.config.clientID
return ErrIssuerInvalid(v.config.issuer, issuer)
}
return nil
} }
func (v *DefaultVerifier) checkAudience(audiences []string) error { func (v *DefaultVerifier) SupportedSignAlgs() []string {
if v.config.ignoreAudience { return v.config.supportedSignAlgs
return nil
}
if !utils.Contains(audiences, v.config.clientID) {
return ErrAudienceMissingClientID(v.config.clientID)
}
//TODO: check aud trusted
return nil
} }
//4. if multiple aud strings --> check if azp func (v *DefaultVerifier) KeySet() oidc.KeySet {
//5. if azp --> check azp == client_id return v.keySet
func (v *DefaultVerifier) checkAuthorizedParty(audiences []string, authorizedParty string) error {
if v.config.ignoreAudience {
return nil
}
if len(audiences) > 1 {
if authorizedParty == "" {
return ErrAzpMissing()
}
}
if authorizedParty != "" && authorizedParty != v.config.clientID {
return ErrAzpInvalid(authorizedParty, v.config.clientID)
}
return nil
} }
func (v *DefaultVerifier) checkSignature(ctx context.Context, idTokenString string, payload []byte) (jose.SignatureAlgorithm, error) { func (v *DefaultVerifier) ACR() oidc.ACRVerifier {
jws, err := jose.ParseSigned(idTokenString) return v.config.acr
if err != nil {
return "", err
}
if len(jws.Signatures) == 0 {
return "", ErrSignatureMissing()
}
if len(jws.Signatures) > 1 {
return "", ErrSignatureMultiple()
}
sig := jws.Signatures[0]
supportedSigAlgs := v.config.supportedSignAlgs
if len(supportedSigAlgs) == 0 {
supportedSigAlgs = []string{"RS256"}
}
if !utils.Contains(supportedSigAlgs, sig.Header.Algorithm) {
return "", fmt.Errorf("oidc: id token signed with unsupported algorithm, expected %q got %q", supportedSigAlgs, sig.Header.Algorithm)
}
signedPayload, err := v.keySet.VerifySignature(ctx, jws)
if err != nil {
return "", err
}
if !bytes.Equal(signedPayload, payload) {
return "", ErrSignatureInvalidPayload()
}
return jose.SignatureAlgorithm(sig.Header.Algorithm), nil
} }
func (v *DefaultVerifier) checkExpiration(expiration time.Time) error { func (v *DefaultVerifier) MaxAge() time.Duration {
if v.config.ignoreExpiration { return v.config.maxAge
return nil
}
expiration = expiration.Round(time.Second)
if !v.now().Before(expiration) {
return ErrExpInvalid(expiration)
}
return nil
} }
func (v *DefaultVerifier) checkIssuedAt(issuedAt time.Time) error { func (v *DefaultVerifier) MaxAgeIAT() time.Duration {
if v.config.iat.ignore { return v.config.iat.maxAge
return nil
}
issuedAt = issuedAt.Round(time.Second)
offset := v.now().Add(v.config.iat.offset).Round(time.Second)
if issuedAt.After(offset) {
return ErrIatInFuture(issuedAt, offset)
}
if v.config.iat.maxAge == 0 {
return nil
}
maxAge := v.now().Add(-v.config.iat.maxAge).Round(time.Second)
if issuedAt.Before(maxAge) {
return ErrIatToOld(maxAge, issuedAt)
}
return nil
}
func (v *DefaultVerifier) checkNonce(nonce string) error {
if v.config.nonce == "" {
return nil
}
if v.config.nonce != nonce {
return ErrNonceInvalid(v.config.nonce, nonce)
}
return nil
}
func (v *DefaultVerifier) checkAuthorizationContextClassReference(acr string) error {
if v.config.acr != nil {
return v.config.acr(acr)
}
return nil
}
func (v *DefaultVerifier) checkAuthTime(authTime time.Time) error {
if v.config.maxAge == 0 {
return nil
}
if authTime.IsZero() {
return ErrAuthTimeNotPresent()
}
authTime = authTime.Round(time.Second)
maxAge := v.now().Add(-v.config.maxAge).Round(time.Second)
if authTime.Before(maxAge) {
return ErrAuthTimeToOld(maxAge, authTime)
}
return nil
} }
func (v *DefaultVerifier) decryptToken(tokenString string) (string, error) { func (v *DefaultVerifier) Offset() time.Duration {
return tokenString, nil //TODO: impl return v.config.iat.offset
}
func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
if atHash == "" {
return nil
}
actual, err := oidc.ClaimHash(accessToken, sigAlgorithm)
if err != nil {
return err
}
if actual != atHash {
return ErrAtHash()
}
return nil
} }

View file

@ -1,67 +0,0 @@
package rp
import (
"fmt"
"time"
)
var (
ErrIssuerInvalid = func(expected, actual string) *validationError {
return ValidationError("Issuer does not match. Expected: %s, got: %s", expected, actual)
}
ErrAudienceMissingClientID = func(clientID string) *validationError {
return ValidationError("Audience is not valid. Audience must contain client_id (%s)", clientID)
}
ErrAzpMissing = func() *validationError {
return ValidationError("Authorized Party is not set. If Token is valid for multiple audiences, azp must not be empty")
}
ErrAzpInvalid = func(azp, clientID string) *validationError {
return ValidationError("Authorized Party is not valid. azp (%s) must be equal to client_id (%s)", azp, clientID)
}
ErrExpInvalid = func(exp time.Time) *validationError {
return ValidationError("Token has expired %v", exp)
}
ErrIatInFuture = func(exp, now time.Time) *validationError {
return ValidationError("IssuedAt of token is in the future (%v, now with offset: %v)", exp, now)
}
ErrIatToOld = func(maxAge, iat time.Time) *validationError {
return ValidationError("IssuedAt of token must not be older than %v, but was %v (%v to old)", maxAge, iat, maxAge.Sub(iat))
}
ErrNonceInvalid = func(expected, actual string) *validationError {
return ValidationError("nonce does not match. Expected: %s, got: %s", expected, actual)
}
ErrAcrInvalid = func(expected []string, actual string) *validationError {
return ValidationError("acr is invalid. Expected one of: %v, got: %s", expected, actual)
}
ErrAuthTimeNotPresent = func() *validationError {
return ValidationError("claim `auth_time` of token is missing")
}
ErrAuthTimeToOld = func(maxAge, authTime time.Time) *validationError {
return ValidationError("Auth Time of token must not be older than %v, but was %v (%v to old)", maxAge, authTime, maxAge.Sub(authTime))
}
ErrSignatureMissing = func() *validationError {
return ValidationError("id_token does not contain a signature")
}
ErrSignatureMultiple = func() *validationError {
return ValidationError("id_token contains multiple signatures")
}
ErrSignatureInvalidPayload = func() *validationError {
return ValidationError("Signature does not match Payload")
}
ErrAtHash = func() *validationError {
return ValidationError("at_hash does not correspond to access token")
}
)
func ValidationError(message string, args ...interface{}) *validationError {
return &validationError{fmt.Sprintf(message, args...)} //TODO: impl
}
type validationError struct {
message string
}
func (v *validationError) Error() string {
return v.message
}

View file

@ -3,11 +3,12 @@ package rp
import ( import (
"context" "context"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
) )
//Verifier implement the Token Response Validation as defined in OIDC specification //deprecated: Use IDTokenVerifier or oidc.Verifier
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
type Verifier interface { type Verifier interface {
//Verify checks the access_token and id_token and returns the `id token claims` //Verify checks the access_token and id_token and returns the `id token claims`
@ -16,3 +17,100 @@ type Verifier interface {
//VerifyIDToken checks the id_token only and returns its `id token claims` //VerifyIDToken checks the id_token only and returns its `id token claims`
VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error)
} }
type IDTokenVerifier interface {
oidc.Verifier
}
//VerifyTokens implement the Token Response Validation as defined in OIDC specification
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) {
idToken, err := VerifyIDToken(ctx, idTokenString, v)
if err != nil {
return nil, err
}
if err := VerifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil {
return nil, err
}
return idToken, nil
}
//VerifyIDToken validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) {
claims := new(oidc.IDTokenClaims)
decrypted, err := oidc.DecryptToken(token)
if err != nil {
return nil, err
}
payload, err := oidc.ParseToken(decrypted, claims)
if err != nil {
return nil, err
}
//2, check issuer (exact match)
if err := oidc.CheckIssuer(claims.GetIssuer(), v); err != nil {
return nil, err
}
//3. check aud (aud must contain client_id, all aud strings must be allowed)
if err = oidc.CheckAudience(claims.GetAudience(), v); err != nil {
return nil, err
}
if err = oidc.CheckAuthorizedParty(claims.GetAudience(), claims.GetAuthorizedParty(), v); err != nil {
return nil, err
}
//6. check signature by keys
//7. check alg default is rs256
//8. check if alg is mac based (hs...) -> audience contains client_id. for validation use utf-8 representation of your client_secret
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v); err != nil {
return nil, err
}
//9. check exp before now
if err = oidc.CheckExpiration(claims.GetExpiration(), v); err != nil {
return nil, err
}
//10. check iat duration is optional (can be checked)
if err = oidc.CheckIssuedAt(claims.GetIssuedAt(), v); err != nil {
return nil, err
}
/*
//11. check nonce (check if optional possible) id_token.nonce == sentNonce
if err = oidc.CheckNonce(claims.GetNonce()); err != nil {
return nil, err
}
*/
//12. if acr requested check acr
if err = oidc.CheckAuthorizationContextClassReference(claims.GetAuthenticationContextClassReference(), v); err != nil {
return nil, err
}
//13. if auth_time requested check if auth_time is less than max age
if err = oidc.CheckAuthTime(claims.GetAuthTime(), v); err != nil {
return nil, err
}
return claims, nil
}
//VerifyAccessToken validates the access token according to
//https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
if atHash == "" {
return nil
}
actual, err := oidc.ClaimHash(accessToken, sigAlgorithm)
if err != nil {
return err
}
if actual != atHash {
return oidc.ErrAtHash
}
return nil
}

9
pkg/rp/verity.go Normal file
View file

@ -0,0 +1,9 @@
package rp
import (
"context"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
)