refactoring

This commit is contained in:
Livio Amstutz 2020-09-25 16:41:25 +02:00
parent 6cfd02e4c9
commit 542ec6ed7b
26 changed files with 1412 additions and 625 deletions

View file

@ -4,6 +4,8 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"html/template"
"io/ioutil"
"net/http" "net/http"
"os" "os"
"time" "time"
@ -30,7 +32,7 @@ func main() {
ctx := context.Background() ctx := context.Background()
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail} scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress, "hodor"}
cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes,
rp.WithPKCE(cookieHandler), rp.WithPKCE(cookieHandler),
@ -82,6 +84,62 @@ func main() {
} }
w.Write(data) w.Write(data)
}) })
http.HandleFunc("/jwt-profile", func(w http.ResponseWriter, r *http.Request) {
tpl := `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Login</title>
</head>
<body>
<form method="POST" action="/jwt-profile-assertion" enctype="multipart/form-data">
<label for="key">Select a key file:</label>
<input type="file" id="key" name="key">
<button type="submit">Upload</button>
</form>
</body>
</html>`
t, err := template.New("login").Parse(tpl)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
err = t.Execute(w, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
})
http.HandleFunc("/jwt-profile-assertion", func(w http.ResponseWriter, r *http.Request) {
r.ParseMultipartForm(32 << 20)
file, handler, err := r.FormFile("key")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer file.Close()
key, err := ioutil.ReadAll(file)
fmt.Println(handler.Header)
assertion, err := oidc.NewJWTProfileAssertionFromFileData(key, []string{issuer})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
token, err := rp.JWTProfileExchange(ctx, assertion, provider)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
data, err := json.Marshal(token)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write(data)
})
lis := fmt.Sprintf("127.0.0.1:%s", port) lis := fmt.Sprintf("127.0.0.1:%s", port)
logrus.Infof("listening on http://%s/", lis) logrus.Infof("listening on http://%s/", lis)
logrus.Fatal(http.ListenAndServe("127.0.0.1:"+port, nil)) logrus.Fatal(http.ListenAndServe("127.0.0.1:"+port, nil))

View file

@ -0,0 +1,39 @@
package client
import (
"context"
"fmt"
"os"
"time"
"github.com/sirupsen/logrus"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/rp"
"github.com/caos/oidc/pkg/utils"
)
var (
callbackPath string = "/auth/callback"
key []byte = []byte("test1234test1234")
)
func main() {
clientID := os.Getenv("CLIENT_ID")
clientSecret := os.Getenv("CLIENT_SECRET")
issuer := os.Getenv("ISSUER")
port := os.Getenv("PORT")
ctx := context.Background()
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress, "hodor"}
cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes,
rp.WithPKCE(cookieHandler),
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5*time.Second)),
)
if err != nil {
logrus.Fatalf("error creating provider %s", err.Error())
}
}

View file

