Merge pull request #60 from caos/serializing

feat: private claims (incl. serialisation refactoring and jwt profile fix)
This commit is contained in:
Fabi 2020-10-15 15:27:00 +02:00 committed by GitHub
commit c1699a2d93
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
43 changed files with 1896 additions and 980 deletions

View file

@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
"html/template"
"io/ioutil"
"net/http"
"os"
"time"
@ -30,7 +32,7 @@ func main() {
ctx := context.Background()
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail}
scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress}
cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes,
rp.WithPKCE(cookieHandler),
@ -82,6 +84,66 @@ func main() {
}
w.Write(data)
})
http.HandleFunc("/jwt-profile", func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
tpl := `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Login</title>
</head>
<body>
<form method="POST" action="/jwt-profile" enctype="multipart/form-data">
<label for="key">Select a key file:</label>
<input type="file" accept=".json" id="key" name="key">
<button type="submit">Get Token</button>
</form>
</body>
</html>`
t, err := template.New("login").Parse(tpl)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
err = t.Execute(w, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
} else {
err := r.ParseMultipartForm(4 << 10)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
file, handler, err := r.FormFile("key")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer file.Close()
key, err := ioutil.ReadAll(file)
fmt.Println(handler.Header)
assertion, err := oidc.NewJWTProfileAssertionFromFileData(key, []string{issuer})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
token, err := rp.JWTProfileAssertionExchange(ctx, assertion, scopes, provider)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
data, err := json.Marshal(token)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write(data)
}
})
lis := fmt.Sprintf("127.0.0.1:%s", port)
logrus.Infof("listening on http://%s/", lis)
logrus.Fatal(http.ListenAndServe("127.0.0.1:"+port, nil))

View file

