diff --git a/example/client/app/app.go b/example/client/app/app.go index 4c0831b..ea1e6e7 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "html/template" + "io/ioutil" "net/http" "os" "time" @@ -30,7 +32,7 @@ func main() { ctx := context.Background() redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) - scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail} + scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress, "hodor"} cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, rp.WithPKCE(cookieHandler), @@ -82,6 +84,62 @@ func main() { } w.Write(data) }) + + http.HandleFunc("/jwt-profile", func(w http.ResponseWriter, r *http.Request) { + tpl := ` + + + + + Login + + +
+ + + +
+ + ` + t, err := template.New("login").Parse(tpl) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + err = t.Execute(w, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + }) + + http.HandleFunc("/jwt-profile-assertion", func(w http.ResponseWriter, r *http.Request) { + r.ParseMultipartForm(32 << 20) + file, handler, err := r.FormFile("key") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer file.Close() + + key, err := ioutil.ReadAll(file) + fmt.Println(handler.Header) + assertion, err := oidc.NewJWTProfileAssertionFromFileData(key, []string{issuer}) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + token, err := rp.JWTProfileExchange(ctx, assertion, provider) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + data, err := json.Marshal(token) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Write(data) + }) lis := fmt.Sprintf("127.0.0.1:%s", port) logrus.Infof("listening on http://%s/", lis) logrus.Fatal(http.ListenAndServe("127.0.0.1:"+port, nil)) diff --git a/example/client/jwt_profile.go b/example/client/jwt_profile.go new file mode 100644 index 0000000..6dcd11b --- /dev/null +++ b/example/client/jwt_profile.go @@ -0,0 +1,39 @@ +package client + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/sirupsen/logrus" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/rp" + "github.com/caos/oidc/pkg/utils" +) + +var ( + callbackPath string = "/auth/callback" + key []byte = []byte("test1234test1234") +) + +func main() { + clientID := os.Getenv("CLIENT_ID") + clientSecret := os.Getenv("CLIENT_SECRET") + issuer := os.Getenv("ISSUER") + port := os.Getenv("PORT") + + ctx := context.Background() + + redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) + scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress, "hodor"} + cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) + provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, + rp.WithPKCE(cookieHandler), + rp.WithVerifierOpts(rp.WithIssuedAtOffset(5*time.Second)), + ) + if err != nil { + logrus.Fatalf("error creating provider %s", err.Error()) + } +} diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index f20fb9b..74e0ed7 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -210,24 +210,24 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st return nil } -func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (*oidc.Userinfo, error) { +func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (*oidc.userinfo, error) { return s.GetUserinfoFromScopes(ctx, "", []string{}) } -func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (*oidc.Userinfo, error) { - return &oidc.Userinfo{ +func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (*oidc.userinfo, error) { + return &oidc.userinfo{ Subject: a.GetSubject(), Address: &oidc.UserinfoAddress{ StreetAddress: "Hjkhkj 789\ndsf", }, - UserinfoEmail: oidc.UserinfoEmail{ + userinfoEmail: oidc.userinfoEmail{ Email: "test", EmailVerified: true, }, - UserinfoPhone: oidc.UserinfoPhone{ + userinfoPhone: oidc.userinfoPhone{ PhoneNumber: "sadsa", PhoneNumberVerified: true, }, - UserinfoProfile: oidc.UserinfoProfile{ + userinfoProfile: oidc.userinfoProfile{ UpdatedAt: time.Now(), }, // Claims: map[string]interface{}{ diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index 35da2fb..71776af 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -1,15 +1,5 @@ package oidc -import ( - "encoding/json" - "errors" - "strings" - "time" - - "golang.org/x/text/language" - "gopkg.in/square/go-jose.v2" -) - const ( //ScopeOpenID defines the scope `openid` //OpenID Connect requests MUST contain the `openid` scope value @@ -64,23 +54,8 @@ const ( //PromptSelectAccount (`select_account `) directs the Authorization Server to prompt the End-User to select a user account (to enable multi user / session switching) PromptSelectAccount Prompt = "select_account" - - //GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow - GrantTypeCode GrantType = "authorization_code" - //GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant - GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" - - //BearerToken defines the token_type `Bearer`, which is returned in a successful token response - BearerToken = "Bearer" ) -var displayValues = map[string]Display{ - "page": DisplayPage, - "popup": DisplayPopup, - "touch": DisplayTouch, - "wap": DisplayWAP, -} - //AuthRequest according to: //https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest type AuthRequest struct { @@ -121,146 +96,3 @@ func (a *AuthRequest) GetResponseType() ResponseType { func (a *AuthRequest) GetState() string { return a.State } - -type TokenRequest interface { - // GrantType GrantType `schema:"grant_type"` - GrantType() GrantType -} - -type TokenRequestType GrantType - -type AccessTokenRequest struct { - Code string `schema:"code"` - RedirectURI string `schema:"redirect_uri"` - ClientID string `schema:"client_id"` - ClientSecret string `schema:"client_secret"` - CodeVerifier string `schema:"code_verifier"` -} - -func (a *AccessTokenRequest) GrantType() GrantType { - return GrantTypeCode -} - -type AccessTokenResponse struct { - AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"` - TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"` - RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"` - ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"` - IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"` -} - -type JWTTokenRequest struct { - Issuer string `json:"iss"` - Subject string `json:"sub"` - Scopes Scopes `json:"scope"` - Audience interface{} `json:"aud"` - IssuedAt Time `json:"iat"` - ExpiresAt Time `json:"exp"` -} - -func (j *JWTTokenRequest) GetClientID() string { - return j.Subject -} - -func (j *JWTTokenRequest) GetSubject() string { - return j.Subject -} - -func (j *JWTTokenRequest) GetScopes() []string { - return j.Scopes -} - -type Time time.Time - -func (t *Time) UnmarshalJSON(data []byte) error { - var i int64 - if err := json.Unmarshal(data, &i); err != nil { - return err - } - *t = Time(time.Unix(i, 0).UTC()) - return nil -} - -func (j *JWTTokenRequest) GetIssuer() string { - return j.Issuer -} - -func (j *JWTTokenRequest) GetAudience() []string { - return audienceFromJSON(j.Audience) -} - -func (j *JWTTokenRequest) GetExpiration() time.Time { - return time.Time(j.ExpiresAt) -} - -func (j *JWTTokenRequest) GetIssuedAt() time.Time { - return time.Time(j.IssuedAt) -} - -func (j *JWTTokenRequest) GetNonce() string { - return "" -} - -func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string { - return "" -} - -func (j *JWTTokenRequest) GetAuthTime() time.Time { - return time.Time{} -} - -func (j *JWTTokenRequest) GetAuthorizedParty() string { - return "" -} - -func (j *JWTTokenRequest) SetSignature(algorithm jose.SignatureAlgorithm) {} - -type TokenExchangeRequest struct { - subjectToken string `schema:"subject_token"` - subjectTokenType string `schema:"subject_token_type"` - actorToken string `schema:"actor_token"` - actorTokenType string `schema:"actor_token_type"` - resource []string `schema:"resource"` - audience []string `schema:"audience"` - Scope []string `schema:"scope"` - requestedTokenType string `schema:"requested_token_type"` -} - -type Scopes []string - -func (s *Scopes) UnmarshalText(text []byte) error { - scopes := strings.Split(string(text), " ") - *s = Scopes(scopes) - return nil -} - -type ResponseType string - -type Display string - -func (d *Display) UnmarshalText(text []byte) error { - var ok bool - display := string(text) - *d, ok = displayValues[display] - if !ok { - return errors.New("") - } - return nil -} - -type Prompt string - -type Locales []language.Tag - -func (l *Locales) UnmarshalText(text []byte) error { - locales := strings.Split(string(text), " ") - for _, locale := range locales { - tag, err := language.Parse(locale) - if err == nil && !tag.IsRoot() { - *l = append(*l, tag) - } - } - return nil -} - -type GrantType string diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 21b0419..e20dd4a 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -3,33 +3,55 @@ package oidc import ( "encoding/json" "io/ioutil" - "strings" "time" "golang.org/x/oauth2" - "golang.org/x/text/language" "gopkg.in/square/go-jose.v2" "github.com/caos/oidc/pkg/utils" ) +const ( + //BearerToken defines the token_type `Bearer`, which is returned in a successful token response + BearerToken = "Bearer" +) + type Tokens struct { *oauth2.Token - IDTokenClaims *IDTokenClaims + IDTokenClaims IDTokenClaims IDToken string } -type AccessTokenClaims struct { +type AccessTokenClaims interface { + Claims +} + +type IDTokenClaims interface { + Claims + GetNotBefore() time.Time + GetJWTID() string + GetAccessTokenHash() string + GetCodeHash() string + GetAuthenticationMethodsReferences() []string + GetClientID() string + GetSignatureAlgorithm() jose.SignatureAlgorithm + SetAccessTokenHash(hash string) + SetUserinfo(userinfo UserInfoSetter) + SetCodeHash(hash string) + UserInfo +} + +type accessTokenClaims struct { Issuer string Subject string - Audiences []string - Expiration time.Time - IssuedAt time.Time - NotBefore time.Time + Audience Audience + Expiration Time + IssuedAt Time + NotBefore Time JWTID string AuthorizedParty string Nonce string - AuthTime time.Time + AuthTime Time CodeHash string AuthenticationContextClassReference string AuthenticationMethodsReferences []string @@ -37,38 +59,155 @@ type AccessTokenClaims struct { Scopes []string ClientID string AccessTokenUseNumber int + + signatureAlg jose.SignatureAlgorithm } -type IDTokenClaims struct { - Issuer string - Audiences []string - Expiration time.Time - NotBefore time.Time - IssuedAt time.Time - JWTID string - UpdatedAt time.Time - AuthorizedParty string - Nonce string - AuthTime time.Time - AccessTokenHash string - CodeHash string - AuthenticationContextClassReference string - AuthenticationMethodsReferences []string - ClientID string - Userinfo +func (a accessTokenClaims) GetIssuer() string { + return a.Issuer +} - Signature jose.SignatureAlgorithm //TODO: ??? +func (a accessTokenClaims) GetAudience() []string { + return a.Audience +} + +func (a accessTokenClaims) GetExpiration() time.Time { + return time.Time(a.Expiration) +} + +func (a accessTokenClaims) GetIssuedAt() time.Time { + return time.Time(a.IssuedAt) +} + +func (a accessTokenClaims) GetNonce() string { + return a.Nonce +} + +func (a accessTokenClaims) GetAuthenticationContextClassReference() string { + return a.AuthenticationContextClassReference +} + +func (a accessTokenClaims) GetAuthTime() time.Time { + return time.Time(a.AuthTime) +} + +func (a accessTokenClaims) GetAuthorizedParty() string { + return a.AuthorizedParty +} + +func (a accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) { + a.signatureAlg = algorithm +} + +func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, id string) AccessTokenClaims { + now := time.Now().UTC() + return &accessTokenClaims{ + Issuer: issuer, + Subject: subject, + Audience: audience, + Expiration: Time(expiration), + IssuedAt: Time(now), + NotBefore: Time(now), + JWTID: id, + } +} + +type idTokenClaims struct { + Issuer string `json:"iss,omitempty"` + Audience Audience `json:"aud,omitempty"` + Expiration Time `json:"exp,omitempty"` + NotBefore Time `json:"nbf,omitempty"` + IssuedAt Time `json:"iat,omitempty"` + JWTID string `json:"jti,omitempty"` + AuthorizedParty string `json:"azp,omitempty"` + Nonce string `json:"nonce,omitempty"` + AuthTime Time `json:"auth_time,omitempty"` + AccessTokenHash string `json:"at_hash,omitempty"` + CodeHash string `json:"c_hash,omitempty"` + AuthenticationContextClassReference string `json:"acr,omitempty"` + AuthenticationMethodsReferences []string `json:"amr,omitempty"` + ClientID string `json:"client_id,omitempty"` + UserInfo `json:"-"` + + signatureAlg jose.SignatureAlgorithm +} + +func (t *idTokenClaims) SetAccessTokenHash(hash string) { + t.AccessTokenHash = hash +} + +func (t *idTokenClaims) SetUserinfo(info UserInfoSetter) { + t.UserInfo = info +} + +func (t *idTokenClaims) SetCodeHash(hash string) { + t.CodeHash = hash +} + +func EmptyIDTokenClaims() IDTokenClaims { + return new(idTokenClaims) +} + +func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string) IDTokenClaims { + return &idTokenClaims{ + Issuer: issuer, + Audience: audience, + Expiration: Time(expiration), + IssuedAt: Time(time.Now().UTC()), + AuthTime: Time(authTime), + Nonce: nonce, + AuthenticationContextClassReference: acr, + AuthenticationMethodsReferences: amr, + AuthorizedParty: clientID, + UserInfo: &userinfo{Subject: subject}, + } +} + +func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm { + return t.signatureAlg +} + +func (t *idTokenClaims) GetNotBefore() time.Time { + return time.Time(t.NotBefore) +} + +func (t *idTokenClaims) GetJWTID() string { + return t.JWTID +} + +func (t *idTokenClaims) GetAccessTokenHash() string { + return t.AccessTokenHash +} + +func (t *idTokenClaims) GetCodeHash() string { + return t.CodeHash +} + +func (t *idTokenClaims) GetAuthenticationMethodsReferences() []string { + return t.AuthenticationMethodsReferences +} + +func (t *idTokenClaims) GetClientID() string { + return t.ClientID +} + +type AccessTokenResponse struct { + AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"` + TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"` + RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"` + ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"` + IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"` } type JWTProfileAssertion struct { - PrivateKeyID string `json:"keyId"` - PrivateKey []byte `json:"key"` - Scopes []string `json:"-"` - Issuer string `json:"-"` - Subject string `json:"userId"` - Audience []string `json:"-"` - Expiration time.Time `json:"-"` - IssuedAt time.Time `json:"-"` + PrivateKeyID string `json:"-"` + PrivateKey []byte `json:"-"` + Scopes []string `json:"scopes"` + Issuer string `json:"issuer"` + Subject string `json:"sub"` + Audience Audience `json:"aud"` + Expiration Time `json:"exp"` + IssuedAt Time `json:"iat"` } func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWTProfileAssertion, error) { @@ -76,12 +215,16 @@ func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWT if err != nil { return nil, err } + return NewJWTProfileAssertionFromFileData(data, audience) +} + +func NewJWTProfileAssertionFromFileData(data []byte, audience []string) (*JWTProfileAssertion, error) { keyData := new(struct { KeyID string `json:"keyId"` Key string `json:"key"` UserID string `json:"userId"` }) - err = json.Unmarshal(data, keyData) + err := json.Unmarshal(data, keyData) if err != nil { return nil, err } @@ -95,241 +238,251 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte) Issuer: userID, Scopes: []string{ScopeOpenID}, Subject: userID, - IssuedAt: time.Now().UTC(), - Expiration: time.Now().Add(1 * time.Hour).UTC(), + IssuedAt: Time(time.Now().UTC()), + Expiration: Time(time.Now().Add(1 * time.Hour).UTC()), Audience: audience, } } -type jsonToken struct { - Issuer string `json:"iss,omitempty"` - Subject string `json:"sub,omitempty"` - Audiences interface{} `json:"aud,omitempty"` - Expiration int64 `json:"exp,omitempty"` - NotBefore int64 `json:"nbf,omitempty"` - IssuedAt int64 `json:"iat,omitempty"` - JWTID string `json:"jti,omitempty"` - AuthorizedParty string `json:"azp,omitempty"` - Nonce string `json:"nonce,omitempty"` - AuthTime int64 `json:"auth_time,omitempty"` - AccessTokenHash string `json:"at_hash,omitempty"` - CodeHash string `json:"c_hash,omitempty"` - AuthenticationContextClassReference string `json:"acr,omitempty"` - AuthenticationMethodsReferences []string `json:"amr,omitempty"` - SessionID string `json:"sid,omitempty"` - Actor interface{} `json:"act,omitempty"` //TODO: impl - Scopes string `json:"scope,omitempty"` - ClientID string `json:"client_id,omitempty"` - AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl - AccessTokenUseNumber int `json:"at_use_nbr,omitempty"` - jsonUserinfo -} +// +//type jsonToken struct { +// Issuer string `json:"iss,omitempty"` +// Subject string `json:"sub,omitempty"` +// Audiences interface{} `json:"aud,omitempty"` +// Expiration int64 `json:"exp,omitempty"` +// NotBefore int64 `json:"nbf,omitempty"` +// IssuedAt int64 `json:"iat,omitempty"` +// JWTID string `json:"jti,omitempty"` +// AuthorizedParty string `json:"azp,omitempty"` +// Nonce string `json:"nonce,omitempty"` +// AuthTime int64 `json:"auth_time,omitempty"` +// AccessTokenHash string `json:"at_hash,omitempty"` +// CodeHash string `json:"c_hash,omitempty"` +// AuthenticationContextClassReference string `json:"acr,omitempty"` +// AuthenticationMethodsReferences []string `json:"amr,omitempty"` +// SessionID string `json:"sid,omitempty"` +// Actor interface{} `json:"act,omitempty"` //TODO: impl +// Scopes string `json:"scope,omitempty"` +// ClientID string `json:"client_id,omitempty"` +// AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl +// AccessTokenUseNumber int `json:"at_use_nbr,omitempty"` +// jsonUserinfo +//} -func (t *AccessTokenClaims) MarshalJSON() ([]byte, error) { - j := jsonToken{ - Issuer: t.Issuer, - Subject: t.Subject, - Audiences: t.Audiences, - Expiration: timeToJSON(t.Expiration), - NotBefore: timeToJSON(t.NotBefore), - IssuedAt: timeToJSON(t.IssuedAt), - JWTID: t.JWTID, - AuthorizedParty: t.AuthorizedParty, - Nonce: t.Nonce, - AuthTime: timeToJSON(t.AuthTime), - CodeHash: t.CodeHash, - AuthenticationContextClassReference: t.AuthenticationContextClassReference, - AuthenticationMethodsReferences: t.AuthenticationMethodsReferences, - SessionID: t.SessionID, - Scopes: strings.Join(t.Scopes, " "), - ClientID: t.ClientID, - AccessTokenUseNumber: t.AccessTokenUseNumber, +// +//func (t *accessTokenClaims) MarshalJSON() ([]byte, error) { +// j := jsonToken{ +// Issuer: t.Issuer, +// Subject: t.Subject, +// Audiences: t.Audiences, +// Expiration: timeToJSON(t.Expiration), +// NotBefore: timeToJSON(t.NotBefore), +// IssuedAt: timeToJSON(t.IssuedAt), +// JWTID: t.JWTID, +// AuthorizedParty: t.AuthorizedParty, +// Nonce: t.Nonce, +// AuthTime: timeToJSON(t.AuthTime), +// CodeHash: t.CodeHash, +// AuthenticationContextClassReference: t.AuthenticationContextClassReference, +// AuthenticationMethodsReferences: t.AuthenticationMethodsReferences, +// SessionID: t.SessionID, +// Scopes: strings.Join(t.Scopes, " "), +// ClientID: t.ClientID, +// AccessTokenUseNumber: t.AccessTokenUseNumber, +// } +// return json.Marshal(j) +//} +// +//func (t *accessTokenClaims) UnmarshalJSON(b []byte) error { +// var j jsonToken +// if err := json.Unmarshal(b, &j); err != nil { +// return err +// } +// t.Issuer = j.Issuer +// t.Subject = j.Subject +// t.Audiences = audienceFromJSON(j.Audiences) +// t.Expiration = time.Unix(j.Expiration, 0).UTC() +// t.NotBefore = time.Unix(j.NotBefore, 0).UTC() +// t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC() +// t.JWTID = j.JWTID +// t.AuthorizedParty = j.AuthorizedParty +// t.Nonce = j.Nonce +// t.AuthTime = time.Unix(j.AuthTime, 0).UTC() +// t.CodeHash = j.CodeHash +// t.AuthenticationContextClassReference = j.AuthenticationContextClassReference +// t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences +// t.SessionID = j.SessionID +// t.Scopes = strings.Split(j.Scopes, " ") +// t.ClientID = j.ClientID +// t.AccessTokenUseNumber = j.AccessTokenUseNumber +// return nil +//} +// +func (t *idTokenClaims) MarshalJSON() ([]byte, error) { + type Alias idTokenClaims + a := &struct { + *Alias + Expiration int64 `json:"nbf,omitempty"` + IssuedAt int64 `json:"nbf,omitempty"` + NotBefore int64 `json:"nbf,omitempty"` + AuthTime int64 `json:"nbf,omitempty"` + }{ + Alias: (*Alias)(t), } - return json.Marshal(j) + if !time.Time(t.Expiration).IsZero() { + a.Expiration = time.Time(t.Expiration).Unix() + } + if !time.Time(t.IssuedAt).IsZero() { + a.IssuedAt = time.Time(t.IssuedAt).Unix() + } + if !time.Time(t.NotBefore).IsZero() { + a.NotBefore = time.Time(t.NotBefore).Unix() + } + if !time.Time(t.AuthTime).IsZero() { + a.AuthTime = time.Time(t.AuthTime).Unix() + } + b, err := json.Marshal(a) + if err != nil { + return nil, err + } + + if t.UserInfo == nil { + return b, nil + } + info, err := json.Marshal(t.UserInfo) + if err != nil { + return nil, err + } + return utils.ConcatenateJSON(b, info) } -func (t *AccessTokenClaims) UnmarshalJSON(b []byte) error { - var j jsonToken - if err := json.Unmarshal(b, &j); err != nil { +func (t *idTokenClaims) UnmarshalJSON(data []byte) error { + type Alias idTokenClaims + if err := json.Unmarshal(data, (*Alias)(t)); err != nil { return err } - t.Issuer = j.Issuer - t.Subject = j.Subject - t.Audiences = audienceFromJSON(j.Audiences) - t.Expiration = time.Unix(j.Expiration, 0).UTC() - t.NotBefore = time.Unix(j.NotBefore, 0).UTC() - t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC() - t.JWTID = j.JWTID - t.AuthorizedParty = j.AuthorizedParty - t.Nonce = j.Nonce - t.AuthTime = time.Unix(j.AuthTime, 0).UTC() - t.CodeHash = j.CodeHash - t.AuthenticationContextClassReference = j.AuthenticationContextClassReference - t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences - t.SessionID = j.SessionID - t.Scopes = strings.Split(j.Scopes, " ") - t.ClientID = j.ClientID - t.AccessTokenUseNumber = j.AccessTokenUseNumber + userinfo := new(userinfo) + if err := json.Unmarshal(data, userinfo); err != nil { + return err + } + t.UserInfo = userinfo + return nil } -func (t *IDTokenClaims) MarshalJSON() ([]byte, error) { - j := jsonToken{ - Issuer: t.Issuer, - Subject: t.Subject, - Audiences: t.Audiences, - Expiration: timeToJSON(t.Expiration), - NotBefore: timeToJSON(t.NotBefore), - IssuedAt: timeToJSON(t.IssuedAt), - JWTID: t.JWTID, - AuthorizedParty: t.AuthorizedParty, - Nonce: t.Nonce, - AuthTime: timeToJSON(t.AuthTime), - AccessTokenHash: t.AccessTokenHash, - CodeHash: t.CodeHash, - AuthenticationContextClassReference: t.AuthenticationContextClassReference, - AuthenticationMethodsReferences: t.AuthenticationMethodsReferences, - ClientID: t.ClientID, - } - j.setUserinfo(t.Userinfo) - return json.Marshal(j) -} - -func (t *IDTokenClaims) UnmarshalJSON(b []byte) error { - var i jsonToken - if err := json.Unmarshal(b, &i); err != nil { - return err - } - t.Issuer = i.Issuer - t.Subject = i.Subject - t.Audiences = audienceFromJSON(i.Audiences) - t.Expiration = time.Unix(i.Expiration, 0).UTC() - t.IssuedAt = time.Unix(i.IssuedAt, 0).UTC() - t.AuthTime = time.Unix(i.AuthTime, 0).UTC() - t.Nonce = i.Nonce - t.AuthenticationContextClassReference = i.AuthenticationContextClassReference - t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences - t.AuthorizedParty = i.AuthorizedParty - t.AccessTokenHash = i.AccessTokenHash - t.CodeHash = i.CodeHash - t.UserinfoProfile = i.UnmarshalUserinfoProfile() - t.UserinfoEmail = i.UnmarshalUserinfoEmail() - t.UserinfoPhone = i.UnmarshalUserinfoPhone() - t.Address = i.UnmarshalUserinfoAddress() - return nil -} - -func (t *IDTokenClaims) GetIssuer() string { +func (t *idTokenClaims) GetIssuer() string { return t.Issuer } -func (t *IDTokenClaims) GetAudience() []string { - return t.Audiences +func (t *idTokenClaims) GetAudience() []string { + return t.Audience } -func (t *IDTokenClaims) GetExpiration() time.Time { - return t.Expiration +func (t *idTokenClaims) GetExpiration() time.Time { + return time.Time(t.Expiration) } -func (t *IDTokenClaims) GetIssuedAt() time.Time { - return t.IssuedAt +func (t *idTokenClaims) GetIssuedAt() time.Time { + return time.Time(t.IssuedAt) } -func (t *IDTokenClaims) GetNonce() string { +func (t *idTokenClaims) GetNonce() string { return t.Nonce } -func (t *IDTokenClaims) GetAuthenticationContextClassReference() string { +func (t *idTokenClaims) GetAuthenticationContextClassReference() string { return t.AuthenticationContextClassReference } -func (t *IDTokenClaims) GetAuthTime() time.Time { - return t.AuthTime +func (t *idTokenClaims) GetAuthTime() time.Time { + return time.Time(t.AuthTime) } -func (t *IDTokenClaims) GetAuthorizedParty() string { +func (t *idTokenClaims) GetAuthorizedParty() string { return t.AuthorizedParty } -func (t *IDTokenClaims) SetSignature(alg jose.SignatureAlgorithm) { - t.Signature = alg +func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) { + t.signatureAlg = alg } -func (t *JWTProfileAssertion) MarshalJSON() ([]byte, error) { - j := jsonToken{ - Issuer: t.Issuer, - Subject: t.Subject, - Audiences: t.Audience, - Expiration: timeToJSON(t.Expiration), - IssuedAt: timeToJSON(t.IssuedAt), - Scopes: strings.Join(t.Scopes, " "), - } - return json.Marshal(j) -} +// +//func (t *JWTProfileAssertion) MarshalJSON() ([]byte, error) { +// j := jsonToken{ +// Issuer: t.Issuer, +// Subject: t.Subject, +// Audiences: t.Audience, +// Expiration: timeToJSON(t.Expiration), +// IssuedAt: timeToJSON(t.IssuedAt), +// Scopes: strings.Join(t.Scopes, " "), +// } +// return json.Marshal(j) +//} -func (t *JWTProfileAssertion) UnmarshalJSON(b []byte) error { - var j jsonToken - if err := json.Unmarshal(b, &j); err != nil { - return err - } +//func (t *JWTProfileAssertion) UnmarshalJSON(b []byte) error { +// var j jsonToken +// if err := json.Unmarshal(b, &j); err != nil { +// return err +// } +// +// t.Issuer = j.Issuer +// t.Subject = j.Subject +// t.Audience = audienceFromJSON(j.Audiences) +// t.Expiration = time.Unix(j.Expiration, 0).UTC() +// t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC() +// t.Scopes = strings.Split(j.Scopes, " ") +// +// return nil +//} - t.Issuer = j.Issuer - t.Subject = j.Subject - t.Audience = audienceFromJSON(j.Audiences) - t.Expiration = time.Unix(j.Expiration, 0).UTC() - t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC() - t.Scopes = strings.Split(j.Scopes, " ") - - return nil -} - -func (j *jsonToken) UnmarshalUserinfoProfile() UserinfoProfile { - locale, _ := language.Parse(j.Locale) - return UserinfoProfile{ - Name: j.Name, - GivenName: j.GivenName, - FamilyName: j.FamilyName, - MiddleName: j.MiddleName, - Nickname: j.Nickname, - Profile: j.Profile, - Picture: j.Picture, - Website: j.Website, - Gender: Gender(j.Gender), - Birthdate: j.Birthdate, - Zoneinfo: j.Zoneinfo, - Locale: locale, - UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(), - PreferredUsername: j.PreferredUsername, - } -} - -func (j *jsonToken) UnmarshalUserinfoEmail() UserinfoEmail { - return UserinfoEmail{ - Email: j.Email, - EmailVerified: j.EmailVerified, - } -} - -func (j *jsonToken) UnmarshalUserinfoPhone() UserinfoPhone { - return UserinfoPhone{ - PhoneNumber: j.Phone, - PhoneNumberVerified: j.PhoneVerified, - } -} - -func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress { - if j.JsonUserinfoAddress == nil { - return nil - } - return &UserinfoAddress{ - Country: j.JsonUserinfoAddress.Country, - Formatted: j.JsonUserinfoAddress.Formatted, - Locality: j.JsonUserinfoAddress.Locality, - PostalCode: j.JsonUserinfoAddress.PostalCode, - Region: j.JsonUserinfoAddress.Region, - StreetAddress: j.JsonUserinfoAddress.StreetAddress, - } -} +// +//func (j *jsonToken) UnmarshalUserinfoProfile() userInfoProfile { +// locale, _ := language.Parse(j.Locale) +// return userInfoProfile{ +// Name: j.Name, +// GivenName: j.GivenName, +// FamilyName: j.FamilyName, +// MiddleName: j.MiddleName, +// Nickname: j.Nickname, +// Profile: j.Profile, +// Picture: j.Picture, +// Website: j.Website, +// Gender: Gender(j.Gender), +// Birthdate: j.Birthdate, +// Zoneinfo: j.Zoneinfo, +// Locale: locale, +// UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(), +// PreferredUsername: j.PreferredUsername, +// } +//} +// +//func (j *jsonToken) UnmarshalUserinfoEmail() userInfoEmail { +// return userInfoEmail{ +// Email: j.Email, +// EmailVerified: j.EmailVerified, +// } +//} +// +//func (j *jsonToken) UnmarshalUserinfoPhone() userInfoPhone { +// return userInfoPhone{ +// PhoneNumber: j.Phone, +// PhoneNumberVerified: j.PhoneVerified, +// } +//} +// +//func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress { +// if j.JsonUserinfoAddress == nil { +// return nil +// } +// return &UserinfoAddress{ +// Country: j.JsonUserinfoAddress.Country, +// Formatted: j.JsonUserinfoAddress.Formatted, +// Locality: j.JsonUserinfoAddress.Locality, +// PostalCode: j.JsonUserinfoAddress.PostalCode, +// Region: j.JsonUserinfoAddress.Region, +// StreetAddress: j.JsonUserinfoAddress.StreetAddress, +// } +//} func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { hash, err := utils.GetHashAlgorithm(sigAlgorithm) diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go new file mode 100644 index 0000000..c04dfb4 --- /dev/null +++ b/pkg/oidc/token_request.go @@ -0,0 +1,101 @@ +package oidc + +import ( + "time" + + "gopkg.in/square/go-jose.v2" +) + +const ( + //GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow + GrantTypeCode GrantType = "authorization_code" + //GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant + GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" +) + +type GrantType string + +type TokenRequest interface { + // GrantType GrantType `schema:"grant_type"` + GrantType() GrantType +} + +type TokenRequestType GrantType + +type AccessTokenRequest struct { + Code string `schema:"code"` + RedirectURI string `schema:"redirect_uri"` + ClientID string `schema:"client_id"` + ClientSecret string `schema:"client_secret"` + CodeVerifier string `schema:"code_verifier"` +} + +func (a *AccessTokenRequest) GrantType() GrantType { + return GrantTypeCode +} + +type JWTTokenRequest struct { + Issuer string `json:"iss"` + Subject string `json:"sub"` + Scopes Scopes `json:"scope"` + Audience Audience `json:"aud"` + IssuedAt Time `json:"iat"` + ExpiresAt Time `json:"exp"` +} + +func (j *JWTTokenRequest) GetClientID() string { + return j.Subject +} + +func (j *JWTTokenRequest) GetSubject() string { + return j.Subject +} + +func (j *JWTTokenRequest) GetScopes() []string { + return j.Scopes +} + +func (j *JWTTokenRequest) GetIssuer() string { + return j.Issuer +} + +func (j *JWTTokenRequest) GetAudience() []string { + return j.Audience +} + +func (j *JWTTokenRequest) GetExpiration() time.Time { + return time.Time(j.ExpiresAt) +} + +func (j *JWTTokenRequest) GetIssuedAt() time.Time { + return time.Time(j.IssuedAt) +} + +func (j *JWTTokenRequest) GetNonce() string { + return "" +} + +func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string { + return "" +} + +func (j *JWTTokenRequest) GetAuthTime() time.Time { + return time.Time{} +} + +func (j *JWTTokenRequest) GetAuthorizedParty() string { + return "" +} + +func (j *JWTTokenRequest) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {} + +type TokenExchangeRequest struct { + subjectToken string `schema:"subject_token"` + subjectTokenType string `schema:"subject_token_type"` + actorToken string `schema:"actor_token"` + actorTokenType string `schema:"actor_token_type"` + resource []string `schema:"resource"` + audience Audience `schema:"audience"` + Scope Scopes `schema:"scope"` + requestedTokenType string `schema:"requested_token_type"` +} diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go new file mode 100644 index 0000000..ad19684 --- /dev/null +++ b/pkg/oidc/types.go @@ -0,0 +1,113 @@ +package oidc + +import ( + "encoding/json" + "strings" + "time" + + "golang.org/x/text/language" +) + +type Audience []string + +func (a *Audience) UnmarshalJSON(text []byte) error { + var i interface{} + err := json.Unmarshal(text, &i) + if err != nil { + return err + } + switch aud := i.(type) { + case []interface{}: + *a = make([]string, len(aud)) + for i, audience := range aud { + (*a)[i] = audience.(string) + } + case string: + *a = []string{aud} + } + return nil +} + +type Display string + +func (d *Display) UnmarshalText(text []byte) error { + display := Display(text) + switch display { + case DisplayPage, DisplayPopup, DisplayTouch, DisplayWAP: + *d = display + } + return nil +} + +type Gender string + +type Locale language.Tag + +//{ +// SetLocale(language.Tag) +// Get() language.Tag +//} +// +//func NewLocale(tag language.Tag) Locale { +// if tag.IsRoot() { +// return nil +// } +// return &locale{Tag: tag} +//} +// +//type locale struct { +// language.Tag +//} +// +//func (l *locale) SetLocale(tag language.Tag) { +// l.Tag = tag +//} +//func (l *locale) Get() language.Tag { +// return l.Tag +//} + +//func (l *locale) MarshalJSON() ([]byte, error) { +// if l != nil && !l.IsRoot() { +// return l.MarshalText() +// } +// return []byte("null"), nil +//} + +type Locales []language.Tag + +func (l *Locales) UnmarshalText(text []byte) error { + locales := strings.Split(string(text), " ") + for _, locale := range locales { + tag, err := language.Parse(locale) + if err == nil && !tag.IsRoot() { + *l = append(*l, tag) + } + } + return nil +} + +type Prompt string + +type ResponseType string + +type Scopes []string + +func (s *Scopes) UnmarshalText(text []byte) error { + *s = strings.Split(string(text), " ") + return nil +} + +type Time time.Time + +func (t *Time) UnmarshalJSON(data []byte) error { + var i int64 + if err := json.Unmarshal(data, &i); err != nil { + return err + } + *t = Time(time.Unix(i, 0).UTC()) + return nil +} + +func (t *Time) MarshalJSON() ([]byte, error) { + return json.Marshal(time.Time(*t).UTC().Unix()) +} diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go new file mode 100644 index 0000000..c451f8c --- /dev/null +++ b/pkg/oidc/types_test.go @@ -0,0 +1,276 @@ +package oidc + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/text/language" +) + +func TestAudience_UnmarshalText(t *testing.T) { + type args struct { + text []byte + } + type res struct { + audience Audience + } + tests := []struct { + name string + args args + res res + wantErr bool + }{ + { + "unknown value", + args{ + []byte(`{"aud": "single audience"}`), + }, + res{ + []string{"single audience"}, + }, + false, + }, + { + "page", + args{ + []byte(`{"aud": ["multiple", "audience"]}`), + }, + res{ + []string{"multiple", "audience"}, + }, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := new(struct { + Audience Audience `json:"aud"` + }) + if err := json.Unmarshal(tt.args.text, &a); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr) + } + assert.ElementsMatch(t, a.Audience, tt.res.audience) + }) + } +} + +func TestDisplay_UnmarshalText(t *testing.T) { + type args struct { + text []byte + } + type res struct { + display Display + } + tests := []struct { + name string + args args + res res + wantErr bool + }{ + { + "unknown value", + args{ + []byte("unknown"), + }, + res{}, + false, + }, + { + "page", + args{ + []byte("page"), + }, + res{DisplayPage}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var d Display + if err := d.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr) + } + if d != tt.res.display { + t.Errorf("Display is not correct is = %v, want %v", d, tt.res.display) + } + }) + } +} + +func TestLocales_UnmarshalText(t *testing.T) { + type args struct { + text []byte + } + type res struct { + tags []language.Tag + } + tests := []struct { + name string + args args + res res + wantErr bool + }{ + { + "unknown value", + args{ + []byte("unknown"), + }, + res{}, + false, + }, + { + "undefined", + args{ + []byte("und"), + }, + res{}, + false, + }, + { + "single language", + args{ + []byte("de"), + }, + res{[]language.Tag{language.German}}, + false, + }, + { + "multiple languages", + args{ + []byte("de en"), + }, + res{[]language.Tag{language.German, language.English}}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var locales Locales + if err := locales.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr) + } + assert.ElementsMatch(t, locales, tt.res.tags) + }) + } +} + +func TestScopes_UnmarshalText(t *testing.T) { + type args struct { + text []byte + } + type res struct { + scopes []string + } + tests := []struct { + name string + args args + res res + wantErr bool + }{ + { + "unknown value", + args{ + []byte("unknown"), + }, + res{ + []string{"unknown"}, + }, + false, + }, + { + "struct", + args{ + []byte(`{"unknown":"value"}`), + }, + res{ + []string{`{"unknown":"value"}`}, + }, + false, + }, + { + "openid", + args{ + []byte("openid"), + }, + res{ + []string{"openid"}, + }, + false, + }, + { + "multiple scopes", + args{ + []byte("openid email custom:scope"), + }, + res{ + []string{"openid", "email", "custom:scope"}, + }, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var scopes Scopes + if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr) + } + assert.ElementsMatch(t, scopes, tt.res.scopes) + }) + } +} + +func TestTime_UnmarshalJSON(t *testing.T) { + type args struct { + text []byte + } + type res struct { + scopes []string + } + tests := []struct { + name string + args args + res res + wantErr bool + }{ + { + "unknown value", + args{ + []byte("unknown"), + }, + res{ + []string{"unknown"}, + }, + false, + }, + { + "openid", + args{ + []byte("openid"), + }, + res{ + []string{"openid"}, + }, + false, + }, + { + "multiple scopes", + args{ + []byte("openid email custom:scope"), + }, + res{ + []string{"openid", "email", "custom:scope"}, + }, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var scopes Scopes + if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr) + } + assert.ElementsMatch(t, scopes, tt.res.scopes) + }) + } +} diff --git a/pkg/oidc/userinfo.go b/pkg/oidc/userinfo.go index f8b4e6c..31de85f 100644 --- a/pkg/oidc/userinfo.go +++ b/pkg/oidc/userinfo.go @@ -2,89 +2,312 @@ package oidc import ( "encoding/json" + "fmt" "time" "golang.org/x/text/language" + + "github.com/caos/oidc/pkg/utils" ) -type Userinfo struct { - Subject string - UserinfoProfile - UserinfoEmail - UserinfoPhone - Address *UserinfoAddress +type UserInfo interface { + GetSubject() string + UserInfoProfile + UserInfoEmail + UserInfoPhone + GetAddress() UserInfoAddress + GetClaim(key string) interface{} +} - Authorizations []string +type UserInfoProfile interface { + GetName() string + GetGivenName() string + GetFamilyName() string + GetMiddleName() string + GetNickname() string + GetProfile() string + GetPicture() string + GetWebsite() string + GetGender() Gender + GetBirthdate() string + GetZoneinfo() string + GetLocale() language.Tag + GetPreferredUsername() string +} + +type UserInfoEmail interface { + GetEmail() string + IsEmailVerified() bool +} + +type UserInfoPhone interface { + GetPhoneNumber() string + IsPhoneNumberVerified() bool +} + +type UserInfoAddress interface { + GetFormatted() string + GetStreetAddress() string + GetLocality() string + GetRegion() string + GetPostalCode() string + GetCountry() string +} + +type UserInfoSetter interface { + UserInfo + SetSubject(sub string) + UserInfoProfileSetter + SetEmail(email string, verified bool) + SetPhone(phone string, verified bool) + SetAddress(address UserInfoAddress) + AppendClaims(key string, values interface{}) +} + +type UserInfoProfileSetter interface { + SetName(name string) + SetGivenName(name string) + SetFamilyName(name string) + SetMiddleName(name string) + SetNickname(name string) + SetUpdatedAt(date time.Time) + SetProfile(profile string) + SetPicture(profile string) + SetWebsite(website string) + SetGender(gender Gender) + SetBirthdate(birthdate string) + SetZoneinfo(zoneInfo string) + SetLocale(locale language.Tag) + SetPreferredUsername(name string) +} + +func NewUserInfo() UserInfoSetter { + return &userinfo{} +} + +type userinfo struct { + Subject string `json:"sub,omitempty"` + userInfoProfile + userInfoEmail + userInfoPhone + Address UserInfoAddress `json:"address,omitempty"` claims map[string]interface{} } -type UserinfoProfile struct { - Name string - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender Gender - Birthdate string - Zoneinfo string - Locale language.Tag - UpdatedAt time.Time - PreferredUsername string +func (u *userinfo) GetSubject() string { + return u.Subject } -type Gender string - -type UserinfoEmail struct { - Email string - EmailVerified bool +func (u *userinfo) GetName() string { + return u.Name } -type UserinfoPhone struct { - PhoneNumber string - PhoneNumberVerified bool +func (u *userinfo) GetGivenName() string { + return u.GivenName } -type UserinfoAddress struct { - Formatted string - StreetAddress string - Locality string - Region string - PostalCode string - Country string +func (u *userinfo) GetFamilyName() string { + return u.FamilyName } -type jsonUserinfoProfile struct { - Name string `json:"name,omitempty"` - GivenName string `json:"given_name,omitempty"` - FamilyName string `json:"family_name,omitempty"` - MiddleName string `json:"middle_name,omitempty"` - Nickname string `json:"nickname,omitempty"` - Profile string `json:"profile,omitempty"` - Picture string `json:"picture,omitempty"` - Website string `json:"website,omitempty"` - Gender string `json:"gender,omitempty"` - Birthdate string `json:"birthdate,omitempty"` - Zoneinfo string `json:"zoneinfo,omitempty"` - Locale string `json:"locale,omitempty"` - UpdatedAt int64 `json:"updated_at,omitempty"` - PreferredUsername string `json:"preferred_username,omitempty"` +func (u *userinfo) GetMiddleName() string { + return u.MiddleName } -type jsonUserinfoEmail struct { +func (u *userinfo) GetNickname() string { + return u.Nickname +} + +func (u *userinfo) GetProfile() string { + return u.Profile +} + +func (u *userinfo) GetPicture() string { + return u.Picture +} + +func (u *userinfo) GetWebsite() string { + return u.Website +} + +func (u *userinfo) GetGender() Gender { + return u.Gender +} + +func (u *userinfo) GetBirthdate() string { + return u.Birthdate +} + +func (u *userinfo) GetZoneinfo() string { + return u.Zoneinfo +} + +func (u *userinfo) GetLocale() language.Tag { + return u.Locale +} + +func (u *userinfo) GetPreferredUsername() string { + return u.PreferredUsername +} + +func (u *userinfo) GetEmail() string { + return u.Email +} + +func (u *userinfo) IsEmailVerified() bool { + return u.EmailVerified +} + +func (u *userinfo) GetPhoneNumber() string { + return u.PhoneNumber +} + +func (u *userinfo) IsPhoneNumberVerified() bool { + return u.PhoneNumberVerified +} + +func (u *userinfo) GetAddress() UserInfoAddress { + return u.Address +} + +func (u *userinfo) GetClaim(key string) interface{} { + return u.claims[key] +} + +func (u *userinfo) SetSubject(sub string) { + u.Subject = sub +} + +func (u *userinfo) SetName(name string) { + u.Name = name +} + +func (u *userinfo) SetGivenName(name string) { + u.GivenName = name +} + +func (u *userinfo) SetFamilyName(name string) { + u.FamilyName = name +} + +func (u *userinfo) SetMiddleName(name string) { + u.MiddleName = name +} + +func (u *userinfo) SetNickname(name string) { + u.Nickname = name +} + +func (u *userinfo) SetUpdatedAt(date time.Time) { + u.UpdatedAt = Time(date) +} + +func (u *userinfo) SetProfile(profile string) { + u.Profile = profile +} + +func (u *userinfo) SetPicture(picture string) { + u.Picture = picture +} + +func (u *userinfo) SetWebsite(website string) { + u.Website = website +} + +func (u *userinfo) SetGender(gender Gender) { + u.Gender = gender +} + +func (u *userinfo) SetBirthdate(birthdate string) { + u.Birthdate = birthdate +} + +func (u *userinfo) SetZoneinfo(zoneInfo string) { + u.Zoneinfo = zoneInfo +} + +func (u *userinfo) SetLocale(locale language.Tag) { + u.Locale = locale +} + +func (u *userinfo) SetPreferredUsername(name string) { + u.PreferredUsername = name +} + +func (u *userinfo) SetEmail(email string, verified bool) { + u.Email = email + u.EmailVerified = verified +} + +func (u *userinfo) SetPhone(phone string, verified bool) { + u.PhoneNumber = phone + u.PhoneNumberVerified = verified +} + +func (u *userinfo) SetAddress(address UserInfoAddress) { + u.Address = address +} + +func (u *userinfo) AppendClaims(key string, value interface{}) { + if u.claims == nil { + u.claims = make(map[string]interface{}) + } + u.claims[key] = value +} + +func (u *userInfoAddress) GetFormatted() string { + panic("implement me") +} + +func (u *userInfoAddress) GetStreetAddress() string { + panic("implement me") +} + +func (u *userInfoAddress) GetLocality() string { + panic("implement me") +} + +func (u *userInfoAddress) GetRegion() string { + panic("implement me") +} + +func (u *userInfoAddress) GetPostalCode() string { + panic("implement me") +} + +func (u *userInfoAddress) GetCountry() string { + panic("implement me") +} + +type userInfoProfile struct { + Name string `json:"name,omitempty"` + GivenName string `json:"given_name,omitempty"` + FamilyName string `json:"family_name,omitempty"` + MiddleName string `json:"middle_name,omitempty"` + Nickname string `json:"nickname,omitempty"` + Profile string `json:"profile,omitempty"` + Picture string `json:"picture,omitempty"` + Website string `json:"website,omitempty"` + Gender Gender `json:"gender,omitempty"` + Birthdate string `json:"birthdate,omitempty"` + Zoneinfo string `json:"zoneinfo,omitempty"` + Locale language.Tag `json:"locale,omitempty"` + UpdatedAt Time `json:"updated_at,omitempty"` + PreferredUsername string `json:"preferred_username,omitempty"` +} + +type userInfoEmail struct { Email string `json:"email,omitempty"` EmailVerified bool `json:"email_verified,omitempty"` } -type jsonUserinfoPhone struct { - Phone string `json:"phone_number,omitempty"` - PhoneVerified bool `json:"phone_number_verified,omitempty"` +type userInfoPhone struct { + PhoneNumber string `json:"phone_number,omitempty"` + PhoneNumberVerified bool `json:"phone_number_verified,omitempty"` } -type jsonUserinfoAddress struct { +type userInfoAddress struct { Formatted string `json:"formatted,omitempty"` StreetAddress string `json:"street_address,omitempty"` Locality string `json:"locality,omitempty"` @@ -93,81 +316,68 @@ type jsonUserinfoAddress struct { Country string `json:"country,omitempty"` } -func (i *Userinfo) MarshalJSON() ([]byte, error) { - j := new(jsonUserinfo) - j.Subject = i.Subject - j.setUserinfo(*i) - j.Authorizations = i.Authorizations - return json.Marshal(j) +func NewUserInfoAddress(streetAddress, locality, region, postalCode, country, formatted string) UserInfoAddress { + return &userInfoAddress{ + StreetAddress: streetAddress, + Locality: locality, + Region: region, + PostalCode: postalCode, + Country: country, + Formatted: formatted, + } +} +func (i *userinfo) MarshalJSON() ([]byte, error) { + type Alias userinfo + a := &struct { + *Alias + Locale interface{} `json:"locale,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + }{ + Alias: (*Alias)(i), + } + if !i.Locale.IsRoot() { + a.Locale = i.Locale + } + fmt.Println(time.Time(i.UpdatedAt).String()) + if !time.Time(i.UpdatedAt).IsZero() { + a.UpdatedAt = time.Time(i.UpdatedAt).Unix() + } + + b, err := json.Marshal(a) + if err != nil { + return nil, err + } + + if len(i.claims) == 0 { + return b, nil + } + + claims, err := json.Marshal(i.claims) + if err != nil { + return nil, fmt.Errorf("jws: invalid map of private claims %v", i.claims) + } + return utils.ConcatenateJSON(b, claims) } -func (i *Userinfo) UnmmarshalJSON(data []byte) error { - if err := json.Unmarshal(data, i); err != nil { +func (i *userinfo) UnmarshalJSON(data []byte) error { + type Alias userinfo + a := &struct { + *Alias + //Locale interface{} `json:"locale,omitempty"` + UpdatedAt int64 `json:"update_at,omitempty"` + }{ + Alias: (*Alias)(i), + } + if err := json.Unmarshal(data, &a); err != nil { return err } - return json.Unmarshal(data, &i.claims) -} + //if !i.Locale.IsRoot() { + // a.Locale = i.Locale + //} -type jsonUserinfo struct { - Subject string `json:"sub,omitempty"` - jsonUserinfoProfile - jsonUserinfoEmail - jsonUserinfoPhone - JsonUserinfoAddress *jsonUserinfoAddress `json:"address,omitempty"` - Authorizations []string `json:"authorizations,omitempty"` -} + i.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC()) -func (j *jsonUserinfo) setUserinfo(i Userinfo) { - j.setUserinfoProfile(i.UserinfoProfile) - j.setUserinfoEmail(i.UserinfoEmail) - j.setUserinfoPhone(i.UserinfoPhone) - j.setUserinfoAddress(i.Address) -} - -func (j *jsonUserinfo) setUserinfoProfile(i UserinfoProfile) { - j.Name = i.Name - j.GivenName = i.GivenName - j.FamilyName = i.FamilyName - j.MiddleName = i.MiddleName - j.Nickname = i.Nickname - j.Profile = i.Profile - j.Picture = i.Picture - j.Website = i.Website - j.Gender = string(i.Gender) - j.Birthdate = i.Birthdate - j.Zoneinfo = i.Zoneinfo - if i.Locale != language.Und { - j.Locale = i.Locale.String() - } - j.UpdatedAt = timeToJSON(i.UpdatedAt) - j.PreferredUsername = i.PreferredUsername -} - -func (j *jsonUserinfo) setUserinfoEmail(i UserinfoEmail) { - j.Email = i.Email - j.EmailVerified = i.EmailVerified -} - -func (j *jsonUserinfo) setUserinfoPhone(i UserinfoPhone) { - j.Phone = i.PhoneNumber - j.PhoneVerified = i.PhoneNumberVerified -} - -func (j *jsonUserinfo) setUserinfoAddress(i *UserinfoAddress) { - if i == nil { - return - } - if i.Country == "" && i.Formatted == "" && i.Locality == "" && i.PostalCode == "" && i.Region == "" && i.StreetAddress == "" { - return - } - j.JsonUserinfoAddress = &jsonUserinfoAddress{ - Country: i.Country, - Formatted: i.Formatted, - Locality: i.Locality, - PostalCode: i.PostalCode, - Region: i.Region, - StreetAddress: i.StreetAddress, - } + return nil } type UserInfoRequest struct { diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go index 492664b..06470a0 100644 --- a/pkg/oidc/verifier.go +++ b/pkg/oidc/verifier.go @@ -24,7 +24,7 @@ type Claims interface { GetAuthenticationContextClassReference() string GetAuthTime() time.Time GetAuthorizedParty() string - SetSignature(algorithm jose.SignatureAlgorithm) + SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) } var ( @@ -140,7 +140,7 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl return ErrSignatureInvalidPayload } - claims.SetSignature(jose.SignatureAlgorithm(sig.Header.Algorithm)) + claims.SetSignatureAlgorithm(jose.SignatureAlgorithm(sig.Header.Algorithm)) return nil } diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index cf40e62..cee6184 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -168,7 +168,7 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie if err != nil { return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.") } - return claims.Subject, nil + return claims.GetSubject(), nil } //RedirectToLogin redirects the end user to the Login UI for authentication diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go index 0a6b6a5..7dfcfff 100644 --- a/pkg/op/mock/authorizer.mock.impl.go +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -81,7 +81,7 @@ func (s *Sig) Health(ctx context.Context) error { func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) { return "", nil } -func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) { +func (s *Sig) SignAccessToken(*oidc.accessTokenClaims) (string, error) { return "", nil } func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm { diff --git a/pkg/op/mock/signer.mock.go b/pkg/op/mock/signer.mock.go index a7d909c..16592a7 100644 --- a/pkg/op/mock/signer.mock.go +++ b/pkg/op/mock/signer.mock.go @@ -50,7 +50,7 @@ func (mr *MockSignerMockRecorder) Health(arg0 interface{}) *gomock.Call { } // SignAccessToken mocks base method -func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) { +func (m *MockSigner) SignAccessToken(arg0 *oidc.accessTokenClaims) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SignAccessToken", arg0) ret0, _ := ret[0].(string) diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index a7ca4cb..bcc04da 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -184,10 +184,10 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac } // GetUserinfoFromScopes mocks base method -func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.Userinfo, error) { +func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.userinfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2) - ret0, _ := ret[0].(*oidc.Userinfo) + ret0, _ := ret[0].(*oidc.userinfo) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -199,10 +199,10 @@ func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interf } // GetUserinfoFromToken mocks base method -func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (*oidc.Userinfo, error) { +func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (*oidc.userinfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2) - ret0, _ := ret[0].(*oidc.Userinfo) + ret0, _ := ret[0].(*oidc.userinfo) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/pkg/op/op.go b/pkg/op/op.go index d913c7f..7e8279a 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -130,7 +130,7 @@ func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opO } keyCh := make(chan jose.SigningKey) - o.signer = NewDefaultSigner(ctx, storage, keyCh) + o.signer = NewSigner(ctx, storage, keyCh) go EnsureKey(ctx, storage, keyCh, o.timer, o.retry) o.httpHandler = CreateRouter(o, o.interceptors...) diff --git a/pkg/op/session.go b/pkg/op/session.go index d04e361..19ebab4 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -66,8 +66,8 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, if err != nil { return nil, ErrInvalidRequest("id_token_hint invalid") } - session.UserID = claims.Subject - session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.AuthorizedParty) + session.UserID = claims.GetSubject() + session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.GetAuthorizedParty()) if err != nil { return nil, ErrServerError("") } diff --git a/pkg/op/signer.go b/pkg/op/signer.go index e9926cd..5cf585e 100644 --- a/pkg/op/signer.go +++ b/pkg/op/signer.go @@ -2,19 +2,17 @@ package op import ( "context" - "encoding/json" "errors" "github.com/caos/logging" "gopkg.in/square/go-jose.v2" - - "github.com/caos/oidc/pkg/oidc" ) type Signer interface { Health(ctx context.Context) error - SignIDToken(claims *oidc.IDTokenClaims) (string, error) - SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) + //SignIDToken(claims *oidc.IDTokenClaims) (string, error) + //SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) + Signer() jose.Signer SignatureAlgorithm() jose.SignatureAlgorithm } @@ -24,7 +22,7 @@ type tokenSigner struct { alg jose.SignatureAlgorithm } -func NewDefaultSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer { +func NewSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer { s := &tokenSigner{ storage: storage, } @@ -41,6 +39,15 @@ func (s *tokenSigner) Health(_ context.Context) error { return nil } +func (s *tokenSigner) Signer() jose.Signer { + return s.signer +} + +// +//func (s *tokenSigner) Sign(payload []byte) (*jose.JSONWebSignature, error) { +// return s.signer.Sign(payload) +//} + func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) { for { select { @@ -55,30 +62,6 @@ func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.S } } -func (s *tokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) { - payload, err := json.Marshal(claims) - if err != nil { - return "", err - } - return s.Sign(payload) -} - -func (s *tokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) { - payload, err := json.Marshal(claims) - if err != nil { - return "", err - } - return s.Sign(payload) -} - -func (s *tokenSigner) Sign(payload []byte) (string, error) { - result, err := s.signer.Sign(payload) - if err != nil { - return "", err - } - return result.CompactSerialize() -} - func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm { return s.alg } diff --git a/pkg/op/signer_test.go b/pkg/op/signer_test.go index 75e184b..c751c76 100644 --- a/pkg/op/signer_test.go +++ b/pkg/op/signer_test.go @@ -38,13 +38,13 @@ import ( // } // for _, tt := range tests { // t.Run(tt.name, func(t *testing.T) { -// got, err := op.NewDefaultSigner(tt.args.storage) +// got, err := op.NewSigner(tt.args.storage) // if (err != nil) != tt.wantErr { -// t.Errorf("NewDefaultSigner() error = %v, wantErr %v", err, tt.wantErr) +// t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr) // return // } // if !reflect.DeepEqual(got, tt.want) { -// t.Errorf("NewDefaultSigner() = %v, want %v", got, tt.want) +// t.Errorf("NewSigner() = %v, want %v", got, tt.want) // } // }) // } diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 669b08e..69784ee 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -28,8 +28,8 @@ type AuthStorage interface { type OPStorage interface { GetClientByClientID(context.Context, string) (Client, error) AuthorizeClientIDSecret(context.Context, string, string) error - GetUserinfoFromScopes(context.Context, string, []string) (*oidc.Userinfo, error) - GetUserinfoFromToken(context.Context, string, string) (*oidc.Userinfo, error) + GetUserinfoFromScopes(context.Context, string, []string) (oidc.UserInfoSetter, error) + GetUserinfoFromToken(context.Context, string, string) (oidc.UserInfoSetter, error) GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) } diff --git a/pkg/op/token.go b/pkg/op/token.go index 87494b9..bb2b3c5 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -5,6 +5,7 @@ import ( "time" "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/utils" ) type TokenCreator interface { @@ -82,51 +83,34 @@ func CreateBearerToken(id string, crypto Crypto) (string, error) { } func CreateJWT(issuer string, authReq TokenRequest, exp time.Time, id string, signer Signer) (string, error) { - now := time.Now().UTC() - nbf := now - claims := &oidc.AccessTokenClaims{ - Issuer: issuer, - Subject: authReq.GetSubject(), - Audiences: authReq.GetAudience(), - Expiration: exp, - IssuedAt: now, - NotBefore: nbf, - JWTID: id, - } - return signer.SignAccessToken(claims) + claims := oidc.NewAccessTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, id) + return utils.Sign(claims, signer.Signer()) } func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer) (string, error) { - var err error exp := time.Now().UTC().Add(validity) - userinfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes()) - if err != nil { - return "", err - } - claims := &oidc.IDTokenClaims{ - Issuer: issuer, - Audiences: authReq.GetAudience(), - Expiration: exp, - IssuedAt: time.Now().UTC(), - AuthTime: authReq.GetAuthTime(), - Nonce: authReq.GetNonce(), - AuthenticationContextClassReference: authReq.GetACR(), - AuthenticationMethodsReferences: authReq.GetAMR(), - AuthorizedParty: authReq.GetClientID(), - Userinfo: *userinfo, - } + claims := oidc.NewIDTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, authReq.GetAuthTime(), authReq.GetNonce(), authReq.GetACR(), authReq.GetAMR(), authReq.GetClientID()) + if accessToken != "" { - claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm()) + atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm()) if err != nil { return "", err } + claims.SetAccessTokenHash(atHash) + } else { + userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes()) + if err != nil { + return "", err + } + claims.SetUserinfo(userInfo) } if code != "" { - claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm()) + codeHash, err := oidc.ClaimHash(code, signer.SignatureAlgorithm()) if err != nil { return "", err } + claims.SetCodeHash(codeHash) } - return signer.SignIDToken(claims) + return utils.Sign(claims, signer.Signer()) } diff --git a/pkg/op/verifier_id_token_hint.go b/pkg/op/verifier_id_token_hint.go index 3268a5e..7baa075 100644 --- a/pkg/op/verifier_id_token_hint.go +++ b/pkg/op/verifier_id_token_hint.go @@ -63,8 +63,8 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet) IDTokenHintVerifi //VerifyIDTokenHint validates the id token according to //https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation -func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (*oidc.IDTokenClaims, error) { - claims := new(oidc.IDTokenClaims) +func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (oidc.IDTokenClaims, error) { + claims := oidc.EmptyIDTokenClaims() decrypted, err := oidc.DecryptToken(token) if err != nil { diff --git a/pkg/rp/mock/verifier.mock.impl.go b/pkg/rp/mock/verifier.mock.impl.go index 53b2f03..0b6dd1c 100644 --- a/pkg/rp/mock/verifier.mock.impl.go +++ b/pkg/rp/mock/verifier.mock.impl.go @@ -33,5 +33,5 @@ func NewMockVerifierExpectValid(t *testing.T) rp.Verifier { func ExpectVerifyValid(v rp.Verifier) { mock := v.(*MockVerifier) - mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.Userinfo{Subject: "id"}}, nil) + mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.userinfo{Subject: "id"}}, nil) } diff --git a/pkg/rp/relaying_party.go b/pkg/rp/relaying_party.go index 3fe8b4b..fd9ee95 100644 --- a/pkg/rp/relaying_party.go +++ b/pkg/rp/relaying_party.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "net/url" "strings" "github.com/google/uuid" @@ -329,10 +330,10 @@ func CallTokenEndpoint(request interface{}, rp RelayingParty) (newToken *oauth2. } func CallJWTProfileEndpoint(assertion string, rp RelayingParty) (*oauth2.Token, error) { - form := make(map[string][]string) - form["assertion"] = []string{assertion} - form["grant_type"] = []string{jwtProfileKey} - req, err := http.NewRequest("POST", rp.OAuthConfig().Endpoint.TokenURL, nil) + form := url.Values{} + form.Add("assertion", assertion) + form.Add("grant_type", jwtProfileKey) + req, err := http.NewRequest("POST", rp.OAuthConfig().Endpoint.TokenURL, strings.NewReader(form.Encode())) if err != nil { return nil, err } diff --git a/pkg/rp/verifier.go b/pkg/rp/verifier.go index ef2cf87..a156f6d 100644 --- a/pkg/rp/verifier.go +++ b/pkg/rp/verifier.go @@ -21,12 +21,12 @@ type IDTokenVerifier interface { //VerifyTokens implement the Token Response Validation as defined in OIDC specification //https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation -func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) { +func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (oidc.IDTokenClaims, error) { idToken, err := VerifyIDToken(ctx, idTokenString, v) if err != nil { return nil, err } - if err := VerifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil { + if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil { return nil, err } return idToken, nil @@ -34,8 +34,8 @@ func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTo //VerifyIDToken validates the id token according to //https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation -func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) { - claims := new(oidc.IDTokenClaims) +func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.IDTokenClaims, error) { + claims := oidc.EmptyIDTokenClaims() decrypted, err := oidc.DecryptToken(token) if err != nil { diff --git a/pkg/utils/marshal.go b/pkg/utils/marshal.go index e279341..4f53b4e 100644 --- a/pkg/utils/marshal.go +++ b/pkg/utils/marshal.go @@ -1,7 +1,9 @@ package utils import ( + "bytes" "encoding/json" + "fmt" "net/http" "github.com/sirupsen/logrus" @@ -19,3 +21,15 @@ func MarshalJSON(w http.ResponseWriter, i interface{}) { logrus.Error("error writing response") } } + +func ConcatenateJSON(first, second []byte) ([]byte, error) { + if !bytes.HasSuffix(first, []byte{'}'}) { + return nil, fmt.Errorf("jws: invalid JSON %s", first) + } + if !bytes.HasPrefix(second, []byte{'{'}) { + return nil, fmt.Errorf("jws: invalid JSON %s", second) + } + first[len(first)-1] = ',' + first = append(first, second[1:]...) + return first, nil +} diff --git a/pkg/utils/sign.go b/pkg/utils/sign.go new file mode 100644 index 0000000..e1efe61 --- /dev/null +++ b/pkg/utils/sign.go @@ -0,0 +1,23 @@ +package utils + +import ( + "encoding/json" + + "gopkg.in/square/go-jose.v2" +) + +func Sign(object interface{}, signer jose.Signer) (string, error) { + payload, err := json.Marshal(object) + if err != nil { + return "", err + } + return SignPayload(payload, signer) +} + +func SignPayload(payload []byte, signer jose.Signer) (string, error) { + result, err := signer.Sign(payload) + if err != nil { + return "", err + } + return result.CompactSerialize() +}