@ -210,24 +210,24 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st
return nil return nil
} }
func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (*oidc.Userinfo, error) { func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (*oidc.userinfo, error) {
return s.GetUserinfoFromScopes(ctx, "", []string{}) return s.GetUserinfoFromScopes(ctx, "", []string{})
} }
func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (*oidc.Userinfo, error) { func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (*oidc.userinfo, error) {
return &oidc.Userinfo{ return &oidc.userinfo{
Subject: a.GetSubject(), Subject: a.GetSubject(),
Address: &oidc.UserinfoAddress{ Address: &oidc.UserinfoAddress{
StreetAddress: "Hjkhkj 789\ndsf", StreetAddress: "Hjkhkj 789\ndsf",
}, },
UserinfoEmail: oidc.UserinfoEmail{ userinfoEmail: oidc.userinfoEmail{
Email: "test", Email: "test",
EmailVerified: true, EmailVerified: true,
}, },
UserinfoPhone: oidc.UserinfoPhone{ userinfoPhone: oidc.userinfoPhone{
PhoneNumber: "sadsa", PhoneNumber: "sadsa",
PhoneNumberVerified: true, PhoneNumberVerified: true,
}, },
UserinfoProfile: oidc.UserinfoProfile{ userinfoProfile: oidc.userinfoProfile{
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
}, },
// Claims: map[string]interface{}{ // Claims: map[string]interface{}{

View file

@ -1,15 +1,5 @@
package oidc package oidc
import (
"encoding/json"
"errors"
"strings"
"time"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
)
const ( const (
//ScopeOpenID defines the scope `openid` //ScopeOpenID defines the scope `openid`
//OpenID Connect requests MUST contain the `openid` scope value //OpenID Connect requests MUST contain the `openid` scope value
@ -64,23 +54,8 @@ const (
//PromptSelectAccount (`select_account `) directs the Authorization Server to prompt the End-User to select a user account (to enable multi user / session switching) //PromptSelectAccount (`select_account `) directs the Authorization Server to prompt the End-User to select a user account (to enable multi user / session switching)
PromptSelectAccount Prompt = "select_account" PromptSelectAccount Prompt = "select_account"
//GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
GrantTypeCode GrantType = "authorization_code"
//GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
//BearerToken defines the token_type `Bearer`, which is returned in a successful token response
BearerToken = "Bearer"
) )
var displayValues = map[string]Display{
"page": DisplayPage,
"popup": DisplayPopup,
"touch": DisplayTouch,
"wap": DisplayWAP,
}
//AuthRequest according to: //AuthRequest according to:
//https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest //https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
type AuthRequest struct { type AuthRequest struct {
@ -121,146 +96,3 @@ func (a *AuthRequest) GetResponseType() ResponseType {
func (a *AuthRequest) GetState() string { func (a *AuthRequest) GetState() string {
return a.State return a.State
} }
type TokenRequest interface {
// GrantType GrantType `schema:"grant_type"`
GrantType() GrantType
}
type TokenRequestType GrantType
type AccessTokenRequest struct {
Code string `schema:"code"`
RedirectURI string `schema:"redirect_uri"`
ClientID string `schema:"client_id"`
ClientSecret string `schema:"client_secret"`
CodeVerifier string `schema:"code_verifier"`
}
func (a *AccessTokenRequest) GrantType() GrantType {
return GrantTypeCode
}
type AccessTokenResponse struct {
AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
}
type JWTTokenRequest struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Scopes Scopes `json:"scope"`
Audience interface{} `json:"aud"`
IssuedAt Time `json:"iat"`
ExpiresAt Time `json:"exp"`
}
func (j *JWTTokenRequest) GetClientID() string {
return j.Subject
}
func (j *JWTTokenRequest) GetSubject() string {
return j.Subject
}
func (j *JWTTokenRequest) GetScopes() []string {
return j.Scopes
}
type Time time.Time
func (t *Time) UnmarshalJSON(data []byte) error {
var i int64
if err := json.Unmarshal(data, &i); err != nil {
return err
}
*t = Time(time.Unix(i, 0).UTC())
return nil
}
func (j *JWTTokenRequest) GetIssuer() string {
return j.Issuer
}
func (j *JWTTokenRequest) GetAudience() []string {
return audienceFromJSON(j.Audience)
}
func (j *JWTTokenRequest) GetExpiration() time.Time {
return time.Time(j.ExpiresAt)
}
func (j *JWTTokenRequest) GetIssuedAt() time.Time {
return time.Time(j.IssuedAt)
}
func (j *JWTTokenRequest) GetNonce() string {
return ""
}
func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string {
return ""
}
func (j *JWTTokenRequest) GetAuthTime() time.Time {
return time.Time{}
}
func (j *JWTTokenRequest) GetAuthorizedParty() string {
return ""
}
func (j *JWTTokenRequest) SetSignature(algorithm jose.SignatureAlgorithm) {}
type TokenExchangeRequest struct {
subjectToken string `schema:"subject_token"`
subjectTokenType string `schema:"subject_token_type"`
actorToken string `schema:"actor_token"`
actorTokenType string `schema:"actor_token_type"`
resource []string `schema:"resource"`
audience []string `schema:"audience"`
Scope []string `schema:"scope"`
requestedTokenType string `schema:"requested_token_type"`
}
type Scopes []string
func (s *Scopes) UnmarshalText(text []byte) error {
scopes := strings.Split(string(text), " ")
*s = Scopes(scopes)
return nil
}
type ResponseType string
type Display string
func (d *Display) UnmarshalText(text []byte) error {
var ok bool
display := string(text)
*d, ok = displayValues[display]
if !ok {
return errors.New("")
}
return nil
}
type Prompt string
type Locales []language.Tag
func (l *Locales) UnmarshalText(text []byte) error {
locales := strings.Split(string(text), " ")
for _, locale := range locales {
tag, err := language.Parse(locale)
if err == nil && !tag.IsRoot() {
*l = append(*l, tag)
}
}
return nil
}
type GrantType string

View file

@ -3,33 +3,55 @@ package oidc
import ( import (
"encoding/json" "encoding/json"
"io/ioutil" "io/ioutil"
"strings"
"time" "time"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/utils"
) )
const (
//BearerToken defines the token_type `Bearer`, which is returned in a successful token response
BearerToken = "Bearer"
)
type Tokens struct { type Tokens struct {
*oauth2.Token *oauth2.Token
IDTokenClaims *IDTokenClaims IDTokenClaims IDTokenClaims
IDToken string IDToken string
} }
type AccessTokenClaims struct { type AccessTokenClaims interface {
Claims
}
type IDTokenClaims interface {
Claims
GetNotBefore() time.Time
GetJWTID() string
GetAccessTokenHash() string
GetCodeHash() string
GetAuthenticationMethodsReferences() []string
GetClientID() string
GetSignatureAlgorithm() jose.SignatureAlgorithm
SetAccessTokenHash(hash string)
SetUserinfo(userinfo UserInfoSetter)
SetCodeHash(hash string)
UserInfo
}
type accessTokenClaims struct {
Issuer string Issuer string
Subject string Subject string
Audiences []string Audience Audience
Expiration time.Time Expiration Time
IssuedAt time.Time IssuedAt Time
NotBefore time.Time NotBefore Time
JWTID string JWTID string
AuthorizedParty string AuthorizedParty string
Nonce string Nonce string
AuthTime time.Time AuthTime Time
CodeHash string CodeHash string
AuthenticationContextClassReference string AuthenticationContextClassReference string
AuthenticationMethodsReferences []string AuthenticationMethodsReferences []string
@ -37,38 +59,155 @@ type AccessTokenClaims struct {
Scopes []string Scopes []string
ClientID string ClientID string
AccessTokenUseNumber int AccessTokenUseNumber int
signatureAlg jose.SignatureAlgorithm
} }
type IDTokenClaims struct { func (a accessTokenClaims) GetIssuer() string {
Issuer string return a.Issuer
Audiences []string }
Expiration time.Time
NotBefore time.Time
IssuedAt time.Time
JWTID string
UpdatedAt time.Time
AuthorizedParty string
Nonce string
AuthTime time.Time
AccessTokenHash string
CodeHash string
AuthenticationContextClassReference string
AuthenticationMethodsReferences []string
ClientID string
Userinfo
Signature jose.SignatureAlgorithm //TODO: ??? func (a accessTokenClaims) GetAudience() []string {
return a.Audience
}
func (a accessTokenClaims) GetExpiration() time.Time {
return time.Time(a.Expiration)
}
func (a accessTokenClaims) GetIssuedAt() time.Time {
return time.Time(a.IssuedAt)
}
func (a accessTokenClaims) GetNonce() string {
return a.Nonce
}
func (a accessTokenClaims) GetAuthenticationContextClassReference() string {
return a.AuthenticationContextClassReference
}
func (a accessTokenClaims) GetAuthTime() time.Time {
return time.Time(a.AuthTime)
}
func (a accessTokenClaims) GetAuthorizedParty() string {
return a.AuthorizedParty
}
func (a accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
a.signatureAlg = algorithm
}
func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, id string) AccessTokenClaims {
now := time.Now().UTC()
return &accessTokenClaims{
Issuer: issuer,
Subject: subject,
Audience: audience,
Expiration: Time(expiration),
IssuedAt: Time(now),
NotBefore: Time(now),
JWTID: id,
}
}
type idTokenClaims struct {
Issuer string `json:"iss,omitempty"`
Audience Audience `json:"aud,omitempty"`
Expiration Time `json:"exp,omitempty"`
NotBefore Time `json:"nbf,omitempty"`
IssuedAt Time `json:"iat,omitempty"`
JWTID string `json:"jti,omitempty"`
AuthorizedParty string `json:"azp,omitempty"`
Nonce string `json:"nonce,omitempty"`
AuthTime Time `json:"auth_time,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
AuthenticationContextClassReference string `json:"acr,omitempty"`
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
ClientID string `json:"client_id,omitempty"`
UserInfo `json:"-"`
signatureAlg jose.SignatureAlgorithm
}
func (t *idTokenClaims) SetAccessTokenHash(hash string) {
t.AccessTokenHash = hash
}
func (t *idTokenClaims) SetUserinfo(info UserInfoSetter) {
t.UserInfo = info
}
func (t *idTokenClaims) SetCodeHash(hash string) {
t.CodeHash = hash
}
func EmptyIDTokenClaims() IDTokenClaims {
return new(idTokenClaims)
}
func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string) IDTokenClaims {
return &idTokenClaims{
Issuer: issuer,
Audience: audience,
Expiration: Time(expiration),
IssuedAt: Time(time.Now().UTC()),
AuthTime: Time(authTime),
Nonce: nonce,
AuthenticationContextClassReference: acr,
AuthenticationMethodsReferences: amr,
AuthorizedParty: clientID,
UserInfo: &userinfo{Subject: subject},
}
}
func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm {
return t.signatureAlg
}
func (t *idTokenClaims) GetNotBefore() time.Time {
return time.Time(t.NotBefore)
}
func (t *idTokenClaims) GetJWTID() string {
return t.JWTID
}
func (t *idTokenClaims) GetAccessTokenHash() string {
return t.AccessTokenHash
}
func (t *idTokenClaims) GetCodeHash() string {
return t.CodeHash
}
func (t *idTokenClaims) GetAuthenticationMethodsReferences() []string {
return t.AuthenticationMethodsReferences
}
func (t *idTokenClaims) GetClientID() string {
return t.ClientID
}
type AccessTokenResponse struct {
AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
} }
type JWTProfileAssertion struct { type JWTProfileAssertion struct {
PrivateKeyID string `json:"keyId"` PrivateKeyID string `json:"-"`
PrivateKey []byte `json:"key"` PrivateKey []byte `json:"-"`
Scopes []string `json:"-"` Scopes []string `json:"scopes"`
Issuer string `json:"-"` Issuer string `json:"issuer"`
Subject string `json:"userId"` Subject string `json:"sub"`
Audience []string `json:"-"` Audience Audience `json:"aud"`
Expiration time.Time `json:"-"` Expiration Time `json:"exp"`
IssuedAt time.Time `json:"-"` IssuedAt Time `json:"iat"`
} }
func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWTProfileAssertion, error) { func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWTProfileAssertion, error) {
@ -76,12 +215,16 @@ func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWT
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewJWTProfileAssertionFromFileData(data, audience)
}
func NewJWTProfileAssertionFromFileData(data []byte, audience []string) (*JWTProfileAssertion, error) {
keyData := new(struct { keyData := new(struct {
KeyID string `json:"keyId"` KeyID string `json:"keyId"`
Key string `json:"key"` Key string `json:"key"`
UserID string `json:"userId"` UserID string `json:"userId"`
}) })
err = json.Unmarshal(data, keyData) err := json.Unmarshal(data, keyData)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -95,241 +238,251 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte)
Issuer: userID, Issuer: userID,
Scopes: []string{ScopeOpenID}, Scopes: []string{ScopeOpenID},
Subject: userID, Subject: userID,
IssuedAt: time.Now().UTC(), IssuedAt: Time(time.Now().UTC()),
Expiration: time.Now().Add(1 * time.Hour).UTC(), Expiration: Time(time.Now().Add(1 * time.Hour).UTC()),
Audience: audience, Audience: audience,
} }
} }
type jsonToken struct { //
Issuer string `json:"iss,omitempty"` //type jsonToken struct {
Subject string `json:"sub,omitempty"` // Issuer string `json:"iss,omitempty"`
Audiences interface{} `json:"aud,omitempty"` // Subject string `json:"sub,omitempty"`
Expiration int64 `json:"exp,omitempty"` // Audiences interface{} `json:"aud,omitempty"`
NotBefore int64 `json:"nbf,omitempty"` // Expiration int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"` // NotBefore int64 `json:"nbf,omitempty"`
JWTID string `json:"jti,omitempty"` // IssuedAt int64 `json:"iat,omitempty"`
AuthorizedParty string `json:"azp,omitempty"` // JWTID string `json:"jti,omitempty"`
Nonce string `json:"nonce,omitempty"` // AuthorizedParty string `json:"azp,omitempty"`
AuthTime int64 `json:"auth_time,omitempty"` // Nonce string `json:"nonce,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"` // AuthTime int64 `json:"auth_time,omitempty"`
CodeHash string `json:"c_hash,omitempty"` // AccessTokenHash string `json:"at_hash,omitempty"`
AuthenticationContextClassReference string `json:"acr,omitempty"` // CodeHash string `json:"c_hash,omitempty"`
AuthenticationMethodsReferences []string `json:"amr,omitempty"` // AuthenticationContextClassReference string `json:"acr,omitempty"`
SessionID string `json:"sid,omitempty"` // AuthenticationMethodsReferences []string `json:"amr,omitempty"`
Actor interface{} `json:"act,omitempty"` //TODO: impl // SessionID string `json:"sid,omitempty"`
Scopes string `json:"scope,omitempty"` // Actor interface{} `json:"act,omitempty"` //TODO: impl
ClientID string `json:"client_id,omitempty"` // Scopes string `json:"scope,omitempty"`
AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl // ClientID string `json:"client_id,omitempty"`
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"` // AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl
jsonUserinfo // AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
} // jsonUserinfo
//}
func (t *AccessTokenClaims) MarshalJSON() ([]byte, error) { //
j := jsonToken{ //func (t *accessTokenClaims) MarshalJSON() ([]byte, error) {
Issuer: t.Issuer, // j := jsonToken{
Subject: t.Subject, // Issuer: t.Issuer,
Audiences: t.Audiences, // Subject: t.Subject,
Expiration: timeToJSON(t.Expiration), // Audiences: t.Audiences,
NotBefore: timeToJSON(t.NotBefore), // Expiration: timeToJSON(t.Expiration),
IssuedAt: timeToJSON(t.IssuedAt), // NotBefore: timeToJSON(t.NotBefore),
JWTID: t.JWTID, // IssuedAt: timeToJSON(t.IssuedAt),
AuthorizedParty: t.AuthorizedParty, // JWTID: t.JWTID,
Nonce: t.Nonce, // AuthorizedParty: t.AuthorizedParty,
AuthTime: timeToJSON(t.AuthTime), // Nonce: t.Nonce,
CodeHash: t.CodeHash, // AuthTime: timeToJSON(t.AuthTime),
AuthenticationContextClassReference: t.AuthenticationContextClassReference, // CodeHash: t.CodeHash,
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences, // AuthenticationContextClassReference: t.AuthenticationContextClassReference,
SessionID: t.SessionID, // AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
Scopes: strings.Join(t.Scopes, " "), // SessionID: t.SessionID,
ClientID: t.ClientID, // Scopes: strings.Join(t.Scopes, " "),
AccessTokenUseNumber: t.AccessTokenUseNumber, // ClientID: t.ClientID,
// AccessTokenUseNumber: t.AccessTokenUseNumber,
// }
// return json.Marshal(j)
//}
//
//func (t *accessTokenClaims) UnmarshalJSON(b []byte) error {
// var j jsonToken
// if err := json.Unmarshal(b, &j); err != nil {
// return err
// }
// t.Issuer = j.Issuer
// t.Subject = j.Subject
// t.Audiences = audienceFromJSON(j.Audiences)
// t.Expiration = time.Unix(j.Expiration, 0).UTC()
// t.NotBefore = time.Unix(j.NotBefore, 0).UTC()
// t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
// t.JWTID = j.JWTID
// t.AuthorizedParty = j.AuthorizedParty
// t.Nonce = j.Nonce
// t.AuthTime = time.Unix(j.AuthTime, 0).UTC()
// t.CodeHash = j.CodeHash
// t.AuthenticationContextClassReference = j.AuthenticationContextClassReference
// t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences
// t.SessionID = j.SessionID
// t.Scopes = strings.Split(j.Scopes, " ")
// t.ClientID = j.ClientID
// t.AccessTokenUseNumber = j.AccessTokenUseNumber
// return nil
//}
//
func (t *idTokenClaims) MarshalJSON() ([]byte, error) {
type Alias idTokenClaims
a := &struct {
*Alias
Expiration int64 `json:"nbf,omitempty"`
IssuedAt int64 `json:"nbf,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
AuthTime int64 `json:"nbf,omitempty"`
}{
Alias: (*Alias)(t),
} }
return json.Marshal(j) if !time.Time(t.Expiration).IsZero() {
a.Expiration = time.Time(t.Expiration).Unix()
}
if !time.Time(t.IssuedAt).IsZero() {
a.IssuedAt = time.Time(t.IssuedAt).Unix()
}
if !time.Time(t.NotBefore).IsZero() {
a.NotBefore = time.Time(t.NotBefore).Unix()
}
if !time.Time(t.AuthTime).IsZero() {
a.AuthTime = time.Time(t.AuthTime).Unix()
}
b, err := json.Marshal(a)
if err != nil {
return nil, err
}
if t.UserInfo == nil {
return b, nil
}
info, err := json.Marshal(t.UserInfo)
if err != nil {
return nil, err
}
return utils.ConcatenateJSON(b, info)
} }
func (t *AccessTokenClaims) UnmarshalJSON(b []byte) error { func (t *idTokenClaims) UnmarshalJSON(data []byte) error {
var j jsonToken type Alias idTokenClaims
if err := json.Unmarshal(b, &j); err != nil { if err := json.Unmarshal(data, (*Alias)(t)); err != nil {
return err return err
} }
t.Issuer = j.Issuer userinfo := new(userinfo)
t.Subject = j.Subject if err := json.Unmarshal(data, userinfo); err != nil {
t.Audiences = audienceFromJSON(j.Audiences) return err
t.Expiration = time.Unix(j.Expiration, 0).UTC() }
t.NotBefore = time.Unix(j.NotBefore, 0).UTC() t.UserInfo = userinfo
t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
t.JWTID = j.JWTID
t.AuthorizedParty = j.AuthorizedParty
t.Nonce = j.Nonce
t.AuthTime = time.Unix(j.AuthTime, 0).UTC()
t.CodeHash = j.CodeHash
t.AuthenticationContextClassReference = j.AuthenticationContextClassReference
t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences
t.SessionID = j.SessionID
t.Scopes = strings.Split(j.Scopes, " ")
t.ClientID = j.ClientID
t.AccessTokenUseNumber = j.AccessTokenUseNumber
return nil return nil
} }
func (t *IDTokenClaims) MarshalJSON() ([]byte, error) { func (t *idTokenClaims) GetIssuer() string {
j := jsonToken{
Issuer: t.Issuer,
Subject: t.Subject,
Audiences: t.Audiences,
Expiration: timeToJSON(t.Expiration),
NotBefore: timeToJSON(t.NotBefore),
IssuedAt: timeToJSON(t.IssuedAt),
JWTID: t.JWTID,
AuthorizedParty: t.AuthorizedParty,
Nonce: t.Nonce,
AuthTime: timeToJSON(t.AuthTime),
AccessTokenHash: t.AccessTokenHash,
CodeHash: t.CodeHash,
AuthenticationContextClassReference: t.AuthenticationContextClassReference,
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
ClientID: t.ClientID,
}
j.setUserinfo(t.Userinfo)
return json.Marshal(j)
}
func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
var i jsonToken
if err := json.Unmarshal(b, &i); err != nil {
return err
}
t.Issuer = i.Issuer
t.Subject = i.Subject
t.Audiences = audienceFromJSON(i.Audiences)
t.Expiration = time.Unix(i.Expiration, 0).UTC()
t.IssuedAt = time.Unix(i.IssuedAt, 0).UTC()
t.AuthTime = time.Unix(i.AuthTime, 0).UTC()
t.Nonce = i.Nonce
t.AuthenticationContextClassReference = i.AuthenticationContextClassReference
t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences
t.AuthorizedParty = i.AuthorizedParty
t.AccessTokenHash = i.AccessTokenHash
t.CodeHash = i.CodeHash
t.UserinfoProfile = i.UnmarshalUserinfoProfile()
t.UserinfoEmail = i.UnmarshalUserinfoEmail()
t.UserinfoPhone = i.UnmarshalUserinfoPhone()
t.Address = i.UnmarshalUserinfoAddress()
return nil
}
func (t *IDTokenClaims) GetIssuer() string {
return t.Issuer return t.Issuer
} }
func (t *IDTokenClaims) GetAudience() []string { func (t *idTokenClaims) GetAudience() []string {
return t.Audiences return t.Audience
} }
func (t *IDTokenClaims) GetExpiration() time.Time { func (t *idTokenClaims) GetExpiration() time.Time {
return t.Expiration return time.Time(t.Expiration)
} }
func (t *IDTokenClaims) GetIssuedAt() time.Time { func (t *idTokenClaims) GetIssuedAt() time.Time {
return t.IssuedAt return time.Time(t.IssuedAt)
} }
func (t *IDTokenClaims) GetNonce() string { func (t *idTokenClaims) GetNonce() string {
return t.Nonce return t.Nonce
} }
func (t *IDTokenClaims) GetAuthenticationContextClassReference() string { func (t *idTokenClaims) GetAuthenticationContextClassReference() string {
return t.AuthenticationContextClassReference return t.AuthenticationContextClassReference
} }
func (t *IDTokenClaims) GetAuthTime() time.Time { func (t *idTokenClaims) GetAuthTime() time.Time {
return t.AuthTime return time.Time(t.AuthTime)
} }
func (t *IDTokenClaims) GetAuthorizedParty() string { func (t *idTokenClaims) GetAuthorizedParty() string {
return t.AuthorizedParty return t.AuthorizedParty
} }
func (t *IDTokenClaims) SetSignature(alg jose.SignatureAlgorithm) { func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) {
t.Signature = alg t.signatureAlg = alg
} }
func (t *JWTProfileAssertion) MarshalJSON() ([]byte, error) { //
j := jsonToken{ //func (t *JWTProfileAssertion) MarshalJSON() ([]byte, error) {
Issuer: t.Issuer, // j := jsonToken{
Subject: t.Subject, // Issuer: t.Issuer,
Audiences: t.Audience, // Subject: t.Subject,
Expiration: timeToJSON(t.Expiration), // Audiences: t.Audience,
IssuedAt: timeToJSON(t.IssuedAt), // Expiration: timeToJSON(t.Expiration),
Scopes: strings.Join(t.Scopes, " "), // IssuedAt: timeToJSON(t.IssuedAt),
} // Scopes: strings.Join(t.Scopes, " "),
return json.Marshal(j) // }
} // return json.Marshal(j)
//}
func (t *JWTProfileAssertion) UnmarshalJSON(b []byte) error { //func (t *JWTProfileAssertion) UnmarshalJSON(b []byte) error {
var j jsonToken // var j jsonToken
if err := json.Unmarshal(b, &j); err != nil { // if err := json.Unmarshal(b, &j); err != nil {
return err // return err
} // }
//
// t.Issuer = j.Issuer
// t.Subject = j.Subject
// t.Audience = audienceFromJSON(j.Audiences)
// t.Expiration = time.Unix(j.Expiration, 0).UTC()
// t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
// t.Scopes = strings.Split(j.Scopes, " ")
//
// return nil
//}
t.Issuer = j.Issuer //
t.Subject = j.Subject //func (j *jsonToken) UnmarshalUserinfoProfile() userInfoProfile {
t.Audience = audienceFromJSON(j.Audiences) // locale, _ := language.Parse(j.Locale)
t.Expiration = time.Unix(j.Expiration, 0).UTC() // return userInfoProfile{
t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC() // Name: j.Name,
t.Scopes = strings.Split(j.Scopes, " ") // GivenName: j.GivenName,
// FamilyName: j.FamilyName,
return nil // MiddleName: j.MiddleName,
} // Nickname: j.Nickname,
// Profile: j.Profile,
func (j *jsonToken) UnmarshalUserinfoProfile() UserinfoProfile { // Picture: j.Picture,
locale, _ := language.Parse(j.Locale) // Website: j.Website,
return UserinfoProfile{ // Gender: Gender(j.Gender),
Name: j.Name, // Birthdate: j.Birthdate,
GivenName: j.GivenName, // Zoneinfo: j.Zoneinfo,
FamilyName: j.FamilyName, // Locale: locale,
MiddleName: j.MiddleName, // UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(),
Nickname: j.Nickname, // PreferredUsername: j.PreferredUsername,
Profile: j.Profile, // }
Picture: j.Picture, //}
Website: j.Website, //
Gender: Gender(j.Gender), //func (j *jsonToken) UnmarshalUserinfoEmail() userInfoEmail {
Birthdate: j.Birthdate, // return userInfoEmail{
Zoneinfo: j.Zoneinfo, // Email: j.Email,
Locale: locale, // EmailVerified: j.EmailVerified,
UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(), // }
PreferredUsername: j.PreferredUsername, //}
} //
} //func (j *jsonToken) UnmarshalUserinfoPhone() userInfoPhone {
// return userInfoPhone{
func (j *jsonToken) UnmarshalUserinfoEmail() UserinfoEmail { // PhoneNumber: j.Phone,
return UserinfoEmail{ // PhoneNumberVerified: j.PhoneVerified,
Email: j.Email, // }
EmailVerified: j.EmailVerified, //}
} //
} //func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress {
// if j.JsonUserinfoAddress == nil {
func (j *jsonToken) UnmarshalUserinfoPhone() UserinfoPhone { // return nil
return UserinfoPhone{ // }
PhoneNumber: j.Phone, // return &UserinfoAddress{
PhoneNumberVerified: j.PhoneVerified, // Country: j.JsonUserinfoAddress.Country,
} // Formatted: j.JsonUserinfoAddress.Formatted,
} // Locality: j.JsonUserinfoAddress.Locality,
// PostalCode: j.JsonUserinfoAddress.PostalCode,
func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress { // Region: j.JsonUserinfoAddress.Region,
if j.JsonUserinfoAddress == nil { // StreetAddress: j.JsonUserinfoAddress.StreetAddress,
return nil // }
} //}
return &UserinfoAddress{
Country: j.JsonUserinfoAddress.Country,
Formatted: j.JsonUserinfoAddress.Formatted,
Locality: j.JsonUserinfoAddress.Locality,
PostalCode: j.JsonUserinfoAddress.PostalCode,
Region: j.JsonUserinfoAddress.Region,
StreetAddress: j.JsonUserinfoAddress.StreetAddress,
}
}
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
hash, err := utils.GetHashAlgorithm(sigAlgorithm) hash, err := utils.GetHashAlgorithm(sigAlgorithm)

101
pkg/oidc/token_request.go Normal file
View file

@ -0,0 +1,101 @@
package oidc
import (
"time"
"gopkg.in/square/go-jose.v2"
)
const (
//GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
GrantTypeCode GrantType = "authorization_code"
//GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
)
type GrantType string
type TokenRequest interface {
// GrantType GrantType `schema:"grant_type"`
GrantType() GrantType
}
type TokenRequestType GrantType
type AccessTokenRequest struct {
Code string `schema:"code"`
RedirectURI string `schema:"redirect_uri"`
ClientID string `schema:"client_id"`
ClientSecret string `schema:"client_secret"`
CodeVerifier string `schema:"code_verifier"`
}
func (a *AccessTokenRequest) GrantType() GrantType {
return GrantTypeCode
}
type JWTTokenRequest struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Scopes Scopes `json:"scope"`
Audience Audience `json:"aud"`
IssuedAt Time `json:"iat"`
ExpiresAt Time `json:"exp"`
}
func (j *JWTTokenRequest) GetClientID() string {
return j.Subject
}
func (j *JWTTokenRequest) GetSubject() string {
return j.Subject
}
func (j *JWTTokenRequest) GetScopes() []string {
return j.Scopes
}
func (j *JWTTokenRequest) GetIssuer() string {
return j.Issuer
}
func (j *JWTTokenRequest) GetAudience() []string {
return j.Audience
}
func (j *JWTTokenRequest) GetExpiration() time.Time {
return time.Time(j.ExpiresAt)
}
func (j *JWTTokenRequest) GetIssuedAt() time.Time {
return time.Time(j.IssuedAt)
}
func (j *JWTTokenRequest) GetNonce() string {
return ""
}
func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string {
return ""
}
func (j *JWTTokenRequest) GetAuthTime() time.Time {
return time.Time{}
}
func (j *JWTTokenRequest) GetAuthorizedParty() string {
return ""
}
func (j *JWTTokenRequest) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {}
type TokenExchangeRequest struct {
subjectToken string `schema:"subject_token"`
subjectTokenType string `schema:"subject_token_type"`
actorToken string `schema:"actor_token"`
actorTokenType string `schema:"actor_token_type"`
resource []string `schema:"resource"`
audience Audience `schema:"audience"`
Scope Scopes `schema:"scope"`
requestedTokenType string `schema:"requested_token_type"`
}

113
pkg/oidc/types.go Normal file
View file

@ -0,0 +1,113 @@
package oidc
import (
"encoding/json"
"strings"
"time"
"golang.org/x/text/language"
)
type Audience []string
func (a *Audience) UnmarshalJSON(text []byte) error {
var i interface{}
err := json.Unmarshal(text, &i)
if err != nil {
return err
}
switch aud := i.(type) {
case []interface{}:
*a = make([]string, len(aud))
for i, audience := range aud {
(*a)[i] = audience.(string)
}
case string:
*a = []string{aud}
}
return nil
}
type Display string
func (d *Display) UnmarshalText(text []byte) error {
display := Display(text)
switch display {
case DisplayPage, DisplayPopup, DisplayTouch, DisplayWAP:
*d = display
}
return nil
}
type Gender string
type Locale language.Tag
//{
// SetLocale(language.Tag)
// Get() language.Tag
//}
//
//func NewLocale(tag language.Tag) Locale {
// if tag.IsRoot() {
// return nil
// }
// return &locale{Tag: tag}
//}
//
//type locale struct {
// language.Tag
//}
//
//func (l *locale) SetLocale(tag language.Tag) {
// l.Tag = tag
//}
//func (l *locale) Get() language.Tag {
// return l.Tag
//}
//func (l *locale) MarshalJSON() ([]byte, error) {
// if l != nil && !l.IsRoot() {
// return l.MarshalText()
// }
// return []byte("null"), nil
//}
type Locales []language.Tag
func (l *Locales) UnmarshalText(text []byte) error {
locales := strings.Split(string(text), " ")
for _, locale := range locales {
tag, err := language.Parse(locale)
if err == nil && !tag.IsRoot() {
*l = append(*l, tag)
}
}
return nil
}
type Prompt string
type ResponseType string
type Scopes []string
func (s *Scopes) UnmarshalText(text []byte) error {
*s = strings.Split(string(text), " ")
return nil
}
type Time time.Time
func (t *Time) UnmarshalJSON(data []byte) error {
var i int64
if err := json.Unmarshal(data, &i); err != nil {
return err
}
*t = Time(time.Unix(i, 0).UTC())
return nil
}
func (t *Time) MarshalJSON() ([]byte, error) {
return json.Marshal(time.Time(*t).UTC().Unix())
}

276
pkg/oidc/types_test.go Normal file
View file

@ -0,0 +1,276 @@
package oidc
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/text/language"
)
func TestAudience_UnmarshalText(t *testing.T) {
type args struct {
text []byte
}
type res struct {
audience Audience
}
tests := []struct {
name string
args args
res res
wantErr bool
}{
{
"unknown value",
args{
[]byte(`{"aud": "single audience"}`),
},
res{
[]string{"single audience"},
},
false,
},
{
"page",
args{
[]byte(`{"aud": ["multiple", "audience"]}`),
},
res{
[]string{"multiple", "audience"},
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := new(struct {
Audience Audience `json:"aud"`
})
if err := json.Unmarshal(tt.args.text, &a); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
}
assert.ElementsMatch(t, a.Audience, tt.res.audience)
})
}
}
func TestDisplay_UnmarshalText(t *testing.T) {
type args struct {
text []byte
}
type res struct {
display Display
}
tests := []struct {
name string
args args
res res
wantErr bool
}{
{
"unknown value",
args{
[]byte("unknown"),
},
res{},
false,
},
{
"page",
args{
[]byte("page"),
},
res{DisplayPage},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var d Display
if err := d.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
}
if d != tt.res.display {
t.Errorf("Display is not correct is = %v, want %v", d, tt.res.display)
}
})
}
}
func TestLocales_UnmarshalText(t *testing.T) {
type args struct {
text []byte
}
type res struct {
tags []language.Tag
}
tests := []struct {
name string
args args
res res
wantErr bool
}{
{
"unknown value",
args{
[]byte("unknown"),
},
res{},
false,
},
{
"undefined",
args{
[]byte("und"),
},
res{},
false,
},
{
"single language",
args{
[]byte("de"),
},
res{[]language.Tag{language.German}},
false,
},
{
"multiple languages",
args{
[]byte("de en"),
},
res{[]language.Tag{language.German, language.English}},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var locales Locales
if err := locales.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
}
assert.ElementsMatch(t, locales, tt.res.tags)
})
}
}
func TestScopes_UnmarshalText(t *testing.T) {
type args struct {
text []byte
}
type res struct {
scopes []string
}
tests := []struct {
name string
args args
res res
wantErr bool
}{
{
"unknown value",
args{
[]byte("unknown"),
},
res{
[]string{"unknown"},
},
false,
},
{
"struct",
args{
[]byte(`{"unknown":"value"}`),
},
res{
[]string{`{"unknown":"value"}`},
},
false,
},
{
"openid",
args{
[]byte("openid"),
},
res{
[]string{"openid"},
},
false,
},
{
"multiple scopes",
args{
[]byte("openid email custom:scope"),
},
res{
[]string{"openid", "email", "custom:scope"},
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var scopes Scopes
if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
}
assert.ElementsMatch(t, scopes, tt.res.scopes)
})
}
}
func TestTime_UnmarshalJSON(t *testing.T) {
type args struct {
text []byte
}
type res struct {
scopes []string
}
tests := []struct {
name string
args args
res res
wantErr bool
}{
{
"unknown value",
args{
[]byte("unknown"),
},
res{
[]string{"unknown"},
},
false,
},
{
"openid",
args{
[]byte("openid"),
},
res{
[]string{"openid"},
},
false,
},
{
"multiple scopes",
args{
[]byte("openid email custom:scope"),
},
res{
[]string{"openid", "email", "custom:scope"},
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var scopes Scopes
if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
}
assert.ElementsMatch(t, scopes, tt.res.scopes)
})
}
}

View file

@ -2,89 +2,312 @@ package oidc
import ( import (
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/caos/oidc/pkg/utils"
) )
type Userinfo struct { type UserInfo interface {
Subject string GetSubject() string
UserinfoProfile UserInfoProfile
UserinfoEmail UserInfoEmail
UserinfoPhone UserInfoPhone
Address *UserinfoAddress GetAddress() UserInfoAddress
GetClaim(key string) interface{}
}
Authorizations []string type UserInfoProfile interface {
GetName() string
GetGivenName() string
GetFamilyName() string
GetMiddleName() string
GetNickname() string
GetProfile() string
GetPicture() string
GetWebsite() string
GetGender() Gender
GetBirthdate() string
GetZoneinfo() string
GetLocale() language.Tag
GetPreferredUsername() string
}
type UserInfoEmail interface {
GetEmail() string
IsEmailVerified() bool
}
type UserInfoPhone interface {
GetPhoneNumber() string
IsPhoneNumberVerified() bool
}
type UserInfoAddress interface {
GetFormatted() string
GetStreetAddress() string
GetLocality() string
GetRegion() string
GetPostalCode() string
GetCountry() string
}
type UserInfoSetter interface {
UserInfo
SetSubject(sub string)
UserInfoProfileSetter
SetEmail(email string, verified bool)
SetPhone(phone string, verified bool)
SetAddress(address UserInfoAddress)
AppendClaims(key string, values interface{})
}
type UserInfoProfileSetter interface {
SetName(name string)
SetGivenName(name string)
SetFamilyName(name string)
SetMiddleName(name string)
SetNickname(name string)
SetUpdatedAt(date time.Time)
SetProfile(profile string)
SetPicture(profile string)
SetWebsite(website string)
SetGender(gender Gender)
SetBirthdate(birthdate string)
SetZoneinfo(zoneInfo string)
SetLocale(locale language.Tag)
SetPreferredUsername(name string)
}
func NewUserInfo() UserInfoSetter {
return &userinfo{}
}
type userinfo struct {
Subject string `json:"sub,omitempty"`
userInfoProfile
userInfoEmail
userInfoPhone
Address UserInfoAddress `json:"address,omitempty"`
claims map[string]interface{} claims map[string]interface{}
} }
type UserinfoProfile struct { func (u *userinfo) GetSubject() string {
Name string return u.Subject
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender Gender
Birthdate string
Zoneinfo string
Locale language.Tag
UpdatedAt time.Time
PreferredUsername string
} }
type Gender string func (u *userinfo) GetName() string {
return u.Name
type UserinfoEmail struct {
Email string
EmailVerified bool
} }
type UserinfoPhone struct { func (u *userinfo) GetGivenName() string {
PhoneNumber string return u.GivenName
PhoneNumberVerified bool
} }
type UserinfoAddress struct { func (u *userinfo) GetFamilyName() string {
Formatted string return u.FamilyName
StreetAddress string
Locality string
Region string
PostalCode string
Country string
} }
type jsonUserinfoProfile struct { func (u *userinfo) GetMiddleName() string {
Name string `json:"name,omitempty"` return u.MiddleName
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
MiddleName string `json:"middle_name,omitempty"`
Nickname string `json:"nickname,omitempty"`
Profile string `json:"profile,omitempty"`
Picture string `json:"picture,omitempty"`
Website string `json:"website,omitempty"`
Gender string `json:"gender,omitempty"`
Birthdate string `json:"birthdate,omitempty"`
Zoneinfo string `json:"zoneinfo,omitempty"`
Locale string `json:"locale,omitempty"`
UpdatedAt int64 `json:"updated_at,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
} }
type jsonUserinfoEmail struct { func (u *userinfo) GetNickname() string {
return u.Nickname
}
func (u *userinfo) GetProfile() string {
return u.Profile
}
func (u *userinfo) GetPicture() string {
return u.Picture
}
func (u *userinfo) GetWebsite() string {
return u.Website
}
func (u *userinfo) GetGender() Gender {
return u.Gender
}
func (u *userinfo) GetBirthdate() string {
return u.Birthdate
}
func (u *userinfo) GetZoneinfo() string {
return u.Zoneinfo
}
func (u *userinfo) GetLocale() language.Tag {
return u.Locale
}
func (u *userinfo) GetPreferredUsername() string {
return u.PreferredUsername
}
func (u *userinfo) GetEmail() string {
return u.Email
}
func (u *userinfo) IsEmailVerified() bool {
return u.EmailVerified
}
func (u *userinfo) GetPhoneNumber() string {
return u.PhoneNumber
}
func (u *userinfo) IsPhoneNumberVerified() bool {
return u.PhoneNumberVerified
}
func (u *userinfo) GetAddress() UserInfoAddress {
return u.Address
}
func (u *userinfo) GetClaim(key string) interface{} {
return u.claims[key]
}
func (u *userinfo) SetSubject(sub string) {
u.Subject = sub
}
func (u *userinfo) SetName(name string) {
u.Name = name
}
func (u *userinfo) SetGivenName(name string) {
u.GivenName = name
}
func (u *userinfo) SetFamilyName(name string) {
u.FamilyName = name
}
func (u *userinfo) SetMiddleName(name string) {
u.MiddleName = name
}
func (u *userinfo) SetNickname(name string) {
u.Nickname = name
}
func (u *userinfo) SetUpdatedAt(date time.Time) {
u.UpdatedAt = Time(date)
}
func (u *userinfo) SetProfile(profile string) {
u.Profile = profile
}
func (u *userinfo) SetPicture(picture string) {
u.Picture = picture
}
func (u *userinfo) SetWebsite(website string) {
u.Website = website
}
func (u *userinfo) SetGender(gender Gender) {
u.Gender = gender
}
func (u *userinfo) SetBirthdate(birthdate string) {
u.Birthdate = birthdate
}
func (u *userinfo) SetZoneinfo(zoneInfo string) {
u.Zoneinfo = zoneInfo
}
func (u *userinfo) SetLocale(locale language.Tag) {
u.Locale = locale
}
func (u *userinfo) SetPreferredUsername(name string) {
u.PreferredUsername = name
}
func (u *userinfo) SetEmail(email string, verified bool) {
u.Email = email
u.EmailVerified = verified
}
func (u *userinfo) SetPhone(phone string, verified bool) {
u.PhoneNumber = phone
u.PhoneNumberVerified = verified
}
func (u *userinfo) SetAddress(address UserInfoAddress) {
u.Address = address
}
func (u *userinfo) AppendClaims(key string, value interface{}) {
if u.claims == nil {
u.claims = make(map[string]interface{})
}
u.claims[key] = value
}
func (u *userInfoAddress) GetFormatted() string {
panic("implement me")
}
func (u *userInfoAddress) GetStreetAddress() string {
panic("implement me")
}
func (u *userInfoAddress) GetLocality() string {
panic("implement me")
}
func (u *userInfoAddress) GetRegion() string {
panic("implement me")
}
func (u *userInfoAddress) GetPostalCode() string {
panic("implement me")
}
func (u *userInfoAddress) GetCountry() string {
panic("implement me")
}
type userInfoProfile struct {
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
MiddleName string `json:"middle_name,omitempty"`
Nickname string `json:"nickname,omitempty"`
Profile string `json:"profile,omitempty"`
Picture string `json:"picture,omitempty"`
Website string `json:"website,omitempty"`
Gender Gender `json:"gender,omitempty"`
Birthdate string `json:"birthdate,omitempty"`
Zoneinfo string `json:"zoneinfo,omitempty"`
Locale language.Tag `json:"locale,omitempty"`
UpdatedAt Time `json:"updated_at,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
}
type userInfoEmail struct {
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"` EmailVerified bool `json:"email_verified,omitempty"`
} }
type jsonUserinfoPhone struct { type userInfoPhone struct {
Phone string `json:"phone_number,omitempty"` PhoneNumber string `json:"phone_number,omitempty"`
PhoneVerified bool `json:"phone_number_verified,omitempty"` PhoneNumberVerified bool `json:"phone_number_verified,omitempty"`
} }
type jsonUserinfoAddress struct { type userInfoAddress struct {
Formatted string `json:"formatted,omitempty"` Formatted string `json:"formatted,omitempty"`
StreetAddress string `json:"street_address,omitempty"` StreetAddress string `json:"street_address,omitempty"`
Locality string `json:"locality,omitempty"` Locality string `json:"locality,omitempty"`
@ -93,81 +316,68 @@ type jsonUserinfoAddress struct {
Country string `json:"country,omitempty"` Country string `json:"country,omitempty"`
} }
func (i *Userinfo) MarshalJSON() ([]byte, error) { func NewUserInfoAddress(streetAddress, locality, region, postalCode, country, formatted string) UserInfoAddress {
j := new(jsonUserinfo) return &userInfoAddress{
j.Subject = i.Subject StreetAddress: streetAddress,
j.setUserinfo(*i) Locality: locality,
j.Authorizations = i.Authorizations Region: region,
return json.Marshal(j) PostalCode: postalCode,
Country: country,
Formatted: formatted,
}
}
func (i *userinfo) MarshalJSON() ([]byte, error) {
type Alias userinfo
a := &struct {
*Alias
Locale interface{} `json:"locale,omitempty"`
UpdatedAt int64 `json:"updated_at,omitempty"`
}{
Alias: (*Alias)(i),
}
if !i.Locale.IsRoot() {
a.Locale = i.Locale
}
fmt.Println(time.Time(i.UpdatedAt).String())
if !time.Time(i.UpdatedAt).IsZero() {
a.UpdatedAt = time.Time(i.UpdatedAt).Unix()
}
b, err := json.Marshal(a)
if err != nil {
return nil, err
}
if len(i.claims) == 0 {
return b, nil
}
claims, err := json.Marshal(i.claims)
if err != nil {
return nil, fmt.Errorf("jws: invalid map of private claims %v", i.claims)
}
return utils.ConcatenateJSON(b, claims)
} }
func (i *Userinfo) UnmmarshalJSON(data []byte) error { func (i *userinfo) UnmarshalJSON(data []byte) error {
if err := json.Unmarshal(data, i); err != nil { type Alias userinfo
a := &struct {
*Alias
//Locale interface{} `json:"locale,omitempty"`
UpdatedAt int64 `json:"update_at,omitempty"`
}{
Alias: (*Alias)(i),
}
if err := json.Unmarshal(data, &a); err != nil {
return err return err
} }
return json.Unmarshal(data, &i.claims) //if !i.Locale.IsRoot() {
} // a.Locale = i.Locale
//}
type jsonUserinfo struct { i.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC())
Subject string `json:"sub,omitempty"`
jsonUserinfoProfile
jsonUserinfoEmail
jsonUserinfoPhone
JsonUserinfoAddress *jsonUserinfoAddress `json:"address,omitempty"`
Authorizations []string `json:"authorizations,omitempty"`
}
func (j *jsonUserinfo) setUserinfo(i Userinfo) { return nil
j.setUserinfoProfile(i.UserinfoProfile)
j.setUserinfoEmail(i.UserinfoEmail)
j.setUserinfoPhone(i.UserinfoPhone)
j.setUserinfoAddress(i.Address)
}
func (j *jsonUserinfo) setUserinfoProfile(i UserinfoProfile) {
j.Name = i.Name
j.GivenName = i.GivenName
j.FamilyName = i.FamilyName
j.MiddleName = i.MiddleName
j.Nickname = i.Nickname
j.Profile = i.Profile
j.Picture = i.Picture
j.Website = i.Website
j.Gender = string(i.Gender)
j.Birthdate = i.Birthdate
j.Zoneinfo = i.Zoneinfo
if i.Locale != language.Und {
j.Locale = i.Locale.String()
}
j.UpdatedAt = timeToJSON(i.UpdatedAt)
j.PreferredUsername = i.PreferredUsername
}
func (j *jsonUserinfo) setUserinfoEmail(i UserinfoEmail) {
j.Email = i.Email
j.EmailVerified = i.EmailVerified
}
func (j *jsonUserinfo) setUserinfoPhone(i UserinfoPhone) {
j.Phone = i.PhoneNumber
j.PhoneVerified = i.PhoneNumberVerified
}
func (j *jsonUserinfo) setUserinfoAddress(i *UserinfoAddress) {
if i == nil {
return
}
if i.Country == "" && i.Formatted == "" && i.Locality == "" && i.PostalCode == "" && i.Region == "" && i.StreetAddress == "" {
return
}
j.JsonUserinfoAddress = &jsonUserinfoAddress{
Country: i.Country,
Formatted: i.Formatted,
Locality: i.Locality,
PostalCode: i.PostalCode,
Region: i.Region,
StreetAddress: i.StreetAddress,
}
} }
type UserInfoRequest struct { type UserInfoRequest struct {

View file

@ -24,7 +24,7 @@ type Claims interface {
GetAuthenticationContextClassReference() string GetAuthenticationContextClassReference() string
GetAuthTime() time.Time GetAuthTime() time.Time
GetAuthorizedParty() string GetAuthorizedParty() string
SetSignature(algorithm jose.SignatureAlgorithm) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm)
} }
var ( var (
@ -140,7 +140,7 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl
return ErrSignatureInvalidPayload return ErrSignatureInvalidPayload
} }
claims.SetSignature(jose.SignatureAlgorithm(sig.Header.Algorithm)) claims.SetSignatureAlgorithm(jose.SignatureAlgorithm(sig.Header.Algorithm))
return nil return nil
} }

View file

@ -168,7 +168,7 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
if err != nil { if err != nil {
return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.") return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.")
} }
return claims.Subject, nil return claims.GetSubject(), nil
} }
//RedirectToLogin redirects the end user to the Login UI for authentication //RedirectToLogin redirects the end user to the Login UI for authentication

View file

@ -81,7 +81,7 @@ func (s *Sig) Health(ctx context.Context) error {
func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) { func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) {
return "", nil return "", nil
} }
func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) { func (s *Sig) SignAccessToken(*oidc.accessTokenClaims) (string, error) {
return "", nil return "", nil
} }
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm { func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {

View file

@ -50,7 +50,7 @@ func (mr *MockSignerMockRecorder) Health(arg0 interface{}) *gomock.Call {
} }
// SignAccessToken mocks base method // SignAccessToken mocks base method
func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) { func (m *MockSigner) SignAccessToken(arg0 *oidc.accessTokenClaims) (string, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignAccessToken", arg0) ret := m.ctrl.Call(m, "SignAccessToken", arg0)
ret0, _ := ret[0].(string) ret0, _ := ret[0].(string)

View file

@ -184,10 +184,10 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac
} }
// GetUserinfoFromScopes mocks base method // GetUserinfoFromScopes mocks base method
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.Userinfo, error) { func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.userinfo, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2) ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2)
ret0, _ := ret[0].(*oidc.Userinfo) ret0, _ := ret[0].(*oidc.userinfo)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@ -199,10 +199,10 @@ func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interf
} }
// GetUserinfoFromToken mocks base method // GetUserinfoFromToken mocks base method
func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (*oidc.Userinfo, error) { func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (*oidc.userinfo, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2) ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2)
ret0, _ := ret[0].(*oidc.Userinfo) ret0, _ := ret[0].(*oidc.userinfo)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }

View file

@ -130,7 +130,7 @@ func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opO
} }
keyCh := make(chan jose.SigningKey) keyCh := make(chan jose.SigningKey)
o.signer = NewDefaultSigner(ctx, storage, keyCh) o.signer = NewSigner(ctx, storage, keyCh)
go EnsureKey(ctx, storage, keyCh, o.timer, o.retry) go EnsureKey(ctx, storage, keyCh, o.timer, o.retry)
o.httpHandler = CreateRouter(o, o.interceptors...) o.httpHandler = CreateRouter(o, o.interceptors...)

View file

@ -66,8 +66,8 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
if err != nil { if err != nil {
return nil, ErrInvalidRequest("id_token_hint invalid") return nil, ErrInvalidRequest("id_token_hint invalid")
} }
session.UserID = claims.Subject session.UserID = claims.GetSubject()
session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.AuthorizedParty) session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.GetAuthorizedParty())
if err != nil { if err != nil {
return nil, ErrServerError("") return nil, ErrServerError("")
} }

View file

@ -2,19 +2,17 @@ package op
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"github.com/caos/logging" "github.com/caos/logging"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
) )
type Signer interface { type Signer interface {
Health(ctx context.Context) error Health(ctx context.Context) error
SignIDToken(claims *oidc.IDTokenClaims) (string, error) //SignIDToken(claims *oidc.IDTokenClaims) (string, error)
SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) //SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
Signer() jose.Signer
SignatureAlgorithm() jose.SignatureAlgorithm SignatureAlgorithm() jose.SignatureAlgorithm
} }
@ -24,7 +22,7 @@ type tokenSigner struct {
alg jose.SignatureAlgorithm alg jose.SignatureAlgorithm
} }
func NewDefaultSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer { func NewSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer {
s := &tokenSigner{ s := &tokenSigner{
storage: storage, storage: storage,
} }
@ -41,6 +39,15 @@ func (s *tokenSigner) Health(_ context.Context) error {
return nil return nil
} }
func (s *tokenSigner) Signer() jose.Signer {
return s.signer
}
//
//func (s *tokenSigner) Sign(payload []byte) (*jose.JSONWebSignature, error) {
// return s.signer.Sign(payload)
//}
func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) { func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) {
for { for {
select { select {
@ -55,30 +62,6 @@ func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.S
} }
} }
func (s *tokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", err
}
return s.Sign(payload)
}
func (s *tokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", err
}
return s.Sign(payload)
}
func (s *tokenSigner) Sign(payload []byte) (string, error) {
result, err := s.signer.Sign(payload)
if err != nil {
return "", err
}
return result.CompactSerialize()
}
func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm { func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
return s.alg return s.alg
} }

View file

@ -38,13 +38,13 @@ import (
// } // }
// for _, tt := range tests { // for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) { // t.Run(tt.name, func(t *testing.T) {
// got, err := op.NewDefaultSigner(tt.args.storage) // got, err := op.NewSigner(tt.args.storage)
// if (err != nil) != tt.wantErr { // if (err != nil) != tt.wantErr {
// t.Errorf("NewDefaultSigner() error = %v, wantErr %v", err, tt.wantErr) // t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
// return // return
// } // }
// if !reflect.DeepEqual(got, tt.want) { // if !reflect.DeepEqual(got, tt.want) {
// t.Errorf("NewDefaultSigner() = %v, want %v", got, tt.want) // t.Errorf("NewSigner() = %v, want %v", got, tt.want)
// } // }
// }) // })
// } // }

View file

@ -28,8 +28,8 @@ type AuthStorage interface {
type OPStorage interface { type OPStorage interface {
GetClientByClientID(context.Context, string) (Client, error) GetClientByClientID(context.Context, string) (Client, error)
AuthorizeClientIDSecret(context.Context, string, string) error AuthorizeClientIDSecret(context.Context, string, string) error
GetUserinfoFromScopes(context.Context, string, []string) (*oidc.Userinfo, error) GetUserinfoFromScopes(context.Context, string, []string) (oidc.UserInfoSetter, error)
GetUserinfoFromToken(context.Context, string, string) (*oidc.Userinfo, error) GetUserinfoFromToken(context.Context, string, string) (oidc.UserInfoSetter, error)
GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
} }

View file

@ -5,6 +5,7 @@ import (
"time" "time"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
type TokenCreator interface { type TokenCreator interface {
@ -82,51 +83,34 @@ func CreateBearerToken(id string, crypto Crypto) (string, error) {
} }
func CreateJWT(issuer string, authReq TokenRequest, exp time.Time, id string, signer Signer) (string, error) { func CreateJWT(issuer string, authReq TokenRequest, exp time.Time, id string, signer Signer) (string, error) {
now := time.Now().UTC() claims := oidc.NewAccessTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, id)
nbf := now return utils.Sign(claims, signer.Signer())
claims := &oidc.AccessTokenClaims{
Issuer: issuer,
Subject: authReq.GetSubject(),
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: now,
NotBefore: nbf,
JWTID: id,
}
return signer.SignAccessToken(claims)
} }
func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer) (string, error) { func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer) (string, error) {
var err error
exp := time.Now().UTC().Add(validity) exp := time.Now().UTC().Add(validity)
userinfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes()) claims := oidc.NewIDTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, authReq.GetAuthTime(), authReq.GetNonce(), authReq.GetACR(), authReq.GetAMR(), authReq.GetClientID())
if err != nil {
return "", err
}
claims := &oidc.IDTokenClaims{
Issuer: issuer,
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: time.Now().UTC(),
AuthTime: authReq.GetAuthTime(),
Nonce: authReq.GetNonce(),
AuthenticationContextClassReference: authReq.GetACR(),
AuthenticationMethodsReferences: authReq.GetAMR(),
AuthorizedParty: authReq.GetClientID(),
Userinfo: *userinfo,
}
if accessToken != "" { if accessToken != "" {
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm()) atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
if err != nil { if err != nil {
return "", err return "", err
} }
claims.SetAccessTokenHash(atHash)
} else {
userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes())
if err != nil {
return "", err
}
claims.SetUserinfo(userInfo)
} }
if code != "" { if code != "" {
claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm()) codeHash, err := oidc.ClaimHash(code, signer.SignatureAlgorithm())
if err != nil { if err != nil {
return "", err return "", err
} }
claims.SetCodeHash(codeHash)
} }
return signer.SignIDToken(claims) return utils.Sign(claims, signer.Signer())
} }