@ -45,7 +45,7 @@ func main() {
}
token := cli.CodeFlow(relayingParty, callbackPath, port, state)
client := github.NewClient(relayingParty.Client(ctx, token.Token))
client := github.NewClient(relayingParty.OAuthConfig().Client(ctx, token.Token))
_, _, err = client.Users.Get(ctx, "")
if err != nil {

View file

@ -210,31 +210,21 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st
return nil
}
func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _, _ string) (*oidc.Userinfo, error) {
return s.GetUserinfoFromScopes(ctx, "", []string{})
func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _, _ string) (oidc.UserInfo, error) {
return s.GetUserinfoFromScopes(ctx, "", "", []string{})
}
func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (*oidc.Userinfo, error) {
return &oidc.Userinfo{
Subject: a.GetSubject(),
Address: &oidc.UserinfoAddress{
StreetAddress: "Hjkhkj 789\ndsf",
},
UserinfoEmail: oidc.UserinfoEmail{
Email: "test",
EmailVerified: true,
},
UserinfoPhone: oidc.UserinfoPhone{
PhoneNumber: "sadsa",
PhoneNumberVerified: true,
},
UserinfoProfile: oidc.UserinfoProfile{
UpdatedAt: time.Now(),
},
// Claims: map[string]interface{}{
// "test": "test",
// "hkjh": "",
// },
}, nil
func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _, _ string, _ []string) (oidc.UserInfo, error) {
userinfo := oidc.NewUserInfo()
userinfo.SetSubject(a.GetSubject())
userinfo.SetAddress(oidc.NewUserInfoAddress("Test 789\nPostfach 2", "", "", "", "", ""))
userinfo.SetEmail("test", true)
userinfo.SetPhone("0791234567", true)
userinfo.SetName("Test")
userinfo.AppendClaims("private_claim", "test")
return userinfo, nil
}
func (s *AuthStorage) GetPrivateClaimsFromScopes(_ context.Context, _, _ string, _ []string) (map[string]interface{}, error) {
return map[string]interface{}{"private_claim": "test"}, nil
}
type ConfClient struct {
@ -289,3 +279,15 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType {
func (c *ConfClient) DevMode() bool {
return c.devMode
}
func (c *ConfClient) AllowedScopes() []string {
return nil
}
func (c *ConfClient) AssertAdditionalIdTokenScopes() bool {
return false
}
func (c *ConfClient) AssertAdditionalAccessTokenScopes() bool {
return false
}

View file

@ -1,15 +1,5 @@
package oidc
import (
"encoding/json"
"errors"
"strings"
"time"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
)
const (
//ScopeOpenID defines the scope `openid`
//OpenID Connect requests MUST contain the `openid` scope value
@ -64,23 +54,8 @@ const (
//PromptSelectAccount (`select_account `) directs the Authorization Server to prompt the End-User to select a user account (to enable multi user / session switching)
PromptSelectAccount Prompt = "select_account"
//GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
GrantTypeCode GrantType = "authorization_code"
//GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
//BearerToken defines the token_type `Bearer`, which is returned in a successful token response
BearerToken = "Bearer"
)
var displayValues = map[string]Display{
"page": DisplayPage,
"popup": DisplayPopup,
"touch": DisplayTouch,
"wap": DisplayWAP,
}
//AuthRequest according to:
//https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
type AuthRequest struct {
@ -121,146 +96,3 @@ func (a *AuthRequest) GetResponseType() ResponseType {
func (a *AuthRequest) GetState() string {
return a.State
}
type TokenRequest interface {
// GrantType GrantType `schema:"grant_type"`
GrantType() GrantType
}
type TokenRequestType GrantType
type AccessTokenRequest struct {
Code string `schema:"code"`
RedirectURI string `schema:"redirect_uri"`
ClientID string `schema:"client_id"`
ClientSecret string `schema:"client_secret"`
CodeVerifier string `schema:"code_verifier"`
}
func (a *AccessTokenRequest) GrantType() GrantType {
return GrantTypeCode
}
type AccessTokenResponse struct {
AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
}
type JWTTokenRequest struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Scopes Scopes `json:"scope"`
Audience interface{} `json:"aud"`
IssuedAt Time `json:"iat"`
ExpiresAt Time `json:"exp"`
}
func (j *JWTTokenRequest) GetClientID() string {
return j.Subject
}
func (j *JWTTokenRequest) GetSubject() string {
return j.Subject
}
func (j *JWTTokenRequest) GetScopes() []string {
return j.Scopes
}
type Time time.Time
func (t *Time) UnmarshalJSON(data []byte) error {
var i int64
if err := json.Unmarshal(data, &i); err != nil {
return err
}
*t = Time(time.Unix(i, 0).UTC())
return nil
}
func (j *JWTTokenRequest) GetIssuer() string {
return j.Issuer
}
func (j *JWTTokenRequest) GetAudience() []string {
return audienceFromJSON(j.Audience)
}
func (j *JWTTokenRequest) GetExpiration() time.Time {
return time.Time(j.ExpiresAt)
}
func (j *JWTTokenRequest) GetIssuedAt() time.Time {
return time.Time(j.IssuedAt)
}
func (j *JWTTokenRequest) GetNonce() string {
return ""
}
func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string {
return ""
}
func (j *JWTTokenRequest) GetAuthTime() time.Time {
return time.Time{}
}
func (j *JWTTokenRequest) GetAuthorizedParty() string {
return ""
}
func (j *JWTTokenRequest) SetSignature(algorithm jose.SignatureAlgorithm) {}
type TokenExchangeRequest struct {
subjectToken string `schema:"subject_token"`
subjectTokenType string `schema:"subject_token_type"`
actorToken string `schema:"actor_token"`
actorTokenType string `schema:"actor_token_type"`
resource []string `schema:"resource"`
audience []string `schema:"audience"`
Scope []string `schema:"scope"`
requestedTokenType string `schema:"requested_token_type"`
}
type Scopes []string
func (s *Scopes) UnmarshalText(text []byte) error {
scopes := strings.Split(string(text), " ")
*s = Scopes(scopes)
return nil
}
type ResponseType string
type Display string
func (d *Display) UnmarshalText(text []byte) error {
var ok bool
display := string(text)
*d, ok = displayValues[display]
if !ok {
return errors.New("")
}
return nil
}
type Prompt string
type Locales []language.Tag
func (l *Locales) UnmarshalText(text []byte) error {
locales := strings.Split(string(text), " ")
for _, locale := range locales {
tag, err := language.Parse(locale)
if err == nil && !tag.IsRoot() {
*l = append(*l, tag)
}
}
return nil
}
type GrantType string

View file

@ -1,5 +1,9 @@
package tokenexchange
import (
"github.com/caos/oidc/pkg/oidc"
)
const (
AccessTokenType = "urn:ietf:params:oauth:token-type:access_token"
RefreshTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
@ -24,6 +28,18 @@ type TokenExchangeRequest struct {
type JWTProfileRequest struct {
Assertion string `schema:"assertion"`
Scope oidc.Scopes `schema:"scope"`
GrantType oidc.GrantType `schema:"grant_type"`
}
//ClientCredentialsGrantBasic creates an oauth2 `Client Credentials` Grant
//sneding client_id and client_secret as basic auth header
func NewJWTProfileRequest(assertion string, scopes ...string) *JWTProfileRequest {
return &JWTProfileRequest{
GrantType: oidc.GrantTypeBearer,
Assertion: assertion,
Scope: scopes,
}
}
func NewTokenExchangeRequest(subjectToken, subjectTokenType string, opts ...TokenExchangeOption) *TokenExchangeRequest {

View file

@ -6,21 +6,19 @@ import (
"gopkg.in/square/go-jose.v2"
)
// KeySet is a set of publc JSON Web Keys that can be used to validate the signature
// of JSON web tokens. This is expected to be backed by a remote key set through
// provider metadata discovery or an in-memory set of keys delivered out-of-band.
//KeySet represents a set of JSON Web Keys
// - remotely fetch via discovery and jwks_uri -> `remoteKeySet`
// - held by the OP itself in storage -> `openIDKeySet`
// - dynamically aggregated by request for OAuth JWT Profile Assertion -> `jwtProfileKeySet`
type KeySet interface {
// VerifySignature parses the JSON web token, verifies the signature, and returns
// the raw payload. Header and claim fields are validated by other parts of the
// package. For example, the KeySet does not need to check values such as signature
// algorithm, issuer, and audience since the IDTokenVerifier validates these values
// independently.
//
// If VerifySignature makes HTTP requests to verify the token, it's expected to
// use any HTTP client associated with the context through ClientContext.
//VerifySignature verifies the signature with the given keyset and returns the raw payload
VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error)
}
//CheckKey searches the given JSON Web Keys for the requested key ID
//and verifies the JSON Web Signature with the found key
//
//will return false but no error if key ID is not found
func CheckKey(keyID string, jws *jose.JSONWebSignature, keys ...jose.JSONWebKey) ([]byte, error, bool) {
for _, key := range keys {
if keyID == "" || key.KeyID == keyID {

View file

@ -1,5 +1,7 @@
package oidc
//EndSessionRequest for the RP-Initiated Logout according to:
//https://openid.net/specs/openid-connect-rpinitiated-1_0.html#RPLogout
type EndSessionRequest struct {
IdTokenHint string `schema:"id_token_hint"`
PostLogoutRedirectURI string `schema:"post_logout_redirect_uri"`

View file

@ -3,72 +3,401 @@ package oidc
import (
"encoding/json"
"io/ioutil"
"strings"
"time"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/utils"
)
const (
//BearerToken defines the token_type `Bearer`, which is returned in a successful token response
BearerToken = "Bearer"
)
type Tokens struct {
*oauth2.Token
IDTokenClaims *IDTokenClaims
IDTokenClaims IDTokenClaims
IDToken string
}
type AccessTokenClaims struct {
Issuer string
Subject string
Audiences []string
Expiration time.Time
IssuedAt time.Time
NotBefore time.Time
JWTID string
AuthorizedParty string
Nonce string
AuthTime time.Time
CodeHash string
AuthenticationContextClassReference string
AuthenticationMethodsReferences []string
SessionID string
Scopes []string
ClientID string
AccessTokenUseNumber int
type AccessTokenClaims interface {
Claims
GetSubject() string
GetTokenID() string
SetPrivateClaims(map[string]interface{})
}
type IDTokenClaims struct {
Issuer string
Audiences []string
Expiration time.Time
NotBefore time.Time
IssuedAt time.Time
JWTID string
UpdatedAt time.Time
AuthorizedParty string
Nonce string
AuthTime time.Time
AccessTokenHash string
CodeHash string
AuthenticationContextClassReference string
AuthenticationMethodsReferences []string
ClientID string
Userinfo
type IDTokenClaims interface {
Claims
GetNotBefore() time.Time
GetJWTID() string
GetAccessTokenHash() string
GetCodeHash() string
GetAuthenticationMethodsReferences() []string
GetClientID() string
GetSignatureAlgorithm() jose.SignatureAlgorithm
SetAccessTokenHash(hash string)
SetUserinfo(userinfo UserInfo)
SetCodeHash(hash string)
UserInfo
}
Signature jose.SignatureAlgorithm //TODO: ???
func EmptyAccessTokenClaims() AccessTokenClaims {
return new(accessTokenClaims)
}
func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, id string) AccessTokenClaims {
now := time.Now().UTC()
return &accessTokenClaims{
Issuer: issuer,
Subject: subject,
Audience: audience,
Expiration: Time(expiration),
IssuedAt: Time(now),
NotBefore: Time(now),
JWTID: id,
}
}
type accessTokenClaims struct {
Issuer string `json:"iss,omitempty"`
Subject string `json:"sub,omitempty"`
Audience Audience `json:"aud,omitempty"`
Expiration Time `json:"exp,omitempty"`
IssuedAt Time `json:"iat,omitempty"`
NotBefore Time `json:"nbf,omitempty"`
JWTID string `json:"jti,omitempty"`
AuthorizedParty string `json:"azp,omitempty"`
Nonce string `json:"nonce,omitempty"`
AuthTime Time `json:"auth_time,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
AuthenticationContextClassReference string `json:"acr,omitempty"`
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
SessionID string `json:"sid,omitempty"`
Scopes []string `json:"scope,omitempty"`
ClientID string `json:"client_id,omitempty"`
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
claims map[string]interface{} `json:"-"`
signatureAlg jose.SignatureAlgorithm `json:"-"`
}
//GetIssuer implements the Claims interface
func (a *accessTokenClaims) GetIssuer() string {
return a.Issuer
}
//GetAudience implements the Claims interface
func (a *accessTokenClaims) GetAudience() []string {
return a.Audience
}
//GetExpiration implements the Claims interface
func (a *accessTokenClaims) GetExpiration() time.Time {
return time.Time(a.Expiration)
}
//GetIssuedAt implements the Claims interface
func (a *accessTokenClaims) GetIssuedAt() time.Time {
return time.Time(a.IssuedAt)
}
//GetNonce implements the Claims interface
func (a *accessTokenClaims) GetNonce() string {
return a.Nonce
}
//GetAuthenticationContextClassReference implements the Claims interface
func (a *accessTokenClaims) GetAuthenticationContextClassReference() string {
return a.AuthenticationContextClassReference
}
//GetAuthTime implements the Claims interface
func (a *accessTokenClaims) GetAuthTime() time.Time {
return time.Time(a.AuthTime)
}
//GetAuthorizedParty implements the Claims interface
func (a *accessTokenClaims) GetAuthorizedParty() string {
return a.AuthorizedParty
}
//SetSignatureAlgorithm implements the Claims interface
func (a *accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
a.signatureAlg = algorithm
}
//GetSubject implements the AccessTokenClaims interface
func (a *accessTokenClaims) GetSubject() string {
return a.Subject
}
//GetTokenID implements the AccessTokenClaims interface
func (a *accessTokenClaims) GetTokenID() string {
return a.JWTID
}
//SetPrivateClaims implements the AccessTokenClaims interface
func (a *accessTokenClaims) SetPrivateClaims(claims map[string]interface{}) {
a.claims = claims
}
func (a *accessTokenClaims) MarshalJSON() ([]byte, error) {
type Alias accessTokenClaims
s := &struct {
*Alias
Expiration int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
AuthTime int64 `json:"auth_time,omitempty"`
}{
Alias: (*Alias)(a),
}
if !time.Time(a.Expiration).IsZero() {
s.Expiration = time.Time(a.Expiration).Unix()
}
if !time.Time(a.IssuedAt).IsZero() {
s.IssuedAt = time.Time(a.IssuedAt).Unix()
}
if !time.Time(a.NotBefore).IsZero() {
s.NotBefore = time.Time(a.NotBefore).Unix()
}
if !time.Time(a.AuthTime).IsZero() {
s.AuthTime = time.Time(a.AuthTime).Unix()
}
b, err := json.Marshal(s)
if err != nil {
return nil, err
}
if a.claims == nil {
return b, nil
}
info, err := json.Marshal(a.claims)
if err != nil {
return nil, err
}
return utils.ConcatenateJSON(b, info)
}
func (a *accessTokenClaims) UnmarshalJSON(data []byte) error {
type Alias accessTokenClaims
if err := json.Unmarshal(data, (*Alias)(a)); err != nil {
return err
}
claims := make(map[string]interface{})
if err := json.Unmarshal(data, &claims); err != nil {
return err
}
a.claims = claims
return nil
}
func EmptyIDTokenClaims() IDTokenClaims {
return new(idTokenClaims)
}
func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string) IDTokenClaims {
return &idTokenClaims{
Issuer: issuer,
Audience: audience,
Expiration: Time(expiration),
IssuedAt: Time(time.Now().UTC()),
AuthTime: Time(authTime),
Nonce: nonce,
AuthenticationContextClassReference: acr,
AuthenticationMethodsReferences: amr,
AuthorizedParty: clientID,
UserInfo: &userinfo{Subject: subject},
}
}
type idTokenClaims struct {
Issuer string `json:"iss,omitempty"`
Audience Audience `json:"aud,omitempty"`
Expiration Time `json:"exp,omitempty"`
NotBefore Time `json:"nbf,omitempty"`
IssuedAt Time `json:"iat,omitempty"`
JWTID string `json:"jti,omitempty"`
AuthorizedParty string `json:"azp,omitempty"`
Nonce string `json:"nonce,omitempty"`
AuthTime Time `json:"auth_time,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
AuthenticationContextClassReference string `json:"acr,omitempty"`
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
ClientID string `json:"client_id,omitempty"`
UserInfo `json:"-"`
signatureAlg jose.SignatureAlgorithm
}
//GetIssuer implements the Claims interface
func (t *idTokenClaims) GetIssuer() string {
return t.Issuer
}
//GetAudience implements the Claims interface
func (t *idTokenClaims) GetAudience() []string {
return t.Audience
}
//GetExpiration implements the Claims interface
func (t *idTokenClaims) GetExpiration() time.Time {
return time.Time(t.Expiration)
}
//GetIssuedAt implements the Claims interface
func (t *idTokenClaims) GetIssuedAt() time.Time {
return time.Time(t.IssuedAt)
}
//GetNonce implements the Claims interface
func (t *idTokenClaims) GetNonce() string {
return t.Nonce
}
//GetAuthenticationContextClassReference implements the Claims interface
func (t *idTokenClaims) GetAuthenticationContextClassReference() string {
return t.AuthenticationContextClassReference
}
//GetAuthTime implements the Claims interface
func (t *idTokenClaims) GetAuthTime() time.Time {
return time.Time(t.AuthTime)
}
//GetAuthorizedParty implements the Claims interface
func (t *idTokenClaims) GetAuthorizedParty() string {
return t.AuthorizedParty
}
//SetSignatureAlgorithm implements the Claims interface
func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) {
t.signatureAlg = alg
}
//GetNotBefore implements the IDTokenClaims interface
func (t *idTokenClaims) GetNotBefore() time.Time {
return time.Time(t.NotBefore)
}
//GetJWTID implements the IDTokenClaims interface
func (t *idTokenClaims) GetJWTID() string {
return t.JWTID
}
//GetAccessTokenHash implements the IDTokenClaims interface
func (t *idTokenClaims) GetAccessTokenHash() string {
return t.AccessTokenHash
}
//GetCodeHash implements the IDTokenClaims interface
func (t *idTokenClaims) GetCodeHash() string {
return t.CodeHash
}
//GetAuthenticationMethodsReferences implements the IDTokenClaims interface
func (t *idTokenClaims) GetAuthenticationMethodsReferences() []string {
return t.AuthenticationMethodsReferences
}
//GetClientID implements the IDTokenClaims interface
func (t *idTokenClaims) GetClientID() string {
return t.ClientID
}
//GetSignatureAlgorithm implements the IDTokenClaims interface
func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm {
return t.signatureAlg
}
//SetSignatureAlgorithm implements the IDTokenClaims interface
func (t *idTokenClaims) SetAccessTokenHash(hash string) {
t.AccessTokenHash = hash
}
//SetUserinfo implements the IDTokenClaims interface
func (t *idTokenClaims) SetUserinfo(info UserInfo) {
t.UserInfo = info
}
//SetCodeHash implements the IDTokenClaims interface
func (t *idTokenClaims) SetCodeHash(hash string) {
t.CodeHash = hash
}
func (t *idTokenClaims) MarshalJSON() ([]byte, error) {
type Alias idTokenClaims
a := &struct {
*Alias
Expiration int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
AuthTime int64 `json:"auth_time,omitempty"`
}{
Alias: (*Alias)(t),
}
if !time.Time(t.Expiration).IsZero() {
a.Expiration = time.Time(t.Expiration).Unix()
}
if !time.Time(t.IssuedAt).IsZero() {
a.IssuedAt = time.Time(t.IssuedAt).Unix()
}
if !time.Time(t.NotBefore).IsZero() {
a.NotBefore = time.Time(t.NotBefore).Unix()
}
if !time.Time(t.AuthTime).IsZero() {
a.AuthTime = time.Time(t.AuthTime).Unix()
}
b, err := json.Marshal(a)
if err != nil {
return nil, err
}
if t.UserInfo == nil {
return b, nil
}
info, err := json.Marshal(t.UserInfo)
if err != nil {
return nil, err
}
return utils.ConcatenateJSON(b, info)
}
func (t *idTokenClaims) UnmarshalJSON(data []byte) error {
type Alias idTokenClaims
if err := json.Unmarshal(data, (*Alias)(t)); err != nil {
return err
}
userinfo := new(userinfo)
if err := json.Unmarshal(data, userinfo); err != nil {
return err
}
t.UserInfo = userinfo
return nil
}
type AccessTokenResponse struct {
AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
}
type JWTProfileAssertion struct {
PrivateKeyID string `json:"keyId"`
PrivateKey []byte `json:"key"`
Scopes []string `json:"-"`
Issuer string `json:"-"`
Subject string `json:"userId"`
Audience []string `json:"-"`
Expiration time.Time `json:"-"`
IssuedAt time.Time `json:"-"`
PrivateKeyID string `json:"-"`
PrivateKey []byte `json:"-"`
Issuer string `json:"issuer"`
Subject string `json:"sub"`
Audience Audience `json:"aud"`
Expiration Time `json:"exp"`
IssuedAt Time `json:"iat"`
}
func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWTProfileAssertion, error) {
@ -76,12 +405,16 @@ func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string) (*JWT
if err != nil {
return nil, err
}
return NewJWTProfileAssertionFromFileData(data, audience)
}
func NewJWTProfileAssertionFromFileData(data []byte, audience []string) (*JWTProfileAssertion, error) {
keyData := new(struct {
KeyID string `json:"keyId"`
Key string `json:"key"`
UserID string `json:"userId"`
})
err = json.Unmarshal(data, keyData)
err := json.Unmarshal(data, keyData)
if err != nil {
return nil, err
}
@ -93,244 +426,13 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte)
PrivateKey: key,
PrivateKeyID: keyID,
Issuer: userID,
Scopes: []string{ScopeOpenID},
Subject: userID,
IssuedAt: time.Now().UTC(),
Expiration: time.Now().Add(1 * time.Hour).UTC(),
IssuedAt: Time(time.Now().UTC()),
Expiration: Time(time.Now().Add(1 * time.Hour).UTC()),
Audience: audience,
}
}
type jsonToken struct {
Issuer string `json:"iss,omitempty"`
Subject string `json:"sub,omitempty"`
Audiences interface{} `json:"aud,omitempty"`
Expiration int64 `json:"exp,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
JWTID string `json:"jti,omitempty"`
AuthorizedParty string `json:"azp,omitempty"`
Nonce string `json:"nonce,omitempty"`
AuthTime int64 `json:"auth_time,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
AuthenticationContextClassReference string `json:"acr,omitempty"`
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
SessionID string `json:"sid,omitempty"`
Actor interface{} `json:"act,omitempty"` //TODO: impl
Scopes string `json:"scope,omitempty"`
ClientID string `json:"client_id,omitempty"`
AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
jsonUserinfo
}
func (t *AccessTokenClaims) MarshalJSON() ([]byte, error) {
j := jsonToken{
Issuer: t.Issuer,
Subject: t.Subject,
Audiences: t.Audiences,
Expiration: timeToJSON(t.Expiration),
NotBefore: timeToJSON(t.NotBefore),
IssuedAt: timeToJSON(t.IssuedAt),
JWTID: t.JWTID,
AuthorizedParty: t.AuthorizedParty,
Nonce: t.Nonce,
AuthTime: timeToJSON(t.AuthTime),
CodeHash: t.CodeHash,
AuthenticationContextClassReference: t.AuthenticationContextClassReference,
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
SessionID: t.SessionID,
Scopes: strings.Join(t.Scopes, " "),
ClientID: t.ClientID,
AccessTokenUseNumber: t.AccessTokenUseNumber,
}
return json.Marshal(j)
}
func (t *AccessTokenClaims) UnmarshalJSON(b []byte) error {
var j jsonToken
if err := json.Unmarshal(b, &j); err != nil {
return err
}
t.Issuer = j.Issuer
t.Subject = j.Subject
t.Audiences = audienceFromJSON(j.Audiences)
t.Expiration = time.Unix(j.Expiration, 0).UTC()
t.NotBefore = time.Unix(j.NotBefore, 0).UTC()
t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
t.JWTID = j.JWTID
t.AuthorizedParty = j.AuthorizedParty
t.Nonce = j.Nonce
t.AuthTime = time.Unix(j.AuthTime, 0).UTC()
t.CodeHash = j.CodeHash
t.AuthenticationContextClassReference = j.AuthenticationContextClassReference
t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences
t.SessionID = j.SessionID
t.Scopes = strings.Split(j.Scopes, " ")
t.ClientID = j.ClientID
t.AccessTokenUseNumber = j.AccessTokenUseNumber
return nil
}
func (t *IDTokenClaims) MarshalJSON() ([]byte, error) {
j := jsonToken{
Issuer: t.Issuer,
Subject: t.Subject,
Audiences: t.Audiences,
Expiration: timeToJSON(t.Expiration),
NotBefore: timeToJSON(t.NotBefore),
IssuedAt: timeToJSON(t.IssuedAt),
JWTID: t.JWTID,
AuthorizedParty: t.AuthorizedParty,
Nonce: t.Nonce,
AuthTime: timeToJSON(t.AuthTime),
AccessTokenHash: t.AccessTokenHash,
CodeHash: t.CodeHash,
AuthenticationContextClassReference: t.AuthenticationContextClassReference,
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
ClientID: t.ClientID,
}
j.setUserinfo(t.Userinfo)
return json.Marshal(j)
}
func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
var i jsonToken
if err := json.Unmarshal(b, &i); err != nil {
return err
}
t.Issuer = i.Issuer
t.Subject = i.Subject
t.Audiences = audienceFromJSON(i.Audiences)
t.Expiration = time.Unix(i.Expiration, 0).UTC()
t.IssuedAt = time.Unix(i.IssuedAt, 0).UTC()
t.AuthTime = time.Unix(i.AuthTime, 0).UTC()
t.Nonce = i.Nonce
t.AuthenticationContextClassReference = i.AuthenticationContextClassReference
t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences
t.AuthorizedParty = i.AuthorizedParty
t.AccessTokenHash = i.AccessTokenHash
t.CodeHash = i.CodeHash
t.UserinfoProfile = i.UnmarshalUserinfoProfile()
t.UserinfoEmail = i.UnmarshalUserinfoEmail()
t.UserinfoPhone = i.UnmarshalUserinfoPhone()
t.Address = i.UnmarshalUserinfoAddress()
return nil
}
func (t *IDTokenClaims) GetIssuer() string {
return t.Issuer
}
func (t *IDTokenClaims) GetAudience() []string {
return t.Audiences
}
func (t *IDTokenClaims) GetExpiration() time.Time {
return t.Expiration
}
func (t *IDTokenClaims) GetIssuedAt() time.Time {
return t.IssuedAt
}
func (t *IDTokenClaims) GetNonce() string {
return t.Nonce
}
func (t *IDTokenClaims) GetAuthenticationContextClassReference() string {
return t.AuthenticationContextClassReference
}
func (t *IDTokenClaims) GetAuthTime() time.Time {
return t.AuthTime
}
func (t *IDTokenClaims) GetAuthorizedParty() string {
return t.AuthorizedParty
}
func (t *IDTokenClaims) SetSignature(alg jose.SignatureAlgorithm) {
t.Signature = alg
}
func (t *JWTProfileAssertion) MarshalJSON() ([]byte, error) {
j := jsonToken{
Issuer: t.Issuer,
Subject: t.Subject,
Audiences: t.Audience,
Expiration: timeToJSON(t.Expiration),
IssuedAt: timeToJSON(t.IssuedAt),
Scopes: strings.Join(t.Scopes, " "),
}
return json.Marshal(j)
}
func (t *JWTProfileAssertion) UnmarshalJSON(b []byte) error {
var j jsonToken
if err := json.Unmarshal(b, &j); err != nil {
return err
}
t.Issuer = j.Issuer
t.Subject = j.Subject
t.Audience = audienceFromJSON(j.Audiences)
t.Expiration = time.Unix(j.Expiration, 0).UTC()
t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
t.Scopes = strings.Split(j.Scopes, " ")
return nil
}
func (j *jsonToken) UnmarshalUserinfoProfile() UserinfoProfile {
locale, _ := language.Parse(j.Locale)
return UserinfoProfile{
Name: j.Name,
GivenName: j.GivenName,
FamilyName: j.FamilyName,
MiddleName: j.MiddleName,
Nickname: j.Nickname,
Profile: j.Profile,
Picture: j.Picture,
Website: j.Website,
Gender: Gender(j.Gender),
Birthdate: j.Birthdate,
Zoneinfo: j.Zoneinfo,
Locale: locale,
UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(),
PreferredUsername: j.PreferredUsername,
}
}
func (j *jsonToken) UnmarshalUserinfoEmail() UserinfoEmail {
return UserinfoEmail{
Email: j.Email,
EmailVerified: j.EmailVerified,
}
}
func (j *jsonToken) UnmarshalUserinfoPhone() UserinfoPhone {
return UserinfoPhone{
PhoneNumber: j.Phone,
PhoneNumberVerified: j.PhoneVerified,
}
}
func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress {
if j.JsonUserinfoAddress == nil {
return nil
}
return &UserinfoAddress{
Country: j.JsonUserinfoAddress.Country,
Formatted: j.JsonUserinfoAddress.Formatted,
Locality: j.JsonUserinfoAddress.Locality,
PostalCode: j.JsonUserinfoAddress.PostalCode,
Region: j.JsonUserinfoAddress.Region,
StreetAddress: j.JsonUserinfoAddress.StreetAddress,
}
}
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
hash, err := utils.GetHashAlgorithm(sigAlgorithm)
if err != nil {
@ -339,26 +441,3 @@ func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, erro
return utils.HashString(hash, claim, true), nil
}
func timeToJSON(t time.Time) int64 {
if t.IsZero() {
return 0
}
return t.Unix()
}
func audienceFromJSON(i interface{}) []string {
switch aud := i.(type) {
case []string:
return aud
case []interface{}:
audience := make([]string, len(aud))
for i, a := range aud {
audience[i] = a.(string)
}
return audience
case string:
return []string{aud}
}
return nil
}

108
pkg/oidc/token_request.go Normal file
View file

@ -0,0 +1,108 @@
package oidc
import (
"time"
"gopkg.in/square/go-jose.v2"
)
const (
//GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
GrantTypeCode GrantType = "authorization_code"
//GrantTypeBearer define the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
)
type GrantType string
type TokenRequest interface {
// GrantType GrantType `schema:"grant_type"`
GrantType() GrantType
}
type TokenRequestType GrantType
type AccessTokenRequest struct {
Code string `schema:"code"`
RedirectURI string `schema:"redirect_uri"`
ClientID string `schema:"client_id"`
ClientSecret string `schema:"client_secret"`
CodeVerifier string `schema:"code_verifier"`
}
func (a *AccessTokenRequest) GrantType() GrantType {
return GrantTypeCode
}
type JWTTokenRequest struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Scopes Scopes `json:"-"`
Audience Audience `json:"aud"`
IssuedAt Time `json:"iat"`
ExpiresAt Time `json:"exp"`
}
//GetIssuer implements the Claims interface
func (j *JWTTokenRequest) GetIssuer() string {
return j.Issuer
}
//GetAudience implements the Claims and TokenRequest interfaces
func (j *JWTTokenRequest) GetAudience() []string {
return j.Audience
}
//GetExpiration implements the Claims interface
func (j *JWTTokenRequest) GetExpiration() time.Time {
return time.Time(j.ExpiresAt)
}
//GetIssuedAt implements the Claims interface
func (j *JWTTokenRequest) GetIssuedAt() time.Time {
return time.Time(j.IssuedAt)
}
//GetNonce implements the Claims interface
func (j *JWTTokenRequest) GetNonce() string {
return ""
}
//GetAuthenticationContextClassReference implements the Claims interface
func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string {
return ""
}
//GetAuthTime implements the Claims interface
func (j *JWTTokenRequest) GetAuthTime() time.Time {
return time.Time{}
}
//GetAuthorizedParty implements the Claims interface
func (j *JWTTokenRequest) GetAuthorizedParty() string {
return ""
}
//SetSignatureAlgorithm implements the Claims interface
func (j *JWTTokenRequest) SetSignatureAlgorithm(_ jose.SignatureAlgorithm) {}
//GetSubject implements the TokenRequest interface
func (j *JWTTokenRequest) GetSubject() string {
return j.Subject
}
//GetSubject implements the TokenRequest interface
func (j *JWTTokenRequest) GetScopes() []string {
return j.Scopes
}
type TokenExchangeRequest struct {
subjectToken string `schema:"subject_token"`
subjectTokenType string `schema:"subject_token_type"`
actorToken string `schema:"actor_token"`
actorTokenType string `schema:"actor_token_type"`
resource []string `schema:"resource"`
audience Audience `schema:"audience"`
Scope Scopes `schema:"scope"`
requestedTokenType string `schema:"requested_token_type"`
}

89
pkg/oidc/types.go Normal file
View file

@ -0,0 +1,89 @@
package oidc
import (
"encoding/json"
"strings"
"time"
"golang.org/x/text/language"
)
type Audience []string
func (a *Audience) UnmarshalJSON(text []byte) error {
var i interface{}
err := json.Unmarshal(text, &i)
if err != nil {
return err
}
switch aud := i.(type) {
case []interface{}:
*a = make([]string, len(aud))
for i, audience := range aud {
(*a)[i] = audience.(string)
}
case string:
*a = []string{aud}
}
return nil
}
type Display string
func (d *Display) UnmarshalText(text []byte) error {
display := Display(text)
switch display {
case DisplayPage, DisplayPopup, DisplayTouch, DisplayWAP:
*d = display
}
return nil
}
type Gender string
type Locales []language.Tag
func (l *Locales) UnmarshalText(text []byte) error {
locales := strings.Split(string(text), " ")
for _, locale := range locales {
tag, err := language.Parse(locale)
if err == nil && !tag.IsRoot() {
*l = append(*l, tag)
}
}
return nil
}
type Prompt string
type ResponseType string
type Scopes []string
func (s Scopes) Encode() string {
return strings.Join(s, " ")
}
func (s *Scopes) UnmarshalText(text []byte) error {
*s = strings.Split(string(text), " ")
return nil
}
func (s *Scopes) MarshalText() ([]byte, error) {
return []byte(s.Encode()), nil
}
type Time time.Time
func (t *Time) UnmarshalJSON(data []byte) error {
var i int64
if err := json.Unmarshal(data, &i); err != nil {
return err
}
*t = Time(time.Unix(i, 0).UTC())
return nil
}
func (t *Time) MarshalJSON() ([]byte, error) {
return json.Marshal(time.Time(*t).UTC().Unix())
}

296
pkg/oidc/types_test.go Normal file
View file

@ -0,0 +1,296 @@
package oidc
import (
"bytes"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/text/language"
)
func TestAudience_UnmarshalText(t *testing.T) {
type args struct {
text []byte
}
type res struct {
audience Audience
}
tests := []struct {
name string
args args
res res
wantErr bool
}{
{
"invalid value",
args{
[]byte(`{"aud": {"a": }}}`),
},
res{},
true,
},
{
"single audience",
args{
[]byte(`{"aud": "single audience"}`),
},
res{
[]string{"single audience"},
},
false,
},
{
"multiple audience",
args{
[]byte(`{"aud": ["multiple", "audience"]}`),
},
res{
[]string{"multiple", "audience"},
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := new(struct {
Audience Audience `json:"aud"`
})
if err := json.Unmarshal(tt.args.text, &a); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
}
assert.ElementsMatch(t, a.Audience, tt.res.audience)
})
}
}
func TestDisplay_UnmarshalText(t *testing.T) {
type args struct {
text []byte
}
type res struct {
display Display
}
tests := []struct {
name string
args args
res res
wantErr bool
}{
{
"unknown value",
args{
[]byte("unknown"),
},
res{},
false,
},
{
"page",
args{
[]byte("page"),
},
res{DisplayPage},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var d Display
if err := d.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
}
if d != tt.res.display {
t.Errorf("Display is not correct is = %v, want %v", d, tt.res.display)
}
})
}
}
func TestLocales_UnmarshalText(t *testing.T) {
type args struct {
text []byte
}
type res struct {
tags []language.Tag
}
tests := []struct {
name string
args args
res res
wantErr bool
}{
{
"unknown value",
args{
[]byte("unknown"),
},
res{},
false,
},
{
"undefined",
args{
[]byte("und"),
},
res{},
false,
},
{
"single language",
args{
[]byte("de"),
},
res{[]language.Tag{language.German}},
false,
},
{
"multiple languages",
args{
[]byte("de en"),
},
res{[]language.Tag{language.German, language.English}},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var locales Locales
if err := locales.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
}
assert.ElementsMatch(t, locales, tt.res.tags)
})
}
}
func TestScopes_UnmarshalText(t *testing.T) {
type args struct {
text []byte
}
type res struct {
scopes []string
}
tests := []struct {
name string
args args
res res
wantErr bool
}{
{
"unknown value",
args{
[]byte("unknown"),
},
res{
[]string{"unknown"},
},
false,
},
{
"struct",
args{
[]byte(`{"unknown":"value"}`),
},
res{
[]string{`{"unknown":"value"}`},
},
false,
},
{
"openid",
args{
[]byte("openid"),
},
res{
[]string{"openid"},
},
false,
},
{
"multiple scopes",
args{
[]byte("openid email custom:scope"),
},
res{
[]string{"openid", "email", "custom:scope"},
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var scopes Scopes
if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
}
assert.ElementsMatch(t, scopes, tt.res.scopes)
})
}
}
func TestScopes_MarshalText(t *testing.T) {
type args struct {
scopes Scopes
}
type res struct {
scopes []byte
}
tests := []struct {
name string
args args
res res
wantErr bool
}{
{
"unknown value",
args{
Scopes{"unknown"},
},
res{
[]byte("unknown"),
},
false,
},
{
"struct",
args{
Scopes{`{"unknown":"value"}`},
},
res{
[]byte(`{"unknown":"value"}`),
},
false,
},
{
"openid",
args{
Scopes{"openid"},
},
res{
[]byte("openid"),
},
false,
},
{
"multiple scopes",
args{
Scopes{"openid", "email", "custom:scope"},
},
res{
[]byte("openid email custom:scope"),
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
text, err := tt.args.scopes.MarshalText()
if (err != nil) != tt.wantErr {
t.Errorf("MarshalText() error = %v, wantErr %v", err, tt.wantErr)
}
if !bytes.Equal(text, tt.res.scopes) {
t.Errorf("MarshalText() is = %q, want %q", text, tt.res.scopes)
}
})
}
}

