Merge pull request #60 from caos/serializing
feat: private claims (incl. serialisation refactoring and jwt profile fix)
This commit is contained in:
commit
c1699a2d93
43 changed files with 1896 additions and 980 deletions
|
@ -4,6 +4,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
@ -30,7 +32,7 @@ func main() {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
|
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
|
||||||
scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail}
|
scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress}
|
||||||
cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
|
cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
|
||||||
provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes,
|
provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes,
|
||||||
rp.WithPKCE(cookieHandler),
|
rp.WithPKCE(cookieHandler),
|
||||||
|
@ -82,6 +84,66 @@ func main() {
|
||||||
}
|
}
|
||||||
w.Write(data)
|
w.Write(data)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
http.HandleFunc("/jwt-profile", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method == "GET" {
|
||||||
|
tpl := `
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>Login</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<form method="POST" action="/jwt-profile" enctype="multipart/form-data">
|
||||||
|
<label for="key">Select a key file:</label>
|
||||||
|
<input type="file" accept=".json" id="key" name="key">
|
||||||
|
<button type="submit">Get Token</button>
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
t, err := template.New("login").Parse(tpl)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = t.Execute(w, nil)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err := r.ParseMultipartForm(4 << 10)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
file, handler, err := r.FormFile("key")
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
key, err := ioutil.ReadAll(file)
|
||||||
|
fmt.Println(handler.Header)
|
||||||
|
assertion, err := oidc.NewJWTProfileAssertionFromFileData(key, []string{issuer})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token, err := rp.JWTProfileAssertionExchange(ctx, assertion, scopes, provider)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(token)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write(data)
|
||||||
|
}
|
||||||
|
})
|
||||||
lis := fmt.Sprintf("127.0.0.1:%s", port)
|
lis := fmt.Sprintf("127.0.0.1:%s", port)
|
||||||
logrus.Infof("listening on http://%s/", lis)
|
logrus.Infof("listening on http://%s/", lis)
|
||||||
logrus.Fatal(http.ListenAndServe("127.0.0.1:"+port, nil))
|
logrus.Fatal(http.ListenAndServe("127.0.0.1:"+port, nil))
|
||||||
|
|
|
@ -45,7 +45,7 @@ func main() {
|
||||||
}
|
}
|
||||||
token := cli.CodeFlow(relayingParty, callbackPath, port, state)
|
token := cli.CodeFlow(relayingParty, callbackPath, port, state)
|
||||||
|
|
||||||
client := github.NewClient(relayingParty.Client(ctx, token.Token))
|
client := github.NewClient(relayingParty.OAuthConfig().Client(ctx, token.Token))
|
||||||
|
|
||||||
_, _, err = client.Users.Get(ctx, "")
|
_, _, err = client.Users.Get(ctx, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -210,31 +210,21 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _, _ string) (*oidc.Userinfo, error) {
|
func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _, _ string) (oidc.UserInfo, error) {
|
||||||
return s.GetUserinfoFromScopes(ctx, "", []string{})
|
return s.GetUserinfoFromScopes(ctx, "", "", []string{})
|
||||||
}
|
}
|
||||||
func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (*oidc.Userinfo, error) {
|
func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _, _ string, _ []string) (oidc.UserInfo, error) {
|
||||||
return &oidc.Userinfo{
|
userinfo := oidc.NewUserInfo()
|
||||||
Subject: a.GetSubject(),
|
userinfo.SetSubject(a.GetSubject())
|
||||||
Address: &oidc.UserinfoAddress{
|
userinfo.SetAddress(oidc.NewUserInfoAddress("Test 789\nPostfach 2", "", "", "", "", ""))
|
||||||
StreetAddress: "Hjkhkj 789\ndsf",
|
userinfo.SetEmail("test", true)
|
||||||
},
|
userinfo.SetPhone("0791234567", true)
|
||||||
UserinfoEmail: oidc.UserinfoEmail{
|
userinfo.SetName("Test")
|
||||||
Email: "test",
|
userinfo.AppendClaims("private_claim", "test")
|
||||||
EmailVerified: true,
|
return userinfo, nil
|
||||||
},
|
}
|
||||||
UserinfoPhone: oidc.UserinfoPhone{
|
func (s *AuthStorage) GetPrivateClaimsFromScopes(_ context.Context, _, _ string, _ []string) (map[string]interface{}, error) {
|
||||||
PhoneNumber: "sadsa",
|
return map[string]interface{}{"private_claim": "test"}, nil
|
||||||
PhoneNumberVerified: true,
|
|
||||||
},
|
|
||||||
UserinfoProfile: oidc.UserinfoProfile{
|
|
||||||
UpdatedAt: time.Now(),
|
|
||||||
},
|
|
||||||
// Claims: map[string]interface{}{
|
|
||||||
// "test": "test",
|
|
||||||
// "hkjh": "",
|
|
||||||
// },
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConfClient struct {
|
type ConfClient struct {
|
||||||
|
@ -289,3 +279,15 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType {
|
||||||
func (c *ConfClient) DevMode() bool {
|
func (c *ConfClient) DevMode() bool {
|
||||||
return c.devMode
|
return c.devMode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ConfClient) AllowedScopes() []string {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConfClient) AssertAdditionalIdTokenScopes() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConfClient) AssertAdditionalAccessTokenScopes() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
@ -1,15 +1,5 @@
|
||||||
package oidc
|
package oidc
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/text/language"
|
|
||||||
"gopkg.in/square/go-jose.v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
//ScopeOpenID defines the scope `openid`
|
//ScopeOpenID defines the scope `openid`
|
||||||
//OpenID Connect requests MUST contain the `openid` scope value
|
//OpenID Connect requests MUST contain the `openid` scope value
|
||||||
|
@ -64,23 +54,8 @@ const (
|
||||||
|
|
||||||
//PromptSelectAccount (`select_account `) directs the Authorization Server to prompt the End-User to select a user account (to enable multi user / session switching)
|
//PromptSelectAccount (`select_account `) directs the Authorization Server to prompt the End-User to select a user account (to enable multi user / session switching)
|
||||||
PromptSelectAccount Prompt = "select_account"
|
PromptSelectAccount Prompt = "select_account"
|
||||||
|
|
||||||
//GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
|
|
||||||
GrantTypeCode GrantType = "authorization_code"
|
|
||||||
//GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
|
|
||||||
GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
|
|
||||||
|
|
||||||
//BearerToken defines the token_type `Bearer`, which is returned in a successful token response
|
|
||||||
BearerToken = "Bearer"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var displayValues = map[string]Display{
|
|
||||||
"page": DisplayPage,
|
|
||||||
"popup": DisplayPopup,
|
|
||||||
"touch": DisplayTouch,
|
|
||||||
"wap": DisplayWAP,
|
|
||||||
}
|
|
||||||
|
|
||||||
//AuthRequest according to:
|
//AuthRequest according to:
|
||||||
//https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
//https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||||
type AuthRequest struct {
|
type AuthRequest struct {
|
||||||
|
@ -121,146 +96,3 @@ func (a *AuthRequest) GetResponseType() ResponseType {
|
||||||
func (a *AuthRequest) GetState() string {
|
func (a *AuthRequest) GetState() string {
|
||||||
return a.State
|
return a.State
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenRequest interface {
|
|
||||||
// GrantType GrantType `schema:"grant_type"`
|
|
||||||
GrantType() GrantType
|
|
||||||
}
|
|
||||||
|
|
||||||
type TokenRequestType GrantType
|
|
||||||
|
|
||||||
type AccessTokenRequest struct {
|
|
||||||
Code string `schema:"code"`
|
|
||||||
RedirectURI string `schema:"redirect_uri"`
|
|
||||||
ClientID string `schema:"client_id"`
|
|
||||||
ClientSecret string `schema:"client_secret"`
|
|
||||||
CodeVerifier string `schema:"code_verifier"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AccessTokenRequest) GrantType() GrantType {
|
|
||||||
return GrantTypeCode
|
|
||||||
}
|
|
||||||
|
|
||||||
type AccessTokenResponse struct {
|
|
||||||
AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
|
|
||||||
TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
|
|
||||||
RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
|
|
||||||
ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
|
|
||||||
IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type JWTTokenRequest struct {
|
|
||||||
Issuer string `json:"iss"`
|
|
||||||
Subject string `json:"sub"`
|
|
||||||
Scopes Scopes `json:"scope"`
|
|
||||||
Audience interface{} `json:"aud"`
|
|
||||||
IssuedAt Time `json:"iat"`
|
|
||||||
ExpiresAt Time `json:"exp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *JWTTokenRequest) GetClientID() string {
|
|
||||||
return j.Subject
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *JWTTokenRequest) GetSubject() string {
|
|
||||||
return j.Subject
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *JWTTokenRequest) GetScopes() []string {
|
|
||||||
return j.Scopes
|
|
||||||
}
|
|
||||||
|
|
||||||
type Time time.Time
|
|
||||||
|
|
||||||
func (t *Time) UnmarshalJSON(data []byte) error {
|
|
||||||
var i int64
|
|
||||||
if err := json.Unmarshal(data, &i); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
*t = Time(time.Unix(i, 0).UTC())
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *JWTTokenRequest) GetIssuer() string {
|
|
||||||
return j.Issuer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *JWTTokenRequest) GetAudience() []string {
|
|
||||||
return audienceFromJSON(j.Audience)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *JWTTokenRequest) GetExpiration() time.Time {
|
|
||||||
return time.Time(j.ExpiresAt)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *JWTTokenRequest) GetIssuedAt() time.Time {
|
|
||||||
return time.Time(j.IssuedAt)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *JWTTokenRequest) GetNonce() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *JWTTokenRequest) GetAuthTime() time.Time {
|
|
||||||
return time.Time{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *JWTTokenRequest) GetAuthorizedParty() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *JWTTokenRequest) SetSignature(algorithm jose.SignatureAlgorithm) {}
|
|
||||||
|
|
||||||
type TokenExchangeRequest struct {
|
|
||||||
subjectToken string `schema:"subject_token"`
|
|
||||||
subjectTokenType string `schema:"subject_token_type"`
|
|
||||||
actorToken string `schema:"actor_token"`
|
|
||||||
actorTokenType string `schema:"actor_token_type"`
|
|
||||||
resource []string `schema:"resource"`
|
|
||||||
audience []string `schema:"audience"`
|
|
||||||
Scope []string `schema:"scope"`
|
|
||||||
requestedTokenType string `schema:"requested_token_type"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Scopes []string
|
|
||||||
|
|
||||||
func (s *Scopes) UnmarshalText(text []byte) error {
|
|
||||||
scopes := strings.Split(string(text), " ")
|
|
||||||
*s = Scopes(scopes)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResponseType string
|
|
||||||
|
|
||||||
type Display string
|
|
||||||
|
|
||||||
func (d *Display) UnmarshalText(text []byte) error {
|
|
||||||
var ok bool
|
|
||||||
display := string(text)
|
|
||||||
*d, ok = displayValues[display]
|
|
||||||
if !ok {
|
|
||||||
return errors.New("")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type Prompt string
|
|
||||||
|
|
||||||
type Locales []language.Tag
|
|
||||||
|
|
||||||
func (l *Locales) UnmarshalText(text []byte) error {
|
|
||||||
locales := strings.Split(string(text), " ")
|
|
||||||
for _, locale := range locales {
|
|
||||||
tag, err := language.Parse(locale)
|
|
||||||
if err == nil && !tag.IsRoot() {
|
|
||||||
*l = append(*l, tag)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type GrantType string
|
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
package tokenexchange
|
package tokenexchange
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
AccessTokenType = "urn:ietf:params:oauth:token-type:access_token"
|
AccessTokenType = "urn:ietf:params:oauth:token-type:access_token"
|
||||||
RefreshTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
|
RefreshTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
|
||||||
|
@ -24,6 +28,18 @@ type TokenExchangeRequest struct {
|
||||||
|
|
||||||
type JWTProfileRequest struct {
|
type JWTProfileRequest struct {
|
||||||
Assertion string `schema:"assertion"`
|
Assertion string `schema:"assertion"`
|
||||||
|
Scope oidc.Scopes `schema:"scope"`
|
||||||
|
GrantType oidc.GrantType `schema:"grant_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
//ClientCredentialsGrantBasic creates an oauth2 `Client Credentials` Grant
|
||||||
|
//sneding client_id and client_secret as basic auth header
|
||||||
|
func NewJWTProfileRequest(assertion string, scopes ...string) *JWTProfileRequest {
|
||||||
|
return &JWTProfileRequest{
|
||||||
|
GrantType: oidc.GrantTypeBearer,
|
||||||
|
Assertion: assertion,
|
||||||
|
Scope: scopes,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTokenExchangeRequest(subjectToken, subjectTokenType string, opts ...TokenExchangeOption) *TokenExchangeRequest {
|
func NewTokenExchangeRequest(subjectToken, subjectTokenType string, opts ...TokenExchangeOption) *TokenExchangeRequest {
|
||||||
|
|
|
@ -6,21 +6,19 @@ import (
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// KeySet is a set of publc JSON Web Keys that can be used to validate the signature
|
//KeySet represents a set of JSON Web Keys
|
||||||
// of JSON web tokens. This is expected to be backed by a remote key set through
|
// - remotely fetch via discovery and jwks_uri -> `remoteKeySet`
|
||||||
// provider metadata discovery or an in-memory set of keys delivered out-of-band.
|
// - held by the OP itself in storage -> `openIDKeySet`
|
||||||
|
// - dynamically aggregated by request for OAuth JWT Profile Assertion -> `jwtProfileKeySet`
|
||||||
type KeySet interface {
|
type KeySet interface {
|
||||||
// VerifySignature parses the JSON web token, verifies the signature, and returns
|
//VerifySignature verifies the signature with the given keyset and returns the raw payload
|
||||||
// the raw payload. Header and claim fields are validated by other parts of the
|
|
||||||
// package. For example, the KeySet does not need to check values such as signature
|
|
||||||
// algorithm, issuer, and audience since the IDTokenVerifier validates these values
|
|
||||||
// independently.
|
|
||||||
//
|
|
||||||
// If VerifySignature makes HTTP requests to verify the token, it's expected to
|
|
||||||
// use any HTTP client associated with the context through ClientContext.
|
|
||||||
VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error)
|
VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//CheckKey searches the given JSON Web Keys for the requested key ID
|
||||||
|
//and verifies the JSON Web Signature with the found key
|
||||||
|
//
|
||||||
|
//will return false but no error if key ID is not found
|
||||||
func CheckKey(keyID string, jws *jose.JSONWebSignature, keys ...jose.JSONWebKey) ([]byte, error, bool) {
|
func CheckKey(keyID string, jws *jose.JSONWebSignature, keys ...jose.JSONWebKey) ([]byte, error, bool) {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
if keyID == "" || key.KeyID == keyID {
|
if keyID == "" || key.KeyID == keyID {
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
package oidc
|
package oidc
|
||||||
|
|
||||||
|
//EndSessionRequest for the RP-Initiated Logout according to:
|
||||||
|
//https://openid.net/specs/openid-connect-rpinitiated-1_0.html#RPLogout
|
||||||
type EndSessionRequest struct {
|
type EndSessionRequest struct {
|
||||||
IdTokenHint string `schema:"id_token_hint"`
|
IdTokenHint string `schema:"id_token_hint"`
|
||||||
PostLogoutRedirectURI string `schema:"post_logout_redirect_uri"`
|
PostLogoutRedirectURI string `schema:"post_logout_redirect_uri"`
|
||||||
|
|
|
@ -3,72 +3,401 @@ package oidc
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/text/language"
|
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/utils"
|
"github.com/caos/oidc/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
//BearerToken defines the token_type `Bearer`, which is returned in a successful token response
|
||||||
|
BearerToken = "Bearer"
|
||||||
|
)
|
||||||
|
|
||||||
type Tokens struct {
|
type Tokens struct {
|
||||||
*oauth2.Token
|
*oauth2.Token
|
||||||
IDTokenClaims *IDTokenClaims
|
IDTokenClaims IDTokenClaims
|
||||||
IDToken string
|
IDToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccessTokenClaims struct {
|
type AccessTokenClaims interface {
|
||||||
Issuer string
|
Claims
|
||||||
Subject string
|
GetSubject() string
|
||||||
Audiences []string
|
GetTokenID() string
|
||||||
Expiration time.Time
|
SetPrivateClaims(map[string]interface{})
|
||||||
IssuedAt time.Time
|
|
||||||
NotBefore time.Time
|
|
||||||
JWTID string
|
|
||||||
AuthorizedParty string
|
|
||||||
Nonce string
|
|
||||||
AuthTime time.Time
|
|
||||||
CodeHash string
|
|
||||||
AuthenticationContextClassReference string
|
|
||||||
AuthenticationMethodsReferences []string
|
|
||||||
SessionID string
|
|
||||||
Scopes []string
|
|
||||||
ClientID string
|
|
||||||
AccessTokenUseNumber int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type IDTokenClaims struct {
|
type IDTokenClaims interface {
|
||||||
Issuer string
|
Claims
|
||||||
Audiences []string
|
GetNotBefore() time.Time
|
||||||
Expiration time.Time
|
GetJWTID() string
|
||||||
NotBefore time.Time
|
GetAccessTokenHash() string
|
||||||
IssuedAt time.Time
|
GetCodeHash() string
|
||||||
JWTID string
|
GetAuthenticationMethodsReferences() []string
|
||||||
UpdatedAt time.Time
|
GetClientID() string
|
||||||
AuthorizedParty string
|
GetSignatureAlgorithm() jose.SignatureAlgorithm
|
||||||
Nonce string
|
SetAccessTokenHash(hash string)
|
||||||
AuthTime time.Time
|
SetUserinfo(userinfo UserInfo)
|
||||||
AccessTokenHash string
|
SetCodeHash(hash string)
|
||||||
CodeHash string
|
UserInfo
|
||||||
AuthenticationContextClassReference string
|
}
|
||||||
AuthenticationMethodsReferences []string
|
|
||||||
ClientID string
|
|
||||||
Userinfo
|
|
||||||
|
|
||||||
Signature jose.SignatureAlgorithm //TODO: ???
|
func EmptyAccessTokenClaims() AccessTokenClaims {
|
||||||
|
return new(accessTokenClaims)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, id string) AccessTokenClaims {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
return &accessTokenClaims{
|
||||||
|
Issuer: issuer,
|
||||||
|
Subject: subject,
|
||||||
|
Audience: audience,
|
||||||
|
Expiration: Time(expiration),
|
||||||
|
IssuedAt: Time(now),
|
||||||
|
NotBefore: Time(now),
|
||||||
|
JWTID: id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type accessTokenClaims struct {
|
||||||
|
Issuer string `json:"iss,omitempty"`
|
||||||
|
Subject string `json:"sub,omitempty"`
|
||||||
|
Audience Audience `json:"aud,omitempty"`
|
||||||
|
Expiration Time `json:"exp,omitempty"`
|
||||||
|
IssuedAt Time `json:"iat,omitempty"`
|
||||||
|
NotBefore Time `json:"nbf,omitempty"`
|
||||||
|
JWTID string `json:"jti,omitempty"`
|
||||||
|
AuthorizedParty string `json:"azp,omitempty"`
|
||||||
|
Nonce string `json:"nonce,omitempty"`
|
||||||
|
AuthTime Time `json:"auth_time,omitempty"`
|
||||||
|
CodeHash string `json:"c_hash,omitempty"`
|
||||||
|
AuthenticationContextClassReference string `json:"acr,omitempty"`
|
||||||
|
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
|
||||||
|
SessionID string `json:"sid,omitempty"`
|
||||||
|
Scopes []string `json:"scope,omitempty"`
|
||||||
|
ClientID string `json:"client_id,omitempty"`
|
||||||
|
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
|
||||||
|
|
||||||
|
claims map[string]interface{} `json:"-"`
|
||||||
|
signatureAlg jose.SignatureAlgorithm `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetIssuer implements the Claims interface
|
||||||
|
func (a *accessTokenClaims) GetIssuer() string {
|
||||||
|
return a.Issuer
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetAudience implements the Claims interface
|
||||||
|
func (a *accessTokenClaims) GetAudience() []string {
|
||||||
|
return a.Audience
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetExpiration implements the Claims interface
|
||||||
|
func (a *accessTokenClaims) GetExpiration() time.Time {
|
||||||
|
return time.Time(a.Expiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetIssuedAt implements the Claims interface
|
||||||
|
func (a *accessTokenClaims) GetIssuedAt() time.Time {
|
||||||
|
return time.Time(a.IssuedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetNonce implements the Claims interface
|
||||||
|
func (a *accessTokenClaims) GetNonce() string {
|
||||||
|
return a.Nonce
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetAuthenticationContextClassReference implements the Claims interface
|
||||||
|
func (a *accessTokenClaims) GetAuthenticationContextClassReference() string {
|
||||||
|
return a.AuthenticationContextClassReference
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetAuthTime implements the Claims interface
|
||||||
|
func (a *accessTokenClaims) GetAuthTime() time.Time {
|
||||||
|
return time.Time(a.AuthTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetAuthorizedParty implements the Claims interface
|
||||||
|
func (a *accessTokenClaims) GetAuthorizedParty() string {
|
||||||
|
return a.AuthorizedParty
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetSignatureAlgorithm implements the Claims interface
|
||||||
|
func (a *accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
|
||||||
|
a.signatureAlg = algorithm
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetSubject implements the AccessTokenClaims interface
|
||||||
|
func (a *accessTokenClaims) GetSubject() string {
|
||||||
|
return a.Subject
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetTokenID implements the AccessTokenClaims interface
|
||||||
|
func (a *accessTokenClaims) GetTokenID() string {
|
||||||
|
return a.JWTID
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetPrivateClaims implements the AccessTokenClaims interface
|
||||||
|
func (a *accessTokenClaims) SetPrivateClaims(claims map[string]interface{}) {
|
||||||
|
a.claims = claims
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *accessTokenClaims) MarshalJSON() ([]byte, error) {
|
||||||
|
type Alias accessTokenClaims
|
||||||
|
s := &struct {
|
||||||
|
*Alias
|
||||||
|
Expiration int64 `json:"exp,omitempty"`
|
||||||
|
IssuedAt int64 `json:"iat,omitempty"`
|
||||||
|
NotBefore int64 `json:"nbf,omitempty"`
|
||||||
|
AuthTime int64 `json:"auth_time,omitempty"`
|
||||||
|
}{
|
||||||
|
Alias: (*Alias)(a),
|
||||||
|
}
|
||||||
|
if !time.Time(a.Expiration).IsZero() {
|
||||||
|
s.Expiration = time.Time(a.Expiration).Unix()
|
||||||
|
}
|
||||||
|
if !time.Time(a.IssuedAt).IsZero() {
|
||||||
|
s.IssuedAt = time.Time(a.IssuedAt).Unix()
|
||||||
|
}
|
||||||
|
if !time.Time(a.NotBefore).IsZero() {
|
||||||
|
s.NotBefore = time.Time(a.NotBefore).Unix()
|
||||||
|
}
|
||||||
|
if !time.Time(a.AuthTime).IsZero() {
|
||||||
|
s.AuthTime = time.Time(a.AuthTime).Unix()
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.claims == nil {
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
info, err := json.Marshal(a.claims)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return utils.ConcatenateJSON(b, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *accessTokenClaims) UnmarshalJSON(data []byte) error {
|
||||||
|
type Alias accessTokenClaims
|
||||||
|
if err := json.Unmarshal(data, (*Alias)(a)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
claims := make(map[string]interface{})
|
||||||
|
if err := json.Unmarshal(data, &claims); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
a.claims = claims
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func EmptyIDTokenClaims() IDTokenClaims {
|
||||||
|
return new(idTokenClaims)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string) IDTokenClaims {
|
||||||
|
return &idTokenClaims{
|
||||||
|
Issuer: issuer,
|
||||||
|
Audience: audience,
|
||||||
|
Expiration: Time(expiration),
|
||||||
|
IssuedAt: Time(time.Now().UTC()),
|
||||||
|
AuthTime: Time(authTime),
|
||||||
|
Nonce: nonce,
|
||||||
|
AuthenticationContextClassReference: acr,
|
||||||
|
AuthenticationMethodsReferences: amr,
|
||||||
|
AuthorizedParty: clientID,
|
||||||
|
UserInfo: &userinfo{Subject: subject},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type idTokenClaims struct {
|
||||||
|
Issuer string `json:"iss,omitempty"`
|
||||||
|
Audience Audience `json:"aud,omitempty"`
|
||||||
|
Expiration Time `json:"exp,omitempty"`
|
||||||
|
NotBefore Time `json:"nbf,omitempty"`
|
||||||
|
IssuedAt Time `json:"iat,omitempty"`
|
||||||
|
JWTID string `json:"jti,omitempty"`
|
||||||
|
AuthorizedParty string `json:"azp,omitempty"`
|
||||||
|
Nonce string `json:"nonce,omitempty"`
|
||||||
|
AuthTime Time `json:"auth_time,omitempty"`
|
||||||
|
AccessTokenHash string `json:"at_hash,omitempty"`
|
||||||
|
CodeHash string `json:"c_hash,omitempty"`
|
||||||
|
AuthenticationContextClassReference string `json:"acr,omitempty"`
|
||||||
|
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
|
||||||
|
ClientID string `json:"client_id,omitempty"`
|
||||||
|
UserInfo `json:"-"`
|
||||||
|
|
||||||
|
signatureAlg jose.SignatureAlgorithm
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetIssuer implements the Claims interface
|
||||||
|
func (t *idTokenClaims) GetIssuer() string {
|
||||||
|
return t.Issuer
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetAudience implements the Claims interface
|
||||||
|
func (t *idTokenClaims) GetAudience() []string {
|
||||||
|
return t.Audience
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetExpiration implements the Claims interface
|
||||||
|
func (t *idTokenClaims) GetExpiration() time.Time {
|
||||||
|
return time.Time(t.Expiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetIssuedAt implements the Claims interface
|
||||||
|
func (t *idTokenClaims) GetIssuedAt() time.Time {
|
||||||
|
return time.Time(t.IssuedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetNonce implements the Claims interface
|
||||||
|
func (t *idTokenClaims) GetNonce() string {
|
||||||
|
return t.Nonce
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetAuthenticationContextClassReference implements the Claims interface
|
||||||
|
func (t *idTokenClaims) GetAuthenticationContextClassReference() string {
|
||||||
|
return t.AuthenticationContextClassReference
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetAuthTime implements the Claims interface
|
||||||
|
func (t *idTokenClaims) GetAuthTime() time.Time {
|
||||||
|
return time.Time(t.AuthTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetAuthorizedParty implements the Claims interface
|
||||||
|
func (t *idTokenClaims) GetAuthorizedParty() string {
|
||||||
|
return t.AuthorizedParty
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetSignatureAlgorithm implements the Claims interface
|
||||||
|
func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) {
|
||||||
|
t.signatureAlg = alg
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetNotBefore implements the IDTokenClaims interface
|
||||||
|
func (t *idTokenClaims) GetNotBefore() time.Time {
|
||||||
|
return time.Time(t.NotBefore)
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetJWTID implements the IDTokenClaims interface
|
||||||
|
func (t *idTokenClaims) GetJWTID() string {
|
||||||
|
return t.JWTID
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetAccessTokenHash implements the IDTokenClaims interface
|
||||||
|
func (t *idTokenClaims) GetAccessTokenHash() string {
|
||||||
|
return t.AccessTokenHash
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetCodeHash implements the IDTokenClaims interface
|
||||||
|
func (t *idTokenClaims) GetCodeHash() string {
|
||||||
|
return t.CodeHash
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetAuthenticationMethodsReferences implements the IDTokenClaims interface
|
||||||
|
func (t *idTokenClaims) GetAuthenticationMethodsReferences() []string {
|
||||||
|
return t.AuthenticationMethodsReferences
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetClientID implements the IDTokenClaims interface
|
||||||
|
func (t *idTokenClaims) GetClientID() string {
|
||||||
|
return t.ClientID
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetSignatureAlgorithm implements the IDTokenClaims interface
|
||||||
|
func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm {
|
||||||
|
return t.signatureAlg
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetSignatureAlgorithm implements the IDTokenClaims interface
|
||||||
|
func (t *idTokenClaims) SetAccessTokenHash(hash string) {
|
||||||
|
t.AccessTokenHash = hash
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetUserinfo implements the IDTokenClaims interface
|
||||||
|
func (t *idTokenClaims) SetUserinfo(info UserInfo) {
|
||||||
|
t.UserInfo = info
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetCodeHash implements the IDTokenClaims interface
|
||||||
|
func (t *idTokenClaims) SetCodeHash(hash string) {
|
||||||
|
t.CodeHash = hash
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *idTokenClaims) MarshalJSON() ([]byte, error) {
|
||||||
|
type Alias idTokenClaims
|
||||||
|
a := &struct {
|
||||||
|
*Alias
|
||||||
|
Expiration int64 `json:"exp,omitempty"`
|
||||||
|
IssuedAt int64 `json:"iat,omitempty"`
|
||||||
|
NotBefore int64 `json:"nbf,omitempty"`
|
||||||
|
AuthTime int64 `json:"auth_time,omitempty"`
|
||||||
|
}{
|
||||||
|
Alias: (*Alias)(t),
|
||||||
|
}
|
||||||
|
if !time.Time(t.Expiration).IsZero() {
|
||||||
|
a.Expiration = time.Time(t.Expiration).Unix()
|
||||||
|
}
|
||||||
|
if !time.Time(t.IssuedAt).IsZero() {
|
||||||
|
a.IssuedAt = time.Time(t.IssuedAt).Unix()
|
||||||
|
}
|
||||||
|
if !time.Time(t.NotBefore).IsZero() {
|
||||||
|
a.NotBefore = time.Time(t.NotBefore).Unix()
|
||||||
|
}
|
||||||
|
if !time.Time(t.AuthTime).IsZero() {
|
||||||
|
a.AuthTime = time.Time(t.AuthTime).Unix()
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(a)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.UserInfo == nil {
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
info, err := json.Marshal(t.UserInfo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return utils.ConcatenateJSON(b, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *idTokenClaims) UnmarshalJSON(data []byte) error {
|
||||||
|
type Alias idTokenClaims
|
||||||
|
if err := json.Unmarshal(data, (*Alias)(t)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
userinfo := new(userinfo)
|
||||||
|
if err := json.Unmarshal(data, userinfo); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.UserInfo = userinfo
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type AccessTokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
|
||||||
|
TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
|
||||||
|
RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
|
||||||
|
ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
|
||||||
|
IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type JWTProfileAssertion struct {
|
type JWTProfileAssertion struct {
|
||||||
PrivateKeyID string `json:"keyId"`
|
PrivateKeyID string `json:"-"`
|
||||||
PrivateKey []byte `json:"key"`
|
PrivateKey []byte `json:"-"`
|
||||||
Scopes []string `json:"-"`
|
Issuer string `json:"issuer"`
|
||||||
Issuer string `json:"-"`
|
Subject string `json:"sub"`
|
||||||
Subject string `json:"userId"`
|
Audience Audience `json:"aud"`
|
||||||
Audience []string `json:"-"`
|
Expiration Time `json:"exp"`
|
||||||
Expiration time.Time `json:"-"`
|
IssuedAt Time `json:"iat"`
|
||||||
IssuedAt time.Time `json:"-"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWTProfileAssertion, error) {
|
func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWTProfileAssertion, error) {
|
||||||
|
@ -76,12 +405,16 @@ func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWT
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return NewJWTProfileAssertionFromFileData(data, audience)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewJWTProfileAssertionFromFileData(data []byte, audience []string) (*JWTProfileAssertion, error) {
|
||||||
keyData := new(struct {
|
keyData := new(struct {
|
||||||
KeyID string `json:"keyId"`
|
KeyID string `json:"keyId"`
|
||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
UserID string `json:"userId"`
|
UserID string `json:"userId"`
|
||||||
})
|
})
|
||||||
err = json.Unmarshal(data, keyData)
|
err := json.Unmarshal(data, keyData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -93,244 +426,13 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte)
|
||||||
PrivateKey: key,
|
PrivateKey: key,
|
||||||
PrivateKeyID: keyID,
|
PrivateKeyID: keyID,
|
||||||
Issuer: userID,
|
Issuer: userID,
|
||||||
Scopes: []string{ScopeOpenID},
|
|
||||||
Subject: userID,
|
Subject: userID,
|
||||||
IssuedAt: time.Now().UTC(),
|
IssuedAt: Time(time.Now().UTC()),
|
||||||
Expiration: time.Now().Add(1 * time.Hour).UTC(),
|
Expiration: Time(time.Now().Add(1 * time.Hour).UTC()),
|
||||||
Audience: audience,
|
Audience: audience,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type jsonToken struct {
|
|
||||||
Issuer string `json:"iss,omitempty"`
|
|
||||||
Subject string `json:"sub,omitempty"`
|
|
||||||
Audiences interface{} `json:"aud,omitempty"`
|
|
||||||
Expiration int64 `json:"exp,omitempty"`
|
|
||||||
NotBefore int64 `json:"nbf,omitempty"`
|
|
||||||
IssuedAt int64 `json:"iat,omitempty"`
|
|
||||||
JWTID string `json:"jti,omitempty"`
|
|
||||||
AuthorizedParty string `json:"azp,omitempty"`
|
|
||||||
Nonce string `json:"nonce,omitempty"`
|
|
||||||
AuthTime int64 `json:"auth_time,omitempty"`
|
|
||||||
AccessTokenHash string `json:"at_hash,omitempty"`
|
|
||||||
CodeHash string `json:"c_hash,omitempty"`
|
|
||||||
AuthenticationContextClassReference string `json:"acr,omitempty"`
|
|
||||||
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
|
|
||||||
SessionID string `json:"sid,omitempty"`
|
|
||||||
Actor interface{} `json:"act,omitempty"` //TODO: impl
|
|
||||||
Scopes string `json:"scope,omitempty"`
|
|
||||||
ClientID string `json:"client_id,omitempty"`
|
|
||||||
AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl
|
|
||||||
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
|
|
||||||
jsonUserinfo
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *AccessTokenClaims) MarshalJSON() ([]byte, error) {
|
|
||||||
j := jsonToken{
|
|
||||||
Issuer: t.Issuer,
|
|
||||||
Subject: t.Subject,
|
|
||||||
Audiences: t.Audiences,
|
|
||||||
Expiration: timeToJSON(t.Expiration),
|
|
||||||
NotBefore: timeToJSON(t.NotBefore),
|
|
||||||
IssuedAt: timeToJSON(t.IssuedAt),
|
|
||||||
JWTID: t.JWTID,
|
|
||||||
AuthorizedParty: t.AuthorizedParty,
|
|
||||||
Nonce: t.Nonce,
|
|
||||||
AuthTime: timeToJSON(t.AuthTime),
|
|
||||||
CodeHash: t.CodeHash,
|
|
||||||
AuthenticationContextClassReference: t.AuthenticationContextClassReference,
|
|
||||||
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
|
|
||||||
SessionID: t.SessionID,
|
|
||||||
Scopes: strings.Join(t.Scopes, " "),
|
|
||||||
ClientID: t.ClientID,
|
|
||||||
AccessTokenUseNumber: t.AccessTokenUseNumber,
|
|
||||||
}
|
|
||||||
return json.Marshal(j)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *AccessTokenClaims) UnmarshalJSON(b []byte) error {
|
|
||||||
var j jsonToken
|
|
||||||
if err := json.Unmarshal(b, &j); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
t.Issuer = j.Issuer
|
|
||||||
t.Subject = j.Subject
|
|
||||||
t.Audiences = audienceFromJSON(j.Audiences)
|
|
||||||
t.Expiration = time.Unix(j.Expiration, 0).UTC()
|
|
||||||
t.NotBefore = time.Unix(j.NotBefore, 0).UTC()
|
|
||||||
t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
|
|
||||||
t.JWTID = j.JWTID
|
|
||||||
t.AuthorizedParty = j.AuthorizedParty
|
|
||||||
t.Nonce = j.Nonce
|
|
||||||
t.AuthTime = time.Unix(j.AuthTime, 0).UTC()
|
|
||||||
t.CodeHash = j.CodeHash
|
|
||||||
t.AuthenticationContextClassReference = j.AuthenticationContextClassReference
|
|
||||||
t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences
|
|
||||||
t.SessionID = j.SessionID
|
|
||||||
t.Scopes = strings.Split(j.Scopes, " ")
|
|
||||||
t.ClientID = j.ClientID
|
|
||||||
t.AccessTokenUseNumber = j.AccessTokenUseNumber
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IDTokenClaims) MarshalJSON() ([]byte, error) {
|
|
||||||
j := jsonToken{
|
|
||||||
Issuer: t.Issuer,
|
|
||||||
Subject: t.Subject,
|
|
||||||
Audiences: t.Audiences,
|
|
||||||
Expiration: timeToJSON(t.Expiration),
|
|
||||||
NotBefore: timeToJSON(t.NotBefore),
|
|
||||||
IssuedAt: timeToJSON(t.IssuedAt),
|
|
||||||
JWTID: t.JWTID,
|
|
||||||
AuthorizedParty: t.AuthorizedParty,
|
|
||||||
Nonce: t.Nonce,
|
|
||||||
AuthTime: timeToJSON(t.AuthTime),
|
|
||||||
AccessTokenHash: t.AccessTokenHash,
|
|
||||||
CodeHash: t.CodeHash,
|
|
||||||
AuthenticationContextClassReference: t.AuthenticationContextClassReference,
|
|
||||||
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
|
|
||||||
ClientID: t.ClientID,
|
|
||||||
}
|
|
||||||
j.setUserinfo(t.Userinfo)
|
|
||||||
return json.Marshal(j)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
|
|
||||||
var i jsonToken
|
|
||||||
if err := json.Unmarshal(b, &i); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
t.Issuer = i.Issuer
|
|
||||||
t.Subject = i.Subject
|
|
||||||
t.Audiences = audienceFromJSON(i.Audiences)
|
|
||||||
t.Expiration = time.Unix(i.Expiration, 0).UTC()
|
|
||||||
t.IssuedAt = time.Unix(i.IssuedAt, 0).UTC()
|
|
||||||
t.AuthTime = time.Unix(i.AuthTime, 0).UTC()
|
|
||||||
t.Nonce = i.Nonce
|
|
||||||
t.AuthenticationContextClassReference = i.AuthenticationContextClassReference
|
|
||||||
t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences
|
|
||||||
t.AuthorizedParty = i.AuthorizedParty
|
|
||||||
t.AccessTokenHash = i.AccessTokenHash
|
|
||||||
t.CodeHash = i.CodeHash
|
|
||||||
t.UserinfoProfile = i.UnmarshalUserinfoProfile()
|
|
||||||
t.UserinfoEmail = i.UnmarshalUserinfoEmail()
|
|
||||||
t.UserinfoPhone = i.UnmarshalUserinfoPhone()
|
|
||||||
t.Address = i.UnmarshalUserinfoAddress()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IDTokenClaims) GetIssuer() string {
|
|
||||||
return t.Issuer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IDTokenClaims) GetAudience() []string {
|
|
||||||
return t.Audiences
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IDTokenClaims) GetExpiration() time.Time {
|
|
||||||
return t.Expiration
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IDTokenClaims) GetIssuedAt() time.Time {
|
|
||||||
return t.IssuedAt
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IDTokenClaims) GetNonce() string {
|
|
||||||
return t.Nonce
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IDTokenClaims) GetAuthenticationContextClassReference() string {
|
|
||||||
return t.AuthenticationContextClassReference
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IDTokenClaims) GetAuthTime() time.Time {
|
|
||||||
return t.AuthTime
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IDTokenClaims) GetAuthorizedParty() string {
|
|
||||||
return t.AuthorizedParty
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IDTokenClaims) SetSignature(alg jose.SignatureAlgorithm) {
|
|
||||||
t.Signature = alg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *JWTProfileAssertion) MarshalJSON() ([]byte, error) {
|
|
||||||
j := jsonToken{
|
|
||||||
Issuer: t.Issuer,
|
|
||||||
Subject: t.Subject,
|
|
||||||
Audiences: t.Audience,
|
|
||||||
Expiration: timeToJSON(t.Expiration),
|
|
||||||
IssuedAt: timeToJSON(t.IssuedAt),
|
|
||||||
Scopes: strings.Join(t.Scopes, " "),
|
|
||||||
}
|
|
||||||
return json.Marshal(j)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *JWTProfileAssertion) UnmarshalJSON(b []byte) error {
|
|
||||||
var j jsonToken
|
|
||||||
if err := json.Unmarshal(b, &j); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Issuer = j.Issuer
|
|
||||||
t.Subject = j.Subject
|
|
||||||
t.Audience = audienceFromJSON(j.Audiences)
|
|
||||||
t.Expiration = time.Unix(j.Expiration, 0).UTC()
|
|
||||||
t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
|
|
||||||
t.Scopes = strings.Split(j.Scopes, " ")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *jsonToken) UnmarshalUserinfoProfile() UserinfoProfile {
|
|
||||||
locale, _ := language.Parse(j.Locale)
|
|
||||||
return UserinfoProfile{
|
|
||||||
Name: j.Name,
|
|
||||||
GivenName: j.GivenName,
|
|
||||||
FamilyName: j.FamilyName,
|
|
||||||
MiddleName: j.MiddleName,
|
|
||||||
Nickname: j.Nickname,
|
|
||||||
Profile: j.Profile,
|
|
||||||
Picture: j.Picture,
|
|
||||||
Website: j.Website,
|
|
||||||
Gender: Gender(j.Gender),
|
|
||||||
Birthdate: j.Birthdate,
|
|
||||||
Zoneinfo: j.Zoneinfo,
|
|
||||||
Locale: locale,
|
|
||||||
UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(),
|
|
||||||
PreferredUsername: j.PreferredUsername,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *jsonToken) UnmarshalUserinfoEmail() UserinfoEmail {
|
|
||||||
return UserinfoEmail{
|
|
||||||
Email: j.Email,
|
|
||||||
EmailVerified: j.EmailVerified,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *jsonToken) UnmarshalUserinfoPhone() UserinfoPhone {
|
|
||||||
return UserinfoPhone{
|
|
||||||
PhoneNumber: j.Phone,
|
|
||||||
PhoneNumberVerified: j.PhoneVerified,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress {
|
|
||||||
if j.JsonUserinfoAddress == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &UserinfoAddress{
|
|
||||||
Country: j.JsonUserinfoAddress.Country,
|
|
||||||
Formatted: j.JsonUserinfoAddress.Formatted,
|
|
||||||
Locality: j.JsonUserinfoAddress.Locality,
|
|
||||||
PostalCode: j.JsonUserinfoAddress.PostalCode,
|
|
||||||
Region: j.JsonUserinfoAddress.Region,
|
|
||||||
StreetAddress: j.JsonUserinfoAddress.StreetAddress,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
||||||
hash, err := utils.GetHashAlgorithm(sigAlgorithm)
|
hash, err := utils.GetHashAlgorithm(sigAlgorithm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -339,26 +441,3 @@ func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, erro
|
||||||
|
|
||||||
return utils.HashString(hash, claim, true), nil
|
return utils.HashString(hash, claim, true), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func timeToJSON(t time.Time) int64 {
|
|
||||||
if t.IsZero() {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return t.Unix()
|
|
||||||
}
|
|
||||||
|
|
||||||
func audienceFromJSON(i interface{}) []string {
|
|
||||||
switch aud := i.(type) {
|
|
||||||
case []string:
|
|
||||||
return aud
|
|
||||||
case []interface{}:
|
|
||||||
audience := make([]string, len(aud))
|
|
||||||
for i, a := range aud {
|
|
||||||
audience[i] = a.(string)
|
|
||||||
}
|
|
||||||
return audience
|
|
||||||
case string:
|
|
||||||
return []string{aud}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
108
pkg/oidc/token_request.go
Normal file
108
pkg/oidc/token_request.go
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
package oidc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
//GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
|
||||||
|
GrantTypeCode GrantType = "authorization_code"
|
||||||
|
//GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
|
||||||
|
GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GrantType string
|
||||||
|
|
||||||
|
type TokenRequest interface {
|
||||||
|
// GrantType GrantType `schema:"grant_type"`
|
||||||
|
GrantType() GrantType
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenRequestType GrantType
|
||||||
|
|
||||||
|
type AccessTokenRequest struct {
|
||||||
|
Code string `schema:"code"`
|
||||||
|
RedirectURI string `schema:"redirect_uri"`
|
||||||
|
ClientID string `schema:"client_id"`
|
||||||
|
ClientSecret string `schema:"client_secret"`
|
||||||
|
CodeVerifier string `schema:"code_verifier"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AccessTokenRequest) GrantType() GrantType {
|
||||||
|
return GrantTypeCode
|
||||||
|
}
|
||||||
|
|
||||||
|
type JWTTokenRequest struct {
|
||||||
|
Issuer string `json:"iss"`
|
||||||
|
Subject string `json:"sub"`
|
||||||
|
Scopes Scopes `json:"-"`
|
||||||
|
Audience Audience `json:"aud"`
|
||||||
|
IssuedAt Time `json:"iat"`
|
||||||
|
ExpiresAt Time `json:"exp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetIssuer implements the Claims interface
|
||||||
|
func (j *JWTTokenRequest) GetIssuer() string {
|
||||||
|
return j.Issuer
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetAudience implements the Claims and TokenRequest interfaces
|
||||||
|
func (j *JWTTokenRequest) GetAudience() []string {
|
||||||
|
return j.Audience
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetExpiration implements the Claims interface
|
||||||
|
func (j *JWTTokenRequest) GetExpiration() time.Time {
|
||||||
|
return time.Time(j.ExpiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetIssuedAt implements the Claims interface
|
||||||
|
func (j *JWTTokenRequest) GetIssuedAt() time.Time {
|
||||||
|
return time.Time(j.IssuedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetNonce implements the Claims interface
|
||||||
|
func (j *JWTTokenRequest) GetNonce() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetAuthenticationContextClassReference implements the Claims interface
|
||||||
|
func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetAuthTime implements the Claims interface
|
||||||
|
func (j *JWTTokenRequest) GetAuthTime() time.Time {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetAuthorizedParty implements the Claims interface
|
||||||
|
func (j *JWTTokenRequest) GetAuthorizedParty() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetSignatureAlgorithm implements the Claims interface
|
||||||
|
func (j *JWTTokenRequest) SetSignatureAlgorithm(_ jose.SignatureAlgorithm) {}
|
||||||
|
|
||||||
|
//GetSubject implements the TokenRequest interface
|
||||||
|
func (j *JWTTokenRequest) GetSubject() string {
|
||||||
|
return j.Subject
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetSubject implements the TokenRequest interface
|
||||||
|
func (j *JWTTokenRequest) GetScopes() []string {
|
||||||
|
return j.Scopes
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenExchangeRequest struct {
|
||||||
|
subjectToken string `schema:"subject_token"`
|
||||||
|
subjectTokenType string `schema:"subject_token_type"`
|
||||||
|
actorToken string `schema:"actor_token"`
|
||||||
|
actorTokenType string `schema:"actor_token_type"`
|
||||||
|
resource []string `schema:"resource"`
|
||||||
|
audience Audience `schema:"audience"`
|
||||||
|
Scope Scopes `schema:"scope"`
|
||||||
|
requestedTokenType string `schema:"requested_token_type"`
|
||||||
|
}
|
89
pkg/oidc/types.go
Normal file
89
pkg/oidc/types.go
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
package oidc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/text/language"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Audience []string
|
||||||
|
|
||||||
|
func (a *Audience) UnmarshalJSON(text []byte) error {
|
||||||
|
var i interface{}
|
||||||
|
err := json.Unmarshal(text, &i)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
switch aud := i.(type) {
|
||||||
|
case []interface{}:
|
||||||
|
*a = make([]string, len(aud))
|
||||||
|
for i, audience := range aud {
|
||||||
|
(*a)[i] = audience.(string)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
*a = []string{aud}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Display string
|
||||||
|
|
||||||
|
func (d *Display) UnmarshalText(text []byte) error {
|
||||||
|
display := Display(text)
|
||||||
|
switch display {
|
||||||
|
case DisplayPage, DisplayPopup, DisplayTouch, DisplayWAP:
|
||||||
|
*d = display
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Gender string
|
||||||
|
|
||||||
|
type Locales []language.Tag
|
||||||
|
|
||||||
|
func (l *Locales) UnmarshalText(text []byte) error {
|
||||||
|
locales := strings.Split(string(text), " ")
|
||||||
|
for _, locale := range locales {
|
||||||
|
tag, err := language.Parse(locale)
|
||||||
|
if err == nil && !tag.IsRoot() {
|
||||||
|
*l = append(*l, tag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Prompt string
|
||||||
|
|
||||||
|
type ResponseType string
|
||||||
|
|
||||||
|
type Scopes []string
|
||||||
|
|
||||||
|
func (s Scopes) Encode() string {
|
||||||
|
return strings.Join(s, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scopes) UnmarshalText(text []byte) error {
|
||||||
|
*s = strings.Split(string(text), " ")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scopes) MarshalText() ([]byte, error) {
|
||||||
|
return []byte(s.Encode()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Time time.Time
|
||||||
|
|
||||||
|
func (t *Time) UnmarshalJSON(data []byte) error {
|
||||||
|
var i int64
|
||||||
|
if err := json.Unmarshal(data, &i); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*t = Time(time.Unix(i, 0).UTC())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Time) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(time.Time(*t).UTC().Unix())
|
||||||
|
}
|
296
pkg/oidc/types_test.go
Normal file
296
pkg/oidc/types_test.go
Normal file
|
@ -0,0 +1,296 @@
|
||||||
|
package oidc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/text/language"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAudience_UnmarshalText(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
text []byte
|
||||||
|
}
|
||||||
|
type res struct {
|
||||||
|
audience Audience
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
res res
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"invalid value",
|
||||||
|
args{
|
||||||
|
[]byte(`{"aud": {"a": }}}`),
|
||||||
|
},
|
||||||
|
res{},
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"single audience",
|
||||||
|
args{
|
||||||
|
[]byte(`{"aud": "single audience"}`),
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
[]string{"single audience"},
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"multiple audience",
|
||||||
|
args{
|
||||||
|
[]byte(`{"aud": ["multiple", "audience"]}`),
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
[]string{"multiple", "audience"},
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := new(struct {
|
||||||
|
Audience Audience `json:"aud"`
|
||||||
|
})
|
||||||
|
if err := json.Unmarshal(tt.args.text, &a); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, a.Audience, tt.res.audience)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDisplay_UnmarshalText(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
text []byte
|
||||||
|
}
|
||||||
|
type res struct {
|
||||||
|
display Display
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
res res
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"unknown value",
|
||||||
|
args{
|
||||||
|
[]byte("unknown"),
|
||||||
|
},
|
||||||
|
res{},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"page",
|
||||||
|
args{
|
||||||
|
[]byte("page"),
|
||||||
|
},
|
||||||
|
res{DisplayPage},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var d Display
|
||||||
|
if err := d.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if d != tt.res.display {
|
||||||
|
t.Errorf("Display is not correct is = %v, want %v", d, tt.res.display)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocales_UnmarshalText(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
text []byte
|
||||||
|
}
|
||||||
|
type res struct {
|
||||||
|
tags []language.Tag
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
res res
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"unknown value",
|
||||||
|
args{
|
||||||
|
[]byte("unknown"),
|
||||||
|
},
|
||||||
|
res{},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"undefined",
|
||||||
|
args{
|
||||||
|
[]byte("und"),
|
||||||
|
},
|
||||||
|
res{},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"single language",
|
||||||
|
args{
|
||||||
|
[]byte("de"),
|
||||||
|
},
|
||||||
|
res{[]language.Tag{language.German}},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"multiple languages",
|
||||||
|
args{
|
||||||
|
[]byte("de en"),
|
||||||
|
},
|
||||||
|
res{[]language.Tag{language.German, language.English}},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var locales Locales
|
||||||
|
if err := locales.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, locales, tt.res.tags)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScopes_UnmarshalText(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
text []byte
|
||||||
|
}
|
||||||
|
type res struct {
|
||||||
|
scopes []string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
res res
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"unknown value",
|
||||||
|
args{
|
||||||
|
[]byte("unknown"),
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
[]string{"unknown"},
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"struct",
|
||||||
|
args{
|
||||||
|
[]byte(`{"unknown":"value"}`),
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
[]string{`{"unknown":"value"}`},
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"openid",
|
||||||
|
args{
|
||||||
|
[]byte("openid"),
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
[]string{"openid"},
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"multiple scopes",
|
||||||
|
args{
|
||||||
|
[]byte("openid email custom:scope"),
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
[]string{"openid", "email", "custom:scope"},
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var scopes Scopes
|
||||||
|
if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, scopes, tt.res.scopes)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func TestScopes_MarshalText(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
scopes Scopes
|
||||||
|
}
|
||||||
|
type res struct {
|
||||||
|
scopes []byte
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
res res
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"unknown value",
|
||||||
|
args{
|
||||||
|
Scopes{"unknown"},
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
[]byte("unknown"),
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"struct",
|
||||||
|
args{
|
||||||
|
Scopes{`{"unknown":"value"}`},
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
[]byte(`{"unknown":"value"}`),
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"openid",
|
||||||
|
args{
|
||||||
|
Scopes{"openid"},
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
[]byte("openid"),
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"multiple scopes",
|
||||||
|
args{
|
||||||
|
Scopes{"openid", "email", "custom:scope"},
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
[]byte("openid email custom:scope"),
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
text, err := tt.args.scopes.MarshalText()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("MarshalText() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(text, tt.res.scopes) {
|
||||||
|
t.Errorf("MarshalText() is = %q, want %q", text, tt.res.scopes)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,62 +2,285 @@ package oidc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/text/language"
|
"golang.org/x/text/language"
|
||||||
|
|
||||||
|
"github.com/caos/oidc/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Userinfo struct {
|
type UserInfo interface {
|
||||||
Subject string
|
GetSubject() string
|
||||||
UserinfoProfile
|
UserInfoProfile
|
||||||
UserinfoEmail
|
UserInfoEmail
|
||||||
UserinfoPhone
|
UserInfoPhone
|
||||||
Address *UserinfoAddress
|
GetAddress() UserInfoAddress
|
||||||
|
GetClaim(key string) interface{}
|
||||||
|
}
|
||||||
|
|
||||||
Authorizations []string
|
type UserInfoProfile interface {
|
||||||
|
GetName() string
|
||||||
|
GetGivenName() string
|
||||||
|
GetFamilyName() string
|
||||||
|
GetMiddleName() string
|
||||||
|
GetNickname() string
|
||||||
|
GetProfile() string
|
||||||
|
GetPicture() string
|
||||||
|
GetWebsite() string
|
||||||
|
GetGender() Gender
|
||||||
|
GetBirthdate() string
|
||||||
|
GetZoneinfo() string
|
||||||
|
GetLocale() language.Tag
|
||||||
|
GetPreferredUsername() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserInfoEmail interface {
|
||||||
|
GetEmail() string
|
||||||
|
IsEmailVerified() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserInfoPhone interface {
|
||||||
|
GetPhoneNumber() string
|
||||||
|
IsPhoneNumberVerified() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserInfoAddress interface {
|
||||||
|
GetFormatted() string
|
||||||
|
GetStreetAddress() string
|
||||||
|
GetLocality() string
|
||||||
|
GetRegion() string
|
||||||
|
GetPostalCode() string
|
||||||
|
GetCountry() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserInfoSetter interface {
|
||||||
|
UserInfo
|
||||||
|
SetSubject(sub string)
|
||||||
|
UserInfoProfileSetter
|
||||||
|
SetEmail(email string, verified bool)
|
||||||
|
SetPhone(phone string, verified bool)
|
||||||
|
SetAddress(address UserInfoAddress)
|
||||||
|
AppendClaims(key string, values interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserInfoProfileSetter interface {
|
||||||
|
SetName(name string)
|
||||||
|
SetGivenName(name string)
|
||||||
|
SetFamilyName(name string)
|
||||||
|
SetMiddleName(name string)
|
||||||
|
SetNickname(name string)
|
||||||
|
SetUpdatedAt(date time.Time)
|
||||||
|
SetProfile(profile string)
|
||||||
|
SetPicture(profile string)
|
||||||
|
SetWebsite(website string)
|
||||||
|
SetGender(gender Gender)
|
||||||
|
SetBirthdate(birthdate string)
|
||||||
|
SetZoneinfo(zoneInfo string)
|
||||||
|
SetLocale(locale language.Tag)
|
||||||
|
SetPreferredUsername(name string)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserInfo() UserInfoSetter {
|
||||||
|
return &userinfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type userinfo struct {
|
||||||
|
Subject string `json:"sub,omitempty"`
|
||||||
|
userInfoProfile
|
||||||
|
userInfoEmail
|
||||||
|
userInfoPhone
|
||||||
|
Address UserInfoAddress `json:"address,omitempty"`
|
||||||
|
|
||||||
claims map[string]interface{}
|
claims map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserinfoProfile struct {
|
func (u *userinfo) GetSubject() string {
|
||||||
Name string
|
return u.Subject
|
||||||
GivenName string
|
|
||||||
FamilyName string
|
|
||||||
MiddleName string
|
|
||||||
Nickname string
|
|
||||||
Profile string
|
|
||||||
Picture string
|
|
||||||
Website string
|
|
||||||
Gender Gender
|
|
||||||
Birthdate string
|
|
||||||
Zoneinfo string
|
|
||||||
Locale language.Tag
|
|
||||||
UpdatedAt time.Time
|
|
||||||
PreferredUsername string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Gender string
|
func (u *userinfo) GetName() string {
|
||||||
|
return u.Name
|
||||||
type UserinfoEmail struct {
|
|
||||||
Email string
|
|
||||||
EmailVerified bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserinfoPhone struct {
|
func (u *userinfo) GetGivenName() string {
|
||||||
PhoneNumber string
|
return u.GivenName
|
||||||
PhoneNumberVerified bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserinfoAddress struct {
|
func (u *userinfo) GetFamilyName() string {
|
||||||
Formatted string
|
return u.FamilyName
|
||||||
StreetAddress string
|
|
||||||
Locality string
|
|
||||||
Region string
|
|
||||||
PostalCode string
|
|
||||||
Country string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type jsonUserinfoProfile struct {
|
func (u *userinfo) GetMiddleName() string {
|
||||||
|
return u.MiddleName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) GetNickname() string {
|
||||||
|
return u.Nickname
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) GetProfile() string {
|
||||||
|
return u.Profile
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) GetPicture() string {
|
||||||
|
return u.Picture
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) GetWebsite() string {
|
||||||
|
return u.Website
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) GetGender() Gender {
|
||||||
|
return u.Gender
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) GetBirthdate() string {
|
||||||
|
return u.Birthdate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) GetZoneinfo() string {
|
||||||
|
return u.Zoneinfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) GetLocale() language.Tag {
|
||||||
|
return u.Locale
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) GetPreferredUsername() string {
|
||||||
|
return u.PreferredUsername
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) GetEmail() string {
|
||||||
|
return u.Email
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) IsEmailVerified() bool {
|
||||||
|
return u.EmailVerified
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) GetPhoneNumber() string {
|
||||||
|
return u.PhoneNumber
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) IsPhoneNumberVerified() bool {
|
||||||
|
return u.PhoneNumberVerified
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) GetAddress() UserInfoAddress {
|
||||||
|
return u.Address
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) GetClaim(key string) interface{} {
|
||||||
|
return u.claims[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetSubject(sub string) {
|
||||||
|
u.Subject = sub
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetName(name string) {
|
||||||
|
u.Name = name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetGivenName(name string) {
|
||||||
|
u.GivenName = name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetFamilyName(name string) {
|
||||||
|
u.FamilyName = name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetMiddleName(name string) {
|
||||||
|
u.MiddleName = name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetNickname(name string) {
|
||||||
|
u.Nickname = name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetUpdatedAt(date time.Time) {
|
||||||
|
u.UpdatedAt = Time(date)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetProfile(profile string) {
|
||||||
|
u.Profile = profile
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetPicture(picture string) {
|
||||||
|
u.Picture = picture
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetWebsite(website string) {
|
||||||
|
u.Website = website
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetGender(gender Gender) {
|
||||||
|
u.Gender = gender
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetBirthdate(birthdate string) {
|
||||||
|
u.Birthdate = birthdate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetZoneinfo(zoneInfo string) {
|
||||||
|
u.Zoneinfo = zoneInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetLocale(locale language.Tag) {
|
||||||
|
u.Locale = locale
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetPreferredUsername(name string) {
|
||||||
|
u.PreferredUsername = name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetEmail(email string, verified bool) {
|
||||||
|
u.Email = email
|
||||||
|
u.EmailVerified = verified
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetPhone(phone string, verified bool) {
|
||||||
|
u.PhoneNumber = phone
|
||||||
|
u.PhoneNumberVerified = verified
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) SetAddress(address UserInfoAddress) {
|
||||||
|
u.Address = address
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userinfo) AppendClaims(key string, value interface{}) {
|
||||||
|
if u.claims == nil {
|
||||||
|
u.claims = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
u.claims[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userInfoAddress) GetFormatted() string {
|
||||||
|
return u.Formatted
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userInfoAddress) GetStreetAddress() string {
|
||||||
|
return u.StreetAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userInfoAddress) GetLocality() string {
|
||||||
|
return u.Locality
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userInfoAddress) GetRegion() string {
|
||||||
|
return u.Region
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userInfoAddress) GetPostalCode() string {
|
||||||
|
return u.PostalCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userInfoAddress) GetCountry() string {
|
||||||
|
return u.Country
|
||||||
|
}
|
||||||
|
|
||||||
|
type userInfoProfile struct {
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
GivenName string `json:"given_name,omitempty"`
|
GivenName string `json:"given_name,omitempty"`
|
||||||
FamilyName string `json:"family_name,omitempty"`
|
FamilyName string `json:"family_name,omitempty"`
|
||||||
|
@ -66,25 +289,25 @@ type jsonUserinfoProfile struct {
|
||||||
Profile string `json:"profile,omitempty"`
|
Profile string `json:"profile,omitempty"`
|
||||||
Picture string `json:"picture,omitempty"`
|
Picture string `json:"picture,omitempty"`
|
||||||
Website string `json:"website,omitempty"`
|
Website string `json:"website,omitempty"`
|
||||||
Gender string `json:"gender,omitempty"`
|
Gender Gender `json:"gender,omitempty"`
|
||||||
Birthdate string `json:"birthdate,omitempty"`
|
Birthdate string `json:"birthdate,omitempty"`
|
||||||
Zoneinfo string `json:"zoneinfo,omitempty"`
|
Zoneinfo string `json:"zoneinfo,omitempty"`
|
||||||
Locale string `json:"locale,omitempty"`
|
Locale language.Tag `json:"locale,omitempty"`
|
||||||
UpdatedAt int64 `json:"updated_at,omitempty"`
|
UpdatedAt Time `json:"updated_at,omitempty"`
|
||||||
PreferredUsername string `json:"preferred_username,omitempty"`
|
PreferredUsername string `json:"preferred_username,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type jsonUserinfoEmail struct {
|
type userInfoEmail struct {
|
||||||
Email string `json:"email,omitempty"`
|
Email string `json:"email,omitempty"`
|
||||||
EmailVerified bool `json:"email_verified,omitempty"`
|
EmailVerified bool `json:"email_verified,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type jsonUserinfoPhone struct {
|
type userInfoPhone struct {
|
||||||
Phone string `json:"phone_number,omitempty"`
|
PhoneNumber string `json:"phone_number,omitempty"`
|
||||||
PhoneVerified bool `json:"phone_number_verified,omitempty"`
|
PhoneNumberVerified bool `json:"phone_number_verified,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type jsonUserinfoAddress struct {
|
type userInfoAddress struct {
|
||||||
Formatted string `json:"formatted,omitempty"`
|
Formatted string `json:"formatted,omitempty"`
|
||||||
StreetAddress string `json:"street_address,omitempty"`
|
StreetAddress string `json:"street_address,omitempty"`
|
||||||
Locality string `json:"locality,omitempty"`
|
Locality string `json:"locality,omitempty"`
|
||||||
|
@ -93,81 +316,63 @@ type jsonUserinfoAddress struct {
|
||||||
Country string `json:"country,omitempty"`
|
Country string `json:"country,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Userinfo) MarshalJSON() ([]byte, error) {
|
func NewUserInfoAddress(streetAddress, locality, region, postalCode, country, formatted string) UserInfoAddress {
|
||||||
j := new(jsonUserinfo)
|
return &userInfoAddress{
|
||||||
j.Subject = i.Subject
|
StreetAddress: streetAddress,
|
||||||
j.setUserinfo(*i)
|
Locality: locality,
|
||||||
j.Authorizations = i.Authorizations
|
Region: region,
|
||||||
return json.Marshal(j)
|
PostalCode: postalCode,
|
||||||
|
Country: country,
|
||||||
|
Formatted: formatted,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (i *userinfo) MarshalJSON() ([]byte, error) {
|
||||||
|
type Alias userinfo
|
||||||
|
a := &struct {
|
||||||
|
*Alias
|
||||||
|
Locale interface{} `json:"locale,omitempty"`
|
||||||
|
UpdatedAt int64 `json:"updated_at,omitempty"`
|
||||||
|
}{
|
||||||
|
Alias: (*Alias)(i),
|
||||||
|
}
|
||||||
|
if !i.Locale.IsRoot() {
|
||||||
|
a.Locale = i.Locale
|
||||||
|
}
|
||||||
|
if !time.Time(i.UpdatedAt).IsZero() {
|
||||||
|
a.UpdatedAt = time.Time(i.UpdatedAt).Unix()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Userinfo) UnmmarshalJSON(data []byte) error {
|
b, err := json.Marshal(a)
|
||||||
if err := json.Unmarshal(data, i); err != nil {
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(i.claims) == 0 {
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := json.Marshal(i.claims)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("jws: invalid map of custom claims %v", i.claims)
|
||||||
|
}
|
||||||
|
return utils.ConcatenateJSON(b, claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *userinfo) UnmarshalJSON(data []byte) error {
|
||||||
|
type Alias userinfo
|
||||||
|
a := &struct {
|
||||||
|
*Alias
|
||||||
|
UpdatedAt int64 `json:"update_at,omitempty"`
|
||||||
|
}{
|
||||||
|
Alias: (*Alias)(i),
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &a); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return json.Unmarshal(data, &i.claims)
|
|
||||||
}
|
|
||||||
|
|
||||||
type jsonUserinfo struct {
|
i.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC())
|
||||||
Subject string `json:"sub,omitempty"`
|
|
||||||
jsonUserinfoProfile
|
|
||||||
jsonUserinfoEmail
|
|
||||||
jsonUserinfoPhone
|
|
||||||
JsonUserinfoAddress *jsonUserinfoAddress `json:"address,omitempty"`
|
|
||||||
Authorizations []string `json:"authorizations,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *jsonUserinfo) setUserinfo(i Userinfo) {
|
return nil
|
||||||
j.setUserinfoProfile(i.UserinfoProfile)
|
|
||||||
j.setUserinfoEmail(i.UserinfoEmail)
|
|
||||||
j.setUserinfoPhone(i.UserinfoPhone)
|
|
||||||
j.setUserinfoAddress(i.Address)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *jsonUserinfo) setUserinfoProfile(i UserinfoProfile) {
|
|
||||||
j.Name = i.Name
|
|
||||||
j.GivenName = i.GivenName
|
|
||||||
j.FamilyName = i.FamilyName
|
|
||||||
j.MiddleName = i.MiddleName
|
|
||||||
j.Nickname = i.Nickname
|
|
||||||
j.Profile = i.Profile
|
|
||||||
j.Picture = i.Picture
|
|
||||||
j.Website = i.Website
|
|
||||||
j.Gender = string(i.Gender)
|
|
||||||
j.Birthdate = i.Birthdate
|
|
||||||
j.Zoneinfo = i.Zoneinfo
|
|
||||||
if i.Locale != language.Und {
|
|
||||||
j.Locale = i.Locale.String()
|
|
||||||
}
|
|
||||||
j.UpdatedAt = timeToJSON(i.UpdatedAt)
|
|
||||||
j.PreferredUsername = i.PreferredUsername
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *jsonUserinfo) setUserinfoEmail(i UserinfoEmail) {
|
|
||||||
j.Email = i.Email
|
|
||||||
j.EmailVerified = i.EmailVerified
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *jsonUserinfo) setUserinfoPhone(i UserinfoPhone) {
|
|
||||||
j.Phone = i.PhoneNumber
|
|
||||||
j.PhoneVerified = i.PhoneNumberVerified
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j *jsonUserinfo) setUserinfoAddress(i *UserinfoAddress) {
|
|
||||||
if i == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if i.Country == "" && i.Formatted == "" && i.Locality == "" && i.PostalCode == "" && i.Region == "" && i.StreetAddress == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
j.JsonUserinfoAddress = &jsonUserinfoAddress{
|
|
||||||
Country: i.Country,
|
|
||||||
Formatted: i.Formatted,
|
|
||||||
Locality: i.Locality,
|
|
||||||
PostalCode: i.PostalCode,
|
|
||||||
Region: i.Region,
|
|
||||||
StreetAddress: i.StreetAddress,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserInfoRequest struct {
|
type UserInfoRequest struct {
|
||||||
|
|
|
@ -24,7 +24,7 @@ type Claims interface {
|
||||||
GetAuthenticationContextClassReference() string
|
GetAuthenticationContextClassReference() string
|
||||||
GetAuthTime() time.Time
|
GetAuthTime() time.Time
|
||||||
GetAuthorizedParty() string
|
GetAuthorizedParty() string
|
||||||
SetSignature(algorithm jose.SignatureAlgorithm)
|
SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -140,7 +140,7 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl
|
||||||
return ErrSignatureInvalidPayload
|
return ErrSignatureInvalidPayload
|
||||||
}
|
}
|
||||||
|
|
||||||
claims.SetSignature(jose.SignatureAlgorithm(sig.Header.Algorithm))
|
claims.SetSignatureAlgorithm(jose.SignatureAlgorithm(sig.Header.Algorithm))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -91,7 +91,8 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", ErrServerError(err.Error())
|
return "", ErrServerError(err.Error())
|
||||||
}
|
}
|
||||||
if err := ValidateAuthReqScopes(authReq.Scopes); err != nil {
|
authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes)
|
||||||
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil {
|
if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil {
|
||||||
|
@ -104,14 +105,33 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
|
||||||
}
|
}
|
||||||
|
|
||||||
//ValidateAuthReqScopes validates the passed scopes
|
//ValidateAuthReqScopes validates the passed scopes
|
||||||
func ValidateAuthReqScopes(scopes []string) error {
|
func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
|
||||||
if len(scopes) == 0 {
|
if len(scopes) == 0 {
|
||||||
return ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.")
|
return nil, ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.")
|
||||||
}
|
}
|
||||||
if !utils.Contains(scopes, oidc.ScopeOpenID) {
|
openID := false
|
||||||
return ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.")
|
for i := len(scopes) - 1; i >= 0; i-- {
|
||||||
|
scope := scopes[i]
|
||||||
|
if scope == oidc.ScopeOpenID {
|
||||||
|
openID = true
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
return nil
|
if !(scope == oidc.ScopeProfile ||
|
||||||
|
scope == oidc.ScopeEmail ||
|
||||||
|
scope == oidc.ScopePhone ||
|
||||||
|
scope == oidc.ScopeAddress ||
|
||||||
|
scope == oidc.ScopeOfflineAccess) &&
|
||||||
|
!utils.Contains(client.AllowedScopes(), scope) {
|
||||||
|
scopes[i] = scopes[len(scopes)-1]
|
||||||
|
scopes[len(scopes)-1] = ""
|
||||||
|
scopes = scopes[:len(scopes)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !openID {
|
||||||
|
return nil, ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.")
|
||||||
|
}
|
||||||
|
|
||||||
|
return scopes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type
|
//ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type
|
||||||
|
@ -168,7 +188,7 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.")
|
return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.")
|
||||||
}
|
}
|
||||||
return claims.Subject, nil
|
return claims.GetSubject(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//RedirectToLogin redirects the end user to the Login UI for authentication
|
//RedirectToLogin redirects the end user to the Login UI for authentication
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gorilla/schema"
|
"github.com/gorilla/schema"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
|
@ -193,28 +194,63 @@ func TestValidateAuthRequest(t *testing.T) {
|
||||||
|
|
||||||
func TestValidateAuthReqScopes(t *testing.T) {
|
func TestValidateAuthReqScopes(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
|
client op.Client
|
||||||
|
scopes []string
|
||||||
|
}
|
||||||
|
type res struct {
|
||||||
|
err bool
|
||||||
scopes []string
|
scopes []string
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
wantErr bool
|
res res
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"scopes missing fails", args{}, true,
|
"scopes missing fails",
|
||||||
|
args{},
|
||||||
|
res{
|
||||||
|
err: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"scope openid missing fails", args{[]string{"email"}}, true,
|
"scope openid missing fails",
|
||||||
|
args{
|
||||||
|
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
|
||||||
|
[]string{"email"},
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
err: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"scope ok", args{[]string{"openid"}}, false,
|
"scope ok",
|
||||||
|
args{
|
||||||
|
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
|
||||||
|
[]string{"openid"},
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
scopes: []string{"openid"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"scope with drop ok",
|
||||||
|
args{
|
||||||
|
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
|
||||||
|
[]string{"openid", "email", "unknown"},
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
scopes: []string{"openid", "email"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := op.ValidateAuthReqScopes(tt.args.scopes); (err != nil) != tt.wantErr {
|
scopes, err := op.ValidateAuthReqScopes(tt.args.client, tt.args.scopes)
|
||||||
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr)
|
if (err != nil) != tt.res.err {
|
||||||
|
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.res.err)
|
||||||
}
|
}
|
||||||
|
assert.ElementsMatch(t, scopes, tt.res.scopes)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,9 @@ const (
|
||||||
ApplicationTypeWeb ApplicationType = iota
|
ApplicationTypeWeb ApplicationType = iota
|
||||||
ApplicationTypeUserAgent
|
ApplicationTypeUserAgent
|
||||||
ApplicationTypeNative
|
ApplicationTypeNative
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
AccessTokenTypeBearer AccessTokenType = iota
|
AccessTokenTypeBearer AccessTokenType = iota
|
||||||
AccessTokenTypeJWT
|
AccessTokenTypeJWT
|
||||||
)
|
)
|
||||||
|
@ -32,6 +34,9 @@ type Client interface {
|
||||||
AccessTokenType() AccessTokenType
|
AccessTokenType() AccessTokenType
|
||||||
IDTokenLifetime() time.Duration
|
IDTokenLifetime() time.Duration
|
||||||
DevMode() bool
|
DevMode() bool
|
||||||
|
AllowedScopes() []string
|
||||||
|
AssertAdditionalIdTokenScopes() bool
|
||||||
|
AssertAdditionalAccessTokenScopes() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool {
|
func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool {
|
||||||
|
|
|
@ -7,6 +7,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const OidcDevMode = "CAOS_OIDC_DEV"
|
||||||
|
|
||||||
type Configuration interface {
|
type Configuration interface {
|
||||||
Issuer() string
|
Issuer() string
|
||||||
AuthorizationEndpoint() Endpoint
|
AuthorizationEndpoint() Endpoint
|
||||||
|
@ -42,7 +44,7 @@ func ValidateIssuer(issuer string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func devLocalAllowed(url *url.URL) bool {
|
func devLocalAllowed(url *url.URL) bool {
|
||||||
_, b := os.LookupEnv("CAOS_OIDC_DEV")
|
_, b := os.LookupEnv(OidcDevMode)
|
||||||
if !b {
|
if !b {
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,6 +60,8 @@ func TestValidateIssuer(t *testing.T) {
|
||||||
true,
|
true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
//ensure env is not set
|
||||||
|
os.Unsetenv(OidcDevMode)
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
|
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
|
||||||
|
@ -84,7 +86,7 @@ func TestValidateIssuerDevLocalAllowed(t *testing.T) {
|
||||||
false,
|
false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
os.Setenv("CAOS_OIDC_DEV", "")
|
os.Setenv(OidcDevMode, "true")
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
|
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
|
||||||
|
|
|
@ -72,18 +72,18 @@ func (v *Verifier) VerifyIDToken(ctx context.Context, idToken string) (*oidc.IDT
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Sig struct{}
|
type Sig struct {
|
||||||
|
signer jose.Signer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sig) Signer() jose.Signer {
|
||||||
|
return s.signer
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Sig) Health(ctx context.Context) error {
|
func (s *Sig) Health(ctx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {
|
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||||
return jose.HS256
|
return jose.HS256
|
||||||
}
|
}
|
||||||
|
@ -92,9 +92,3 @@ func ExpectStorage(a op.Authorizer, t *testing.T) {
|
||||||
mockA := a.(*MockAuthorizer)
|
mockA := a.(*MockAuthorizer)
|
||||||
mockA.EXPECT().Storage().AnyTimes().Return(NewMockStorageAny(t))
|
mockA.EXPECT().Storage().AnyTimes().Return(NewMockStorageAny(t))
|
||||||
}
|
}
|
||||||
|
|
||||||
// func NewMockSignerAny(t *testing.T) op.Signer {
|
|
||||||
// m := NewMockSigner(gomock.NewController(t))
|
|
||||||
// m.EXPECT().Sign(gomock.Any()).AnyTimes().Return("", nil)
|
|
||||||
// return m
|
|
||||||
// }
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ func NewClientExpectAny(t *testing.T, appType op.ApplicationType) op.Client {
|
||||||
func(id string) string {
|
func(id string) string {
|
||||||
return "login?id=" + id
|
return "login?id=" + id
|
||||||
})
|
})
|
||||||
|
m.EXPECT().AllowedScopes().AnyTimes().Return(nil)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,20 @@ func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowedScopes mocks base method
|
||||||
|
func (m *MockClient) AllowedScopes() []string {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "AllowedScopes")
|
||||||
|
ret0, _ := ret[0].([]string)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowedScopes indicates an expected call of AllowedScopes
|
||||||
|
func (mr *MockClientMockRecorder) AllowedScopes() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowedScopes", reflect.TypeOf((*MockClient)(nil).AllowedScopes))
|
||||||
|
}
|
||||||
|
|
||||||
// ApplicationType mocks base method
|
// ApplicationType mocks base method
|
||||||
func (m *MockClient) ApplicationType() op.ApplicationType {
|
func (m *MockClient) ApplicationType() op.ApplicationType {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -63,6 +77,34 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AssertAdditionalAccessTokenScopes mocks base method
|
||||||
|
func (m *MockClient) AssertAdditionalAccessTokenScopes() bool {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "AssertAdditionalAccessTokenScopes")
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssertAdditionalAccessTokenScopes indicates an expected call of AssertAdditionalAccessTokenScopes
|
||||||
|
func (mr *MockClientMockRecorder) AssertAdditionalAccessTokenScopes() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalAccessTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalAccessTokenScopes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssertAdditionalIdTokenScopes mocks base method
|
||||||
|
func (m *MockClient) AssertAdditionalIdTokenScopes() bool {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "AssertAdditionalIdTokenScopes")
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssertAdditionalIdTokenScopes indicates an expected call of AssertAdditionalIdTokenScopes
|
||||||
|
func (mr *MockClientMockRecorder) AssertAdditionalIdTokenScopes() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalIdTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalIdTokenScopes))
|
||||||
|
}
|
||||||
|
|
||||||
// AuthMethod mocks base method
|
// AuthMethod mocks base method
|
||||||
func (m *MockClient) AuthMethod() op.AuthMethod {
|
func (m *MockClient) AuthMethod() op.AuthMethod {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -6,7 +6,6 @@ package mock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
context "context"
|
context "context"
|
||||||
oidc "github.com/caos/oidc/pkg/oidc"
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
jose "gopkg.in/square/go-jose.v2"
|
jose "gopkg.in/square/go-jose.v2"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
@ -49,36 +48,6 @@ func (mr *MockSignerMockRecorder) Health(arg0 interface{}) *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockSigner)(nil).Health), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockSigner)(nil).Health), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignAccessToken mocks base method
|
|
||||||
func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "SignAccessToken", arg0)
|
|
||||||
ret0, _ := ret[0].(string)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignAccessToken indicates an expected call of SignAccessToken
|
|
||||||
func (mr *MockSignerMockRecorder) SignAccessToken(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignAccessToken", reflect.TypeOf((*MockSigner)(nil).SignAccessToken), arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignIDToken mocks base method
|
|
||||||
func (m *MockSigner) SignIDToken(arg0 *oidc.IDTokenClaims) (string, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "SignIDToken", arg0)
|
|
||||||
ret0, _ := ret[0].(string)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignIDToken indicates an expected call of SignIDToken
|
|
||||||
func (mr *MockSignerMockRecorder) SignIDToken(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignIDToken", reflect.TypeOf((*MockSigner)(nil).SignIDToken), arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignatureAlgorithm mocks base method
|
// SignatureAlgorithm mocks base method
|
||||||
func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
|
func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -92,3 +61,17 @@ func (mr *MockSignerMockRecorder) SignatureAlgorithm() *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigner)(nil).SignatureAlgorithm))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigner)(nil).SignatureAlgorithm))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Signer mocks base method
|
||||||
|
func (m *MockSigner) Signer() jose.Signer {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Signer")
|
||||||
|
ret0, _ := ret[0].(jose.Signer)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signer indicates an expected call of Signer
|
||||||
|
func (mr *MockSignerMockRecorder) Signer() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockSigner)(nil).Signer))
|
||||||
|
}
|
||||||
|
|
|
@ -171,6 +171,21 @@ func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPrivateClaimsFromScopes mocks base method
|
||||||
|
func (m *MockStorage) GetPrivateClaimsFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (map[string]interface{}, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetPrivateClaimsFromScopes", arg0, arg1, arg2, arg3)
|
||||||
|
ret0, _ := ret[0].(map[string]interface{})
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPrivateClaimsFromScopes indicates an expected call of GetPrivateClaimsFromScopes
|
||||||
|
func (mr *MockStorageMockRecorder) GetPrivateClaimsFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivateClaimsFromScopes", reflect.TypeOf((*MockStorage)(nil).GetPrivateClaimsFromScopes), arg0, arg1, arg2, arg3)
|
||||||
|
}
|
||||||
|
|
||||||
// GetSigningKey mocks base method
|
// GetSigningKey mocks base method
|
||||||
func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) {
|
func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -184,25 +199,25 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserinfoFromScopes mocks base method
|
// GetUserinfoFromScopes mocks base method
|
||||||
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.Userinfo, error) {
|
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (oidc.UserInfo, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2)
|
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2, arg3)
|
||||||
ret0, _ := ret[0].(*oidc.Userinfo)
|
ret0, _ := ret[0].(oidc.UserInfo)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
|
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
|
||||||
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interface{}) *gomock.Call {
|
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2, arg3)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserinfoFromToken mocks base method
|
// GetUserinfoFromToken mocks base method
|
||||||
func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2, arg3 string) (*oidc.Userinfo, error) {
|
func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2, arg3 string) (oidc.UserInfo, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2, arg3)
|
ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2, arg3)
|
||||||
ret0, _ := ret[0].(*oidc.Userinfo)
|
ret0, _ := ret[0].(oidc.UserInfo)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
|
@ -168,3 +168,12 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType {
|
||||||
func (c *ConfClient) DevMode() bool {
|
func (c *ConfClient) DevMode() bool {
|
||||||
return c.devMode
|
return c.devMode
|
||||||
}
|
}
|
||||||
|
func (c *ConfClient) AllowedScopes() []string {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c *ConfClient) AssertAdditionalIdTokenScopes() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
func (c *ConfClient) AssertAdditionalAccessTokenScopes() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
21
pkg/op/op.go
21
pkg/op/op.go
|
@ -51,6 +51,7 @@ type OpenIDProvider interface {
|
||||||
Encoder() utils.Encoder
|
Encoder() utils.Encoder
|
||||||
IDTokenHintVerifier() IDTokenHintVerifier
|
IDTokenHintVerifier() IDTokenHintVerifier
|
||||||
JWTProfileVerifier() JWTProfileVerifier
|
JWTProfileVerifier() JWTProfileVerifier
|
||||||
|
AccessTokenVerifier() AccessTokenVerifier
|
||||||
Crypto() Crypto
|
Crypto() Crypto
|
||||||
DefaultLogoutRedirectURI() string
|
DefaultLogoutRedirectURI() string
|
||||||
Signer() Signer
|
Signer() Signer
|
||||||
|
@ -130,7 +131,7 @@ func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opO
|
||||||
}
|
}
|
||||||
|
|
||||||
keyCh := make(chan jose.SigningKey)
|
keyCh := make(chan jose.SigningKey)
|
||||||
o.signer = NewDefaultSigner(ctx, storage, keyCh)
|
o.signer = NewSigner(ctx, storage, keyCh)
|
||||||
go EnsureKey(ctx, storage, keyCh, o.timer, o.retry)
|
go EnsureKey(ctx, storage, keyCh, o.timer, o.retry)
|
||||||
|
|
||||||
o.httpHandler = CreateRouter(o, o.interceptors...)
|
o.httpHandler = CreateRouter(o, o.interceptors...)
|
||||||
|
@ -152,6 +153,8 @@ type openidProvider struct {
|
||||||
signer Signer
|
signer Signer
|
||||||
idTokenHintVerifier IDTokenHintVerifier
|
idTokenHintVerifier IDTokenHintVerifier
|
||||||
jwtProfileVerifier JWTProfileVerifier
|
jwtProfileVerifier JWTProfileVerifier
|
||||||
|
accessTokenVerifier AccessTokenVerifier
|
||||||
|
keySet *openIDKeySet
|
||||||
crypto Crypto
|
crypto Crypto
|
||||||
httpHandler http.Handler
|
httpHandler http.Handler
|
||||||
decoder *schema.Decoder
|
decoder *schema.Decoder
|
||||||
|
@ -207,7 +210,7 @@ func (o *openidProvider) Encoder() utils.Encoder {
|
||||||
|
|
||||||
func (o *openidProvider) IDTokenHintVerifier() IDTokenHintVerifier {
|
func (o *openidProvider) IDTokenHintVerifier() IDTokenHintVerifier {
|
||||||
if o.idTokenHintVerifier == nil {
|
if o.idTokenHintVerifier == nil {
|
||||||
o.idTokenHintVerifier = NewIDTokenHintVerifier(o.Issuer(), &openIDKeySet{o.Storage()})
|
o.idTokenHintVerifier = NewIDTokenHintVerifier(o.Issuer(), o.openIDKeySet())
|
||||||
}
|
}
|
||||||
return o.idTokenHintVerifier
|
return o.idTokenHintVerifier
|
||||||
}
|
}
|
||||||
|
@ -219,6 +222,20 @@ func (o *openidProvider) JWTProfileVerifier() JWTProfileVerifier {
|
||||||
return o.jwtProfileVerifier
|
return o.jwtProfileVerifier
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *openidProvider) AccessTokenVerifier() AccessTokenVerifier {
|
||||||
|
if o.accessTokenVerifier == nil {
|
||||||
|
o.accessTokenVerifier = NewAccessTokenVerifier(o.Issuer(), o.openIDKeySet())
|
||||||
|
}
|
||||||
|
return o.accessTokenVerifier
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *openidProvider) openIDKeySet() oidc.KeySet {
|
||||||
|
if o.keySet == nil {
|
||||||
|
o.keySet = &openIDKeySet{o.Storage()}
|
||||||
|
}
|
||||||
|
return o.keySet
|
||||||
|
}
|
||||||
|
|
||||||
func (o *openidProvider) Crypto() Crypto {
|
func (o *openidProvider) Crypto() Crypto {
|
||||||
return o.crypto
|
return o.crypto
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,8 +66,8 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrInvalidRequest("id_token_hint invalid")
|
return nil, ErrInvalidRequest("id_token_hint invalid")
|
||||||
}
|
}
|
||||||
session.UserID = claims.Subject
|
session.UserID = claims.GetSubject()
|
||||||
session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.AuthorizedParty)
|
session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.GetAuthorizedParty())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrServerError("")
|
return nil, ErrServerError("")
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,19 +2,15 @@ package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/caos/logging"
|
"github.com/caos/logging"
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Signer interface {
|
type Signer interface {
|
||||||
Health(ctx context.Context) error
|
Health(ctx context.Context) error
|
||||||
SignIDToken(claims *oidc.IDTokenClaims) (string, error)
|
Signer() jose.Signer
|
||||||
SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
|
|
||||||
SignatureAlgorithm() jose.SignatureAlgorithm
|
SignatureAlgorithm() jose.SignatureAlgorithm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,7 +20,7 @@ type tokenSigner struct {
|
||||||
alg jose.SignatureAlgorithm
|
alg jose.SignatureAlgorithm
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDefaultSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer {
|
func NewSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer {
|
||||||
s := &tokenSigner{
|
s := &tokenSigner{
|
||||||
storage: storage,
|
storage: storage,
|
||||||
}
|
}
|
||||||
|
@ -41,6 +37,10 @@ func (s *tokenSigner) Health(_ context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *tokenSigner) Signer() jose.Signer {
|
||||||
|
return s.signer
|
||||||
|
}
|
||||||
|
|
||||||
func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) {
|
func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
@ -55,30 +55,6 @@ func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.S
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *tokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) {
|
|
||||||
payload, err := json.Marshal(claims)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return s.Sign(payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *tokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) {
|
|
||||||
payload, err := json.Marshal(claims)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return s.Sign(payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *tokenSigner) Sign(payload []byte) (string, error) {
|
|
||||||
result, err := s.signer.Sign(payload)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return result.CompactSerialize()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
|
func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||||
return s.alg
|
return s.alg
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,95 +0,0 @@
|
||||||
package op
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"gopkg.in/square/go-jose.v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
// func TestNewDefaultSigner(t *testing.T) {
|
|
||||||
// type args struct {
|
|
||||||
// storage Storage
|
|
||||||
// }
|
|
||||||
// tests := []struct {
|
|
||||||
// name string
|
|
||||||
// args args
|
|
||||||
// want Signer
|
|
||||||
// wantErr bool
|
|
||||||
// }{
|
|
||||||
// {
|
|
||||||
// "err initialize storage fails",
|
|
||||||
// args{mock.NewMockStorageSigningKeyError(t)},
|
|
||||||
// nil,
|
|
||||||
// true,
|
|
||||||
// },
|
|
||||||
// {
|
|
||||||
// "err initialize storage fails",
|
|
||||||
// args{mock.NewMockStorageSigningKeyInvalid(t)},
|
|
||||||
// nil,
|
|
||||||
// true,
|
|
||||||
// },
|
|
||||||
// {
|
|
||||||
// "initialize ok",
|
|
||||||
// args{mock.NewMockStorageSigningKey(t)},
|
|
||||||
// &idTokenSigner{Storage: mock.NewMockStorageSigningKey(t)},
|
|
||||||
// false,
|
|
||||||
// },
|
|
||||||
// }
|
|
||||||
// for _, tt := range tests {
|
|
||||||
// t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// got, err := op.NewDefaultSigner(tt.args.storage)
|
|
||||||
// if (err != nil) != tt.wantErr {
|
|
||||||
// t.Errorf("NewDefaultSigner() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
// if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
// t.Errorf("NewDefaultSigner() = %v, want %v", got, tt.want)
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
func Test_idTokenSigner_Sign(t *testing.T) {
|
|
||||||
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, &jose.SignerOptions{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
type fields struct {
|
|
||||||
signer jose.Signer
|
|
||||||
storage Storage
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
payload []byte
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"ok",
|
|
||||||
fields{signer, nil},
|
|
||||||
args{[]byte("test")},
|
|
||||||
"eyJhbGciOiJIUzI1NiJ9.dGVzdA.SxYZRsvB_Dr4F7SEFuYXvkMZqCCwzpsPOQXl-vLPEww",
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := &tokenSigner{
|
|
||||||
signer: tt.fields.signer,
|
|
||||||
storage: tt.fields.storage,
|
|
||||||
}
|
|
||||||
got, err := s.Sign(tt.args.payload)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("idTokenSigner.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("idTokenSigner.Sign() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -26,10 +26,11 @@ type AuthStorage interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type OPStorage interface {
|
type OPStorage interface {
|
||||||
GetClientByClientID(context.Context, string) (Client, error)
|
GetClientByClientID(ctx context.Context, clientID string) (Client, error)
|
||||||
AuthorizeClientIDSecret(context.Context, string, string) error
|
AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error
|
||||||
GetUserinfoFromScopes(context.Context, string, []string) (*oidc.Userinfo, error)
|
GetUserinfoFromScopes(ctx context.Context, userID, clientID string, scopes []string) (oidc.UserInfo, error)
|
||||||
GetUserinfoFromToken(ctx context.Context, tokenID, subject, origin string) (*oidc.Userinfo, error)
|
GetUserinfoFromToken(ctx context.Context, tokenID, subject, origin string) (oidc.UserInfo, error)
|
||||||
|
GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]interface{}, error)
|
||||||
GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
|
GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
108
pkg/op/token.go
108
pkg/op/token.go
|
@ -5,6 +5,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
|
"github.com/caos/oidc/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TokenCreator interface {
|
type TokenCreator interface {
|
||||||
|
@ -25,12 +26,12 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client
|
||||||
var validity time.Duration
|
var validity time.Duration
|
||||||
if createAccessToken {
|
if createAccessToken {
|
||||||
var err error
|
var err error
|
||||||
accessToken, validity, err = CreateAccessToken(ctx, authReq, client.AccessTokenType(), creator)
|
accessToken, validity, err = CreateAccessToken(ctx, authReq, client.AccessTokenType(), creator, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer())
|
idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer(), client.AssertAdditionalIdTokenScopes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -50,7 +51,7 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator) (*oidc.AccessTokenResponse, error) {
|
func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator) (*oidc.AccessTokenResponse, error) {
|
||||||
accessToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator)
|
accessToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -63,17 +64,17 @@ func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, crea
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateAccessToken(ctx context.Context, authReq TokenRequest, accessTokenType AccessTokenType, creator TokenCreator) (token string, validity time.Duration, err error) {
|
func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client Client) (token string, validity time.Duration, err error) {
|
||||||
id, exp, err := creator.Storage().CreateToken(ctx, authReq)
|
id, exp, err := creator.Storage().CreateToken(ctx, tokenRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", 0, err
|
return "", 0, err
|
||||||
}
|
}
|
||||||
validity = exp.Sub(time.Now().UTC())
|
validity = exp.Sub(time.Now().UTC())
|
||||||
if accessTokenType == AccessTokenTypeJWT {
|
if accessTokenType == AccessTokenTypeJWT {
|
||||||
token, err = CreateJWT(creator.Issuer(), authReq, exp, id, creator.Signer())
|
token, err = CreateJWT(ctx, creator.Issuer(), tokenRequest, exp, id, creator.Signer(), client, creator.Storage())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
token, err = CreateBearerToken(id, authReq.GetSubject(), creator.Crypto())
|
token, err = CreateBearerToken(id, tokenRequest.GetSubject(), creator.Crypto())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,52 +82,79 @@ func CreateBearerToken(tokenID, subject string, crypto Crypto) (string, error) {
|
||||||
return crypto.Encrypt(tokenID + ":" + subject)
|
return crypto.Encrypt(tokenID + ":" + subject)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateJWT(issuer string, authReq TokenRequest, exp time.Time, id string, signer Signer) (string, error) {
|
func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, signer Signer, client Client, storage Storage) (string, error) {
|
||||||
now := time.Now().UTC()
|
claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id)
|
||||||
nbf := now
|
if client != nil && client.AssertAdditionalAccessTokenScopes() {
|
||||||
claims := &oidc.AccessTokenClaims{
|
privateClaims, err := storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(tokenRequest.GetScopes()))
|
||||||
Issuer: issuer,
|
if err != nil {
|
||||||
Subject: authReq.GetSubject(),
|
return "", err
|
||||||
Audiences: authReq.GetAudience(),
|
|
||||||
Expiration: exp,
|
|
||||||
IssuedAt: now,
|
|
||||||
NotBefore: nbf,
|
|
||||||
JWTID: id,
|
|
||||||
}
|
}
|
||||||
return signer.SignAccessToken(claims)
|
claims.SetPrivateClaims(privateClaims)
|
||||||
|
}
|
||||||
|
return utils.Sign(claims, signer.Signer())
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer) (string, error) {
|
func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer, additonalScopes bool) (string, error) {
|
||||||
var err error
|
|
||||||
exp := time.Now().UTC().Add(validity)
|
exp := time.Now().UTC().Add(validity)
|
||||||
userinfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes())
|
claims := oidc.NewIDTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, authReq.GetAuthTime(), authReq.GetNonce(), authReq.GetACR(), authReq.GetAMR(), authReq.GetClientID())
|
||||||
if err != nil {
|
scopes := authReq.GetScopes()
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
claims := &oidc.IDTokenClaims{
|
|
||||||
Issuer: issuer,
|
|
||||||
Audiences: authReq.GetAudience(),
|
|
||||||
Expiration: exp,
|
|
||||||
IssuedAt: time.Now().UTC(),
|
|
||||||
AuthTime: authReq.GetAuthTime(),
|
|
||||||
Nonce: authReq.GetNonce(),
|
|
||||||
AuthenticationContextClassReference: authReq.GetACR(),
|
|
||||||
AuthenticationMethodsReferences: authReq.GetAMR(),
|
|
||||||
AuthorizedParty: authReq.GetClientID(),
|
|
||||||
Userinfo: *userinfo,
|
|
||||||
}
|
|
||||||
if accessToken != "" {
|
if accessToken != "" {
|
||||||
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
|
atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
claims.SetAccessTokenHash(atHash)
|
||||||
|
scopes = removeUserinfoScopes(scopes)
|
||||||
|
}
|
||||||
|
if !additonalScopes {
|
||||||
|
scopes = removeAdditionalScopes(scopes)
|
||||||
|
}
|
||||||
|
if len(scopes) > 0 {
|
||||||
|
userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetClientID(), scopes)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
claims.SetUserinfo(userInfo)
|
||||||
}
|
}
|
||||||
if code != "" {
|
if code != "" {
|
||||||
claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm())
|
codeHash, err := oidc.ClaimHash(code, signer.SignatureAlgorithm())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
claims.SetCodeHash(codeHash)
|
||||||
}
|
}
|
||||||
|
|
||||||
return signer.SignIDToken(claims)
|
return utils.Sign(claims, signer.Signer())
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeUserinfoScopes(scopes []string) []string {
|
||||||
|
for i := len(scopes) - 1; i >= 0; i-- {
|
||||||
|
if scopes[i] == oidc.ScopeProfile ||
|
||||||
|
scopes[i] == oidc.ScopeEmail ||
|
||||||
|
scopes[i] == oidc.ScopeAddress ||
|
||||||
|
scopes[i] == oidc.ScopePhone {
|
||||||
|
|
||||||
|
scopes[i] = scopes[len(scopes)-1]
|
||||||
|
scopes[len(scopes)-1] = ""
|
||||||
|
scopes = scopes[:len(scopes)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scopes
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeAdditionalScopes(scopes []string) []string {
|
||||||
|
for i := len(scopes) - 1; i >= 0; i-- {
|
||||||
|
if !(scopes[i] == oidc.ScopeOpenID ||
|
||||||
|
scopes[i] == oidc.ScopeProfile ||
|
||||||
|
scopes[i] == oidc.ScopeEmail ||
|
||||||
|
scopes[i] == oidc.ScopeAddress ||
|
||||||
|
scopes[i] == oidc.ScopePhone) {
|
||||||
|
|
||||||
|
scopes[i] = scopes[len(scopes)-1]
|
||||||
|
scopes[len(scopes)-1] = ""
|
||||||
|
scopes = scopes[:len(scopes)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scopes
|
||||||
}
|
}
|
||||||
|
|
|
@ -138,18 +138,18 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque
|
||||||
}
|
}
|
||||||
|
|
||||||
func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
assertion, err := ParseJWTProfileRequest(r, exchanger.Decoder())
|
profileRequest, err := ParseJWTProfileRequest(r, exchanger.Decoder())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, err := VerifyJWTAssertion(r.Context(), assertion, exchanger.JWTProfileVerifier())
|
tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest, exchanger.JWTProfileVerifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := CreateJWTTokenResponse(r.Context(), claims, exchanger)
|
resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, exchanger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
return
|
return
|
||||||
|
@ -157,17 +157,17 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
utils.MarshalJSON(w, resp)
|
utils.MarshalJSON(w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (string, error) {
|
func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (*tokenexchange.JWTProfileRequest, error) {
|
||||||
err := r.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", ErrInvalidRequest("error parsing form")
|
return nil, ErrInvalidRequest("error parsing form")
|
||||||
}
|
}
|
||||||
tokenReq := new(tokenexchange.JWTProfileRequest)
|
tokenReq := new(tokenexchange.JWTProfileRequest)
|
||||||
err = decoder.Decode(tokenReq, r.Form)
|
err = decoder.Decode(tokenReq, r.Form)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", ErrInvalidRequest("error decoding form")
|
return nil, ErrInvalidRequest("error decoding form")
|
||||||
}
|
}
|
||||||
return tokenReq.Assertion, nil
|
return tokenReq, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package op
|
package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -13,6 +14,7 @@ type UserinfoProvider interface {
|
||||||
Decoder() utils.Decoder
|
Decoder() utils.Decoder
|
||||||
Crypto() Crypto
|
Crypto() Crypto
|
||||||
Storage() Storage
|
Storage() Storage
|
||||||
|
AccessTokenVerifier() AccessTokenVerifier
|
||||||
}
|
}
|
||||||
|
|
||||||
func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) {
|
func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) {
|
||||||
|
@ -27,17 +29,12 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP
|
||||||
http.Error(w, "access token missing", http.StatusUnauthorized)
|
http.Error(w, "access token missing", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
|
tokenID, subject, ok := getTokenIDAndSubject(r.Context(), userinfoProvider, accessToken)
|
||||||
if err != nil {
|
if !ok {
|
||||||
http.Error(w, "access token missing", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
splittedToken := strings.Split(tokenIDSubject, ":")
|
|
||||||
if len(splittedToken) != 2 {
|
|
||||||
http.Error(w, "access token invalid", http.StatusUnauthorized)
|
http.Error(w, "access token invalid", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), splittedToken[0], splittedToken[1], r.Header.Get("origin"))
|
info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), tokenID, subject, r.Header.Get("origin"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusForbidden)
|
w.WriteHeader(http.StatusForbidden)
|
||||||
utils.MarshalJSON(w, err)
|
utils.MarshalJSON(w, err)
|
||||||
|
@ -66,3 +63,19 @@ func getAccessToken(r *http.Request, decoder utils.Decoder) (string, error) {
|
||||||
}
|
}
|
||||||
return req.AccessToken, nil
|
return req.AccessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) {
|
||||||
|
tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
|
||||||
|
if err == nil {
|
||||||
|
splitToken := strings.Split(tokenIDSubject, ":")
|
||||||
|
if len(splitToken) != 2 {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
return splitToken[0], splitToken[1], true
|
||||||
|
}
|
||||||
|
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier())
|
||||||
|
if err != nil {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true
|
||||||
|
}
|
||||||
|
|
85
pkg/op/verifier_access_token.go
Normal file
85
pkg/op/verifier_access_token.go
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
package op
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AccessTokenVerifier interface {
|
||||||
|
oidc.Verifier
|
||||||
|
SupportedSignAlgs() []string
|
||||||
|
KeySet() oidc.KeySet
|
||||||
|
}
|
||||||
|
|
||||||
|
type accessTokenVerifier struct {
|
||||||
|
issuer string
|
||||||
|
maxAgeIAT time.Duration
|
||||||
|
offset time.Duration
|
||||||
|
supportedSignAlgs []string
|
||||||
|
maxAge time.Duration
|
||||||
|
acr oidc.ACRVerifier
|
||||||
|
keySet oidc.KeySet
|
||||||
|
}
|
||||||
|
|
||||||
|
//Issuer implements oidc.Verifier interface
|
||||||
|
func (i *accessTokenVerifier) Issuer() string {
|
||||||
|
return i.issuer
|
||||||
|
}
|
||||||
|
|
||||||
|
//MaxAgeIAT implements oidc.Verifier interface
|
||||||
|
func (i *accessTokenVerifier) MaxAgeIAT() time.Duration {
|
||||||
|
return i.maxAgeIAT
|
||||||
|
}
|
||||||
|
|
||||||
|
//Offset implements oidc.Verifier interface
|
||||||
|
func (i *accessTokenVerifier) Offset() time.Duration {
|
||||||
|
return i.offset
|
||||||
|
}
|
||||||
|
|
||||||
|
//SupportedSignAlgs implements AccessTokenVerifier interface
|
||||||
|
func (i *accessTokenVerifier) SupportedSignAlgs() []string {
|
||||||
|
return i.supportedSignAlgs
|
||||||
|
}
|
||||||
|
|
||||||
|
//KeySet implements AccessTokenVerifier interface
|
||||||
|
func (i *accessTokenVerifier) KeySet() oidc.KeySet {
|
||||||
|
return i.keySet
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet) AccessTokenVerifier {
|
||||||
|
verifier := &idTokenHintVerifier{
|
||||||
|
issuer: issuer,
|
||||||
|
keySet: keySet,
|
||||||
|
}
|
||||||
|
return verifier
|
||||||
|
}
|
||||||
|
|
||||||
|
//VerifyAccessToken validates the access token (issuer, signature and expiration)
|
||||||
|
func VerifyAccessToken(ctx context.Context, token string, v AccessTokenVerifier) (oidc.AccessTokenClaims, error) {
|
||||||
|
claims := oidc.EmptyAccessTokenClaims()
|
||||||
|
|
||||||
|
decrypted, err := oidc.DecryptToken(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload, err := oidc.ParseToken(decrypted, claims)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims, nil
|
||||||
|
}
|
|
@ -63,8 +63,8 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet) IDTokenHintVerifi
|
||||||
|
|
||||||
//VerifyIDTokenHint validates the id token according to
|
//VerifyIDTokenHint validates the id token according to
|
||||||
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||||
func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (*oidc.IDTokenClaims, error) {
|
func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (oidc.IDTokenClaims, error) {
|
||||||
claims := new(oidc.IDTokenClaims)
|
claims := oidc.EmptyIDTokenClaims()
|
||||||
|
|
||||||
decrypted, err := oidc.DecryptToken(token)
|
decrypted, err := oidc.DecryptToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
|
"github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JWTProfileVerifier interface {
|
type JWTProfileVerifier interface {
|
||||||
|
@ -47,9 +48,9 @@ func (v *jwtProfileVerifier) Offset() time.Duration {
|
||||||
return v.offset
|
return v.offset
|
||||||
}
|
}
|
||||||
|
|
||||||
func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) {
|
func VerifyJWTAssertion(ctx context.Context, profileRequest *tokenexchange.JWTProfileRequest, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) {
|
||||||
request := new(oidc.JWTTokenRequest)
|
request := new(oidc.JWTTokenRequest)
|
||||||
payload, err := oidc.ParseToken(assertion, request)
|
payload, err := oidc.ParseToken(profileRequest.Assertion, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -72,9 +73,10 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif
|
||||||
|
|
||||||
keySet := &jwtProfileKeySet{v.Storage(), request.Subject}
|
keySet := &jwtProfileKeySet{v.Storage(), request.Subject}
|
||||||
|
|
||||||
if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil {
|
if err = oidc.CheckSignature(ctx, profileRequest.Assertion, payload, request, nil, keySet); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
request.Scopes = profileRequest.Scope
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,37 +0,0 @@
|
||||||
package mock
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
|
||||||
"github.com/caos/oidc/pkg/rp"
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewVerifier(t *testing.T) rp.Verifier {
|
|
||||||
return NewMockVerifier(gomock.NewController(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMockVerifierExpectInvalid(t *testing.T) rp.Verifier {
|
|
||||||
m := NewVerifier(t)
|
|
||||||
ExpectVerifyInvalid(m)
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExpectVerifyInvalid(v rp.Verifier) {
|
|
||||||
mock := v.(*MockVerifier)
|
|
||||||
mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(nil, errors.New("invalid"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMockVerifierExpectValid(t *testing.T) rp.Verifier {
|
|
||||||
m := NewVerifier(t)
|
|
||||||
ExpectVerifyValid(m)
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExpectVerifyValid(v rp.Verifier) {
|
|
||||||
mock := v.(*MockVerifier)
|
|
||||||
mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.Userinfo{Subject: "id"}}, nil)
|
|
||||||
}
|
|
|
@ -4,9 +4,13 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/gorilla/schema"
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
"github.com/caos/oidc/pkg/oidc/grants"
|
"github.com/caos/oidc/pkg/oidc/grants"
|
||||||
|
@ -22,6 +26,16 @@ const (
|
||||||
jwtProfileKey = "urn:ietf:params:oauth:grant-type:jwt-bearer"
|
jwtProfileKey = "urn:ietf:params:oauth:grant-type:jwt-bearer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
encoder = func() utils.Encoder {
|
||||||
|
e := schema.NewEncoder()
|
||||||
|
e.RegisterEncoder(oidc.Scopes{}, func(value reflect.Value) string {
|
||||||
|
return value.Interface().(oidc.Scopes).Encode()
|
||||||
|
})
|
||||||
|
return e
|
||||||
|
}()
|
||||||
|
)
|
||||||
|
|
||||||
//RelayingParty declares the minimal interface for oidc clients
|
//RelayingParty declares the minimal interface for oidc clients
|
||||||
type RelayingParty interface {
|
type RelayingParty interface {
|
||||||
//OAuthConfig returns the oauth2 Config
|
//OAuthConfig returns the oauth2 Config
|
||||||
|
@ -312,38 +326,45 @@ func CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc
|
||||||
//ClientCredentials is the `RelayingParty` interface implementation
|
//ClientCredentials is the `RelayingParty` interface implementation
|
||||||
//handling the oauth2 client credentials grant
|
//handling the oauth2 client credentials grant
|
||||||
func ClientCredentials(ctx context.Context, rp RelayingParty, scopes ...string) (newToken *oauth2.Token, err error) {
|
func ClientCredentials(ctx context.Context, rp RelayingParty, scopes ...string) (newToken *oauth2.Token, err error) {
|
||||||
return CallTokenEndpoint(grants.ClientCredentialsGrantBasic(scopes...), rp)
|
return CallTokenEndpointAuthorized(grants.ClientCredentialsGrantBasic(scopes...), rp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallTokenEndpointAuthorized(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
|
||||||
|
config := rp.OAuthConfig()
|
||||||
|
var fn interface{} = utils.AuthorizeBasic(config.ClientID, config.ClientSecret)
|
||||||
|
if config.Endpoint.AuthStyle == oauth2.AuthStyleInParams {
|
||||||
|
fn = func(form url.Values) {
|
||||||
|
form.Set("client_id", config.ClientID)
|
||||||
|
form.Set("client_secret", config.ClientSecret)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return callTokenEndpoint(request, fn, rp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CallTokenEndpoint(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
|
func CallTokenEndpoint(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
|
||||||
config := rp.OAuthConfig()
|
return callTokenEndpoint(request, nil, rp)
|
||||||
req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, config.ClientID, config.ClientSecret, config.Endpoint.AuthStyle != oauth2.AuthStyleInParams)
|
}
|
||||||
|
|
||||||
|
func callTokenEndpoint(request interface{}, authFn interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
|
||||||
|
req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, encoder, authFn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
token := new(oauth2.Token)
|
var tokenRes struct {
|
||||||
if err := utils.HttpRequest(rp.HttpClient(), req, token); err != nil {
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
if err := utils.HttpRequest(rp.HttpClient(), req, &tokenRes); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return token, nil
|
return &oauth2.Token{
|
||||||
}
|
AccessToken: tokenRes.AccessToken,
|
||||||
|
TokenType: tokenRes.TokenType,
|
||||||
func CallJWTProfileEndpoint(assertion string, rp RelayingParty) (*oauth2.Token, error) {
|
RefreshToken: tokenRes.RefreshToken,
|
||||||
form := make(map[string][]string)
|
Expiry: time.Now().UTC().Add(time.Duration(tokenRes.ExpiresIn) * time.Second),
|
||||||
form["assertion"] = []string{assertion}
|
}, nil
|
||||||
form["grant_type"] = []string{jwtProfileKey}
|
|
||||||
req, err := http.NewRequest("POST", rp.OAuthConfig().Endpoint.TokenURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
token := new(oauth2.Token)
|
|
||||||
if err := utils.HttpRequest(rp.HttpClient(), req, token); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return token, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func trySetStateCookie(w http.ResponseWriter, state string, rp RelayingParty) error {
|
func trySetStateCookie(w http.ResponseWriter, state string, rp RelayingParty) error {
|
||||||
|
|
|
@ -43,12 +43,17 @@ func DelegationTokenExchange(ctx context.Context, subjectToken string, rp Relayi
|
||||||
}
|
}
|
||||||
|
|
||||||
//JWTProfileExchange handles the oauth2 jwt profile exchange
|
//JWTProfileExchange handles the oauth2 jwt profile exchange
|
||||||
func JWTProfileExchange(ctx context.Context, assertion *oidc.JWTProfileAssertion, rp RelayingParty) (*oauth2.Token, error) {
|
func JWTProfileExchange(ctx context.Context, jwtProfileRequest *tokenexchange.JWTProfileRequest, rp RelayingParty) (*oauth2.Token, error) {
|
||||||
|
return CallTokenEndpoint(jwtProfileRequest, rp)
|
||||||
|
}
|
||||||
|
|
||||||
|
//JWTProfileExchange handles the oauth2 jwt profile exchange
|
||||||
|
func JWTProfileAssertionExchange(ctx context.Context, assertion *oidc.JWTProfileAssertion, scopes oidc.Scopes, rp RelayingParty) (*oauth2.Token, error) {
|
||||||
token, err := generateJWTProfileToken(assertion)
|
token, err := generateJWTProfileToken(assertion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return CallJWTProfileEndpoint(token, rp)
|
return JWTProfileExchange(ctx, tokenexchange.NewJWTProfileRequest(token, scopes...), rp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateJWTProfileToken(assertion *oidc.JWTProfileAssertion) (string, error) {
|
func generateJWTProfileToken(assertion *oidc.JWTProfileAssertion) (string, error) {
|
||||||
|
|
|
@ -21,12 +21,12 @@ type IDTokenVerifier interface {
|
||||||
|
|
||||||
//VerifyTokens implement the Token Response Validation as defined in OIDC specification
|
//VerifyTokens implement the Token Response Validation as defined in OIDC specification
|
||||||
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
||||||
func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) {
|
func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
|
||||||
idToken, err := VerifyIDToken(ctx, idTokenString, v)
|
idToken, err := VerifyIDToken(ctx, idTokenString, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := VerifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil {
|
if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return idToken, nil
|
return idToken, nil
|
||||||
|
@ -34,8 +34,8 @@ func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTo
|
||||||
|
|
||||||
//VerifyIDToken validates the id token according to
|
//VerifyIDToken validates the id token according to
|
||||||
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||||
func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) {
|
func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
|
||||||
claims := new(oidc.IDTokenClaims)
|
claims := oidc.EmptyIDTokenClaims()
|
||||||
|
|
||||||
decrypted, err := oidc.DecryptToken(token)
|
decrypted, err := oidc.DecryptToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -10,8 +10,6 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/schema"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -27,23 +25,30 @@ type Encoder interface {
|
||||||
Encode(src interface{}, dst map[string][]string) error
|
Encode(src interface{}, dst map[string][]string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func FormRequest(endpoint string, request interface{}, clientID, clientSecret string, header bool) (*http.Request, error) {
|
type FormAuthorization func(url.Values)
|
||||||
form := make(map[string][]string)
|
type RequestAuthorization func(*http.Request)
|
||||||
encoder := schema.NewEncoder()
|
|
||||||
|
func AuthorizeBasic(user, password string) RequestAuthorization {
|
||||||
|
return func(req *http.Request) {
|
||||||
|
req.SetBasicAuth(user, password)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func FormRequest(endpoint string, request interface{}, encoder Encoder, authFn interface{}) (*http.Request, error) {
|
||||||
|
form := url.Values{}
|
||||||
if err := encoder.Encode(request, form); err != nil {
|
if err := encoder.Encode(request, form); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !header {
|
if fn, ok := authFn.(FormAuthorization); ok {
|
||||||
form["client_id"] = []string{clientID}
|
fn(form)
|
||||||
form["client_secret"] = []string{clientSecret}
|
|
||||||
}
|
}
|
||||||
body := strings.NewReader(url.Values(form).Encode())
|
body := strings.NewReader(form.Encode())
|
||||||
req, err := http.NewRequest("POST", endpoint, body)
|
req, err := http.NewRequest("POST", endpoint, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if header {
|
if fn, ok := authFn.(RequestAuthorization); ok {
|
||||||
req.SetBasicAuth(clientID, clientSecret)
|
fn(req)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
return req, nil
|
return req, nil
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
@ -19,3 +21,15 @@ func MarshalJSON(w http.ResponseWriter, i interface{}) {
|
||||||
logrus.Error("error writing response")
|
logrus.Error("error writing response")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ConcatenateJSON(first, second []byte) ([]byte, error) {
|
||||||
|
if !bytes.HasSuffix(first, []byte{'}'}) {
|
||||||
|
return nil, fmt.Errorf("jws: invalid JSON %s", first)
|
||||||
|
}
|
||||||
|
if !bytes.HasPrefix(second, []byte{'{'}) {
|
||||||
|
return nil, fmt.Errorf("jws: invalid JSON %s", second)
|
||||||
|
}
|
||||||
|
first[len(first)-1] = ','
|
||||||
|
first = append(first, second[1:]...)
|
||||||
|
return first, nil
|
||||||
|
}
|
||||||
|
|
60
pkg/utils/marshal_test.go
Normal file
60
pkg/utils/marshal_test.go
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConcatenateJSON(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
first []byte
|
||||||
|
second []byte
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"invalid first part, error",
|
||||||
|
args{
|
||||||
|
[]byte(`invalid`),
|
||||||
|
[]byte(`{"some": "thing"}`),
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid second part, error",
|
||||||
|
args{
|
||||||
|
[]byte(`{"some": "thing"}`),
|
||||||
|
[]byte(`invalid`),
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"both valid, merged",
|
||||||
|
args{
|
||||||
|
[]byte(`{"some": "thing"}`),
|
||||||
|
[]byte(`{"another": "thing"}`),
|
||||||
|
},
|
||||||
|
|
||||||
|
[]byte(`{"some": "thing","another": "thing"}`),
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := ConcatenateJSON(tt.args.first, tt.args.second)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("ConcatenateJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got, tt.want) {
|
||||||
|
t.Errorf("ConcatenateJSON() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
23
pkg/utils/sign.go
Normal file
23
pkg/utils/sign.go
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Sign(object interface{}, signer jose.Signer) (string, error) {
|
||||||
|
payload, err := json.Marshal(object)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return SignPayload(payload, signer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SignPayload(payload []byte, signer jose.Signer) (string, error) {
|
||||||
|
result, err := signer.Sign(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return result.CompactSerialize()
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue