Merge pull request #60 from caos/serializing

feat: private claims (incl. serialisation refactoring and jwt profile fix)
This commit is contained in:
Fabi 2020-10-15 15:27:00 +02:00 committed by GitHub
commit c1699a2d93
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
43 changed files with 1896 additions and 980 deletions

View file

@ -4,6 +4,8 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"html/template"
"io/ioutil"
"net/http" "net/http"
"os" "os"
"time" "time"
@ -30,7 +32,7 @@ func main() {
ctx := context.Background() ctx := context.Background()
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail} scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress}
cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes,
rp.WithPKCE(cookieHandler), rp.WithPKCE(cookieHandler),
@ -82,6 +84,66 @@ func main() {
} }
w.Write(data) w.Write(data)
}) })
http.HandleFunc("/jwt-profile", func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
tpl := `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Login</title>
</head>
<body>
<form method="POST" action="/jwt-profile" enctype="multipart/form-data">
<label for="key">Select a key file:</label>
<input type="file" accept=".json" id="key" name="key">
<button type="submit">Get Token</button>
</form>
</body>
</html>`
t, err := template.New("login").Parse(tpl)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
err = t.Execute(w, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
} else {
err := r.ParseMultipartForm(4 << 10)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
file, handler, err := r.FormFile("key")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer file.Close()
key, err := ioutil.ReadAll(file)
fmt.Println(handler.Header)
assertion, err := oidc.NewJWTProfileAssertionFromFileData(key, []string{issuer})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
token, err := rp.JWTProfileAssertionExchange(ctx, assertion, scopes, provider)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
data, err := json.Marshal(token)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write(data)
}
})
lis := fmt.Sprintf("127.0.0.1:%s", port) lis := fmt.Sprintf("127.0.0.1:%s", port)
logrus.Infof("listening on http://%s/", lis) logrus.Infof("listening on http://%s/", lis)
logrus.Fatal(http.ListenAndServe("127.0.0.1:"+port, nil)) logrus.Fatal(http.ListenAndServe("127.0.0.1:"+port, nil))

View file