View file

@ -2,62 +2,285 @@ package oidc
import (
"encoding/json"
"fmt"
"time"
"golang.org/x/text/language"
"github.com/caos/oidc/pkg/utils"
)
type Userinfo struct {
Subject string
UserinfoProfile
UserinfoEmail
UserinfoPhone
Address *UserinfoAddress
type UserInfo interface {
GetSubject() string
UserInfoProfile
UserInfoEmail
UserInfoPhone
GetAddress() UserInfoAddress
GetClaim(key string) interface{}
}
Authorizations []string
type UserInfoProfile interface {
GetName() string
GetGivenName() string
GetFamilyName() string
GetMiddleName() string
GetNickname() string
GetProfile() string
GetPicture() string
GetWebsite() string
GetGender() Gender
GetBirthdate() string
GetZoneinfo() string
GetLocale() language.Tag
GetPreferredUsername() string
}
type UserInfoEmail interface {
GetEmail() string
IsEmailVerified() bool
}
type UserInfoPhone interface {
GetPhoneNumber() string
IsPhoneNumberVerified() bool
}
type UserInfoAddress interface {
GetFormatted() string
GetStreetAddress() string
GetLocality() string
GetRegion() string
GetPostalCode() string
GetCountry() string
}
type UserInfoSetter interface {
UserInfo
SetSubject(sub string)
UserInfoProfileSetter
SetEmail(email string, verified bool)
SetPhone(phone string, verified bool)
SetAddress(address UserInfoAddress)
AppendClaims(key string, values interface{})
}
type UserInfoProfileSetter interface {
SetName(name string)
SetGivenName(name string)
SetFamilyName(name string)
SetMiddleName(name string)
SetNickname(name string)
SetUpdatedAt(date time.Time)
SetProfile(profile string)
SetPicture(profile string)
SetWebsite(website string)
SetGender(gender Gender)
SetBirthdate(birthdate string)
SetZoneinfo(zoneInfo string)
SetLocale(locale language.Tag)
SetPreferredUsername(name string)
}
func NewUserInfo() UserInfoSetter {
return &userinfo{}
}
type userinfo struct {
Subject string `json:"sub,omitempty"`
userInfoProfile
userInfoEmail
userInfoPhone
Address UserInfoAddress `json:"address,omitempty"`
claims map[string]interface{}
}
type UserinfoProfile struct {
Name string
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender Gender
Birthdate string
Zoneinfo string
Locale language.Tag
UpdatedAt time.Time
PreferredUsername string
func (u *userinfo) GetSubject() string {
return u.Subject
}
type Gender string
type UserinfoEmail struct {
Email string
EmailVerified bool
func (u *userinfo) GetName() string {
return u.Name
}
type UserinfoPhone struct {
PhoneNumber string
PhoneNumberVerified bool
func (u *userinfo) GetGivenName() string {
return u.GivenName
}
type UserinfoAddress struct {
Formatted string
StreetAddress string
Locality string
Region string
PostalCode string
Country string
func (u *userinfo) GetFamilyName() string {
return u.FamilyName
}
type jsonUserinfoProfile struct {
func (u *userinfo) GetMiddleName() string {
return u.MiddleName
}
func (u *userinfo) GetNickname() string {
return u.Nickname
}
func (u *userinfo) GetProfile() string {
return u.Profile
}
func (u *userinfo) GetPicture() string {
return u.Picture
}
func (u *userinfo) GetWebsite() string {
return u.Website
}
func (u *userinfo) GetGender() Gender {
return u.Gender
}
func (u *userinfo) GetBirthdate() string {
return u.Birthdate
}
func (u *userinfo) GetZoneinfo() string {
return u.Zoneinfo
}
func (u *userinfo) GetLocale() language.Tag {
return u.Locale
}
func (u *userinfo) GetPreferredUsername() string {
return u.PreferredUsername
}
func (u *userinfo) GetEmail() string {
return u.Email
}
func (u *userinfo) IsEmailVerified() bool {
return u.EmailVerified
}
func (u *userinfo) GetPhoneNumber() string {
return u.PhoneNumber
}
func (u *userinfo) IsPhoneNumberVerified() bool {
return u.PhoneNumberVerified
}
func (u *userinfo) GetAddress() UserInfoAddress {
return u.Address
}
func (u *userinfo) GetClaim(key string) interface{} {
return u.claims[key]
}
func (u *userinfo) SetSubject(sub string) {
u.Subject = sub
}
func (u *userinfo) SetName(name string) {
u.Name = name
}
func (u *userinfo) SetGivenName(name string) {
u.GivenName = name
}
func (u *userinfo) SetFamilyName(name string) {
u.FamilyName = name
}
func (u *userinfo) SetMiddleName(name string) {
u.MiddleName = name
}
func (u *userinfo) SetNickname(name string) {
u.Nickname = name
}
func (u *userinfo) SetUpdatedAt(date time.Time) {
u.UpdatedAt = Time(date)
}
func (u *userinfo) SetProfile(profile string) {
u.Profile = profile
}
func (u *userinfo) SetPicture(picture string) {
u.Picture = picture
}
func (u *userinfo) SetWebsite(website string) {
u.Website = website
}
func (u *userinfo) SetGender(gender Gender) {
u.Gender = gender
}
func (u *userinfo) SetBirthdate(birthdate string) {
u.Birthdate = birthdate
}
func (u *userinfo) SetZoneinfo(zoneInfo string) {
u.Zoneinfo = zoneInfo
}
func (u *userinfo) SetLocale(locale language.Tag) {
u.Locale = locale
}
func (u *userinfo) SetPreferredUsername(name string) {
u.PreferredUsername = name
}
func (u *userinfo) SetEmail(email string, verified bool) {
u.Email = email
u.EmailVerified = verified
}
func (u *userinfo) SetPhone(phone string, verified bool) {
u.PhoneNumber = phone
u.PhoneNumberVerified = verified
}
func (u *userinfo) SetAddress(address UserInfoAddress) {
u.Address = address
}
func (u *userinfo) AppendClaims(key string, value interface{}) {
if u.claims == nil {
u.claims = make(map[string]interface{})
}
u.claims[key] = value
}
func (u *userInfoAddress) GetFormatted() string {
return u.Formatted
}
func (u *userInfoAddress) GetStreetAddress() string {
return u.StreetAddress
}
func (u *userInfoAddress) GetLocality() string {
return u.Locality
}
func (u *userInfoAddress) GetRegion() string {
return u.Region
}
func (u *userInfoAddress) GetPostalCode() string {
return u.PostalCode
}
func (u *userInfoAddress) GetCountry() string {
return u.Country
}
type userInfoProfile struct {
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
@ -66,25 +289,25 @@ type jsonUserinfoProfile struct {
Profile string `json:"profile,omitempty"`
Picture string `json:"picture,omitempty"`
Website string `json:"website,omitempty"`
Gender string `json:"gender,omitempty"`
Gender Gender `json:"gender,omitempty"`
Birthdate string `json:"birthdate,omitempty"`
Zoneinfo string `json:"zoneinfo,omitempty"`
Locale string `json:"locale,omitempty"`
UpdatedAt int64 `json:"updated_at,omitempty"`
Locale language.Tag `json:"locale,omitempty"`
UpdatedAt Time `json:"updated_at,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
}
type jsonUserinfoEmail struct {
type userInfoEmail struct {
Email string `json:"email,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"`
}
type jsonUserinfoPhone struct {
Phone string `json:"phone_number,omitempty"`
PhoneVerified bool `json:"phone_number_verified,omitempty"`
type userInfoPhone struct {
PhoneNumber string `json:"phone_number,omitempty"`
PhoneNumberVerified bool `json:"phone_number_verified,omitempty"`
}
type jsonUserinfoAddress struct {
type userInfoAddress struct {
Formatted string `json:"formatted,omitempty"`
StreetAddress string `json:"street_address,omitempty"`
Locality string `json:"locality,omitempty"`
@ -93,81 +316,63 @@ type jsonUserinfoAddress struct {
Country string `json:"country,omitempty"`
}
func (i *Userinfo) MarshalJSON() ([]byte, error) {
j := new(jsonUserinfo)
j.Subject = i.Subject
j.setUserinfo(*i)
j.Authorizations = i.Authorizations
return json.Marshal(j)
func NewUserInfoAddress(streetAddress, locality, region, postalCode, country, formatted string) UserInfoAddress {
return &userInfoAddress{
StreetAddress: streetAddress,
Locality: locality,
Region: region,
PostalCode: postalCode,
Country: country,
Formatted: formatted,
}
}
func (i *userinfo) MarshalJSON() ([]byte, error) {
type Alias userinfo
a := &struct {
*Alias
Locale interface{} `json:"locale,omitempty"`
UpdatedAt int64 `json:"updated_at,omitempty"`
}{
Alias: (*Alias)(i),
}
if !i.Locale.IsRoot() {
a.Locale = i.Locale
}
if !time.Time(i.UpdatedAt).IsZero() {
a.UpdatedAt = time.Time(i.UpdatedAt).Unix()
}
b, err := json.Marshal(a)
if err != nil {
return nil, err
}
if len(i.claims) == 0 {
return b, nil
}
claims, err := json.Marshal(i.claims)
if err != nil {
return nil, fmt.Errorf("jws: invalid map of custom claims %v", i.claims)
}
return utils.ConcatenateJSON(b, claims)
}
func (i *Userinfo) UnmmarshalJSON(data []byte) error {
if err := json.Unmarshal(data, i); err != nil {
func (i *userinfo) UnmarshalJSON(data []byte) error {
type Alias userinfo
a := &struct {
*Alias
UpdatedAt int64 `json:"update_at,omitempty"`
}{
Alias: (*Alias)(i),
}
if err := json.Unmarshal(data, &a); err != nil {
return err
}
return json.Unmarshal(data, &i.claims)
}
type jsonUserinfo struct {
Subject string `json:"sub,omitempty"`
jsonUserinfoProfile
jsonUserinfoEmail
jsonUserinfoPhone
JsonUserinfoAddress *jsonUserinfoAddress `json:"address,omitempty"`
Authorizations []string `json:"authorizations,omitempty"`
}
i.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC())
func (j *jsonUserinfo) setUserinfo(i Userinfo) {
j.setUserinfoProfile(i.UserinfoProfile)
j.setUserinfoEmail(i.UserinfoEmail)
j.setUserinfoPhone(i.UserinfoPhone)
j.setUserinfoAddress(i.Address)
}
func (j *jsonUserinfo) setUserinfoProfile(i UserinfoProfile) {
j.Name = i.Name
j.GivenName = i.GivenName
j.FamilyName = i.FamilyName
j.MiddleName = i.MiddleName
j.Nickname = i.Nickname
j.Profile = i.Profile
j.Picture = i.Picture
j.Website = i.Website
j.Gender = string(i.Gender)
j.Birthdate = i.Birthdate
j.Zoneinfo = i.Zoneinfo
if i.Locale != language.Und {
j.Locale = i.Locale.String()
}
j.UpdatedAt = timeToJSON(i.UpdatedAt)
j.PreferredUsername = i.PreferredUsername
}
func (j *jsonUserinfo) setUserinfoEmail(i UserinfoEmail) {
j.Email = i.Email
j.EmailVerified = i.EmailVerified
}
func (j *jsonUserinfo) setUserinfoPhone(i UserinfoPhone) {
j.Phone = i.PhoneNumber
j.PhoneVerified = i.PhoneNumberVerified
}
func (j *jsonUserinfo) setUserinfoAddress(i *UserinfoAddress) {
if i == nil {
return
}
if i.Country == "" && i.Formatted == "" && i.Locality == "" && i.PostalCode == "" && i.Region == "" && i.StreetAddress == "" {
return
}
j.JsonUserinfoAddress = &jsonUserinfoAddress{
Country: i.Country,
Formatted: i.Formatted,
Locality: i.Locality,
PostalCode: i.PostalCode,
Region: i.Region,
StreetAddress: i.StreetAddress,
}
return nil
}
type UserInfoRequest struct {

View file

@ -24,7 +24,7 @@ type Claims interface {
GetAuthenticationContextClassReference() string
GetAuthTime() time.Time
GetAuthorizedParty() string
SetSignature(algorithm jose.SignatureAlgorithm)
SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm)
}
var (
@ -140,7 +140,7 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl
return ErrSignatureInvalidPayload
}
claims.SetSignature(jose.SignatureAlgorithm(sig.Header.Algorithm))
claims.SetSignatureAlgorithm(jose.SignatureAlgorithm(sig.Header.Algorithm))
return nil
}

View file

@ -91,7 +91,8 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
if err != nil {
return "", ErrServerError(err.Error())
}
if err := ValidateAuthReqScopes(authReq.Scopes); err != nil {
authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes)
if err != nil {
return "", err
}
if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil {
@ -104,14 +105,33 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
}
//ValidateAuthReqScopes validates the passed scopes
func ValidateAuthReqScopes(scopes []string) error {
func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
if len(scopes) == 0 {
return ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.")
return nil, ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.")
}
if !utils.Contains(scopes, oidc.ScopeOpenID) {
return ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.")
openID := false
for i := len(scopes) - 1; i >= 0; i-- {
scope := scopes[i]
if scope == oidc.ScopeOpenID {
openID = true
continue
}
return nil
if !(scope == oidc.ScopeProfile ||
scope == oidc.ScopeEmail ||
scope == oidc.ScopePhone ||
scope == oidc.ScopeAddress ||
scope == oidc.ScopeOfflineAccess) &&
!utils.Contains(client.AllowedScopes(), scope) {
scopes[i] = scopes[len(scopes)-1]
scopes[len(scopes)-1] = ""
scopes = scopes[:len(scopes)-1]
}
}
if !openID {
return nil, ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.")
}
return scopes, nil
}
//ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type
@ -168,7 +188,7 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
if err != nil {
return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.")
}
return claims.Subject, nil
return claims.GetSubject(), nil
}
//RedirectToLogin redirects the end user to the Login UI for authentication

View file

@ -8,6 +8,7 @@ import (
"testing"
"github.com/gorilla/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/caos/oidc/pkg/oidc"
@ -193,28 +194,63 @@ func TestValidateAuthRequest(t *testing.T) {
func TestValidateAuthReqScopes(t *testing.T) {
type args struct {
client op.Client
scopes []string
}
type res struct {
err bool
scopes []string
}
tests := []struct {
name string
args args
wantErr bool
res res
}{
{
"scopes missing fails", args{}, true,
"scopes missing fails",
args{},
res{
err: true,
},
},
{
"scope openid missing fails", args{[]string{"email"}}, true,
"scope openid missing fails",
args{
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
[]string{"email"},
},
res{
err: true,
},
},
{
"scope ok", args{[]string{"openid"}}, false,
"scope ok",
args{
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
[]string{"openid"},
},
res{
scopes: []string{"openid"},
},
},
{
"scope with drop ok",
args{
mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
[]string{"openid", "email", "unknown"},
},
res{
scopes: []string{"openid", "email"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := op.ValidateAuthReqScopes(tt.args.scopes); (err != nil) != tt.wantErr {
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr)
scopes, err := op.ValidateAuthReqScopes(tt.args.client, tt.args.scopes)
if (err != nil) != tt.res.err {
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.res.err)
}
assert.ElementsMatch(t, scopes, tt.res.scopes)
})
}
}

View file

@ -10,7 +10,9 @@ const (
ApplicationTypeWeb ApplicationType = iota
ApplicationTypeUserAgent
ApplicationTypeNative
)
const (
AccessTokenTypeBearer AccessTokenType = iota
AccessTokenTypeJWT
)
@ -32,6 +34,9 @@ type Client interface {
AccessTokenType() AccessTokenType
IDTokenLifetime() time.Duration
DevMode() bool
AllowedScopes() []string
AssertAdditionalIdTokenScopes() bool
AssertAdditionalAccessTokenScopes() bool
}
func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool {

View file

@ -7,6 +7,8 @@ import (
"strings"
)
const OidcDevMode = "CAOS_OIDC_DEV"
type Configuration interface {
Issuer() string
AuthorizationEndpoint() Endpoint
@ -42,7 +44,7 @@ func ValidateIssuer(issuer string) error {
}
func devLocalAllowed(url *url.URL) bool {
_, b := os.LookupEnv("CAOS_OIDC_DEV")
_, b := os.LookupEnv(OidcDevMode)
if !b {
return b
}

View file

@ -60,6 +60,8 @@ func TestValidateIssuer(t *testing.T) {
true,
},
}
//ensure env is not set
os.Unsetenv(OidcDevMode)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
@ -84,7 +86,7 @@ func TestValidateIssuerDevLocalAllowed(t *testing.T) {
false,
},
}
os.Setenv("CAOS_OIDC_DEV", "")
os.Setenv(OidcDevMode, "true")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {

View file

@ -72,18 +72,18 @@ func (v *Verifier) VerifyIDToken(ctx context.Context, idToken string) (*oidc.IDT
return nil, nil
}
type Sig struct{}
type Sig struct {
signer jose.Signer
}
func (s *Sig) Signer() jose.Signer {
return s.signer
}
func (s *Sig) Health(ctx context.Context) error {
return nil
}
func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) {
return "", nil
}
func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) {
return "", nil
}
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {
return jose.HS256
}
@ -92,9 +92,3 @@ func ExpectStorage(a op.Authorizer, t *testing.T) {
mockA := a.(*MockAuthorizer)
mockA.EXPECT().Storage().AnyTimes().Return(NewMockStorageAny(t))
}
// func NewMockSignerAny(t *testing.T) op.Signer {
// m := NewMockSigner(gomock.NewController(t))
// m.EXPECT().Sign(gomock.Any()).AnyTimes().Return("", nil)
// return m
// }

View file

@ -26,6 +26,7 @@ func NewClientExpectAny(t *testing.T, appType op.ApplicationType) op.Client {
func(id string) string {
return "login?id=" + id
})
m.EXPECT().AllowedScopes().AnyTimes().Return(nil)
return c
}

View file

@ -49,6 +49,20 @@ func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType))
}
// AllowedScopes mocks base method
func (m *MockClient) AllowedScopes() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AllowedScopes")
ret0, _ := ret[0].([]string)
return ret0
}
// AllowedScopes indicates an expected call of AllowedScopes
func (mr *MockClientMockRecorder) AllowedScopes() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowedScopes", reflect.TypeOf((*MockClient)(nil).AllowedScopes))
}
// ApplicationType mocks base method
func (m *MockClient) ApplicationType() op.ApplicationType {
m.ctrl.T.Helper()
@ -63,6 +77,34 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
}
// AssertAdditionalAccessTokenScopes mocks base method
func (m *MockClient) AssertAdditionalAccessTokenScopes() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AssertAdditionalAccessTokenScopes")
ret0, _ := ret[0].(bool)
return ret0
}
// AssertAdditionalAccessTokenScopes indicates an expected call of AssertAdditionalAccessTokenScopes
func (mr *MockClientMockRecorder) AssertAdditionalAccessTokenScopes() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalAccessTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalAccessTokenScopes))
}
// AssertAdditionalIdTokenScopes mocks base method
func (m *MockClient) AssertAdditionalIdTokenScopes() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AssertAdditionalIdTokenScopes")
ret0, _ := ret[0].(bool)
return ret0
}
// AssertAdditionalIdTokenScopes indicates an expected call of AssertAdditionalIdTokenScopes
func (mr *MockClientMockRecorder) AssertAdditionalIdTokenScopes() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertAdditionalIdTokenScopes", reflect.TypeOf((*MockClient)(nil).AssertAdditionalIdTokenScopes))
}
// AuthMethod mocks base method
func (m *MockClient) AuthMethod() op.AuthMethod {
m.ctrl.T.Helper()

View file

@ -6,7 +6,6 @@ package mock
import (
context "context"
oidc "github.com/caos/oidc/pkg/oidc"
gomock "github.com/golang/mock/gomock"
jose "gopkg.in/square/go-jose.v2"
reflect "reflect"
@ -49,36 +48,6 @@ func (mr *MockSignerMockRecorder) Health(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockSigner)(nil).Health), arg0)
}
// SignAccessToken mocks base method
func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignAccessToken", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SignAccessToken indicates an expected call of SignAccessToken
func (mr *MockSignerMockRecorder) SignAccessToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignAccessToken", reflect.TypeOf((*MockSigner)(nil).SignAccessToken), arg0)
}
// SignIDToken mocks base method
func (m *MockSigner) SignIDToken(arg0 *oidc.IDTokenClaims) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignIDToken", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SignIDToken indicates an expected call of SignIDToken
func (mr *MockSignerMockRecorder) SignIDToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignIDToken", reflect.TypeOf((*MockSigner)(nil).SignIDToken), arg0)
}
// SignatureAlgorithm mocks base method
func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
m.ctrl.T.Helper()
@ -92,3 +61,17 @@ func (mr *MockSignerMockRecorder) SignatureAlgorithm() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigner)(nil).SignatureAlgorithm))
}
// Signer mocks base method
func (m *MockSigner) Signer() jose.Signer {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Signer")
ret0, _ := ret[0].(jose.Signer)
return ret0
}
// Signer indicates an expected call of Signer
func (mr *MockSignerMockRecorder) Signer() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockSigner)(nil).Signer))
}

View file

@ -171,6 +171,21 @@ func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
}
// GetPrivateClaimsFromScopes mocks base method
func (m *MockStorage) GetPrivateClaimsFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (map[string]interface{}, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPrivateClaimsFromScopes", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(map[string]interface{})
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPrivateClaimsFromScopes indicates an expected call of GetPrivateClaimsFromScopes
func (mr *MockStorageMockRecorder) GetPrivateClaimsFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivateClaimsFromScopes", reflect.TypeOf((*MockStorage)(nil).GetPrivateClaimsFromScopes), arg0, arg1, arg2, arg3)
}
// GetSigningKey mocks base method
func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) {
m.ctrl.T.Helper()
@ -184,25 +199,25 @@ func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interfac
}
// GetUserinfoFromScopes mocks base method
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.Userinfo, error) {
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (oidc.UserInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2)
ret0, _ := ret[0].(*oidc.Userinfo)
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(oidc.UserInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interface{}) *gomock.Call {
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2, arg3)
}
// GetUserinfoFromToken mocks base method
func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2, arg3 string) (*oidc.Userinfo, error) {
func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2, arg3 string) (oidc.UserInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*oidc.Userinfo)
ret0, _ := ret[0].(oidc.UserInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}

View file

@ -168,3 +168,12 @@ func (c *ConfClient) ResponseTypes() []oidc.ResponseType {
func (c *ConfClient) DevMode() bool {
return c.devMode
}
func (c *ConfClient) AllowedScopes() []string {
return nil
}
func (c *ConfClient) AssertAdditionalIdTokenScopes() bool {
return false
}
func (c *ConfClient) AssertAdditionalAccessTokenScopes() bool {
return false
}

View file

@ -51,6 +51,7 @@ type OpenIDProvider interface {
Encoder() utils.Encoder
IDTokenHintVerifier() IDTokenHintVerifier
JWTProfileVerifier() JWTProfileVerifier
AccessTokenVerifier() AccessTokenVerifier
Crypto() Crypto
DefaultLogoutRedirectURI() string
Signer() Signer
@ -130,7 +131,7 @@ func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opO
}
keyCh := make(chan jose.SigningKey)
o.signer = NewDefaultSigner(ctx, storage, keyCh)
o.signer = NewSigner(ctx, storage, keyCh)
go EnsureKey(ctx, storage, keyCh, o.timer, o.retry)
o.httpHandler = CreateRouter(o, o.interceptors...)
@ -152,6 +153,8 @@ type openidProvider struct {
signer Signer
idTokenHintVerifier IDTokenHintVerifier
jwtProfileVerifier JWTProfileVerifier
accessTokenVerifier AccessTokenVerifier
keySet *openIDKeySet
crypto Crypto
httpHandler http.Handler
decoder *schema.Decoder
@ -207,7 +210,7 @@ func (o *openidProvider) Encoder() utils.Encoder {
func (o *openidProvider) IDTokenHintVerifier() IDTokenHintVerifier {
if o.idTokenHintVerifier == nil {
o.idTokenHintVerifier = NewIDTokenHintVerifier(o.Issuer(), &openIDKeySet{o.Storage()})
o.idTokenHintVerifier = NewIDTokenHintVerifier(o.Issuer(), o.openIDKeySet())
}
return o.idTokenHintVerifier
}
@ -219,6 +222,20 @@ func (o *openidProvider) JWTProfileVerifier() JWTProfileVerifier {
return o.jwtProfileVerifier
}
func (o *openidProvider) AccessTokenVerifier() AccessTokenVerifier {
if o.accessTokenVerifier == nil {
o.accessTokenVerifier = NewAccessTokenVerifier(o.Issuer(), o.openIDKeySet())
}
return o.accessTokenVerifier
}
func (o *openidProvider) openIDKeySet() oidc.KeySet {
if o.keySet == nil {
o.keySet = &openIDKeySet{o.Storage()}
}
return o.keySet
}
func (o *openidProvider) Crypto() Crypto {
return o.crypto
}

View file

@ -66,8 +66,8 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
if err != nil {
return nil, ErrInvalidRequest("id_token_hint invalid")
}
session.UserID = claims.Subject
session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.AuthorizedParty)
session.UserID = claims.GetSubject()
session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.GetAuthorizedParty())
if err != nil {
return nil, ErrServerError("")
}

View file

@ -2,19 +2,15 @@ package op
import (
"context"
"encoding/json"
"errors"
"github.com/caos/logging"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
)
type Signer interface {
Health(ctx context.Context) error
SignIDToken(claims *oidc.IDTokenClaims) (string, error)
SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
Signer() jose.Signer
SignatureAlgorithm() jose.SignatureAlgorithm
}
@ -24,7 +20,7 @@ type tokenSigner struct {
alg jose.SignatureAlgorithm
}
func NewDefaultSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer {
func NewSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer {
s := &tokenSigner{
storage: storage,
}
@ -41,6 +37,10 @@ func (s *tokenSigner) Health(_ context.Context) error {
return nil
}
func (s *tokenSigner) Signer() jose.Signer {
return s.signer
}
func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) {
for {
select {
@ -55,30 +55,6 @@ func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.S
}
}
func (s *tokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", err
}
return s.Sign(payload)
}
func (s *tokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", err
}
return s.Sign(payload)
}
func (s *tokenSigner) Sign(payload []byte) (string, error) {
result, err := s.signer.Sign(payload)
if err != nil {
return "", err
}
return result.CompactSerialize()
}
func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
return s.alg
}

View file

@ -1,95 +0,0 @@
package op
import (
"testing"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2"
)
// func TestNewDefaultSigner(t *testing.T) {
// type args struct {
// storage Storage
// }
// tests := []struct {
// name string
// args args
// want Signer
// wantErr bool
// }{
// {
// "err initialize storage fails",
// args{mock.NewMockStorageSigningKeyError(t)},
// nil,
// true,
// },
// {
// "err initialize storage fails",
// args{mock.NewMockStorageSigningKeyInvalid(t)},
// nil,
// true,
// },
// {
// "initialize ok",
// args{mock.NewMockStorageSigningKey(t)},
// &idTokenSigner{Storage: mock.NewMockStorageSigningKey(t)},
// false,
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// got, err := op.NewDefaultSigner(tt.args.storage)
// if (err != nil) != tt.wantErr {
// t.Errorf("NewDefaultSigner() error = %v, wantErr %v", err, tt.wantErr)
// return
// }
// if !reflect.DeepEqual(got, tt.want) {
// t.Errorf("NewDefaultSigner() = %v, want %v", got, tt.want)
// }
// })
// }
// }
func Test_idTokenSigner_Sign(t *testing.T) {
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, &jose.SignerOptions{})
require.NoError(t, err)
type fields struct {
signer jose.Signer
storage Storage
}
type args struct {
payload []byte
}
tests := []struct {
name string
fields fields
args args
want string
wantErr bool
}{
{
"ok",
fields{signer, nil},
args{[]byte("test")},
"eyJhbGciOiJIUzI1NiJ9.dGVzdA.SxYZRsvB_Dr4F7SEFuYXvkMZqCCwzpsPOQXl-vLPEww",
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &tokenSigner{
signer: tt.fields.signer,
storage: tt.fields.storage,
}
got, err := s.Sign(tt.args.payload)
if (err != nil) != tt.wantErr {
t.Errorf("idTokenSigner.Sign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("idTokenSigner.Sign() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -26,10 +26,11 @@ type AuthStorage interface {
}
type OPStorage interface {
GetClientByClientID(context.Context, string) (Client, error)
AuthorizeClientIDSecret(context.Context, string, string) error
GetUserinfoFromScopes(context.Context, string, []string) (*oidc.Userinfo, error)
GetUserinfoFromToken(ctx context.Context, tokenID, subject, origin string) (*oidc.Userinfo, error)
GetClientByClientID(ctx context.Context, clientID string) (Client, error)
AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error
GetUserinfoFromScopes(ctx context.Context, userID, clientID string, scopes []string) (oidc.UserInfo, error)
GetUserinfoFromToken(ctx context.Context, tokenID, subject, origin string) (oidc.UserInfo, error)
GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]interface{}, error)
GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
}

View file

@ -5,6 +5,7 @@ import (
"time"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
type TokenCreator interface {
@ -25,12 +26,12 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client
var validity time.Duration
if createAccessToken {
var err error
accessToken, validity, err = CreateAccessToken(ctx, authReq, client.AccessTokenType(), creator)
accessToken, validity, err = CreateAccessToken(ctx, authReq, client.AccessTokenType(), creator, client)
if err != nil {
return nil, err
}
}
idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer())
idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer(), client.AssertAdditionalIdTokenScopes())
if err != nil {
return nil, err
}
@ -50,7 +51,7 @@ func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client
}
func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator) (*oidc.AccessTokenResponse, error) {
accessToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator)
accessToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator, nil)
if err != nil {
return nil, err
}
@ -63,17 +64,17 @@ func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, crea
}, nil
}
func CreateAccessToken(ctx context.Context, authReq TokenRequest, accessTokenType AccessTokenType, creator TokenCreator) (token string, validity time.Duration, err error) {
id, exp, err := creator.Storage().CreateToken(ctx, authReq)
func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client Client) (token string, validity time.Duration, err error) {
id, exp, err := creator.Storage().CreateToken(ctx, tokenRequest)
if err != nil {
return "", 0, err
}
validity = exp.Sub(time.Now().UTC())
if accessTokenType == AccessTokenTypeJWT {
token, err = CreateJWT(creator.Issuer(), authReq, exp, id, creator.Signer())
token, err = CreateJWT(ctx, creator.Issuer(), tokenRequest, exp, id, creator.Signer(), client, creator.Storage())
return
}
token, err = CreateBearerToken(id, authReq.GetSubject(), creator.Crypto())
token, err = CreateBearerToken(id, tokenRequest.GetSubject(), creator.Crypto())
return
}
@ -81,52 +82,79 @@ func CreateBearerToken(tokenID, subject string, crypto Crypto) (string, error) {
return crypto.Encrypt(tokenID + ":" + subject)
}
func CreateJWT(issuer string, authReq TokenRequest, exp time.Time, id string, signer Signer) (string, error) {
now := time.Now().UTC()
nbf := now
claims := &oidc.AccessTokenClaims{
Issuer: issuer,
Subject: authReq.GetSubject(),
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: now,
NotBefore: nbf,
JWTID: id,
func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, signer Signer, client Client, storage Storage) (string, error) {
claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id)
if client != nil && client.AssertAdditionalAccessTokenScopes() {
privateClaims, err := storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(tokenRequest.GetScopes()))
if err != nil {
return "", err
}
return signer.SignAccessToken(claims)
claims.SetPrivateClaims(privateClaims)
}
return utils.Sign(claims, signer.Signer())
}
func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer) (string, error) {
var err error
func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer, additonalScopes bool) (string, error) {
exp := time.Now().UTC().Add(validity)
userinfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes())
if err != nil {
return "", err
}
claims := &oidc.IDTokenClaims{
Issuer: issuer,
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: time.Now().UTC(),
AuthTime: authReq.GetAuthTime(),
Nonce: authReq.GetNonce(),
AuthenticationContextClassReference: authReq.GetACR(),
AuthenticationMethodsReferences: authReq.GetAMR(),
AuthorizedParty: authReq.GetClientID(),
Userinfo: *userinfo,
}
claims := oidc.NewIDTokenClaims(issuer, authReq.GetSubject(), authReq.GetAudience(), exp, authReq.GetAuthTime(), authReq.GetNonce(), authReq.GetACR(), authReq.GetAMR(), authReq.GetClientID())
scopes := authReq.GetScopes()
if accessToken != "" {
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
claims.SetAccessTokenHash(atHash)
scopes = removeUserinfoScopes(scopes)
}
if !additonalScopes {
scopes = removeAdditionalScopes(scopes)
}
if len(scopes) > 0 {
userInfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetClientID(), scopes)
if err != nil {
return "", err
}
claims.SetUserinfo(userInfo)
}
if code != "" {
claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm())
codeHash, err := oidc.ClaimHash(code, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
claims.SetCodeHash(codeHash)
}
return signer.SignIDToken(claims)
return utils.Sign(claims, signer.Signer())
}
func removeUserinfoScopes(scopes []string) []string {
for i := len(scopes) - 1; i >= 0; i-- {
if scopes[i] == oidc.ScopeProfile ||
scopes[i] == oidc.ScopeEmail ||
scopes[i] == oidc.ScopeAddress ||
scopes[i] == oidc.ScopePhone {
scopes[i] = scopes[len(scopes)-1]
scopes[len(scopes)-1] = ""
scopes = scopes[:len(scopes)-1]
}
}
return scopes
}
func removeAdditionalScopes(scopes []string) []string {
for i := len(scopes) - 1; i >= 0; i-- {
if !(scopes[i] == oidc.ScopeOpenID ||
scopes[i] == oidc.ScopeProfile ||
scopes[i] == oidc.ScopeEmail ||
scopes[i] == oidc.ScopeAddress ||
scopes[i] == oidc.ScopePhone) {
scopes[i] = scopes[len(scopes)-1]
scopes[len(scopes)-1] = ""
scopes = scopes[:len(scopes)-1]
}
}
return scopes
}

View file

@ -138,18 +138,18 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque
}
func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
assertion, err := ParseJWTProfileRequest(r, exchanger.Decoder())
profileRequest, err := ParseJWTProfileRequest(r, exchanger.Decoder())
if err != nil {
RequestError(w, r, err)
}
claims, err := VerifyJWTAssertion(r.Context(), assertion, exchanger.JWTProfileVerifier())
tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest, exchanger.JWTProfileVerifier())
if err != nil {
RequestError(w, r, err)
return
}
resp, err := CreateJWTTokenResponse(r.Context(), claims, exchanger)
resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, exchanger)
if err != nil {
RequestError(w, r, err)
return
@ -157,17 +157,17 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
utils.MarshalJSON(w, resp)
}
func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (string, error) {
func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (*tokenexchange.JWTProfileRequest, error) {
err := r.ParseForm()
if err != nil {
return "", ErrInvalidRequest("error parsing form")
return nil, ErrInvalidRequest("error parsing form")
}
tokenReq := new(tokenexchange.JWTProfileRequest)
err = decoder.Decode(tokenReq, r.Form)
if err != nil {
return "", ErrInvalidRequest("error decoding form")
return nil, ErrInvalidRequest("error decoding form")
}
return tokenReq.Assertion, nil
return tokenReq, nil
}
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {

View file

@ -1,6 +1,7 @@
package op
import (
"context"
"errors"
"net/http"
"strings"
@ -13,6 +14,7 @@ type UserinfoProvider interface {
Decoder() utils.Decoder
Crypto() Crypto
Storage() Storage
AccessTokenVerifier() AccessTokenVerifier
}
func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) {
@ -27,17 +29,12 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP
http.Error(w, "access token missing", http.StatusUnauthorized)
return
}
tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
if err != nil {
http.Error(w, "access token missing", http.StatusUnauthorized)
return
}
splittedToken := strings.Split(tokenIDSubject, ":")
if len(splittedToken) != 2 {
tokenID, subject, ok := getTokenIDAndSubject(r.Context(), userinfoProvider, accessToken)
if !ok {
http.Error(w, "access token invalid", http.StatusUnauthorized)
return
}
info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), splittedToken[0], splittedToken[1], r.Header.Get("origin"))
info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), tokenID, subject, r.Header.Get("origin"))
if err != nil {
w.WriteHeader(http.StatusForbidden)
utils.MarshalJSON(w, err)
@ -66,3 +63,19 @@ func getAccessToken(r *http.Request, decoder utils.Decoder) (string, error) {
}
return req.AccessToken, nil
}
func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) {
tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
if err == nil {
splitToken := strings.Split(tokenIDSubject, ":")
if len(splitToken) != 2 {
return "", "", false
}
return splitToken[0], splitToken[1], true
}
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier())
if err != nil {
return "", "", false
}
return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true
}

View file

@ -0,0 +1,85 @@
package op
import (
"context"
"time"
"github.com/caos/oidc/pkg/oidc"
)
type AccessTokenVerifier interface {
oidc.Verifier
SupportedSignAlgs() []string
KeySet() oidc.KeySet
}
type accessTokenVerifier struct {
issuer string
maxAgeIAT time.Duration
offset time.Duration
supportedSignAlgs []string
maxAge time.Duration
acr oidc.ACRVerifier
keySet oidc.KeySet
}
//Issuer implements oidc.Verifier interface
func (i *accessTokenVerifier) Issuer() string {
return i.issuer
}
//MaxAgeIAT implements oidc.Verifier interface
func (i *accessTokenVerifier) MaxAgeIAT() time.Duration {
return i.maxAgeIAT
}
//Offset implements oidc.Verifier interface
func (i *accessTokenVerifier) Offset() time.Duration {
return i.offset
}
//SupportedSignAlgs implements AccessTokenVerifier interface
func (i *accessTokenVerifier) SupportedSignAlgs() []string {
return i.supportedSignAlgs
}
//KeySet implements AccessTokenVerifier interface
func (i *accessTokenVerifier) KeySet() oidc.KeySet {
return i.keySet
}
func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet) AccessTokenVerifier {
verifier := &idTokenHintVerifier{
issuer: issuer,
keySet: keySet,
}
return verifier
}
//VerifyAccessToken validates the access token (issuer, signature and expiration)
func VerifyAccessToken(ctx context.Context, token string, v AccessTokenVerifier) (oidc.AccessTokenClaims, error) {
claims := oidc.EmptyAccessTokenClaims()
decrypted, err := oidc.DecryptToken(token)
if err != nil {
return nil, err
}
payload, err := oidc.ParseToken(decrypted, claims)
if err != nil {
return nil, err
}
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
return nil, err
}
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
return nil, err
}
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
return nil, err
}
return claims, nil
}

View file

@ -63,8 +63,8 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet) IDTokenHintVerifi
//VerifyIDTokenHint validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (*oidc.IDTokenClaims, error) {
claims := new(oidc.IDTokenClaims)
func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (oidc.IDTokenClaims, error) {
claims := oidc.EmptyIDTokenClaims()
decrypted, err := oidc.DecryptToken(token)
if err != nil {

View file

@ -8,6 +8,7 @@ import (
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
)
type JWTProfileVerifier interface {
@ -47,9 +48,9 @@ func (v *jwtProfileVerifier) Offset() time.Duration {
return v.offset
}
func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) {
func VerifyJWTAssertion(ctx context.Context, profileRequest *tokenexchange.JWTProfileRequest, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) {
request := new(oidc.JWTTokenRequest)
payload, err := oidc.ParseToken(assertion, request)
payload, err := oidc.ParseToken(profileRequest.Assertion, request)
if err != nil {
return nil, err
}
@ -72,9 +73,10 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif
keySet := &jwtProfileKeySet{v.Storage(), request.Subject}
if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil {
if err = oidc.CheckSignature(ctx, profileRequest.Assertion, payload, request, nil, keySet); err != nil {
return nil, err
}
request.Scopes = profileRequest.Scope
return request, nil
}

View file

@ -1,37 +0,0 @@
package mock
import (
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/rp"
)
func NewVerifier(t *testing.T) rp.Verifier {
return NewMockVerifier(gomock.NewController(t))
}
func NewMockVerifierExpectInvalid(t *testing.T) rp.Verifier {
m := NewVerifier(t)
ExpectVerifyInvalid(m)
return m
}
func ExpectVerifyInvalid(v rp.Verifier) {
mock := v.(*MockVerifier)
mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(nil, errors.New("invalid"))
}
func NewMockVerifierExpectValid(t *testing.T) rp.Verifier {
m := NewVerifier(t)
ExpectVerifyValid(m)
return m
}
func ExpectVerifyValid(v rp.Verifier) {
mock := v.(*MockVerifier)
mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.Userinfo{Subject: "id"}}, nil)
}

View file

@ -4,9 +4,13 @@ import (
"context"
"errors"
"net/http"
"net/url"
"reflect"
"strings"
"time"
"github.com/google/uuid"
"github.com/gorilla/schema"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/oidc/grants"
@ -22,6 +26,16 @@ const (
jwtProfileKey = "urn:ietf:params:oauth:grant-type:jwt-bearer"
)
var (
encoder = func() utils.Encoder {
e := schema.NewEncoder()
e.RegisterEncoder(oidc.Scopes{}, func(value reflect.Value) string {
return value.Interface().(oidc.Scopes).Encode()
})
return e
}()
)
//RelayingParty declares the minimal interface for oidc clients
type RelayingParty interface {
//OAuthConfig returns the oauth2 Config
@ -312,38 +326,45 @@ func CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc
//ClientCredentials is the `RelayingParty` interface implementation
//handling the oauth2 client credentials grant
func ClientCredentials(ctx context.Context, rp RelayingParty, scopes ...string) (newToken *oauth2.Token, err error) {
return CallTokenEndpoint(grants.ClientCredentialsGrantBasic(scopes...), rp)
return CallTokenEndpointAuthorized(grants.ClientCredentialsGrantBasic(scopes...), rp)
}
func CallTokenEndpointAuthorized(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
config := rp.OAuthConfig()
var fn interface{} = utils.AuthorizeBasic(config.ClientID, config.ClientSecret)
if config.Endpoint.AuthStyle == oauth2.AuthStyleInParams {
fn = func(form url.Values) {
form.Set("client_id", config.ClientID)
form.Set("client_secret", config.ClientSecret)
}
}
return callTokenEndpoint(request, fn, rp)
}
func CallTokenEndpoint(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
config := rp.OAuthConfig()
req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, config.ClientID, config.ClientSecret, config.Endpoint.AuthStyle != oauth2.AuthStyleInParams)
if err != nil {
return nil, err
}
token := new(oauth2.Token)
if err := utils.HttpRequest(rp.HttpClient(), req, token); err != nil {
return nil, err
}
return token, nil
return callTokenEndpoint(request, nil, rp)
}
func CallJWTProfileEndpoint(assertion string, rp RelayingParty) (*oauth2.Token, error) {
form := make(map[string][]string)
form["assertion"] = []string{assertion}
form["grant_type"] = []string{jwtProfileKey}
req, err := http.NewRequest("POST", rp.OAuthConfig().Endpoint.TokenURL, nil)
func callTokenEndpoint(request interface{}, authFn interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, encoder, authFn)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
token := new(oauth2.Token)
if err := utils.HttpRequest(rp.HttpClient(), req, token); err != nil {
var tokenRes struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
}
if err := utils.HttpRequest(rp.HttpClient(), req, &tokenRes); err != nil {
return nil, err
}
return token, nil
return &oauth2.Token{
AccessToken: tokenRes.AccessToken,
TokenType: tokenRes.TokenType,
RefreshToken: tokenRes.RefreshToken,
Expiry: time.Now().UTC().Add(time.Duration(tokenRes.ExpiresIn) * time.Second),
}, nil
}
func trySetStateCookie(w http.ResponseWriter, state string, rp RelayingParty) error {

View file

@ -43,12 +43,17 @@ func DelegationTokenExchange(ctx context.Context, subjectToken string, rp Relayi
}
//JWTProfileExchange handles the oauth2 jwt profile exchange
func JWTProfileExchange(ctx context.Context, assertion *oidc.JWTProfileAssertion, rp RelayingParty) (*oauth2.Token, error) {
func JWTProfileExchange(ctx context.Context, jwtProfileRequest *tokenexchange.JWTProfileRequest, rp RelayingParty) (*oauth2.Token, error) {
return CallTokenEndpoint(jwtProfileRequest, rp)
}
//JWTProfileExchange handles the oauth2 jwt profile exchange
func JWTProfileAssertionExchange(ctx context.Context, assertion *oidc.JWTProfileAssertion, scopes oidc.Scopes, rp RelayingParty) (*oauth2.Token, error) {
token, err := generateJWTProfileToken(assertion)
if err != nil {
return nil, err
}
return CallJWTProfileEndpoint(token, rp)
return JWTProfileExchange(ctx, tokenexchange.NewJWTProfileRequest(token, scopes...), rp)
}
func generateJWTProfileToken(assertion *oidc.JWTProfileAssertion) (string, error) {

View file

@ -21,12 +21,12 @@ type IDTokenVerifier interface {
//VerifyTokens implement the Token Response Validation as defined in OIDC specification
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) {
func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
idToken, err := VerifyIDToken(ctx, idTokenString, v)
if err != nil {
return nil, err
}
if err := VerifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil {
if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil {
return nil, err
}
return idToken, nil
@ -34,8 +34,8 @@ func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTo
//VerifyIDToken validates the id token according to
//https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (*oidc.IDTokenClaims, error) {
claims := new(oidc.IDTokenClaims)
func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.IDTokenClaims, error) {
claims := oidc.EmptyIDTokenClaims()
decrypted, err := oidc.DecryptToken(token)
if err != nil {

View file

@ -10,8 +10,6 @@ import (
"net/url"
"strings"
"time"
"github.com/gorilla/schema"
)
var (
@ -27,23 +25,30 @@ type Encoder interface {
Encode(src interface{}, dst map[string][]string) error
}
func FormRequest(endpoint string, request interface{}, clientID, clientSecret string, header bool) (*http.Request, error) {
form := make(map[string][]string)
encoder := schema.NewEncoder()
type FormAuthorization func(url.Values)
type RequestAuthorization func(*http.Request)
func AuthorizeBasic(user, password string) RequestAuthorization {
return func(req *http.Request) {
req.SetBasicAuth(user, password)
}
}
func FormRequest(endpoint string, request interface{}, encoder Encoder, authFn interface{}) (*http.Request, error) {
form := url.Values{}
if err := encoder.Encode(request, form); err != nil {
return nil, err
}
if !header {
form["client_id"] = []string{clientID}
form["client_secret"] = []string{clientSecret}
if fn, ok := authFn.(FormAuthorization); ok {
fn(form)
}
body := strings.NewReader(url.Values(form).Encode())
body := strings.NewReader(form.Encode())
req, err := http.NewRequest("POST", endpoint, body)
if err != nil {
return nil, err
}
if header {
req.SetBasicAuth(clientID, clientSecret)
if fn, ok := authFn.(RequestAuthorization); ok {
fn(req)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req, nil

View file

@ -1,7 +1,9 @@
package utils
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"github.com/sirupsen/logrus"
@ -19,3 +21,15 @@ func MarshalJSON(w http.ResponseWriter, i interface{}) {
logrus.Error("error writing response")
}
}
func ConcatenateJSON(first, second []byte) ([]byte, error) {
if !bytes.HasSuffix(first, []byte{'}'}) {
return nil, fmt.Errorf("jws: invalid JSON %s", first)
}
if !bytes.HasPrefix(second, []byte{'{'}) {
return nil, fmt.Errorf("jws: invalid JSON %s", second)
}
first[len(first)-1] = ','
first = append(first, second[1:]...)
return first, nil
}

60
pkg/utils/marshal_test.go Normal file
View file

@ -0,0 +1,60 @@
package utils
import (
"bytes"
"testing"
)
func TestConcatenateJSON(t *testing.T) {
type args struct {
first []byte
second []byte
}
tests := []struct {
name string
args args
want []byte
wantErr bool
}{
{
"invalid first part, error",
args{
[]byte(`invalid`),
[]byte(`{"some": "thing"}`),
},
nil,
true,
},
{
"invalid second part, error",
args{
[]byte(`{"some": "thing"}`),
[]byte(`invalid`),
},
nil,
true,
},
{
"both valid, merged",
args{
[]byte(`{"some": "thing"}`),
[]byte(`{"another": "thing"}`),
},
[]byte(`{"some": "thing","another": "thing"}`),
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ConcatenateJSON(tt.args.first, tt.args.second)
if (err != nil) != tt.wantErr {
t.Errorf("ConcatenateJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !bytes.Equal(got, tt.want) {
t.Errorf("ConcatenateJSON() got = %v, want %v", got, tt.want)
}
})
}
}

23
pkg/utils/sign.go Normal file
View file

@ -0,0 +1,23 @@
package utils
import (
"encoding/json"
"gopkg.in/square/go-jose.v2"
)
func Sign(object interface{}, signer jose.Signer) (string, error) {
payload, err := json.Marshal(object)
if err != nil {
return "", err
}
return SignPayload(payload, signer)
}
func SignPayload(payload []byte, signer jose.Signer) (string, error) {
result, err := signer.Sign(payload)
if err != nil {
return "", err
}
return result.CompactSerialize()
}