View file

@ -63,8 +63,8 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet) IDTokenHintVerifi
//VerifyIDTokenHint validates the id token according to //VerifyIDTokenHint validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation //https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (*oidc.IDTokenClaims, error) { func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (oidc.IDTokenClaims, error) {
claims := new(oidc.IDTokenClaims) claims := oidc.EmptyIDTokenClaims()
decrypted, err := oidc.DecryptToken(token) decrypted, err := oidc.DecryptToken(token)
if err != nil { if err != nil {

View file

@ -33,5 +33,5 @@ func NewMockVerifierExpectValid(t *testing.T) rp.Verifier {
func ExpectVerifyValid(v rp.Verifier) { func ExpectVerifyValid(v rp.Verifier) {
mock := v.(*MockVerifier) mock := v.(*MockVerifier)
mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.Userinfo{Subject: "id"}}, nil) mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.userinfo{Subject: "id"}}, nil)
} }

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"net/url"
"strings" "strings"
"github.com/google/uuid" "github.com/google/uuid"
@ -329,10 +330,10 @@ func CallTokenEndpoint(request interface{}, rp RelayingParty) (newToken *oauth2.
} }
func CallJWTProfileEndpoint(assertion string, rp RelayingParty) (*oauth2.Token, error) { func CallJWTProfileEndpoint(assertion string, rp RelayingParty) (*oauth2.Token, error) {
form := make(map[string][]string) form := url.Values{}
form["assertion"] = []string{assertion} form.Add("assertion", assertion)
form["grant_type"] = []string{jwtProfileKey} form.Add("grant_type", jwtProfileKey)
req, err := http.NewRequest("POST", rp.OAuthConfig().Endpoint.TokenURL, nil) req, err := http.NewRequest("POST", rp.OAuthConfig().Endpoint.TokenURL, strings.NewReader(form.Encode()))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -21,12 +21,12 @@ type IDTokenVerifier interface {
//VerifyTokens implement the Token Response Validation as defined in OIDC specification //VerifyTokens implement the Token Response Validation as defined in OIDC specification
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation //https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) { func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
idToken, err := VerifyIDToken(ctx, idTokenString, v) idToken, err := VerifyIDToken(ctx, idTokenString, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := VerifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil { if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil {
return nil, err return nil, err
} }
return idToken, nil return idToken, nil
@ -34,8 +34,8 @@ func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTo
//VerifyIDToken validates the id token according to //VerifyIDToken validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation //https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) { func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
claims := new(oidc.IDTokenClaims) claims := oidc.EmptyIDTokenClaims()
decrypted, err := oidc.DecryptToken(token) decrypted, err := oidc.DecryptToken(token)
if err != nil { if err != nil {

View file

@ -1,7 +1,9 @@
package utils package utils
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -19,3 +21,15 @@ func MarshalJSON(w http.ResponseWriter, i interface{}) {
logrus.Error("error writing response") logrus.Error("error writing response")
} }
} }
func ConcatenateJSON(first, second []byte) ([]byte, error) {
if !bytes.HasSuffix(first, []byte{'}'}) {
return nil, fmt.Errorf("jws: invalid JSON %s", first)
}
if !bytes.HasPrefix(second, []byte{'{'}) {
return nil, fmt.Errorf("jws: invalid JSON %s", second)
}
first[len(first)-1] = ','
first = append(first, second[1:]...)
return first, nil
}

23
pkg/utils/sign.go Normal file
View file

@ -0,0 +1,23 @@
package utils
import (
"encoding/json"
"gopkg.in/square/go-jose.v2"
)
func Sign(object interface{}, signer jose.Signer) (string, error) {
payload, err := json.Marshal(object)
if err != nil {
return "", err
}
return SignPayload(payload, signer)
}
func SignPayload(payload []byte, signer jose.Signer) (string, error) {
result, err := signer.Sign(payload)
if err != nil {
return "", err
}
return result.CompactSerialize()
}