@ -45,7 +45,7 @@ func main() {
} }
token := cli.CodeFlow(relayingParty, callbackPath, port, state) token := cli.CodeFlow(relayingParty, callbackPath, port, state)
client := github.NewClient(relayingParty.Client(ctx, token.Token)) client := github.NewClient(relayingParty.OAuthConfig().Client(ctx, token.Token))
_, _, err = client.Users.Get(ctx, "") _, _, err = client.Users.Get(ctx, "")
if err != nil { if err != nil {

View file

@ -210,31 +210,21 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st
return nil return nil
} }
func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _, _ string) (*oidc.Userinfo, error) { func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _, _ string) (oidc.UserInfo, error) {
return s.GetUserinfoFromScopes(ctx, "", []string{}) return s.GetUserinfoFromScopes(ctx, "", "", []string{})
} }
func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (*oidc.Userinfo, error) { func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _, _ string, _ []string) (oidc.UserInfo, error) {
return &oidc.Userinfo{ userinfo := oidc.NewUserInfo()
Subject: a.GetSubject(), userinfo.SetSubject(a.GetSubject())
Address: &oidc.UserinfoAddress{ userinfo.SetAddress(oidc.NewUserInfoAddress("Test 789\nPostfach 2", "", "", "", "", ""))
StreetAddress: "Hjkhkj 789\ndsf", userinfo.SetEmail("test", true)
}, userinfo.SetPhone("0791234567", true)
UserinfoEmail: oidc.UserinfoEmail{ userinfo.SetName("Test")
Email: "test", userinfo.AppendClaims("private_claim", "test")
EmailVerified: true, return userinfo, nil
}, }
UserinfoPhone: oidc.UserinfoPhone{ func (s *AuthStorage) GetPrivateClaimsFromScopes(_ context.Context, _, _ string, _ []string) (map[string]interface{}, error) {
PhoneNumber: "sadsa", return map[string]interface{}{"private_claim": "test"}, nil
PhoneNumberVerified: true,
},
UserinfoProfile: oidc.UserinfoProfile{
UpdatedAt: time.Now(),
},
// Claims: map[string]interface{}{
// "test": "test",
// "hkjh": "",
// },
}, nil
} }
type ConfClient struct { type ConfClient struct {
@ -289,3 +279,15 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType {
func (c *ConfClient) DevMode() bool { func (c *ConfClient) DevMode() bool {
return c.devMode return c.devMode
} }
func (c *ConfClient) AllowedScopes() []string {
return nil
}
func (c *ConfClient) AssertAdditionalIdTokenScopes() bool {
return false
}
func (c *ConfClient) AssertAdditionalAccessTokenScopes() bool {
return false
}

View file

@ -1,15 +1,5 @@
package oidc package oidc
import (
"encoding/json"
"errors"
"strings"
"time"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
)
const ( const (
//ScopeOpenID defines the scope `openid` //ScopeOpenID defines the scope `openid`
//OpenID Connect requests MUST contain the `openid` scope value //OpenID Connect requests MUST contain the `openid` scope value
@ -64,23 +54,8 @@ const (
//PromptSelectAccount (`select_account `) directs the Authorization Server to prompt the End-User to select a user account (to enable multi user / session switching) //PromptSelectAccount (`select_account `) directs the Authorization Server to prompt the End-User to select a user account (to enable multi user / session switching)
PromptSelectAccount Prompt = "select_account" PromptSelectAccount Prompt = "select_account"
//GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
GrantTypeCode GrantType = "authorization_code"
//GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
//BearerToken defines the token_type `Bearer`, which is returned in a successful token response
BearerToken = "Bearer"
) )
var displayValues = map[string]Display{
"page": DisplayPage,
"popup": DisplayPopup,
"touch": DisplayTouch,
"wap": DisplayWAP,
}
//AuthRequest according to: //AuthRequest according to:
//https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest //https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
type AuthRequest struct { type AuthRequest struct {
@ -121,146 +96,3 @@ func (a *AuthRequest) GetResponseType() ResponseType {
func (a *AuthRequest) GetState() string { func (a *AuthRequest) GetState() string {
return a.State return a.State
} }
type TokenRequest interface {
// GrantType GrantType `schema:"grant_type"`
GrantType() GrantType
}
type TokenRequestType GrantType
type AccessTokenRequest struct {
Code string `schema:"code"`
RedirectURI string `schema:"redirect_uri"`
ClientID string `schema:"client_id"`
ClientSecret string `schema:"client_secret"`
CodeVerifier string `schema:"code_verifier"`
}
func (a *AccessTokenRequest) GrantType() GrantType {
return GrantTypeCode
}
type AccessTokenResponse struct {
AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
}
type JWTTokenRequest struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Scopes Scopes `json:"scope"`
Audience interface{} `json:"aud"`
IssuedAt Time `json:"iat"`
ExpiresAt Time `json:"exp"`
}
func (j *JWTTokenRequest) GetClientID() string {
return j.Subject
}
func (j *JWTTokenRequest) GetSubject() string {
return j.Subject
}
func (j *JWTTokenRequest) GetScopes() []string {
return j.Scopes
}
type Time time.Time
func (t *Time) UnmarshalJSON(data []byte) error {
var i int64
if err := json.Unmarshal(data, &i); err != nil {
return err
}
*t = Time(time.Unix(i, 0).UTC())
return nil
}
func (j *JWTTokenRequest) GetIssuer() string {
return j.Issuer
}
func (j *JWTTokenRequest) GetAudience() []string {
return audienceFromJSON(j.Audience)
}
func (j *JWTTokenRequest) GetExpiration() time.Time {
return time.Time(j.ExpiresAt)
}
func (j *JWTTokenRequest) GetIssuedAt() time.Time {
return time.Time(j.IssuedAt)
}
func (j *JWTTokenRequest) GetNonce() string {
return ""
}
func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string {
return ""
}
func (j *JWTTokenRequest) GetAuthTime() time.Time {
return time.Time{}
}
func (j *JWTTokenRequest) GetAuthorizedParty() string {
return ""
}
func (j *JWTTokenRequest) SetSignature(algorithm jose.SignatureAlgorithm) {}
type TokenExchangeRequest struct {
subjectToken string `schema:"subject_token"`
subjectTokenType string `schema:"subject_token_type"`
actorToken string `schema:"actor_token"`
actorTokenType string `schema:"actor_token_type"`
resource []string `schema:"resource"`
audience []string `schema:"audience"`
Scope []string `schema:"scope"`
requestedTokenType string `schema:"requested_token_type"`
}
type Scopes []string
func (s *Scopes) UnmarshalText(text []byte) error {
scopes := strings.Split(string(text), " ")
*s = Scopes(scopes)
return nil
}
type ResponseType string
type Display string
func (d *Display) UnmarshalText(text []byte) error {
var ok bool
display := string(text)
*d, ok = displayValues[display]
if !ok {
return errors.New("")
}
return nil
}
type Prompt string
type Locales []language.Tag
func (l *Locales) UnmarshalText(text []byte) error {
locales := strings.Split(string(text), " ")
for _, locale := range locales {
tag, err := language.Parse(locale)
if err == nil && !tag.IsRoot() {
*l = append(*l, tag)
}
}
return nil
}
type GrantType string

View file

@ -1,5 +1,9 @@
package tokenexchange package tokenexchange
import (
"github.com/caos/oidc/pkg/oidc"
)
const ( const (
AccessTokenType = "urn:ietf:params:oauth:token-type:access_token" AccessTokenType = "urn:ietf:params:oauth:token-type:access_token"
RefreshTokenType = "urn:ietf:params:oauth:token-type:refresh_token" RefreshTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
@ -24,6 +28,18 @@ type TokenExchangeRequest struct {
type JWTProfileRequest struct { type JWTProfileRequest struct {
Assertion string `schema:"assertion"` Assertion string `schema:"assertion"`
Scope oidc.Scopes `schema:"scope"`
GrantType oidc.GrantType `schema:"grant_type"`
}
//ClientCredentialsGrantBasic creates an oauth2 `Client Credentials` Grant
//sneding client_id and client_secret as basic auth header
func NewJWTProfileRequest(assertion string, scopes ...string) *JWTProfileRequest {
return &JWTProfileRequest{
GrantType: oidc.GrantTypeBearer,
Assertion: assertion,
Scope: scopes,
}
} }
func NewTokenExchangeRequest(subjectToken, subjectTokenType string, opts ...TokenExchangeOption) *TokenExchangeRequest { func NewTokenExchangeRequest(subjectToken, subjectTokenType string, opts ...TokenExchangeOption) *TokenExchangeRequest {

View file

@ -6,21 +6,19 @@ import (
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
) )
// KeySet is a set of publc JSON Web Keys that can be used to validate the signature //KeySet represents a set of JSON Web Keys
// of JSON web tokens. This is expected to be backed by a remote key set through // - remotely fetch via discovery and jwks_uri -> `remoteKeySet`
// provider metadata discovery or an in-memory set of keys delivered out-of-band. // - held by the OP itself in storage -> `openIDKeySet`
// - dynamically aggregated by request for OAuth JWT Profile Assertion -> `jwtProfileKeySet`
type KeySet interface { type KeySet interface {
// VerifySignature parses the JSON web token, verifies the signature, and returns //VerifySignature verifies the signature with the given keyset and returns the raw payload
// the raw payload. Header and claim fields are validated by other parts of the
// package. For example, the KeySet does not need to check values such as signature
// algorithm, issuer, and audience since the IDTokenVerifier validates these values
// independently.
//
// If VerifySignature makes HTTP requests to verify the token, it's expected to
// use any HTTP client associated with the context through ClientContext.
VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error)
} }
//CheckKey searches the given JSON Web Keys for the requested key ID
//and verifies the JSON Web Signature with the found key
//
//will return false but no error if key ID is not found
func CheckKey(keyID string, jws *jose.JSONWebSignature, keys ...jose.JSONWebKey) ([]byte, error, bool) { func CheckKey(keyID string, jws *jose.JSONWebSignature, keys ...jose.JSONWebKey) ([]byte, error, bool) {
for _, key := range keys { for _, key := range keys {
if keyID == "" || key.KeyID == keyID { if keyID == "" || key.KeyID == keyID {

View file

@ -1,5 +1,7 @@
package oidc package oidc
//EndSessionRequest for the RP-Initiated Logout according to:
//https://openid.net/specs/openid-connect-rpinitiated-1_0.html#RPLogout
type EndSessionRequest struct { type EndSessionRequest struct {
IdTokenHint string `schema:"id_token_hint"` IdTokenHint string `schema:"id_token_hint"`
PostLogoutRedirectURI string `schema:"post_logout_redirect_uri"` PostLogoutRedirectURI string `schema:"post_logout_redirect_uri"`

View file

@ -3,72 +3,401 @@ package oidc
import ( import (
"encoding/json" "encoding/json"
"io/ioutil" "io/ioutil"
"strings"
"time" "time"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/utils"
) )
const (
//BearerToken defines the token_type `Bearer`, which is returned in a successful token response
BearerToken = "Bearer"
)
type Tokens struct { type Tokens struct {
*oauth2.Token *oauth2.Token
IDTokenClaims *IDTokenClaims IDTokenClaims IDTokenClaims
IDToken string IDToken string
} }
type AccessTokenClaims struct { type AccessTokenClaims interface {
Issuer string Claims
Subject string GetSubject() string
Audiences []string GetTokenID() string
Expiration time.Time SetPrivateClaims(map[string]interface{})
IssuedAt time.Time
NotBefore time.Time
JWTID string
AuthorizedParty string
Nonce string
AuthTime time.Time
CodeHash string
AuthenticationContextClassReference string
AuthenticationMethodsReferences []string
SessionID string
Scopes []string
ClientID string
AccessTokenUseNumber int
} }
type IDTokenClaims struct { type IDTokenClaims interface {
Issuer string Claims
Audiences []string GetNotBefore() time.Time
Expiration time.Time GetJWTID() string
NotBefore time.Time GetAccessTokenHash() string
IssuedAt time.Time GetCodeHash() string
JWTID string GetAuthenticationMethodsReferences() []string
UpdatedAt time.Time GetClientID() string
AuthorizedParty string GetSignatureAlgorithm() jose.SignatureAlgorithm
Nonce string SetAccessTokenHash(hash string)
AuthTime time.Time SetUserinfo(userinfo UserInfo)
AccessTokenHash string SetCodeHash(hash string)
CodeHash string UserInfo
AuthenticationContextClassReference string }
AuthenticationMethodsReferences []string
ClientID string
Userinfo
Signature jose.SignatureAlgorithm //TODO: ??? func EmptyAccessTokenClaims() AccessTokenClaims {
return new(accessTokenClaims)
}
func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, id string) AccessTokenClaims {
now := time.Now().UTC()
return &accessTokenClaims{
Issuer: issuer,
Subject: subject,
Audience: audience,
Expiration: Time(expiration),
IssuedAt: Time(now),
NotBefore: Time(now),
JWTID: id,
}
}
type accessTokenClaims struct {
Issuer string `json:"iss,omitempty"`
Subject string `json:"sub,omitempty"`
Audience Audience `json:"aud,omitempty"`
Expiration Time `json:"exp,omitempty"`
IssuedAt Time `json:"iat,omitempty"`
NotBefore Time `json:"nbf,omitempty"`
JWTID string `json:"jti,omitempty"`
AuthorizedParty string `json:"azp,omitempty"`
Nonce string `json:"nonce,omitempty"`
AuthTime Time `json:"auth_time,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
AuthenticationContextClassReference string `json:"acr,omitempty"`
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
SessionID string `json:"sid,omitempty"`
Scopes []string `json:"scope,omitempty"`
ClientID string `json:"client_id,omitempty"`
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
claims map[string]interface{} `json:"-"`
signatureAlg jose.SignatureAlgorithm `json:"-"`
}
//GetIssuer implements the Claims interface
func (a *accessTokenClaims) GetIssuer() string {
return a.Issuer
}
//GetAudience implements the Claims interface
func (a *accessTokenClaims) GetAudience() []string {
return a.Audience
}
//GetExpiration implements the Claims interface
func (a *accessTokenClaims) GetExpiration() time.Time {
return time.Time(a.Expiration)
}
//GetIssuedAt implements the Claims interface
func (a *accessTokenClaims) GetIssuedAt() time.Time {
return time.Time(a.IssuedAt)
}
//GetNonce implements the Claims interface
func (a *accessTokenClaims) GetNonce() string {
return a.Nonce
}
//GetAuthenticationContextClassReference implements the Claims interface
func (a *accessTokenClaims) GetAuthenticationContextClassReference() string {
return a.AuthenticationContextClassReference
}
//GetAuthTime implements the Claims interface
func (a *accessTokenClaims) GetAuthTime() time.Time {
return time.Time(a.AuthTime)
}
//GetAuthorizedParty implements the Claims interface
func (a *accessTokenClaims) GetAuthorizedParty() string {
return a.AuthorizedParty
}
//SetSignatureAlgorithm implements the Claims interface
func (a *accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
a.signatureAlg = algorithm
}
//GetSubject implements the AccessTokenClaims interface
func (a *accessTokenClaims) GetSubject() string {
return a.Subject
}
//GetTokenID implements the AccessTokenClaims interface
func (a *accessTokenClaims) GetTokenID() string {
return a.JWTID
}
//SetPrivateClaims implements the AccessTokenClaims interface
func (a *accessTokenClaims) SetPrivateClaims(claims map[string]interface{}) {
a.claims = claims
}
func (a *accessTokenClaims) MarshalJSON() ([]byte, error) {
type Alias accessTokenClaims
s := &struct {
*Alias
Expiration int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
AuthTime int64 `json:"auth_time,omitempty"`
}{
Alias: (*Alias)(a),
}
if !time.Time(a.Expiration).IsZero() {
s.Expiration = time.Time(a.Expiration).Unix()
}
if !time.Time(a.IssuedAt).IsZero() {
s.IssuedAt = time.Time(a.IssuedAt).Unix()
}
if !time.Time(a.NotBefore).IsZero() {
s.NotBefore = time.Time(a.NotBefore).Unix()
}
if !time.Time(a.AuthTime).IsZero() {
s.AuthTime = time.Time(a.AuthTime).Unix()
}
b, err := json.Marshal(s)
if err != nil {
return nil, err
}
if a.claims == nil {
return b, nil
}
info, err := json.Marshal(a.claims)
if err != nil {
return nil, err
}
return utils.ConcatenateJSON(b, info)
}
func (a *accessTokenClaims) UnmarshalJSON(data []byte) error {
type Alias accessTokenClaims
if err := json.Unmarshal(data, (*Alias)(a)); err != nil {
return err
}
claims := make(map[string]interface{})
if err := json.Unmarshal(data, &claims); err != nil {
return err
}
a.claims = claims
return nil
}
func EmptyIDTokenClaims() IDTokenClaims {
return new(idTokenClaims)
}
func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string) IDTokenClaims {
return &idTokenClaims{
Issuer: issuer,
Audience: audience,
Expiration: Time(expiration),
IssuedAt: Time(time.Now().UTC()),
AuthTime: Time(authTime),
Nonce: nonce,
AuthenticationContextClassReference: acr,
AuthenticationMethodsReferences: amr,
AuthorizedParty: clientID,
UserInfo: &userinfo{Subject: subject},
}
}
type idTokenClaims struct {
Issuer string `json:"iss,omitempty"`
Audience Audience `json:"aud,omitempty"`
Expiration Time `json:"exp,omitempty"`
NotBefore Time `json:"nbf,omitempty"`
IssuedAt Time `json:"iat,omitempty"`
JWTID string `json:"jti,omitempty"`
AuthorizedParty string `json:"azp,omitempty"`
Nonce string `json:"nonce,omitempty"`
AuthTime Time `json:"auth_time,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
AuthenticationContextClassReference string `json:"acr,omitempty"`
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
ClientID string `json:"client_id,omitempty"`
UserInfo `json:"-"`
signatureAlg jose.SignatureAlgorithm
}
//GetIssuer implements the Claims interface
func (t *idTokenClaims) GetIssuer() string {
return t.Issuer
}
//GetAudience implements the Claims interface
func (t *idTokenClaims) GetAudience() []string {
return t.Audience
}
//GetExpiration implements the Claims interface
func (t *idTokenClaims) GetExpiration() time.Time {
return time.Time(t.Expiration)
}
//GetIssuedAt implements the Claims interface
func (t *idTokenClaims) GetIssuedAt() time.Time {
return time.Time(t.IssuedAt)
}
//GetNonce implements the Claims interface
func (t *idTokenClaims) GetNonce() string {
return t.Nonce
}
//GetAuthenticationContextClassReference implements the Claims interface
func (t *idTokenClaims) GetAuthenticationContextClassReference() string {
return t.AuthenticationContextClassReference
}
//GetAuthTime implements the Claims interface
func (t *idTokenClaims) GetAuthTime() time.Time {
return time.Time(t.AuthTime)
}
//GetAuthorizedParty implements the Claims interface
func (t *idTokenClaims) GetAuthorizedParty() string {
return t.AuthorizedParty
}
//SetSignatureAlgorithm implements the Claims interface
func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) {
t.signatureAlg = alg
}
//GetNotBefore implements the IDTokenClaims interface
func (t *idTokenClaims) GetNotBefore() time.Time {
return time.Time(t.NotBefore)
}
//GetJWTID implements the IDTokenClaims interface
func (t *idTokenClaims) GetJWTID() string {
return t.JWTID
}
//GetAccessTokenHash implements the IDTokenClaims interface
func (t *idTokenClaims) GetAccessTokenHash() string {
return t.AccessTokenHash
}
//GetCodeHash implements the IDTokenClaims interface
func (t *idTokenClaims) GetCodeHash() string {
return t.CodeHash
}
//GetAuthenticationMethodsReferences implements the IDTokenClaims interface
func (t *idTokenClaims) GetAuthenticationMethodsReferences() []string {
return t.AuthenticationMethodsReferences
}
//GetClientID implements the IDTokenClaims interface
func (t *idTokenClaims) GetClientID() string {
return t.ClientID
}
//GetSignatureAlgorithm implements the IDTokenClaims interface
func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm {
return t.signatureAlg
}
//SetSignatureAlgorithm implements the IDTokenClaims interface
func (t *idTokenClaims) SetAccessTokenHash(hash string) {
t.AccessTokenHash = hash
}
//SetUserinfo implements the IDTokenClaims interface
func (t *idTokenClaims) SetUserinfo(info UserInfo) {
t.UserInfo = info
}
//SetCodeHash implements the IDTokenClaims interface
func (t *idTokenClaims) SetCodeHash(hash string) {
t.CodeHash = hash
}
func (t *idTokenClaims) MarshalJSON() ([]byte, error) {
type Alias idTokenClaims
a := &struct {
*Alias
Expiration int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
AuthTime int64 `json:"auth_time,omitempty"`
}{
Alias: (*Alias)(t),
}
if !time.Time(t.Expiration).IsZero() {
a.Expiration = time.Time(t.Expiration).Unix()
}
if !time.Time(t.IssuedAt).IsZero() {
a.IssuedAt = time.Time(t.IssuedAt).Unix()
}
if !time.Time(t.NotBefore).IsZero() {
a.NotBefore = time.Time(t.NotBefore).Unix()
}
if !time.Time(t.AuthTime).IsZero() {
a.AuthTime = time.Time(t.AuthTime).Unix()
}
b, err := json.Marshal(a)
if err != nil {
return nil, err
}
if t.UserInfo == nil {
return b, nil
}
info, err := json.Marshal(t.UserInfo)
if err != nil {
return nil, err
}
return utils.ConcatenateJSON(b, info)
}
func (t *idTokenClaims) UnmarshalJSON(data []byte) error {
type Alias idTokenClaims
if err := json.Unmarshal(data, (*Alias)(t)); err != nil {
return err
}
userinfo := new(userinfo)
if err := json.Unmarshal(data, userinfo); err != nil {
return err
}
t.UserInfo = userinfo
return nil
}
type AccessTokenResponse struct {
AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
} }
type JWTProfileAssertion struct { type JWTProfileAssertion struct {
PrivateKeyID string `json:"keyId"` PrivateKeyID string `json:"-"`
PrivateKey []byte `json:"key"` PrivateKey []byte `json:"-"`
Scopes []string `json:"-"` Issuer string `json:"issuer"`
Issuer string `json:"-"` Subject string `json:"sub"`
Subject string `json:"userId"` Audience Audience `json:"aud"`
Audience []string `json:"-"` Expiration Time `json:"exp"`
Expiration time.Time `json:"-"` IssuedAt Time `json:"iat"`
IssuedAt time.Time `json:"-"`
} }
func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWTProfileAssertion, error) { func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWTProfileAssertion, error) {
@ -76,12 +405,16 @@ func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWT
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewJWTProfileAssertionFromFileData(data, audience)
}
func NewJWTProfileAssertionFromFileData(data []byte, audience []string) (*JWTProfileAssertion, error) {
keyData := new(struct { keyData := new(struct {
KeyID string `json:"keyId"` KeyID string `json:"keyId"`
Key string `json:"key"` Key string `json:"key"`
UserID string `json:"userId"` UserID string `json:"userId"`
}) })
err = json.Unmarshal(data, keyData) err := json.Unmarshal(data, keyData)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -93,244 +426,13 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte)
PrivateKey: key, PrivateKey: key,
PrivateKeyID: keyID, PrivateKeyID: keyID,
Issuer: userID, Issuer: userID,
Scopes: []string{ScopeOpenID},
Subject: userID, Subject: userID,
IssuedAt: time.Now().UTC(), IssuedAt: Time(time.Now().UTC()),
Expiration: time.Now().Add(1 * time.Hour).UTC(), Expiration: Time(time.Now().Add(1 * time.Hour).UTC()),
Audience: audience, Audience: audience,
} }
} }
type jsonToken struct {
Issuer string `json:"iss,omitempty"`
Subject string `json:"sub,omitempty"`
Audiences interface{} `json:"aud,omitempty"`
Expiration int64 `json:"exp,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
JWTID string `json:"jti,omitempty"`
AuthorizedParty string `json:"azp,omitempty"`
Nonce string `json:"nonce,omitempty"`
AuthTime int64 `json:"auth_time,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
AuthenticationContextClassReference string `json:"acr,omitempty"`
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
SessionID string `json:"sid,omitempty"`
Actor interface{} `json:"act,omitempty"` //TODO: impl
Scopes string `json:"scope,omitempty"`
ClientID string `json:"client_id,omitempty"`
AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
jsonUserinfo
}
func (t *AccessTokenClaims) MarshalJSON() ([]byte, error) {
j := jsonToken{
Issuer: t.Issuer,
Subject: t.Subject,
Audiences: t.Audiences,
Expiration: timeToJSON(t.Expiration),
NotBefore: timeToJSON(t.NotBefore),
IssuedAt: timeToJSON(t.IssuedAt),
JWTID: t.JWTID,
AuthorizedParty: t.AuthorizedParty,
Nonce: t.Nonce,
AuthTime: timeToJSON(t.AuthTime),
CodeHash: t.CodeHash,
AuthenticationContextClassReference: t.AuthenticationContextClassReference,
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
SessionID: t.SessionID,
Scopes: strings.Join(t.Scopes, " "),
ClientID: t.ClientID,
AccessTokenUseNumber: t.AccessTokenUseNumber,
}
return json.Marshal(j)
}
func (t *AccessTokenClaims) UnmarshalJSON(b []byte) error {
var j jsonToken
if err := json.Unmarshal(b, &j); err != nil {
return err
}
t.Issuer = j.Issuer
t.Subject = j.Subject
t.Audiences = audienceFromJSON(j.Audiences)
t.Expiration = time.Unix(j.Expiration, 0).UTC()
t.NotBefore = time.Unix(j.NotBefore, 0).UTC()
t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
t.JWTID = j.JWTID
t.AuthorizedParty = j.AuthorizedParty
t.Nonce = j.Nonce
t.AuthTime = time.Unix(j.AuthTime, 0).UTC()
t.CodeHash = j.CodeHash
t.AuthenticationContextClassReference = j.AuthenticationContextClassReference
t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences
t.SessionID = j.SessionID
t.Scopes = strings.Split(j.Scopes, " ")
t.ClientID = j.ClientID
t.AccessTokenUseNumber = j.AccessTokenUseNumber
return nil
}
func (t *IDTokenClaims) MarshalJSON() ([]byte, error) {
j := jsonToken{
Issuer: t.Issuer,
Subject: t.Subject,
Audiences: t.Audiences,
Expiration: timeToJSON(t.Expiration),
NotBefore: timeToJSON(t.NotBefore),
IssuedAt: timeToJSON(t.IssuedAt),
JWTID: t.JWTID,
AuthorizedParty: t.AuthorizedParty,
Nonce: t.Nonce,
AuthTime: timeToJSON(t.AuthTime),
AccessTokenHash: t.AccessTokenHash,
CodeHash: t.CodeHash,
AuthenticationContextClassReference: t.AuthenticationContextClassReference,
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
ClientID: t.ClientID,
}
j.setUserinfo(t.Userinfo)
return json.Marshal(j)
}
func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
var i jsonToken
if err := json.Unmarshal(b, &i); err != nil {
return err
}
t.Issuer = i.Issuer
t.Subject = i.Subject
t.Audiences = audienceFromJSON(i.Audiences)
t.Expiration = time.Unix(i.Expiration, 0).UTC()
t.IssuedAt = time.Unix(i.IssuedAt, 0).UTC()
t.AuthTime = time.Unix(i.AuthTime, 0).UTC()
t.Nonce = i.Nonce
t.AuthenticationContextClassReference = i.AuthenticationContextClassReference
t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences
t.AuthorizedParty = i.AuthorizedParty
t.AccessTokenHash = i.AccessTokenHash
t.CodeHash = i.CodeHash
t.UserinfoProfile = i.UnmarshalUserinfoProfile()
t.UserinfoEmail = i.UnmarshalUserinfoEmail()
t.UserinfoPhone = i.UnmarshalUserinfoPhone()
t.Address = i.UnmarshalUserinfoAddress()
return nil
}
func (t *IDTokenClaims) GetIssuer() string {
return t.Issuer
}
func (t *IDTokenClaims) GetAudience() []string {
return t.Audiences
}
func (t *IDTokenClaims) GetExpiration() time.Time {
return t.Expiration
}
func (t *IDTokenClaims) GetIssuedAt() time.Time {
return t.IssuedAt
}
func (t *IDTokenClaims) GetNonce() string {
return t.Nonce
}
func (t *IDTokenClaims) GetAuthenticationContextClassReference() string {
return t.AuthenticationContextClassReference
}
func (t *IDTokenClaims) GetAuthTime() time.Time {
return t.AuthTime
}
func (t *IDTokenClaims) GetAuthorizedParty() string {
return t.AuthorizedParty
}
func (t *IDTokenClaims) SetSignature(alg jose.SignatureAlgorithm) {
t.Signature = alg
}
func (t *JWTProfileAssertion) MarshalJSON() ([]byte, error) {
j := jsonToken{
Issuer: t.Issuer,
Subject: t.Subject,
Audiences: t.Audience,
Expiration: timeToJSON(t.Expiration),
IssuedAt: timeToJSON(t.IssuedAt),
Scopes: strings.Join(t.Scopes, " "),
}
return json.Marshal(j)
}
func (t *JWTProfileAssertion) UnmarshalJSON(b []byte) error {
var j jsonToken
if err := json.Unmarshal(b, &j); err != nil {
return err
}
t.Issuer = j.Issuer
t.Subject = j.Subject
t.Audience = audienceFromJSON(j.Audiences)
t.Expiration = time.Unix(j.Expiration, 0).UTC()
t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
t.Scopes = strings.Split(j.Scopes, " ")
return nil
}
func (j *jsonToken) UnmarshalUserinfoProfile() UserinfoProfile {
locale, _ := language.Parse(j.Locale)
return UserinfoProfile{
Name: j.Name,
GivenName: j.GivenName,
FamilyName: j.FamilyName,
MiddleName: j.MiddleName,
Nickname: j.Nickname,
Profile: j.Profile,
Picture: j.Picture,
Website: j.Website,
Gender: Gender(j.Gender),
Birthdate: j.Birthdate,
Zoneinfo: j.Zoneinfo,
Locale: locale,
UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(),
PreferredUsername: j.PreferredUsername,
}
}
func (j *jsonToken) UnmarshalUserinfoEmail() UserinfoEmail {
return UserinfoEmail{
Email: j.Email,
EmailVerified: j.EmailVerified,
}
}
func (j *jsonToken) UnmarshalUserinfoPhone() UserinfoPhone {
return UserinfoPhone{
PhoneNumber: j.Phone,
PhoneNumberVerified: j.PhoneVerified,
}
}
func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress {
if j.JsonUserinfoAddress == nil {
return nil
}
return &UserinfoAddress{
Country: j.JsonUserinfoAddress.Country,
Formatted: j.JsonUserinfoAddress.Formatted,
Locality: j.JsonUserinfoAddress.Locality,
PostalCode: j.JsonUserinfoAddress.PostalCode,
Region: j.JsonUserinfoAddress.Region,
StreetAddress: j.JsonUserinfoAddress.StreetAddress,
}
}
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
hash, err := utils.GetHashAlgorithm(sigAlgorithm) hash, err := utils.GetHashAlgorithm(sigAlgorithm)
if err != nil { if err != nil {
@ -339,26 +441,3 @@ func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, erro
return utils.HashString(hash, claim, true), nil return utils.HashString(hash, claim, true), nil
} }
func timeToJSON(t time.Time) int64 {
if t.IsZero() {
return 0
}
return t.Unix()
}
func audienceFromJSON(i interface{}) []string {
switch aud := i.(type) {
case []string:
return aud
case []interface{}:
audience := make([]string, len(aud))
for i, a := range aud {
audience[i] = a.(string)
}
return audience
case string:
return []string{aud}
}
return nil
}

108
pkg/oidc/token_request.go Normal file
View file

@ -0,0 +1,108 @@
package oidc
import (
"time"
"gopkg.in/square/go-jose.v2"
)
const (
//GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
GrantTypeCode GrantType = "authorization_code"
//GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
)
type GrantType string
type TokenRequest interface {
// GrantType GrantType `schema:"grant_type"`
GrantType() GrantType
}
type TokenRequestType GrantType
type AccessTokenRequest struct {
Code string `schema:"code"`
RedirectURI string `schema:"redirect_uri"`
ClientID string `schema:"client_id"`
ClientSecret string `schema:"client_secret"`
CodeVerifier string `schema:"code_verifier"`
}
func (a *AccessTokenRequest) GrantType() GrantType {
return GrantTypeCode
}
type JWTTokenRequest struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Scopes Scopes `json:"-"`
Audience Audience `json:"aud"`
IssuedAt Time `json:"iat"`
ExpiresAt Time `json:"exp"`
}
//GetIssuer implements the Claims interface
func (j *JWTTokenRequest) GetIssuer() string {
return j.Issuer
}
//GetAudience implements the Claims and TokenRequest interfaces
func (j *JWTTokenRequest) GetAudience() []string {
return j.Audience
}
//GetExpiration implements the Claims interface
func (j *JWTTokenRequest) GetExpiration() time.Time {
return time.Time(j.ExpiresAt)
}
//GetIssuedAt implements the Claims interface
func (j *JWTTokenRequest) GetIssuedAt() time.Time {
return time.Time(j.IssuedAt)
}
//GetNonce implements the Claims interface
func (j *JWTTokenRequest) GetNonce() string {
return ""
}
//GetAuthenticationContextClassReference implements the Claims interface
func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string {
return ""
}
//GetAuthTime implements the Claims interface
func (j *JWTTokenRequest) GetAuthTime() time.Time {
return time.Time{}
}
//GetAuthorizedParty implements the Claims interface
func (j *JWTTokenRequest) GetAuthorizedParty() string {
return ""
}
//SetSignatureAlgorithm implements the Claims interface
func (j *JWTTokenRequest) SetSignatureAlgorithm(_ jose.SignatureAlgorithm) {}
//GetSubject implements the TokenRequest interface
func (j *JWTTokenRequest) GetSubject() string {
return j.Subject
}
//GetSubject implements the TokenRequest interface
func (j *JWTTokenRequest) GetScopes() []string {
return j.Scopes
}
type TokenExchangeRequest struct {
subjectToken string `schema:"subject_token"`
subjectTokenType string `schema:"subject_token_type"`
actorToken string `schema:"actor_token"`
actorTokenType string `schema:"actor_token_type"`
resource []string `schema:"resource"`
audience Audience `schema:"audience"`
Scope Scopes `schema:"scope"`
requestedTokenType string `schema:"requested_token_type"`
}

89
pkg/oidc/types.go Normal file
View file

@ -0,0 +1,89 @@
package oidc
import (
"encoding/json"
"strings"
"time"
"golang.org/x/text/language"
)
type Audience []string
func (a *Audience) UnmarshalJSON(text []byte) error {
var i interface{}
err := json.Unmarshal(text, &i)
if err != nil {
return err
}
switch aud := i.(type) {
case []interface{}:
*a = make([]string, len(aud))
for i, audience := range aud {
(*a)[i] = audience.(string)
}
case string:
*a = []string{aud}
}
return nil
}
type Display string
func (d *Display) UnmarshalText(text []byte) error {
display := Display(text)
switch display {
case DisplayPage, DisplayPopup, DisplayTouch, DisplayWAP:
*d = display
}
return nil
}
type Gender string
type Locales []language.Tag
func (l *Locales) UnmarshalText(text []byte) error {
locales := strings.Split(string(text), " ")
for _, locale := range locales {
tag, err := language.Parse(locale)
if err == nil && !tag.IsRoot() {
*l = append(*l, tag)
}
}
return nil
}
type Prompt string
type ResponseType string
type Scopes []string
func (s Scopes) Encode() string {
return strings.Join(s, " ")
}
func (s *Scopes) UnmarshalText(text []byte) error {
*s = strings.Split(string(text), " ")
return nil
}
func (s *Scopes) MarshalText() ([]byte, error) {
return []byte(s.Encode()), nil
}
type Time time.Time
func (t *Time) UnmarshalJSON(data []byte) error {
var i int64
if err := json.Unmarshal(data, &i); err != nil {
return err
}
*t = Time(time.Unix(i, 0).UTC())
return nil
}
func (t *Time) MarshalJSON() ([]byte, error) {
return json.Marshal(time.Time(*t).UTC().Unix())
}

296
pkg/oidc/types_test.go Normal file
View file

@ -0,0 +1,296 @@
package oidc
import (
"bytes"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/text/language"
)
func TestAudience_UnmarshalText(t *testing.T) {
type args struct {
text []byte
}
type res struct {
audience Audience
}
tests := []struct {
name string
args args
res res
wantErr bool
}{
{
"invalid value",
args{
[]byte(`{"aud": {"a": }}}`),
},
res{},
true,
},
{
"single audience",
args{
[]byte(`{"aud": "single audience"}`),
},
res{
[]string{"single audience"},
},
false,
},
{
"multiple audience",
args{
[]byte(`{"aud": ["multiple", "audience"]}`),
},
res{
[]string{"multiple", "audience"},
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := new(struct {
Audience Audience `json:"aud"`
})
if err := json.Unmarshal(tt.args.text, &a); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
}
assert.ElementsMatch(t, a.Audience, tt.res.audience)
})
}
}
func TestDisplay_UnmarshalText(t *testing.T) {
type args struct {
text []byte
}
type res struct {
display Display
}
tests := []struct {
name string
args args
res res
wantErr bool
}{
{
"unknown value",
args{
[]byte("unknown"),
},
res{},
false,
},
{
"page",
args{
[]byte("page"),
},
res{DisplayPage},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var d Display
if err := d.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
}
if d != tt.res.display {
t.Errorf("Display is not correct is = %v, want %v", d, tt.res.display)
}
})
}
}
func TestLocales_UnmarshalText(t *testing.T) {
type args struct {
text []byte
}
type res struct {
tags []language.Tag
}
tests := []struct {
name string
args args
res res
wantErr bool
}{
{
"unknown value",
args{
[]byte("unknown"),
},
res{},
false,
},
{
"undefined",
args{
[]byte("und"),
},
res{},
false,
},
{
"single language",
args{
[]byte("de"),
},
res{[]language.Tag{language.German}},
false,
},
{
"multiple languages",
args{
[]byte("de en"),
},
res{[]language.Tag{language.German, language.English}},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var locales Locales
if err := locales.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
}
assert.ElementsMatch(t, locales, tt.res.tags)
})
}
}
func TestScopes_UnmarshalText(t *testing.T) {
type args struct {
text []byte
}
type res struct {
scopes []string
}
tests := []struct {
name string
args args
res res
wantErr bool
}{
{
"unknown value",
args{
[]byte("unknown"),
},
res{
[]string{"unknown"},
},
false,
},
{
"struct",
args{
[]byte(`{"unknown":"value"}`),
},
res{
[]string{`{"unknown":"value"}`},
},
false,
},
{
"openid",
args{
[]byte("openid"),
},
res{
[]string{"openid"},
},
false,
},
{
"multiple scopes",
args{
[]byte("openid email custom:scope"),
},
res{
[]string{"openid", "email", "custom:scope"},
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var scopes Scopes
if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
}
assert.ElementsMatch(t, scopes, tt.res.scopes)
})
}
}
func TestScopes_MarshalText(t *testing.T) {
type args struct {
scopes Scopes
}
type res struct {
scopes []byte
}
tests := []struct {
name string
args args
res res
wantErr bool
}{
{
"unknown value",
args{
Scopes{"unknown"},
},
res{
[]byte("unknown"),
},
false,
},
{
"struct",
args{
Scopes{`{"unknown":"value"}`},
},
res{
[]byte(`{"unknown":"value"}`),
},
false,
},
{
"openid",
args{
Scopes{"openid"},
},
res{
[]byte("openid"),
},
false,
},
{
"multiple scopes",
args{
Scopes{"openid", "email", "custom:scope"},
},
res{
[]byte("openid email custom:scope"),
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
text, err := tt.args.scopes.MarshalText()
if (err != nil) != tt.wantErr {
t.Errorf("MarshalText() error = %v, wantErr %v", err, tt.wantErr)
}
if !bytes.Equal(text, tt.res.scopes) {
t.Errorf("MarshalText() is = %q, want %q", text, tt.res.scopes)
}
})
}
}

View file

@ -2,62 +2,285 @@ package oidc
import ( import (
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/caos/oidc/pkg/utils"
) )
type Userinfo struct { type UserInfo interface {
Subject string GetSubject() string
UserinfoProfile UserInfoProfile
UserinfoEmail UserInfoEmail
UserinfoPhone UserInfoPhone
Address *UserinfoAddress GetAddress() UserInfoAddress
GetClaim(key string) interface{}
}
Authorizations []string type UserInfoProfile interface {
GetName() string
GetGivenName() string
GetFamilyName() string
GetMiddleName() string
GetNickname() string
GetProfile() string
GetPicture() string
GetWebsite() string
GetGender() Gender
GetBirthdate() string
GetZoneinfo() string
GetLocale() language.Tag
GetPreferredUsername() string
}
type UserInfoEmail interface {
GetEmail() string
IsEmailVerified() bool
}
type UserInfoPhone interface {
GetPhoneNumber() string
IsPhoneNumberVerified() bool
}
type UserInfoAddress interface {
GetFormatted() string
GetStreetAddress() string
GetLocality() string
GetRegion() string
GetPostalCode() string
GetCountry() string
}
type UserInfoSetter interface {
UserInfo
SetSubject(sub string)
UserInfoProfileSetter
SetEmail(email string, verified bool)
SetPhone(phone string, verified bool)
SetAddress(address UserInfoAddress)
AppendClaims(key string, values interface{})
}
type UserInfoProfileSetter interface {
SetName(name string)
SetGivenName(name string)
SetFamilyName(name string)
SetMiddleName(name string)
SetNickname(name string)
SetUpdatedAt(date time.Time)
SetProfile(profile string)
SetPicture(profile string)
SetWebsite(website string)
SetGender(gender Gender)
SetBirthdate(birthdate string)
SetZoneinfo(zoneInfo string)
SetLocale(locale language.Tag)
SetPreferredUsername(name string)
}
func NewUserInfo() UserInfoSetter {
return &userinfo{}
}
type userinfo struct {
Subject string `json:"sub,omitempty"`
userInfoProfile
userInfoEmail
userInfoPhone
Address UserInfoAddress `json:"address,omitempty"`
claims map[string]interface{} claims map[string]interface{}
} }
type UserinfoProfile struct { func (u *userinfo) GetSubject() string {
Name string return u.Subject
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender Gender
Birthdate string
Zoneinfo string
Locale language.Tag
UpdatedAt time.Time
PreferredUsername string
} }
type Gender string func (u *userinfo) GetName() string {
return u.Name
type UserinfoEmail struct {
Email string
EmailVerified bool
} }
type UserinfoPhone struct { func (u *userinfo) GetGivenName() string {
PhoneNumber string return u.GivenName
PhoneNumberVerified bool
} }
type UserinfoAddress struct { func (u *userinfo) GetFamilyName() string {
Formatted string return u.FamilyName
StreetAddress string
Locality string
Region string
PostalCode string
Country string
} }
type jsonUserinfoProfile struct { func (u *userinfo) GetMiddleName() string {
return u.MiddleName
}
func (u *userinfo) GetNickname() string {
return u.Nickname
}
func (u *userinfo) GetProfile() string {
return u.Profile
}
func (u *userinfo) GetPicture() string {
return u.Picture
}
func (u *userinfo) GetWebsite() string {
return u.Website
}
func (u *userinfo) GetGender() Gender {
return u.Gender
}
func (u *userinfo) GetBirthdate() string {
return u.Birthdate
}
func (u *userinfo) GetZoneinfo() string {
return u.Zoneinfo
}
func (u *userinfo) GetLocale() language.Tag {
return u.Locale
}
func (u *userinfo) GetPreferredUsername() string {
return u.PreferredUsername
}
func (u *userinfo) GetEmail() string {
return u.Email
}
func (u *userinfo) IsEmailVerified() bool {
return u.EmailVerified
}
func (u *userinfo) GetPhoneNumber() string {
return u.PhoneNumber
}
func (u *userinfo) IsPhoneNumberVerified() bool {
return u.PhoneNumberVerified
}
func (u *userinfo) GetAddress() UserInfoAddress {
return u.Address
}
func (u *userinfo) GetClaim(key string) interface{} {
return u.claims[key]
}
func (u *userinfo) SetSubject(sub string) {
u.Subject = sub
}
func (u *userinfo) SetName(name string) {
u.Name = name
}
func (u *userinfo) SetGivenName(name string) {
u.GivenName = name
}
func (u *userinfo) SetFamilyName(name string) {
u.FamilyName = name
}
func (u *userinfo) SetMiddleName(name string) {
u.MiddleName = name
}
func (u *userinfo) SetNickname(name string) {
u.Nickname = name
}
func (u *userinfo) SetUpdatedAt(date time.Time) {
u.UpdatedAt = Time(date)
}
func (u *userinfo) SetProfile(profile string) {
u.Profile = profile
}
func (u *userinfo) SetPicture(picture string) {
u.Picture = picture
}
func (u *userinfo) SetWebsite(website string) {
u.Website = website
}
func (u *userinfo) SetGender(gender Gender) {
u.Gender = gender
}
func (u *userinfo) SetBirthdate(birthdate string) {
u.Birthdate = birthdate
}
func (u *userinfo) SetZoneinfo(zoneInfo string) {
u.Zoneinfo = zoneInfo
}
func (u *userinfo) SetLocale(locale language.Tag) {
u.Locale = locale
}
func (u *userinfo) SetPreferredUsername(name string) {
u.PreferredUsername = name
}
func (u *userinfo) SetEmail(email string, verified bool) {
u.Email = email
u.EmailVerified = verified
}
func (u *userinfo) SetPhone(phone string, verified bool) {
u.PhoneNumber = phone
u.PhoneNumberVerified = verified
}
func (u *userinfo) SetAddress(address UserInfoAddress) {
u.Address = address
}
func (u *userinfo) AppendClaims(key string, value interface{}) {
if u.claims == nil {
u.claims = make(map[string]interface{})
}
u.claims[key] = value
}
func (u *userInfoAddress) GetFormatted() string {
return u.Formatted
}
func (u *userInfoAddress) GetStreetAddress() string {
return u.StreetAddress
}
func (u *userInfoAddress) GetLocality() string {
return u.Locality
}
func (u *userInfoAddress) GetRegion() string {
return u.Region
}
func (u *userInfoAddress) GetPostalCode() string {
return u.PostalCode
}
func (u *userInfoAddress) GetCountry() string {
return u.Country
}
type userInfoProfile struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"` GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"` FamilyName string `json:"family_name,omitempty"`
@ -66,25 +289,25 @@ type jsonUserinfoProfile struct {
Profile string `json:"profile,omitempty"` Profile string `json:"profile,omitempty"`
Picture string `json:"picture,omitempty"` Picture string `json:"picture,omitempty"`
Website string `json:"website,omitempty"` Website string `json:"website,omitempty"`
Gender string `json:"gender,omitempty"` Gender Gender `json:"gender,omitempty"`
Birthdate string `json:"birthdate,omitempty"` Birthdate string `json:"birthdate,omitempty"`
Zoneinfo string `json:"zoneinfo,omitempty"` Zoneinfo string `json:"zoneinfo,omitempty"`
Locale string `json:"locale,omitempty"` Locale language.Tag `json:"locale,omitempty"`
UpdatedAt int64 `json:"updated_at,omitempty"` UpdatedAt Time `json:"updated_at,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"`
} }
type jsonUserinfoEmail struct { type userInfoEmail struct {
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"` EmailVerified bool `json:"email_verified,omitempty"`
} }
type jsonUserinfoPhone struct { type userInfoPhone struct {
Phone string `json:"phone_number,omitempty"` PhoneNumber string `json:"phone_number,omitempty"`
PhoneVerified bool `json:"phone_number_verified,omitempty"` PhoneNumberVerified bool `json:"phone_number_verified,omitempty"`
} }
type jsonUserinfoAddress struct { type userInfoAddress struct {
Formatted string `json:"formatted,omitempty"` Formatted string `json:"formatted,omitempty"`
StreetAddress string `json:"street_address,omitempty"` StreetAddress string `json:"street_address,omitempty"`
Locality string `json:"locality,omitempty"` Locality string `json:"locality,omitempty"`
@ -93,81 +316,63 @@ type jsonUserinfoAddress struct {
Country string `json:"country,omitempty"` Country string `json:"country,omitempty"`
} }
func (i *Userinfo) MarshalJSON() ([]byte, error) { func NewUserInfoAddress(streetAddress, locality, region, postalCode, country, formatted string) UserInfoAddress {
j := new(jsonUserinfo) return &userInfoAddress{
j.Subject = i.Subject StreetAddress: streetAddress,
j.setUserinfo(*i) Locality: locality,
j.Authorizations = i.Authorizations Region: region,
return json.Marshal(j) PostalCode: postalCode,
Country: country,
Formatted: formatted,
}
}
func (i *userinfo) MarshalJSON() ([]byte, error) {
type Alias userinfo
a := &struct {
*Alias
Locale interface{} `json:"locale,omitempty"`
UpdatedAt int64 `json:"updated_at,omitempty"`
}{
Alias: (*Alias)(i),
}
if !i.Locale.IsRoot() {
a.Locale = i.Locale
}
if !time.Time(i.UpdatedAt).IsZero() {
a.UpdatedAt = time.Time(i.UpdatedAt).Unix()
}
b, err := json.Marshal(a)
if err != nil {
return nil, err
}
if len(i.claims) == 0 {
return b, nil
}
claims, err := json.Marshal(i.claims)
if err != nil {
return nil, fmt.Errorf("jws: invalid map of custom claims %v", i.claims)
}
return utils.ConcatenateJSON(b, claims)
} }
func (i *Userinfo) UnmmarshalJSON(data []byte) error { func (i *userinfo) UnmarshalJSON(data []byte) error {
if err := json.Unmarshal(data, i); err != nil { type Alias userinfo
a := &struct {
*Alias
UpdatedAt int64 `json:"update_at,omitempty"`
}{
Alias: (*Alias)(i),
}
if err := json.Unmarshal(data, &a); err != nil {
return err return err
} }
return json.Unmarshal(data, &i.claims)
}
type jsonUserinfo struct { i.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC())
Subject string `json:"sub,omitempty"`
jsonUserinfoProfile
jsonUserinfoEmail
jsonUserinfoPhone
JsonUserinfoAddress *jsonUserinfoAddress `json:"address,omitempty"`
Authorizations []string `json:"authorizations,omitempty"`
}
func (j *jsonUserinfo) setUserinfo(i Userinfo) { return nil
j.setUserinfoProfile(i.UserinfoProfile)
j.setUserinfoEmail(i.UserinfoEmail)
j.setUserinfoPhone(i.UserinfoPhone)
j.setUserinfoAddress(i.Address)
}
func (j *jsonUserinfo) setUserinfoProfile(i UserinfoProfile) {
j.Name = i.Name
j.GivenName = i.GivenName
j.FamilyName = i.FamilyName
j.MiddleName = i.MiddleName
j.Nickname = i.Nickname
j.Profile = i.Profile
j.Picture = i.Picture
j.Website = i.Website
j.Gender = string(i.Gender)
j.Birthdate = i.Birthdate
j.Zoneinfo = i.Zoneinfo
if i.Locale != language.Und {
j.Locale = i.Locale.String()
}
j.UpdatedAt = timeToJSON(i.UpdatedAt)
j.PreferredUsername = i.PreferredUsername
}
func (j *jsonUserinfo) setUserinfoEmail(i UserinfoEmail) {
j.Email = i.Email
j.EmailVerified = i.EmailVerified
}
func (j *jsonUserinfo) setUserinfoPhone(i UserinfoPhone) {
j.Phone = i.PhoneNumber
j.PhoneVerified = i.PhoneNumberVerified
}
func (j *jsonUserinfo) setUserinfoAddress(i *UserinfoAddress) {
if i == nil {
return
}
if i.Country == "" && i.Formatted == "" && i.Locality == "" && i.PostalCode == "" && i.Region == "" && i.StreetAddress == "" {
return
}
j.JsonUserinfoAddress = &jsonUserinfoAddress{
Country: i.Country,
Formatted: i.Formatted,
Locality: i.Locality,
PostalCode: i.PostalCode,
Region: i.Region,
StreetAddress: i.StreetAddress,
}
} }
type UserInfoRequest struct { type UserInfoRequest struct {

View file

@ -24,7 +24,7 @@ type Claims interface {
GetAuthenticationContextClassReference() string GetAuthenticationContextClassReference() string
GetAuthTime() time.Time GetAuthTime() time.Time
GetAuthorizedParty() string GetAuthorizedParty() string
SetSignature(algorithm jose.SignatureAlgorithm) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm)
} }
var ( var (
@ -140,7 +140,7 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl
return ErrSignatureInvalidPayload return ErrSignatureInvalidPayload
} }
claims.SetSignature(jose.SignatureAlgorithm(sig.Header.Algorithm)) claims.SetSignatureAlgorithm(jose.SignatureAlgorithm(sig.Header.Algorithm))
return nil return nil
} }

