diff --git a/example/client/app/app.go b/example/client/app/app.go
index 4c0831b..1c9c469 100644
--- a/example/client/app/app.go
+++ b/example/client/app/app.go
@@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
+ "html/template"
+ "io/ioutil"
"net/http"
"os"
"time"
@@ -30,7 +32,7 @@ func main() {
ctx := context.Background()
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
- scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail}
+ scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress}
cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes,
rp.WithPKCE(cookieHandler),
@@ -82,6 +84,66 @@ func main() {
}
w.Write(data)
})
+
+ http.HandleFunc("/jwt-profile", func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == "GET" {
+ tpl := `
+
+
+
+
+ Login
+
+
+
+
+ `
+ t, err := template.New("login").Parse(tpl)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ err = t.Execute(w, nil)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
+ } else {
+ err := r.ParseMultipartForm(4 << 10)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ file, handler, err := r.FormFile("key")
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ defer file.Close()
+
+ key, err := ioutil.ReadAll(file)
+ fmt.Println(handler.Header)
+ assertion, err := oidc.NewJWTProfileAssertionFromFileData(key, []string{issuer})
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ token, err := rp.JWTProfileAssertionExchange(ctx, assertion, scopes, provider)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ data, err := json.Marshal(token)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ w.Write(data)
+ }
+ })
lis := fmt.Sprintf("127.0.0.1:%s", port)
logrus.Infof("listening on http://%s/", lis)
logrus.Fatal(http.ListenAndServe("127.0.0.1:"+port, nil))
diff --git a/example/client/github/github.go b/example/client/github/github.go
index 5489389..c136091 100644
--- a/example/client/github/github.go
+++ b/example/client/github/github.go
@@ -45,7 +45,7 @@ func main() {
}
token := cli.CodeFlow(relayingParty, callbackPath, port, state)
- client := github.NewClient(relayingParty.Client(ctx, token.Token))
+ client := github.NewClient(relayingParty.OAuthConfig().Client(ctx, token.Token))
_, _, err = client.Users.Get(ctx, "")
if err != nil {
diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go
index 686f5ee..9671ec7 100644
--- a/example/internal/mock/storage.go
+++ b/example/internal/mock/storage.go
@@ -210,31 +210,21 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st
return nil
}
-func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _, _ string) (*oidc.Userinfo, error) {
- return s.GetUserinfoFromScopes(ctx, "", []string{})
+func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _, _ string) (oidc.UserInfo, error) {
+ return s.GetUserinfoFromScopes(ctx, "", "", []string{})
}
-func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (*oidc.Userinfo, error) {
- return &oidc.Userinfo{
- Subject: a.GetSubject(),
- Address: &oidc.UserinfoAddress{
- StreetAddress: "Hjkhkj 789\ndsf",
- },
- UserinfoEmail: oidc.UserinfoEmail{
- Email: "test",
- EmailVerified: true,
- },
- UserinfoPhone: oidc.UserinfoPhone{
- PhoneNumber: "sadsa",
- PhoneNumberVerified: true,
- },
- UserinfoProfile: oidc.UserinfoProfile{
- UpdatedAt: time.Now(),
- },
- // Claims: map[string]interface{}{
- // "test": "test",
- // "hkjh": "",
- // },
- }, nil
+func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _, _ string, _ []string) (oidc.UserInfo, error) {
+ userinfo := oidc.NewUserInfo()
+ userinfo.SetSubject(a.GetSubject())
+ userinfo.SetAddress(oidc.NewUserInfoAddress("Test 789\nPostfach 2", "", "", "", "", ""))
+ userinfo.SetEmail("test", true)
+ userinfo.SetPhone("0791234567", true)
+ userinfo.SetName("Test")
+ userinfo.AppendClaims("private_claim", "test")
+ return userinfo, nil
+}
+func (s *AuthStorage) GetPrivateClaimsFromScopes(_ context.Context, _, _ string, _ []string) (map[string]interface{}, error) {
+ return map[string]interface{}{"private_claim": "test"}, nil
}
type ConfClient struct {
@@ -289,3 +279,15 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType {
func (c *ConfClient) DevMode() bool {
return c.devMode
}
+
+func (c *ConfClient) AllowedScopes() []string {
+ return nil
+}
+
+func (c *ConfClient) AssertAdditionalIdTokenScopes() bool {
+ return false
+}
+
+func (c *ConfClient) AssertAdditionalAccessTokenScopes() bool {
+ return false
+}
diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go
index 35da2fb..71776af 100644
--- a/pkg/oidc/authorization.go
+++ b/pkg/oidc/authorization.go
@@ -1,15 +1,5 @@
package oidc
-import (
- "encoding/json"
- "errors"
- "strings"
- "time"
-
- "golang.org/x/text/language"
- "gopkg.in/square/go-jose.v2"
-)
-
const (
//ScopeOpenID defines the scope `openid`
//OpenID Connect requests MUST contain the `openid` scope value
@@ -64,23 +54,8 @@ const (
//PromptSelectAccount (`select_account `) directs the Authorization Server to prompt the End-User to select a user account (to enable multi user / session switching)
PromptSelectAccount Prompt = "select_account"
-
- //GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
- GrantTypeCode GrantType = "authorization_code"
- //GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
- GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
-
- //BearerToken defines the token_type `Bearer`, which is returned in a successful token response
- BearerToken = "Bearer"
)
-var displayValues = map[string]Display{
- "page": DisplayPage,
- "popup": DisplayPopup,
- "touch": DisplayTouch,
- "wap": DisplayWAP,
-}
-
//AuthRequest according to:
//https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
type AuthRequest struct {
@@ -121,146 +96,3 @@ func (a *AuthRequest) GetResponseType() ResponseType {
func (a *AuthRequest) GetState() string {
return a.State
}
-
-type TokenRequest interface {
- // GrantType GrantType `schema:"grant_type"`
- GrantType() GrantType
-}
-
-type TokenRequestType GrantType
-
-type AccessTokenRequest struct {
- Code string `schema:"code"`
- RedirectURI string `schema:"redirect_uri"`
- ClientID string `schema:"client_id"`
- ClientSecret string `schema:"client_secret"`
- CodeVerifier string `schema:"code_verifier"`
-}
-
-func (a *AccessTokenRequest) GrantType() GrantType {
- return GrantTypeCode
-}
-
-type AccessTokenResponse struct {
- AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
- TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
- RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
- ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
- IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
-}
-
-type JWTTokenRequest struct {
- Issuer string `json:"iss"`
- Subject string `json:"sub"`
- Scopes Scopes `json:"scope"`
- Audience interface{} `json:"aud"`
- IssuedAt Time `json:"iat"`
- ExpiresAt Time `json:"exp"`
-}
-
-func (j *JWTTokenRequest) GetClientID() string {
- return j.Subject
-}
-
-func (j *JWTTokenRequest) GetSubject() string {
- return j.Subject
-}
-
-func (j *JWTTokenRequest) GetScopes() []string {
- return j.Scopes
-}
-
-type Time time.Time
-
-func (t *Time) UnmarshalJSON(data []byte) error {
- var i int64
- if err := json.Unmarshal(data, &i); err != nil {
- return err
- }
- *t = Time(time.Unix(i, 0).UTC())
- return nil
-}
-
-func (j *JWTTokenRequest) GetIssuer() string {
- return j.Issuer
-}
-
-func (j *JWTTokenRequest) GetAudience() []string {
- return audienceFromJSON(j.Audience)
-}
-
-func (j *JWTTokenRequest) GetExpiration() time.Time {
- return time.Time(j.ExpiresAt)
-}
-
-func (j *JWTTokenRequest) GetIssuedAt() time.Time {
- return time.Time(j.IssuedAt)
-}
-
-func (j *JWTTokenRequest) GetNonce() string {
- return ""
-}
-
-func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string {
- return ""
-}
-
-func (j *JWTTokenRequest) GetAuthTime() time.Time {
- return time.Time{}
-}
-
-func (j *JWTTokenRequest) GetAuthorizedParty() string {
- return ""
-}
-
-func (j *JWTTokenRequest) SetSignature(algorithm jose.SignatureAlgorithm) {}
-
-type TokenExchangeRequest struct {
- subjectToken string `schema:"subject_token"`
- subjectTokenType string `schema:"subject_token_type"`
- actorToken string `schema:"actor_token"`
- actorTokenType string `schema:"actor_token_type"`
- resource []string `schema:"resource"`
- audience []string `schema:"audience"`
- Scope []string `schema:"scope"`
- requestedTokenType string `schema:"requested_token_type"`
-}
-
-type Scopes []string
-
-func (s *Scopes) UnmarshalText(text []byte) error {
- scopes := strings.Split(string(text), " ")
- *s = Scopes(scopes)
- return nil
-}
-
-type ResponseType string
-
-type Display string
-
-func (d *Display) UnmarshalText(text []byte) error {
- var ok bool
- display := string(text)
- *d, ok = displayValues[display]
- if !ok {
- return errors.New("")
- }
- return nil
-}
-
-type Prompt string
-
-type Locales []language.Tag
-
-func (l *Locales) UnmarshalText(text []byte) error {
- locales := strings.Split(string(text), " ")
- for _, locale := range locales {
- tag, err := language.Parse(locale)
- if err == nil && !tag.IsRoot() {
- *l = append(*l, tag)
- }
- }
- return nil
-}
-
-type GrantType string
diff --git a/pkg/oidc/grants/tokenexchange/tokenexchange.go b/pkg/oidc/grants/tokenexchange/tokenexchange.go
index 9464605..5cb6e79 100644
--- a/pkg/oidc/grants/tokenexchange/tokenexchange.go
+++ b/pkg/oidc/grants/tokenexchange/tokenexchange.go
@@ -1,5 +1,9 @@
package tokenexchange
+import (
+ "github.com/caos/oidc/pkg/oidc"
+)
+
const (
AccessTokenType = "urn:ietf:params:oauth:token-type:access_token"
RefreshTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
@@ -23,7 +27,19 @@ type TokenExchangeRequest struct {
}
type JWTProfileRequest struct {
- Assertion string `schema:"assertion"`
+ Assertion string `schema:"assertion"`
+ Scope oidc.Scopes `schema:"scope"`
+ GrantType oidc.GrantType `schema:"grant_type"`
+}
+
+//ClientCredentialsGrantBasic creates an oauth2 `Client Credentials` Grant
+//sneding client_id and client_secret as basic auth header
+func NewJWTProfileRequest(assertion string, scopes ...string) *JWTProfileRequest {
+ return &JWTProfileRequest{
+ GrantType: oidc.GrantTypeBearer,
+ Assertion: assertion,
+ Scope: scopes,
+ }
}
func NewTokenExchangeRequest(subjectToken, subjectTokenType string, opts ...TokenExchangeOption) *TokenExchangeRequest {
diff --git a/pkg/oidc/keyset.go b/pkg/oidc/keyset.go
index abe55d1..0d8e02c 100644
--- a/pkg/oidc/keyset.go
+++ b/pkg/oidc/keyset.go
@@ -6,21 +6,19 @@ import (
"gopkg.in/square/go-jose.v2"
)
-// KeySet is a set of publc JSON Web Keys that can be used to validate the signature
-// of JSON web tokens. This is expected to be backed by a remote key set through
-// provider metadata discovery or an in-memory set of keys delivered out-of-band.
+//KeySet represents a set of JSON Web Keys
+// - remotely fetch via discovery and jwks_uri -> `remoteKeySet`
+// - held by the OP itself in storage -> `openIDKeySet`
+// - dynamically aggregated by request for OAuth JWT Profile Assertion -> `jwtProfileKeySet`
type KeySet interface {
- // VerifySignature parses the JSON web token, verifies the signature, and returns
- // the raw payload. Header and claim fields are validated by other parts of the
- // package. For example, the KeySet does not need to check values such as signature
- // algorithm, issuer, and audience since the IDTokenVerifier validates these values
- // independently.
- //
- // If VerifySignature makes HTTP requests to verify the token, it's expected to
- // use any HTTP client associated with the context through ClientContext.
+ //VerifySignature verifies the signature with the given keyset and returns the raw payload
VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error)
}
+//CheckKey searches the given JSON Web Keys for the requested key ID
+//and verifies the JSON Web Signature with the found key
+//
+//will return false but no error if key ID is not found
func CheckKey(keyID string, jws *jose.JSONWebSignature, keys ...jose.JSONWebKey) ([]byte, error, bool) {
for _, key := range keys {
if keyID == "" || key.KeyID == keyID {
diff --git a/pkg/oidc/session.go b/pkg/oidc/session.go
index 418439e..d6735b4 100644
--- a/pkg/oidc/session.go
+++ b/pkg/oidc/session.go
@@ -1,5 +1,7 @@
package oidc
+//EndSessionRequest for the RP-Initiated Logout according to:
+//https://openid.net/specs/openid-connect-rpinitiated-1_0.html#RPLogout
type EndSessionRequest struct {
IdTokenHint string `schema:"id_token_hint"`
PostLogoutRedirectURI string `schema:"post_logout_redirect_uri"`
diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go
index 21b0419..99f18c7 100644
--- a/pkg/oidc/token.go
+++ b/pkg/oidc/token.go
@@ -3,72 +3,401 @@ package oidc
import (
"encoding/json"
"io/ioutil"
- "strings"
"time"
"golang.org/x/oauth2"
- "golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/utils"
)
+const (
+ //BearerToken defines the token_type `Bearer`, which is returned in a successful token response
+ BearerToken = "Bearer"
+)
+
type Tokens struct {
*oauth2.Token
- IDTokenClaims *IDTokenClaims
+ IDTokenClaims IDTokenClaims
IDToken string
}
-type AccessTokenClaims struct {
- Issuer string
- Subject string
- Audiences []string
- Expiration time.Time
- IssuedAt time.Time
- NotBefore time.Time
- JWTID string
- AuthorizedParty string
- Nonce string
- AuthTime time.Time
- CodeHash string
- AuthenticationContextClassReference string
- AuthenticationMethodsReferences []string
- SessionID string
- Scopes []string
- ClientID string
- AccessTokenUseNumber int
+type AccessTokenClaims interface {
+ Claims
+ GetSubject() string
+ GetTokenID() string
+ SetPrivateClaims(map[string]interface{})
}
-type IDTokenClaims struct {
- Issuer string
- Audiences []string
- Expiration time.Time
- NotBefore time.Time
- IssuedAt time.Time
- JWTID string
- UpdatedAt time.Time
- AuthorizedParty string
- Nonce string
- AuthTime time.Time
- AccessTokenHash string
- CodeHash string
- AuthenticationContextClassReference string
- AuthenticationMethodsReferences []string
- ClientID string
- Userinfo
+type IDTokenClaims interface {
+ Claims
+ GetNotBefore() time.Time
+ GetJWTID() string
+ GetAccessTokenHash() string
+ GetCodeHash() string
+ GetAuthenticationMethodsReferences() []string
+ GetClientID() string
+ GetSignatureAlgorithm() jose.SignatureAlgorithm
+ SetAccessTokenHash(hash string)
+ SetUserinfo(userinfo UserInfo)
+ SetCodeHash(hash string)
+ UserInfo
+}
- Signature jose.SignatureAlgorithm //TODO: ???
+func EmptyAccessTokenClaims() AccessTokenClaims {
+ return new(accessTokenClaims)
+}
+
+func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, id string) AccessTokenClaims {
+ now := time.Now().UTC()
+ return &accessTokenClaims{
+ Issuer: issuer,
+ Subject: subject,
+ Audience: audience,
+ Expiration: Time(expiration),
+ IssuedAt: Time(now),
+ NotBefore: Time(now),
+ JWTID: id,
+ }
+}
+
+type accessTokenClaims struct {
+ Issuer string `json:"iss,omitempty"`
+ Subject string `json:"sub,omitempty"`
+ Audience Audience `json:"aud,omitempty"`
+ Expiration Time `json:"exp,omitempty"`
+ IssuedAt Time `json:"iat,omitempty"`
+ NotBefore Time `json:"nbf,omitempty"`
+ JWTID string `json:"jti,omitempty"`
+ AuthorizedParty string `json:"azp,omitempty"`
+ Nonce string `json:"nonce,omitempty"`
+ AuthTime Time `json:"auth_time,omitempty"`
+ CodeHash string `json:"c_hash,omitempty"`
+ AuthenticationContextClassReference string `json:"acr,omitempty"`
+ AuthenticationMethodsReferences []string `json:"amr,omitempty"`
+ SessionID string `json:"sid,omitempty"`
+ Scopes []string `json:"scope,omitempty"`
+ ClientID string `json:"client_id,omitempty"`
+ AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
+
+ claims map[string]interface{} `json:"-"`
+ signatureAlg jose.SignatureAlgorithm `json:"-"`
+}
+
+//GetIssuer implements the Claims interface
+func (a *accessTokenClaims) GetIssuer() string {
+ return a.Issuer
+}
+
+//GetAudience implements the Claims interface
+func (a *accessTokenClaims) GetAudience() []string {
+ return a.Audience
+}
+
+//GetExpiration implements the Claims interface
+func (a *accessTokenClaims) GetExpiration() time.Time {
+ return time.Time(a.Expiration)
+}
+
+//GetIssuedAt implements the Claims interface
+func (a *accessTokenClaims) GetIssuedAt() time.Time {
+ return time.Time(a.IssuedAt)
+}
+
+//GetNonce implements the Claims interface
+func (a *accessTokenClaims) GetNonce() string {
+ return a.Nonce
+}
+
+//GetAuthenticationContextClassReference implements the Claims interface
+func (a *accessTokenClaims) GetAuthenticationContextClassReference() string {
+ return a.AuthenticationContextClassReference
+}
+
+//GetAuthTime implements the Claims interface
+func (a *accessTokenClaims) GetAuthTime() time.Time {
+ return time.Time(a.AuthTime)
+}
+
+//GetAuthorizedParty implements the Claims interface
+func (a *accessTokenClaims) GetAuthorizedParty() string {
+ return a.AuthorizedParty
+}
+
+//SetSignatureAlgorithm implements the Claims interface
+func (a *accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
+ a.signatureAlg = algorithm
+}
+
+//GetSubject implements the AccessTokenClaims interface
+func (a *accessTokenClaims) GetSubject() string {
+ return a.Subject
+}
+
+//GetTokenID implements the AccessTokenClaims interface
+func (a *accessTokenClaims) GetTokenID() string {
+ return a.JWTID
+}
+
+//SetPrivateClaims implements the AccessTokenClaims interface
+func (a *accessTokenClaims) SetPrivateClaims(claims map[string]interface{}) {
+ a.claims = claims
+}
+
+func (a *accessTokenClaims) MarshalJSON() ([]byte, error) {
+ type Alias accessTokenClaims
+ s := &struct {
+ *Alias
+ Expiration int64 `json:"exp,omitempty"`
+ IssuedAt int64 `json:"iat,omitempty"`
+ NotBefore int64 `json:"nbf,omitempty"`
+ AuthTime int64 `json:"auth_time,omitempty"`
+ }{
+ Alias: (*Alias)(a),
+ }
+ if !time.Time(a.Expiration).IsZero() {
+ s.Expiration = time.Time(a.Expiration).Unix()
+ }
+ if !time.Time(a.IssuedAt).IsZero() {
+ s.IssuedAt = time.Time(a.IssuedAt).Unix()
+ }
+ if !time.Time(a.NotBefore).IsZero() {
+ s.NotBefore = time.Time(a.NotBefore).Unix()
+ }
+ if !time.Time(a.AuthTime).IsZero() {
+ s.AuthTime = time.Time(a.AuthTime).Unix()
+ }
+ b, err := json.Marshal(s)
+ if err != nil {
+ return nil, err
+ }
+
+ if a.claims == nil {
+ return b, nil
+ }
+ info, err := json.Marshal(a.claims)
+ if err != nil {
+ return nil, err
+ }
+ return utils.ConcatenateJSON(b, info)
+}
+
+func (a *accessTokenClaims) UnmarshalJSON(data []byte) error {
+ type Alias accessTokenClaims
+ if err := json.Unmarshal(data, (*Alias)(a)); err != nil {
+ return err
+ }
+ claims := make(map[string]interface{})
+ if err := json.Unmarshal(data, &claims); err != nil {
+ return err
+ }
+ a.claims = claims
+
+ return nil
+}
+
+func EmptyIDTokenClaims() IDTokenClaims {
+ return new(idTokenClaims)
+}
+
+func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string) IDTokenClaims {
+ return &idTokenClaims{
+ Issuer: issuer,
+ Audience: audience,
+ Expiration: Time(expiration),
+ IssuedAt: Time(time.Now().UTC()),
+ AuthTime: Time(authTime),
+ Nonce: nonce,
+ AuthenticationContextClassReference: acr,
+ AuthenticationMethodsReferences: amr,
+ AuthorizedParty: clientID,
+ UserInfo: &userinfo{Subject: subject},
+ }
+}
+
+type idTokenClaims struct {
+ Issuer string `json:"iss,omitempty"`
+ Audience Audience `json:"aud,omitempty"`
+ Expiration Time `json:"exp,omitempty"`
+ NotBefore Time `json:"nbf,omitempty"`
+ IssuedAt Time `json:"iat,omitempty"`
+ JWTID string `json:"jti,omitempty"`
+ AuthorizedParty string `json:"azp,omitempty"`
+ Nonce string `json:"nonce,omitempty"`
+ AuthTime Time `json:"auth_time,omitempty"`
+ AccessTokenHash string `json:"at_hash,omitempty"`
+ CodeHash string `json:"c_hash,omitempty"`
+ AuthenticationContextClassReference string `json:"acr,omitempty"`
+ AuthenticationMethodsReferences []string `json:"amr,omitempty"`
+ ClientID string `json:"client_id,omitempty"`
+ UserInfo `json:"-"`
+
+ signatureAlg jose.SignatureAlgorithm
+}
+
+//GetIssuer implements the Claims interface
+func (t *idTokenClaims) GetIssuer() string {
+ return t.Issuer
+}
+
+//GetAudience implements the Claims interface
+func (t *idTokenClaims) GetAudience() []string {
+ return t.Audience
+}
+
+//GetExpiration implements the Claims interface
+func (t *idTokenClaims) GetExpiration() time.Time {
+ return time.Time(t.Expiration)
+}
+
+//GetIssuedAt implements the Claims interface
+func (t *idTokenClaims) GetIssuedAt() time.Time {
+ return time.Time(t.IssuedAt)
+}
+
+//GetNonce implements the Claims interface
+func (t *idTokenClaims) GetNonce() string {
+ return t.Nonce
+}
+
+//GetAuthenticationContextClassReference implements the Claims interface
+func (t *idTokenClaims) GetAuthenticationContextClassReference() string {
+ return t.AuthenticationContextClassReference
+}
+
+//GetAuthTime implements the Claims interface
+func (t *idTokenClaims) GetAuthTime() time.Time {
+ return time.Time(t.AuthTime)
+}
+
+//GetAuthorizedParty implements the Claims interface
+func (t *idTokenClaims) GetAuthorizedParty() string {
+ return t.AuthorizedParty
+}
+
+//SetSignatureAlgorithm implements the Claims interface
+func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) {
+ t.signatureAlg = alg
+}
+
+//GetNotBefore implements the IDTokenClaims interface
+func (t *idTokenClaims) GetNotBefore() time.Time {
+ return time.Time(t.NotBefore)
+}
+
+//GetJWTID implements the IDTokenClaims interface
+func (t *idTokenClaims) GetJWTID() string {
+ return t.JWTID
+}
+
+//GetAccessTokenHash implements the IDTokenClaims interface
+func (t *idTokenClaims) GetAccessTokenHash() string {
+ return t.AccessTokenHash
+}
+
+//GetCodeHash implements the IDTokenClaims interface
+func (t *idTokenClaims) GetCodeHash() string {
+ return t.CodeHash
+}
+
+//GetAuthenticationMethodsReferences implements the IDTokenClaims interface
+func (t *idTokenClaims) GetAuthenticationMethodsReferences() []string {
+ return t.AuthenticationMethodsReferences
+}
+
+//GetClientID implements the IDTokenClaims interface
+func (t *idTokenClaims) GetClientID() string {
+ return t.ClientID
+}
+
+//GetSignatureAlgorithm implements the IDTokenClaims interface
+func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm {
+ return t.signatureAlg
+}
+
+//SetSignatureAlgorithm implements the IDTokenClaims interface
+func (t *idTokenClaims) SetAccessTokenHash(hash string) {
+ t.AccessTokenHash = hash
+}
+
+//SetUserinfo implements the IDTokenClaims interface
+func (t *idTokenClaims) SetUserinfo(info UserInfo) {
+ t.UserInfo = info
+}
+
+//SetCodeHash implements the IDTokenClaims interface
+func (t *idTokenClaims) SetCodeHash(hash string) {
+ t.CodeHash = hash
+}
+
+func (t *idTokenClaims) MarshalJSON() ([]byte, error) {
+ type Alias idTokenClaims
+ a := &struct {
+ *Alias
+ Expiration int64 `json:"exp,omitempty"`
+ IssuedAt int64 `json:"iat,omitempty"`
+ NotBefore int64 `json:"nbf,omitempty"`
+ AuthTime int64 `json:"auth_time,omitempty"`
+ }{
+ Alias: (*Alias)(t),
+ }
+ if !time.Time(t.Expiration).IsZero() {
+ a.Expiration = time.Time(t.Expiration).Unix()
+ }
+ if !time.Time(t.IssuedAt).IsZero() {
+ a.IssuedAt = time.Time(t.IssuedAt).Unix()
+ }
+ if !time.Time(t.NotBefore).IsZero() {
+ a.NotBefore = time.Time(t.NotBefore).Unix()
+ }
+ if !time.Time(t.AuthTime).IsZero() {
+ a.AuthTime = time.Time(t.AuthTime).Unix()
+ }
+ b, err := json.Marshal(a)
+ if err != nil {
+ return nil, err
+ }
+
+ if t.UserInfo == nil {
+ return b, nil
+ }
+ info, err := json.Marshal(t.UserInfo)
+ if err != nil {
+ return nil, err
+ }
+ return utils.ConcatenateJSON(b, info)
+}
+
+func (t *idTokenClaims) UnmarshalJSON(data []byte) error {
+ type Alias idTokenClaims
+ if err := json.Unmarshal(data, (*Alias)(t)); err != nil {
+ return err
+ }
+ userinfo := new(userinfo)
+ if err := json.Unmarshal(data, userinfo); err != nil {
+ return err
+ }
+ t.UserInfo = userinfo
+
+ return nil
+}
+
+type AccessTokenResponse struct {
+ AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
+ TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
+ RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
+ ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
+ IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
}
type JWTProfileAssertion struct {
- PrivateKeyID string `json:"keyId"`
- PrivateKey []byte `json:"key"`
- Scopes []string `json:"-"`
- Issuer string `json:"-"`
- Subject string `json:"userId"`
- Audience []string `json:"-"`
- Expiration time.Time `json:"-"`
- IssuedAt time.Time `json:"-"`
+ PrivateKeyID string `json:"-"`
+ PrivateKey []byte `json:"-"`
+ Issuer string `json:"issuer"`
+ Subject string `json:"sub"`
+ Audience Audience `json:"aud"`
+ Expiration Time `json:"exp"`
+ IssuedAt Time `json:"iat"`
}
func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWTProfileAssertion, error) {
@@ -76,12 +405,16 @@ func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWT
if err != nil {
return nil, err
}
+ return NewJWTProfileAssertionFromFileData(data, audience)
+}
+
+func NewJWTProfileAssertionFromFileData(data []byte, audience []string) (*JWTProfileAssertion, error) {
keyData := new(struct {
KeyID string `json:"keyId"`
Key string `json:"key"`
UserID string `json:"userId"`
})
- err = json.Unmarshal(data, keyData)
+ err := json.Unmarshal(data, keyData)
if err != nil {
return nil, err
}
@@ -93,244 +426,13 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte)
PrivateKey: key,
PrivateKeyID: keyID,
Issuer: userID,
- Scopes: []string{ScopeOpenID},
Subject: userID,
- IssuedAt: time.Now().UTC(),
- Expiration: time.Now().Add(1 * time.Hour).UTC(),
+ IssuedAt: Time(time.Now().UTC()),
+ Expiration: Time(time.Now().Add(1 * time.Hour).UTC()),
Audience: audience,
}
}
-type jsonToken struct {
- Issuer string `json:"iss,omitempty"`
- Subject string `json:"sub,omitempty"`
- Audiences interface{} `json:"aud,omitempty"`
- Expiration int64 `json:"exp,omitempty"`
- NotBefore int64 `json:"nbf,omitempty"`
- IssuedAt int64 `json:"iat,omitempty"`
- JWTID string `json:"jti,omitempty"`
- AuthorizedParty string `json:"azp,omitempty"`
- Nonce string `json:"nonce,omitempty"`
- AuthTime int64 `json:"auth_time,omitempty"`
- AccessTokenHash string `json:"at_hash,omitempty"`
- CodeHash string `json:"c_hash,omitempty"`
- AuthenticationContextClassReference string `json:"acr,omitempty"`
- AuthenticationMethodsReferences []string `json:"amr,omitempty"`
- SessionID string `json:"sid,omitempty"`
- Actor interface{} `json:"act,omitempty"` //TODO: impl
- Scopes string `json:"scope,omitempty"`
- ClientID string `json:"client_id,omitempty"`
- AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl
- AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
- jsonUserinfo
-}
-
-func (t *AccessTokenClaims) MarshalJSON() ([]byte, error) {
- j := jsonToken{
- Issuer: t.Issuer,
- Subject: t.Subject,
- Audiences: t.Audiences,
- Expiration: timeToJSON(t.Expiration),
- NotBefore: timeToJSON(t.NotBefore),
- IssuedAt: timeToJSON(t.IssuedAt),
- JWTID: t.JWTID,
- AuthorizedParty: t.AuthorizedParty,
- Nonce: t.Nonce,
- AuthTime: timeToJSON(t.AuthTime),
- CodeHash: t.CodeHash,
- AuthenticationContextClassReference: t.AuthenticationContextClassReference,
- AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
- SessionID: t.SessionID,
- Scopes: strings.Join(t.Scopes, " "),
- ClientID: t.ClientID,
- AccessTokenUseNumber: t.AccessTokenUseNumber,
- }
- return json.Marshal(j)
-}
-
-func (t *AccessTokenClaims) UnmarshalJSON(b []byte) error {
- var j jsonToken
- if err := json.Unmarshal(b, &j); err != nil {
- return err
- }
- t.Issuer = j.Issuer
- t.Subject = j.Subject
- t.Audiences = audienceFromJSON(j.Audiences)
- t.Expiration = time.Unix(j.Expiration, 0).UTC()
- t.NotBefore = time.Unix(j.NotBefore, 0).UTC()
- t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
- t.JWTID = j.JWTID
- t.AuthorizedParty = j.AuthorizedParty
- t.Nonce = j.Nonce
- t.AuthTime = time.Unix(j.AuthTime, 0).UTC()
- t.CodeHash = j.CodeHash
- t.AuthenticationContextClassReference = j.AuthenticationContextClassReference
- t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences
- t.SessionID = j.SessionID
- t.Scopes = strings.Split(j.Scopes, " ")
- t.ClientID = j.ClientID
- t.AccessTokenUseNumber = j.AccessTokenUseNumber
- return nil
-}
-
-func (t *IDTokenClaims) MarshalJSON() ([]byte, error) {
- j := jsonToken{
- Issuer: t.Issuer,
- Subject: t.Subject,
- Audiences: t.Audiences,
- Expiration: timeToJSON(t.Expiration),
- NotBefore: timeToJSON(t.NotBefore),
- IssuedAt: timeToJSON(t.IssuedAt),
- JWTID: t.JWTID,
- AuthorizedParty: t.AuthorizedParty,
- Nonce: t.Nonce,
- AuthTime: timeToJSON(t.AuthTime),
- AccessTokenHash: t.AccessTokenHash,
- CodeHash: t.CodeHash,
- AuthenticationContextClassReference: t.AuthenticationContextClassReference,
- AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
- ClientID: t.ClientID,
- }
- j.setUserinfo(t.Userinfo)
- return json.Marshal(j)
-}
-
-func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
- var i jsonToken
- if err := json.Unmarshal(b, &i); err != nil {
- return err
- }
- t.Issuer = i.Issuer
- t.Subject = i.Subject
- t.Audiences = audienceFromJSON(i.Audiences)
- t.Expiration = time.Unix(i.Expiration, 0).UTC()
- t.IssuedAt = time.Unix(i.IssuedAt, 0).UTC()
- t.AuthTime = time.Unix(i.AuthTime, 0).UTC()
- t.Nonce = i.Nonce
- t.AuthenticationContextClassReference = i.AuthenticationContextClassReference
- t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences
- t.AuthorizedParty = i.AuthorizedParty
- t.AccessTokenHash = i.AccessTokenHash
- t.CodeHash = i.CodeHash
- t.UserinfoProfile = i.UnmarshalUserinfoProfile()
- t.UserinfoEmail = i.UnmarshalUserinfoEmail()
- t.UserinfoPhone = i.UnmarshalUserinfoPhone()
- t.Address = i.UnmarshalUserinfoAddress()
- return nil
-}
-
-func (t *IDTokenClaims) GetIssuer() string {
- return t.Issuer
-}
-
-func (t *IDTokenClaims) GetAudience() []string {
- return t.Audiences
-}
-
-func (t *IDTokenClaims) GetExpiration() time.Time {
- return t.Expiration
-}
-
-func (t *IDTokenClaims) GetIssuedAt() time.Time {
- return t.IssuedAt
-}
-
-func (t *IDTokenClaims) GetNonce() string {
- return t.Nonce
-}
-
-func (t *IDTokenClaims) GetAuthenticationContextClassReference() string {
- return t.AuthenticationContextClassReference
-}
-
-func (t *IDTokenClaims) GetAuthTime() time.Time {
- return t.AuthTime
-}
-
-func (t *IDTokenClaims) GetAuthorizedParty() string {
- return t.AuthorizedParty
-}
-
-func (t *IDTokenClaims) SetSignature(alg jose.SignatureAlgorithm) {
- t.Signature = alg
-}
-
-func (t *JWTProfileAssertion) MarshalJSON() ([]byte, error) {
- j := jsonToken{
- Issuer: t.Issuer,
- Subject: t.Subject,
- Audiences: t.Audience,
- Expiration: timeToJSON(t.Expiration),
- IssuedAt: timeToJSON(t.IssuedAt),
- Scopes: strings.Join(t.Scopes, " "),
- }
- return json.Marshal(j)
-}
-
-func (t *JWTProfileAssertion) UnmarshalJSON(b []byte) error {
- var j jsonToken
- if err := json.Unmarshal(b, &j); err != nil {
- return err
- }
-
- t.Issuer = j.Issuer
- t.Subject = j.Subject
- t.Audience = audienceFromJSON(j.Audiences)
- t.Expiration = time.Unix(j.Expiration, 0).UTC()
- t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
- t.Scopes = strings.Split(j.Scopes, " ")
-
- return nil
-}
-
-func (j *jsonToken) UnmarshalUserinfoProfile() UserinfoProfile {
- locale, _ := language.Parse(j.Locale)
- return UserinfoProfile{
- Name: j.Name,
- GivenName: j.GivenName,
- FamilyName: j.FamilyName,
- MiddleName: j.MiddleName,
- Nickname: j.Nickname,
- Profile: j.Profile,
- Picture: j.Picture,
- Website: j.Website,
- Gender: Gender(j.Gender),
- Birthdate: j.Birthdate,
- Zoneinfo: j.Zoneinfo,
- Locale: locale,
- UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(),
- PreferredUsername: j.PreferredUsername,
- }
-}
-
-func (j *jsonToken) UnmarshalUserinfoEmail() UserinfoEmail {
- return UserinfoEmail{
- Email: j.Email,
- EmailVerified: j.EmailVerified,
- }
-}
-
-func (j *jsonToken) UnmarshalUserinfoPhone() UserinfoPhone {
- return UserinfoPhone{
- PhoneNumber: j.Phone,
- PhoneNumberVerified: j.PhoneVerified,
- }
-}
-
-func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress {
- if j.JsonUserinfoAddress == nil {
- return nil
- }
- return &UserinfoAddress{
- Country: j.JsonUserinfoAddress.Country,
- Formatted: j.JsonUserinfoAddress.Formatted,
- Locality: j.JsonUserinfoAddress.Locality,
- PostalCode: j.JsonUserinfoAddress.PostalCode,
- Region: j.JsonUserinfoAddress.Region,
- StreetAddress: j.JsonUserinfoAddress.StreetAddress,
- }
-}
-
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
hash, err := utils.GetHashAlgorithm(sigAlgorithm)
if err != nil {
@@ -339,26 +441,3 @@ func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, erro
return utils.HashString(hash, claim, true), nil
}
-
-func timeToJSON(t time.Time) int64 {
- if t.IsZero() {
- return 0
- }
- return t.Unix()
-}
-
-func audienceFromJSON(i interface{}) []string {
- switch aud := i.(type) {
- case []string:
- return aud
- case []interface{}:
- audience := make([]string, len(aud))
- for i, a := range aud {
- audience[i] = a.(string)
- }
- return audience
- case string:
- return []string{aud}
- }
- return nil
-}
diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go
new file mode 100644
index 0000000..e80d28a
--- /dev/null
+++ b/pkg/oidc/token_request.go
@@ -0,0 +1,108 @@
+package oidc
+
+import (
+ "time"
+
+ "gopkg.in/square/go-jose.v2"
+)
+
+const (
+ //GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
+ GrantTypeCode GrantType = "authorization_code"
+ //GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
+ GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
+)
+
+type GrantType string
+
+type TokenRequest interface {
+ // GrantType GrantType `schema:"grant_type"`
+ GrantType() GrantType
+}
+
+type TokenRequestType GrantType
+
+type AccessTokenRequest struct {
+ Code string `schema:"code"`
+ RedirectURI string `schema:"redirect_uri"`
+ ClientID string `schema:"client_id"`
+ ClientSecret string `schema:"client_secret"`
+ CodeVerifier string `schema:"code_verifier"`
+}
+
+func (a *AccessTokenRequest) GrantType() GrantType {
+ return GrantTypeCode
+}
+
+type JWTTokenRequest struct {
+ Issuer string `json:"iss"`
+ Subject string `json:"sub"`
+ Scopes Scopes `json:"-"`
+ Audience Audience `json:"aud"`
+ IssuedAt Time `json:"iat"`
+ ExpiresAt Time `json:"exp"`
+}
+
+//GetIssuer implements the Claims interface
+func (j *JWTTokenRequest) GetIssuer() string {
+ return j.Issuer
+}
+
+//GetAudience implements the Claims and TokenRequest interfaces
+func (j *JWTTokenRequest) GetAudience() []string {
+ return j.Audience
+}
+
+//GetExpiration implements the Claims interface
+func (j *JWTTokenRequest) GetExpiration() time.Time {
+ return time.Time(j.ExpiresAt)
+}
+
+//GetIssuedAt implements the Claims interface
+func (j *JWTTokenRequest) GetIssuedAt() time.Time {
+ return time.Time(j.IssuedAt)
+}
+
+//GetNonce implements the Claims interface
+func (j *JWTTokenRequest) GetNonce() string {
+ return ""
+}
+
+//GetAuthenticationContextClassReference implements the Claims interface
+func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string {
+ return ""
+}
+
+//GetAuthTime implements the Claims interface
+func (j *JWTTokenRequest) GetAuthTime() time.Time {
+ return time.Time{}
+}
+
+//GetAuthorizedParty implements the Claims interface
+func (j *JWTTokenRequest) GetAuthorizedParty() string {
+ return ""
+}
+
+//SetSignatureAlgorithm implements the Claims interface
+func (j *JWTTokenRequest) SetSignatureAlgorithm(_ jose.SignatureAlgorithm) {}
+
+//GetSubject implements the TokenRequest interface
+func (j *JWTTokenRequest) GetSubject() string {
+ return j.Subject
+}
+
+//GetSubject implements the TokenRequest interface
+func (j *JWTTokenRequest) GetScopes() []string {
+ return j.Scopes
+}
+
+type TokenExchangeRequest struct {
+ subjectToken string `schema:"subject_token"`
+ subjectTokenType string `schema:"subject_token_type"`
+ actorToken string `schema:"actor_token"`
+ actorTokenType string `schema:"actor_token_type"`
+ resource []string `schema:"resource"`
+ audience Audience `schema:"audience"`
+ Scope Scopes `schema:"scope"`
+ requestedTokenType string `schema:"requested_token_type"`
+}
diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go
new file mode 100644
index 0000000..86e5d06
--- /dev/null
+++ b/pkg/oidc/types.go
@@ -0,0 +1,89 @@
+package oidc
+
+import (
+ "encoding/json"
+ "strings"
+ "time"
+
+ "golang.org/x/text/language"
+)
+
+type Audience []string
+
+func (a *Audience) UnmarshalJSON(text []byte) error {
+ var i interface{}
+ err := json.Unmarshal(text, &i)
+ if err != nil {
+ return err
+ }
+ switch aud := i.(type) {
+ case []interface{}:
+ *a = make([]string, len(aud))
+ for i, audience := range aud {
+ (*a)[i] = audience.(string)
+ }
+ case string:
+ *a = []string{aud}
+ }
+ return nil
+}
+
+type Display string
+
+func (d *Display) UnmarshalText(text []byte) error {
+ display := Display(text)
+ switch display {
+ case DisplayPage, DisplayPopup, DisplayTouch, DisplayWAP:
+ *d = display
+ }
+ return nil
+}
+
+type Gender string
+
+type Locales []language.Tag
+
+func (l *Locales) UnmarshalText(text []byte) error {
+ locales := strings.Split(string(text), " ")
+ for _, locale := range locales {
+ tag, err := language.Parse(locale)
+ if err == nil && !tag.IsRoot() {
+ *l = append(*l, tag)
+ }
+ }
+ return nil
+}
+
+type Prompt string
+
+type ResponseType string
+
+type Scopes []string
+
+func (s Scopes) Encode() string {
+ return strings.Join(s, " ")
+}
+
+func (s *Scopes) UnmarshalText(text []byte) error {
+ *s = strings.Split(string(text), " ")
+ return nil
+}
+
+func (s *Scopes) MarshalText() ([]byte, error) {
+ return []byte(s.Encode()), nil
+}
+
+type Time time.Time
+
+func (t *Time) UnmarshalJSON(data []byte) error {
+ var i int64
+ if err := json.Unmarshal(data, &i); err != nil {
+ return err
+ }
+ *t = Time(time.Unix(i, 0).UTC())
+ return nil
+}
+
+func (t *Time) MarshalJSON() ([]byte, error) {
+ return json.Marshal(time.Time(*t).UTC().Unix())
+}
diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go
new file mode 100644
index 0000000..8138b4b
--- /dev/null
+++ b/pkg/oidc/types_test.go
@@ -0,0 +1,296 @@
+package oidc
+
+import (
+ "bytes"
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "golang.org/x/text/language"
+)
+
+func TestAudience_UnmarshalText(t *testing.T) {
+ type args struct {
+ text []byte
+ }
+ type res struct {
+ audience Audience
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ wantErr bool
+ }{
+ {
+ "invalid value",
+ args{
+ []byte(`{"aud": {"a": }}}`),
+ },
+ res{},
+ true,
+ },
+ {
+ "single audience",
+ args{
+ []byte(`{"aud": "single audience"}`),
+ },
+ res{
+ []string{"single audience"},
+ },
+ false,
+ },
+ {
+ "multiple audience",
+ args{
+ []byte(`{"aud": ["multiple", "audience"]}`),
+ },
+ res{
+ []string{"multiple", "audience"},
+ },
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ a := new(struct {
+ Audience Audience `json:"aud"`
+ })
+ if err := json.Unmarshal(tt.args.text, &a); (err != nil) != tt.wantErr {
+ t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ assert.ElementsMatch(t, a.Audience, tt.res.audience)
+ })
+ }
+}
+
+func TestDisplay_UnmarshalText(t *testing.T) {
+ type args struct {
+ text []byte
+ }
+ type res struct {
+ display Display
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ wantErr bool
+ }{
+ {
+ "unknown value",
+ args{
+ []byte("unknown"),
+ },
+ res{},
+ false,
+ },
+ {
+ "page",
+ args{
+ []byte("page"),
+ },
+ res{DisplayPage},
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var d Display
+ if err := d.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
+ t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ if d != tt.res.display {
+ t.Errorf("Display is not correct is = %v, want %v", d, tt.res.display)
+ }
+ })
+ }
+}
+
+func TestLocales_UnmarshalText(t *testing.T) {
+ type args struct {
+ text []byte
+ }
+ type res struct {
+ tags []language.Tag
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ wantErr bool
+ }{
+ {
+ "unknown value",
+ args{
+ []byte("unknown"),
+ },
+ res{},
+ false,
+ },
+ {
+ "undefined",
+ args{
+ []byte("und"),
+ },
+ res{},
+ false,
+ },
+ {
+ "single language",
+ args{
+ []byte("de"),
+ },
+ res{[]language.Tag{language.German}},
+ false,
+ },
+ {
+ "multiple languages",
+ args{
+ []byte("de en"),
+ },
+ res{[]language.Tag{language.German, language.English}},
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var locales Locales
+ if err := locales.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
+ t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ assert.ElementsMatch(t, locales, tt.res.tags)
+ })
+ }
+}
+
+func TestScopes_UnmarshalText(t *testing.T) {
+ type args struct {
+ text []byte
+ }
+ type res struct {
+ scopes []string
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ wantErr bool
+ }{
+ {
+ "unknown value",
+ args{
+ []byte("unknown"),
+ },
+ res{
+ []string{"unknown"},
+ },
+ false,
+ },
+ {
+ "struct",
+ args{
+ []byte(`{"unknown":"value"}`),
+ },
+ res{
+ []string{`{"unknown":"value"}`},
+ },
+ false,
+ },
+ {
+ "openid",
+ args{
+ []byte("openid"),
+ },
+ res{
+ []string{"openid"},
+ },
+ false,
+ },
+ {
+ "multiple scopes",
+ args{
+ []byte("openid email custom:scope"),
+ },
+ res{
+ []string{"openid", "email", "custom:scope"},
+ },
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var scopes Scopes
+ if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
+ t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ assert.ElementsMatch(t, scopes, tt.res.scopes)
+ })
+ }
+}
+func TestScopes_MarshalText(t *testing.T) {
+ type args struct {
+ scopes Scopes
+ }
+ type res struct {
+ scopes []byte
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ wantErr bool
+ }{
+ {
+ "unknown value",
+ args{
+ Scopes{"unknown"},
+ },
+ res{
+ []byte("unknown"),
+ },
+ false,
+ },
+ {
+ "struct",
+ args{
+ Scopes{`{"unknown":"value"}`},
+ },
+ res{
+ []byte(`{"unknown":"value"}`),
+ },
+ false,
+ },
+ {
+ "openid",
+ args{
+ Scopes{"openid"},
+ },
+ res{
+ []byte("openid"),
+ },
+ false,
+ },
+ {
+ "multiple scopes",
+ args{
+ Scopes{"openid", "email", "custom:scope"},
+ },
+ res{
+ []byte("openid email custom:scope"),
+ },
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ text, err := tt.args.scopes.MarshalText()
+ if (err != nil) != tt.wantErr {
+ t.Errorf("MarshalText() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ if !bytes.Equal(text, tt.res.scopes) {
+ t.Errorf("MarshalText() is = %q, want %q", text, tt.res.scopes)
+ }
+ })
+ }
+}
diff --git a/pkg/oidc/userinfo.go b/pkg/oidc/userinfo.go
index f8b4e6c..3c77b7b 100644
--- a/pkg/oidc/userinfo.go
+++ b/pkg/oidc/userinfo.go
@@ -2,89 +2,312 @@ package oidc
import (
"encoding/json"
+ "fmt"
"time"
"golang.org/x/text/language"
+
+ "github.com/caos/oidc/pkg/utils"
)
-type Userinfo struct {
- Subject string
- UserinfoProfile
- UserinfoEmail
- UserinfoPhone
- Address *UserinfoAddress
+type UserInfo interface {
+ GetSubject() string
+ UserInfoProfile
+ UserInfoEmail
+ UserInfoPhone
+ GetAddress() UserInfoAddress
+ GetClaim(key string) interface{}
+}
- Authorizations []string
+type UserInfoProfile interface {
+ GetName() string
+ GetGivenName() string
+ GetFamilyName() string
+ GetMiddleName() string
+ GetNickname() string
+ GetProfile() string
+ GetPicture() string
+ GetWebsite() string
+ GetGender() Gender
+ GetBirthdate() string
+ GetZoneinfo() string
+ GetLocale() language.Tag
+ GetPreferredUsername() string
+}
+
+type UserInfoEmail interface {
+ GetEmail() string
+ IsEmailVerified() bool
+}
+
+type UserInfoPhone interface {
+ GetPhoneNumber() string
+ IsPhoneNumberVerified() bool
+}
+
+type UserInfoAddress interface {
+ GetFormatted() string
+ GetStreetAddress() string
+ GetLocality() string
+ GetRegion() string
+ GetPostalCode() string
+ GetCountry() string
+}
+
+type UserInfoSetter interface {
+ UserInfo
+ SetSubject(sub string)
+ UserInfoProfileSetter
+ SetEmail(email string, verified bool)
+ SetPhone(phone string, verified bool)
+ SetAddress(address UserInfoAddress)
+ AppendClaims(key string, values interface{})
+}
+
+type UserInfoProfileSetter interface {
+ SetName(name string)
+ SetGivenName(name string)
+ SetFamilyName(name string)
+ SetMiddleName(name string)
+ SetNickname(name string)
+ SetUpdatedAt(date time.Time)
+ SetProfile(profile string)
+ SetPicture(profile string)
+ SetWebsite(website string)
+ SetGender(gender Gender)
+ SetBirthdate(birthdate string)
+ SetZoneinfo(zoneInfo string)
+ SetLocale(locale language.Tag)
+ SetPreferredUsername(name string)
+}
+
+func NewUserInfo() UserInfoSetter {
+ return &userinfo{}
+}
+
+type userinfo struct {
+ Subject string `json:"sub,omitempty"`
+ userInfoProfile
+ userInfoEmail
+ userInfoPhone
+ Address UserInfoAddress `json:"address,omitempty"`
claims map[string]interface{}
}
-type UserinfoProfile struct {
- Name string
- GivenName string
- FamilyName string
- MiddleName string
- Nickname string
- Profile string
- Picture string
- Website string
- Gender Gender
- Birthdate string
- Zoneinfo string
- Locale language.Tag
- UpdatedAt time.Time
- PreferredUsername string
+func (u *userinfo) GetSubject() string {
+ return u.Subject
}
-type Gender string
-
-type UserinfoEmail struct {
- Email string
- EmailVerified bool
+func (u *userinfo) GetName() string {
+ return u.Name
}
-type UserinfoPhone struct {
- PhoneNumber string
- PhoneNumberVerified bool
+func (u *userinfo) GetGivenName() string {
+ return u.GivenName
}
-type UserinfoAddress struct {
- Formatted string
- StreetAddress string
- Locality string
- Region string
- PostalCode string
- Country string
+func (u *userinfo) GetFamilyName() string {
+ return u.FamilyName
}
-type jsonUserinfoProfile struct {
- Name string `json:"name,omitempty"`
- GivenName string `json:"given_name,omitempty"`
- FamilyName string `json:"family_name,omitempty"`
- MiddleName string `json:"middle_name,omitempty"`
- Nickname string `json:"nickname,omitempty"`
- Profile string `json:"profile,omitempty"`
- Picture string `json:"picture,omitempty"`
- Website string `json:"website,omitempty"`
- Gender string `json:"gender,omitempty"`
- Birthdate string `json:"birthdate,omitempty"`
- Zoneinfo string `json:"zoneinfo,omitempty"`
- Locale string `json:"locale,omitempty"`
- UpdatedAt int64 `json:"updated_at,omitempty"`
- PreferredUsername string `json:"preferred_username,omitempty"`
+func (u *userinfo) GetMiddleName() string {
+ return u.MiddleName
}
-type jsonUserinfoEmail struct {
+func (u *userinfo) GetNickname() string {
+ return u.Nickname
+}
+
+func (u *userinfo) GetProfile() string {
+ return u.Profile
+}
+
+func (u *userinfo) GetPicture() string {
+ return u.Picture
+}
+
+func (u *userinfo) GetWebsite() string {
+ return u.Website
+}
+
+func (u *userinfo) GetGender() Gender {
+ return u.Gender
+}
+
+func (u *userinfo) GetBirthdate() string {
+ return u.Birthdate
+}
+
+func (u *userinfo) GetZoneinfo() string {
+ return u.Zoneinfo
+}
+
+func (u *userinfo) GetLocale() language.Tag {
+ return u.Locale
+}
+
+func (u *userinfo) GetPreferredUsername() string {
+ return u.PreferredUsername
+}
+
+func (u *userinfo) GetEmail() string {
+ return u.Email
+}
+
+func (u *userinfo) IsEmailVerified() bool {
+ return u.EmailVerified
+}
+
+func (u *userinfo) GetPhoneNumber() string {
+ return u.PhoneNumber
+}
+
+func (u *userinfo) IsPhoneNumberVerified() bool {
+ return u.PhoneNumberVerified
+}
+
+func (u *userinfo) GetAddress() UserInfoAddress {
+ return u.Address
+}
+
+func (u *userinfo) GetClaim(key string) interface{} {
+ return u.claims[key]
+}
+
+func (u *userinfo) SetSubject(sub string) {
+ u.Subject = sub
+}
+
+func (u *userinfo) SetName(name string) {
+ u.Name = name
+}
+
+func (u *userinfo) SetGivenName(name string) {
+ u.GivenName = name
+}
+
+func (u *userinfo) SetFamilyName(name string) {
+ u.FamilyName = name
+}
+
+func (u *userinfo) SetMiddleName(name string) {
+ u.MiddleName = name
+}
+
+func (u *userinfo) SetNickname(name string) {
+ u.Nickname = name
+}
+
+func (u *userinfo) SetUpdatedAt(date time.Time) {
+ u.UpdatedAt = Time(date)
+}
+
+func (u *userinfo) SetProfile(profile string) {
+ u.Profile = profile
+}
+
+func (u *userinfo) SetPicture(picture string) {
+ u.Picture = picture
+}
+
+func (u *userinfo) SetWebsite(website string) {
+ u.Website = website
+}
+
+func (u *userinfo) SetGender(gender Gender) {
+ u.Gender = gender
+}
+
+func (u *userinfo) SetBirthdate(birthdate string) {
+ u.Birthdate = birthdate
+}
+
+func (u *userinfo) SetZoneinfo(zoneInfo string) {
+ u.Zoneinfo = zoneInfo
+}
+
+func (u *userinfo) SetLocale(locale language.Tag) {
+ u.Locale = locale
+}
+
+func (u *userinfo) SetPreferredUsername(name string) {
+ u.PreferredUsername = name
+}
+
+func (u *userinfo) SetEmail(email string, verified bool) {
+ u.Email = email
+ u.EmailVerified = verified
+}
+
+func (u *userinfo) SetPhone(phone string, verified bool) {
+ u.PhoneNumber = phone
+ u.PhoneNumberVerified = verified
+}
+
+func (u *userinfo) SetAddress(address UserInfoAddress) {
+ u.Address = address
+}
+
+func (u *userinfo) AppendClaims(key string, value interface{}) {
+ if u.claims == nil {
+ u.claims = make(map[string]interface{})
+ }
+ u.claims[key] = value
+}
+
+func (u *userInfoAddress) GetFormatted() string {
+ return u.Formatted
+}
+
+func (u *userInfoAddress) GetStreetAddress() string {
+ return u.StreetAddress
+}
+
+func (u *userInfoAddress) GetLocality() string {
+ return u.Locality
+}
+
+func (u *userInfoAddress) GetRegion() string {
+ return u.Region
+}
+
+func (u *userInfoAddress) GetPostalCode() string {
+ return u.PostalCode
+}
+
+func (u *userInfoAddress) GetCountry() string {
+ return u.Country
+}
+
+type userInfoProfile struct {
+ Name string `json:"name,omitempty"`
+ GivenName string `json:"given_name,omitempty"`
+ FamilyName string `json:"family_name,omitempty"`
+ MiddleName string `json:"middle_name,omitempty"`
+ Nickname string `json:"nickname,omitempty"`
+ Profile string `json:"profile,omitempty"`
+ Picture string `json:"picture,omitempty"`
+ Website string `json:"website,omitempty"`
+ Gender Gender `json:"gender,omitempty"`
+ Birthdate string `json:"birthdate,omitempty"`
+ Zoneinfo string `json:"zoneinfo,omitempty"`
+ Locale language.Tag `json:"locale,omitempty"`
+ UpdatedAt Time `json:"updated_at,omitempty"`
+ PreferredUsername string `json:"preferred_username,omitempty"`
+}
+
+type userInfoEmail struct {
Email string `json:"email,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"`
}
-type jsonUserinfoPhone struct {
- Phone string `json:"phone_number,omitempty"`
- PhoneVerified bool `json:"phone_number_verified,omitempty"`
+type userInfoPhone struct {
+ PhoneNumber string `json:"phone_number,omitempty"`
+ PhoneNumberVerified bool `json:"phone_number_verified,omitempty"`
}
-type jsonUserinfoAddress struct {
+type userInfoAddress struct {
Formatted string `json:"formatted,omitempty"`
StreetAddress string `json:"street_address,omitempty"`
Locality string `json:"locality,omitempty"`
@@ -93,81 +316,63 @@ type jsonUserinfoAddress struct {
Country string `json:"country,omitempty"`
}
-func (i *Userinfo) MarshalJSON() ([]byte, error) {
- j := new(jsonUserinfo)
- j.Subject = i.Subject
- j.setUserinfo(*i)
- j.Authorizations = i.Authorizations
- return json.Marshal(j)
+func NewUserInfoAddress(streetAddress, locality, region, postalCode, country, formatted string) UserInfoAddress {
+ return &userInfoAddress{
+ StreetAddress: streetAddress,
+ Locality: locality,
+ Region: region,
+ PostalCode: postalCode,
+ Country: country,
+ Formatted: formatted,
+ }
+}
+func (i *userinfo) MarshalJSON() ([]byte, error) {
+ type Alias userinfo
+ a := &struct {
+ *Alias
+ Locale interface{} `json:"locale,omitempty"`
+ UpdatedAt int64 `json:"updated_at,omitempty"`
+ }{
+ Alias: (*Alias)(i),
+ }
+ if !i.Locale.IsRoot() {
+ a.Locale = i.Locale
+ }
+ if !time.Time(i.UpdatedAt).IsZero() {
+ a.UpdatedAt = time.Time(i.UpdatedAt).Unix()
+ }
+
+ b, err := json.Marshal(a)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(i.claims) == 0 {
+ return b, nil
+ }
+
+ claims, err := json.Marshal(i.claims)
+ if err != nil {
+ return nil, fmt.Errorf("jws: invalid map of custom claims %v", i.claims)
+ }
+ return utils.ConcatenateJSON(b, claims)
}
-func (i *Userinfo) UnmmarshalJSON(data []byte) error {
- if err := json.Unmarshal(data, i); err != nil {
+func (i *userinfo) UnmarshalJSON(data []byte) error {
+ type Alias userinfo
+ a := &struct {
+ *Alias
+ UpdatedAt int64 `json:"update_at,omitempty"`
+ }{
+ Alias: (*Alias)(i),
+ }
+ if err := json.Unmarshal(data, &a); err != nil {
return err
}
- return json.Unmarshal(data, &i.claims)
-}
-type jsonUserinfo struct {
- Subject string `json:"sub,omitempty"`
- jsonUserinfoProfile
- jsonUserinfoEmail
- jsonUserinfoPhone
- JsonUserinfoAddress *jsonUserinfoAddress `json:"address,omitempty"`
- Authorizations []string `json:"authorizations,omitempty"`
-}
+ i.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC())
-func (j *jsonUserinfo) setUserinfo(i Userinfo) {
- j.setUserinfoProfile(i.UserinfoProfile)
- j.setUserinfoEmail(i.UserinfoEmail)
- j.setUserinfoPhone(i.UserinfoPhone)
- j.setUserinfoAddress(i.Address)
-}
-
-func (j *jsonUserinfo) setUserinfoProfile(i UserinfoProfile) {
- j.Name = i.Name
- j.GivenName = i.GivenName
- j.FamilyName = i.FamilyName
- j.MiddleName = i.MiddleName
- j.Nickname = i.Nickname
- j.Profile = i.Profile
- j.Picture = i.Picture
- j.Website = i.Website
- j.Gender = string(i.Gender)
- j.Birthdate = i.Birthdate
- j.Zoneinfo = i.Zoneinfo
- if i.Locale != language.Und {
- j.Locale = i.Locale.String()
- }
- j.UpdatedAt = timeToJSON(i.UpdatedAt)
- j.PreferredUsername = i.PreferredUsername
-}
-
-func (j *jsonUserinfo) setUserinfoEmail(i UserinfoEmail) {
- j.Email = i.Email
- j.EmailVerified = i.EmailVerified
-}
-
-func (j *jsonUserinfo) setUserinfoPhone(i UserinfoPhone) {
- j.Phone = i.PhoneNumber
- j.PhoneVerified = i.PhoneNumberVerified
-}
-
-func (j *jsonUserinfo) setUserinfoAddress(i *UserinfoAddress) {
- if i == nil {
- return
- }
- if i.Country == "" && i.Formatted == "" && i.Locality == "" && i.PostalCode == "" && i.Region == "" && i.StreetAddress == "" {
- return
- }
- j.JsonUserinfoAddress = &jsonUserinfoAddress{
- Country: i.Country,
- Formatted: i.Formatted,
- Locality: i.Locality,
- PostalCode: i.PostalCode,
- Region: i.Region,
- StreetAddress: i.StreetAddress,
- }
+ return nil
}
type UserInfoRequest struct {
diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go
index 492664b..06470a0 100644
--- a/pkg/oidc/verifier.go
+++ b/pkg/oidc/verifier.go
@@ -24,7 +24,7 @@ type Claims interface {
GetAuthenticationContextClassReference() string
GetAuthTime() time.Time
GetAuthorizedParty() string
- SetSignature(algorithm jose.SignatureAlgorithm)
+ SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm)
}
var (
@@ -140,7 +140,7 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl
return ErrSignatureInvalidPayload
}
- claims.SetSignature(jose.SignatureAlgorithm(sig.Header.Algorithm))
+ claims.SetSignatureAlgorithm(jose.SignatureAlgorithm(sig.Header.Algorithm))
return nil
}
diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go
index cf40e62..4d6118c 100644
--- a/pkg/op/authrequest.go
+++ b/pkg/op/authrequest.go
@@ -91,7 +91,8 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
if err != nil {
return "", ErrServerError(err.Error())
}
- if err := ValidateAuthReqScopes(authReq.Scopes); err != nil {
+ authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes)
+ if err != nil {
return "", err
}
if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil {
@@ -104,14 +105,33 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
}
//ValidateAuthReqScopes validates the passed scopes
-func ValidateAuthReqScopes(scopes []string) error {
+func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
if len(scopes) == 0 {
- return ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.")
+ return nil, ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.")
}
- if !utils.Contains(scopes, oidc.ScopeOpenID) {
- return ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.")
+ openID := false
+ for i := len(scopes) - 1; i >= 0; i-- {
+ scope := scopes[i]
+ if scope == oidc.ScopeOpenID {
+ openID = true
+ continue
+ }
+ if !(scope == oidc.ScopeProfile ||
+ scope == oidc.ScopeEmail ||
+ scope == oidc.ScopePhone ||
+ scope == oidc.ScopeAddress ||
+ scope == oidc.ScopeOfflineAccess) &&
+ !utils.Contains(client.AllowedScopes(), scope) {
+ scopes[i] = scopes[len(scopes)-1]
+ scopes[len(scopes)-1] = ""
+ scopes = scopes[:len(scopes)-1]
+ }
}
- return nil
+ if !openID {
+ return nil, ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.")
+ }
+
+ return scopes, nil
}
//ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type
@@ -168,7 +188,7 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
if err != nil {
return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.")
}
- return claims.Subject, nil
+ return claims.GetSubject(), nil
}
//RedirectToLogin redirects the end user to the Login UI for authentication
diff --git a/pkg/op/authrequest_test.go b/pkg/op/authrequest_test.go
index d74d365..3856acd 100644
--- a/pkg/op/authrequest_test.go
+++ b/pkg/op/authrequest_test.go
@@ -8,6 +8,7 @@ import (
"testing"
"github.com/gorilla/schema"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/caos/oidc/pkg/oidc"
@@ -193,28 +194,63 @@ func TestValidateAuthRequest(t *testing.T) {
func TestValidateAuthReqScopes(t *testing.T) {
type args struct {
+ client op.Client
+ scopes []string
+ }
+ type res struct {
+ err bool
scopes []string
}
tests := []struct {
- name string
- args args
- wantErr bool
+ name string
+ args args
+ res res
}{
{
- "scopes missing fails", args{}, true,
+ "scopes missing fails",
+ args{},
+ res{
+ err: true,
+ },
},
{
- "scope openid missing fails", args{[]string{"email"}}, true,
+ "scope openid missing fails",
+ args{
+ mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
+ []string{"email"},
+ },
+ res{
+ err: true,
+ },
},
{
- "scope ok", args{[]string{"openid"}}, false,
+ "scope ok",
+ args{
+ mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
+ []string{"openid"},
+ },
+ res{
+ scopes: []string{"openid"},
+ },
+ },
+ {
+ "scope with drop ok",
+ args{
+ mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
+ []string{"openid", "email", "unknown"},
+ },
+ res{
+ scopes: []string{"openid", "email"},
+ },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- if err := op.ValidateAuthReqScopes(tt.args.scopes); (err != nil) != tt.wantErr {
- t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr)
+ scopes, err := op.ValidateAuthReqScopes(tt.args.client, tt.args.scopes)
+ if (err != nil) != tt.res.err {
+ t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.res.err)
}
+ assert.ElementsMatch(t, scopes, tt.res.scopes)
})
}
}
diff --git a/pkg/op/client.go b/pkg/op/client.go
index 3184b90..790933e 100644
--- a/pkg/op/client.go
+++ b/pkg/op/client.go
@@ -10,7 +10,9 @@ const (
ApplicationTypeWeb ApplicationType = iota
ApplicationTypeUserAgent
ApplicationTypeNative
+)
+const (
AccessTokenTypeBearer AccessTokenType = iota
AccessTokenTypeJWT
)
@@ -32,6 +34,9 @@ type Client interface {
AccessTokenType() AccessTokenType
IDTokenLifetime() time.Duration
DevMode() bool
+ AllowedScopes() []string
+ AssertAdditionalIdTokenScopes() bool
+ AssertAdditionalAccessTokenScopes() bool
}
func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool {
diff --git a/pkg/op/config.go b/pkg/op/config.go
index b3df943..d64c0ee 100644
--- a/pkg/op/config.go
+++ b/pkg/op/config.go
@@ -7,6 +7,8 @@ import (
"strings"
)
+const OidcDevMode = "CAOS_OIDC_DEV"
+
type Configuration interface {
Issuer() string
AuthorizationEndpoint() Endpoint
@@ -42,7 +44,7 @@ func ValidateIssuer(issuer string) error {
}
func devLocalAllowed(url *url.URL) bool {
- _, b := os.LookupEnv("CAOS_OIDC_DEV")
+ _, b := os.LookupEnv(OidcDevMode)
if !b {
return b
}
diff --git a/pkg/op/config_test.go b/pkg/op/config_test.go
index 79173fb..e140074 100644
--- a/pkg/op/config_test.go
+++ b/pkg/op/config_test.go
@@ -60,6 +60,8 @@ func TestValidateIssuer(t *testing.T) {
true,
},
}
+ //ensure env is not set
+ os.Unsetenv(OidcDevMode)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
@@ -84,7 +86,7 @@ func TestValidateIssuerDevLocalAllowed(t *testing.T) {
false,
},
}
- os.Setenv("CAOS_OIDC_DEV", "")
+ os.Setenv(OidcDevMode, "true")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go
index 0a6b6a5..a481a8b 100644
--- a/pkg/op/mock/authorizer.mock.impl.go
+++ b/pkg/op/mock/authorizer.mock.impl.go
@@ -72,18 +72,18 @@ func (v *Verifier) VerifyIDToken(ctx context.Context, idToken string) (*oidc.IDT
return nil, nil
}
-type Sig struct{}
+type Sig struct {
+ signer jose.Signer
+}
+
+func (s *Sig) Signer() jose.Signer {
+ return s.signer
+}
func (s *Sig) Health(ctx context.Context) error {
return nil
}
-func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) {
- return "", nil
-}
-func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) {
- return "", nil
-}
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {
return jose.HS256
}
@@ -92,9 +92,3 @@ func ExpectStorage(a op.Authorizer, t *testing.T) {
mockA := a.(*MockAuthorizer)
mockA.EXPECT().Storage().AnyTimes().Return(NewMockStorageAny(t))
}
-
-// func NewMockSignerAny(t *testing.T) op.Signer {
-// m := NewMockSigner(gomock.NewController(t))
-// m.EXPECT().Sign(gomock.Any()).AnyTimes().Return("", nil)
-// return m
-// }
diff --git a/pkg/op/mock/client.go b/pkg/op/mock/client.go
index eed21d5..12c00cc 100644
--- a/pkg/op/mock/client.go
+++ b/pkg/op/mock/client.go
@@ -26,6 +26,7 @@ func NewClientExpectAny(t *testing.T, appType op.ApplicationType) op.Client {
func(id string) string {
return "login?id=" + id
})
+ m.EXPECT().AllowedScopes().AnyTimes().Return(nil)
return c
}
diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go
index 4007347..0780623 100644
--- a/pkg/op/mock/client.mock.go
+++ b/pkg/op/mock/client.mock.go
@@ -49,6 +49,20 @@ func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType))
}
+// AllowedScopes mocks base method
+func (m *MockClient) AllowedScopes() []string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "AllowedScopes")
+ ret0, _ := ret[0].([]string)
+ return ret0
+}
+
+// AllowedScopes indicates an expected call of AllowedScopes
+func (mr *MockClientMockRecorder) AllowedScopes() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowedScopes", reflect.TypeOf((*MockClient)(nil).AllowedScopes))
+}
+
// ApplicationType mocks base method
func (m *MockClient) ApplicationType() op.ApplicationType {
m.ctrl.T.Helper()
@@ -63,6 +77,34 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
}
+// AssertAdditionalAccessTokenScopes mocks base method
+func (m *MockClient) AssertAdditionalAccessTokenScopes() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "AssertAdditionalAccessTokenScopes")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// AssertAdditionalAccessTokenScopes indicates an expected call of AssertAdditionalAccessTokenScopes
+func (mr *MockClientMockRecorder) AssertAdditionalAccessTokenScopes() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalAccessTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalAccessTokenScopes))
+}
+
+// AssertAdditionalIdTokenScopes mocks base method
+func (m *MockClient) AssertAdditionalIdTokenScopes() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "AssertAdditionalIdTokenScopes")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// AssertAdditionalIdTokenScopes indicates an expected call of AssertAdditionalIdTokenScopes
+func (mr *MockClientMockRecorder) AssertAdditionalIdTokenScopes() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalIdTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalIdTokenScopes))
+}
+
// AuthMethod mocks base method
func (m *MockClient) AuthMethod() op.AuthMethod {
m.ctrl.T.Helper()
diff --git a/pkg/op/mock/signer.mock.go b/pkg/op/mock/signer.mock.go
index a7d909c..b52f9d4 100644
--- a/pkg/op/mock/signer.mock.go
+++ b/pkg/op/mock/signer.mock.go
@@ -6,7 +6,6 @@ package mock
import (
context "context"
- oidc "github.com/caos/oidc/pkg/oidc"
gomock "github.com/golang/mock/gomock"
jose "gopkg.in/square/go-jose.v2"
reflect "reflect"
@@ -49,36 +48,6 @@ func (mr *MockSignerMockRecorder) Health(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockSigner)(nil).Health), arg0)
}
-// SignAccessToken mocks base method
-func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) {
- m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "SignAccessToken", arg0)
- ret0, _ := ret[0].(string)
- ret1, _ := ret[1].(error)
- return ret0, ret1
-}
-
-// SignAccessToken indicates an expected call of SignAccessToken
-func (mr *MockSignerMockRecorder) SignAccessToken(arg0 interface{}) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignAccessToken", reflect.TypeOf((*MockSigner)(nil).SignAccessToken), arg0)
-}
-
-// SignIDToken mocks base method
-func (m *MockSigner) SignIDToken(arg0 *oidc.IDTokenClaims) (string, error) {
- m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "SignIDToken", arg0)
- ret0, _ := ret[0].(string)
- ret1, _ := ret[1].(error)
- return ret0, ret1
-}
-
-// SignIDToken indicates an expected call of SignIDToken
-func (mr *MockSignerMockRecorder) SignIDToken(arg0 interface{}) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignIDToken", reflect.TypeOf((*MockSigner)(nil).SignIDToken), arg0)
-}
-
// SignatureAlgorithm mocks base method
func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
m.ctrl.T.Helper()
@@ -92,3 +61,17 @@ func (mr *MockSignerMockRecorder) SignatureAlgorithm() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigner)(nil).SignatureAlgorithm))
}
+
+// Signer mocks base method
+func (m *MockSigner) Signer() jose.Signer {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Signer")
+ ret0, _ := ret[0].(jose.Signer)
+ return ret0
+}
+
+// Signer indicates an expected call of Signer
+func (mr *MockSignerMockRecorder) Signer() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockSigner)(nil).Signer))
+}
diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go
index 0efbbd4..9e4963a 100644
--- a/pkg/op/mock/storage.mock.go
+++ b/pkg/op/mock/storage.mock.go
@@ -171,6 +171,21 @@ func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
}
+// GetPrivateClaimsFromScopes mocks base method
+func (m *MockStorage) GetPrivateClaimsFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (map[string]interface{}, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetPrivateClaimsFromScopes", arg0, arg1, arg2, arg3)
+ ret0, _ := ret[0].(map[string]interface{})
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetPrivateClaimsFromScopes indicates an expected call of GetPrivateClaimsFromScopes
+func (mr *MockStorageMockRecorder) GetPrivateClaimsFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivateClaimsFromScopes", reflect.TypeOf((*MockStorage)(nil).GetPrivateClaimsFromScopes), arg0, arg1, arg2, arg3)
+}
+
// GetSigningKey mocks base method
func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) {
m.ctrl.T.Helper()
@@ -184,25 +199,25 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac
}
// GetUserinfoFromScopes mocks base method
-func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.Userinfo, error) {
+func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (oidc.UserInfo, error) {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2)
- ret0, _ := ret[0].(*oidc.Userinfo)
+ ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2, arg3)
+ ret0, _ := ret[0].(oidc.UserInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
-func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interface{}) *gomock.Call {
+func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2, arg3)
}
// GetUserinfoFromToken mocks base method
-func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2, arg3 string) (*oidc.Userinfo, error) {
+func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2, arg3 string) (oidc.UserInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2, arg3)
- ret0, _ := ret[0].(*oidc.Userinfo)
+ ret0, _ := ret[0].(oidc.UserInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go
index 6fd2760..de9dee9 100644
--- a/pkg/op/mock/storage.mock.impl.go
+++ b/pkg/op/mock/storage.mock.impl.go
@@ -168,3 +168,12 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType {
func (c *ConfClient) DevMode() bool {
return c.devMode
}
+func (c *ConfClient) AllowedScopes() []string {
+ return nil
+}
+func (c *ConfClient) AssertAdditionalIdTokenScopes() bool {
+ return false
+}
+func (c *ConfClient) AssertAdditionalAccessTokenScopes() bool {
+ return false
+}
diff --git a/pkg/op/op.go b/pkg/op/op.go
index d913c7f..bba7a14 100644
--- a/pkg/op/op.go
+++ b/pkg/op/op.go
@@ -51,6 +51,7 @@ type OpenIDProvider interface {
Encoder() utils.Encoder
IDTokenHintVerifier() IDTokenHintVerifier
JWTProfileVerifier() JWTProfileVerifier
+ AccessTokenVerifier() AccessTokenVerifier
Crypto() Crypto
DefaultLogoutRedirectURI() string
Signer() Signer
@@ -130,7 +131,7 @@ func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opO
}
keyCh := make(chan jose.SigningKey)
- o.signer = NewDefaultSigner(ctx, storage, keyCh)
+ o.signer = NewSigner(ctx, storage, keyCh)
go EnsureKey(ctx, storage, keyCh, o.timer, o.retry)
o.httpHandler = CreateRouter(o, o.interceptors...)
@@ -152,6 +153,8 @@ type openidProvider struct {
signer Signer
idTokenHintVerifier IDTokenHintVerifier
jwtProfileVerifier JWTProfileVerifier
+ accessTokenVerifier AccessTokenVerifier
+ keySet *openIDKeySet
crypto Crypto
httpHandler http.Handler
decoder *schema.Decoder
@@ -207,7 +210,7 @@ func (o *openidProvider) Encoder() utils.Encoder {
func (o *openidProvider) IDTokenHintVerifier() IDTokenHintVerifier {
if o.idTokenHintVerifier == nil {
- o.idTokenHintVerifier = NewIDTokenHintVerifier(o.Issuer(), &openIDKeySet{o.Storage()})
+ o.idTokenHintVerifier = NewIDTokenHintVerifier(o.Issuer(), o.openIDKeySet())
}
return o.idTokenHintVerifier
}
@@ -219,6 +222,20 @@ func (o *openidProvider) JWTProfileVerifier() JWTProfileVerifier {
return o.jwtProfileVerifier
}
+func (o *openidProvider) AccessTokenVerifier() AccessTokenVerifier {
+ if o.accessTokenVerifier == nil {
+ o.accessTokenVerifier = NewAccessTokenVerifier(o.Issuer(), o.openIDKeySet())
+ }
+ return o.accessTokenVerifier
+}
+
+func (o *openidProvider) openIDKeySet() oidc.KeySet {
+ if o.keySet == nil {
+ o.keySet = &openIDKeySet{o.Storage()}
+ }
+ return o.keySet
+}
+
func (o *openidProvider) Crypto() Crypto {
return o.crypto
}
diff --git a/pkg/op/session.go b/pkg/op/session.go
index d04e361..19ebab4 100644
--- a/pkg/op/session.go
+++ b/pkg/op/session.go
@@ -66,8 +66,8 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
if err != nil {
return nil, ErrInvalidRequest("id_token_hint invalid")
}
- session.UserID = claims.Subject
- session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.AuthorizedParty)
+ session.UserID = claims.GetSubject()
+ session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.GetAuthorizedParty())
if err != nil {
return nil, ErrServerError("")
}
diff --git a/pkg/op/signer.go b/pkg/op/signer.go
index e9926cd..76bb9c7 100644
--- a/pkg/op/signer.go
+++ b/pkg/op/signer.go
@@ -2,19 +2,15 @@ package op
import (
"context"
- "encoding/json"
"errors"
"github.com/caos/logging"
"gopkg.in/square/go-jose.v2"
-
- "github.com/caos/oidc/pkg/oidc"
)
type Signer interface {
Health(ctx context.Context) error
- SignIDToken(claims *oidc.IDTokenClaims) (string, error)
- SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
+ Signer() jose.Signer
SignatureAlgorithm() jose.SignatureAlgorithm
}
@@ -24,7 +20,7 @@ type tokenSigner struct {
alg jose.SignatureAlgorithm
}
-func NewDefaultSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer {
+func NewSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer {
s := &tokenSigner{
storage: storage,
}
@@ -41,6 +37,10 @@ func (s *tokenSigner) Health(_ context.Context) error {
return nil
}
+func (s *tokenSigner) Signer() jose.Signer {
+ return s.signer
+}
+
func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) {
for {
select {
@@ -55,30 +55,6 @@ func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.S
}
}
-func (s *tokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) {
- payload, err := json.Marshal(claims)
- if err != nil {
- return "", err
- }
- return s.Sign(payload)
-}
-
-func (s *tokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) {
- payload, err := json.Marshal(claims)
- if err != nil {
- return "", err
- }
- return s.Sign(payload)
-}
-
-func (s *tokenSigner) Sign(payload []byte) (string, error) {
- result, err := s.signer.Sign(payload)
- if err != nil {
- return "", err
- }
- return result.CompactSerialize()
-}
-
func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
return s.alg
}
diff --git a/pkg/op/signer_test.go b/pkg/op/signer_test.go
deleted file mode 100644
index 75e184b..0000000
--- a/pkg/op/signer_test.go
+++ /dev/null
@@ -1,95 +0,0 @@
-package op
-
-import (
- "testing"
-
- "github.com/stretchr/testify/require"
- "gopkg.in/square/go-jose.v2"
-)
-
-// func TestNewDefaultSigner(t *testing.T) {
-// type args struct {
-// storage Storage
-// }
-// tests := []struct {
-// name string
-// args args
-// want Signer
-// wantErr bool
-// }{
-// {
-// "err initialize storage fails",
-// args{mock.NewMockStorageSigningKeyError(t)},
-// nil,
-// true,
-// },
-// {
-// "err initialize storage fails",
-// args{mock.NewMockStorageSigningKeyInvalid(t)},
-// nil,
-// true,
-// },
-// {
-// "initialize ok",
-// args{mock.NewMockStorageSigningKey(t)},
-// &idTokenSigner{Storage: mock.NewMockStorageSigningKey(t)},
-// false,
-// },
-// }
-// for _, tt := range tests {
-// t.Run(tt.name, func(t *testing.T) {
-// got, err := op.NewDefaultSigner(tt.args.storage)
-// if (err != nil) != tt.wantErr {
-// t.Errorf("NewDefaultSigner() error = %v, wantErr %v", err, tt.wantErr)
-// return
-// }
-// if !reflect.DeepEqual(got, tt.want) {
-// t.Errorf("NewDefaultSigner() = %v, want %v", got, tt.want)
-// }
-// })
-// }
-// }
-
-func Test_idTokenSigner_Sign(t *testing.T) {
- signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, &jose.SignerOptions{})
- require.NoError(t, err)
-
- type fields struct {
- signer jose.Signer
- storage Storage
- }
- type args struct {
- payload []byte
- }
- tests := []struct {
- name string
- fields fields
- args args
- want string
- wantErr bool
- }{
- {
- "ok",
- fields{signer, nil},
- args{[]byte("test")},
- "eyJhbGciOiJIUzI1NiJ9.dGVzdA.SxYZRsvB_Dr4F7SEFuYXvkMZqCCwzpsPOQXl-vLPEww",
- false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- s := &tokenSigner{
- signer: tt.fields.signer,
- storage: tt.fields.storage,
- }
- got, err := s.Sign(tt.args.payload)
- if (err != nil) != tt.wantErr {
- t.Errorf("idTokenSigner.Sign() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("idTokenSigner.Sign() = %v, want %v", got, tt.want)
- }
- })
- }
-}
diff --git a/pkg/op/storage.go b/pkg/op/storage.go
index 9f8e6bd..eba5003 100644
--- a/pkg/op/storage.go
+++ b/pkg/op/storage.go
@@ -26,10 +26,11 @@ type AuthStorage interface {
}
type OPStorage interface {
- GetClientByClientID(context.Context, string) (Client, error)
- AuthorizeClientIDSecret(context.Context, string, string) error
- GetUserinfoFromScopes(context.Context, string, []string) (*oidc.Userinfo, error)
- GetUserinfoFromToken(ctx context.Context, tokenID, subject, origin string) (*oidc.Userinfo, error)
+ GetClientByClientID(ctx context.Context, clientID string) (Client, error)
+ AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error
+ GetUserinfoFromScopes(ctx context.Context, userID, clientID string, scopes []string) (oidc.UserInfo, error)
+ GetUserinfoFromToken(ctx context.Context, tokenID, subject, origin string) (oidc.UserInfo, error)
+ GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]interface{}, error)
GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
}
diff --git a/pkg/op/token.go b/pkg/op/token.go
index 0dd0663..2d66ef5 100644
--- a/pkg/op/token.go
+++ b/pkg/op/token.go
@@ -5,6 +5,7 @@ import (
"time"
"github.com/caos/oidc/pkg/oidc"
+ "github.com/caos/oidc/pkg/utils"
)
type TokenCreator interface {
@@ -25,12 +26,12 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client
var validity time.Duration
if createAccessToken {
var err error
- accessToken, validity, err = CreateAccessToken(ctx, authReq, client.AccessTokenType(), creator)
+ accessToken, validity, err = CreateAccessToken(ctx, authReq, client.AccessTokenType(), creator, client)
if err != nil {
return nil, err
}
}
- idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer())
+ idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer(), client.AssertAdditionalIdTokenScopes())
if err != nil {
return nil, err
}
@@ -50,7 +51,7 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client
}
func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator) (*oidc.AccessTokenResponse, error) {
- accessToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator)
+ accessToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator, nil)
if err != nil {
return nil, err
}
@@ -63,17 +64,17 @@ func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, crea
}, nil
}
-func CreateAccessToken(ctx context.Context, authReq TokenRequest, accessTokenType AccessTokenType, creator TokenCreator) (token string, validity time.Duration, err error) {
- id, exp, err := creator.Storage().CreateToken(ctx, authReq)
+func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client Client) (token string, validity time.Duration, err error) {
+ id, exp, err := creator.Storage().CreateToken(ctx, tokenRequest)
if err != nil {
return "", 0, err
}
validity = exp.Sub(time.Now().UTC())
if accessTokenType == AccessTokenTypeJWT {
- token, err = CreateJWT(creator.Issuer(), authReq, exp, id, creator.Signer())
+ token, err = CreateJWT(ctx, creator.Issuer(), tokenRequest, exp, id, creator.Signer(), client, creator.Storage())
return
}
- token, err = CreateBearerToken(id, authReq.GetSubject(), creator.Crypto())
+ token, err = CreateBearerToken(id, tokenRequest.GetSubject(), creator.Crypto())
return
}
@@ -81,52 +82,79 @@ func CreateBearerToken(tokenID, subject string, crypto Crypto) (string, error) {
return crypto.Encrypt(tokenID + ":" + subject)
}
-func CreateJWT(issuer string, authReq TokenRequest, exp time.Time, id string, signer Signer) (string, error) {
- now := time.Now().UTC()
- nbf := now
- claims := &oidc.AccessTokenClaims{
- Issuer: issuer,
- Subject: authReq.GetSubject(),
- Audiences: authReq.GetAudience(),
- Expiration: exp,
- IssuedAt: now,
- NotBefore: nbf,
- JWTID: id,
- }
- return signer.SignAccessToken(claims)
-}
-
-func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer) (string, error) {
- var err error
- exp := time.Now().UTC().Add(validity)
- userinfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes())
- if err != nil {
- return "", err
- }
- claims := &oidc.IDTokenClaims{
- Issuer: issuer,
- Audiences: authReq.GetAudience(),
- Expiration: exp,
- IssuedAt: time.Now().UTC(),
- AuthTime: authReq.GetAuthTime(),
- Nonce: authReq.GetNonce(),
- AuthenticationContextClassReference: authReq.GetACR(),
- AuthenticationMethodsReferences: authReq.GetAMR(),
- AuthorizedParty: authReq.GetClientID(),
- Userinfo: *userinfo,
- }
- if accessToken != "" {
- claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
+func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, signer Signer, client Client, storage Storage) (string, error) {
+ claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id)
+ if client != nil && client.AssertAdditionalAccessTokenScopes() {
+ privateClaims, err := storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(tokenRequest.GetScopes()))
if err != nil {
return "", err
}
+ claims.SetPrivateClaims(privateClaims)
+ }
+ return utils.Sign(claims, signer.Signer())
+}
+
+func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer, additonalScopes bool) (string, error) {
+ exp := time.Now().UTC().Add(validity)
+ claims := oidc.NewIDTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, authReq.GetAuthTime(), authReq.GetNonce(), authReq.GetACR(), authReq.GetAMR(), authReq.GetClientID())
+ scopes := authReq.GetScopes()
+
+ if accessToken != "" {
+ atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
+ if err != nil {
+ return "", err
+ }
+ claims.SetAccessTokenHash(atHash)
+ scopes = removeUserinfoScopes(scopes)
+ }
+ if !additonalScopes {
+ scopes = removeAdditionalScopes(scopes)
+ }
+ if len(scopes) > 0 {
+ userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetClientID(), scopes)
+ if err != nil {
+ return "", err
+ }
+ claims.SetUserinfo(userInfo)
}
if code != "" {
- claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm())
+ codeHash, err := oidc.ClaimHash(code, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
+ claims.SetCodeHash(codeHash)
}
- return signer.SignIDToken(claims)
+ return utils.Sign(claims, signer.Signer())
+}
+
+func removeUserinfoScopes(scopes []string) []string {
+ for i := len(scopes) - 1; i >= 0; i-- {
+ if scopes[i] == oidc.ScopeProfile ||
+ scopes[i] == oidc.ScopeEmail ||
+ scopes[i] == oidc.ScopeAddress ||
+ scopes[i] == oidc.ScopePhone {
+
+ scopes[i] = scopes[len(scopes)-1]
+ scopes[len(scopes)-1] = ""
+ scopes = scopes[:len(scopes)-1]
+ }
+ }
+ return scopes
+}
+
+func removeAdditionalScopes(scopes []string) []string {
+ for i := len(scopes) - 1; i >= 0; i-- {
+ if !(scopes[i] == oidc.ScopeOpenID ||
+ scopes[i] == oidc.ScopeProfile ||
+ scopes[i] == oidc.ScopeEmail ||
+ scopes[i] == oidc.ScopeAddress ||
+ scopes[i] == oidc.ScopePhone) {
+
+ scopes[i] = scopes[len(scopes)-1]
+ scopes[len(scopes)-1] = ""
+ scopes = scopes[:len(scopes)-1]
+ }
+ }
+ return scopes
}
diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go
index b3613ce..cba70f3 100644
--- a/pkg/op/tokenrequest.go
+++ b/pkg/op/tokenrequest.go
@@ -138,18 +138,18 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque
}
func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
- assertion, err := ParseJWTProfileRequest(r, exchanger.Decoder())
+ profileRequest, err := ParseJWTProfileRequest(r, exchanger.Decoder())
if err != nil {
RequestError(w, r, err)
}
- claims, err := VerifyJWTAssertion(r.Context(), assertion, exchanger.JWTProfileVerifier())
+ tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest, exchanger.JWTProfileVerifier())
if err != nil {
RequestError(w, r, err)
return
}
- resp, err := CreateJWTTokenResponse(r.Context(), claims, exchanger)
+ resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, exchanger)
if err != nil {
RequestError(w, r, err)
return
@@ -157,17 +157,17 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
utils.MarshalJSON(w, resp)
}
-func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (string, error) {
+func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (*tokenexchange.JWTProfileRequest, error) {
err := r.ParseForm()
if err != nil {
- return "", ErrInvalidRequest("error parsing form")
+ return nil, ErrInvalidRequest("error parsing form")
}
tokenReq := new(tokenexchange.JWTProfileRequest)
err = decoder.Decode(tokenReq, r.Form)
if err != nil {
- return "", ErrInvalidRequest("error decoding form")
+ return nil, ErrInvalidRequest("error decoding form")
}
- return tokenReq.Assertion, nil
+ return tokenReq, nil
}
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go
index 2362e4f..1163598 100644
--- a/pkg/op/userinfo.go
+++ b/pkg/op/userinfo.go
@@ -1,6 +1,7 @@
package op
import (
+ "context"
"errors"
"net/http"
"strings"
@@ -13,6 +14,7 @@ type UserinfoProvider interface {
Decoder() utils.Decoder
Crypto() Crypto
Storage() Storage
+ AccessTokenVerifier() AccessTokenVerifier
}
func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) {
@@ -27,17 +29,12 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP
http.Error(w, "access token missing", http.StatusUnauthorized)
return
}
- tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
- if err != nil {
- http.Error(w, "access token missing", http.StatusUnauthorized)
- return
- }
- splittedToken := strings.Split(tokenIDSubject, ":")
- if len(splittedToken) != 2 {
+ tokenID, subject, ok := getTokenIDAndSubject(r.Context(), userinfoProvider, accessToken)
+ if !ok {
http.Error(w, "access token invalid", http.StatusUnauthorized)
return
}
- info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), splittedToken[0], splittedToken[1], r.Header.Get("origin"))
+ info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), tokenID, subject, r.Header.Get("origin"))
if err != nil {
w.WriteHeader(http.StatusForbidden)
utils.MarshalJSON(w, err)
@@ -66,3 +63,19 @@ func getAccessToken(r *http.Request, decoder utils.Decoder) (string, error) {
}
return req.AccessToken, nil
}
+
+func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) {
+ tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
+ if err == nil {
+ splitToken := strings.Split(tokenIDSubject, ":")
+ if len(splitToken) != 2 {
+ return "", "", false
+ }
+ return splitToken[0], splitToken[1], true
+ }
+ accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier())
+ if err != nil {
+ return "", "", false
+ }
+ return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true
+}
diff --git a/pkg/op/verifier_access_token.go b/pkg/op/verifier_access_token.go
new file mode 100644
index 0000000..05168a6
--- /dev/null
+++ b/pkg/op/verifier_access_token.go
@@ -0,0 +1,85 @@
+package op
+
+import (
+ "context"
+ "time"
+
+ "github.com/caos/oidc/pkg/oidc"
+)
+
+type AccessTokenVerifier interface {
+ oidc.Verifier
+ SupportedSignAlgs() []string
+ KeySet() oidc.KeySet
+}
+
+type accessTokenVerifier struct {
+ issuer string
+ maxAgeIAT time.Duration
+ offset time.Duration
+ supportedSignAlgs []string
+ maxAge time.Duration
+ acr oidc.ACRVerifier
+ keySet oidc.KeySet
+}
+
+//Issuer implements oidc.Verifier interface
+func (i *accessTokenVerifier) Issuer() string {
+ return i.issuer
+}
+
+//MaxAgeIAT implements oidc.Verifier interface
+func (i *accessTokenVerifier) MaxAgeIAT() time.Duration {
+ return i.maxAgeIAT
+}
+
+//Offset implements oidc.Verifier interface
+func (i *accessTokenVerifier) Offset() time.Duration {
+ return i.offset
+}
+
+//SupportedSignAlgs implements AccessTokenVerifier interface
+func (i *accessTokenVerifier) SupportedSignAlgs() []string {
+ return i.supportedSignAlgs
+}
+
+//KeySet implements AccessTokenVerifier interface
+func (i *accessTokenVerifier) KeySet() oidc.KeySet {
+ return i.keySet
+}
+
+func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet) AccessTokenVerifier {
+ verifier := &idTokenHintVerifier{
+ issuer: issuer,
+ keySet: keySet,
+ }
+ return verifier
+}
+
+//VerifyAccessToken validates the access token (issuer, signature and expiration)
+func VerifyAccessToken(ctx context.Context, token string, v AccessTokenVerifier) (oidc.AccessTokenClaims, error) {
+ claims := oidc.EmptyAccessTokenClaims()
+
+ decrypted, err := oidc.DecryptToken(token)
+ if err != nil {
+ return nil, err
+ }
+ payload, err := oidc.ParseToken(decrypted, claims)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
+ return nil, err
+ }
+
+ if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
+ return nil, err
+ }
+
+ if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
+ return nil, err
+ }
+
+ return claims, nil
+}
diff --git a/pkg/op/verifier_id_token_hint.go b/pkg/op/verifier_id_token_hint.go
index 3268a5e..7baa075 100644
--- a/pkg/op/verifier_id_token_hint.go
+++ b/pkg/op/verifier_id_token_hint.go
@@ -63,8 +63,8 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet) IDTokenHintVerifi
//VerifyIDTokenHint validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
-func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (*oidc.IDTokenClaims, error) {
- claims := new(oidc.IDTokenClaims)
+func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (oidc.IDTokenClaims, error) {
+ claims := oidc.EmptyIDTokenClaims()
decrypted, err := oidc.DecryptToken(token)
if err != nil {
diff --git a/pkg/op/verifier_jwt_profile.go b/pkg/op/verifier_jwt_profile.go
index b30bdc5..8a31253 100644
--- a/pkg/op/verifier_jwt_profile.go
+++ b/pkg/op/verifier_jwt_profile.go
@@ -8,6 +8,7 @@ import (
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
+ "github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
)
type JWTProfileVerifier interface {
@@ -47,9 +48,9 @@ func (v *jwtProfileVerifier) Offset() time.Duration {
return v.offset
}
-func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) {
+func VerifyJWTAssertion(ctx context.Context, profileRequest *tokenexchange.JWTProfileRequest, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) {
request := new(oidc.JWTTokenRequest)
- payload, err := oidc.ParseToken(assertion, request)
+ payload, err := oidc.ParseToken(profileRequest.Assertion, request)
if err != nil {
return nil, err
}
@@ -72,9 +73,10 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif
keySet := &jwtProfileKeySet{v.Storage(), request.Subject}
- if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil {
+ if err = oidc.CheckSignature(ctx, profileRequest.Assertion, payload, request, nil, keySet); err != nil {
return nil, err
}
+ request.Scopes = profileRequest.Scope
return request, nil
}
diff --git a/pkg/rp/mock/verifier.mock.impl.go b/pkg/rp/mock/verifier.mock.impl.go
deleted file mode 100644
index 53b2f03..0000000
--- a/pkg/rp/mock/verifier.mock.impl.go
+++ /dev/null
@@ -1,37 +0,0 @@
-package mock
-
-import (
- "errors"
- "testing"
-
- "github.com/golang/mock/gomock"
-
- "github.com/caos/oidc/pkg/oidc"
- "github.com/caos/oidc/pkg/rp"
-)
-
-func NewVerifier(t *testing.T) rp.Verifier {
- return NewMockVerifier(gomock.NewController(t))
-}
-
-func NewMockVerifierExpectInvalid(t *testing.T) rp.Verifier {
- m := NewVerifier(t)
- ExpectVerifyInvalid(m)
- return m
-}
-
-func ExpectVerifyInvalid(v rp.Verifier) {
- mock := v.(*MockVerifier)
- mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(nil, errors.New("invalid"))
-}
-
-func NewMockVerifierExpectValid(t *testing.T) rp.Verifier {
- m := NewVerifier(t)
- ExpectVerifyValid(m)
- return m
-}
-
-func ExpectVerifyValid(v rp.Verifier) {
- mock := v.(*MockVerifier)
- mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.Userinfo{Subject: "id"}}, nil)
-}
diff --git a/pkg/rp/relaying_party.go b/pkg/rp/relaying_party.go
index 3fe8b4b..6807221 100644
--- a/pkg/rp/relaying_party.go
+++ b/pkg/rp/relaying_party.go
@@ -4,9 +4,13 @@ import (
"context"
"errors"
"net/http"
+ "net/url"
+ "reflect"
"strings"
+ "time"
"github.com/google/uuid"
+ "github.com/gorilla/schema"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/oidc/grants"
@@ -22,6 +26,16 @@ const (
jwtProfileKey = "urn:ietf:params:oauth:grant-type:jwt-bearer"
)
+var (
+ encoder = func() utils.Encoder {
+ e := schema.NewEncoder()
+ e.RegisterEncoder(oidc.Scopes{}, func(value reflect.Value) string {
+ return value.Interface().(oidc.Scopes).Encode()
+ })
+ return e
+ }()
+)
+
//RelayingParty declares the minimal interface for oidc clients
type RelayingParty interface {
//OAuthConfig returns the oauth2 Config
@@ -312,38 +326,45 @@ func CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc
//ClientCredentials is the `RelayingParty` interface implementation
//handling the oauth2 client credentials grant
func ClientCredentials(ctx context.Context, rp RelayingParty, scopes ...string) (newToken *oauth2.Token, err error) {
- return CallTokenEndpoint(grants.ClientCredentialsGrantBasic(scopes...), rp)
+ return CallTokenEndpointAuthorized(grants.ClientCredentialsGrantBasic(scopes...), rp)
+}
+
+func CallTokenEndpointAuthorized(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
+ config := rp.OAuthConfig()
+ var fn interface{} = utils.AuthorizeBasic(config.ClientID, config.ClientSecret)
+ if config.Endpoint.AuthStyle == oauth2.AuthStyleInParams {
+ fn = func(form url.Values) {
+ form.Set("client_id", config.ClientID)
+ form.Set("client_secret", config.ClientSecret)
+ }
+ }
+ return callTokenEndpoint(request, fn, rp)
}
func CallTokenEndpoint(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
- config := rp.OAuthConfig()
- req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, config.ClientID, config.ClientSecret, config.Endpoint.AuthStyle != oauth2.AuthStyleInParams)
- if err != nil {
- return nil, err
- }
- token := new(oauth2.Token)
- if err := utils.HttpRequest(rp.HttpClient(), req, token); err != nil {
- return nil, err
- }
- return token, nil
+ return callTokenEndpoint(request, nil, rp)
}
-func CallJWTProfileEndpoint(assertion string, rp RelayingParty) (*oauth2.Token, error) {
- form := make(map[string][]string)
- form["assertion"] = []string{assertion}
- form["grant_type"] = []string{jwtProfileKey}
- req, err := http.NewRequest("POST", rp.OAuthConfig().Endpoint.TokenURL, nil)
+func callTokenEndpoint(request interface{}, authFn interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
+ req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, encoder, authFn)
if err != nil {
return nil, err
}
-
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
-
- token := new(oauth2.Token)
- if err := utils.HttpRequest(rp.HttpClient(), req, token); err != nil {
+ var tokenRes struct {
+ AccessToken string `json:"access_token"`
+ TokenType string `json:"token_type"`
+ ExpiresIn int64 `json:"expires_in"`
+ RefreshToken string `json:"refresh_token"`
+ }
+ if err := utils.HttpRequest(rp.HttpClient(), req, &tokenRes); err != nil {
return nil, err
}
- return token, nil
+ return &oauth2.Token{
+ AccessToken: tokenRes.AccessToken,
+ TokenType: tokenRes.TokenType,
+ RefreshToken: tokenRes.RefreshToken,
+ Expiry: time.Now().UTC().Add(time.Duration(tokenRes.ExpiresIn) * time.Second),
+ }, nil
}
func trySetStateCookie(w http.ResponseWriter, state string, rp RelayingParty) error {
diff --git a/pkg/rp/tockenexchange.go b/pkg/rp/tockenexchange.go
index 24b588a..4396dc4 100644
--- a/pkg/rp/tockenexchange.go
+++ b/pkg/rp/tockenexchange.go
@@ -43,12 +43,17 @@ func DelegationTokenExchange(ctx context.Context, subjectToken string, rp Relayi
}
//JWTProfileExchange handles the oauth2 jwt profile exchange
-func JWTProfileExchange(ctx context.Context, assertion *oidc.JWTProfileAssertion, rp RelayingParty) (*oauth2.Token, error) {
+func JWTProfileExchange(ctx context.Context, jwtProfileRequest *tokenexchange.JWTProfileRequest, rp RelayingParty) (*oauth2.Token, error) {
+ return CallTokenEndpoint(jwtProfileRequest, rp)
+}
+
+//JWTProfileExchange handles the oauth2 jwt profile exchange
+func JWTProfileAssertionExchange(ctx context.Context, assertion *oidc.JWTProfileAssertion, scopes oidc.Scopes, rp RelayingParty) (*oauth2.Token, error) {
token, err := generateJWTProfileToken(assertion)
if err != nil {
return nil, err
}
- return CallJWTProfileEndpoint(token, rp)
+ return JWTProfileExchange(ctx, tokenexchange.NewJWTProfileRequest(token, scopes...), rp)
}
func generateJWTProfileToken(assertion *oidc.JWTProfileAssertion) (string, error) {
diff --git a/pkg/rp/verifier.go b/pkg/rp/verifier.go
index ef2cf87..a156f6d 100644
--- a/pkg/rp/verifier.go
+++ b/pkg/rp/verifier.go
@@ -21,12 +21,12 @@ type IDTokenVerifier interface {
//VerifyTokens implement the Token Response Validation as defined in OIDC specification
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
-func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) {
+func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
idToken, err := VerifyIDToken(ctx, idTokenString, v)
if err != nil {
return nil, err
}
- if err := VerifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil {
+ if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil {
return nil, err
}
return idToken, nil
@@ -34,8 +34,8 @@ func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTo
//VerifyIDToken validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
-func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) {
- claims := new(oidc.IDTokenClaims)
+func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
+ claims := oidc.EmptyIDTokenClaims()
decrypted, err := oidc.DecryptToken(token)
if err != nil {
diff --git a/pkg/utils/http.go b/pkg/utils/http.go
index 993febb..fa51815 100644
--- a/pkg/utils/http.go
+++ b/pkg/utils/http.go
@@ -10,8 +10,6 @@ import (
"net/url"
"strings"
"time"
-
- "github.com/gorilla/schema"
)
var (
@@ -27,23 +25,30 @@ type Encoder interface {
Encode(src interface{}, dst map[string][]string) error
}
-func FormRequest(endpoint string, request interface{}, clientID, clientSecret string, header bool) (*http.Request, error) {
- form := make(map[string][]string)
- encoder := schema.NewEncoder()
+type FormAuthorization func(url.Values)
+type RequestAuthorization func(*http.Request)
+
+func AuthorizeBasic(user, password string) RequestAuthorization {
+ return func(req *http.Request) {
+ req.SetBasicAuth(user, password)
+ }
+}
+
+func FormRequest(endpoint string, request interface{}, encoder Encoder, authFn interface{}) (*http.Request, error) {
+ form := url.Values{}
if err := encoder.Encode(request, form); err != nil {
return nil, err
}
- if !header {
- form["client_id"] = []string{clientID}
- form["client_secret"] = []string{clientSecret}
+ if fn, ok := authFn.(FormAuthorization); ok {
+ fn(form)
}
- body := strings.NewReader(url.Values(form).Encode())
+ body := strings.NewReader(form.Encode())
req, err := http.NewRequest("POST", endpoint, body)
if err != nil {
return nil, err
}
- if header {
- req.SetBasicAuth(clientID, clientSecret)
+ if fn, ok := authFn.(RequestAuthorization); ok {
+ fn(req)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req, nil
diff --git a/pkg/utils/marshal.go b/pkg/utils/marshal.go
index e279341..4f53b4e 100644
--- a/pkg/utils/marshal.go
+++ b/pkg/utils/marshal.go
@@ -1,7 +1,9 @@
package utils
import (
+ "bytes"
"encoding/json"
+ "fmt"
"net/http"
"github.com/sirupsen/logrus"
@@ -19,3 +21,15 @@ func MarshalJSON(w http.ResponseWriter, i interface{}) {
logrus.Error("error writing response")
}
}
+
+func ConcatenateJSON(first, second []byte) ([]byte, error) {
+ if !bytes.HasSuffix(first, []byte{'}'}) {
+ return nil, fmt.Errorf("jws: invalid JSON %s", first)
+ }
+ if !bytes.HasPrefix(second, []byte{'{'}) {
+ return nil, fmt.Errorf("jws: invalid JSON %s", second)
+ }
+ first[len(first)-1] = ','
+ first = append(first, second[1:]...)
+ return first, nil
+}
diff --git a/pkg/utils/marshal_test.go b/pkg/utils/marshal_test.go
new file mode 100644
index 0000000..f9221f6
--- /dev/null
+++ b/pkg/utils/marshal_test.go
@@ -0,0 +1,60 @@
+package utils
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestConcatenateJSON(t *testing.T) {
+ type args struct {
+ first []byte
+ second []byte
+ }
+ tests := []struct {
+ name string
+ args args
+ want []byte
+ wantErr bool
+ }{
+ {
+ "invalid first part, error",
+ args{
+ []byte(`invalid`),
+ []byte(`{"some": "thing"}`),
+ },
+ nil,
+ true,
+ },
+ {
+ "invalid second part, error",
+ args{
+ []byte(`{"some": "thing"}`),
+ []byte(`invalid`),
+ },
+ nil,
+ true,
+ },
+ {
+ "both valid, merged",
+ args{
+ []byte(`{"some": "thing"}`),
+ []byte(`{"another": "thing"}`),
+ },
+
+ []byte(`{"some": "thing","another": "thing"}`),
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := ConcatenateJSON(tt.args.first, tt.args.second)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("ConcatenateJSON() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if !bytes.Equal(got, tt.want) {
+ t.Errorf("ConcatenateJSON() got = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/pkg/utils/sign.go b/pkg/utils/sign.go
new file mode 100644
index 0000000..e1efe61
--- /dev/null
+++ b/pkg/utils/sign.go
@@ -0,0 +1,23 @@
+package utils
+
+import (
+ "encoding/json"
+
+ "gopkg.in/square/go-jose.v2"
+)
+
+func Sign(object interface{}, signer jose.Signer) (string, error) {
+ payload, err := json.Marshal(object)
+ if err != nil {
+ return "", err
+ }
+ return SignPayload(payload, signer)
+}
+
+func SignPayload(payload []byte, signer jose.Signer) (string, error) {
+ result, err := signer.Sign(payload)
+ if err != nil {
+ return "", err
+ }
+ return result.CompactSerialize()
+}