From 542ec6ed7beb36f0a0fde7ac3cea3f8d5eeeda3c Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Fri, 25 Sep 2020 16:41:25 +0200 Subject: [PATCH 01/17] refactoring --- example/client/app/app.go | 60 ++- example/client/jwt_profile.go | 39 ++ example/internal/mock/storage.go | 12 +- pkg/oidc/authorization.go | 168 -------- pkg/oidc/token.go | 621 +++++++++++++++++----------- pkg/oidc/token_request.go | 101 +++++ pkg/oidc/types.go | 113 +++++ pkg/oidc/types_test.go | 276 +++++++++++++ pkg/oidc/userinfo.go | 462 +++++++++++++++------ pkg/oidc/verifier.go | 4 +- pkg/op/authrequest.go | 2 +- pkg/op/mock/authorizer.mock.impl.go | 2 +- pkg/op/mock/signer.mock.go | 2 +- pkg/op/mock/storage.mock.go | 8 +- pkg/op/op.go | 2 +- pkg/op/session.go | 4 +- pkg/op/signer.go | 43 +- pkg/op/signer_test.go | 6 +- pkg/op/storage.go | 4 +- pkg/op/token.go | 48 +-- pkg/op/verifier_id_token_hint.go | 4 +- pkg/rp/mock/verifier.mock.impl.go | 2 +- pkg/rp/relaying_party.go | 9 +- pkg/rp/verifier.go | 8 +- pkg/utils/marshal.go | 14 + pkg/utils/sign.go | 23 ++ 26 files changed, 1412 insertions(+), 625 deletions(-) create mode 100644 example/client/jwt_profile.go create mode 100644 pkg/oidc/token_request.go create mode 100644 pkg/oidc/types.go create mode 100644 pkg/oidc/types_test.go create mode 100644 pkg/utils/sign.go diff --git a/example/client/app/app.go b/example/client/app/app.go index 4c0831b..ea1e6e7 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "html/template" + "io/ioutil" "net/http" "os" "time" @@ -30,7 +32,7 @@ func main() { ctx := context.Background() redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) - scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail} + scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress, "hodor"} cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, rp.WithPKCE(cookieHandler), @@ -82,6 +84,62 @@ func main() { } w.Write(data) }) + + http.HandleFunc("/jwt-profile", func(w http.ResponseWriter, r *http.Request) { + tpl := ` + + + + + Login + + +
+ + + +
+ + ` + t, err := template.New("login").Parse(tpl) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + err = t.Execute(w, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + }) + + http.HandleFunc("/jwt-profile-assertion", func(w http.ResponseWriter, r *http.Request) { + r.ParseMultipartForm(32 << 20) + file, handler, err := r.FormFile("key") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer file.Close() + + key, err := ioutil.ReadAll(file) + fmt.Println(handler.Header) + assertion, err := oidc.NewJWTProfileAssertionFromFileData(key, []string{issuer}) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + token, err := rp.JWTProfileExchange(ctx, assertion, provider) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + data, err := json.Marshal(token) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Write(data) + }) lis := fmt.Sprintf("127.0.0.1:%s", port) logrus.Infof("listening on http://%s/", lis) logrus.Fatal(http.ListenAndServe("127.0.0.1:"+port, nil)) diff --git a/example/client/jwt_profile.go b/example/client/jwt_profile.go new file mode 100644 index 0000000..6dcd11b --- /dev/null +++ b/example/client/jwt_profile.go @@ -0,0 +1,39 @@ +package client + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/sirupsen/logrus" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/rp" + "github.com/caos/oidc/pkg/utils" +) + +var ( + callbackPath string = "/auth/callback" + key []byte = []byte("test1234test1234") +) + +func main() { + clientID := os.Getenv("CLIENT_ID") + clientSecret := os.Getenv("CLIENT_SECRET") + issuer := os.Getenv("ISSUER") + port := os.Getenv("PORT") + + ctx := context.Background() + + redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) + scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress, "hodor"} + cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) + provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, + rp.WithPKCE(cookieHandler), + rp.WithVerifierOpts(rp.WithIssuedAtOffset(5*time.Second)), + ) + if err != nil { + logrus.Fatalf("error creating provider %s", err.Error()) + } +} diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index f20fb9b..74e0ed7 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -210,24 +210,24 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st return nil } -func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (*oidc.Userinfo, error) { +func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (*oidc.userinfo, error) { return s.GetUserinfoFromScopes(ctx, "", []string{}) } -func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (*oidc.Userinfo, error) { - return &oidc.Userinfo{ +func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (*oidc.userinfo, error) { + return &oidc.userinfo{ Subject: a.GetSubject(), Address: &oidc.UserinfoAddress{ StreetAddress: "Hjkhkj 789\ndsf", }, - UserinfoEmail: oidc.UserinfoEmail{ + userinfoEmail: oidc.userinfoEmail{ Email: "test", EmailVerified: true, }, - UserinfoPhone: oidc.UserinfoPhone{ + userinfoPhone: oidc.userinfoPhone{ PhoneNumber: "sadsa", PhoneNumberVerified: true, }, - UserinfoProfile: oidc.UserinfoProfile{ + userinfoProfile: oidc.userinfoProfile{ UpdatedAt: time.Now(), }, // Claims: map[string]interface{}{ diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index 35da2fb..71776af 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -1,15 +1,5 @@ package oidc -import ( - "encoding/json" - "errors" - "strings" - "time" - - "golang.org/x/text/language" - "gopkg.in/square/go-jose.v2" -) - const ( //ScopeOpenID defines the scope `openid` //OpenID Connect requests MUST contain the `openid` scope value @@ -64,23 +54,8 @@ const ( //PromptSelectAccount (`select_account `) directs the Authorization Server to prompt the End-User to select a user account (to enable multi user / session switching) PromptSelectAccount Prompt = "select_account" - - //GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow - GrantTypeCode GrantType = "authorization_code" - //GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant - GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" - - //BearerToken defines the token_type `Bearer`, which is returned in a successful token response - BearerToken = "Bearer" ) -var displayValues = map[string]Display{ - "page": DisplayPage, - "popup": DisplayPopup, - "touch": DisplayTouch, - "wap": DisplayWAP, -} - //AuthRequest according to: //https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest type AuthRequest struct { @@ -121,146 +96,3 @@ func (a *AuthRequest) GetResponseType() ResponseType { func (a *AuthRequest) GetState() string { return a.State } - -type TokenRequest interface { - // GrantType GrantType `schema:"grant_type"` - GrantType() GrantType -} - -type TokenRequestType GrantType - -type AccessTokenRequest struct { - Code string `schema:"code"` - RedirectURI string `schema:"redirect_uri"` - ClientID string `schema:"client_id"` - ClientSecret string `schema:"client_secret"` - CodeVerifier string `schema:"code_verifier"` -} - -func (a *AccessTokenRequest) GrantType() GrantType { - return GrantTypeCode -} - -type AccessTokenResponse struct { - AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"` - TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"` - RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"` - ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"` - IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"` -} - -type JWTTokenRequest struct { - Issuer string `json:"iss"` - Subject string `json:"sub"` - Scopes Scopes `json:"scope"` - Audience interface{} `json:"aud"` - IssuedAt Time `json:"iat"` - ExpiresAt Time `json:"exp"` -} - -func (j *JWTTokenRequest) GetClientID() string { - return j.Subject -} - -func (j *JWTTokenRequest) GetSubject() string { - return j.Subject -} - -func (j *JWTTokenRequest) GetScopes() []string { - return j.Scopes -} - -type Time time.Time - -func (t *Time) UnmarshalJSON(data []byte) error { - var i int64 - if err := json.Unmarshal(data, &i); err != nil { - return err - } - *t = Time(time.Unix(i, 0).UTC()) - return nil -} - -func (j *JWTTokenRequest) GetIssuer() string { - return j.Issuer -} - -func (j *JWTTokenRequest) GetAudience() []string { - return audienceFromJSON(j.Audience) -} - -func (j *JWTTokenRequest) GetExpiration() time.Time { - return time.Time(j.ExpiresAt) -} - -func (j *JWTTokenRequest) GetIssuedAt() time.Time { - return time.Time(j.IssuedAt) -} - -func (j *JWTTokenRequest) GetNonce() string { - return "" -} - -func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string { - return "" -} - -func (j *JWTTokenRequest) GetAuthTime() time.Time { - return time.Time{} -} - -func (j *JWTTokenRequest) GetAuthorizedParty() string { - return "" -} - -func (j *JWTTokenRequest) SetSignature(algorithm jose.SignatureAlgorithm) {} - -type TokenExchangeRequest struct { - subjectToken string `schema:"subject_token"` - subjectTokenType string `schema:"subject_token_type"` - actorToken string `schema:"actor_token"` - actorTokenType string `schema:"actor_token_type"` - resource []string `schema:"resource"` - audience []string `schema:"audience"` - Scope []string `schema:"scope"` - requestedTokenType string `schema:"requested_token_type"` -} - -type Scopes []string - -func (s *Scopes) UnmarshalText(text []byte) error { - scopes := strings.Split(string(text), " ") - *s = Scopes(scopes) - return nil -} - -type ResponseType string - -type Display string - -func (d *Display) UnmarshalText(text []byte) error { - var ok bool - display := string(text) - *d, ok = displayValues[display] - if !ok { - return errors.New("") - } - return nil -} - -type Prompt string - -type Locales []language.Tag - -func (l *Locales) UnmarshalText(text []byte) error { - locales := strings.Split(string(text), " ") - for _, locale := range locales { - tag, err := language.Parse(locale) - if err == nil && !tag.IsRoot() { - *l = append(*l, tag) - } - } - return nil -} - -type GrantType string diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 21b0419..e20dd4a 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -3,33 +3,55 @@ package oidc import ( "encoding/json" "io/ioutil" - "strings" "time" "golang.org/x/oauth2" - "golang.org/x/text/language" "gopkg.in/square/go-jose.v2" "github.com/caos/oidc/pkg/utils" ) +const ( + //BearerToken defines the token_type `Bearer`, which is returned in a successful token response + BearerToken = "Bearer" +) + type Tokens struct { *oauth2.Token - IDTokenClaims *IDTokenClaims + IDTokenClaims IDTokenClaims IDToken string } -type AccessTokenClaims struct { +type AccessTokenClaims interface { + Claims +} + +type IDTokenClaims interface { + Claims + GetNotBefore() time.Time + GetJWTID() string + GetAccessTokenHash() string + GetCodeHash() string + GetAuthenticationMethodsReferences() []string + GetClientID() string + GetSignatureAlgorithm() jose.SignatureAlgorithm + SetAccessTokenHash(hash string) + SetUserinfo(userinfo UserInfoSetter) + SetCodeHash(hash string) + UserInfo +} + +type accessTokenClaims struct { Issuer string Subject string - Audiences []string - Expiration time.Time - IssuedAt time.Time - NotBefore time.Time + Audience Audience + Expiration Time + IssuedAt Time + NotBefore Time JWTID string AuthorizedParty string Nonce string - AuthTime time.Time + AuthTime Time CodeHash string AuthenticationContextClassReference string AuthenticationMethodsReferences []string @@ -37,38 +59,155 @@ type AccessTokenClaims struct { Scopes []string ClientID string AccessTokenUseNumber int + + signatureAlg jose.SignatureAlgorithm } -type IDTokenClaims struct { - Issuer string - Audiences []string - Expiration time.Time - NotBefore time.Time - IssuedAt time.Time - JWTID string - UpdatedAt time.Time - AuthorizedParty string - Nonce string - AuthTime time.Time - AccessTokenHash string - CodeHash string - AuthenticationContextClassReference string - AuthenticationMethodsReferences []string - ClientID string - Userinfo +func (a accessTokenClaims) GetIssuer() string { + return a.Issuer +} - Signature jose.SignatureAlgorithm //TODO: ??? +func (a accessTokenClaims) GetAudience() []string { + return a.Audience +} + +func (a accessTokenClaims) GetExpiration() time.Time { + return time.Time(a.Expiration) +} + +func (a accessTokenClaims) GetIssuedAt() time.Time { + return time.Time(a.IssuedAt) +} + +func (a accessTokenClaims) GetNonce() string { + return a.Nonce +} + +func (a accessTokenClaims) GetAuthenticationContextClassReference() string { + return a.AuthenticationContextClassReference +} + +func (a accessTokenClaims) GetAuthTime() time.Time { + return time.Time(a.AuthTime) +} + +func (a accessTokenClaims) GetAuthorizedParty() string { + return a.AuthorizedParty +} + +func (a accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) { + a.signatureAlg = algorithm +} + +func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, id string) AccessTokenClaims { + now := time.Now().UTC() + return &accessTokenClaims{ + Issuer: issuer, + Subject: subject, + Audience: audience, + Expiration: Time(expiration), + IssuedAt: Time(now), + NotBefore: Time(now), + JWTID: id, + } +} + +type idTokenClaims struct { + Issuer string `json:"iss,omitempty"` + Audience Audience `json:"aud,omitempty"` + Expiration Time `json:"exp,omitempty"` + NotBefore Time `json:"nbf,omitempty"` + IssuedAt Time `json:"iat,omitempty"` + JWTID string `json:"jti,omitempty"` + AuthorizedParty string `json:"azp,omitempty"` + Nonce string `json:"nonce,omitempty"` + AuthTime Time `json:"auth_time,omitempty"` + AccessTokenHash string `json:"at_hash,omitempty"` + CodeHash string `json:"c_hash,omitempty"` + AuthenticationContextClassReference string `json:"acr,omitempty"` + AuthenticationMethodsReferences []string `json:"amr,omitempty"` + ClientID string `json:"client_id,omitempty"` + UserInfo `json:"-"` + + signatureAlg jose.SignatureAlgorithm +} + +func (t *idTokenClaims) SetAccessTokenHash(hash string) { + t.AccessTokenHash = hash +} + +func (t *idTokenClaims) SetUserinfo(info UserInfoSetter) { + t.UserInfo = info +} + +func (t *idTokenClaims) SetCodeHash(hash string) { + t.CodeHash = hash +} + +func EmptyIDTokenClaims() IDTokenClaims { + return new(idTokenClaims) +} + +func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string) IDTokenClaims { + return &idTokenClaims{ + Issuer: issuer, + Audience: audience, + Expiration: Time(expiration), + IssuedAt: Time(time.Now().UTC()), + AuthTime: Time(authTime), + Nonce: nonce, + AuthenticationContextClassReference: acr, + AuthenticationMethodsReferences: amr, + AuthorizedParty: clientID, + UserInfo: &userinfo{Subject: subject}, + } +} + +func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm { + return t.signatureAlg +} + +func (t *idTokenClaims) GetNotBefore() time.Time { + return time.Time(t.NotBefore) +} + +func (t *idTokenClaims) GetJWTID() string { + return t.JWTID +} + +func (t *idTokenClaims) GetAccessTokenHash() string { + return t.AccessTokenHash +} + +func (t *idTokenClaims) GetCodeHash() string { + return t.CodeHash +} + +func (t *idTokenClaims) GetAuthenticationMethodsReferences() []string { + return t.AuthenticationMethodsReferences +} + +func (t *idTokenClaims) GetClientID() string { + return t.ClientID +} + +type AccessTokenResponse struct { + AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"` + TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"` + RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"` + ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"` + IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"` } type JWTProfileAssertion struct { - PrivateKeyID string `json:"keyId"` - PrivateKey []byte `json:"key"` - Scopes []string `json:"-"` - Issuer string `json:"-"` - Subject string `json:"userId"` - Audience []string `json:"-"` - Expiration time.Time `json:"-"` - IssuedAt time.Time `json:"-"` + PrivateKeyID string `json:"-"` + PrivateKey []byte `json:"-"` + Scopes []string `json:"scopes"` + Issuer string `json:"issuer"` + Subject string `json:"sub"` + Audience Audience `json:"aud"` + Expiration Time `json:"exp"` + IssuedAt Time `json:"iat"` } func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWTProfileAssertion, error) { @@ -76,12 +215,16 @@ func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWT if err != nil { return nil, err } + return NewJWTProfileAssertionFromFileData(data, audience) +} + +func NewJWTProfileAssertionFromFileData(data []byte, audience []string) (*JWTProfileAssertion, error) { keyData := new(struct { KeyID string `json:"keyId"` Key string `json:"key"` UserID string `json:"userId"` }) - err = json.Unmarshal(data, keyData) + err := json.Unmarshal(data, keyData) if err != nil { return nil, err } @@ -95,241 +238,251 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte) Issuer: userID, Scopes: []string{ScopeOpenID}, Subject: userID, - IssuedAt: time.Now().UTC(), - Expiration: time.Now().Add(1 * time.Hour).UTC(), + IssuedAt: Time(time.Now().UTC()), + Expiration: Time(time.Now().Add(1 * time.Hour).UTC()), Audience: audience, } } -type jsonToken struct { - Issuer string `json:"iss,omitempty"` - Subject string `json:"sub,omitempty"` - Audiences interface{} `json:"aud,omitempty"` - Expiration int64 `json:"exp,omitempty"` - NotBefore int64 `json:"nbf,omitempty"` - IssuedAt int64 `json:"iat,omitempty"` - JWTID string `json:"jti,omitempty"` - AuthorizedParty string `json:"azp,omitempty"` - Nonce string `json:"nonce,omitempty"` - AuthTime int64 `json:"auth_time,omitempty"` - AccessTokenHash string `json:"at_hash,omitempty"` - CodeHash string `json:"c_hash,omitempty"` - AuthenticationContextClassReference string `json:"acr,omitempty"` - AuthenticationMethodsReferences []string `json:"amr,omitempty"` - SessionID string `json:"sid,omitempty"` - Actor interface{} `json:"act,omitempty"` //TODO: impl - Scopes string `json:"scope,omitempty"` - ClientID string `json:"client_id,omitempty"` - AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl - AccessTokenUseNumber int `json:"at_use_nbr,omitempty"` - jsonUserinfo -} +// +//type jsonToken struct { +// Issuer string `json:"iss,omitempty"` +// Subject string `json:"sub,omitempty"` +// Audiences interface{} `json:"aud,omitempty"` +// Expiration int64 `json:"exp,omitempty"` +// NotBefore int64 `json:"nbf,omitempty"` +// IssuedAt int64 `json:"iat,omitempty"` +// JWTID string `json:"jti,omitempty"` +// AuthorizedParty string `json:"azp,omitempty"` +// Nonce string `json:"nonce,omitempty"` +// AuthTime int64 `json:"auth_time,omitempty"` +// AccessTokenHash string `json:"at_hash,omitempty"` +// CodeHash string `json:"c_hash,omitempty"` +// AuthenticationContextClassReference string `json:"acr,omitempty"` +// AuthenticationMethodsReferences []string `json:"amr,omitempty"` +// SessionID string `json:"sid,omitempty"` +// Actor interface{} `json:"act,omitempty"` //TODO: impl +// Scopes string `json:"scope,omitempty"` +// ClientID string `json:"client_id,omitempty"` +// AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl +// AccessTokenUseNumber int `json:"at_use_nbr,omitempty"` +// jsonUserinfo +//} -func (t *AccessTokenClaims) MarshalJSON() ([]byte, error) { - j := jsonToken{ - Issuer: t.Issuer, - Subject: t.Subject, - Audiences: t.Audiences, - Expiration: timeToJSON(t.Expiration), - NotBefore: timeToJSON(t.NotBefore), - IssuedAt: timeToJSON(t.IssuedAt), - JWTID: t.JWTID, - AuthorizedParty: t.AuthorizedParty, - Nonce: t.Nonce, - AuthTime: timeToJSON(t.AuthTime), - CodeHash: t.CodeHash, - AuthenticationContextClassReference: t.AuthenticationContextClassReference, - AuthenticationMethodsReferences: t.AuthenticationMethodsReferences, - SessionID: t.SessionID, - Scopes: strings.Join(t.Scopes, " "), - ClientID: t.ClientID, - AccessTokenUseNumber: t.AccessTokenUseNumber, +// +//func (t *accessTokenClaims) MarshalJSON() ([]byte, error) { +// j := jsonToken{ +// Issuer: t.Issuer, +// Subject: t.Subject, +// Audiences: t.Audiences, +// Expiration: timeToJSON(t.Expiration), +// NotBefore: timeToJSON(t.NotBefore), +// IssuedAt: timeToJSON(t.IssuedAt), +// JWTID: t.JWTID, +// AuthorizedParty: t.AuthorizedParty, +// Nonce: t.Nonce, +// AuthTime: timeToJSON(t.AuthTime), +// CodeHash: t.CodeHash, +// AuthenticationContextClassReference: t.AuthenticationContextClassReference, +// AuthenticationMethodsReferences: t.AuthenticationMethodsReferences, +// SessionID: t.SessionID, +// Scopes: strings.Join(t.Scopes, " "), +// ClientID: t.ClientID, +// AccessTokenUseNumber: t.AccessTokenUseNumber, +// } +// return json.Marshal(j) +//} +// +//func (t *accessTokenClaims) UnmarshalJSON(b []byte) error { +// var j jsonToken +// if err := json.Unmarshal(b, &j); err != nil { +// return err +// } +// t.Issuer = j.Issuer +// t.Subject = j.Subject +// t.Audiences = audienceFromJSON(j.Audiences) +// t.Expiration = time.Unix(j.Expiration, 0).UTC() +// t.NotBefore = time.Unix(j.NotBefore, 0).UTC() +// t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC() +// t.JWTID = j.JWTID +// t.AuthorizedParty = j.AuthorizedParty +// t.Nonce = j.Nonce +// t.AuthTime = time.Unix(j.AuthTime, 0).UTC() +// t.CodeHash = j.CodeHash +// t.AuthenticationContextClassReference = j.AuthenticationContextClassReference +// t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences +// t.SessionID = j.SessionID +// t.Scopes = strings.Split(j.Scopes, " ") +// t.ClientID = j.ClientID +// t.AccessTokenUseNumber = j.AccessTokenUseNumber +// return nil +//} +// +func (t *idTokenClaims) MarshalJSON() ([]byte, error) { + type Alias idTokenClaims + a := &struct { + *Alias + Expiration int64 `json:"nbf,omitempty"` + IssuedAt int64 `json:"nbf,omitempty"` + NotBefore int64 `json:"nbf,omitempty"` + AuthTime int64 `json:"nbf,omitempty"` + }{ + Alias: (*Alias)(t), } - return json.Marshal(j) + if !time.Time(t.Expiration).IsZero() { + a.Expiration = time.Time(t.Expiration).Unix() + } + if !time.Time(t.IssuedAt).IsZero() { + a.IssuedAt = time.Time(t.IssuedAt).Unix() + } + if !time.Time(t.NotBefore).IsZero() { + a.NotBefore = time.Time(t.NotBefore).Unix() + } + if !time.Time(t.AuthTime).IsZero() { + a.AuthTime = time.Time(t.AuthTime).Unix() + } + b, err := json.Marshal(a) + if err != nil { + return nil, err + } + + if t.UserInfo == nil { + return b, nil + } + info, err := json.Marshal(t.UserInfo) + if err != nil { + return nil, err + } + return utils.ConcatenateJSON(b, info) } -func (t *AccessTokenClaims) UnmarshalJSON(b []byte) error { - var j jsonToken - if err := json.Unmarshal(b, &j); err != nil { +func (t *idTokenClaims) UnmarshalJSON(data []byte) error { + type Alias idTokenClaims + if err := json.Unmarshal(data, (*Alias)(t)); err != nil { return err } - t.Issuer = j.Issuer - t.Subject = j.Subject - t.Audiences = audienceFromJSON(j.Audiences) - t.Expiration = time.Unix(j.Expiration, 0).UTC() - t.NotBefore = time.Unix(j.NotBefore, 0).UTC() - t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC() - t.JWTID = j.JWTID - t.AuthorizedParty = j.AuthorizedParty - t.Nonce = j.Nonce - t.AuthTime = time.Unix(j.AuthTime, 0).UTC() - t.CodeHash = j.CodeHash - t.AuthenticationContextClassReference = j.AuthenticationContextClassReference - t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences - t.SessionID = j.SessionID - t.Scopes = strings.Split(j.Scopes, " ") - t.ClientID = j.ClientID - t.AccessTokenUseNumber = j.AccessTokenUseNumber + userinfo := new(userinfo) + if err := json.Unmarshal(data, userinfo); err != nil { + return err + } + t.UserInfo = userinfo + return nil } -func (t *IDTokenClaims) MarshalJSON() ([]byte, error) { - j := jsonToken{ - Issuer: t.Issuer, - Subject: t.Subject, - Audiences: t.Audiences, - Expiration: timeToJSON(t.Expiration), - NotBefore: timeToJSON(t.NotBefore), - IssuedAt: timeToJSON(t.IssuedAt), - JWTID: t.JWTID, - AuthorizedParty: t.AuthorizedParty, - Nonce: t.Nonce, - AuthTime: timeToJSON(t.AuthTime), - AccessTokenHash: t.AccessTokenHash, - CodeHash: t.CodeHash, - AuthenticationContextClassReference: t.AuthenticationContextClassReference, - AuthenticationMethodsReferences: t.AuthenticationMethodsReferences, - ClientID: t.ClientID, - } - j.setUserinfo(t.Userinfo) - return json.Marshal(j) -} - -func (t *IDTokenClaims) UnmarshalJSON(b []byte) error { - var i jsonToken - if err := json.Unmarshal(b, &i); err != nil { - return err - } - t.Issuer = i.Issuer - t.Subject = i.Subject - t.Audiences = audienceFromJSON(i.Audiences) - t.Expiration = time.Unix(i.Expiration, 0).UTC() - t.IssuedAt = time.Unix(i.IssuedAt, 0).UTC() - t.AuthTime = time.Unix(i.AuthTime, 0).UTC() - t.Nonce = i.Nonce - t.AuthenticationContextClassReference = i.AuthenticationContextClassReference - t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences - t.AuthorizedParty = i.AuthorizedParty - t.AccessTokenHash = i.AccessTokenHash - t.CodeHash = i.CodeHash - t.UserinfoProfile = i.UnmarshalUserinfoProfile() - t.UserinfoEmail = i.UnmarshalUserinfoEmail() - t.UserinfoPhone = i.UnmarshalUserinfoPhone() - t.Address = i.UnmarshalUserinfoAddress() - return nil -} - -func (t *IDTokenClaims) GetIssuer() string { +func (t *idTokenClaims) GetIssuer() string { return t.Issuer } -func (t *IDTokenClaims) GetAudience() []string { - return t.Audiences +func (t *idTokenClaims) GetAudience() []string { + return t.Audience } -func (t *IDTokenClaims) GetExpiration() time.Time { - return t.Expiration +func (t *idTokenClaims) GetExpiration() time.Time { + return time.Time(t.Expiration) } -func (t *IDTokenClaims) GetIssuedAt() time.Time { - return t.IssuedAt +func (t *idTokenClaims) GetIssuedAt() time.Time { + return time.Time(t.IssuedAt) } -func (t *IDTokenClaims) GetNonce() string { +func (t *idTokenClaims) GetNonce() string { return t.Nonce } -func (t *IDTokenClaims) GetAuthenticationContextClassReference() string { +func (t *idTokenClaims) GetAuthenticationContextClassReference() string { return t.AuthenticationContextClassReference } -func (t *IDTokenClaims) GetAuthTime() time.Time { - return t.AuthTime +func (t *idTokenClaims) GetAuthTime() time.Time { + return time.Time(t.AuthTime) } -func (t *IDTokenClaims) GetAuthorizedParty() string { +func (t *idTokenClaims) GetAuthorizedParty() string { return t.AuthorizedParty } -func (t *IDTokenClaims) SetSignature(alg jose.SignatureAlgorithm) { - t.Signature = alg +func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) { + t.signatureAlg = alg } -func (t *JWTProfileAssertion) MarshalJSON() ([]byte, error) { - j := jsonToken{ - Issuer: t.Issuer, - Subject: t.Subject, - Audiences: t.Audience, - Expiration: timeToJSON(t.Expiration), - IssuedAt: timeToJSON(t.IssuedAt), - Scopes: strings.Join(t.Scopes, " "), - } - return json.Marshal(j) -} +// +//func (t *JWTProfileAssertion) MarshalJSON() ([]byte, error) { +// j := jsonToken{ +// Issuer: t.Issuer, +// Subject: t.Subject, +// Audiences: t.Audience, +// Expiration: timeToJSON(t.Expiration), +// IssuedAt: timeToJSON(t.IssuedAt), +// Scopes: strings.Join(t.Scopes, " "), +// } +// return json.Marshal(j) +//} -func (t *JWTProfileAssertion) UnmarshalJSON(b []byte) error { - var j jsonToken - if err := json.Unmarshal(b, &j); err != nil { - return err - } +//func (t *JWTProfileAssertion) UnmarshalJSON(b []byte) error { +// var j jsonToken +// if err := json.Unmarshal(b, &j); err != nil { +// return err +// } +// +// t.Issuer = j.Issuer +// t.Subject = j.Subject +// t.Audience = audienceFromJSON(j.Audiences) +// t.Expiration = time.Unix(j.Expiration, 0).UTC() +// t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC() +// t.Scopes = strings.Split(j.Scopes, " ") +// +// return nil +//} - t.Issuer = j.Issuer - t.Subject = j.Subject - t.Audience = audienceFromJSON(j.Audiences) - t.Expiration = time.Unix(j.Expiration, 0).UTC() - t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC() - t.Scopes = strings.Split(j.Scopes, " ") - - return nil -} - -func (j *jsonToken) UnmarshalUserinfoProfile() UserinfoProfile { - locale, _ := language.Parse(j.Locale) - return UserinfoProfile{ - Name: j.Name, - GivenName: j.GivenName, - FamilyName: j.FamilyName, - MiddleName: j.MiddleName, - Nickname: j.Nickname, - Profile: j.Profile, - Picture: j.Picture, - Website: j.Website, - Gender: Gender(j.Gender), - Birthdate: j.Birthdate, - Zoneinfo: j.Zoneinfo, - Locale: locale, - UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(), - PreferredUsername: j.PreferredUsername, - } -} - -func (j *jsonToken) UnmarshalUserinfoEmail() UserinfoEmail { - return UserinfoEmail{ - Email: j.Email, - EmailVerified: j.EmailVerified, - } -} - -func (j *jsonToken) UnmarshalUserinfoPhone() UserinfoPhone { - return UserinfoPhone{ - PhoneNumber: j.Phone, - PhoneNumberVerified: j.PhoneVerified, - } -} - -func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress { - if j.JsonUserinfoAddress == nil { - return nil - } - return &UserinfoAddress{ - Country: j.JsonUserinfoAddress.Country, - Formatted: j.JsonUserinfoAddress.Formatted, - Locality: j.JsonUserinfoAddress.Locality, - PostalCode: j.JsonUserinfoAddress.PostalCode, - Region: j.JsonUserinfoAddress.Region, - StreetAddress: j.JsonUserinfoAddress.StreetAddress, - } -} +// +//func (j *jsonToken) UnmarshalUserinfoProfile() userInfoProfile { +// locale, _ := language.Parse(j.Locale) +// return userInfoProfile{ +// Name: j.Name, +// GivenName: j.GivenName, +// FamilyName: j.FamilyName, +// MiddleName: j.MiddleName, +// Nickname: j.Nickname, +// Profile: j.Profile, +// Picture: j.Picture, +// Website: j.Website, +// Gender: Gender(j.Gender), +// Birthdate: j.Birthdate, +// Zoneinfo: j.Zoneinfo, +// Locale: locale, +// UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(), +// PreferredUsername: j.PreferredUsername, +// } +//} +// +//func (j *jsonToken) UnmarshalUserinfoEmail() userInfoEmail { +// return userInfoEmail{ +// Email: j.Email, +// EmailVerified: j.EmailVerified, +// } +//} +// +//func (j *jsonToken) UnmarshalUserinfoPhone() userInfoPhone { +// return userInfoPhone{ +// PhoneNumber: j.Phone, +// PhoneNumberVerified: j.PhoneVerified, +// } +//} +// +//func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress { +// if j.JsonUserinfoAddress == nil { +// return nil +// } +// return &UserinfoAddress{ +// Country: j.JsonUserinfoAddress.Country, +// Formatted: j.JsonUserinfoAddress.Formatted, +// Locality: j.JsonUserinfoAddress.Locality, +// PostalCode: j.JsonUserinfoAddress.PostalCode, +// Region: j.JsonUserinfoAddress.Region, +// StreetAddress: j.JsonUserinfoAddress.StreetAddress, +// } +//} func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { hash, err := utils.GetHashAlgorithm(sigAlgorithm) diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go new file mode 100644 index 0000000..c04dfb4 --- /dev/null +++ b/pkg/oidc/token_request.go @@ -0,0 +1,101 @@ +package oidc + +import ( + "time" + + "gopkg.in/square/go-jose.v2" +) + +const ( + //GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow + GrantTypeCode GrantType = "authorization_code" + //GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant + GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" +) + +type GrantType string + +type TokenRequest interface { + // GrantType GrantType `schema:"grant_type"` + GrantType() GrantType +} + +type TokenRequestType GrantType + +type AccessTokenRequest struct { + Code string `schema:"code"` + RedirectURI string `schema:"redirect_uri"` + ClientID string `schema:"client_id"` + ClientSecret string `schema:"client_secret"` + CodeVerifier string `schema:"code_verifier"` +} + +func (a *AccessTokenRequest) GrantType() GrantType { + return GrantTypeCode +} + +type JWTTokenRequest struct { + Issuer string `json:"iss"` + Subject string `json:"sub"` + Scopes Scopes `json:"scope"` + Audience Audience `json:"aud"` + IssuedAt Time `json:"iat"` + ExpiresAt Time `json:"exp"` +} + +func (j *JWTTokenRequest) GetClientID() string { + return j.Subject +} + +func (j *JWTTokenRequest) GetSubject() string { + return j.Subject +} + +func (j *JWTTokenRequest) GetScopes() []string { + return j.Scopes +} + +func (j *JWTTokenRequest) GetIssuer() string { + return j.Issuer +} + +func (j *JWTTokenRequest) GetAudience() []string { + return j.Audience +} + +func (j *JWTTokenRequest) GetExpiration() time.Time { + return time.Time(j.ExpiresAt) +} + +func (j *JWTTokenRequest) GetIssuedAt() time.Time { + return time.Time(j.IssuedAt) +} + +func (j *JWTTokenRequest) GetNonce() string { + return "" +} + +func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string { + return "" +} + +func (j *JWTTokenRequest) GetAuthTime() time.Time { + return time.Time{} +} + +func (j *JWTTokenRequest) GetAuthorizedParty() string { + return "" +} + +func (j *JWTTokenRequest) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {} + +type TokenExchangeRequest struct { + subjectToken string `schema:"subject_token"` + subjectTokenType string `schema:"subject_token_type"` + actorToken string `schema:"actor_token"` + actorTokenType string `schema:"actor_token_type"` + resource []string `schema:"resource"` + audience Audience `schema:"audience"` + Scope Scopes `schema:"scope"` + requestedTokenType string `schema:"requested_token_type"` +} diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go new file mode 100644 index 0000000..ad19684 --- /dev/null +++ b/pkg/oidc/types.go @@ -0,0 +1,113 @@ +package oidc + +import ( + "encoding/json" + "strings" + "time" + + "golang.org/x/text/language" +) + +type Audience []string + +func (a *Audience) UnmarshalJSON(text []byte) error { + var i interface{} + err := json.Unmarshal(text, &i) + if err != nil { + return err + } + switch aud := i.(type) { + case []interface{}: + *a = make([]string, len(aud)) + for i, audience := range aud { + (*a)[i] = audience.(string) + } + case string: + *a = []string{aud} + } + return nil +} + +type Display string + +func (d *Display) UnmarshalText(text []byte) error { + display := Display(text) + switch display { + case DisplayPage, DisplayPopup, DisplayTouch, DisplayWAP: + *d = display + } + return nil +} + +type Gender string + +type Locale language.Tag + +//{ +// SetLocale(language.Tag) +// Get() language.Tag +//} +// +//func NewLocale(tag language.Tag) Locale { +// if tag.IsRoot() { +// return nil +// } +// return &locale{Tag: tag} +//} +// +//type locale struct { +// language.Tag +//} +// +//func (l *locale) SetLocale(tag language.Tag) { +// l.Tag = tag +//} +//func (l *locale) Get() language.Tag { +// return l.Tag +//} + +//func (l *locale) MarshalJSON() ([]byte, error) { +// if l != nil && !l.IsRoot() { +// return l.MarshalText() +// } +// return []byte("null"), nil +//} + +type Locales []language.Tag + +func (l *Locales) UnmarshalText(text []byte) error { + locales := strings.Split(string(text), " ") + for _, locale := range locales { + tag, err := language.Parse(locale) + if err == nil && !tag.IsRoot() { + *l = append(*l, tag) + } + } + return nil +} + +type Prompt string + +type ResponseType string + +type Scopes []string + +func (s *Scopes) UnmarshalText(text []byte) error { + *s = strings.Split(string(text), " ") + return nil +} + +type Time time.Time + +func (t *Time) UnmarshalJSON(data []byte) error { + var i int64 + if err := json.Unmarshal(data, &i); err != nil { + return err + } + *t = Time(time.Unix(i, 0).UTC()) + return nil +} + +func (t *Time) MarshalJSON() ([]byte, error) { + return json.Marshal(time.Time(*t).UTC().Unix()) +} diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go new file mode 100644 index 0000000..c451f8c --- /dev/null +++ b/pkg/oidc/types_test.go @@ -0,0 +1,276 @@ +package oidc + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/text/language" +) + +func TestAudience_UnmarshalText(t *testing.T) { + type args struct { + text []byte + } + type res struct { + audience Audience + } + tests := []struct { + name string + args args + res res + wantErr bool + }{ + { + "unknown value", + args{ + []byte(`{"aud": "single audience"}`), + }, + res{ + []string{"single audience"}, + }, + false, + }, + { + "page", + args{ + []byte(`{"aud": ["multiple", "audience"]}`), + }, + res{ + []string{"multiple", "audience"}, + }, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := new(struct { + Audience Audience `json:"aud"` + }) + if err := json.Unmarshal(tt.args.text, &a); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr) + } + assert.ElementsMatch(t, a.Audience, tt.res.audience) + }) + } +} + +func TestDisplay_UnmarshalText(t *testing.T) { + type args struct { + text []byte + } + type res struct { + display Display + } + tests := []struct { + name string + args args + res res + wantErr bool + }{ + { + "unknown value", + args{ + []byte("unknown"), + }, + res{}, + false, + }, + { + "page", + args{ + []byte("page"), + }, + res{DisplayPage}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var d Display + if err := d.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr) + } + if d != tt.res.display { + t.Errorf("Display is not correct is = %v, want %v", d, tt.res.display) + } + }) + } +} + +func TestLocales_UnmarshalText(t *testing.T) { + type args struct { + text []byte + } + type res struct { + tags []language.Tag + } + tests := []struct { + name string + args args + res res + wantErr bool + }{ + { + "unknown value", + args{ + []byte("unknown"), + }, + res{}, + false, + }, + { + "undefined", + args{ + []byte("und"), + }, + res{}, + false, + }, + { + "single language", + args{ + []byte("de"), + }, + res{[]language.Tag{language.German}}, + false, + }, + { + "multiple languages", + args{ + []byte("de en"), + }, + res{[]language.Tag{language.German, language.English}}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var locales Locales + if err := locales.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr) + } + assert.ElementsMatch(t, locales, tt.res.tags) + }) + } +} + +func TestScopes_UnmarshalText(t *testing.T) { + type args struct { + text []byte + } + type res struct { + scopes []string + } + tests := []struct { + name string + args args + res res + wantErr bool + }{ + { + "unknown value", + args{ + []byte("unknown"), + }, + res{ + []string{"unknown"}, + }, + false, + }, + { + "struct", + args{ + []byte(`{"unknown":"value"}`), + }, + res{ + []string{`{"unknown":"value"}`}, + }, + false, + }, + { + "openid", + args{ + []byte("openid"), + }, + res{ + []string{"openid"}, + }, + false, + }, + { + "multiple scopes", + args{ + []byte("openid email custom:scope"), + }, + res{ + []string{"openid", "email", "custom:scope"}, + }, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var scopes Scopes + if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr) + } + assert.ElementsMatch(t, scopes, tt.res.scopes) + }) + } +} + +func TestTime_UnmarshalJSON(t *testing.T) { + type args struct { + text []byte + } + type res struct { + scopes []string + } + tests := []struct { + name string + args args + res res + wantErr bool + }{ + { + "unknown value", + args{ + []byte("unknown"), + }, + res{ + []string{"unknown"}, + }, + false, + }, + { + "openid", + args{ + []byte("openid"), + }, + res{ + []string{"openid"}, + }, + false, + }, + { + "multiple scopes", + args{ + []byte("openid email custom:scope"), + }, + res{ + []string{"openid", "email", "custom:scope"}, + }, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var scopes Scopes + if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr) + } + assert.ElementsMatch(t, scopes, tt.res.scopes) + }) + } +} diff --git a/pkg/oidc/userinfo.go b/pkg/oidc/userinfo.go index f8b4e6c..31de85f 100644 --- a/pkg/oidc/userinfo.go +++ b/pkg/oidc/userinfo.go @@ -2,89 +2,312 @@ package oidc import ( "encoding/json" + "fmt" "time" "golang.org/x/text/language" + + "github.com/caos/oidc/pkg/utils" ) -type Userinfo struct { - Subject string - UserinfoProfile - UserinfoEmail - UserinfoPhone - Address *UserinfoAddress +type UserInfo interface { + GetSubject() string + UserInfoProfile + UserInfoEmail + UserInfoPhone + GetAddress() UserInfoAddress + GetClaim(key string) interface{} +} - Authorizations []string +type UserInfoProfile interface { + GetName() string + GetGivenName() string + GetFamilyName() string + GetMiddleName() string + GetNickname() string + GetProfile() string + GetPicture() string + GetWebsite() string + GetGender() Gender + GetBirthdate() string + GetZoneinfo() string + GetLocale() language.Tag + GetPreferredUsername() string +} + +type UserInfoEmail interface { + GetEmail() string + IsEmailVerified() bool +} + +type UserInfoPhone interface { + GetPhoneNumber() string + IsPhoneNumberVerified() bool +} + +type UserInfoAddress interface { + GetFormatted() string + GetStreetAddress() string + GetLocality() string + GetRegion() string + GetPostalCode() string + GetCountry() string +} + +type UserInfoSetter interface { + UserInfo + SetSubject(sub string) + UserInfoProfileSetter + SetEmail(email string, verified bool) + SetPhone(phone string, verified bool) + SetAddress(address UserInfoAddress) + AppendClaims(key string, values interface{}) +} + +type UserInfoProfileSetter interface { + SetName(name string) + SetGivenName(name string) + SetFamilyName(name string) + SetMiddleName(name string) + SetNickname(name string) + SetUpdatedAt(date time.Time) + SetProfile(profile string) + SetPicture(profile string) + SetWebsite(website string) + SetGender(gender Gender) + SetBirthdate(birthdate string) + SetZoneinfo(zoneInfo string) + SetLocale(locale language.Tag) + SetPreferredUsername(name string) +} + +func NewUserInfo() UserInfoSetter { + return &userinfo{} +} + +type userinfo struct { + Subject string `json:"sub,omitempty"` + userInfoProfile + userInfoEmail + userInfoPhone + Address UserInfoAddress `json:"address,omitempty"` claims map[string]interface{} } -type UserinfoProfile struct { - Name string - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender Gender - Birthdate string - Zoneinfo string - Locale language.Tag - UpdatedAt time.Time - PreferredUsername string +func (u *userinfo) GetSubject() string { + return u.Subject } -type Gender string - -type UserinfoEmail struct { - Email string - EmailVerified bool +func (u *userinfo) GetName() string { + return u.Name } -type UserinfoPhone struct { - PhoneNumber string - PhoneNumberVerified bool +func (u *userinfo) GetGivenName() string { + return u.GivenName } -type UserinfoAddress struct { - Formatted string - StreetAddress string - Locality string - Region string - PostalCode string - Country string +func (u *userinfo) GetFamilyName() string { + return u.FamilyName } -type jsonUserinfoProfile struct { - Name string `json:"name,omitempty"` - GivenName string `json:"given_name,omitempty"` - FamilyName string `json:"family_name,omitempty"` - MiddleName string `json:"middle_name,omitempty"` - Nickname string `json:"nickname,omitempty"` - Profile string `json:"profile,omitempty"` - Picture string `json:"picture,omitempty"` - Website string `json:"website,omitempty"` - Gender string `json:"gender,omitempty"` - Birthdate string `json:"birthdate,omitempty"` - Zoneinfo string `json:"zoneinfo,omitempty"` - Locale string `json:"locale,omitempty"` - UpdatedAt int64 `json:"updated_at,omitempty"` - PreferredUsername string `json:"preferred_username,omitempty"` +func (u *userinfo) GetMiddleName() string { + return u.MiddleName } -type jsonUserinfoEmail struct { +func (u *userinfo) GetNickname() string { + return u.Nickname +} + +func (u *userinfo) GetProfile() string { + return u.Profile +} + +func (u *userinfo) GetPicture() string { + return u.Picture +} + +func (u *userinfo) GetWebsite() string { + return u.Website +} + +func (u *userinfo) GetGender() Gender { + return u.Gender +} + +func (u *userinfo) GetBirthdate() string { + return u.Birthdate +} + +func (u *userinfo) GetZoneinfo() string { + return u.Zoneinfo +} + +func (u *userinfo) GetLocale() language.Tag { + return u.Locale +} + +func (u *userinfo) GetPreferredUsername() string { + return u.PreferredUsername +} + +func (u *userinfo) GetEmail() string { + return u.Email +} + +func (u *userinfo) IsEmailVerified() bool { + return u.EmailVerified +} + +func (u *userinfo) GetPhoneNumber() string { + return u.PhoneNumber +} + +func (u *userinfo) IsPhoneNumberVerified() bool { + return u.PhoneNumberVerified +} + +func (u *userinfo) GetAddress() UserInfoAddress { + return u.Address +} + +func (u *userinfo) GetClaim(key string) interface{} { + return u.claims[key] +} + +func (u *userinfo) SetSubject(sub string) { + u.Subject = sub +} + +func (u *userinfo) SetName(name string) { + u.Name = name +} + +func (u *userinfo) SetGivenName(name string) { + u.GivenName = name +} + +func (u *userinfo) SetFamilyName(name string) { + u.FamilyName = name +} + +func (u *userinfo) SetMiddleName(name string) { + u.MiddleName = name +} + +func (u *userinfo) SetNickname(name string) { + u.Nickname = name +} + +func (u *userinfo) SetUpdatedAt(date time.Time) { + u.UpdatedAt = Time(date) +} + +func (u *userinfo) SetProfile(profile string) { + u.Profile = profile +} + +func (u *userinfo) SetPicture(picture string) { + u.Picture = picture +} + +func (u *userinfo) SetWebsite(website string) { + u.Website = website +} + +func (u *userinfo) SetGender(gender Gender) { + u.Gender = gender +} + +func (u *userinfo) SetBirthdate(birthdate string) { + u.Birthdate = birthdate +} + +func (u *userinfo) SetZoneinfo(zoneInfo string) { + u.Zoneinfo = zoneInfo +} + +func (u *userinfo) SetLocale(locale language.Tag) { + u.Locale = locale +} + +func (u *userinfo) SetPreferredUsername(name string) { + u.PreferredUsername = name +} + +func (u *userinfo) SetEmail(email string, verified bool) { + u.Email = email + u.EmailVerified = verified +} + +func (u *userinfo) SetPhone(phone string, verified bool) { + u.PhoneNumber = phone + u.PhoneNumberVerified = verified +} + +func (u *userinfo) SetAddress(address UserInfoAddress) { + u.Address = address +} + +func (u *userinfo) AppendClaims(key string, value interface{}) { + if u.claims == nil { + u.claims = make(map[string]interface{}) + } + u.claims[key] = value +} + +func (u *userInfoAddress) GetFormatted() string { + panic("implement me") +} + +func (u *userInfoAddress) GetStreetAddress() string { + panic("implement me") +} + +func (u *userInfoAddress) GetLocality() string { + panic("implement me") +} + +func (u *userInfoAddress) GetRegion() string { + panic("implement me") +} + +func (u *userInfoAddress) GetPostalCode() string { + panic("implement me") +} + +func (u *userInfoAddress) GetCountry() string { + panic("implement me") +} + +type userInfoProfile struct { + Name string `json:"name,omitempty"` + GivenName string `json:"given_name,omitempty"` + FamilyName string `json:"family_name,omitempty"` + MiddleName string `json:"middle_name,omitempty"` + Nickname string `json:"nickname,omitempty"` + Profile string `json:"profile,omitempty"` + Picture string `json:"picture,omitempty"` + Website string `json:"website,omitempty"` + Gender Gender `json:"gender,omitempty"` + Birthdate string `json:"birthdate,omitempty"` + Zoneinfo string `json:"zoneinfo,omitempty"` + Locale language.Tag `json:"locale,omitempty"` + UpdatedAt Time `json:"updated_at,omitempty"` + PreferredUsername string `json:"preferred_username,omitempty"` +} + +type userInfoEmail struct { Email string `json:"email,omitempty"` EmailVerified bool `json:"email_verified,omitempty"` } -type jsonUserinfoPhone struct { - Phone string `json:"phone_number,omitempty"` - PhoneVerified bool `json:"phone_number_verified,omitempty"` +type userInfoPhone struct { + PhoneNumber string `json:"phone_number,omitempty"` + PhoneNumberVerified bool `json:"phone_number_verified,omitempty"` } -type jsonUserinfoAddress struct { +type userInfoAddress struct { Formatted string `json:"formatted,omitempty"` StreetAddress string `json:"street_address,omitempty"` Locality string `json:"locality,omitempty"` @@ -93,81 +316,68 @@ type jsonUserinfoAddress struct { Country string `json:"country,omitempty"` } -func (i *Userinfo) MarshalJSON() ([]byte, error) { - j := new(jsonUserinfo) - j.Subject = i.Subject - j.setUserinfo(*i) - j.Authorizations = i.Authorizations - return json.Marshal(j) +func NewUserInfoAddress(streetAddress, locality, region, postalCode, country, formatted string) UserInfoAddress { + return &userInfoAddress{ + StreetAddress: streetAddress, + Locality: locality, + Region: region, + PostalCode: postalCode, + Country: country, + Formatted: formatted, + } +} +func (i *userinfo) MarshalJSON() ([]byte, error) { + type Alias userinfo + a := &struct { + *Alias + Locale interface{} `json:"locale,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + }{ + Alias: (*Alias)(i), + } + if !i.Locale.IsRoot() { + a.Locale = i.Locale + } + fmt.Println(time.Time(i.UpdatedAt).String()) + if !time.Time(i.UpdatedAt).IsZero() { + a.UpdatedAt = time.Time(i.UpdatedAt).Unix() + } + + b, err := json.Marshal(a) + if err != nil { + return nil, err + } + + if len(i.claims) == 0 { + return b, nil + } + + claims, err := json.Marshal(i.claims) + if err != nil { + return nil, fmt.Errorf("jws: invalid map of private claims %v", i.claims) + } + return utils.ConcatenateJSON(b, claims) } -func (i *Userinfo) UnmmarshalJSON(data []byte) error { - if err := json.Unmarshal(data, i); err != nil { +func (i *userinfo) UnmarshalJSON(data []byte) error { + type Alias userinfo + a := &struct { + *Alias + //Locale interface{} `json:"locale,omitempty"` + UpdatedAt int64 `json:"update_at,omitempty"` + }{ + Alias: (*Alias)(i), + } + if err := json.Unmarshal(data, &a); err != nil { return err } - return json.Unmarshal(data, &i.claims) -} + //if !i.Locale.IsRoot() { + // a.Locale = i.Locale + //} -type jsonUserinfo struct { - Subject string `json:"sub,omitempty"` - jsonUserinfoProfile - jsonUserinfoEmail - jsonUserinfoPhone - JsonUserinfoAddress *jsonUserinfoAddress `json:"address,omitempty"` - Authorizations []string `json:"authorizations,omitempty"` -} + i.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC()) -func (j *jsonUserinfo) setUserinfo(i Userinfo) { - j.setUserinfoProfile(i.UserinfoProfile) - j.setUserinfoEmail(i.UserinfoEmail) - j.setUserinfoPhone(i.UserinfoPhone) - j.setUserinfoAddress(i.Address) -} - -func (j *jsonUserinfo) setUserinfoProfile(i UserinfoProfile) { - j.Name = i.Name - j.GivenName = i.GivenName - j.FamilyName = i.FamilyName - j.MiddleName = i.MiddleName - j.Nickname = i.Nickname - j.Profile = i.Profile - j.Picture = i.Picture - j.Website = i.Website - j.Gender = string(i.Gender) - j.Birthdate = i.Birthdate - j.Zoneinfo = i.Zoneinfo - if i.Locale != language.Und { - j.Locale = i.Locale.String() - } - j.UpdatedAt = timeToJSON(i.UpdatedAt) - j.PreferredUsername = i.PreferredUsername -} - -func (j *jsonUserinfo) setUserinfoEmail(i UserinfoEmail) { - j.Email = i.Email - j.EmailVerified = i.EmailVerified -} - -func (j *jsonUserinfo) setUserinfoPhone(i UserinfoPhone) { - j.Phone = i.PhoneNumber - j.PhoneVerified = i.PhoneNumberVerified -} - -func (j *jsonUserinfo) setUserinfoAddress(i *UserinfoAddress) { - if i == nil { - return - } - if i.Country == "" && i.Formatted == "" && i.Locality == "" && i.PostalCode == "" && i.Region == "" && i.StreetAddress == "" { - return - } - j.JsonUserinfoAddress = &jsonUserinfoAddress{ - Country: i.Country, - Formatted: i.Formatted, - Locality: i.Locality, - PostalCode: i.PostalCode, - Region: i.Region, - StreetAddress: i.StreetAddress, - } + return nil } type UserInfoRequest struct { diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go index 492664b..06470a0 100644 --- a/pkg/oidc/verifier.go +++ b/pkg/oidc/verifier.go @@ -24,7 +24,7 @@ type Claims interface { GetAuthenticationContextClassReference() string GetAuthTime() time.Time GetAuthorizedParty() string - SetSignature(algorithm jose.SignatureAlgorithm) + SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) } var ( @@ -140,7 +140,7 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl return ErrSignatureInvalidPayload } - claims.SetSignature(jose.SignatureAlgorithm(sig.Header.Algorithm)) + claims.SetSignatureAlgorithm(jose.SignatureAlgorithm(sig.Header.Algorithm)) return nil } diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index cf40e62..cee6184 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -168,7 +168,7 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie if err != nil { return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.") } - return claims.Subject, nil + return claims.GetSubject(), nil } //RedirectToLogin redirects the end user to the Login UI for authentication diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go index 0a6b6a5..7dfcfff 100644 --- a/pkg/op/mock/authorizer.mock.impl.go +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -81,7 +81,7 @@ func (s *Sig) Health(ctx context.Context) error { func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) { return "", nil } -func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) { +func (s *Sig) SignAccessToken(*oidc.accessTokenClaims) (string, error) { return "", nil } func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm { diff --git a/pkg/op/mock/signer.mock.go b/pkg/op/mock/signer.mock.go index a7d909c..16592a7 100644 --- a/pkg/op/mock/signer.mock.go +++ b/pkg/op/mock/signer.mock.go @@ -50,7 +50,7 @@ func (mr *MockSignerMockRecorder) Health(arg0 interface{}) *gomock.Call { } // SignAccessToken mocks base method -func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) { +func (m *MockSigner) SignAccessToken(arg0 *oidc.accessTokenClaims) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SignAccessToken", arg0) ret0, _ := ret[0].(string) diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index a7ca4cb..bcc04da 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -184,10 +184,10 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac } // GetUserinfoFromScopes mocks base method -func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.Userinfo, error) { +func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.userinfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2) - ret0, _ := ret[0].(*oidc.Userinfo) + ret0, _ := ret[0].(*oidc.userinfo) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -199,10 +199,10 @@ func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interf } // GetUserinfoFromToken mocks base method -func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (*oidc.Userinfo, error) { +func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (*oidc.userinfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2) - ret0, _ := ret[0].(*oidc.Userinfo) + ret0, _ := ret[0].(*oidc.userinfo) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/pkg/op/op.go b/pkg/op/op.go index d913c7f..7e8279a 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -130,7 +130,7 @@ func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opO } keyCh := make(chan jose.SigningKey) - o.signer = NewDefaultSigner(ctx, storage, keyCh) + o.signer = NewSigner(ctx, storage, keyCh) go EnsureKey(ctx, storage, keyCh, o.timer, o.retry) o.httpHandler = CreateRouter(o, o.interceptors...) diff --git a/pkg/op/session.go b/pkg/op/session.go index d04e361..19ebab4 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -66,8 +66,8 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, if err != nil { return nil, ErrInvalidRequest("id_token_hint invalid") } - session.UserID = claims.Subject - session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.AuthorizedParty) + session.UserID = claims.GetSubject() + session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.GetAuthorizedParty()) if err != nil { return nil, ErrServerError("") } diff --git a/pkg/op/signer.go b/pkg/op/signer.go index e9926cd..5cf585e 100644 --- a/pkg/op/signer.go +++ b/pkg/op/signer.go @@ -2,19 +2,17 @@ package op import ( "context" - "encoding/json" "errors" "github.com/caos/logging" "gopkg.in/square/go-jose.v2" - - "github.com/caos/oidc/pkg/oidc" ) type Signer interface { Health(ctx context.Context) error - SignIDToken(claims *oidc.IDTokenClaims) (string, error) - SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) + //SignIDToken(claims *oidc.IDTokenClaims) (string, error) + //SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) + Signer() jose.Signer SignatureAlgorithm() jose.SignatureAlgorithm } @@ -24,7 +22,7 @@ type tokenSigner struct { alg jose.SignatureAlgorithm } -func NewDefaultSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer { +func NewSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer { s := &tokenSigner{ storage: storage, } @@ -41,6 +39,15 @@ func (s *tokenSigner) Health(_ context.Context) error { return nil } +func (s *tokenSigner) Signer() jose.Signer { + return s.signer +} + +// +//func (s *tokenSigner) Sign(payload []byte) (*jose.JSONWebSignature, error) { +// return s.signer.Sign(payload) +//} + func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) { for { select { @@ -55,30 +62,6 @@ func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.S } } -func (s *tokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) { - payload, err := json.Marshal(claims) - if err != nil { - return "", err - } - return s.Sign(payload) -} - -func (s *tokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) { - payload, err := json.Marshal(claims) - if err != nil { - return "", err - } - return s.Sign(payload) -} - -func (s *tokenSigner) Sign(payload []byte) (string, error) { - result, err := s.signer.Sign(payload) - if err != nil { - return "", err - } - return result.CompactSerialize() -} - func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm { return s.alg } diff --git a/pkg/op/signer_test.go b/pkg/op/signer_test.go index 75e184b..c751c76 100644 --- a/pkg/op/signer_test.go +++ b/pkg/op/signer_test.go @@ -38,13 +38,13 @@ import ( // } // for _, tt := range tests { // t.Run(tt.name, func(t *testing.T) { -// got, err := op.NewDefaultSigner(tt.args.storage) +// got, err := op.NewSigner(tt.args.storage) // if (err != nil) != tt.wantErr { -// t.Errorf("NewDefaultSigner() error = %v, wantErr %v", err, tt.wantErr) +// t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr) // return // } // if !reflect.DeepEqual(got, tt.want) { -// t.Errorf("NewDefaultSigner() = %v, want %v", got, tt.want) +// t.Errorf("NewSigner() = %v, want %v", got, tt.want) // } // }) // } diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 669b08e..69784ee 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -28,8 +28,8 @@ type AuthStorage interface { type OPStorage interface { GetClientByClientID(context.Context, string) (Client, error) AuthorizeClientIDSecret(context.Context, string, string) error - GetUserinfoFromScopes(context.Context, string, []string) (*oidc.Userinfo, error) - GetUserinfoFromToken(context.Context, string, string) (*oidc.Userinfo, error) + GetUserinfoFromScopes(context.Context, string, []string) (oidc.UserInfoSetter, error) + GetUserinfoFromToken(context.Context, string, string) (oidc.UserInfoSetter, error) GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) } diff --git a/pkg/op/token.go b/pkg/op/token.go index 87494b9..bb2b3c5 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -5,6 +5,7 @@ import ( "time" "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/utils" ) type TokenCreator interface { @@ -82,51 +83,34 @@ func CreateBearerToken(id string, crypto Crypto) (string, error) { } func CreateJWT(issuer string, authReq TokenRequest, exp time.Time, id string, signer Signer) (string, error) { - now := time.Now().UTC() - nbf := now - claims := &oidc.AccessTokenClaims{ - Issuer: issuer, - Subject: authReq.GetSubject(), - Audiences: authReq.GetAudience(), - Expiration: exp, - IssuedAt: now, - NotBefore: nbf, - JWTID: id, - } - return signer.SignAccessToken(claims) + claims := oidc.NewAccessTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, id) + return utils.Sign(claims, signer.Signer()) } func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer) (string, error) { - var err error exp := time.Now().UTC().Add(validity) - userinfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes()) - if err != nil { - return "", err - } - claims := &oidc.IDTokenClaims{ - Issuer: issuer, - Audiences: authReq.GetAudience(), - Expiration: exp, - IssuedAt: time.Now().UTC(), - AuthTime: authReq.GetAuthTime(), - Nonce: authReq.GetNonce(), - AuthenticationContextClassReference: authReq.GetACR(), - AuthenticationMethodsReferences: authReq.GetAMR(), - AuthorizedParty: authReq.GetClientID(), - Userinfo: *userinfo, - } + claims := oidc.NewIDTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, authReq.GetAuthTime(), authReq.GetNonce(), authReq.GetACR(), authReq.GetAMR(), authReq.GetClientID()) + if accessToken != "" { - claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm()) + atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm()) if err != nil { return "", err } + claims.SetAccessTokenHash(atHash) + } else { + userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes()) + if err != nil { + return "", err + } + claims.SetUserinfo(userInfo) } if code != "" { - claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm()) + codeHash, err := oidc.ClaimHash(code, signer.SignatureAlgorithm()) if err != nil { return "", err } + claims.SetCodeHash(codeHash) } - return signer.SignIDToken(claims) + return utils.Sign(claims, signer.Signer()) } diff --git a/pkg/op/verifier_id_token_hint.go b/pkg/op/verifier_id_token_hint.go index 3268a5e..7baa075 100644 --- a/pkg/op/verifier_id_token_hint.go +++ b/pkg/op/verifier_id_token_hint.go @@ -63,8 +63,8 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet) IDTokenHintVerifi //VerifyIDTokenHint validates the id token according to //https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation -func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (*oidc.IDTokenClaims, error) { - claims := new(oidc.IDTokenClaims) +func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (oidc.IDTokenClaims, error) { + claims := oidc.EmptyIDTokenClaims() decrypted, err := oidc.DecryptToken(token) if err != nil { diff --git a/pkg/rp/mock/verifier.mock.impl.go b/pkg/rp/mock/verifier.mock.impl.go index 53b2f03..0b6dd1c 100644 --- a/pkg/rp/mock/verifier.mock.impl.go +++ b/pkg/rp/mock/verifier.mock.impl.go @@ -33,5 +33,5 @@ func NewMockVerifierExpectValid(t *testing.T) rp.Verifier { func ExpectVerifyValid(v rp.Verifier) { mock := v.(*MockVerifier) - mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.Userinfo{Subject: "id"}}, nil) + mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.userinfo{Subject: "id"}}, nil) } diff --git a/pkg/rp/relaying_party.go b/pkg/rp/relaying_party.go index 3fe8b4b..fd9ee95 100644 --- a/pkg/rp/relaying_party.go +++ b/pkg/rp/relaying_party.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "net/url" "strings" "github.com/google/uuid" @@ -329,10 +330,10 @@ func CallTokenEndpoint(request interface{}, rp RelayingParty) (newToken *oauth2. } func CallJWTProfileEndpoint(assertion string, rp RelayingParty) (*oauth2.Token, error) { - form := make(map[string][]string) - form["assertion"] = []string{assertion} - form["grant_type"] = []string{jwtProfileKey} - req, err := http.NewRequest("POST", rp.OAuthConfig().Endpoint.TokenURL, nil) + form := url.Values{} + form.Add("assertion", assertion) + form.Add("grant_type", jwtProfileKey) + req, err := http.NewRequest("POST", rp.OAuthConfig().Endpoint.TokenURL, strings.NewReader(form.Encode())) if err != nil { return nil, err } diff --git a/pkg/rp/verifier.go b/pkg/rp/verifier.go index ef2cf87..a156f6d 100644 --- a/pkg/rp/verifier.go +++ b/pkg/rp/verifier.go @@ -21,12 +21,12 @@ type IDTokenVerifier interface { //VerifyTokens implement the Token Response Validation as defined in OIDC specification //https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation -func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) { +func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (oidc.IDTokenClaims, error) { idToken, err := VerifyIDToken(ctx, idTokenString, v) if err != nil { return nil, err } - if err := VerifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil { + if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil { return nil, err } return idToken, nil @@ -34,8 +34,8 @@ func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTo //VerifyIDToken validates the id token according to //https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation -func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) { - claims := new(oidc.IDTokenClaims) +func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.IDTokenClaims, error) { + claims := oidc.EmptyIDTokenClaims() decrypted, err := oidc.DecryptToken(token) if err != nil { diff --git a/pkg/utils/marshal.go b/pkg/utils/marshal.go index e279341..4f53b4e 100644 --- a/pkg/utils/marshal.go +++ b/pkg/utils/marshal.go @@ -1,7 +1,9 @@ package utils import ( + "bytes" "encoding/json" + "fmt" "net/http" "github.com/sirupsen/logrus" @@ -19,3 +21,15 @@ func MarshalJSON(w http.ResponseWriter, i interface{}) { logrus.Error("error writing response") } } + +func ConcatenateJSON(first, second []byte) ([]byte, error) { + if !bytes.HasSuffix(first, []byte{'}'}) { + return nil, fmt.Errorf("jws: invalid JSON %s", first) + } + if !bytes.HasPrefix(second, []byte{'{'}) { + return nil, fmt.Errorf("jws: invalid JSON %s", second) + } + first[len(first)-1] = ',' + first = append(first, second[1:]...) + return first, nil +} diff --git a/pkg/utils/sign.go b/pkg/utils/sign.go new file mode 100644 index 0000000..e1efe61 --- /dev/null +++ b/pkg/utils/sign.go @@ -0,0 +1,23 @@ +package utils + +import ( + "encoding/json" + + "gopkg.in/square/go-jose.v2" +) + +func Sign(object interface{}, signer jose.Signer) (string, error) { + payload, err := json.Marshal(object) + if err != nil { + return "", err + } + return SignPayload(payload, signer) +} + +func SignPayload(payload []byte, signer jose.Signer) (string, error) { + result, err := signer.Sign(payload) + if err != nil { + return "", err + } + return result.CompactSerialize() +} From d7ed59db2bf2916fdb13e9479281a539d86e3658 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Mon, 28 Sep 2020 08:14:10 +0200 Subject: [PATCH 02/17] refactoring --- pkg/oidc/keyset.go | 20 +++++++++----------- pkg/oidc/session.go | 2 ++ pkg/oidc/userinfo.go | 19 +++++++------------ 3 files changed, 18 insertions(+), 23 deletions(-) diff --git a/pkg/oidc/keyset.go b/pkg/oidc/keyset.go index abe55d1..0d8e02c 100644 --- a/pkg/oidc/keyset.go +++ b/pkg/oidc/keyset.go @@ -6,21 +6,19 @@ import ( "gopkg.in/square/go-jose.v2" ) -// KeySet is a set of publc JSON Web Keys that can be used to validate the signature -// of JSON web tokens. This is expected to be backed by a remote key set through -// provider metadata discovery or an in-memory set of keys delivered out-of-band. +//KeySet represents a set of JSON Web Keys +// - remotely fetch via discovery and jwks_uri -> `remoteKeySet` +// - held by the OP itself in storage -> `openIDKeySet` +// - dynamically aggregated by request for OAuth JWT Profile Assertion -> `jwtProfileKeySet` type KeySet interface { - // VerifySignature parses the JSON web token, verifies the signature, and returns - // the raw payload. Header and claim fields are validated by other parts of the - // package. For example, the KeySet does not need to check values such as signature - // algorithm, issuer, and audience since the IDTokenVerifier validates these values - // independently. - // - // If VerifySignature makes HTTP requests to verify the token, it's expected to - // use any HTTP client associated with the context through ClientContext. + //VerifySignature verifies the signature with the given keyset and returns the raw payload VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) } +//CheckKey searches the given JSON Web Keys for the requested key ID +//and verifies the JSON Web Signature with the found key +// +//will return false but no error if key ID is not found func CheckKey(keyID string, jws *jose.JSONWebSignature, keys ...jose.JSONWebKey) ([]byte, error, bool) { for _, key := range keys { if keyID == "" || key.KeyID == keyID { diff --git a/pkg/oidc/session.go b/pkg/oidc/session.go index 418439e..d6735b4 100644 --- a/pkg/oidc/session.go +++ b/pkg/oidc/session.go @@ -1,5 +1,7 @@ package oidc +//EndSessionRequest for the RP-Initiated Logout according to: +//https://openid.net/specs/openid-connect-rpinitiated-1_0.html#RPLogout type EndSessionRequest struct { IdTokenHint string `schema:"id_token_hint"` PostLogoutRedirectURI string `schema:"post_logout_redirect_uri"` diff --git a/pkg/oidc/userinfo.go b/pkg/oidc/userinfo.go index 31de85f..3c77b7b 100644 --- a/pkg/oidc/userinfo.go +++ b/pkg/oidc/userinfo.go @@ -257,27 +257,27 @@ func (u *userinfo) AppendClaims(key string, value interface{}) { } func (u *userInfoAddress) GetFormatted() string { - panic("implement me") + return u.Formatted } func (u *userInfoAddress) GetStreetAddress() string { - panic("implement me") + return u.StreetAddress } func (u *userInfoAddress) GetLocality() string { - panic("implement me") + return u.Locality } func (u *userInfoAddress) GetRegion() string { - panic("implement me") + return u.Region } func (u *userInfoAddress) GetPostalCode() string { - panic("implement me") + return u.PostalCode } func (u *userInfoAddress) GetCountry() string { - panic("implement me") + return u.Country } type userInfoProfile struct { @@ -338,7 +338,6 @@ func (i *userinfo) MarshalJSON() ([]byte, error) { if !i.Locale.IsRoot() { a.Locale = i.Locale } - fmt.Println(time.Time(i.UpdatedAt).String()) if !time.Time(i.UpdatedAt).IsZero() { a.UpdatedAt = time.Time(i.UpdatedAt).Unix() } @@ -354,7 +353,7 @@ func (i *userinfo) MarshalJSON() ([]byte, error) { claims, err := json.Marshal(i.claims) if err != nil { - return nil, fmt.Errorf("jws: invalid map of private claims %v", i.claims) + return nil, fmt.Errorf("jws: invalid map of custom claims %v", i.claims) } return utils.ConcatenateJSON(b, claims) } @@ -363,7 +362,6 @@ func (i *userinfo) UnmarshalJSON(data []byte) error { type Alias userinfo a := &struct { *Alias - //Locale interface{} `json:"locale,omitempty"` UpdatedAt int64 `json:"update_at,omitempty"` }{ Alias: (*Alias)(i), @@ -371,9 +369,6 @@ func (i *userinfo) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &a); err != nil { return err } - //if !i.Locale.IsRoot() { - // a.Locale = i.Locale - //} i.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC()) From d368b2d9506cf0e4a8fc7608672e5577e75a396c Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Mon, 28 Sep 2020 09:07:46 +0200 Subject: [PATCH 03/17] refactoring --- example/client/app/app.go | 80 ++++++++++++------------ example/internal/mock/storage.go | 33 +++------- pkg/op/mock/authorizer.mock.impl.go | 20 +++--- pkg/op/mock/signer.mock.go | 45 +++++--------- pkg/op/mock/storage.mock.go | 8 +-- pkg/op/signer.go | 7 --- pkg/op/signer_test.go | 95 ----------------------------- 7 files changed, 77 insertions(+), 211 deletions(-) delete mode 100644 pkg/op/signer_test.go diff --git a/example/client/app/app.go b/example/client/app/app.go index ea1e6e7..3a96830 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -86,7 +86,8 @@ func main() { }) http.HandleFunc("/jwt-profile", func(w http.ResponseWriter, r *http.Request) { - tpl := ` + if r.Method == "GET" { + tpl := ` @@ -94,51 +95,54 @@ func main() { Login -
+
` - t, err := template.New("login").Parse(tpl) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - err = t.Execute(w, nil) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - }) + t, err := template.New("login").Parse(tpl) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + err = t.Execute(w, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + } else { + err := r.ParseMultipartForm(4 << 10) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + file, handler, err := r.FormFile("key") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer file.Close() - http.HandleFunc("/jwt-profile-assertion", func(w http.ResponseWriter, r *http.Request) { - r.ParseMultipartForm(32 << 20) - file, handler, err := r.FormFile("key") - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + key, err := ioutil.ReadAll(file) + fmt.Println(handler.Header) + assertion, err := oidc.NewJWTProfileAssertionFromFileData(key, []string{issuer}) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + token, err := rp.JWTProfileExchange(ctx, assertion, provider) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + data, err := json.Marshal(token) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Write(data) } - defer file.Close() - - key, err := ioutil.ReadAll(file) - fmt.Println(handler.Header) - assertion, err := oidc.NewJWTProfileAssertionFromFileData(key, []string{issuer}) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - token, err := rp.JWTProfileExchange(ctx, assertion, provider) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - data, err := json.Marshal(token) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.Write(data) }) lis := fmt.Sprintf("127.0.0.1:%s", port) logrus.Infof("listening on http://%s/", lis) diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 74e0ed7..1c33906 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -210,31 +210,18 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st return nil } -func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (*oidc.userinfo, error) { +func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (oidc.UserInfoSetter, error) { return s.GetUserinfoFromScopes(ctx, "", []string{}) } -func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (*oidc.userinfo, error) { - return &oidc.userinfo{ - Subject: a.GetSubject(), - Address: &oidc.UserinfoAddress{ - StreetAddress: "Hjkhkj 789\ndsf", - }, - userinfoEmail: oidc.userinfoEmail{ - Email: "test", - EmailVerified: true, - }, - userinfoPhone: oidc.userinfoPhone{ - PhoneNumber: "sadsa", - PhoneNumberVerified: true, - }, - userinfoProfile: oidc.userinfoProfile{ - UpdatedAt: time.Now(), - }, - // Claims: map[string]interface{}{ - // "test": "test", - // "hkjh": "", - // }, - }, nil +func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (oidc.UserInfoSetter, error) { + userinfo := oidc.NewUserInfo() + userinfo.SetSubject(a.GetSubject()) + userinfo.SetAddress(oidc.NewUserInfoAddress("Test 789\nPostfach 2", "", "", "", "", "")) + userinfo.SetEmail("test", true) + userinfo.SetPhone("0791234567", true) + userinfo.SetName("Test") + userinfo.AppendClaims("private_claim", "test") + return userinfo, nil } type ConfClient struct { diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go index 7dfcfff..a481a8b 100644 --- a/pkg/op/mock/authorizer.mock.impl.go +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -72,18 +72,18 @@ func (v *Verifier) VerifyIDToken(ctx context.Context, idToken string) (*oidc.IDT return nil, nil } -type Sig struct{} +type Sig struct { + signer jose.Signer +} + +func (s *Sig) Signer() jose.Signer { + return s.signer +} func (s *Sig) Health(ctx context.Context) error { return nil } -func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) { - return "", nil -} -func (s *Sig) SignAccessToken(*oidc.accessTokenClaims) (string, error) { - return "", nil -} func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm { return jose.HS256 } @@ -92,9 +92,3 @@ func ExpectStorage(a op.Authorizer, t *testing.T) { mockA := a.(*MockAuthorizer) mockA.EXPECT().Storage().AnyTimes().Return(NewMockStorageAny(t)) } - -// func NewMockSignerAny(t *testing.T) op.Signer { -// m := NewMockSigner(gomock.NewController(t)) -// m.EXPECT().Sign(gomock.Any()).AnyTimes().Return("", nil) -// return m -// } diff --git a/pkg/op/mock/signer.mock.go b/pkg/op/mock/signer.mock.go index 16592a7..b52f9d4 100644 --- a/pkg/op/mock/signer.mock.go +++ b/pkg/op/mock/signer.mock.go @@ -6,7 +6,6 @@ package mock import ( context "context" - oidc "github.com/caos/oidc/pkg/oidc" gomock "github.com/golang/mock/gomock" jose "gopkg.in/square/go-jose.v2" reflect "reflect" @@ -49,36 +48,6 @@ func (mr *MockSignerMockRecorder) Health(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockSigner)(nil).Health), arg0) } -// SignAccessToken mocks base method -func (m *MockSigner) SignAccessToken(arg0 *oidc.accessTokenClaims) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SignAccessToken", arg0) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// SignAccessToken indicates an expected call of SignAccessToken -func (mr *MockSignerMockRecorder) SignAccessToken(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignAccessToken", reflect.TypeOf((*MockSigner)(nil).SignAccessToken), arg0) -} - -// SignIDToken mocks base method -func (m *MockSigner) SignIDToken(arg0 *oidc.IDTokenClaims) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SignIDToken", arg0) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// SignIDToken indicates an expected call of SignIDToken -func (mr *MockSignerMockRecorder) SignIDToken(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignIDToken", reflect.TypeOf((*MockSigner)(nil).SignIDToken), arg0) -} - // SignatureAlgorithm mocks base method func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm { m.ctrl.T.Helper() @@ -92,3 +61,17 @@ func (mr *MockSignerMockRecorder) SignatureAlgorithm() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigner)(nil).SignatureAlgorithm)) } + +// Signer mocks base method +func (m *MockSigner) Signer() jose.Signer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Signer") + ret0, _ := ret[0].(jose.Signer) + return ret0 +} + +// Signer indicates an expected call of Signer +func (mr *MockSignerMockRecorder) Signer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockSigner)(nil).Signer)) +} diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index bcc04da..1bcd1a6 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -184,10 +184,10 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac } // GetUserinfoFromScopes mocks base method -func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.userinfo, error) { +func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (oidc.UserInfoSetter, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2) - ret0, _ := ret[0].(*oidc.userinfo) + ret0, _ := ret[0].(oidc.UserInfoSetter) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -199,10 +199,10 @@ func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interf } // GetUserinfoFromToken mocks base method -func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (*oidc.userinfo, error) { +func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (oidc.UserInfoSetter, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2) - ret0, _ := ret[0].(*oidc.userinfo) + ret0, _ := ret[0].(oidc.UserInfoSetter) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/pkg/op/signer.go b/pkg/op/signer.go index 5cf585e..76bb9c7 100644 --- a/pkg/op/signer.go +++ b/pkg/op/signer.go @@ -10,8 +10,6 @@ import ( type Signer interface { Health(ctx context.Context) error - //SignIDToken(claims *oidc.IDTokenClaims) (string, error) - //SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) Signer() jose.Signer SignatureAlgorithm() jose.SignatureAlgorithm } @@ -43,11 +41,6 @@ func (s *tokenSigner) Signer() jose.Signer { return s.signer } -// -//func (s *tokenSigner) Sign(payload []byte) (*jose.JSONWebSignature, error) { -// return s.signer.Sign(payload) -//} - func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) { for { select { diff --git a/pkg/op/signer_test.go b/pkg/op/signer_test.go deleted file mode 100644 index c751c76..0000000 --- a/pkg/op/signer_test.go +++ /dev/null @@ -1,95 +0,0 @@ -package op - -import ( - "testing" - - "github.com/stretchr/testify/require" - "gopkg.in/square/go-jose.v2" -) - -// func TestNewDefaultSigner(t *testing.T) { -// type args struct { -// storage Storage -// } -// tests := []struct { -// name string -// args args -// want Signer -// wantErr bool -// }{ -// { -// "err initialize storage fails", -// args{mock.NewMockStorageSigningKeyError(t)}, -// nil, -// true, -// }, -// { -// "err initialize storage fails", -// args{mock.NewMockStorageSigningKeyInvalid(t)}, -// nil, -// true, -// }, -// { -// "initialize ok", -// args{mock.NewMockStorageSigningKey(t)}, -// &idTokenSigner{Storage: mock.NewMockStorageSigningKey(t)}, -// false, -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// got, err := op.NewSigner(tt.args.storage) -// if (err != nil) != tt.wantErr { -// t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr) -// return -// } -// if !reflect.DeepEqual(got, tt.want) { -// t.Errorf("NewSigner() = %v, want %v", got, tt.want) -// } -// }) -// } -// } - -func Test_idTokenSigner_Sign(t *testing.T) { - signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, &jose.SignerOptions{}) - require.NoError(t, err) - - type fields struct { - signer jose.Signer - storage Storage - } - type args struct { - payload []byte - } - tests := []struct { - name string - fields fields - args args - want string - wantErr bool - }{ - { - "ok", - fields{signer, nil}, - args{[]byte("test")}, - "eyJhbGciOiJIUzI1NiJ9.dGVzdA.SxYZRsvB_Dr4F7SEFuYXvkMZqCCwzpsPOQXl-vLPEww", - false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s := &tokenSigner{ - signer: tt.fields.signer, - storage: tt.fields.storage, - } - got, err := s.Sign(tt.args.payload) - if (err != nil) != tt.wantErr { - t.Errorf("idTokenSigner.Sign() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("idTokenSigner.Sign() = %v, want %v", got, tt.want) - } - }) - } -} From 0cad2e4652b30ac4d31cf35c989b372ad23558f1 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Mon, 28 Sep 2020 13:55:22 +0200 Subject: [PATCH 04/17] jwt profile and authorization handling --- example/client/app/app.go | 2 +- example/client/jwt_profile.go | 39 ---- .../grants/tokenexchange/tokenexchange.go | 18 +- pkg/oidc/token.go | 177 ------------------ pkg/oidc/token_request.go | 2 +- pkg/oidc/types.go | 40 +--- pkg/oidc/types_test.go | 52 +++-- pkg/op/tokenrequest.go | 14 +- pkg/op/verifier_jwt_profile.go | 8 +- pkg/rp/relaying_party.go | 52 ++--- pkg/rp/tockenexchange.go | 9 +- pkg/utils/http.go | 24 ++- 12 files changed, 128 insertions(+), 309 deletions(-) delete mode 100644 example/client/jwt_profile.go diff --git a/example/client/app/app.go b/example/client/app/app.go index 3a96830..a2fff44 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -131,7 +131,7 @@ func main() { http.Error(w, err.Error(), http.StatusInternalServerError) return } - token, err := rp.JWTProfileExchange(ctx, assertion, provider) + token, err := rp.JWTProfileAssertionExchange(ctx, assertion, oidc.Scopes{oidc.ScopeOpenID, oidc.ScopeProfile}, provider) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/example/client/jwt_profile.go b/example/client/jwt_profile.go deleted file mode 100644 index 6dcd11b..0000000 --- a/example/client/jwt_profile.go +++ /dev/null @@ -1,39 +0,0 @@ -package client - -import ( - "context" - "fmt" - "os" - "time" - - "github.com/sirupsen/logrus" - - "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/rp" - "github.com/caos/oidc/pkg/utils" -) - -var ( - callbackPath string = "/auth/callback" - key []byte = []byte("test1234test1234") -) - -func main() { - clientID := os.Getenv("CLIENT_ID") - clientSecret := os.Getenv("CLIENT_SECRET") - issuer := os.Getenv("ISSUER") - port := os.Getenv("PORT") - - ctx := context.Background() - - redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) - scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress, "hodor"} - cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) - provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, - rp.WithPKCE(cookieHandler), - rp.WithVerifierOpts(rp.WithIssuedAtOffset(5*time.Second)), - ) - if err != nil { - logrus.Fatalf("error creating provider %s", err.Error()) - } -} diff --git a/pkg/oidc/grants/tokenexchange/tokenexchange.go b/pkg/oidc/grants/tokenexchange/tokenexchange.go index 9464605..5cb6e79 100644 --- a/pkg/oidc/grants/tokenexchange/tokenexchange.go +++ b/pkg/oidc/grants/tokenexchange/tokenexchange.go @@ -1,5 +1,9 @@ package tokenexchange +import ( + "github.com/caos/oidc/pkg/oidc" +) + const ( AccessTokenType = "urn:ietf:params:oauth:token-type:access_token" RefreshTokenType = "urn:ietf:params:oauth:token-type:refresh_token" @@ -23,7 +27,19 @@ type TokenExchangeRequest struct { } type JWTProfileRequest struct { - Assertion string `schema:"assertion"` + Assertion string `schema:"assertion"` + Scope oidc.Scopes `schema:"scope"` + GrantType oidc.GrantType `schema:"grant_type"` +} + +//ClientCredentialsGrantBasic creates an oauth2 `Client Credentials` Grant +//sneding client_id and client_secret as basic auth header +func NewJWTProfileRequest(assertion string, scopes ...string) *JWTProfileRequest { + return &JWTProfileRequest{ + GrantType: oidc.GrantTypeBearer, + Assertion: assertion, + Scope: scopes, + } } func NewTokenExchangeRequest(subjectToken, subjectTokenType string, opts ...TokenExchangeOption) *TokenExchangeRequest { diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index e20dd4a..e445e7e 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -202,7 +202,6 @@ type AccessTokenResponse struct { type JWTProfileAssertion struct { PrivateKeyID string `json:"-"` PrivateKey []byte `json:"-"` - Scopes []string `json:"scopes"` Issuer string `json:"issuer"` Subject string `json:"sub"` Audience Audience `json:"aud"` @@ -236,7 +235,6 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte) PrivateKey: key, PrivateKeyID: keyID, Issuer: userID, - Scopes: []string{ScopeOpenID}, Subject: userID, IssuedAt: Time(time.Now().UTC()), Expiration: Time(time.Now().Add(1 * time.Hour).UTC()), @@ -244,80 +242,6 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte) } } -// -//type jsonToken struct { -// Issuer string `json:"iss,omitempty"` -// Subject string `json:"sub,omitempty"` -// Audiences interface{} `json:"aud,omitempty"` -// Expiration int64 `json:"exp,omitempty"` -// NotBefore int64 `json:"nbf,omitempty"` -// IssuedAt int64 `json:"iat,omitempty"` -// JWTID string `json:"jti,omitempty"` -// AuthorizedParty string `json:"azp,omitempty"` -// Nonce string `json:"nonce,omitempty"` -// AuthTime int64 `json:"auth_time,omitempty"` -// AccessTokenHash string `json:"at_hash,omitempty"` -// CodeHash string `json:"c_hash,omitempty"` -// AuthenticationContextClassReference string `json:"acr,omitempty"` -// AuthenticationMethodsReferences []string `json:"amr,omitempty"` -// SessionID string `json:"sid,omitempty"` -// Actor interface{} `json:"act,omitempty"` //TODO: impl -// Scopes string `json:"scope,omitempty"` -// ClientID string `json:"client_id,omitempty"` -// AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl -// AccessTokenUseNumber int `json:"at_use_nbr,omitempty"` -// jsonUserinfo -//} - -// -//func (t *accessTokenClaims) MarshalJSON() ([]byte, error) { -// j := jsonToken{ -// Issuer: t.Issuer, -// Subject: t.Subject, -// Audiences: t.Audiences, -// Expiration: timeToJSON(t.Expiration), -// NotBefore: timeToJSON(t.NotBefore), -// IssuedAt: timeToJSON(t.IssuedAt), -// JWTID: t.JWTID, -// AuthorizedParty: t.AuthorizedParty, -// Nonce: t.Nonce, -// AuthTime: timeToJSON(t.AuthTime), -// CodeHash: t.CodeHash, -// AuthenticationContextClassReference: t.AuthenticationContextClassReference, -// AuthenticationMethodsReferences: t.AuthenticationMethodsReferences, -// SessionID: t.SessionID, -// Scopes: strings.Join(t.Scopes, " "), -// ClientID: t.ClientID, -// AccessTokenUseNumber: t.AccessTokenUseNumber, -// } -// return json.Marshal(j) -//} -// -//func (t *accessTokenClaims) UnmarshalJSON(b []byte) error { -// var j jsonToken -// if err := json.Unmarshal(b, &j); err != nil { -// return err -// } -// t.Issuer = j.Issuer -// t.Subject = j.Subject -// t.Audiences = audienceFromJSON(j.Audiences) -// t.Expiration = time.Unix(j.Expiration, 0).UTC() -// t.NotBefore = time.Unix(j.NotBefore, 0).UTC() -// t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC() -// t.JWTID = j.JWTID -// t.AuthorizedParty = j.AuthorizedParty -// t.Nonce = j.Nonce -// t.AuthTime = time.Unix(j.AuthTime, 0).UTC() -// t.CodeHash = j.CodeHash -// t.AuthenticationContextClassReference = j.AuthenticationContextClassReference -// t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences -// t.SessionID = j.SessionID -// t.Scopes = strings.Split(j.Scopes, " ") -// t.ClientID = j.ClientID -// t.AccessTokenUseNumber = j.AccessTokenUseNumber -// return nil -//} -// func (t *idTokenClaims) MarshalJSON() ([]byte, error) { type Alias idTokenClaims a := &struct { @@ -406,84 +330,6 @@ func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) { t.signatureAlg = alg } -// -//func (t *JWTProfileAssertion) MarshalJSON() ([]byte, error) { -// j := jsonToken{ -// Issuer: t.Issuer, -// Subject: t.Subject, -// Audiences: t.Audience, -// Expiration: timeToJSON(t.Expiration), -// IssuedAt: timeToJSON(t.IssuedAt), -// Scopes: strings.Join(t.Scopes, " "), -// } -// return json.Marshal(j) -//} - -//func (t *JWTProfileAssertion) UnmarshalJSON(b []byte) error { -// var j jsonToken -// if err := json.Unmarshal(b, &j); err != nil { -// return err -// } -// -// t.Issuer = j.Issuer -// t.Subject = j.Subject -// t.Audience = audienceFromJSON(j.Audiences) -// t.Expiration = time.Unix(j.Expiration, 0).UTC() -// t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC() -// t.Scopes = strings.Split(j.Scopes, " ") -// -// return nil -//} - -// -//func (j *jsonToken) UnmarshalUserinfoProfile() userInfoProfile { -// locale, _ := language.Parse(j.Locale) -// return userInfoProfile{ -// Name: j.Name, -// GivenName: j.GivenName, -// FamilyName: j.FamilyName, -// MiddleName: j.MiddleName, -// Nickname: j.Nickname, -// Profile: j.Profile, -// Picture: j.Picture, -// Website: j.Website, -// Gender: Gender(j.Gender), -// Birthdate: j.Birthdate, -// Zoneinfo: j.Zoneinfo, -// Locale: locale, -// UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(), -// PreferredUsername: j.PreferredUsername, -// } -//} -// -//func (j *jsonToken) UnmarshalUserinfoEmail() userInfoEmail { -// return userInfoEmail{ -// Email: j.Email, -// EmailVerified: j.EmailVerified, -// } -//} -// -//func (j *jsonToken) UnmarshalUserinfoPhone() userInfoPhone { -// return userInfoPhone{ -// PhoneNumber: j.Phone, -// PhoneNumberVerified: j.PhoneVerified, -// } -//} -// -//func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress { -// if j.JsonUserinfoAddress == nil { -// return nil -// } -// return &UserinfoAddress{ -// Country: j.JsonUserinfoAddress.Country, -// Formatted: j.JsonUserinfoAddress.Formatted, -// Locality: j.JsonUserinfoAddress.Locality, -// PostalCode: j.JsonUserinfoAddress.PostalCode, -// Region: j.JsonUserinfoAddress.Region, -// StreetAddress: j.JsonUserinfoAddress.StreetAddress, -// } -//} - func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { hash, err := utils.GetHashAlgorithm(sigAlgorithm) if err != nil { @@ -492,26 +338,3 @@ func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, erro return utils.HashString(hash, claim, true), nil } - -func timeToJSON(t time.Time) int64 { - if t.IsZero() { - return 0 - } - return t.Unix() -} - -func audienceFromJSON(i interface{}) []string { - switch aud := i.(type) { - case []string: - return aud - case []interface{}: - audience := make([]string, len(aud)) - for i, a := range aud { - audience[i] = a.(string) - } - return audience - case string: - return []string{aud} - } - return nil -} diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go index c04dfb4..1d958df 100644 --- a/pkg/oidc/token_request.go +++ b/pkg/oidc/token_request.go @@ -37,7 +37,7 @@ func (a *AccessTokenRequest) GrantType() GrantType { type JWTTokenRequest struct { Issuer string `json:"iss"` Subject string `json:"sub"` - Scopes Scopes `json:"scope"` + Scopes Scopes `json:"-"` Audience Audience `json:"aud"` IssuedAt Time `json:"iat"` ExpiresAt Time `json:"exp"` diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index ad19684..8423cff 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -41,38 +41,6 @@ func (d *Display) UnmarshalText(text []byte) error { type Gender string -type Locale language.Tag - -//{ -// SetLocale(language.Tag) -// Get() language.Tag -//} -// -//func NewLocale(tag language.Tag) Locale { -// if tag.IsRoot() { -// return nil -// } -// return &locale{Tag: tag} -//} -// -//type locale struct { -// language.Tag -//} -// -//func (l *locale) SetLocale(tag language.Tag) { -// l.Tag = tag -//} -//func (l *locale) Get() language.Tag { -// return l.Tag -//} - -//func (l *locale) MarshalJSON() ([]byte, error) { -// if l != nil && !l.IsRoot() { -// return l.MarshalText() -// } -// return []byte("null"), nil -//} - type Locales []language.Tag func (l *Locales) UnmarshalText(text []byte) error { @@ -92,11 +60,19 @@ type ResponseType string type Scopes []string +func (s *Scopes) Encode() string { + return strings.Join(*s, " ") +} + func (s *Scopes) UnmarshalText(text []byte) error { *s = strings.Split(string(text), " ") return nil } +func (s *Scopes) MarshalText() ([]byte, error) { + return []byte(s.Encode()), nil +} + type Time time.Time func (t *Time) UnmarshalJSON(data []byte) error { diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go index c451f8c..830fb02 100644 --- a/pkg/oidc/types_test.go +++ b/pkg/oidc/types_test.go @@ -1,6 +1,7 @@ package oidc import ( + "bytes" "encoding/json" "testing" @@ -22,7 +23,15 @@ func TestAudience_UnmarshalText(t *testing.T) { wantErr bool }{ { - "unknown value", + "invalid value", + args{ + []byte(`{"aud": {"a": }}}`), + }, + res{}, + false, + }, + { + "single audience", args{ []byte(`{"aud": "single audience"}`), }, @@ -32,7 +41,7 @@ func TestAudience_UnmarshalText(t *testing.T) { false, }, { - "page", + "multiple audience", args{ []byte(`{"aud": ["multiple", "audience"]}`), }, @@ -219,13 +228,12 @@ func TestScopes_UnmarshalText(t *testing.T) { }) } } - -func TestTime_UnmarshalJSON(t *testing.T) { +func TestScopes_MarshalText(t *testing.T) { type args struct { - text []byte + scopes Scopes } type res struct { - scopes []string + scopes []byte } tests := []struct { name string @@ -236,41 +244,53 @@ func TestTime_UnmarshalJSON(t *testing.T) { { "unknown value", args{ - []byte("unknown"), + Scopes{"unknown"}, }, res{ - []string{"unknown"}, + []byte("unknown"), + }, + false, + }, + { + "struct", + args{ + Scopes{`{"unknown":"value"}`}, + }, + res{ + []byte(`{"unknown":"value"}`), }, false, }, { "openid", args{ - []byte("openid"), + Scopes{"openid"}, }, res{ - []string{"openid"}, + []byte("openid"), }, false, }, { "multiple scopes", args{ - []byte("openid email custom:scope"), + Scopes{"openid", "email", "custom:scope"}, }, res{ - []string{"openid", "email", "custom:scope"}, + []byte("openid email custom:scope"), }, false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var scopes Scopes - if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr { - t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr) + text, err := tt.args.scopes.MarshalText() + if (err != nil) != tt.wantErr { + t.Errorf("MarshalText() error = %v, wantErr %v", err, tt.wantErr) + } + if !bytes.Equal(text, tt.res.scopes) { + t.Errorf("MarshalText() is = %q, want %q", text, tt.res.scopes) } - assert.ElementsMatch(t, scopes, tt.res.scopes) }) } } diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index b3613ce..cba70f3 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -138,18 +138,18 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque } func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { - assertion, err := ParseJWTProfileRequest(r, exchanger.Decoder()) + profileRequest, err := ParseJWTProfileRequest(r, exchanger.Decoder()) if err != nil { RequestError(w, r, err) } - claims, err := VerifyJWTAssertion(r.Context(), assertion, exchanger.JWTProfileVerifier()) + tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest, exchanger.JWTProfileVerifier()) if err != nil { RequestError(w, r, err) return } - resp, err := CreateJWTTokenResponse(r.Context(), claims, exchanger) + resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, exchanger) if err != nil { RequestError(w, r, err) return @@ -157,17 +157,17 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { utils.MarshalJSON(w, resp) } -func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (string, error) { +func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (*tokenexchange.JWTProfileRequest, error) { err := r.ParseForm() if err != nil { - return "", ErrInvalidRequest("error parsing form") + return nil, ErrInvalidRequest("error parsing form") } tokenReq := new(tokenexchange.JWTProfileRequest) err = decoder.Decode(tokenReq, r.Form) if err != nil { - return "", ErrInvalidRequest("error decoding form") + return nil, ErrInvalidRequest("error decoding form") } - return tokenReq.Assertion, nil + return tokenReq, nil } func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { diff --git a/pkg/op/verifier_jwt_profile.go b/pkg/op/verifier_jwt_profile.go index b30bdc5..8a31253 100644 --- a/pkg/op/verifier_jwt_profile.go +++ b/pkg/op/verifier_jwt_profile.go @@ -8,6 +8,7 @@ import ( "gopkg.in/square/go-jose.v2" "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/oidc/grants/tokenexchange" ) type JWTProfileVerifier interface { @@ -47,9 +48,9 @@ func (v *jwtProfileVerifier) Offset() time.Duration { return v.offset } -func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) { +func VerifyJWTAssertion(ctx context.Context, profileRequest *tokenexchange.JWTProfileRequest, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) { request := new(oidc.JWTTokenRequest) - payload, err := oidc.ParseToken(assertion, request) + payload, err := oidc.ParseToken(profileRequest.Assertion, request) if err != nil { return nil, err } @@ -72,9 +73,10 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif keySet := &jwtProfileKeySet{v.Storage(), request.Subject} - if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil { + if err = oidc.CheckSignature(ctx, profileRequest.Assertion, payload, request, nil, keySet); err != nil { return nil, err } + request.Scopes = profileRequest.Scope return request, nil } diff --git a/pkg/rp/relaying_party.go b/pkg/rp/relaying_party.go index fd9ee95..a8bb9bb 100644 --- a/pkg/rp/relaying_party.go +++ b/pkg/rp/relaying_party.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "strings" + "time" "github.com/google/uuid" @@ -313,38 +314,45 @@ func CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc //ClientCredentials is the `RelayingParty` interface implementation //handling the oauth2 client credentials grant func ClientCredentials(ctx context.Context, rp RelayingParty, scopes ...string) (newToken *oauth2.Token, err error) { - return CallTokenEndpoint(grants.ClientCredentialsGrantBasic(scopes...), rp) + return CallTokenEndpointAuthorized(grants.ClientCredentialsGrantBasic(scopes...), rp) +} + +func CallTokenEndpointAuthorized(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) { + config := rp.OAuthConfig() + var fn interface{} = utils.AuthorizeBasic(config.ClientID, config.ClientSecret) + if config.Endpoint.AuthStyle == oauth2.AuthStyleInParams { + fn = func(form url.Values) { + form.Set("client_id", config.ClientID) + form.Set("client_secret", config.ClientSecret) + } + } + return callTokenEndpoint(request, fn, rp) } func CallTokenEndpoint(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) { - config := rp.OAuthConfig() - req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, config.ClientID, config.ClientSecret, config.Endpoint.AuthStyle != oauth2.AuthStyleInParams) - if err != nil { - return nil, err - } - token := new(oauth2.Token) - if err := utils.HttpRequest(rp.HttpClient(), req, token); err != nil { - return nil, err - } - return token, nil + return callTokenEndpoint(request, nil, rp) } -func CallJWTProfileEndpoint(assertion string, rp RelayingParty) (*oauth2.Token, error) { - form := url.Values{} - form.Add("assertion", assertion) - form.Add("grant_type", jwtProfileKey) - req, err := http.NewRequest("POST", rp.OAuthConfig().Endpoint.TokenURL, strings.NewReader(form.Encode())) +func callTokenEndpoint(request interface{}, authFn interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) { + req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, authFn) if err != nil { return nil, err } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - token := new(oauth2.Token) - if err := utils.HttpRequest(rp.HttpClient(), req, token); err != nil { + var tokenRes struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + } + if err := utils.HttpRequest(rp.HttpClient(), req, &tokenRes); err != nil { return nil, err } - return token, nil + return &oauth2.Token{ + AccessToken: tokenRes.AccessToken, + TokenType: tokenRes.TokenType, + RefreshToken: tokenRes.RefreshToken, + Expiry: time.Now().UTC().Add(time.Duration(tokenRes.ExpiresIn) * time.Second), + }, nil } func trySetStateCookie(w http.ResponseWriter, state string, rp RelayingParty) error { diff --git a/pkg/rp/tockenexchange.go b/pkg/rp/tockenexchange.go index 24b588a..4396dc4 100644 --- a/pkg/rp/tockenexchange.go +++ b/pkg/rp/tockenexchange.go @@ -43,12 +43,17 @@ func DelegationTokenExchange(ctx context.Context, subjectToken string, rp Relayi } //JWTProfileExchange handles the oauth2 jwt profile exchange -func JWTProfileExchange(ctx context.Context, assertion *oidc.JWTProfileAssertion, rp RelayingParty) (*oauth2.Token, error) { +func JWTProfileExchange(ctx context.Context, jwtProfileRequest *tokenexchange.JWTProfileRequest, rp RelayingParty) (*oauth2.Token, error) { + return CallTokenEndpoint(jwtProfileRequest, rp) +} + +//JWTProfileExchange handles the oauth2 jwt profile exchange +func JWTProfileAssertionExchange(ctx context.Context, assertion *oidc.JWTProfileAssertion, scopes oidc.Scopes, rp RelayingParty) (*oauth2.Token, error) { token, err := generateJWTProfileToken(assertion) if err != nil { return nil, err } - return CallJWTProfileEndpoint(token, rp) + return JWTProfileExchange(ctx, tokenexchange.NewJWTProfileRequest(token, scopes...), rp) } func generateJWTProfileToken(assertion *oidc.JWTProfileAssertion) (string, error) { diff --git a/pkg/utils/http.go b/pkg/utils/http.go index 993febb..e785472 100644 --- a/pkg/utils/http.go +++ b/pkg/utils/http.go @@ -27,23 +27,31 @@ type Encoder interface { Encode(src interface{}, dst map[string][]string) error } -func FormRequest(endpoint string, request interface{}, clientID, clientSecret string, header bool) (*http.Request, error) { - form := make(map[string][]string) +type FormAuthorization func(url.Values) +type RequestAuthorization func(*http.Request) + +func AuthorizeBasic(user, password string) RequestAuthorization { + return func(req *http.Request) { + req.SetBasicAuth(user, password) + } +} + +func FormRequest(endpoint string, request interface{}, authFn interface{}) (*http.Request, error) { + form := url.Values{} encoder := schema.NewEncoder() if err := encoder.Encode(request, form); err != nil { return nil, err } - if !header { - form["client_id"] = []string{clientID} - form["client_secret"] = []string{clientSecret} + if fn, ok := authFn.(FormAuthorization); ok { + fn(form) } - body := strings.NewReader(url.Values(form).Encode()) + body := strings.NewReader(form.Encode()) req, err := http.NewRequest("POST", endpoint, body) if err != nil { return nil, err } - if header { - req.SetBasicAuth(clientID, clientSecret) + if fn, ok := authFn.(RequestAuthorization); ok { + fn(req) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") return req, nil From 1661b40fbe21ea5679c6ae9545ad2363a48111e5 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Mon, 28 Sep 2020 15:06:14 +0200 Subject: [PATCH 05/17] fix tests --- example/client/github/github.go | 2 +- pkg/oidc/types_test.go | 2 +- pkg/rp/mock/verifier.mock.impl.go | 37 ------------------------------- 3 files changed, 2 insertions(+), 39 deletions(-) delete mode 100644 pkg/rp/mock/verifier.mock.impl.go diff --git a/example/client/github/github.go b/example/client/github/github.go index 5489389..c136091 100644 --- a/example/client/github/github.go +++ b/example/client/github/github.go @@ -45,7 +45,7 @@ func main() { } token := cli.CodeFlow(relayingParty, callbackPath, port, state) - client := github.NewClient(relayingParty.Client(ctx, token.Token)) + client := github.NewClient(relayingParty.OAuthConfig().Client(ctx, token.Token)) _, _, err = client.Users.Get(ctx, "") if err != nil { diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go index 830fb02..8138b4b 100644 --- a/pkg/oidc/types_test.go +++ b/pkg/oidc/types_test.go @@ -28,7 +28,7 @@ func TestAudience_UnmarshalText(t *testing.T) { []byte(`{"aud": {"a": }}}`), }, res{}, - false, + true, }, { "single audience", diff --git a/pkg/rp/mock/verifier.mock.impl.go b/pkg/rp/mock/verifier.mock.impl.go deleted file mode 100644 index 0b6dd1c..0000000 --- a/pkg/rp/mock/verifier.mock.impl.go +++ /dev/null @@ -1,37 +0,0 @@ -package mock - -import ( - "errors" - "testing" - - "github.com/golang/mock/gomock" - - "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/rp" -) - -func NewVerifier(t *testing.T) rp.Verifier { - return NewMockVerifier(gomock.NewController(t)) -} - -func NewMockVerifierExpectInvalid(t *testing.T) rp.Verifier { - m := NewVerifier(t) - ExpectVerifyInvalid(m) - return m -} - -func ExpectVerifyInvalid(v rp.Verifier) { - mock := v.(*MockVerifier) - mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(nil, errors.New("invalid")) -} - -func NewMockVerifierExpectValid(t *testing.T) rp.Verifier { - m := NewVerifier(t) - ExpectVerifyValid(m) - return m -} - -func ExpectVerifyValid(v rp.Verifier) { - mock := v.(*MockVerifier) - mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.userinfo{Subject: "id"}}, nil) -} From 507a437c564348728ab877e33aab258bc71dbc11 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Tue, 29 Sep 2020 08:13:51 +0200 Subject: [PATCH 06/17] scope form encoding --- pkg/oidc/types.go | 4 ++-- pkg/rp/relaying_party.go | 14 +++++++++++++- pkg/utils/http.go | 5 +---- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index 8423cff..86e5d06 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -60,8 +60,8 @@ type ResponseType string type Scopes []string -func (s *Scopes) Encode() string { - return strings.Join(*s, " ") +func (s Scopes) Encode() string { + return strings.Join(s, " ") } func (s *Scopes) UnmarshalText(text []byte) error { diff --git a/pkg/rp/relaying_party.go b/pkg/rp/relaying_party.go index a8bb9bb..6807221 100644 --- a/pkg/rp/relaying_party.go +++ b/pkg/rp/relaying_party.go @@ -5,10 +5,12 @@ import ( "errors" "net/http" "net/url" + "reflect" "strings" "time" "github.com/google/uuid" + "github.com/gorilla/schema" "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc/grants" @@ -24,6 +26,16 @@ const ( jwtProfileKey = "urn:ietf:params:oauth:grant-type:jwt-bearer" ) +var ( + encoder = func() utils.Encoder { + e := schema.NewEncoder() + e.RegisterEncoder(oidc.Scopes{}, func(value reflect.Value) string { + return value.Interface().(oidc.Scopes).Encode() + }) + return e + }() +) + //RelayingParty declares the minimal interface for oidc clients type RelayingParty interface { //OAuthConfig returns the oauth2 Config @@ -334,7 +346,7 @@ func CallTokenEndpoint(request interface{}, rp RelayingParty) (newToken *oauth2. } func callTokenEndpoint(request interface{}, authFn interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) { - req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, authFn) + req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, encoder, authFn) if err != nil { return nil, err } diff --git a/pkg/utils/http.go b/pkg/utils/http.go index e785472..fa51815 100644 --- a/pkg/utils/http.go +++ b/pkg/utils/http.go @@ -10,8 +10,6 @@ import ( "net/url" "strings" "time" - - "github.com/gorilla/schema" ) var ( @@ -36,9 +34,8 @@ func AuthorizeBasic(user, password string) RequestAuthorization { } } -func FormRequest(endpoint string, request interface{}, authFn interface{}) (*http.Request, error) { +func FormRequest(endpoint string, request interface{}, encoder Encoder, authFn interface{}) (*http.Request, error) { form := url.Values{} - encoder := schema.NewEncoder() if err := encoder.Encode(request, form); err != nil { return nil, err } From f845ce2010c986b93b4b48d9672df56d86b59612 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Tue, 29 Sep 2020 08:34:37 +0200 Subject: [PATCH 07/17] comments --- pkg/oidc/token_request.go | 33 ++++++++++++++++++++------------- pkg/op/token.go | 4 ++-- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go index 1d958df..800c515 100644 --- a/pkg/oidc/token_request.go +++ b/pkg/oidc/token_request.go @@ -43,51 +43,58 @@ type JWTTokenRequest struct { ExpiresAt Time `json:"exp"` } -func (j *JWTTokenRequest) GetClientID() string { - return j.Subject -} - -func (j *JWTTokenRequest) GetSubject() string { - return j.Subject -} - -func (j *JWTTokenRequest) GetScopes() []string { - return j.Scopes -} - +//GetSubject implements the Claims interface func (j *JWTTokenRequest) GetIssuer() string { return j.Issuer } +//GetAudience implements the Claims and TokenRequest interfaces func (j *JWTTokenRequest) GetAudience() []string { return j.Audience } +//GetExpiration implements the Claims interface func (j *JWTTokenRequest) GetExpiration() time.Time { return time.Time(j.ExpiresAt) } +//GetIssuedAt implements the Claims interface func (j *JWTTokenRequest) GetIssuedAt() time.Time { return time.Time(j.IssuedAt) } +//GetNonce implements the Claims interface func (j *JWTTokenRequest) GetNonce() string { return "" } +//GetAuthenticationContextClassReference implements the Claims interface func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string { return "" } +//GetAuthTime implements the Claims interface func (j *JWTTokenRequest) GetAuthTime() time.Time { return time.Time{} } +//GetAuthorizedParty implements the Claims interface func (j *JWTTokenRequest) GetAuthorizedParty() string { return "" } -func (j *JWTTokenRequest) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {} +//SetSignatureAlgorithm implements the Claims interface +func (j *JWTTokenRequest) SetSignatureAlgorithm(_ jose.SignatureAlgorithm) {} + +//GetSubject implements the TokenRequest interface +func (j *JWTTokenRequest) GetSubject() string { + return j.Subject +} + +//GetSubject implements the TokenRequest interface +func (j *JWTTokenRequest) GetScopes() []string { + return j.Scopes +} type TokenExchangeRequest struct { subjectToken string `schema:"subject_token"` diff --git a/pkg/op/token.go b/pkg/op/token.go index bb2b3c5..a2236d4 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -82,8 +82,8 @@ func CreateBearerToken(id string, crypto Crypto) (string, error) { return crypto.Encrypt(id) } -func CreateJWT(issuer string, authReq TokenRequest, exp time.Time, id string, signer Signer) (string, error) { - claims := oidc.NewAccessTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, id) +func CreateJWT(issuer string, tokenRequest TokenRequest, exp time.Time, id string, signer Signer) (string, error) { + claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id) return utils.Sign(claims, signer.Signer()) } From 707029d431ade2a88a34645989f9dc37e686226c Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Tue, 29 Sep 2020 08:40:32 +0200 Subject: [PATCH 08/17] update example --- example/client/app/app.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example/client/app/app.go b/example/client/app/app.go index a2fff44..b747474 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -97,8 +97,8 @@ func main() {
- - + +
` @@ -131,7 +131,7 @@ func main() { http.Error(w, err.Error(), http.StatusInternalServerError) return } - token, err := rp.JWTProfileAssertionExchange(ctx, assertion, oidc.Scopes{oidc.ScopeOpenID, oidc.ScopeProfile}, provider) + token, err := rp.JWTProfileAssertionExchange(ctx, assertion, scopes, provider) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return From b2903212ab05fa2e95f83bb6959411c4509e686a Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Wed, 30 Sep 2020 08:40:28 +0200 Subject: [PATCH 09/17] cleanup --- example/client/app/app.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/client/app/app.go b/example/client/app/app.go index b747474..1c9c469 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -32,7 +32,7 @@ func main() { ctx := context.Background() redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) - scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress, "hodor"} + scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress} cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, rp.WithPKCE(cookieHandler), From b311610d063e02d27d7dc7771b4d8b4563376209 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Wed, 7 Oct 2020 08:44:26 +0200 Subject: [PATCH 10/17] feat: check allowed scopes (and pass clientID to GetUserinfoFromScopes) --- example/internal/mock/storage.go | 8 +++-- pkg/op/authrequest.go | 32 ++++++++++++++++---- pkg/op/authrequest_test.go | 52 +++++++++++++++++++++++++++----- pkg/op/client.go | 1 + pkg/op/mock/client.go | 1 + pkg/op/mock/client.mock.go | 14 +++++++++ pkg/op/mock/storage.mock.go | 8 ++--- pkg/op/mock/storage.mock.impl.go | 3 ++ pkg/op/storage.go | 2 +- pkg/op/token.go | 2 +- 10 files changed, 101 insertions(+), 22 deletions(-) diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 1c33906..e3a4e1a 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -211,9 +211,9 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st } func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (oidc.UserInfoSetter, error) { - return s.GetUserinfoFromScopes(ctx, "", []string{}) + return s.GetUserinfoFromScopes(ctx, "", "", []string{}) } -func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (oidc.UserInfoSetter, error) { +func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _, _ string, _ []string) (oidc.UserInfoSetter, error) { userinfo := oidc.NewUserInfo() userinfo.SetSubject(a.GetSubject()) userinfo.SetAddress(oidc.NewUserInfoAddress("Test 789\nPostfach 2", "", "", "", "", "")) @@ -276,3 +276,7 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType { func (c *ConfClient) DevMode() bool { return c.devMode } + +func (c *ConfClient) AllowedScopes() []string { + return nil +} diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index cee6184..86e2275 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -91,7 +91,8 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage if err != nil { return "", ErrServerError(err.Error()) } - if err := ValidateAuthReqScopes(authReq.Scopes); err != nil { + authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes) + if err != nil { return "", err } if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil { @@ -104,14 +105,33 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage } //ValidateAuthReqScopes validates the passed scopes -func ValidateAuthReqScopes(scopes []string) error { +func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) { if len(scopes) == 0 { - return ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.") + return nil, ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.") } - if !utils.Contains(scopes, oidc.ScopeOpenID) { - return ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.") + openID := false + for i := len(scopes) - 1; i >= 0; i-- { + switch scopes[i] { + case oidc.ScopeOpenID: + openID = true + case oidc.ScopeProfile, + oidc.ScopeEmail, + oidc.ScopePhone, + oidc.ScopeAddress, + oidc.ScopeOfflineAccess: + default: + if !utils.Contains(client.AllowedScopes(), scopes[i]) { + scopes[i] = scopes[len(scopes)-1] + scopes[len(scopes)-1] = "" + scopes = scopes[:len(scopes)-1] + } + } } - return nil + if !openID { + return nil, ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.") + } + + return scopes, nil } //ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type diff --git a/pkg/op/authrequest_test.go b/pkg/op/authrequest_test.go index d74d365..3856acd 100644 --- a/pkg/op/authrequest_test.go +++ b/pkg/op/authrequest_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/gorilla/schema" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/caos/oidc/pkg/oidc" @@ -193,28 +194,63 @@ func TestValidateAuthRequest(t *testing.T) { func TestValidateAuthReqScopes(t *testing.T) { type args struct { + client op.Client + scopes []string + } + type res struct { + err bool scopes []string } tests := []struct { - name string - args args - wantErr bool + name string + args args + res res }{ { - "scopes missing fails", args{}, true, + "scopes missing fails", + args{}, + res{ + err: true, + }, }, { - "scope openid missing fails", args{[]string{"email"}}, true, + "scope openid missing fails", + args{ + mock.NewClientExpectAny(t, op.ApplicationTypeWeb), + []string{"email"}, + }, + res{ + err: true, + }, }, { - "scope ok", args{[]string{"openid"}}, false, + "scope ok", + args{ + mock.NewClientExpectAny(t, op.ApplicationTypeWeb), + []string{"openid"}, + }, + res{ + scopes: []string{"openid"}, + }, + }, + { + "scope with drop ok", + args{ + mock.NewClientExpectAny(t, op.ApplicationTypeWeb), + []string{"openid", "email", "unknown"}, + }, + res{ + scopes: []string{"openid", "email"}, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := op.ValidateAuthReqScopes(tt.args.scopes); (err != nil) != tt.wantErr { - t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr) + scopes, err := op.ValidateAuthReqScopes(tt.args.client, tt.args.scopes) + if (err != nil) != tt.res.err { + t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.res.err) } + assert.ElementsMatch(t, scopes, tt.res.scopes) }) } } diff --git a/pkg/op/client.go b/pkg/op/client.go index 3184b90..258ce6e 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -32,6 +32,7 @@ type Client interface { AccessTokenType() AccessTokenType IDTokenLifetime() time.Duration DevMode() bool + AllowedScopes() []string } func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool { diff --git a/pkg/op/mock/client.go b/pkg/op/mock/client.go index eed21d5..12c00cc 100644 --- a/pkg/op/mock/client.go +++ b/pkg/op/mock/client.go @@ -26,6 +26,7 @@ func NewClientExpectAny(t *testing.T, appType op.ApplicationType) op.Client { func(id string) string { return "login?id=" + id }) + m.EXPECT().AllowedScopes().AnyTimes().Return(nil) return c } diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go index 4007347..8e18d56 100644 --- a/pkg/op/mock/client.mock.go +++ b/pkg/op/mock/client.mock.go @@ -49,6 +49,20 @@ func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType)) } +// AllowedScopes mocks base method +func (m *MockClient) AllowedScopes() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AllowedScopes") + ret0, _ := ret[0].([]string) + return ret0 +} + +// AllowedScopes indicates an expected call of AllowedScopes +func (mr *MockClientMockRecorder) AllowedScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowedScopes", reflect.TypeOf((*MockClient)(nil).AllowedScopes)) +} + // ApplicationType mocks base method func (m *MockClient) ApplicationType() op.ApplicationType { m.ctrl.T.Helper() diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 1bcd1a6..973f58b 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -184,18 +184,18 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac } // GetUserinfoFromScopes mocks base method -func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (oidc.UserInfoSetter, error) { +func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (oidc.UserInfoSetter, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(oidc.UserInfoSetter) ret1, _ := ret[1].(error) return ret0, ret1 } // GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes -func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2, arg3) } // GetUserinfoFromToken mocks base method diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go index 6fd2760..54cd059 100644 --- a/pkg/op/mock/storage.mock.impl.go +++ b/pkg/op/mock/storage.mock.impl.go @@ -168,3 +168,6 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType { func (c *ConfClient) DevMode() bool { return c.devMode } +func (c *ConfClient) AllowedScopes() []string { + return nil +} diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 69784ee..1c266d7 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -28,7 +28,7 @@ type AuthStorage interface { type OPStorage interface { GetClientByClientID(context.Context, string) (Client, error) AuthorizeClientIDSecret(context.Context, string, string) error - GetUserinfoFromScopes(context.Context, string, []string) (oidc.UserInfoSetter, error) + GetUserinfoFromScopes(context.Context, string, string, []string) (oidc.UserInfoSetter, error) GetUserinfoFromToken(context.Context, string, string) (oidc.UserInfoSetter, error) GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) } diff --git a/pkg/op/token.go b/pkg/op/token.go index a2236d4..670fca7 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -98,7 +98,7 @@ func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, vali } claims.SetAccessTokenHash(atHash) } else { - userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes()) + userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetClientID(), authReq.GetScopes()) if err != nil { return "", err } From d6203fb0d57b90780b9831ff899f89775b5ae05c Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Wed, 7 Oct 2020 08:49:23 +0200 Subject: [PATCH 11/17] chore: move CAOS_OIDC_DEV to const (and ensure TestValidateIssuer runs (even on machines with env set)) --- pkg/op/config.go | 4 +++- pkg/op/config_test.go | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pkg/op/config.go b/pkg/op/config.go index b3df943..d64c0ee 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -7,6 +7,8 @@ import ( "strings" ) +const OidcDevMode = "CAOS_OIDC_DEV" + type Configuration interface { Issuer() string AuthorizationEndpoint() Endpoint @@ -42,7 +44,7 @@ func ValidateIssuer(issuer string) error { } func devLocalAllowed(url *url.URL) bool { - _, b := os.LookupEnv("CAOS_OIDC_DEV") + _, b := os.LookupEnv(OidcDevMode) if !b { return b } diff --git a/pkg/op/config_test.go b/pkg/op/config_test.go index 79173fb..e140074 100644 --- a/pkg/op/config_test.go +++ b/pkg/op/config_test.go @@ -60,6 +60,8 @@ func TestValidateIssuer(t *testing.T) { true, }, } + //ensure env is not set + os.Unsetenv(OidcDevMode) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr { @@ -84,7 +86,7 @@ func TestValidateIssuerDevLocalAllowed(t *testing.T) { false, }, } - os.Setenv("CAOS_OIDC_DEV", "") + os.Setenv(OidcDevMode, "true") for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr { From b8d892443ce332472fc3cb11c4a1817e64490206 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Wed, 14 Oct 2020 16:41:04 +0200 Subject: [PATCH 12/17] claims assertion --- pkg/oidc/token.go | 433 +++++++++++++++++++------------- pkg/op/client.go | 4 + pkg/op/mock/client.mock.go | 28 +++ pkg/op/mock/storage.mock.go | 23 +- pkg/op/op.go | 19 +- pkg/op/storage.go | 5 +- pkg/op/token.go | 64 ++++- pkg/op/userinfo.go | 19 +- pkg/op/verifier_access_token.go | 85 +++++++ 9 files changed, 491 insertions(+), 189 deletions(-) create mode 100644 pkg/op/verifier_access_token.go diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index e445e7e..2a8c0ad 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -24,6 +24,8 @@ type Tokens struct { type AccessTokenClaims interface { Claims + GetTokenID() string + SetPrivateClaims(map[string]interface{}) } type IDTokenClaims interface { @@ -36,67 +38,13 @@ type IDTokenClaims interface { GetClientID() string GetSignatureAlgorithm() jose.SignatureAlgorithm SetAccessTokenHash(hash string) - SetUserinfo(userinfo UserInfoSetter) + SetUserinfo(userinfo UserInfo) SetCodeHash(hash string) UserInfo } -type accessTokenClaims struct { - Issuer string - Subject string - Audience Audience - Expiration Time - IssuedAt Time - NotBefore Time - JWTID string - AuthorizedParty string - Nonce string - AuthTime Time - CodeHash string - AuthenticationContextClassReference string - AuthenticationMethodsReferences []string - SessionID string - Scopes []string - ClientID string - AccessTokenUseNumber int - - signatureAlg jose.SignatureAlgorithm -} - -func (a accessTokenClaims) GetIssuer() string { - return a.Issuer -} - -func (a accessTokenClaims) GetAudience() []string { - return a.Audience -} - -func (a accessTokenClaims) GetExpiration() time.Time { - return time.Time(a.Expiration) -} - -func (a accessTokenClaims) GetIssuedAt() time.Time { - return time.Time(a.IssuedAt) -} - -func (a accessTokenClaims) GetNonce() string { - return a.Nonce -} - -func (a accessTokenClaims) GetAuthenticationContextClassReference() string { - return a.AuthenticationContextClassReference -} - -func (a accessTokenClaims) GetAuthTime() time.Time { - return time.Time(a.AuthTime) -} - -func (a accessTokenClaims) GetAuthorizedParty() string { - return a.AuthorizedParty -} - -func (a accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) { - a.signatureAlg = algorithm +func EmptyAccessTokenClaims() AccessTokenClaims { + return new(accessTokenClaims) } func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, id string) AccessTokenClaims { @@ -112,6 +60,155 @@ func NewAccessTokenClaims(issuer, subject string, audience []string, expiration } } +type accessTokenClaims struct { + Issuer string `json:"iss,omitempty"` + Subject string `json:"sub,omitempty"` + Audience Audience `json:"aud,omitempty"` + Expiration Time `json:"exp,omitempty"` + IssuedAt Time `json:"iat,omitempty"` + NotBefore Time `json:"nbf,omitempty"` + JWTID string `json:"jti,omitempty"` + AuthorizedParty string `json:"azp,omitempty"` + Nonce string `json:"nonce,omitempty"` + AuthTime Time `json:"auth_time,omitempty"` + CodeHash string `json:"c_hash,omitempty"` + AuthenticationContextClassReference string `json:"acr,omitempty"` + AuthenticationMethodsReferences []string `json:"amr,omitempty"` + SessionID string `json:"sid,omitempty"` + Scopes []string `json:"scope,omitempty"` + ClientID string `json:"client_id,omitempty"` + AccessTokenUseNumber int `json:"at_use_nbr,omitempty"` + + claims map[string]interface{} `json:"-"` + signatureAlg jose.SignatureAlgorithm `json:"-"` +} + +//GetIssuer implements the Claims interface +func (a *accessTokenClaims) GetIssuer() string { + return a.Issuer +} + +//GetAudience implements the Claims interface +func (a *accessTokenClaims) GetAudience() []string { + return a.Audience +} + +//GetExpiration implements the Claims interface +func (a *accessTokenClaims) GetExpiration() time.Time { + return time.Time(a.Expiration) +} + +//GetIssuedAt implements the Claims interface +func (a *accessTokenClaims) GetIssuedAt() time.Time { + return time.Time(a.IssuedAt) +} + +//GetNonce implements the Claims interface +func (a *accessTokenClaims) GetNonce() string { + return a.Nonce +} + +//GetAuthenticationContextClassReference implements the Claims interface +func (a *accessTokenClaims) GetAuthenticationContextClassReference() string { + return a.AuthenticationContextClassReference +} + +//GetAuthTime implements the Claims interface +func (a *accessTokenClaims) GetAuthTime() time.Time { + return time.Time(a.AuthTime) +} + +//GetAuthorizedParty implements the Claims interface +func (a *accessTokenClaims) GetAuthorizedParty() string { + return a.AuthorizedParty +} + +//SetSignatureAlgorithm implements the Claims interface +func (a *accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) { + a.signatureAlg = algorithm +} + +//GetTokenID implements the AccessTokenClaims interface +func (a *accessTokenClaims) GetTokenID() string { + return a.JWTID +} + +//SetPrivateClaims implements the AccessTokenClaims interface +func (a *accessTokenClaims) SetPrivateClaims(claims map[string]interface{}) { + a.claims = claims +} + +func (a *accessTokenClaims) MarshalJSON() ([]byte, error) { + type Alias accessTokenClaims + s := &struct { + *Alias + Expiration int64 `json:"exp,omitempty"` + IssuedAt int64 `json:"iat,omitempty"` + NotBefore int64 `json:"nbf,omitempty"` + AuthTime int64 `json:"auth_time,omitempty"` + }{ + Alias: (*Alias)(a), + } + if !time.Time(a.Expiration).IsZero() { + s.Expiration = time.Time(a.Expiration).Unix() + } + if !time.Time(a.IssuedAt).IsZero() { + s.IssuedAt = time.Time(a.IssuedAt).Unix() + } + if !time.Time(a.NotBefore).IsZero() { + s.NotBefore = time.Time(a.NotBefore).Unix() + } + if !time.Time(a.AuthTime).IsZero() { + s.AuthTime = time.Time(a.AuthTime).Unix() + } + b, err := json.Marshal(s) + if err != nil { + return nil, err + } + + if a.claims == nil { + return b, nil + } + info, err := json.Marshal(a.claims) + if err != nil { + return nil, err + } + return utils.ConcatenateJSON(b, info) +} + +func (a *accessTokenClaims) UnmarshalJSON(data []byte) error { + type Alias accessTokenClaims + if err := json.Unmarshal(data, (*Alias)(a)); err != nil { + return err + } + claims := make(map[string]interface{}) + if err := json.Unmarshal(data, &claims); err != nil { + return err + } + a.claims = claims + + return nil +} + +func EmptyIDTokenClaims() IDTokenClaims { + return new(idTokenClaims) +} + +func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string) IDTokenClaims { + return &idTokenClaims{ + Issuer: issuer, + Audience: audience, + Expiration: Time(expiration), + IssuedAt: Time(time.Now().UTC()), + AuthTime: Time(authTime), + Nonce: nonce, + AuthenticationContextClassReference: acr, + AuthenticationMethodsReferences: amr, + AuthorizedParty: clientID, + UserInfo: &userinfo{Subject: subject}, + } +} + type idTokenClaims struct { Issuer string `json:"iss,omitempty"` Audience Audience `json:"aud,omitempty"` @@ -132,65 +229,153 @@ type idTokenClaims struct { signatureAlg jose.SignatureAlgorithm } -func (t *idTokenClaims) SetAccessTokenHash(hash string) { - t.AccessTokenHash = hash +//GetIssuer implements the Claims interface +func (t *idTokenClaims) GetIssuer() string { + return t.Issuer } -func (t *idTokenClaims) SetUserinfo(info UserInfoSetter) { - t.UserInfo = info +//GetAudience implements the Claims interface +func (t *idTokenClaims) GetAudience() []string { + return t.Audience } -func (t *idTokenClaims) SetCodeHash(hash string) { - t.CodeHash = hash +//GetExpiration implements the Claims interface +func (t *idTokenClaims) GetExpiration() time.Time { + return time.Time(t.Expiration) } -func EmptyIDTokenClaims() IDTokenClaims { - return new(idTokenClaims) +//GetIssuedAt implements the Claims interface +func (t *idTokenClaims) GetIssuedAt() time.Time { + return time.Time(t.IssuedAt) } -func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string) IDTokenClaims { - return &idTokenClaims{ - Issuer: issuer, - Audience: audience, - Expiration: Time(expiration), - IssuedAt: Time(time.Now().UTC()), - AuthTime: Time(authTime), - Nonce: nonce, - AuthenticationContextClassReference: acr, - AuthenticationMethodsReferences: amr, - AuthorizedParty: clientID, - UserInfo: &userinfo{Subject: subject}, - } +//GetNonce implements the Claims interface +func (t *idTokenClaims) GetNonce() string { + return t.Nonce } -func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm { - return t.signatureAlg +//GetAuthenticationContextClassReference implements the Claims interface +func (t *idTokenClaims) GetAuthenticationContextClassReference() string { + return t.AuthenticationContextClassReference } +//GetAuthTime implements the Claims interface +func (t *idTokenClaims) GetAuthTime() time.Time { + return time.Time(t.AuthTime) +} + +//GetAuthorizedParty implements the Claims interface +func (t *idTokenClaims) GetAuthorizedParty() string { + return t.AuthorizedParty +} + +//SetSignatureAlgorithm implements the Claims interface +func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) { + t.signatureAlg = alg +} + +//GetNotBefore implements the IDTokenClaims interface func (t *idTokenClaims) GetNotBefore() time.Time { return time.Time(t.NotBefore) } +//GetJWTID implements the IDTokenClaims interface func (t *idTokenClaims) GetJWTID() string { return t.JWTID } +//GetAccessTokenHash implements the IDTokenClaims interface func (t *idTokenClaims) GetAccessTokenHash() string { return t.AccessTokenHash } +//GetCodeHash implements the IDTokenClaims interface func (t *idTokenClaims) GetCodeHash() string { return t.CodeHash } +//GetAuthenticationMethodsReferences implements the IDTokenClaims interface func (t *idTokenClaims) GetAuthenticationMethodsReferences() []string { return t.AuthenticationMethodsReferences } +//GetClientID implements the IDTokenClaims interface func (t *idTokenClaims) GetClientID() string { return t.ClientID } +//GetSignatureAlgorithm implements the IDTokenClaims interface +func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm { + return t.signatureAlg +} + +//SetSignatureAlgorithm implements the IDTokenClaims interface +func (t *idTokenClaims) SetAccessTokenHash(hash string) { + t.AccessTokenHash = hash +} + +//SetUserinfo implements the IDTokenClaims interface +func (t *idTokenClaims) SetUserinfo(info UserInfo) { + t.UserInfo = info +} + +//SetCodeHash implements the IDTokenClaims interface +func (t *idTokenClaims) SetCodeHash(hash string) { + t.CodeHash = hash +} + +func (t *idTokenClaims) MarshalJSON() ([]byte, error) { + type Alias idTokenClaims + a := &struct { + *Alias + Expiration int64 `json:"exp,omitempty"` + IssuedAt int64 `json:"iat,omitempty"` + NotBefore int64 `json:"nbf,omitempty"` + AuthTime int64 `json:"auth_time,omitempty"` + }{ + Alias: (*Alias)(t), + } + if !time.Time(t.Expiration).IsZero() { + a.Expiration = time.Time(t.Expiration).Unix() + } + if !time.Time(t.IssuedAt).IsZero() { + a.IssuedAt = time.Time(t.IssuedAt).Unix() + } + if !time.Time(t.NotBefore).IsZero() { + a.NotBefore = time.Time(t.NotBefore).Unix() + } + if !time.Time(t.AuthTime).IsZero() { + a.AuthTime = time.Time(t.AuthTime).Unix() + } + b, err := json.Marshal(a) + if err != nil { + return nil, err + } + + if t.UserInfo == nil { + return b, nil + } + info, err := json.Marshal(t.UserInfo) + if err != nil { + return nil, err + } + return utils.ConcatenateJSON(b, info) +} + +func (t *idTokenClaims) UnmarshalJSON(data []byte) error { + type Alias idTokenClaims + if err := json.Unmarshal(data, (*Alias)(t)); err != nil { + return err + } + userinfo := new(userinfo) + if err := json.Unmarshal(data, userinfo); err != nil { + return err + } + t.UserInfo = userinfo + + return nil +} + type AccessTokenResponse struct { AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"` TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"` @@ -242,94 +427,6 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte) } } -func (t *idTokenClaims) MarshalJSON() ([]byte, error) { - type Alias idTokenClaims - a := &struct { - *Alias - Expiration int64 `json:"nbf,omitempty"` - IssuedAt int64 `json:"nbf,omitempty"` - NotBefore int64 `json:"nbf,omitempty"` - AuthTime int64 `json:"nbf,omitempty"` - }{ - Alias: (*Alias)(t), - } - if !time.Time(t.Expiration).IsZero() { - a.Expiration = time.Time(t.Expiration).Unix() - } - if !time.Time(t.IssuedAt).IsZero() { - a.IssuedAt = time.Time(t.IssuedAt).Unix() - } - if !time.Time(t.NotBefore).IsZero() { - a.NotBefore = time.Time(t.NotBefore).Unix() - } - if !time.Time(t.AuthTime).IsZero() { - a.AuthTime = time.Time(t.AuthTime).Unix() - } - b, err := json.Marshal(a) - if err != nil { - return nil, err - } - - if t.UserInfo == nil { - return b, nil - } - info, err := json.Marshal(t.UserInfo) - if err != nil { - return nil, err - } - return utils.ConcatenateJSON(b, info) -} - -func (t *idTokenClaims) UnmarshalJSON(data []byte) error { - type Alias idTokenClaims - if err := json.Unmarshal(data, (*Alias)(t)); err != nil { - return err - } - userinfo := new(userinfo) - if err := json.Unmarshal(data, userinfo); err != nil { - return err - } - t.UserInfo = userinfo - - return nil -} - -func (t *idTokenClaims) GetIssuer() string { - return t.Issuer -} - -func (t *idTokenClaims) GetAudience() []string { - return t.Audience -} - -func (t *idTokenClaims) GetExpiration() time.Time { - return time.Time(t.Expiration) -} - -func (t *idTokenClaims) GetIssuedAt() time.Time { - return time.Time(t.IssuedAt) -} - -func (t *idTokenClaims) GetNonce() string { - return t.Nonce -} - -func (t *idTokenClaims) GetAuthenticationContextClassReference() string { - return t.AuthenticationContextClassReference -} - -func (t *idTokenClaims) GetAuthTime() time.Time { - return time.Time(t.AuthTime) -} - -func (t *idTokenClaims) GetAuthorizedParty() string { - return t.AuthorizedParty -} - -func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) { - t.signatureAlg = alg -} - func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { hash, err := utils.GetHashAlgorithm(sigAlgorithm) if err != nil { diff --git a/pkg/op/client.go b/pkg/op/client.go index 258ce6e..790933e 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -10,7 +10,9 @@ const ( ApplicationTypeWeb ApplicationType = iota ApplicationTypeUserAgent ApplicationTypeNative +) +const ( AccessTokenTypeBearer AccessTokenType = iota AccessTokenTypeJWT ) @@ -33,6 +35,8 @@ type Client interface { IDTokenLifetime() time.Duration DevMode() bool AllowedScopes() []string + AssertAdditionalIdTokenScopes() bool + AssertAdditionalAccessTokenScopes() bool } func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool { diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go index 8e18d56..0780623 100644 --- a/pkg/op/mock/client.mock.go +++ b/pkg/op/mock/client.mock.go @@ -77,6 +77,34 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType)) } +// AssertAdditionalAccessTokenScopes mocks base method +func (m *MockClient) AssertAdditionalAccessTokenScopes() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AssertAdditionalAccessTokenScopes") + ret0, _ := ret[0].(bool) + return ret0 +} + +// AssertAdditionalAccessTokenScopes indicates an expected call of AssertAdditionalAccessTokenScopes +func (mr *MockClientMockRecorder) AssertAdditionalAccessTokenScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalAccessTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalAccessTokenScopes)) +} + +// AssertAdditionalIdTokenScopes mocks base method +func (m *MockClient) AssertAdditionalIdTokenScopes() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AssertAdditionalIdTokenScopes") + ret0, _ := ret[0].(bool) + return ret0 +} + +// AssertAdditionalIdTokenScopes indicates an expected call of AssertAdditionalIdTokenScopes +func (mr *MockClientMockRecorder) AssertAdditionalIdTokenScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalIdTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalIdTokenScopes)) +} + // AuthMethod mocks base method func (m *MockClient) AuthMethod() op.AuthMethod { m.ctrl.T.Helper() diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 973f58b..a184597 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -171,6 +171,21 @@ func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0) } +// GetPrivateClaimsFromScopes mocks base method +func (m *MockStorage) GetPrivateClaimsFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (map[string]interface{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPrivateClaimsFromScopes", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(map[string]interface{}) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPrivateClaimsFromScopes indicates an expected call of GetPrivateClaimsFromScopes +func (mr *MockStorageMockRecorder) GetPrivateClaimsFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivateClaimsFromScopes", reflect.TypeOf((*MockStorage)(nil).GetPrivateClaimsFromScopes), arg0, arg1, arg2, arg3) +} + // GetSigningKey mocks base method func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) { m.ctrl.T.Helper() @@ -184,10 +199,10 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac } // GetUserinfoFromScopes mocks base method -func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (oidc.UserInfoSetter, error) { +func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (oidc.UserInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(oidc.UserInfoSetter) + ret0, _ := ret[0].(oidc.UserInfo) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -199,10 +214,10 @@ func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2, arg3 } // GetUserinfoFromToken mocks base method -func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (oidc.UserInfoSetter, error) { +func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (oidc.UserInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2) - ret0, _ := ret[0].(oidc.UserInfoSetter) + ret0, _ := ret[0].(oidc.UserInfo) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/pkg/op/op.go b/pkg/op/op.go index 7e8279a..bba7a14 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -51,6 +51,7 @@ type OpenIDProvider interface { Encoder() utils.Encoder IDTokenHintVerifier() IDTokenHintVerifier JWTProfileVerifier() JWTProfileVerifier + AccessTokenVerifier() AccessTokenVerifier Crypto() Crypto DefaultLogoutRedirectURI() string Signer() Signer @@ -152,6 +153,8 @@ type openidProvider struct { signer Signer idTokenHintVerifier IDTokenHintVerifier jwtProfileVerifier JWTProfileVerifier + accessTokenVerifier AccessTokenVerifier + keySet *openIDKeySet crypto Crypto httpHandler http.Handler decoder *schema.Decoder @@ -207,7 +210,7 @@ func (o *openidProvider) Encoder() utils.Encoder { func (o *openidProvider) IDTokenHintVerifier() IDTokenHintVerifier { if o.idTokenHintVerifier == nil { - o.idTokenHintVerifier = NewIDTokenHintVerifier(o.Issuer(), &openIDKeySet{o.Storage()}) + o.idTokenHintVerifier = NewIDTokenHintVerifier(o.Issuer(), o.openIDKeySet()) } return o.idTokenHintVerifier } @@ -219,6 +222,20 @@ func (o *openidProvider) JWTProfileVerifier() JWTProfileVerifier { return o.jwtProfileVerifier } +func (o *openidProvider) AccessTokenVerifier() AccessTokenVerifier { + if o.accessTokenVerifier == nil { + o.accessTokenVerifier = NewAccessTokenVerifier(o.Issuer(), o.openIDKeySet()) + } + return o.accessTokenVerifier +} + +func (o *openidProvider) openIDKeySet() oidc.KeySet { + if o.keySet == nil { + o.keySet = &openIDKeySet{o.Storage()} + } + return o.keySet +} + func (o *openidProvider) Crypto() Crypto { return o.crypto } diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 1c266d7..10e7779 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -28,8 +28,9 @@ type AuthStorage interface { type OPStorage interface { GetClientByClientID(context.Context, string) (Client, error) AuthorizeClientIDSecret(context.Context, string, string) error - GetUserinfoFromScopes(context.Context, string, string, []string) (oidc.UserInfoSetter, error) - GetUserinfoFromToken(context.Context, string, string) (oidc.UserInfoSetter, error) + GetUserinfoFromScopes(context.Context, string, string, []string) (oidc.UserInfo, error) + GetUserinfoFromToken(context.Context, string, string) (oidc.UserInfo, error) + GetPrivateClaimsFromScopes(context.Context, string, string, []string) (map[string]interface{}, error) GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) } diff --git a/pkg/op/token.go b/pkg/op/token.go index 670fca7..f542588 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -26,12 +26,12 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client var validity time.Duration if createAccessToken { var err error - accessToken, validity, err = CreateAccessToken(ctx, authReq, client.AccessTokenType(), creator) + accessToken, validity, err = CreateAccessToken(ctx, authReq, client.AccessTokenType(), creator, client) if err != nil { return nil, err } } - idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer()) + idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer(), client.AssertAdditionalIdTokenScopes()) if err != nil { return nil, err } @@ -51,7 +51,7 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client } func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator) (*oidc.AccessTokenResponse, error) { - accessToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator) + accessToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator, nil) if err != nil { return nil, err } @@ -64,14 +64,14 @@ func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, crea }, nil } -func CreateAccessToken(ctx context.Context, authReq TokenRequest, accessTokenType AccessTokenType, creator TokenCreator) (token string, validity time.Duration, err error) { - id, exp, err := creator.Storage().CreateToken(ctx, authReq) +func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client Client) (token string, validity time.Duration, err error) { + id, exp, err := creator.Storage().CreateToken(ctx, tokenRequest) if err != nil { return "", 0, err } validity = exp.Sub(time.Now().UTC()) if accessTokenType == AccessTokenTypeJWT { - token, err = CreateJWT(creator.Issuer(), authReq, exp, id, creator.Signer()) + token, err = CreateJWT(ctx, creator.Issuer(), tokenRequest, exp, id, creator.Signer(), client, creator.Storage()) return } token, err = CreateBearerToken(id, creator.Crypto()) @@ -82,14 +82,22 @@ func CreateBearerToken(id string, crypto Crypto) (string, error) { return crypto.Encrypt(id) } -func CreateJWT(issuer string, tokenRequest TokenRequest, exp time.Time, id string, signer Signer) (string, error) { +func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, signer Signer, client Client, storage Storage) (string, error) { claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id) + if client != nil && client.AssertAdditionalAccessTokenScopes() { + privateClaims, err := storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(tokenRequest.GetScopes())) + if err != nil { + return "", err + } + claims.SetPrivateClaims(privateClaims) + } return utils.Sign(claims, signer.Signer()) } -func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer) (string, error) { +func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer, additonalScopes bool) (string, error) { exp := time.Now().UTC().Add(validity) claims := oidc.NewIDTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, authReq.GetAuthTime(), authReq.GetNonce(), authReq.GetACR(), authReq.GetAMR(), authReq.GetClientID()) + scopes := authReq.GetScopes() if accessToken != "" { atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm()) @@ -97,8 +105,13 @@ func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, vali return "", err } claims.SetAccessTokenHash(atHash) - } else { - userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetClientID(), authReq.GetScopes()) + scopes = removeUserinfoScopes(scopes) + } + if !additonalScopes { + scopes = removeAdditionalScopes(scopes) + } + if len(scopes) > 0 { + userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetClientID(), scopes) if err != nil { return "", err } @@ -114,3 +127,34 @@ func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, vali return utils.Sign(claims, signer.Signer()) } + +func removeUserinfoScopes(scopes []string) []string { + for i := len(scopes) - 1; i >= 0; i-- { + if scopes[i] == oidc.ScopeProfile || + scopes[i] == oidc.ScopeEmail || + scopes[i] == oidc.ScopeAddress || + scopes[i] == oidc.ScopePhone { + + scopes[i] = scopes[len(scopes)-1] + scopes[len(scopes)-1] = "" + scopes = scopes[:len(scopes)-1] + } + } + return scopes +} + +func removeAdditionalScopes(scopes []string) []string { + for i := len(scopes) - 1; i >= 0; i-- { + if !(scopes[i] == oidc.ScopeOpenID || + scopes[i] == oidc.ScopeProfile || + scopes[i] == oidc.ScopeEmail || + scopes[i] == oidc.ScopeAddress || + scopes[i] == oidc.ScopePhone) { + + scopes[i] = scopes[len(scopes)-1] + scopes[len(scopes)-1] = "" + scopes = scopes[:len(scopes)-1] + } + } + return scopes +} diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index 36ecd4a..0b27a5e 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -13,6 +13,7 @@ type UserinfoProvider interface { Decoder() utils.Decoder Crypto() Crypto Storage() Storage + AccessTokenVerifier() AccessTokenVerifier } func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) { @@ -27,10 +28,20 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP http.Error(w, "access token missing", http.StatusUnauthorized) return } - tokenID, err := userinfoProvider.Crypto().Decrypt(accessToken) - if err != nil { - http.Error(w, "access token missing", http.StatusUnauthorized) - return + var tokenID string + if strings.HasPrefix(accessToken, "eyJhbGci") { //TODO: improve + accessTokenClaims, err := VerifyAccessToken(r.Context(), accessToken, userinfoProvider.AccessTokenVerifier()) + if err != nil { + http.Error(w, "access token invalid", http.StatusUnauthorized) + return + } + tokenID = accessTokenClaims.GetTokenID() + } else { + tokenID, err = userinfoProvider.Crypto().Decrypt(accessToken) + if err != nil { + http.Error(w, "access token invalid", http.StatusUnauthorized) + return + } } info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), tokenID, r.Header.Get("origin")) if err != nil { diff --git a/pkg/op/verifier_access_token.go b/pkg/op/verifier_access_token.go new file mode 100644 index 0000000..05168a6 --- /dev/null +++ b/pkg/op/verifier_access_token.go @@ -0,0 +1,85 @@ +package op + +import ( + "context" + "time" + + "github.com/caos/oidc/pkg/oidc" +) + +type AccessTokenVerifier interface { + oidc.Verifier + SupportedSignAlgs() []string + KeySet() oidc.KeySet +} + +type accessTokenVerifier struct { + issuer string + maxAgeIAT time.Duration + offset time.Duration + supportedSignAlgs []string + maxAge time.Duration + acr oidc.ACRVerifier + keySet oidc.KeySet +} + +//Issuer implements oidc.Verifier interface +func (i *accessTokenVerifier) Issuer() string { + return i.issuer +} + +//MaxAgeIAT implements oidc.Verifier interface +func (i *accessTokenVerifier) MaxAgeIAT() time.Duration { + return i.maxAgeIAT +} + +//Offset implements oidc.Verifier interface +func (i *accessTokenVerifier) Offset() time.Duration { + return i.offset +} + +//SupportedSignAlgs implements AccessTokenVerifier interface +func (i *accessTokenVerifier) SupportedSignAlgs() []string { + return i.supportedSignAlgs +} + +//KeySet implements AccessTokenVerifier interface +func (i *accessTokenVerifier) KeySet() oidc.KeySet { + return i.keySet +} + +func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet) AccessTokenVerifier { + verifier := &idTokenHintVerifier{ + issuer: issuer, + keySet: keySet, + } + return verifier +} + +//VerifyAccessToken validates the access token (issuer, signature and expiration) +func VerifyAccessToken(ctx context.Context, token string, v AccessTokenVerifier) (oidc.AccessTokenClaims, error) { + claims := oidc.EmptyAccessTokenClaims() + + decrypted, err := oidc.DecryptToken(token) + if err != nil { + return nil, err + } + payload, err := oidc.ParseToken(decrypted, claims) + if err != nil { + return nil, err + } + + if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { + return nil, err + } + + if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { + return nil, err + } + + if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { + return nil, err + } + + return claims, nil +} From 44c341d42e76b406a39600a916c071ba86096e24 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Thu, 15 Oct 2020 09:41:50 +0200 Subject: [PATCH 13/17] improve userinfo token handling --- pkg/op/userinfo.go | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index 0b27a5e..6701eb3 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -28,20 +28,14 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP http.Error(w, "access token missing", http.StatusUnauthorized) return } - var tokenID string - if strings.HasPrefix(accessToken, "eyJhbGci") { //TODO: improve + tokenID, err := userinfoProvider.Crypto().Decrypt(accessToken) + if err != nil { accessTokenClaims, err := VerifyAccessToken(r.Context(), accessToken, userinfoProvider.AccessTokenVerifier()) if err != nil { http.Error(w, "access token invalid", http.StatusUnauthorized) return } tokenID = accessTokenClaims.GetTokenID() - } else { - tokenID, err = userinfoProvider.Crypto().Decrypt(accessToken) - if err != nil { - http.Error(w, "access token invalid", http.StatusUnauthorized) - return - } } info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), tokenID, r.Header.Get("origin")) if err != nil { From d89470a33f339dc0d4d7b0c41c045c38ba103dd5 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Thu, 15 Oct 2020 12:39:07 +0200 Subject: [PATCH 14/17] improve userinfo token handling --- pkg/oidc/token.go | 6 ++++++ pkg/op/userinfo.go | 32 ++++++++++++++++++++++---------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 2a8c0ad..99f18c7 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -24,6 +24,7 @@ type Tokens struct { type AccessTokenClaims interface { Claims + GetSubject() string GetTokenID() string SetPrivateClaims(map[string]interface{}) } @@ -128,6 +129,11 @@ func (a *accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgori a.signatureAlg = algorithm } +//GetSubject implements the AccessTokenClaims interface +func (a *accessTokenClaims) GetSubject() string { + return a.Subject +} + //GetTokenID implements the AccessTokenClaims interface func (a *accessTokenClaims) GetTokenID() string { return a.JWTID diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index f1991ac..d5ca68e 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -1,6 +1,7 @@ package op import ( + "context" "errors" "net/http" "strings" @@ -28,17 +29,12 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP http.Error(w, "access token missing", http.StatusUnauthorized) return } - tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken) - if err != nil { - accessTokenClaims, err := VerifyAccessToken(r.Context(), accessToken, userinfoProvider.AccessTokenVerifier()) - if err != nil { - http.Error(w, "access token invalid", http.StatusUnauthorized) - return - } - tokenID = accessTokenClaims.GetTokenID() + tokenID, subject, ok := getTokenIDAndSubject(r.Context(), userinfoProvider, accessToken) + if !ok { + http.Error(w, "access token invalid", http.StatusUnauthorized) + return } - splittedToken := strings.Split(tokenIDSubject, ":") - info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), splittedToken[0], splittedToken[1], r.Header.Get("origin")) + info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), tokenID, subject, r.Header.Get("origin")) if err != nil { w.WriteHeader(http.StatusForbidden) utils.MarshalJSON(w, err) @@ -67,3 +63,19 @@ func getAccessToken(r *http.Request, decoder utils.Decoder) (string, error) { } return req.AccessToken, nil } + +func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) { + tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken) + if err == nil { + splittedToken := strings.Split(tokenIDSubject, ":") + if len(splittedToken) != 2 { + return "", "", false + } + return splittedToken[0], splittedToken[1], true + } + accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier()) + if err != nil { + return "", "", false + } + return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true +} From ed33332dce70fbc800007e2c0e70298149c3ae91 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Thu, 15 Oct 2020 13:41:31 +0200 Subject: [PATCH 15/17] merging and missing mocks --- example/internal/mock/storage.go | 15 +++++++++++++-- pkg/op/mock/storage.mock.go | 8 ++++---- pkg/op/mock/storage.mock.impl.go | 6 ++++++ pkg/op/token.go | 2 +- pkg/op/userinfo.go | 6 +++--- 5 files changed, 27 insertions(+), 10 deletions(-) diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index a9d5c9b..9671ec7 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -210,10 +210,10 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st return nil } -func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _, _ string) (oidc.UserInfoSetter, error) { +func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _, _ string) (oidc.UserInfo, error) { return s.GetUserinfoFromScopes(ctx, "", "", []string{}) } -func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _, _ string, _ []string) (oidc.UserInfoSetter, error) { +func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _, _ string, _ []string) (oidc.UserInfo, error) { userinfo := oidc.NewUserInfo() userinfo.SetSubject(a.GetSubject()) userinfo.SetAddress(oidc.NewUserInfoAddress("Test 789\nPostfach 2", "", "", "", "", "")) @@ -223,6 +223,9 @@ func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _, _ string, _ [] userinfo.AppendClaims("private_claim", "test") return userinfo, nil } +func (s *AuthStorage) GetPrivateClaimsFromScopes(_ context.Context, _, _ string, _ []string) (map[string]interface{}, error) { + return map[string]interface{}{"private_claim": "test"}, nil +} type ConfClient struct { applicationType op.ApplicationType @@ -280,3 +283,11 @@ func (c *ConfClient) DevMode() bool { func (c *ConfClient) AllowedScopes() []string { return nil } + +func (c *ConfClient) AssertAdditionalIdTokenScopes() bool { + return false +} + +func (c *ConfClient) AssertAdditionalAccessTokenScopes() bool { + return false +} diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index a184597..9e4963a 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -214,18 +214,18 @@ func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2, arg3 } // GetUserinfoFromToken mocks base method -func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (oidc.UserInfo, error) { +func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2, arg3 string) (oidc.UserInfo, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(oidc.UserInfo) ret1, _ := ret[1].(error) return ret0, ret1 } // GetUserinfoFromToken indicates an expected call of GetUserinfoFromToken -func (mr *MockStorageMockRecorder) GetUserinfoFromToken(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) GetUserinfoFromToken(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromToken", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromToken), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromToken", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromToken), arg0, arg1, arg2, arg3) } // Health mocks base method diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go index 54cd059..de9dee9 100644 --- a/pkg/op/mock/storage.mock.impl.go +++ b/pkg/op/mock/storage.mock.impl.go @@ -171,3 +171,9 @@ func (c *ConfClient) DevMode() bool { func (c *ConfClient) AllowedScopes() []string { return nil } +func (c *ConfClient) AssertAdditionalIdTokenScopes() bool { + return false +} +func (c *ConfClient) AssertAdditionalAccessTokenScopes() bool { + return false +} diff --git a/pkg/op/token.go b/pkg/op/token.go index 67bcaae..2d66ef5 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -74,7 +74,7 @@ func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTok token, err = CreateJWT(ctx, creator.Issuer(), tokenRequest, exp, id, creator.Signer(), client, creator.Storage()) return } - token, err = CreateBearerToken(id, authReq.GetSubject(), creator.Crypto()) + token, err = CreateBearerToken(id, tokenRequest.GetSubject(), creator.Crypto()) return } diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index d5ca68e..1163598 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -67,11 +67,11 @@ func getAccessToken(r *http.Request, decoder utils.Decoder) (string, error) { func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) { tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken) if err == nil { - splittedToken := strings.Split(tokenIDSubject, ":") - if len(splittedToken) != 2 { + splitToken := strings.Split(tokenIDSubject, ":") + if len(splitToken) != 2 { return "", "", false } - return splittedToken[0], splittedToken[1], true + return splitToken[0], splitToken[1], true } accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier()) if err != nil { From 5cc884766e12e24bf84c0c29bc34e569f8195bd8 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Thu, 15 Oct 2020 14:02:53 +0200 Subject: [PATCH 16/17] improve ValidateAuthReqScopes --- pkg/op/authrequest.go | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index 86e2275..4d6118c 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -111,20 +111,20 @@ func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) { } openID := false for i := len(scopes) - 1; i >= 0; i-- { - switch scopes[i] { - case oidc.ScopeOpenID: + scope := scopes[i] + if scope == oidc.ScopeOpenID { openID = true - case oidc.ScopeProfile, - oidc.ScopeEmail, - oidc.ScopePhone, - oidc.ScopeAddress, - oidc.ScopeOfflineAccess: - default: - if !utils.Contains(client.AllowedScopes(), scopes[i]) { - scopes[i] = scopes[len(scopes)-1] - scopes[len(scopes)-1] = "" - scopes = scopes[:len(scopes)-1] - } + continue + } + if !(scope == oidc.ScopeProfile || + scope == oidc.ScopeEmail || + scope == oidc.ScopePhone || + scope == oidc.ScopeAddress || + scope == oidc.ScopeOfflineAccess) && + !utils.Contains(client.AllowedScopes(), scope) { + scopes[i] = scopes[len(scopes)-1] + scopes[len(scopes)-1] = "" + scopes = scopes[:len(scopes)-1] } } if !openID { From 736d6902d964d0a8caed1e510e1711f14dc2ab20 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Thu, 15 Oct 2020 15:23:33 +0200 Subject: [PATCH 17/17] solve PR issues --- pkg/oidc/token_request.go | 2 +- pkg/op/storage.go | 8 +++--- pkg/utils/marshal_test.go | 60 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 pkg/utils/marshal_test.go diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go index 800c515..e80d28a 100644 --- a/pkg/oidc/token_request.go +++ b/pkg/oidc/token_request.go @@ -43,7 +43,7 @@ type JWTTokenRequest struct { ExpiresAt Time `json:"exp"` } -//GetSubject implements the Claims interface +//GetIssuer implements the Claims interface func (j *JWTTokenRequest) GetIssuer() string { return j.Issuer } diff --git a/pkg/op/storage.go b/pkg/op/storage.go index e220c15..eba5003 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -26,11 +26,11 @@ type AuthStorage interface { } type OPStorage interface { - GetClientByClientID(context.Context, string) (Client, error) - AuthorizeClientIDSecret(context.Context, string, string) error - GetUserinfoFromScopes(context.Context, string, string, []string) (oidc.UserInfo, error) + GetClientByClientID(ctx context.Context, clientID string) (Client, error) + AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error + GetUserinfoFromScopes(ctx context.Context, userID, clientID string, scopes []string) (oidc.UserInfo, error) GetUserinfoFromToken(ctx context.Context, tokenID, subject, origin string) (oidc.UserInfo, error) - GetPrivateClaimsFromScopes(context.Context, string, string, []string) (map[string]interface{}, error) + GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]interface{}, error) GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) } diff --git a/pkg/utils/marshal_test.go b/pkg/utils/marshal_test.go new file mode 100644 index 0000000..f9221f6 --- /dev/null +++ b/pkg/utils/marshal_test.go @@ -0,0 +1,60 @@ +package utils + +import ( + "bytes" + "testing" +) + +func TestConcatenateJSON(t *testing.T) { + type args struct { + first []byte + second []byte + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + { + "invalid first part, error", + args{ + []byte(`invalid`), + []byte(`{"some": "thing"}`), + }, + nil, + true, + }, + { + "invalid second part, error", + args{ + []byte(`{"some": "thing"}`), + []byte(`invalid`), + }, + nil, + true, + }, + { + "both valid, merged", + args{ + []byte(`{"some": "thing"}`), + []byte(`{"another": "thing"}`), + }, + + []byte(`{"some": "thing","another": "thing"}`), + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ConcatenateJSON(tt.args.first, tt.args.second) + if (err != nil) != tt.wantErr { + t.Errorf("ConcatenateJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !bytes.Equal(got, tt.want) { + t.Errorf("ConcatenateJSON() got = %v, want %v", got, tt.want) + } + }) + } +}