View file

@ -91,7 +91,8 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
if err != nil { if err != nil {
return "", ErrServerError(err.Error()) return "", ErrServerError(err.Error())
} }
if err := ValidateAuthReqScopes(authReq.Scopes); err != nil { authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes)
if err != nil {
return "", err return "", err
} }
if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil { if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil {
@ -104,14 +105,33 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
} }
//ValidateAuthReqScopes validates the passed scopes //ValidateAuthReqScopes validates the passed scopes
func ValidateAuthReqScopes(scopes []string) error { func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
if len(scopes) == 0 { if len(scopes) == 0 {
return ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.") return nil, ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.")
} }
if !utils.Contains(scopes, oidc.ScopeOpenID) { openID := false
return ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.") for i := len(scopes) - 1; i >= 0; i-- {
scope := scopes[i]
if scope == oidc.ScopeOpenID {
openID = true
continue
} }
return nil if !(scope == oidc.ScopeProfile ||
scope == oidc.ScopeEmail ||
scope == oidc.ScopePhone ||
scope == oidc.ScopeAddress ||
scope == oidc.ScopeOfflineAccess) &&
!utils.Contains(client.AllowedScopes(), scope) {
scopes[i] = scopes[len(scopes)-1]
scopes[len(scopes)-1] = ""
scopes = scopes[:len(scopes)-1]
}
}
if !openID {
return nil, ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.")
}
return scopes, nil
} }
//ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type //ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type
@ -168,7 +188,7 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
if err != nil { if err != nil {
return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.") return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.")
} }
return claims.Subject, nil return claims.GetSubject(), nil
} }
//RedirectToLogin redirects the end user to the Login UI for authentication //RedirectToLogin redirects the end user to the Login UI for authentication

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/gorilla/schema" "github.com/gorilla/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
@ -193,28 +194,63 @@ func TestValidateAuthRequest(t *testing.T) {
func TestValidateAuthReqScopes(t *testing.T) { func TestValidateAuthReqScopes(t *testing.T) {
type args struct { type args struct {
client op.Client
scopes []string
}
type res struct {
err bool
scopes []string scopes []string
} }
tests := []struct { tests := []struct {
name string name string
args args args args
wantErr bool res res
}{ }{
{ {
"scopes missing fails", args{}, true, "scopes missing fails",
args{},
res{
err: true,
},
}, },
{ {
"scope openid missing fails", args{[]string{"email"}}, true, "scope openid missing fails",
args{
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
[]string{"email"},
},
res{
err: true,
},
}, },
{ {
"scope ok", args{[]string{"openid"}}, false, "scope ok",
args{
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
[]string{"openid"},
},
res{
scopes: []string{"openid"},
},
},
{
"scope with drop ok",
args{
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
[]string{"openid", "email", "unknown"},
},
res{
scopes: []string{"openid", "email"},
},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := op.ValidateAuthReqScopes(tt.args.scopes); (err != nil) != tt.wantErr { scopes, err := op.ValidateAuthReqScopes(tt.args.client, tt.args.scopes)
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr) if (err != nil) != tt.res.err {
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.res.err)
} }
assert.ElementsMatch(t, scopes, tt.res.scopes)
}) })
} }
} }

