diff --git a/example/client/app/app.go b/example/client/app/app.go
index 4c0831b..ea1e6e7 100644
--- a/example/client/app/app.go
+++ b/example/client/app/app.go
@@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
+ "html/template"
+ "io/ioutil"
"net/http"
"os"
"time"
@@ -30,7 +32,7 @@ func main() {
ctx := context.Background()
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
- scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail}
+ scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress, "hodor"}
cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes,
rp.WithPKCE(cookieHandler),
@@ -82,6 +84,62 @@ func main() {
}
w.Write(data)
})
+
+ http.HandleFunc("/jwt-profile", func(w http.ResponseWriter, r *http.Request) {
+ tpl := `
+
+
+
+
+ Login
+
+
+
+
+ `
+ t, err := template.New("login").Parse(tpl)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ err = t.Execute(w, nil)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
+ })
+
+ http.HandleFunc("/jwt-profile-assertion", func(w http.ResponseWriter, r *http.Request) {
+ r.ParseMultipartForm(32 << 20)
+ file, handler, err := r.FormFile("key")
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ defer file.Close()
+
+ key, err := ioutil.ReadAll(file)
+ fmt.Println(handler.Header)
+ assertion, err := oidc.NewJWTProfileAssertionFromFileData(key, []string{issuer})
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ token, err := rp.JWTProfileExchange(ctx, assertion, provider)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ data, err := json.Marshal(token)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ w.Write(data)
+ })
lis := fmt.Sprintf("127.0.0.1:%s", port)
logrus.Infof("listening on http://%s/", lis)
logrus.Fatal(http.ListenAndServe("127.0.0.1:"+port, nil))
diff --git a/example/client/jwt_profile.go b/example/client/jwt_profile.go
new file mode 100644
index 0000000..6dcd11b
--- /dev/null
+++ b/example/client/jwt_profile.go
@@ -0,0 +1,39 @@
+package client
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "time"
+
+ "github.com/sirupsen/logrus"
+
+ "github.com/caos/oidc/pkg/oidc"
+ "github.com/caos/oidc/pkg/rp"
+ "github.com/caos/oidc/pkg/utils"
+)
+
+var (
+ callbackPath string = "/auth/callback"
+ key []byte = []byte("test1234test1234")
+)
+
+func main() {
+ clientID := os.Getenv("CLIENT_ID")
+ clientSecret := os.Getenv("CLIENT_SECRET")
+ issuer := os.Getenv("ISSUER")
+ port := os.Getenv("PORT")
+
+ ctx := context.Background()
+
+ redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
+ scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress, "hodor"}
+ cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
+ provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes,
+ rp.WithPKCE(cookieHandler),
+ rp.WithVerifierOpts(rp.WithIssuedAtOffset(5*time.Second)),
+ )
+ if err != nil {
+ logrus.Fatalf("error creating provider %s", err.Error())
+ }
+}
diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go
index f20fb9b..74e0ed7 100644
--- a/example/internal/mock/storage.go
+++ b/example/internal/mock/storage.go
@@ -210,24 +210,24 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st
return nil
}
-func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (*oidc.Userinfo, error) {
+func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (*oidc.userinfo, error) {
return s.GetUserinfoFromScopes(ctx, "", []string{})
}
-func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (*oidc.Userinfo, error) {
- return &oidc.Userinfo{
+func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (*oidc.userinfo, error) {
+ return &oidc.userinfo{
Subject: a.GetSubject(),
Address: &oidc.UserinfoAddress{
StreetAddress: "Hjkhkj 789\ndsf",
},
- UserinfoEmail: oidc.UserinfoEmail{
+ userinfoEmail: oidc.userinfoEmail{
Email: "test",
EmailVerified: true,
},
- UserinfoPhone: oidc.UserinfoPhone{
+ userinfoPhone: oidc.userinfoPhone{
PhoneNumber: "sadsa",
PhoneNumberVerified: true,
},
- UserinfoProfile: oidc.UserinfoProfile{
+ userinfoProfile: oidc.userinfoProfile{
UpdatedAt: time.Now(),
},
// Claims: map[string]interface{}{
diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go
index 35da2fb..71776af 100644
--- a/pkg/oidc/authorization.go
+++ b/pkg/oidc/authorization.go
@@ -1,15 +1,5 @@
package oidc
-import (
- "encoding/json"
- "errors"
- "strings"
- "time"
-
- "golang.org/x/text/language"
- "gopkg.in/square/go-jose.v2"
-)
-
const (
//ScopeOpenID defines the scope `openid`
//OpenID Connect requests MUST contain the `openid` scope value
@@ -64,23 +54,8 @@ const (
//PromptSelectAccount (`select_account `) directs the Authorization Server to prompt the End-User to select a user account (to enable multi user / session switching)
PromptSelectAccount Prompt = "select_account"
-
- //GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
- GrantTypeCode GrantType = "authorization_code"
- //GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
- GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
-
- //BearerToken defines the token_type `Bearer`, which is returned in a successful token response
- BearerToken = "Bearer"
)
-var displayValues = map[string]Display{
- "page": DisplayPage,
- "popup": DisplayPopup,
- "touch": DisplayTouch,
- "wap": DisplayWAP,
-}
-
//AuthRequest according to:
//https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
type AuthRequest struct {
@@ -121,146 +96,3 @@ func (a *AuthRequest) GetResponseType() ResponseType {
func (a *AuthRequest) GetState() string {
return a.State
}
-
-type TokenRequest interface {
- // GrantType GrantType `schema:"grant_type"`
- GrantType() GrantType
-}
-
-type TokenRequestType GrantType
-
-type AccessTokenRequest struct {
- Code string `schema:"code"`
- RedirectURI string `schema:"redirect_uri"`
- ClientID string `schema:"client_id"`
- ClientSecret string `schema:"client_secret"`
- CodeVerifier string `schema:"code_verifier"`
-}
-
-func (a *AccessTokenRequest) GrantType() GrantType {
- return GrantTypeCode
-}
-
-type AccessTokenResponse struct {
- AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
- TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
- RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
- ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
- IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
-}
-
-type JWTTokenRequest struct {
- Issuer string `json:"iss"`
- Subject string `json:"sub"`
- Scopes Scopes `json:"scope"`
- Audience interface{} `json:"aud"`
- IssuedAt Time `json:"iat"`
- ExpiresAt Time `json:"exp"`
-}
-
-func (j *JWTTokenRequest) GetClientID() string {
- return j.Subject
-}
-
-func (j *JWTTokenRequest) GetSubject() string {
- return j.Subject
-}
-
-func (j *JWTTokenRequest) GetScopes() []string {
- return j.Scopes
-}
-
-type Time time.Time
-
-func (t *Time) UnmarshalJSON(data []byte) error {
- var i int64
- if err := json.Unmarshal(data, &i); err != nil {
- return err
- }
- *t = Time(time.Unix(i, 0).UTC())
- return nil
-}
-
-func (j *JWTTokenRequest) GetIssuer() string {
- return j.Issuer
-}
-
-func (j *JWTTokenRequest) GetAudience() []string {
- return audienceFromJSON(j.Audience)
-}
-
-func (j *JWTTokenRequest) GetExpiration() time.Time {
- return time.Time(j.ExpiresAt)
-}
-
-func (j *JWTTokenRequest) GetIssuedAt() time.Time {
- return time.Time(j.IssuedAt)
-}
-
-func (j *JWTTokenRequest) GetNonce() string {
- return ""
-}
-
-func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string {
- return ""
-}
-
-func (j *JWTTokenRequest) GetAuthTime() time.Time {
- return time.Time{}
-}
-
-func (j *JWTTokenRequest) GetAuthorizedParty() string {
- return ""
-}
-
-func (j *JWTTokenRequest) SetSignature(algorithm jose.SignatureAlgorithm) {}
-
-type TokenExchangeRequest struct {
- subjectToken string `schema:"subject_token"`
- subjectTokenType string `schema:"subject_token_type"`
- actorToken string `schema:"actor_token"`
- actorTokenType string `schema:"actor_token_type"`
- resource []string `schema:"resource"`
- audience []string `schema:"audience"`
- Scope []string `schema:"scope"`
- requestedTokenType string `schema:"requested_token_type"`
-}
-
-type Scopes []string
-
-func (s *Scopes) UnmarshalText(text []byte) error {
- scopes := strings.Split(string(text), " ")
- *s = Scopes(scopes)
- return nil
-}
-
-type ResponseType string
-
-type Display string
-
-func (d *Display) UnmarshalText(text []byte) error {
- var ok bool
- display := string(text)
- *d, ok = displayValues[display]
- if !ok {
- return errors.New("")
- }
- return nil
-}
-
-type Prompt string
-
-type Locales []language.Tag
-
-func (l *Locales) UnmarshalText(text []byte) error {
- locales := strings.Split(string(text), " ")
- for _, locale := range locales {
- tag, err := language.Parse(locale)
- if err == nil && !tag.IsRoot() {
- *l = append(*l, tag)
- }
- }
- return nil
-}
-
-type GrantType string
diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go
index 21b0419..e20dd4a 100644
--- a/pkg/oidc/token.go
+++ b/pkg/oidc/token.go
@@ -3,33 +3,55 @@ package oidc
import (
"encoding/json"
"io/ioutil"
- "strings"
"time"
"golang.org/x/oauth2"
- "golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/utils"
)
+const (
+ //BearerToken defines the token_type `Bearer`, which is returned in a successful token response
+ BearerToken = "Bearer"
+)
+
type Tokens struct {
*oauth2.Token
- IDTokenClaims *IDTokenClaims
+ IDTokenClaims IDTokenClaims
IDToken string
}
-type AccessTokenClaims struct {
+type AccessTokenClaims interface {
+ Claims
+}
+
+type IDTokenClaims interface {
+ Claims
+ GetNotBefore() time.Time
+ GetJWTID() string
+ GetAccessTokenHash() string
+ GetCodeHash() string
+ GetAuthenticationMethodsReferences() []string
+ GetClientID() string
+ GetSignatureAlgorithm() jose.SignatureAlgorithm
+ SetAccessTokenHash(hash string)
+ SetUserinfo(userinfo UserInfoSetter)
+ SetCodeHash(hash string)
+ UserInfo
+}
+
+type accessTokenClaims struct {
Issuer string
Subject string
- Audiences []string
- Expiration time.Time
- IssuedAt time.Time
- NotBefore time.Time
+ Audience Audience
+ Expiration Time
+ IssuedAt Time
+ NotBefore Time
JWTID string
AuthorizedParty string
Nonce string
- AuthTime time.Time
+ AuthTime Time
CodeHash string
AuthenticationContextClassReference string
AuthenticationMethodsReferences []string
@@ -37,38 +59,155 @@ type AccessTokenClaims struct {
Scopes []string
ClientID string
AccessTokenUseNumber int
+
+ signatureAlg jose.SignatureAlgorithm
}
-type IDTokenClaims struct {
- Issuer string
- Audiences []string
- Expiration time.Time
- NotBefore time.Time
- IssuedAt time.Time
- JWTID string
- UpdatedAt time.Time
- AuthorizedParty string
- Nonce string
- AuthTime time.Time
- AccessTokenHash string
- CodeHash string
- AuthenticationContextClassReference string
- AuthenticationMethodsReferences []string
- ClientID string
- Userinfo
+func (a accessTokenClaims) GetIssuer() string {
+ return a.Issuer
+}
- Signature jose.SignatureAlgorithm //TODO: ???
+func (a accessTokenClaims) GetAudience() []string {
+ return a.Audience
+}
+
+func (a accessTokenClaims) GetExpiration() time.Time {
+ return time.Time(a.Expiration)
+}
+
+func (a accessTokenClaims) GetIssuedAt() time.Time {
+ return time.Time(a.IssuedAt)
+}
+
+func (a accessTokenClaims) GetNonce() string {
+ return a.Nonce
+}
+
+func (a accessTokenClaims) GetAuthenticationContextClassReference() string {
+ return a.AuthenticationContextClassReference
+}
+
+func (a accessTokenClaims) GetAuthTime() time.Time {
+ return time.Time(a.AuthTime)
+}
+
+func (a accessTokenClaims) GetAuthorizedParty() string {
+ return a.AuthorizedParty
+}
+
+func (a accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
+ a.signatureAlg = algorithm
+}
+
+func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, id string) AccessTokenClaims {
+ now := time.Now().UTC()
+ return &accessTokenClaims{
+ Issuer: issuer,
+ Subject: subject,
+ Audience: audience,
+ Expiration: Time(expiration),
+ IssuedAt: Time(now),
+ NotBefore: Time(now),
+ JWTID: id,
+ }
+}
+
+type idTokenClaims struct {
+ Issuer string `json:"iss,omitempty"`
+ Audience Audience `json:"aud,omitempty"`
+ Expiration Time `json:"exp,omitempty"`
+ NotBefore Time `json:"nbf,omitempty"`
+ IssuedAt Time `json:"iat,omitempty"`
+ JWTID string `json:"jti,omitempty"`
+ AuthorizedParty string `json:"azp,omitempty"`
+ Nonce string `json:"nonce,omitempty"`
+ AuthTime Time `json:"auth_time,omitempty"`
+ AccessTokenHash string `json:"at_hash,omitempty"`
+ CodeHash string `json:"c_hash,omitempty"`
+ AuthenticationContextClassReference string `json:"acr,omitempty"`
+ AuthenticationMethodsReferences []string `json:"amr,omitempty"`
+ ClientID string `json:"client_id,omitempty"`
+ UserInfo `json:"-"`
+
+ signatureAlg jose.SignatureAlgorithm
+}
+
+func (t *idTokenClaims) SetAccessTokenHash(hash string) {
+ t.AccessTokenHash = hash
+}
+
+func (t *idTokenClaims) SetUserinfo(info UserInfoSetter) {
+ t.UserInfo = info
+}
+
+func (t *idTokenClaims) SetCodeHash(hash string) {
+ t.CodeHash = hash
+}
+
+func EmptyIDTokenClaims() IDTokenClaims {
+ return new(idTokenClaims)
+}
+
+func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string) IDTokenClaims {
+ return &idTokenClaims{
+ Issuer: issuer,
+ Audience: audience,
+ Expiration: Time(expiration),
+ IssuedAt: Time(time.Now().UTC()),
+ AuthTime: Time(authTime),
+ Nonce: nonce,
+ AuthenticationContextClassReference: acr,
+ AuthenticationMethodsReferences: amr,
+ AuthorizedParty: clientID,
+ UserInfo: &userinfo{Subject: subject},
+ }
+}
+
+func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm {
+ return t.signatureAlg
+}
+
+func (t *idTokenClaims) GetNotBefore() time.Time {
+ return time.Time(t.NotBefore)
+}
+
+func (t *idTokenClaims) GetJWTID() string {
+ return t.JWTID
+}
+
+func (t *idTokenClaims) GetAccessTokenHash() string {
+ return t.AccessTokenHash
+}
+
+func (t *idTokenClaims) GetCodeHash() string {
+ return t.CodeHash
+}
+
+func (t *idTokenClaims) GetAuthenticationMethodsReferences() []string {
+ return t.AuthenticationMethodsReferences
+}
+
+func (t *idTokenClaims) GetClientID() string {
+ return t.ClientID
+}
+
+type AccessTokenResponse struct {
+ AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
+ TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
+ RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
+ ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
+ IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
}
type JWTProfileAssertion struct {
- PrivateKeyID string `json:"keyId"`
- PrivateKey []byte `json:"key"`
- Scopes []string `json:"-"`
- Issuer string `json:"-"`
- Subject string `json:"userId"`
- Audience []string `json:"-"`
- Expiration time.Time `json:"-"`
- IssuedAt time.Time `json:"-"`
+ PrivateKeyID string `json:"-"`
+ PrivateKey []byte `json:"-"`
+ Scopes []string `json:"scopes"`
+ Issuer string `json:"issuer"`
+ Subject string `json:"sub"`
+ Audience Audience `json:"aud"`
+ Expiration Time `json:"exp"`
+ IssuedAt Time `json:"iat"`
}
func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWTProfileAssertion, error) {
@@ -76,12 +215,16 @@ func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWT
if err != nil {
return nil, err
}
+ return NewJWTProfileAssertionFromFileData(data, audience)
+}
+
+func NewJWTProfileAssertionFromFileData(data []byte, audience []string) (*JWTProfileAssertion, error) {
keyData := new(struct {
KeyID string `json:"keyId"`
Key string `json:"key"`
UserID string `json:"userId"`
})
- err = json.Unmarshal(data, keyData)
+ err := json.Unmarshal(data, keyData)
if err != nil {
return nil, err
}
@@ -95,241 +238,251 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte)
Issuer: userID,
Scopes: []string{ScopeOpenID},
Subject: userID,
- IssuedAt: time.Now().UTC(),
- Expiration: time.Now().Add(1 * time.Hour).UTC(),
+ IssuedAt: Time(time.Now().UTC()),
+ Expiration: Time(time.Now().Add(1 * time.Hour).UTC()),
Audience: audience,
}
}
-type jsonToken struct {
- Issuer string `json:"iss,omitempty"`
- Subject string `json:"sub,omitempty"`
- Audiences interface{} `json:"aud,omitempty"`
- Expiration int64 `json:"exp,omitempty"`
- NotBefore int64 `json:"nbf,omitempty"`
- IssuedAt int64 `json:"iat,omitempty"`
- JWTID string `json:"jti,omitempty"`
- AuthorizedParty string `json:"azp,omitempty"`
- Nonce string `json:"nonce,omitempty"`
- AuthTime int64 `json:"auth_time,omitempty"`
- AccessTokenHash string `json:"at_hash,omitempty"`
- CodeHash string `json:"c_hash,omitempty"`
- AuthenticationContextClassReference string `json:"acr,omitempty"`
- AuthenticationMethodsReferences []string `json:"amr,omitempty"`
- SessionID string `json:"sid,omitempty"`
- Actor interface{} `json:"act,omitempty"` //TODO: impl
- Scopes string `json:"scope,omitempty"`
- ClientID string `json:"client_id,omitempty"`
- AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl
- AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
- jsonUserinfo
-}
+//
+//type jsonToken struct {
+// Issuer string `json:"iss,omitempty"`
+// Subject string `json:"sub,omitempty"`
+// Audiences interface{} `json:"aud,omitempty"`
+// Expiration int64 `json:"exp,omitempty"`
+// NotBefore int64 `json:"nbf,omitempty"`
+// IssuedAt int64 `json:"iat,omitempty"`
+// JWTID string `json:"jti,omitempty"`
+// AuthorizedParty string `json:"azp,omitempty"`
+// Nonce string `json:"nonce,omitempty"`
+// AuthTime int64 `json:"auth_time,omitempty"`
+// AccessTokenHash string `json:"at_hash,omitempty"`
+// CodeHash string `json:"c_hash,omitempty"`
+// AuthenticationContextClassReference string `json:"acr,omitempty"`
+// AuthenticationMethodsReferences []string `json:"amr,omitempty"`
+// SessionID string `json:"sid,omitempty"`
+// Actor interface{} `json:"act,omitempty"` //TODO: impl
+// Scopes string `json:"scope,omitempty"`
+// ClientID string `json:"client_id,omitempty"`
+// AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl
+// AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
+// jsonUserinfo
+//}
-func (t *AccessTokenClaims) MarshalJSON() ([]byte, error) {
- j := jsonToken{
- Issuer: t.Issuer,
- Subject: t.Subject,
- Audiences: t.Audiences,
- Expiration: timeToJSON(t.Expiration),
- NotBefore: timeToJSON(t.NotBefore),
- IssuedAt: timeToJSON(t.IssuedAt),
- JWTID: t.JWTID,
- AuthorizedParty: t.AuthorizedParty,
- Nonce: t.Nonce,
- AuthTime: timeToJSON(t.AuthTime),
- CodeHash: t.CodeHash,
- AuthenticationContextClassReference: t.AuthenticationContextClassReference,
- AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
- SessionID: t.SessionID,
- Scopes: strings.Join(t.Scopes, " "),
- ClientID: t.ClientID,
- AccessTokenUseNumber: t.AccessTokenUseNumber,
+//
+//func (t *accessTokenClaims) MarshalJSON() ([]byte, error) {
+// j := jsonToken{
+// Issuer: t.Issuer,
+// Subject: t.Subject,
+// Audiences: t.Audiences,
+// Expiration: timeToJSON(t.Expiration),
+// NotBefore: timeToJSON(t.NotBefore),
+// IssuedAt: timeToJSON(t.IssuedAt),
+// JWTID: t.JWTID,
+// AuthorizedParty: t.AuthorizedParty,
+// Nonce: t.Nonce,
+// AuthTime: timeToJSON(t.AuthTime),
+// CodeHash: t.CodeHash,
+// AuthenticationContextClassReference: t.AuthenticationContextClassReference,
+// AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
+// SessionID: t.SessionID,
+// Scopes: strings.Join(t.Scopes, " "),
+// ClientID: t.ClientID,
+// AccessTokenUseNumber: t.AccessTokenUseNumber,
+// }
+// return json.Marshal(j)
+//}
+//
+//func (t *accessTokenClaims) UnmarshalJSON(b []byte) error {
+// var j jsonToken
+// if err := json.Unmarshal(b, &j); err != nil {
+// return err
+// }
+// t.Issuer = j.Issuer
+// t.Subject = j.Subject
+// t.Audiences = audienceFromJSON(j.Audiences)
+// t.Expiration = time.Unix(j.Expiration, 0).UTC()
+// t.NotBefore = time.Unix(j.NotBefore, 0).UTC()
+// t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
+// t.JWTID = j.JWTID
+// t.AuthorizedParty = j.AuthorizedParty
+// t.Nonce = j.Nonce
+// t.AuthTime = time.Unix(j.AuthTime, 0).UTC()
+// t.CodeHash = j.CodeHash
+// t.AuthenticationContextClassReference = j.AuthenticationContextClassReference
+// t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences
+// t.SessionID = j.SessionID
+// t.Scopes = strings.Split(j.Scopes, " ")
+// t.ClientID = j.ClientID
+// t.AccessTokenUseNumber = j.AccessTokenUseNumber
+// return nil
+//}
+//
+func (t *idTokenClaims) MarshalJSON() ([]byte, error) {
+ type Alias idTokenClaims
+ a := &struct {
+ *Alias
+ Expiration int64 `json:"nbf,omitempty"`
+ IssuedAt int64 `json:"nbf,omitempty"`
+ NotBefore int64 `json:"nbf,omitempty"`
+ AuthTime int64 `json:"nbf,omitempty"`
+ }{
+ Alias: (*Alias)(t),
}
- return json.Marshal(j)
+ if !time.Time(t.Expiration).IsZero() {
+ a.Expiration = time.Time(t.Expiration).Unix()
+ }
+ if !time.Time(t.IssuedAt).IsZero() {
+ a.IssuedAt = time.Time(t.IssuedAt).Unix()
+ }
+ if !time.Time(t.NotBefore).IsZero() {
+ a.NotBefore = time.Time(t.NotBefore).Unix()
+ }
+ if !time.Time(t.AuthTime).IsZero() {
+ a.AuthTime = time.Time(t.AuthTime).Unix()
+ }
+ b, err := json.Marshal(a)
+ if err != nil {
+ return nil, err
+ }
+
+ if t.UserInfo == nil {
+ return b, nil
+ }
+ info, err := json.Marshal(t.UserInfo)
+ if err != nil {
+ return nil, err
+ }
+ return utils.ConcatenateJSON(b, info)
}
-func (t *AccessTokenClaims) UnmarshalJSON(b []byte) error {
- var j jsonToken
- if err := json.Unmarshal(b, &j); err != nil {
+func (t *idTokenClaims) UnmarshalJSON(data []byte) error {
+ type Alias idTokenClaims
+ if err := json.Unmarshal(data, (*Alias)(t)); err != nil {
return err
}
- t.Issuer = j.Issuer
- t.Subject = j.Subject
- t.Audiences = audienceFromJSON(j.Audiences)
- t.Expiration = time.Unix(j.Expiration, 0).UTC()
- t.NotBefore = time.Unix(j.NotBefore, 0).UTC()
- t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
- t.JWTID = j.JWTID
- t.AuthorizedParty = j.AuthorizedParty
- t.Nonce = j.Nonce
- t.AuthTime = time.Unix(j.AuthTime, 0).UTC()
- t.CodeHash = j.CodeHash
- t.AuthenticationContextClassReference = j.AuthenticationContextClassReference
- t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences
- t.SessionID = j.SessionID
- t.Scopes = strings.Split(j.Scopes, " ")
- t.ClientID = j.ClientID
- t.AccessTokenUseNumber = j.AccessTokenUseNumber
+ userinfo := new(userinfo)
+ if err := json.Unmarshal(data, userinfo); err != nil {
+ return err
+ }
+ t.UserInfo = userinfo
+
return nil
}
-func (t *IDTokenClaims) MarshalJSON() ([]byte, error) {
- j := jsonToken{
- Issuer: t.Issuer,
- Subject: t.Subject,
- Audiences: t.Audiences,
- Expiration: timeToJSON(t.Expiration),
- NotBefore: timeToJSON(t.NotBefore),
- IssuedAt: timeToJSON(t.IssuedAt),
- JWTID: t.JWTID,
- AuthorizedParty: t.AuthorizedParty,
- Nonce: t.Nonce,
- AuthTime: timeToJSON(t.AuthTime),
- AccessTokenHash: t.AccessTokenHash,
- CodeHash: t.CodeHash,
- AuthenticationContextClassReference: t.AuthenticationContextClassReference,
- AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
- ClientID: t.ClientID,
- }
- j.setUserinfo(t.Userinfo)
- return json.Marshal(j)
-}
-
-func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
- var i jsonToken
- if err := json.Unmarshal(b, &i); err != nil {
- return err
- }
- t.Issuer = i.Issuer
- t.Subject = i.Subject
- t.Audiences = audienceFromJSON(i.Audiences)
- t.Expiration = time.Unix(i.Expiration, 0).UTC()
- t.IssuedAt = time.Unix(i.IssuedAt, 0).UTC()
- t.AuthTime = time.Unix(i.AuthTime, 0).UTC()
- t.Nonce = i.Nonce
- t.AuthenticationContextClassReference = i.AuthenticationContextClassReference
- t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences
- t.AuthorizedParty = i.AuthorizedParty
- t.AccessTokenHash = i.AccessTokenHash
- t.CodeHash = i.CodeHash
- t.UserinfoProfile = i.UnmarshalUserinfoProfile()
- t.UserinfoEmail = i.UnmarshalUserinfoEmail()
- t.UserinfoPhone = i.UnmarshalUserinfoPhone()
- t.Address = i.UnmarshalUserinfoAddress()
- return nil
-}
-
-func (t *IDTokenClaims) GetIssuer() string {
+func (t *idTokenClaims) GetIssuer() string {
return t.Issuer
}
-func (t *IDTokenClaims) GetAudience() []string {
- return t.Audiences
+func (t *idTokenClaims) GetAudience() []string {
+ return t.Audience
}
-func (t *IDTokenClaims) GetExpiration() time.Time {
- return t.Expiration
+func (t *idTokenClaims) GetExpiration() time.Time {
+ return time.Time(t.Expiration)
}
-func (t *IDTokenClaims) GetIssuedAt() time.Time {
- return t.IssuedAt
+func (t *idTokenClaims) GetIssuedAt() time.Time {
+ return time.Time(t.IssuedAt)
}
-func (t *IDTokenClaims) GetNonce() string {
+func (t *idTokenClaims) GetNonce() string {
return t.Nonce
}
-func (t *IDTokenClaims) GetAuthenticationContextClassReference() string {
+func (t *idTokenClaims) GetAuthenticationContextClassReference() string {
return t.AuthenticationContextClassReference
}
-func (t *IDTokenClaims) GetAuthTime() time.Time {
- return t.AuthTime
+func (t *idTokenClaims) GetAuthTime() time.Time {
+ return time.Time(t.AuthTime)
}
-func (t *IDTokenClaims) GetAuthorizedParty() string {
+func (t *idTokenClaims) GetAuthorizedParty() string {
return t.AuthorizedParty
}
-func (t *IDTokenClaims) SetSignature(alg jose.SignatureAlgorithm) {
- t.Signature = alg
+func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) {
+ t.signatureAlg = alg
}
-func (t *JWTProfileAssertion) MarshalJSON() ([]byte, error) {
- j := jsonToken{
- Issuer: t.Issuer,
- Subject: t.Subject,
- Audiences: t.Audience,
- Expiration: timeToJSON(t.Expiration),
- IssuedAt: timeToJSON(t.IssuedAt),
- Scopes: strings.Join(t.Scopes, " "),
- }
- return json.Marshal(j)
-}
+//
+//func (t *JWTProfileAssertion) MarshalJSON() ([]byte, error) {
+// j := jsonToken{
+// Issuer: t.Issuer,
+// Subject: t.Subject,
+// Audiences: t.Audience,
+// Expiration: timeToJSON(t.Expiration),
+// IssuedAt: timeToJSON(t.IssuedAt),
+// Scopes: strings.Join(t.Scopes, " "),
+// }
+// return json.Marshal(j)
+//}
-func (t *JWTProfileAssertion) UnmarshalJSON(b []byte) error {
- var j jsonToken
- if err := json.Unmarshal(b, &j); err != nil {
- return err
- }
+//func (t *JWTProfileAssertion) UnmarshalJSON(b []byte) error {
+// var j jsonToken
+// if err := json.Unmarshal(b, &j); err != nil {
+// return err
+// }
+//
+// t.Issuer = j.Issuer
+// t.Subject = j.Subject
+// t.Audience = audienceFromJSON(j.Audiences)
+// t.Expiration = time.Unix(j.Expiration, 0).UTC()
+// t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
+// t.Scopes = strings.Split(j.Scopes, " ")
+//
+// return nil
+//}
- t.Issuer = j.Issuer
- t.Subject = j.Subject
- t.Audience = audienceFromJSON(j.Audiences)
- t.Expiration = time.Unix(j.Expiration, 0).UTC()
- t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
- t.Scopes = strings.Split(j.Scopes, " ")
-
- return nil
-}
-
-func (j *jsonToken) UnmarshalUserinfoProfile() UserinfoProfile {
- locale, _ := language.Parse(j.Locale)
- return UserinfoProfile{
- Name: j.Name,
- GivenName: j.GivenName,
- FamilyName: j.FamilyName,
- MiddleName: j.MiddleName,
- Nickname: j.Nickname,
- Profile: j.Profile,
- Picture: j.Picture,
- Website: j.Website,
- Gender: Gender(j.Gender),
- Birthdate: j.Birthdate,
- Zoneinfo: j.Zoneinfo,
- Locale: locale,
- UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(),
- PreferredUsername: j.PreferredUsername,
- }
-}
-
-func (j *jsonToken) UnmarshalUserinfoEmail() UserinfoEmail {
- return UserinfoEmail{
- Email: j.Email,
- EmailVerified: j.EmailVerified,
- }
-}
-
-func (j *jsonToken) UnmarshalUserinfoPhone() UserinfoPhone {
- return UserinfoPhone{
- PhoneNumber: j.Phone,
- PhoneNumberVerified: j.PhoneVerified,
- }
-}
-
-func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress {
- if j.JsonUserinfoAddress == nil {
- return nil
- }
- return &UserinfoAddress{
- Country: j.JsonUserinfoAddress.Country,
- Formatted: j.JsonUserinfoAddress.Formatted,
- Locality: j.JsonUserinfoAddress.Locality,
- PostalCode: j.JsonUserinfoAddress.PostalCode,
- Region: j.JsonUserinfoAddress.Region,
- StreetAddress: j.JsonUserinfoAddress.StreetAddress,
- }
-}
+//
+//func (j *jsonToken) UnmarshalUserinfoProfile() userInfoProfile {
+// locale, _ := language.Parse(j.Locale)
+// return userInfoProfile{
+// Name: j.Name,
+// GivenName: j.GivenName,
+// FamilyName: j.FamilyName,
+// MiddleName: j.MiddleName,
+// Nickname: j.Nickname,
+// Profile: j.Profile,
+// Picture: j.Picture,
+// Website: j.Website,
+// Gender: Gender(j.Gender),
+// Birthdate: j.Birthdate,
+// Zoneinfo: j.Zoneinfo,
+// Locale: locale,
+// UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(),
+// PreferredUsername: j.PreferredUsername,
+// }
+//}
+//
+//func (j *jsonToken) UnmarshalUserinfoEmail() userInfoEmail {
+// return userInfoEmail{
+// Email: j.Email,
+// EmailVerified: j.EmailVerified,
+// }
+//}
+//
+//func (j *jsonToken) UnmarshalUserinfoPhone() userInfoPhone {
+// return userInfoPhone{
+// PhoneNumber: j.Phone,
+// PhoneNumberVerified: j.PhoneVerified,
+// }
+//}
+//
+//func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress {
+// if j.JsonUserinfoAddress == nil {
+// return nil
+// }
+// return &UserinfoAddress{
+// Country: j.JsonUserinfoAddress.Country,
+// Formatted: j.JsonUserinfoAddress.Formatted,
+// Locality: j.JsonUserinfoAddress.Locality,
+// PostalCode: j.JsonUserinfoAddress.PostalCode,
+// Region: j.JsonUserinfoAddress.Region,
+// StreetAddress: j.JsonUserinfoAddress.StreetAddress,
+// }
+//}
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
hash, err := utils.GetHashAlgorithm(sigAlgorithm)
diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go
new file mode 100644
index 0000000..c04dfb4
--- /dev/null
+++ b/pkg/oidc/token_request.go
@@ -0,0 +1,101 @@
+package oidc
+
+import (
+ "time"
+
+ "gopkg.in/square/go-jose.v2"
+)
+
+const (
+ //GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
+ GrantTypeCode GrantType = "authorization_code"
+ //GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
+ GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
+)
+
+type GrantType string
+
+type TokenRequest interface {
+ // GrantType GrantType `schema:"grant_type"`
+ GrantType() GrantType
+}
+
+type TokenRequestType GrantType
+
+type AccessTokenRequest struct {
+ Code string `schema:"code"`
+ RedirectURI string `schema:"redirect_uri"`
+ ClientID string `schema:"client_id"`
+ ClientSecret string `schema:"client_secret"`
+ CodeVerifier string `schema:"code_verifier"`
+}
+
+func (a *AccessTokenRequest) GrantType() GrantType {
+ return GrantTypeCode
+}
+
+type JWTTokenRequest struct {
+ Issuer string `json:"iss"`
+ Subject string `json:"sub"`
+ Scopes Scopes `json:"scope"`
+ Audience Audience `json:"aud"`
+ IssuedAt Time `json:"iat"`
+ ExpiresAt Time `json:"exp"`
+}
+
+func (j *JWTTokenRequest) GetClientID() string {
+ return j.Subject
+}
+
+func (j *JWTTokenRequest) GetSubject() string {
+ return j.Subject
+}
+
+func (j *JWTTokenRequest) GetScopes() []string {
+ return j.Scopes
+}
+
+func (j *JWTTokenRequest) GetIssuer() string {
+ return j.Issuer
+}
+
+func (j *JWTTokenRequest) GetAudience() []string {
+ return j.Audience
+}
+
+func (j *JWTTokenRequest) GetExpiration() time.Time {
+ return time.Time(j.ExpiresAt)
+}
+
+func (j *JWTTokenRequest) GetIssuedAt() time.Time {
+ return time.Time(j.IssuedAt)
+}
+
+func (j *JWTTokenRequest) GetNonce() string {
+ return ""
+}
+
+func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string {
+ return ""
+}
+
+func (j *JWTTokenRequest) GetAuthTime() time.Time {
+ return time.Time{}
+}
+
+func (j *JWTTokenRequest) GetAuthorizedParty() string {
+ return ""
+}
+
+func (j *JWTTokenRequest) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {}
+
+type TokenExchangeRequest struct {
+ subjectToken string `schema:"subject_token"`
+ subjectTokenType string `schema:"subject_token_type"`
+ actorToken string `schema:"actor_token"`
+ actorTokenType string `schema:"actor_token_type"`
+ resource []string `schema:"resource"`
+ audience Audience `schema:"audience"`
+ Scope Scopes `schema:"scope"`
+ requestedTokenType string `schema:"requested_token_type"`
+}
diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go
new file mode 100644
index 0000000..ad19684
--- /dev/null
+++ b/pkg/oidc/types.go
@@ -0,0 +1,113 @@
+package oidc
+
+import (
+ "encoding/json"
+ "strings"
+ "time"
+
+ "golang.org/x/text/language"
+)
+
+type Audience []string
+
+func (a *Audience) UnmarshalJSON(text []byte) error {
+ var i interface{}
+ err := json.Unmarshal(text, &i)
+ if err != nil {
+ return err
+ }
+ switch aud := i.(type) {
+ case []interface{}:
+ *a = make([]string, len(aud))
+ for i, audience := range aud {
+ (*a)[i] = audience.(string)
+ }
+ case string:
+ *a = []string{aud}
+ }
+ return nil
+}
+
+type Display string
+
+func (d *Display) UnmarshalText(text []byte) error {
+ display := Display(text)
+ switch display {
+ case DisplayPage, DisplayPopup, DisplayTouch, DisplayWAP:
+ *d = display
+ }
+ return nil
+}
+
+type Gender string
+
+type Locale language.Tag
+
+//{
+// SetLocale(language.Tag)
+// Get() language.Tag
+//}
+//
+//func NewLocale(tag language.Tag) Locale {
+// if tag.IsRoot() {
+// return nil
+// }
+// return &locale{Tag: tag}
+//}
+//
+//type locale struct {
+// language.Tag
+//}
+//
+//func (l *locale) SetLocale(tag language.Tag) {
+// l.Tag = tag
+//}
+//func (l *locale) Get() language.Tag {
+// return l.Tag
+//}
+
+//func (l *locale) MarshalJSON() ([]byte, error) {
+// if l != nil && !l.IsRoot() {
+// return l.MarshalText()
+// }
+// return []byte("null"), nil
+//}
+
+type Locales []language.Tag
+
+func (l *Locales) UnmarshalText(text []byte) error {
+ locales := strings.Split(string(text), " ")
+ for _, locale := range locales {
+ tag, err := language.Parse(locale)
+ if err == nil && !tag.IsRoot() {
+ *l = append(*l, tag)
+ }
+ }
+ return nil
+}
+
+type Prompt string
+
+type ResponseType string
+
+type Scopes []string
+
+func (s *Scopes) UnmarshalText(text []byte) error {
+ *s = strings.Split(string(text), " ")
+ return nil
+}
+
+type Time time.Time
+
+func (t *Time) UnmarshalJSON(data []byte) error {
+ var i int64
+ if err := json.Unmarshal(data, &i); err != nil {
+ return err
+ }
+ *t = Time(time.Unix(i, 0).UTC())
+ return nil
+}
+
+func (t *Time) MarshalJSON() ([]byte, error) {
+ return json.Marshal(time.Time(*t).UTC().Unix())
+}
diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go
new file mode 100644
index 0000000..c451f8c
--- /dev/null
+++ b/pkg/oidc/types_test.go
@@ -0,0 +1,276 @@
+package oidc
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "golang.org/x/text/language"
+)
+
+func TestAudience_UnmarshalText(t *testing.T) {
+ type args struct {
+ text []byte
+ }
+ type res struct {
+ audience Audience
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ wantErr bool
+ }{
+ {
+ "unknown value",
+ args{
+ []byte(`{"aud": "single audience"}`),
+ },
+ res{
+ []string{"single audience"},
+ },
+ false,
+ },
+ {
+ "page",
+ args{
+ []byte(`{"aud": ["multiple", "audience"]}`),
+ },
+ res{
+ []string{"multiple", "audience"},
+ },
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ a := new(struct {
+ Audience Audience `json:"aud"`
+ })
+ if err := json.Unmarshal(tt.args.text, &a); (err != nil) != tt.wantErr {
+ t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ assert.ElementsMatch(t, a.Audience, tt.res.audience)
+ })
+ }
+}
+
+func TestDisplay_UnmarshalText(t *testing.T) {
+ type args struct {
+ text []byte
+ }
+ type res struct {
+ display Display
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ wantErr bool
+ }{
+ {
+ "unknown value",
+ args{
+ []byte("unknown"),
+ },
+ res{},
+ false,
+ },
+ {
+ "page",
+ args{
+ []byte("page"),
+ },
+ res{DisplayPage},
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var d Display
+ if err := d.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
+ t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ if d != tt.res.display {
+ t.Errorf("Display is not correct is = %v, want %v", d, tt.res.display)
+ }
+ })
+ }
+}
+
+func TestLocales_UnmarshalText(t *testing.T) {
+ type args struct {
+ text []byte
+ }
+ type res struct {
+ tags []language.Tag
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ wantErr bool
+ }{
+ {
+ "unknown value",
+ args{
+ []byte("unknown"),
+ },
+ res{},
+ false,
+ },
+ {
+ "undefined",
+ args{
+ []byte("und"),
+ },
+ res{},
+ false,
+ },
+ {
+ "single language",
+ args{
+ []byte("de"),
+ },
+ res{[]language.Tag{language.German}},
+ false,
+ },
+ {
+ "multiple languages",
+ args{
+ []byte("de en"),
+ },
+ res{[]language.Tag{language.German, language.English}},
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var locales Locales
+ if err := locales.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
+ t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ assert.ElementsMatch(t, locales, tt.res.tags)
+ })
+ }
+}
+
+func TestScopes_UnmarshalText(t *testing.T) {
+ type args struct {
+ text []byte
+ }
+ type res struct {
+ scopes []string
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ wantErr bool
+ }{
+ {
+ "unknown value",
+ args{
+ []byte("unknown"),
+ },
+ res{
+ []string{"unknown"},
+ },
+ false,
+ },
+ {
+ "struct",
+ args{
+ []byte(`{"unknown":"value"}`),
+ },
+ res{
+ []string{`{"unknown":"value"}`},
+ },
+ false,
+ },
+ {
+ "openid",
+ args{
+ []byte("openid"),
+ },
+ res{
+ []string{"openid"},
+ },
+ false,
+ },
+ {
+ "multiple scopes",
+ args{
+ []byte("openid email custom:scope"),
+ },
+ res{
+ []string{"openid", "email", "custom:scope"},
+ },
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var scopes Scopes
+ if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
+ t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ assert.ElementsMatch(t, scopes, tt.res.scopes)
+ })
+ }
+}
+
+func TestTime_UnmarshalJSON(t *testing.T) {
+ type args struct {
+ text []byte
+ }
+ type res struct {
+ scopes []string
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ wantErr bool
+ }{
+ {
+ "unknown value",
+ args{
+ []byte("unknown"),
+ },
+ res{
+ []string{"unknown"},
+ },
+ false,
+ },
+ {
+ "openid",
+ args{
+ []byte("openid"),
+ },
+ res{
+ []string{"openid"},
+ },
+ false,
+ },
+ {
+ "multiple scopes",
+ args{
+ []byte("openid email custom:scope"),
+ },
+ res{
+ []string{"openid", "email", "custom:scope"},
+ },
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var scopes Scopes
+ if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
+ t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ assert.ElementsMatch(t, scopes, tt.res.scopes)
+ })
+ }
+}
diff --git a/pkg/oidc/userinfo.go b/pkg/oidc/userinfo.go
index f8b4e6c..31de85f 100644
--- a/pkg/oidc/userinfo.go
+++ b/pkg/oidc/userinfo.go
@@ -2,89 +2,312 @@ package oidc
import (
"encoding/json"
+ "fmt"
"time"
"golang.org/x/text/language"
+
+ "github.com/caos/oidc/pkg/utils"
)
-type Userinfo struct {
- Subject string
- UserinfoProfile
- UserinfoEmail
- UserinfoPhone
- Address *UserinfoAddress
+type UserInfo interface {
+ GetSubject() string
+ UserInfoProfile
+ UserInfoEmail
+ UserInfoPhone
+ GetAddress() UserInfoAddress
+ GetClaim(key string) interface{}
+}
- Authorizations []string
+type UserInfoProfile interface {
+ GetName() string
+ GetGivenName() string
+ GetFamilyName() string
+ GetMiddleName() string
+ GetNickname() string
+ GetProfile() string
+ GetPicture() string
+ GetWebsite() string
+ GetGender() Gender
+ GetBirthdate() string
+ GetZoneinfo() string
+ GetLocale() language.Tag
+ GetPreferredUsername() string
+}
+
+type UserInfoEmail interface {
+ GetEmail() string
+ IsEmailVerified() bool
+}
+
+type UserInfoPhone interface {
+ GetPhoneNumber() string
+ IsPhoneNumberVerified() bool
+}
+
+type UserInfoAddress interface {
+ GetFormatted() string
+ GetStreetAddress() string
+ GetLocality() string
+ GetRegion() string
+ GetPostalCode() string
+ GetCountry() string
+}
+
+type UserInfoSetter interface {
+ UserInfo
+ SetSubject(sub string)
+ UserInfoProfileSetter
+ SetEmail(email string, verified bool)
+ SetPhone(phone string, verified bool)
+ SetAddress(address UserInfoAddress)
+ AppendClaims(key string, values interface{})
+}
+
+type UserInfoProfileSetter interface {
+ SetName(name string)
+ SetGivenName(name string)
+ SetFamilyName(name string)
+ SetMiddleName(name string)
+ SetNickname(name string)
+ SetUpdatedAt(date time.Time)
+ SetProfile(profile string)
+ SetPicture(profile string)
+ SetWebsite(website string)
+ SetGender(gender Gender)
+ SetBirthdate(birthdate string)
+ SetZoneinfo(zoneInfo string)
+ SetLocale(locale language.Tag)
+ SetPreferredUsername(name string)
+}
+
+func NewUserInfo() UserInfoSetter {
+ return &userinfo{}
+}
+
+type userinfo struct {
+ Subject string `json:"sub,omitempty"`
+ userInfoProfile
+ userInfoEmail
+ userInfoPhone
+ Address UserInfoAddress `json:"address,omitempty"`
claims map[string]interface{}
}
-type UserinfoProfile struct {
- Name string
- GivenName string
- FamilyName string
- MiddleName string
- Nickname string
- Profile string
- Picture string
- Website string
- Gender Gender
- Birthdate string
- Zoneinfo string
- Locale language.Tag
- UpdatedAt time.Time
- PreferredUsername string
+func (u *userinfo) GetSubject() string {
+ return u.Subject
}
-type Gender string
-
-type UserinfoEmail struct {
- Email string
- EmailVerified bool
+func (u *userinfo) GetName() string {
+ return u.Name
}
-type UserinfoPhone struct {
- PhoneNumber string
- PhoneNumberVerified bool
+func (u *userinfo) GetGivenName() string {
+ return u.GivenName
}
-type UserinfoAddress struct {
- Formatted string
- StreetAddress string
- Locality string
- Region string
- PostalCode string
- Country string
+func (u *userinfo) GetFamilyName() string {
+ return u.FamilyName
}
-type jsonUserinfoProfile struct {
- Name string `json:"name,omitempty"`
- GivenName string `json:"given_name,omitempty"`
- FamilyName string `json:"family_name,omitempty"`
- MiddleName string `json:"middle_name,omitempty"`
- Nickname string `json:"nickname,omitempty"`
- Profile string `json:"profile,omitempty"`
- Picture string `json:"picture,omitempty"`
- Website string `json:"website,omitempty"`
- Gender string `json:"gender,omitempty"`
- Birthdate string `json:"birthdate,omitempty"`
- Zoneinfo string `json:"zoneinfo,omitempty"`
- Locale string `json:"locale,omitempty"`
- UpdatedAt int64 `json:"updated_at,omitempty"`
- PreferredUsername string `json:"preferred_username,omitempty"`
+func (u *userinfo) GetMiddleName() string {
+ return u.MiddleName
}
-type jsonUserinfoEmail struct {
+func (u *userinfo) GetNickname() string {
+ return u.Nickname
+}
+
+func (u *userinfo) GetProfile() string {
+ return u.Profile
+}
+
+func (u *userinfo) GetPicture() string {
+ return u.Picture
+}
+
+func (u *userinfo) GetWebsite() string {
+ return u.Website
+}
+
+func (u *userinfo) GetGender() Gender {
+ return u.Gender
+}
+
+func (u *userinfo) GetBirthdate() string {
+ return u.Birthdate
+}
+
+func (u *userinfo) GetZoneinfo() string {
+ return u.Zoneinfo
+}
+
+func (u *userinfo) GetLocale() language.Tag {
+ return u.Locale
+}
+
+func (u *userinfo) GetPreferredUsername() string {
+ return u.PreferredUsername
+}
+
+func (u *userinfo) GetEmail() string {
+ return u.Email
+}
+
+func (u *userinfo) IsEmailVerified() bool {
+ return u.EmailVerified
+}
+
+func (u *userinfo) GetPhoneNumber() string {
+ return u.PhoneNumber
+}
+
+func (u *userinfo) IsPhoneNumberVerified() bool {
+ return u.PhoneNumberVerified
+}
+
+func (u *userinfo) GetAddress() UserInfoAddress {
+ return u.Address
+}
+
+func (u *userinfo) GetClaim(key string) interface{} {
+ return u.claims[key]
+}
+
+func (u *userinfo) SetSubject(sub string) {
+ u.Subject = sub
+}
+
+func (u *userinfo) SetName(name string) {
+ u.Name = name
+}
+
+func (u *userinfo) SetGivenName(name string) {
+ u.GivenName = name
+}
+
+func (u *userinfo) SetFamilyName(name string) {
+ u.FamilyName = name
+}
+
+func (u *userinfo) SetMiddleName(name string) {
+ u.MiddleName = name
+}
+
+func (u *userinfo) SetNickname(name string) {
+ u.Nickname = name
+}
+
+func (u *userinfo) SetUpdatedAt(date time.Time) {
+ u.UpdatedAt = Time(date)
+}
+
+func (u *userinfo) SetProfile(profile string) {
+ u.Profile = profile
+}
+
+func (u *userinfo) SetPicture(picture string) {
+ u.Picture = picture
+}
+
+func (u *userinfo) SetWebsite(website string) {
+ u.Website = website
+}
+
+func (u *userinfo) SetGender(gender Gender) {
+ u.Gender = gender
+}
+
+func (u *userinfo) SetBirthdate(birthdate string) {
+ u.Birthdate = birthdate
+}
+
+func (u *userinfo) SetZoneinfo(zoneInfo string) {
+ u.Zoneinfo = zoneInfo
+}
+
+func (u *userinfo) SetLocale(locale language.Tag) {
+ u.Locale = locale
+}
+
+func (u *userinfo) SetPreferredUsername(name string) {
+ u.PreferredUsername = name
+}
+
+func (u *userinfo) SetEmail(email string, verified bool) {
+ u.Email = email
+ u.EmailVerified = verified
+}
+
+func (u *userinfo) SetPhone(phone string, verified bool) {
+ u.PhoneNumber = phone
+ u.PhoneNumberVerified = verified
+}
+
+func (u *userinfo) SetAddress(address UserInfoAddress) {
+ u.Address = address
+}
+
+func (u *userinfo) AppendClaims(key string, value interface{}) {
+ if u.claims == nil {
+ u.claims = make(map[string]interface{})
+ }
+ u.claims[key] = value
+}
+
+func (u *userInfoAddress) GetFormatted() string {
+ panic("implement me")
+}
+
+func (u *userInfoAddress) GetStreetAddress() string {
+ panic("implement me")
+}
+
+func (u *userInfoAddress) GetLocality() string {
+ panic("implement me")
+}
+
+func (u *userInfoAddress) GetRegion() string {
+ panic("implement me")
+}
+
+func (u *userInfoAddress) GetPostalCode() string {
+ panic("implement me")
+}
+
+func (u *userInfoAddress) GetCountry() string {
+ panic("implement me")
+}
+
+type userInfoProfile struct {
+ Name string `json:"name,omitempty"`
+ GivenName string `json:"given_name,omitempty"`
+ FamilyName string `json:"family_name,omitempty"`
+ MiddleName string `json:"middle_name,omitempty"`
+ Nickname string `json:"nickname,omitempty"`
+ Profile string `json:"profile,omitempty"`
+ Picture string `json:"picture,omitempty"`
+ Website string `json:"website,omitempty"`
+ Gender Gender `json:"gender,omitempty"`
+ Birthdate string `json:"birthdate,omitempty"`
+ Zoneinfo string `json:"zoneinfo,omitempty"`
+ Locale language.Tag `json:"locale,omitempty"`
+ UpdatedAt Time `json:"updated_at,omitempty"`
+ PreferredUsername string `json:"preferred_username,omitempty"`
+}
+
+type userInfoEmail struct {
Email string `json:"email,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"`
}
-type jsonUserinfoPhone struct {
- Phone string `json:"phone_number,omitempty"`
- PhoneVerified bool `json:"phone_number_verified,omitempty"`
+type userInfoPhone struct {
+ PhoneNumber string `json:"phone_number,omitempty"`
+ PhoneNumberVerified bool `json:"phone_number_verified,omitempty"`
}
-type jsonUserinfoAddress struct {
+type userInfoAddress struct {
Formatted string `json:"formatted,omitempty"`
StreetAddress string `json:"street_address,omitempty"`
Locality string `json:"locality,omitempty"`
@@ -93,81 +316,68 @@ type jsonUserinfoAddress struct {
Country string `json:"country,omitempty"`
}
-func (i *Userinfo) MarshalJSON() ([]byte, error) {
- j := new(jsonUserinfo)
- j.Subject = i.Subject
- j.setUserinfo(*i)
- j.Authorizations = i.Authorizations
- return json.Marshal(j)
+func NewUserInfoAddress(streetAddress, locality, region, postalCode, country, formatted string) UserInfoAddress {
+ return &userInfoAddress{
+ StreetAddress: streetAddress,
+ Locality: locality,
+ Region: region,
+ PostalCode: postalCode,
+ Country: country,
+ Formatted: formatted,
+ }
+}
+func (i *userinfo) MarshalJSON() ([]byte, error) {
+ type Alias userinfo
+ a := &struct {
+ *Alias
+ Locale interface{} `json:"locale,omitempty"`
+ UpdatedAt int64 `json:"updated_at,omitempty"`
+ }{
+ Alias: (*Alias)(i),
+ }
+ if !i.Locale.IsRoot() {
+ a.Locale = i.Locale
+ }
+ fmt.Println(time.Time(i.UpdatedAt).String())
+ if !time.Time(i.UpdatedAt).IsZero() {
+ a.UpdatedAt = time.Time(i.UpdatedAt).Unix()
+ }
+
+ b, err := json.Marshal(a)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(i.claims) == 0 {
+ return b, nil
+ }
+
+ claims, err := json.Marshal(i.claims)
+ if err != nil {
+ return nil, fmt.Errorf("jws: invalid map of private claims %v", i.claims)
+ }
+ return utils.ConcatenateJSON(b, claims)
}
-func (i *Userinfo) UnmmarshalJSON(data []byte) error {
- if err := json.Unmarshal(data, i); err != nil {
+func (i *userinfo) UnmarshalJSON(data []byte) error {
+ type Alias userinfo
+ a := &struct {
+ *Alias
+ //Locale interface{} `json:"locale,omitempty"`
+ UpdatedAt int64 `json:"update_at,omitempty"`
+ }{
+ Alias: (*Alias)(i),
+ }
+ if err := json.Unmarshal(data, &a); err != nil {
return err
}
- return json.Unmarshal(data, &i.claims)
-}
+ //if !i.Locale.IsRoot() {
+ // a.Locale = i.Locale
+ //}
-type jsonUserinfo struct {
- Subject string `json:"sub,omitempty"`
- jsonUserinfoProfile
- jsonUserinfoEmail
- jsonUserinfoPhone
- JsonUserinfoAddress *jsonUserinfoAddress `json:"address,omitempty"`
- Authorizations []string `json:"authorizations,omitempty"`
-}
+ i.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC())
-func (j *jsonUserinfo) setUserinfo(i Userinfo) {
- j.setUserinfoProfile(i.UserinfoProfile)
- j.setUserinfoEmail(i.UserinfoEmail)
- j.setUserinfoPhone(i.UserinfoPhone)
- j.setUserinfoAddress(i.Address)
-}
-
-func (j *jsonUserinfo) setUserinfoProfile(i UserinfoProfile) {
- j.Name = i.Name
- j.GivenName = i.GivenName
- j.FamilyName = i.FamilyName
- j.MiddleName = i.MiddleName
- j.Nickname = i.Nickname
- j.Profile = i.Profile
- j.Picture = i.Picture
- j.Website = i.Website
- j.Gender = string(i.Gender)
- j.Birthdate = i.Birthdate
- j.Zoneinfo = i.Zoneinfo
- if i.Locale != language.Und {
- j.Locale = i.Locale.String()
- }
- j.UpdatedAt = timeToJSON(i.UpdatedAt)
- j.PreferredUsername = i.PreferredUsername
-}
-
-func (j *jsonUserinfo) setUserinfoEmail(i UserinfoEmail) {
- j.Email = i.Email
- j.EmailVerified = i.EmailVerified
-}
-
-func (j *jsonUserinfo) setUserinfoPhone(i UserinfoPhone) {
- j.Phone = i.PhoneNumber
- j.PhoneVerified = i.PhoneNumberVerified
-}
-
-func (j *jsonUserinfo) setUserinfoAddress(i *UserinfoAddress) {
- if i == nil {
- return
- }
- if i.Country == "" && i.Formatted == "" && i.Locality == "" && i.PostalCode == "" && i.Region == "" && i.StreetAddress == "" {
- return
- }
- j.JsonUserinfoAddress = &jsonUserinfoAddress{
- Country: i.Country,
- Formatted: i.Formatted,
- Locality: i.Locality,
- PostalCode: i.PostalCode,
- Region: i.Region,
- StreetAddress: i.StreetAddress,
- }
+ return nil
}
type UserInfoRequest struct {
diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go
index 492664b..06470a0 100644
--- a/pkg/oidc/verifier.go
+++ b/pkg/oidc/verifier.go
@@ -24,7 +24,7 @@ type Claims interface {
GetAuthenticationContextClassReference() string
GetAuthTime() time.Time
GetAuthorizedParty() string
- SetSignature(algorithm jose.SignatureAlgorithm)
+ SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm)
}
var (
@@ -140,7 +140,7 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl
return ErrSignatureInvalidPayload
}
- claims.SetSignature(jose.SignatureAlgorithm(sig.Header.Algorithm))
+ claims.SetSignatureAlgorithm(jose.SignatureAlgorithm(sig.Header.Algorithm))
return nil
}
diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go
index cf40e62..cee6184 100644
--- a/pkg/op/authrequest.go
+++ b/pkg/op/authrequest.go
@@ -168,7 +168,7 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
if err != nil {
return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.")
}
- return claims.Subject, nil
+ return claims.GetSubject(), nil
}
//RedirectToLogin redirects the end user to the Login UI for authentication
diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go
index 0a6b6a5..7dfcfff 100644
--- a/pkg/op/mock/authorizer.mock.impl.go
+++ b/pkg/op/mock/authorizer.mock.impl.go
@@ -81,7 +81,7 @@ func (s *Sig) Health(ctx context.Context) error {
func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) {
return "", nil
}
-func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) {
+func (s *Sig) SignAccessToken(*oidc.accessTokenClaims) (string, error) {
return "", nil
}
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {
diff --git a/pkg/op/mock/signer.mock.go b/pkg/op/mock/signer.mock.go
index a7d909c..16592a7 100644
--- a/pkg/op/mock/signer.mock.go
+++ b/pkg/op/mock/signer.mock.go
@@ -50,7 +50,7 @@ func (mr *MockSignerMockRecorder) Health(arg0 interface{}) *gomock.Call {
}
// SignAccessToken mocks base method
-func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) {
+func (m *MockSigner) SignAccessToken(arg0 *oidc.accessTokenClaims) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignAccessToken", arg0)
ret0, _ := ret[0].(string)
diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go
index a7ca4cb..bcc04da 100644
--- a/pkg/op/mock/storage.mock.go
+++ b/pkg/op/mock/storage.mock.go
@@ -184,10 +184,10 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac
}
// GetUserinfoFromScopes mocks base method
-func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.Userinfo, error) {
+func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.userinfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2)
- ret0, _ := ret[0].(*oidc.Userinfo)
+ ret0, _ := ret[0].(*oidc.userinfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -199,10 +199,10 @@ func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interf
}
// GetUserinfoFromToken mocks base method
-func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (*oidc.Userinfo, error) {
+func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (*oidc.userinfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2)
- ret0, _ := ret[0].(*oidc.Userinfo)
+ ret0, _ := ret[0].(*oidc.userinfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
diff --git a/pkg/op/op.go b/pkg/op/op.go
index d913c7f..7e8279a 100644
--- a/pkg/op/op.go
+++ b/pkg/op/op.go
@@ -130,7 +130,7 @@ func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opO
}
keyCh := make(chan jose.SigningKey)
- o.signer = NewDefaultSigner(ctx, storage, keyCh)
+ o.signer = NewSigner(ctx, storage, keyCh)
go EnsureKey(ctx, storage, keyCh, o.timer, o.retry)
o.httpHandler = CreateRouter(o, o.interceptors...)
diff --git a/pkg/op/session.go b/pkg/op/session.go
index d04e361..19ebab4 100644
--- a/pkg/op/session.go
+++ b/pkg/op/session.go
@@ -66,8 +66,8 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
if err != nil {
return nil, ErrInvalidRequest("id_token_hint invalid")
}
- session.UserID = claims.Subject
- session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.AuthorizedParty)
+ session.UserID = claims.GetSubject()
+ session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.GetAuthorizedParty())
if err != nil {
return nil, ErrServerError("")
}
diff --git a/pkg/op/signer.go b/pkg/op/signer.go
index e9926cd..5cf585e 100644
--- a/pkg/op/signer.go
+++ b/pkg/op/signer.go
@@ -2,19 +2,17 @@ package op
import (
"context"
- "encoding/json"
"errors"
"github.com/caos/logging"
"gopkg.in/square/go-jose.v2"
-
- "github.com/caos/oidc/pkg/oidc"
)
type Signer interface {
Health(ctx context.Context) error
- SignIDToken(claims *oidc.IDTokenClaims) (string, error)
- SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
+ //SignIDToken(claims *oidc.IDTokenClaims) (string, error)
+ //SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
+ Signer() jose.Signer
SignatureAlgorithm() jose.SignatureAlgorithm
}
@@ -24,7 +22,7 @@ type tokenSigner struct {
alg jose.SignatureAlgorithm
}
-func NewDefaultSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer {
+func NewSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer {
s := &tokenSigner{
storage: storage,
}
@@ -41,6 +39,15 @@ func (s *tokenSigner) Health(_ context.Context) error {
return nil
}
+func (s *tokenSigner) Signer() jose.Signer {
+ return s.signer
+}
+
+//
+//func (s *tokenSigner) Sign(payload []byte) (*jose.JSONWebSignature, error) {
+// return s.signer.Sign(payload)
+//}
+
func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) {
for {
select {
@@ -55,30 +62,6 @@ func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.S
}
}
-func (s *tokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) {
- payload, err := json.Marshal(claims)
- if err != nil {
- return "", err
- }
- return s.Sign(payload)
-}
-
-func (s *tokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) {
- payload, err := json.Marshal(claims)
- if err != nil {
- return "", err
- }
- return s.Sign(payload)
-}
-
-func (s *tokenSigner) Sign(payload []byte) (string, error) {
- result, err := s.signer.Sign(payload)
- if err != nil {
- return "", err
- }
- return result.CompactSerialize()
-}
-
func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
return s.alg
}
diff --git a/pkg/op/signer_test.go b/pkg/op/signer_test.go
index 75e184b..c751c76 100644
--- a/pkg/op/signer_test.go
+++ b/pkg/op/signer_test.go
@@ -38,13 +38,13 @@ import (
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
-// got, err := op.NewDefaultSigner(tt.args.storage)
+// got, err := op.NewSigner(tt.args.storage)
// if (err != nil) != tt.wantErr {
-// t.Errorf("NewDefaultSigner() error = %v, wantErr %v", err, tt.wantErr)
+// t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
// return
// }
// if !reflect.DeepEqual(got, tt.want) {
-// t.Errorf("NewDefaultSigner() = %v, want %v", got, tt.want)
+// t.Errorf("NewSigner() = %v, want %v", got, tt.want)
// }
// })
// }
diff --git a/pkg/op/storage.go b/pkg/op/storage.go
index 669b08e..69784ee 100644
--- a/pkg/op/storage.go
+++ b/pkg/op/storage.go
@@ -28,8 +28,8 @@ type AuthStorage interface {
type OPStorage interface {
GetClientByClientID(context.Context, string) (Client, error)
AuthorizeClientIDSecret(context.Context, string, string) error
- GetUserinfoFromScopes(context.Context, string, []string) (*oidc.Userinfo, error)
- GetUserinfoFromToken(context.Context, string, string) (*oidc.Userinfo, error)
+ GetUserinfoFromScopes(context.Context, string, []string) (oidc.UserInfoSetter, error)
+ GetUserinfoFromToken(context.Context, string, string) (oidc.UserInfoSetter, error)
GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
}
diff --git a/pkg/op/token.go b/pkg/op/token.go
index 87494b9..bb2b3c5 100644
--- a/pkg/op/token.go
+++ b/pkg/op/token.go
@@ -5,6 +5,7 @@ import (
"time"
"github.com/caos/oidc/pkg/oidc"
+ "github.com/caos/oidc/pkg/utils"
)
type TokenCreator interface {
@@ -82,51 +83,34 @@ func CreateBearerToken(id string, crypto Crypto) (string, error) {
}
func CreateJWT(issuer string, authReq TokenRequest, exp time.Time, id string, signer Signer) (string, error) {
- now := time.Now().UTC()
- nbf := now
- claims := &oidc.AccessTokenClaims{
- Issuer: issuer,
- Subject: authReq.GetSubject(),
- Audiences: authReq.GetAudience(),
- Expiration: exp,
- IssuedAt: now,
- NotBefore: nbf,
- JWTID: id,
- }
- return signer.SignAccessToken(claims)
+ claims := oidc.NewAccessTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, id)
+ return utils.Sign(claims, signer.Signer())
}
func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer) (string, error) {
- var err error
exp := time.Now().UTC().Add(validity)
- userinfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes())
- if err != nil {
- return "", err
- }
- claims := &oidc.IDTokenClaims{
- Issuer: issuer,
- Audiences: authReq.GetAudience(),
- Expiration: exp,
- IssuedAt: time.Now().UTC(),
- AuthTime: authReq.GetAuthTime(),
- Nonce: authReq.GetNonce(),
- AuthenticationContextClassReference: authReq.GetACR(),
- AuthenticationMethodsReferences: authReq.GetAMR(),
- AuthorizedParty: authReq.GetClientID(),
- Userinfo: *userinfo,
- }
+ claims := oidc.NewIDTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, authReq.GetAuthTime(), authReq.GetNonce(), authReq.GetACR(), authReq.GetAMR(), authReq.GetClientID())
+
if accessToken != "" {
- claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
+ atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
+ claims.SetAccessTokenHash(atHash)
+ } else {
+ userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes())
+ if err != nil {
+ return "", err
+ }
+ claims.SetUserinfo(userInfo)
}
if code != "" {
- claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm())
+ codeHash, err := oidc.ClaimHash(code, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
+ claims.SetCodeHash(codeHash)
}
- return signer.SignIDToken(claims)
+ return utils.Sign(claims, signer.Signer())
}
diff --git a/pkg/op/verifier_id_token_hint.go b/pkg/op/verifier_id_token_hint.go
index 3268a5e..7baa075 100644
--- a/pkg/op/verifier_id_token_hint.go
+++ b/pkg/op/verifier_id_token_hint.go
@@ -63,8 +63,8 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet) IDTokenHintVerifi
//VerifyIDTokenHint validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
-func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (*oidc.IDTokenClaims, error) {
- claims := new(oidc.IDTokenClaims)
+func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (oidc.IDTokenClaims, error) {
+ claims := oidc.EmptyIDTokenClaims()
decrypted, err := oidc.DecryptToken(token)
if err != nil {
diff --git a/pkg/rp/mock/verifier.mock.impl.go b/pkg/rp/mock/verifier.mock.impl.go
index 53b2f03..0b6dd1c 100644
--- a/pkg/rp/mock/verifier.mock.impl.go
+++ b/pkg/rp/mock/verifier.mock.impl.go
@@ -33,5 +33,5 @@ func NewMockVerifierExpectValid(t *testing.T) rp.Verifier {
func ExpectVerifyValid(v rp.Verifier) {
mock := v.(*MockVerifier)
- mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.Userinfo{Subject: "id"}}, nil)
+ mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.userinfo{Subject: "id"}}, nil)
}
diff --git a/pkg/rp/relaying_party.go b/pkg/rp/relaying_party.go
index 3fe8b4b..fd9ee95 100644
--- a/pkg/rp/relaying_party.go
+++ b/pkg/rp/relaying_party.go
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
+ "net/url"
"strings"
"github.com/google/uuid"
@@ -329,10 +330,10 @@ func CallTokenEndpoint(request interface{}, rp RelayingParty) (newToken *oauth2.
}
func CallJWTProfileEndpoint(assertion string, rp RelayingParty) (*oauth2.Token, error) {
- form := make(map[string][]string)
- form["assertion"] = []string{assertion}
- form["grant_type"] = []string{jwtProfileKey}
- req, err := http.NewRequest("POST", rp.OAuthConfig().Endpoint.TokenURL, nil)
+ form := url.Values{}
+ form.Add("assertion", assertion)
+ form.Add("grant_type", jwtProfileKey)
+ req, err := http.NewRequest("POST", rp.OAuthConfig().Endpoint.TokenURL, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
diff --git a/pkg/rp/verifier.go b/pkg/rp/verifier.go
index ef2cf87..a156f6d 100644
--- a/pkg/rp/verifier.go
+++ b/pkg/rp/verifier.go
@@ -21,12 +21,12 @@ type IDTokenVerifier interface {
//VerifyTokens implement the Token Response Validation as defined in OIDC specification
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
-func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) {
+func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
idToken, err := VerifyIDToken(ctx, idTokenString, v)
if err != nil {
return nil, err
}
- if err := VerifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil {
+ if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil {
return nil, err
}
return idToken, nil
@@ -34,8 +34,8 @@ func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTo
//VerifyIDToken validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
-func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) {
- claims := new(oidc.IDTokenClaims)
+func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
+ claims := oidc.EmptyIDTokenClaims()
decrypted, err := oidc.DecryptToken(token)
if err != nil {
diff --git a/pkg/utils/marshal.go b/pkg/utils/marshal.go
index e279341..4f53b4e 100644
--- a/pkg/utils/marshal.go
+++ b/pkg/utils/marshal.go
@@ -1,7 +1,9 @@
package utils
import (
+ "bytes"
"encoding/json"
+ "fmt"
"net/http"
"github.com/sirupsen/logrus"
@@ -19,3 +21,15 @@ func MarshalJSON(w http.ResponseWriter, i interface{}) {
logrus.Error("error writing response")
}
}
+
+func ConcatenateJSON(first, second []byte) ([]byte, error) {
+ if !bytes.HasSuffix(first, []byte{'}'}) {
+ return nil, fmt.Errorf("jws: invalid JSON %s", first)
+ }
+ if !bytes.HasPrefix(second, []byte{'{'}) {
+ return nil, fmt.Errorf("jws: invalid JSON %s", second)
+ }
+ first[len(first)-1] = ','
+ first = append(first, second[1:]...)
+ return first, nil
+}
diff --git a/pkg/utils/sign.go b/pkg/utils/sign.go
new file mode 100644
index 0000000..e1efe61
--- /dev/null
+++ b/pkg/utils/sign.go
@@ -0,0 +1,23 @@
+package utils
+
+import (
+ "encoding/json"
+
+ "gopkg.in/square/go-jose.v2"
+)
+
+func Sign(object interface{}, signer jose.Signer) (string, error) {
+ payload, err := json.Marshal(object)
+ if err != nil {
+ return "", err
+ }
+ return SignPayload(payload, signer)
+}
+
+func SignPayload(payload []byte, signer jose.Signer) (string, error) {
+ result, err := signer.Sign(payload)
+ if err != nil {
+ return "", err
+ }
+ return result.CompactSerialize()
+}