View file

@ -10,7 +10,9 @@ const (
ApplicationTypeWeb ApplicationType = iota ApplicationTypeWeb ApplicationType = iota
ApplicationTypeUserAgent ApplicationTypeUserAgent
ApplicationTypeNative ApplicationTypeNative
)
const (
AccessTokenTypeBearer AccessTokenType = iota AccessTokenTypeBearer AccessTokenType = iota
AccessTokenTypeJWT AccessTokenTypeJWT
) )
@ -32,6 +34,9 @@ type Client interface {
AccessTokenType() AccessTokenType AccessTokenType() AccessTokenType
IDTokenLifetime() time.Duration IDTokenLifetime() time.Duration
DevMode() bool DevMode() bool
AllowedScopes() []string
AssertAdditionalIdTokenScopes() bool
AssertAdditionalAccessTokenScopes() bool
} }
func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool { func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool {

View file

@ -7,6 +7,8 @@ import (
"strings" "strings"
) )
const OidcDevMode = "CAOS_OIDC_DEV"
type Configuration interface { type Configuration interface {
Issuer() string Issuer() string
AuthorizationEndpoint() Endpoint AuthorizationEndpoint() Endpoint
@ -42,7 +44,7 @@ func ValidateIssuer(issuer string) error {
} }
func devLocalAllowed(url *url.URL) bool { func devLocalAllowed(url *url.URL) bool {
_, b := os.LookupEnv("CAOS_OIDC_DEV") _, b := os.LookupEnv(OidcDevMode)
if !b { if !b {
return b return b
} }

View file

@ -60,6 +60,8 @@ func TestValidateIssuer(t *testing.T) {
true, true,
}, },
} }
//ensure env is not set
os.Unsetenv(OidcDevMode)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr { if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
@ -84,7 +86,7 @@ func TestValidateIssuerDevLocalAllowed(t *testing.T) {
false, false,
}, },
} }
os.Setenv("CAOS_OIDC_DEV", "") os.Setenv(OidcDevMode, "true")
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr { if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {

View file

@ -72,18 +72,18 @@ func (v *Verifier) VerifyIDToken(ctx context.Context, idToken string) (*oidc.IDT
return nil, nil return nil, nil
} }
type Sig struct{} type Sig struct {
signer jose.Signer
}
func (s *Sig) Signer() jose.Signer {
return s.signer
}
func (s *Sig) Health(ctx context.Context) error { func (s *Sig) Health(ctx context.Context) error {
return nil return nil
} }
func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) {
return "", nil
}
func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) {
return "", nil
}
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm { func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {
return jose.HS256 return jose.HS256
} }
@ -92,9 +92,3 @@ func ExpectStorage(a op.Authorizer, t *testing.T) {
mockA := a.(*MockAuthorizer) mockA := a.(*MockAuthorizer)
mockA.EXPECT().Storage().AnyTimes().Return(NewMockStorageAny(t)) mockA.EXPECT().Storage().AnyTimes().Return(NewMockStorageAny(t))
} }
// func NewMockSignerAny(t *testing.T) op.Signer {
// m := NewMockSigner(gomock.NewController(t))
// m.EXPECT().Sign(gomock.Any()).AnyTimes().Return("", nil)
// return m
// }

View file

@ -26,6 +26,7 @@ func NewClientExpectAny(t *testing.T, appType op.ApplicationType) op.Client {
func(id string) string { func(id string) string {
return "login?id=" + id return "login?id=" + id
}) })
m.EXPECT().AllowedScopes().AnyTimes().Return(nil)
return c return c
} }

View file

@ -49,6 +49,20 @@ func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType))
} }
// AllowedScopes mocks base method
func (m *MockClient) AllowedScopes() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AllowedScopes")
ret0, _ := ret[0].([]string)
return ret0
}
// AllowedScopes indicates an expected call of AllowedScopes
func (mr *MockClientMockRecorder) AllowedScopes() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowedScopes", reflect.TypeOf((*MockClient)(nil).AllowedScopes))
}
// ApplicationType mocks base method // ApplicationType mocks base method
func (m *MockClient) ApplicationType() op.ApplicationType { func (m *MockClient) ApplicationType() op.ApplicationType {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -63,6 +77,34 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
} }
// AssertAdditionalAccessTokenScopes mocks base method
func (m *MockClient) AssertAdditionalAccessTokenScopes() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AssertAdditionalAccessTokenScopes")
ret0, _ := ret[0].(bool)
return ret0
}
// AssertAdditionalAccessTokenScopes indicates an expected call of AssertAdditionalAccessTokenScopes
func (mr *MockClientMockRecorder) AssertAdditionalAccessTokenScopes() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalAccessTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalAccessTokenScopes))
}
// AssertAdditionalIdTokenScopes mocks base method
func (m *MockClient) AssertAdditionalIdTokenScopes() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AssertAdditionalIdTokenScopes")
ret0, _ := ret[0].(bool)
return ret0
}
// AssertAdditionalIdTokenScopes indicates an expected call of AssertAdditionalIdTokenScopes
func (mr *MockClientMockRecorder) AssertAdditionalIdTokenScopes() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalIdTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalIdTokenScopes))
}
// AuthMethod mocks base method // AuthMethod mocks base method
func (m *MockClient) AuthMethod() op.AuthMethod { func (m *MockClient) AuthMethod() op.AuthMethod {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -6,7 +6,6 @@ package mock
import ( import (
context "context" context "context"
oidc "github.com/caos/oidc/pkg/oidc"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
jose "gopkg.in/square/go-jose.v2" jose "gopkg.in/square/go-jose.v2"
reflect "reflect" reflect "reflect"
@ -49,36 +48,6 @@ func (mr *MockSignerMockRecorder) Health(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockSigner)(nil).Health), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockSigner)(nil).Health), arg0)
} }
// SignAccessToken mocks base method
func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignAccessToken", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SignAccessToken indicates an expected call of SignAccessToken
func (mr *MockSignerMockRecorder) SignAccessToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignAccessToken", reflect.TypeOf((*MockSigner)(nil).SignAccessToken), arg0)
}
// SignIDToken mocks base method
func (m *MockSigner) SignIDToken(arg0 *oidc.IDTokenClaims) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignIDToken", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SignIDToken indicates an expected call of SignIDToken
func (mr *MockSignerMockRecorder) SignIDToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignIDToken", reflect.TypeOf((*MockSigner)(nil).SignIDToken), arg0)
}
// SignatureAlgorithm mocks base method // SignatureAlgorithm mocks base method
func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm { func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -92,3 +61,17 @@ func (mr *MockSignerMockRecorder) SignatureAlgorithm() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigner)(nil).SignatureAlgorithm)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigner)(nil).SignatureAlgorithm))
} }
// Signer mocks base method
func (m *MockSigner) Signer() jose.Signer {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Signer")
ret0, _ := ret[0].(jose.Signer)
return ret0
}
// Signer indicates an expected call of Signer
func (mr *MockSignerMockRecorder) Signer() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockSigner)(nil).Signer))
}

View file

@ -171,6 +171,21 @@ func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
} }
// GetPrivateClaimsFromScopes mocks base method
func (m *MockStorage) GetPrivateClaimsFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (map[string]interface{}, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPrivateClaimsFromScopes", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(map[string]interface{})
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPrivateClaimsFromScopes indicates an expected call of GetPrivateClaimsFromScopes
func (mr *MockStorageMockRecorder) GetPrivateClaimsFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivateClaimsFromScopes", reflect.TypeOf((*MockStorage)(nil).GetPrivateClaimsFromScopes), arg0, arg1, arg2, arg3)
}
// GetSigningKey mocks base method // GetSigningKey mocks base method
func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) { func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -184,25 +199,25 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac
} }
// GetUserinfoFromScopes mocks base method // GetUserinfoFromScopes mocks base method
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.Userinfo, error) { func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (oidc.UserInfo, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2) ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*oidc.Userinfo) ret0, _ := ret[0].(oidc.UserInfo)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes // GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interface{}) *gomock.Call { func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2, arg3)
} }
// GetUserinfoFromToken mocks base method // GetUserinfoFromToken mocks base method
func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2, arg3 string) (*oidc.Userinfo, error) { func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2, arg3 string) (oidc.UserInfo, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2, arg3) ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*oidc.Userinfo) ret0, _ := ret[0].(oidc.UserInfo)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }

View file

@ -168,3 +168,12 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType {
func (c *ConfClient) DevMode() bool { func (c *ConfClient) DevMode() bool {
return c.devMode return c.devMode
} }
func (c *ConfClient) AllowedScopes() []string {
return nil
}
func (c *ConfClient) AssertAdditionalIdTokenScopes() bool {
return false
}
func (c *ConfClient) AssertAdditionalAccessTokenScopes() bool {
return false
}

View file

@ -51,6 +51,7 @@ type OpenIDProvider interface {
Encoder() utils.Encoder Encoder() utils.Encoder
IDTokenHintVerifier() IDTokenHintVerifier IDTokenHintVerifier() IDTokenHintVerifier
JWTProfileVerifier() JWTProfileVerifier JWTProfileVerifier() JWTProfileVerifier
AccessTokenVerifier() AccessTokenVerifier
Crypto() Crypto Crypto() Crypto
DefaultLogoutRedirectURI() string DefaultLogoutRedirectURI() string
Signer() Signer Signer() Signer
@ -130,7 +131,7 @@ func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opO
} }
keyCh := make(chan jose.SigningKey) keyCh := make(chan jose.SigningKey)
o.signer = NewDefaultSigner(ctx, storage, keyCh) o.signer = NewSigner(ctx, storage, keyCh)
go EnsureKey(ctx, storage, keyCh, o.timer, o.retry) go EnsureKey(ctx, storage, keyCh, o.timer, o.retry)
o.httpHandler = CreateRouter(o, o.interceptors...) o.httpHandler = CreateRouter(o, o.interceptors...)
@ -152,6 +153,8 @@ type openidProvider struct {
signer Signer signer Signer
idTokenHintVerifier IDTokenHintVerifier idTokenHintVerifier IDTokenHintVerifier
jwtProfileVerifier JWTProfileVerifier jwtProfileVerifier JWTProfileVerifier
accessTokenVerifier AccessTokenVerifier
keySet *openIDKeySet
crypto Crypto crypto Crypto
httpHandler http.Handler httpHandler http.Handler
decoder *schema.Decoder decoder *schema.Decoder
@ -207,7 +210,7 @@ func (o *openidProvider) Encoder() utils.Encoder {
func (o *openidProvider) IDTokenHintVerifier() IDTokenHintVerifier { func (o *openidProvider) IDTokenHintVerifier() IDTokenHintVerifier {
if o.idTokenHintVerifier == nil { if o.idTokenHintVerifier == nil {
o.idTokenHintVerifier = NewIDTokenHintVerifier(o.Issuer(), &openIDKeySet{o.Storage()}) o.idTokenHintVerifier = NewIDTokenHintVerifier(o.Issuer(), o.openIDKeySet())
} }
return o.idTokenHintVerifier return o.idTokenHintVerifier
} }
@ -219,6 +222,20 @@ func (o *openidProvider) JWTProfileVerifier() JWTProfileVerifier {
return o.jwtProfileVerifier return o.jwtProfileVerifier
} }
func (o *openidProvider) AccessTokenVerifier() AccessTokenVerifier {
if o.accessTokenVerifier == nil {
o.accessTokenVerifier = NewAccessTokenVerifier(o.Issuer(), o.openIDKeySet())
}
return o.accessTokenVerifier
}
func (o *openidProvider) openIDKeySet() oidc.KeySet {
if o.keySet == nil {
o.keySet = &openIDKeySet{o.Storage()}
}
return o.keySet
}
func (o *openidProvider) Crypto() Crypto { func (o *openidProvider) Crypto() Crypto {
return o.crypto return o.crypto
} }

View file

@ -66,8 +66,8 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
if err != nil { if err != nil {
return nil, ErrInvalidRequest("id_token_hint invalid") return nil, ErrInvalidRequest("id_token_hint invalid")
} }
session.UserID = claims.Subject session.UserID = claims.GetSubject()
session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.AuthorizedParty) session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.GetAuthorizedParty())
if err != nil { if err != nil {
return nil, ErrServerError("") return nil, ErrServerError("")
} }

View file

@ -2,19 +2,15 @@ package op
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"github.com/caos/logging" "github.com/caos/logging"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
) )
type Signer interface { type Signer interface {
Health(ctx context.Context) error Health(ctx context.Context) error
SignIDToken(claims *oidc.IDTokenClaims) (string, error) Signer() jose.Signer
SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
SignatureAlgorithm() jose.SignatureAlgorithm SignatureAlgorithm() jose.SignatureAlgorithm
} }
@ -24,7 +20,7 @@ type tokenSigner struct {
alg jose.SignatureAlgorithm alg jose.SignatureAlgorithm
} }
func NewDefaultSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer { func NewSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer {
s := &tokenSigner{ s := &tokenSigner{
storage: storage, storage: storage,
} }
@ -41,6 +37,10 @@ func (s *tokenSigner) Health(_ context.Context) error {
return nil return nil
} }
func (s *tokenSigner) Signer() jose.Signer {
return s.signer
}
func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) { func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) {
for { for {
select { select {
@ -55,30 +55,6 @@ func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.S
} }
} }
func (s *tokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", err
}
return s.Sign(payload)
}
func (s *tokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", err
}
return s.Sign(payload)
}
func (s *tokenSigner) Sign(payload []byte) (string, error) {
result, err := s.signer.Sign(payload)
if err != nil {
return "", err
}
return result.CompactSerialize()
}
func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm { func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
return s.alg return s.alg
} }

View file

@ -1,95 +0,0 @@
package op
import (
"testing"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2"
)
// func TestNewDefaultSigner(t *testing.T) {
// type args struct {
// storage Storage
// }
// tests := []struct {
// name string
// args args
// want Signer
// wantErr bool
// }{
// {
// "err initialize storage fails",
// args{mock.NewMockStorageSigningKeyError(t)},
// nil,
// true,
// },
// {
// "err initialize storage fails",
// args{mock.NewMockStorageSigningKeyInvalid(t)},
// nil,
// true,
// },
// {
// "initialize ok",
// args{mock.NewMockStorageSigningKey(t)},
// &idTokenSigner{Storage: mock.NewMockStorageSigningKey(t)},
// false,
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// got, err := op.NewDefaultSigner(tt.args.storage)
// if (err != nil) != tt.wantErr {
// t.Errorf("NewDefaultSigner() error = %v, wantErr %v", err, tt.wantErr)
// return
// }
// if !reflect.DeepEqual(got, tt.want) {
// t.Errorf("NewDefaultSigner() = %v, want %v", got, tt.want)
// }
// })
// }
// }
func Test_idTokenSigner_Sign(t *testing.T) {
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, &jose.SignerOptions{})
require.NoError(t, err)
type fields struct {
signer jose.Signer
storage Storage
}
type args struct {
payload []byte
}
tests := []struct {
name string
fields fields
args args
want string
wantErr bool
}{
{
"ok",
fields{signer, nil},
args{[]byte("test")},
"eyJhbGciOiJIUzI1NiJ9.dGVzdA.SxYZRsvB_Dr4F7SEFuYXvkMZqCCwzpsPOQXl-vLPEww",
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &tokenSigner{
signer: tt.fields.signer,
storage: tt.fields.storage,
}
got, err := s.Sign(tt.args.payload)
if (err != nil) != tt.wantErr {
t.Errorf("idTokenSigner.Sign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("idTokenSigner.Sign() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -26,10 +26,11 @@ type AuthStorage interface {
} }
type OPStorage interface { type OPStorage interface {
GetClientByClientID(context.Context, string) (Client, error) GetClientByClientID(ctx context.Context, clientID string) (Client, error)
AuthorizeClientIDSecret(context.Context, string, string) error AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error
GetUserinfoFromScopes(context.Context, string, []string) (*oidc.Userinfo, error) GetUserinfoFromScopes(ctx context.Context, userID, clientID string, scopes []string) (oidc.UserInfo, error)
GetUserinfoFromToken(ctx context.Context, tokenID, subject, origin string) (*oidc.Userinfo, error) GetUserinfoFromToken(ctx context.Context, tokenID, subject, origin string) (oidc.UserInfo, error)
GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]interface{}, error)
GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
} }

View file

@ -5,6 +5,7 @@ import (
"time" "time"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
type TokenCreator interface { type TokenCreator interface {
@ -25,12 +26,12 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client
var validity time.Duration var validity time.Duration
if createAccessToken { if createAccessToken {
var err error var err error
accessToken, validity, err = CreateAccessToken(ctx, authReq, client.AccessTokenType(), creator) accessToken, validity, err = CreateAccessToken(ctx, authReq, client.AccessTokenType(), creator, client)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer()) idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer(), client.AssertAdditionalIdTokenScopes())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -50,7 +51,7 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client
} }
func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator) (*oidc.AccessTokenResponse, error) { func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator) (*oidc.AccessTokenResponse, error) {
accessToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator) accessToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -63,17 +64,17 @@ func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, crea
}, nil }, nil
} }
func CreateAccessToken(ctx context.Context, authReq TokenRequest, accessTokenType AccessTokenType, creator TokenCreator) (token string, validity time.Duration, err error) { func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client Client) (token string, validity time.Duration, err error) {
id, exp, err := creator.Storage().CreateToken(ctx, authReq) id, exp, err := creator.Storage().CreateToken(ctx, tokenRequest)
if err != nil { if err != nil {
return "", 0, err return "", 0, err
} }
validity = exp.Sub(time.Now().UTC()) validity = exp.Sub(time.Now().UTC())
if accessTokenType == AccessTokenTypeJWT { if accessTokenType == AccessTokenTypeJWT {
token, err = CreateJWT(creator.Issuer(), authReq, exp, id, creator.Signer()) token, err = CreateJWT(ctx, creator.Issuer(), tokenRequest, exp, id, creator.Signer(), client, creator.Storage())
return return
} }
token, err = CreateBearerToken(id, authReq.GetSubject(), creator.Crypto()) token, err = CreateBearerToken(id, tokenRequest.GetSubject(), creator.Crypto())
return return
} }
@ -81,52 +82,79 @@ func CreateBearerToken(tokenID, subject string, crypto Crypto) (string, error) {
return crypto.Encrypt(tokenID + ":" + subject) return crypto.Encrypt(tokenID + ":" + subject)
} }
func CreateJWT(issuer string, authReq TokenRequest, exp time.Time, id string, signer Signer) (string, error) { func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, signer Signer, client Client, storage Storage) (string, error) {
now := time.Now().UTC() claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id)
nbf := now if client != nil && client.AssertAdditionalAccessTokenScopes() {
claims := &oidc.AccessTokenClaims{ privateClaims, err := storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(tokenRequest.GetScopes()))
Issuer: issuer, if err != nil {
Subject: authReq.GetSubject(), return "", err
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: now,
NotBefore: nbf,
JWTID: id,
} }
return signer.SignAccessToken(claims) claims.SetPrivateClaims(privateClaims)
}
return utils.Sign(claims, signer.Signer())
} }
func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer) (string, error) { func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer, additonalScopes bool) (string, error) {
var err error
exp := time.Now().UTC().Add(validity) exp := time.Now().UTC().Add(validity)
userinfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes()) claims := oidc.NewIDTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, authReq.GetAuthTime(), authReq.GetNonce(), authReq.GetACR(), authReq.GetAMR(), authReq.GetClientID())
if err != nil { scopes := authReq.GetScopes()
return "", err
}
claims := &oidc.IDTokenClaims{
Issuer: issuer,
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: time.Now().UTC(),
AuthTime: authReq.GetAuthTime(),
Nonce: authReq.GetNonce(),
AuthenticationContextClassReference: authReq.GetACR(),
AuthenticationMethodsReferences: authReq.GetAMR(),
AuthorizedParty: authReq.GetClientID(),
Userinfo: *userinfo,
}
if accessToken != "" { if accessToken != "" {
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm()) atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
if err != nil { if err != nil {
return "", err return "", err
} }
claims.SetAccessTokenHash(atHash)
scopes = removeUserinfoScopes(scopes)
}
if !additonalScopes {
scopes = removeAdditionalScopes(scopes)
}
if len(scopes) > 0 {
userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetClientID(), scopes)
if err != nil {
return "", err
}
claims.SetUserinfo(userInfo)
} }
if code != "" { if code != "" {
claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm()) codeHash, err := oidc.ClaimHash(code, signer.SignatureAlgorithm())
if err != nil { if err != nil {
return "", err return "", err
} }
claims.SetCodeHash(codeHash)
} }
return signer.SignIDToken(claims) return utils.Sign(claims, signer.Signer())
}
func removeUserinfoScopes(scopes []string) []string {
for i := len(scopes) - 1; i >= 0; i-- {
if scopes[i] == oidc.ScopeProfile ||
scopes[i] == oidc.ScopeEmail ||
scopes[i] == oidc.ScopeAddress ||
scopes[i] == oidc.ScopePhone {
scopes[i] = scopes[len(scopes)-1]
scopes[len(scopes)-1] = ""
scopes = scopes[:len(scopes)-1]
}
}
return scopes
}
func removeAdditionalScopes(scopes []string) []string {
for i := len(scopes) - 1; i >= 0; i-- {
if !(scopes[i] == oidc.ScopeOpenID ||
scopes[i] == oidc.ScopeProfile ||
scopes[i] == oidc.ScopeEmail ||
scopes[i] == oidc.ScopeAddress ||
scopes[i] == oidc.ScopePhone) {
scopes[i] = scopes[len(scopes)-1]
scopes[len(scopes)-1] = ""
scopes = scopes[:len(scopes)-1]
}
}
return scopes
} }

View file

@ -138,18 +138,18 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque
} }
func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
assertion, err := ParseJWTProfileRequest(r, exchanger.Decoder()) profileRequest, err := ParseJWTProfileRequest(r, exchanger.Decoder())
if err != nil { if err != nil {
RequestError(w, r, err) RequestError(w, r, err)
} }
claims, err := VerifyJWTAssertion(r.Context(), assertion, exchanger.JWTProfileVerifier()) tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest, exchanger.JWTProfileVerifier())
if err != nil { if err != nil {
RequestError(w, r, err) RequestError(w, r, err)
return return
} }
resp, err := CreateJWTTokenResponse(r.Context(), claims, exchanger) resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, exchanger)
if err != nil { if err != nil {
RequestError(w, r, err) RequestError(w, r, err)
return return
@ -157,17 +157,17 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
utils.MarshalJSON(w, resp) utils.MarshalJSON(w, resp)
} }
func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (string, error) { func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (*tokenexchange.JWTProfileRequest, error) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
return "", ErrInvalidRequest("error parsing form") return nil, ErrInvalidRequest("error parsing form")
} }
tokenReq := new(tokenexchange.JWTProfileRequest) tokenReq := new(tokenexchange.JWTProfileRequest)
err = decoder.Decode(tokenReq, r.Form) err = decoder.Decode(tokenReq, r.Form)
if err != nil { if err != nil {
return "", ErrInvalidRequest("error decoding form") return nil, ErrInvalidRequest("error decoding form")
} }
return tokenReq.Assertion, nil return tokenReq, nil
} }
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {

View file

@ -1,6 +1,7 @@
package op package op
import ( import (
"context"
"errors" "errors"
"net/http" "net/http"
"strings" "strings"
@ -13,6 +14,7 @@ type UserinfoProvider interface {
Decoder() utils.Decoder Decoder() utils.Decoder
Crypto() Crypto Crypto() Crypto
Storage() Storage Storage() Storage
AccessTokenVerifier() AccessTokenVerifier
} }
func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) { func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) {
@ -27,17 +29,12 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP
http.Error(w, "access token missing", http.StatusUnauthorized) http.Error(w, "access token missing", http.StatusUnauthorized)
return return
} }
tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken) tokenID, subject, ok := getTokenIDAndSubject(r.Context(), userinfoProvider, accessToken)
if err != nil { if !ok {
http.Error(w, "access token missing", http.StatusUnauthorized)
return
}
splittedToken := strings.Split(tokenIDSubject, ":")
if len(splittedToken) != 2 {
http.Error(w, "access token invalid", http.StatusUnauthorized) http.Error(w, "access token invalid", http.StatusUnauthorized)
return return
} }
info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), splittedToken[0], splittedToken[1], r.Header.Get("origin")) info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), tokenID, subject, r.Header.Get("origin"))
if err != nil { if err != nil {
w.WriteHeader(http.StatusForbidden) w.WriteHeader(http.StatusForbidden)
utils.MarshalJSON(w, err) utils.MarshalJSON(w, err)
@ -66,3 +63,19 @@ func getAccessToken(r *http.Request, decoder utils.Decoder) (string, error) {
} }
return req.AccessToken, nil return req.AccessToken, nil
} }
func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) {
tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
if err == nil {
splitToken := strings.Split(tokenIDSubject, ":")
if len(splitToken) != 2 {
return "", "", false
}
return splitToken[0], splitToken[1], true
}
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier())
if err != nil {
return "", "", false
}
return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true
}

View file

@ -0,0 +1,85 @@
package op
import (
"context"
"time"
"github.com/caos/oidc/pkg/oidc"
)
type AccessTokenVerifier interface {
oidc.Verifier
SupportedSignAlgs() []string
KeySet() oidc.KeySet
}
type accessTokenVerifier struct {
issuer string
maxAgeIAT time.Duration
offset time.Duration
supportedSignAlgs []string
maxAge time.Duration
acr oidc.ACRVerifier
keySet oidc.KeySet
}
//Issuer implements oidc.Verifier interface
func (i *accessTokenVerifier) Issuer() string {
return i.issuer
}
//MaxAgeIAT implements oidc.Verifier interface
func (i *accessTokenVerifier) MaxAgeIAT() time.Duration {
return i.maxAgeIAT
}
//Offset implements oidc.Verifier interface
func (i *accessTokenVerifier) Offset() time.Duration {
return i.offset
}
//SupportedSignAlgs implements AccessTokenVerifier interface
func (i *accessTokenVerifier) SupportedSignAlgs() []string {
return i.supportedSignAlgs
}
//KeySet implements AccessTokenVerifier interface
func (i *accessTokenVerifier) KeySet() oidc.KeySet {
return i.keySet
}
func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet) AccessTokenVerifier {
verifier := &idTokenHintVerifier{
issuer: issuer,
keySet: keySet,
}
return verifier
}
//VerifyAccessToken validates the access token (issuer, signature and expiration)
func VerifyAccessToken(ctx context.Context, token string, v AccessTokenVerifier) (oidc.AccessTokenClaims, error) {
claims := oidc.EmptyAccessTokenClaims()
decrypted, err := oidc.DecryptToken(token)
if err != nil {
return nil, err
}
payload, err := oidc.ParseToken(decrypted, claims)
if err != nil {
return nil, err
}
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
return nil, err
}
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
return nil, err
}
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
return nil, err
}
return claims, nil
}

View file

@ -63,8 +63,8 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet) IDTokenHintVerifi
//VerifyIDTokenHint validates the id token according to //VerifyIDTokenHint validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation //https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (*oidc.IDTokenClaims, error) { func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (oidc.IDTokenClaims, error) {
claims := new(oidc.IDTokenClaims) claims := oidc.EmptyIDTokenClaims()
decrypted, err := oidc.DecryptToken(token) decrypted, err := oidc.DecryptToken(token)
if err != nil { if err != nil {

View file

@ -8,6 +8,7 @@ import (
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
) )
type JWTProfileVerifier interface { type JWTProfileVerifier interface {
@ -47,9 +48,9 @@ func (v *jwtProfileVerifier) Offset() time.Duration {
return v.offset return v.offset
} }
func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) { func VerifyJWTAssertion(ctx context.Context, profileRequest *tokenexchange.JWTProfileRequest, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) {
request := new(oidc.JWTTokenRequest) request := new(oidc.JWTTokenRequest)
payload, err := oidc.ParseToken(assertion, request) payload, err := oidc.ParseToken(profileRequest.Assertion, request)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -72,9 +73,10 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif
keySet := &jwtProfileKeySet{v.Storage(), request.Subject} keySet := &jwtProfileKeySet{v.Storage(), request.Subject}
if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil { if err = oidc.CheckSignature(ctx, profileRequest.Assertion, payload, request, nil, keySet); err != nil {
return nil, err return nil, err
} }
request.Scopes = profileRequest.Scope
return request, nil return request, nil
} }

View file

@ -1,37 +0,0 @@
package mock
import (
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/rp"
)
func NewVerifier(t *testing.T) rp.Verifier {
return NewMockVerifier(gomock.NewController(t))
}
func NewMockVerifierExpectInvalid(t *testing.T) rp.Verifier {
m := NewVerifier(t)
ExpectVerifyInvalid(m)
return m
}
func ExpectVerifyInvalid(v rp.Verifier) {
mock := v.(*MockVerifier)
mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(nil, errors.New("invalid"))
}
func NewMockVerifierExpectValid(t *testing.T) rp.Verifier {
m := NewVerifier(t)
ExpectVerifyValid(m)
return m
}
func ExpectVerifyValid(v rp.Verifier) {
mock := v.(*MockVerifier)
mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.Userinfo{Subject: "id"}}, nil)
}

View file

@ -4,9 +4,13 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"net/url"
"reflect"
"strings" "strings"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gorilla/schema"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/oidc/grants" "github.com/caos/oidc/pkg/oidc/grants"
@ -22,6 +26,16 @@ const (
jwtProfileKey = "urn:ietf:params:oauth:grant-type:jwt-bearer" jwtProfileKey = "urn:ietf:params:oauth:grant-type:jwt-bearer"
) )
var (
encoder = func() utils.Encoder {
e := schema.NewEncoder()
e.RegisterEncoder(oidc.Scopes{}, func(value reflect.Value) string {
return value.Interface().(oidc.Scopes).Encode()
})
return e
}()
)
//RelayingParty declares the minimal interface for oidc clients //RelayingParty declares the minimal interface for oidc clients
type RelayingParty interface { type RelayingParty interface {
//OAuthConfig returns the oauth2 Config //OAuthConfig returns the oauth2 Config
@ -312,38 +326,45 @@ func CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc
//ClientCredentials is the `RelayingParty` interface implementation //ClientCredentials is the `RelayingParty` interface implementation
//handling the oauth2 client credentials grant //handling the oauth2 client credentials grant
func ClientCredentials(ctx context.Context, rp RelayingParty, scopes ...string) (newToken *oauth2.Token, err error) { func ClientCredentials(ctx context.Context, rp RelayingParty, scopes ...string) (newToken *oauth2.Token, err error) {
return CallTokenEndpoint(grants.ClientCredentialsGrantBasic(scopes...), rp) return CallTokenEndpointAuthorized(grants.ClientCredentialsGrantBasic(scopes...), rp)
}
func CallTokenEndpointAuthorized(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
config := rp.OAuthConfig()
var fn interface{} = utils.AuthorizeBasic(config.ClientID, config.ClientSecret)
if config.Endpoint.AuthStyle == oauth2.AuthStyleInParams {
fn = func(form url.Values) {
form.Set("client_id", config.ClientID)
form.Set("client_secret", config.ClientSecret)
}
}
return callTokenEndpoint(request, fn, rp)
} }
func CallTokenEndpoint(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) { func CallTokenEndpoint(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
config := rp.OAuthConfig() return callTokenEndpoint(request, nil, rp)
req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, config.ClientID, config.ClientSecret, config.Endpoint.AuthStyle != oauth2.AuthStyleInParams)
if err != nil {
return nil, err
}
token := new(oauth2.Token)
if err := utils.HttpRequest(rp.HttpClient(), req, token); err != nil {
return nil, err
}
return token, nil
} }
func CallJWTProfileEndpoint(assertion string, rp RelayingParty) (*oauth2.Token, error) { func callTokenEndpoint(request interface{}, authFn interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
form := make(map[string][]string) req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, encoder, authFn)
form["assertion"] = []string{assertion}
form["grant_type"] = []string{jwtProfileKey}
req, err := http.NewRequest("POST", rp.OAuthConfig().Endpoint.TokenURL, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var tokenRes struct {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
token := new(oauth2.Token) ExpiresIn int64 `json:"expires_in"`
if err := utils.HttpRequest(rp.HttpClient(), req, token); err != nil { RefreshToken string `json:"refresh_token"`
}
if err := utils.HttpRequest(rp.HttpClient(), req, &tokenRes); err != nil {
return nil, err return nil, err
} }
return token, nil return &oauth2.Token{
AccessToken: tokenRes.AccessToken,
TokenType: tokenRes.TokenType,
RefreshToken: tokenRes.RefreshToken,
Expiry: time.Now().UTC().Add(time.Duration(tokenRes.ExpiresIn) * time.Second),
}, nil
} }
func trySetStateCookie(w http.ResponseWriter, state string, rp RelayingParty) error { func trySetStateCookie(w http.ResponseWriter, state string, rp RelayingParty) error {

View file

@ -43,12 +43,17 @@ func DelegationTokenExchange(ctx context.Context, subjectToken string, rp Relayi
} }
//JWTProfileExchange handles the oauth2 jwt profile exchange //JWTProfileExchange handles the oauth2 jwt profile exchange
func JWTProfileExchange(ctx context.Context, assertion *oidc.JWTProfileAssertion, rp RelayingParty) (*oauth2.Token, error) { func JWTProfileExchange(ctx context.Context, jwtProfileRequest *tokenexchange.JWTProfileRequest, rp RelayingParty) (*oauth2.Token, error) {
return CallTokenEndpoint(jwtProfileRequest, rp)
}
//JWTProfileExchange handles the oauth2 jwt profile exchange
func JWTProfileAssertionExchange(ctx context.Context, assertion *oidc.JWTProfileAssertion, scopes oidc.Scopes, rp RelayingParty) (*oauth2.Token, error) {
token, err := generateJWTProfileToken(assertion) token, err := generateJWTProfileToken(assertion)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return CallJWTProfileEndpoint(token, rp) return JWTProfileExchange(ctx, tokenexchange.NewJWTProfileRequest(token, scopes...), rp)
} }
func generateJWTProfileToken(assertion *oidc.JWTProfileAssertion) (string, error) { func generateJWTProfileToken(assertion *oidc.JWTProfileAssertion) (string, error) {

View file

@ -21,12 +21,12 @@ type IDTokenVerifier interface {
//VerifyTokens implement the Token Response Validation as defined in OIDC specification //VerifyTokens implement the Token Response Validation as defined in OIDC specification
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation //https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) { func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
idToken, err := VerifyIDToken(ctx, idTokenString, v) idToken, err := VerifyIDToken(ctx, idTokenString, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := VerifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil { if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil {
return nil, err return nil, err
} }
return idToken, nil return idToken, nil
@ -34,8 +34,8 @@ func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTo
//VerifyIDToken validates the id token according to //VerifyIDToken validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation //https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) { func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
claims := new(oidc.IDTokenClaims) claims := oidc.EmptyIDTokenClaims()
decrypted, err := oidc.DecryptToken(token) decrypted, err := oidc.DecryptToken(token)
if err != nil { if err != nil {

View file

@ -10,8 +10,6 @@ import (
"net/url" "net/url"
"strings" "strings"
"time" "time"
"github.com/gorilla/schema"
) )
var ( var (
@ -27,23 +25,30 @@ type Encoder interface {
Encode(src interface{}, dst map[string][]string) error Encode(src interface{}, dst map[string][]string) error
} }
func FormRequest(endpoint string, request interface{}, clientID, clientSecret string, header bool) (*http.Request, error) { type FormAuthorization func(url.Values)
form := make(map[string][]string) type RequestAuthorization func(*http.Request)
encoder := schema.NewEncoder()
func AuthorizeBasic(user, password string) RequestAuthorization {
return func(req *http.Request) {
req.SetBasicAuth(user, password)
}
}
func FormRequest(endpoint string, request interface{}, encoder Encoder, authFn interface{}) (*http.Request, error) {
form := url.Values{}
if err := encoder.Encode(request, form); err != nil { if err := encoder.Encode(request, form); err != nil {
return nil, err return nil, err
} }
if !header { if fn, ok := authFn.(FormAuthorization); ok {
form["client_id"] = []string{clientID} fn(form)
form["client_secret"] = []string{clientSecret}
} }
body := strings.NewReader(url.Values(form).Encode()) body := strings.NewReader(form.Encode())
req, err := http.NewRequest("POST", endpoint, body) req, err := http.NewRequest("POST", endpoint, body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if header { if fn, ok := authFn.(RequestAuthorization); ok {
req.SetBasicAuth(clientID, clientSecret) fn(req)
} }
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req, nil return req, nil

View file

@ -1,7 +1,9 @@
package utils package utils
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -19,3 +21,15 @@ func MarshalJSON(w http.ResponseWriter, i interface{}) {
logrus.Error("error writing response") logrus.Error("error writing response")
} }
} }
func ConcatenateJSON(first, second []byte) ([]byte, error) {
if !bytes.HasSuffix(first, []byte{'}'}) {
return nil, fmt.Errorf("jws: invalid JSON %s", first)
}
if !bytes.HasPrefix(second, []byte{'{'}) {
return nil, fmt.Errorf("jws: invalid JSON %s", second)
}
first[len(first)-1] = ','
first = append(first, second[1:]...)
return first, nil
}

60
pkg/utils/marshal_test.go Normal file
View file

@ -0,0 +1,60 @@
package utils
import (
"bytes"
"testing"
)
func TestConcatenateJSON(t *testing.T) {
type args struct {
first []byte
second []byte
}
tests := []struct {
name string
args args
want []byte
wantErr bool
}{
{
"invalid first part, error",
args{
[]byte(`invalid`),
[]byte(`{"some": "thing"}`),
},
nil,
true,
},
{
"invalid second part, error",
args{
[]byte(`{"some": "thing"}`),
[]byte(`invalid`),
},
nil,
true,
},
{
"both valid, merged",
args{
[]byte(`{"some": "thing"}`),
[]byte(`{"another": "thing"}`),
},
[]byte(`{"some": "thing","another": "thing"}`),
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ConcatenateJSON(tt.args.first, tt.args.second)
if (err != nil) != tt.wantErr {
t.Errorf("ConcatenateJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !bytes.Equal(got, tt.want) {
t.Errorf("ConcatenateJSON() got = %v, want %v", got, tt.want)
}
})
}
}

23
pkg/utils/sign.go Normal file
View file

@ -0,0 +1,23 @@
package utils
import (
"encoding/json"
"gopkg.in/square/go-jose.v2"
)
func Sign(object interface{}, signer jose.Signer) (string, error) {
payload, err := json.Marshal(object)
if err != nil {
return "", err
}
return SignPayload(payload, signer)
}
func SignPayload(payload []byte, signer jose.Signer) (string, error) {
result, err := signer.Sign(payload)
if err != nil {
return "", err
}
return result.CompactSerialize()
}