jwt profile and authorization handling
This commit is contained in:
parent
d368b2d950
commit
0cad2e4652
12 changed files with 128 additions and 309 deletions
|
@ -131,7 +131,7 @@ func main() {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
token, err := rp.JWTProfileExchange(ctx, assertion, provider)
|
token, err := rp.JWTProfileAssertionExchange(ctx, assertion, oidc.Scopes{oidc.ScopeOpenID, oidc.ScopeProfile}, provider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,39 +0,0 @@
|
||||||
package client
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
|
||||||
"github.com/caos/oidc/pkg/rp"
|
|
||||||
"github.com/caos/oidc/pkg/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
callbackPath string = "/auth/callback"
|
|
||||||
key []byte = []byte("test1234test1234")
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
clientID := os.Getenv("CLIENT_ID")
|
|
||||||
clientSecret := os.Getenv("CLIENT_SECRET")
|
|
||||||
issuer := os.Getenv("ISSUER")
|
|
||||||
port := os.Getenv("PORT")
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
|
|
||||||
scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress, "hodor"}
|
|
||||||
cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
|
|
||||||
provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes,
|
|
||||||
rp.WithPKCE(cookieHandler),
|
|
||||||
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5*time.Second)),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
logrus.Fatalf("error creating provider %s", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,5 +1,9 @@
|
||||||
package tokenexchange
|
package tokenexchange
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
AccessTokenType = "urn:ietf:params:oauth:token-type:access_token"
|
AccessTokenType = "urn:ietf:params:oauth:token-type:access_token"
|
||||||
RefreshTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
|
RefreshTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
|
||||||
|
@ -24,6 +28,18 @@ type TokenExchangeRequest struct {
|
||||||
|
|
||||||
type JWTProfileRequest struct {
|
type JWTProfileRequest struct {
|
||||||
Assertion string `schema:"assertion"`
|
Assertion string `schema:"assertion"`
|
||||||
|
Scope oidc.Scopes `schema:"scope"`
|
||||||
|
GrantType oidc.GrantType `schema:"grant_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
//ClientCredentialsGrantBasic creates an oauth2 `Client Credentials` Grant
|
||||||
|
//sneding client_id and client_secret as basic auth header
|
||||||
|
func NewJWTProfileRequest(assertion string, scopes ...string) *JWTProfileRequest {
|
||||||
|
return &JWTProfileRequest{
|
||||||
|
GrantType: oidc.GrantTypeBearer,
|
||||||
|
Assertion: assertion,
|
||||||
|
Scope: scopes,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTokenExchangeRequest(subjectToken, subjectTokenType string, opts ...TokenExchangeOption) *TokenExchangeRequest {
|
func NewTokenExchangeRequest(subjectToken, subjectTokenType string, opts ...TokenExchangeOption) *TokenExchangeRequest {
|
||||||
|
|
|
@ -202,7 +202,6 @@ type AccessTokenResponse struct {
|
||||||
type JWTProfileAssertion struct {
|
type JWTProfileAssertion struct {
|
||||||
PrivateKeyID string `json:"-"`
|
PrivateKeyID string `json:"-"`
|
||||||
PrivateKey []byte `json:"-"`
|
PrivateKey []byte `json:"-"`
|
||||||
Scopes []string `json:"scopes"`
|
|
||||||
Issuer string `json:"issuer"`
|
Issuer string `json:"issuer"`
|
||||||
Subject string `json:"sub"`
|
Subject string `json:"sub"`
|
||||||
Audience Audience `json:"aud"`
|
Audience Audience `json:"aud"`
|
||||||
|
@ -236,7 +235,6 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte)
|
||||||
PrivateKey: key,
|
PrivateKey: key,
|
||||||
PrivateKeyID: keyID,
|
PrivateKeyID: keyID,
|
||||||
Issuer: userID,
|
Issuer: userID,
|
||||||
Scopes: []string{ScopeOpenID},
|
|
||||||
Subject: userID,
|
Subject: userID,
|
||||||
IssuedAt: Time(time.Now().UTC()),
|
IssuedAt: Time(time.Now().UTC()),
|
||||||
Expiration: Time(time.Now().Add(1 * time.Hour).UTC()),
|
Expiration: Time(time.Now().Add(1 * time.Hour).UTC()),
|
||||||
|
@ -244,80 +242,6 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
//type jsonToken struct {
|
|
||||||
// Issuer string `json:"iss,omitempty"`
|
|
||||||
// Subject string `json:"sub,omitempty"`
|
|
||||||
// Audiences interface{} `json:"aud,omitempty"`
|
|
||||||
// Expiration int64 `json:"exp,omitempty"`
|
|
||||||
// NotBefore int64 `json:"nbf,omitempty"`
|
|
||||||
// IssuedAt int64 `json:"iat,omitempty"`
|
|
||||||
// JWTID string `json:"jti,omitempty"`
|
|
||||||
// AuthorizedParty string `json:"azp,omitempty"`
|
|
||||||
// Nonce string `json:"nonce,omitempty"`
|
|
||||||
// AuthTime int64 `json:"auth_time,omitempty"`
|
|
||||||
// AccessTokenHash string `json:"at_hash,omitempty"`
|
|
||||||
// CodeHash string `json:"c_hash,omitempty"`
|
|
||||||
// AuthenticationContextClassReference string `json:"acr,omitempty"`
|
|
||||||
// AuthenticationMethodsReferences []string `json:"amr,omitempty"`
|
|
||||||
// SessionID string `json:"sid,omitempty"`
|
|
||||||
// Actor interface{} `json:"act,omitempty"` //TODO: impl
|
|
||||||
// Scopes string `json:"scope,omitempty"`
|
|
||||||
// ClientID string `json:"client_id,omitempty"`
|
|
||||||
// AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl
|
|
||||||
// AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
|
|
||||||
// jsonUserinfo
|
|
||||||
//}
|
|
||||||
|
|
||||||
//
|
|
||||||
//func (t *accessTokenClaims) MarshalJSON() ([]byte, error) {
|
|
||||||
// j := jsonToken{
|
|
||||||
// Issuer: t.Issuer,
|
|
||||||
// Subject: t.Subject,
|
|
||||||
// Audiences: t.Audiences,
|
|
||||||
// Expiration: timeToJSON(t.Expiration),
|
|
||||||
// NotBefore: timeToJSON(t.NotBefore),
|
|
||||||
// IssuedAt: timeToJSON(t.IssuedAt),
|
|
||||||
// JWTID: t.JWTID,
|
|
||||||
// AuthorizedParty: t.AuthorizedParty,
|
|
||||||
// Nonce: t.Nonce,
|
|
||||||
// AuthTime: timeToJSON(t.AuthTime),
|
|
||||||
// CodeHash: t.CodeHash,
|
|
||||||
// AuthenticationContextClassReference: t.AuthenticationContextClassReference,
|
|
||||||
// AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
|
|
||||||
// SessionID: t.SessionID,
|
|
||||||
// Scopes: strings.Join(t.Scopes, " "),
|
|
||||||
// ClientID: t.ClientID,
|
|
||||||
// AccessTokenUseNumber: t.AccessTokenUseNumber,
|
|
||||||
// }
|
|
||||||
// return json.Marshal(j)
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//func (t *accessTokenClaims) UnmarshalJSON(b []byte) error {
|
|
||||||
// var j jsonToken
|
|
||||||
// if err := json.Unmarshal(b, &j); err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
// t.Issuer = j.Issuer
|
|
||||||
// t.Subject = j.Subject
|
|
||||||
// t.Audiences = audienceFromJSON(j.Audiences)
|
|
||||||
// t.Expiration = time.Unix(j.Expiration, 0).UTC()
|
|
||||||
// t.NotBefore = time.Unix(j.NotBefore, 0).UTC()
|
|
||||||
// t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
|
|
||||||
// t.JWTID = j.JWTID
|
|
||||||
// t.AuthorizedParty = j.AuthorizedParty
|
|
||||||
// t.Nonce = j.Nonce
|
|
||||||
// t.AuthTime = time.Unix(j.AuthTime, 0).UTC()
|
|
||||||
// t.CodeHash = j.CodeHash
|
|
||||||
// t.AuthenticationContextClassReference = j.AuthenticationContextClassReference
|
|
||||||
// t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences
|
|
||||||
// t.SessionID = j.SessionID
|
|
||||||
// t.Scopes = strings.Split(j.Scopes, " ")
|
|
||||||
// t.ClientID = j.ClientID
|
|
||||||
// t.AccessTokenUseNumber = j.AccessTokenUseNumber
|
|
||||||
// return nil
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
func (t *idTokenClaims) MarshalJSON() ([]byte, error) {
|
func (t *idTokenClaims) MarshalJSON() ([]byte, error) {
|
||||||
type Alias idTokenClaims
|
type Alias idTokenClaims
|
||||||
a := &struct {
|
a := &struct {
|
||||||
|
@ -406,84 +330,6 @@ func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) {
|
||||||
t.signatureAlg = alg
|
t.signatureAlg = alg
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
//func (t *JWTProfileAssertion) MarshalJSON() ([]byte, error) {
|
|
||||||
// j := jsonToken{
|
|
||||||
// Issuer: t.Issuer,
|
|
||||||
// Subject: t.Subject,
|
|
||||||
// Audiences: t.Audience,
|
|
||||||
// Expiration: timeToJSON(t.Expiration),
|
|
||||||
// IssuedAt: timeToJSON(t.IssuedAt),
|
|
||||||
// Scopes: strings.Join(t.Scopes, " "),
|
|
||||||
// }
|
|
||||||
// return json.Marshal(j)
|
|
||||||
//}
|
|
||||||
|
|
||||||
//func (t *JWTProfileAssertion) UnmarshalJSON(b []byte) error {
|
|
||||||
// var j jsonToken
|
|
||||||
// if err := json.Unmarshal(b, &j); err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// t.Issuer = j.Issuer
|
|
||||||
// t.Subject = j.Subject
|
|
||||||
// t.Audience = audienceFromJSON(j.Audiences)
|
|
||||||
// t.Expiration = time.Unix(j.Expiration, 0).UTC()
|
|
||||||
// t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
|
|
||||||
// t.Scopes = strings.Split(j.Scopes, " ")
|
|
||||||
//
|
|
||||||
// return nil
|
|
||||||
//}
|
|
||||||
|
|
||||||
//
|
|
||||||
//func (j *jsonToken) UnmarshalUserinfoProfile() userInfoProfile {
|
|
||||||
// locale, _ := language.Parse(j.Locale)
|
|
||||||
// return userInfoProfile{
|
|
||||||
// Name: j.Name,
|
|
||||||
// GivenName: j.GivenName,
|
|
||||||
// FamilyName: j.FamilyName,
|
|
||||||
// MiddleName: j.MiddleName,
|
|
||||||
// Nickname: j.Nickname,
|
|
||||||
// Profile: j.Profile,
|
|
||||||
// Picture: j.Picture,
|
|
||||||
// Website: j.Website,
|
|
||||||
// Gender: Gender(j.Gender),
|
|
||||||
// Birthdate: j.Birthdate,
|
|
||||||
// Zoneinfo: j.Zoneinfo,
|
|
||||||
// Locale: locale,
|
|
||||||
// UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(),
|
|
||||||
// PreferredUsername: j.PreferredUsername,
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//func (j *jsonToken) UnmarshalUserinfoEmail() userInfoEmail {
|
|
||||||
// return userInfoEmail{
|
|
||||||
// Email: j.Email,
|
|
||||||
// EmailVerified: j.EmailVerified,
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//func (j *jsonToken) UnmarshalUserinfoPhone() userInfoPhone {
|
|
||||||
// return userInfoPhone{
|
|
||||||
// PhoneNumber: j.Phone,
|
|
||||||
// PhoneNumberVerified: j.PhoneVerified,
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress {
|
|
||||||
// if j.JsonUserinfoAddress == nil {
|
|
||||||
// return nil
|
|
||||||
// }
|
|
||||||
// return &UserinfoAddress{
|
|
||||||
// Country: j.JsonUserinfoAddress.Country,
|
|
||||||
// Formatted: j.JsonUserinfoAddress.Formatted,
|
|
||||||
// Locality: j.JsonUserinfoAddress.Locality,
|
|
||||||
// PostalCode: j.JsonUserinfoAddress.PostalCode,
|
|
||||||
// Region: j.JsonUserinfoAddress.Region,
|
|
||||||
// StreetAddress: j.JsonUserinfoAddress.StreetAddress,
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
|
|
||||||
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
||||||
hash, err := utils.GetHashAlgorithm(sigAlgorithm)
|
hash, err := utils.GetHashAlgorithm(sigAlgorithm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -492,26 +338,3 @@ func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, erro
|
||||||
|
|
||||||
return utils.HashString(hash, claim, true), nil
|
return utils.HashString(hash, claim, true), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func timeToJSON(t time.Time) int64 {
|
|
||||||
if t.IsZero() {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return t.Unix()
|
|
||||||
}
|
|
||||||
|
|
||||||
func audienceFromJSON(i interface{}) []string {
|
|
||||||
switch aud := i.(type) {
|
|
||||||
case []string:
|
|
||||||
return aud
|
|
||||||
case []interface{}:
|
|
||||||
audience := make([]string, len(aud))
|
|
||||||
for i, a := range aud {
|
|
||||||
audience[i] = a.(string)
|
|
||||||
}
|
|
||||||
return audience
|
|
||||||
case string:
|
|
||||||
return []string{aud}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ func (a *AccessTokenRequest) GrantType() GrantType {
|
||||||
type JWTTokenRequest struct {
|
type JWTTokenRequest struct {
|
||||||
Issuer string `json:"iss"`
|
Issuer string `json:"iss"`
|
||||||
Subject string `json:"sub"`
|
Subject string `json:"sub"`
|
||||||
Scopes Scopes `json:"scope"`
|
Scopes Scopes `json:"-"`
|
||||||
Audience Audience `json:"aud"`
|
Audience Audience `json:"aud"`
|
||||||
IssuedAt Time `json:"iat"`
|
IssuedAt Time `json:"iat"`
|
||||||
ExpiresAt Time `json:"exp"`
|
ExpiresAt Time `json:"exp"`
|
||||||
|
|
|
@ -41,38 +41,6 @@ func (d *Display) UnmarshalText(text []byte) error {
|
||||||
|
|
||||||
type Gender string
|
type Gender string
|
||||||
|
|
||||||
type Locale language.Tag
|
|
||||||
|
|
||||||
//{
|
|
||||||
// SetLocale(language.Tag)
|
|
||||||
// Get() language.Tag
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//func NewLocale(tag language.Tag) Locale {
|
|
||||||
// if tag.IsRoot() {
|
|
||||||
// return nil
|
|
||||||
// }
|
|
||||||
// return &locale{Tag: tag}
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//type locale struct {
|
|
||||||
// language.Tag
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//func (l *locale) SetLocale(tag language.Tag) {
|
|
||||||
// l.Tag = tag
|
|
||||||
//}
|
|
||||||
//func (l *locale) Get() language.Tag {
|
|
||||||
// return l.Tag
|
|
||||||
//}
|
|
||||||
|
|
||||||
//func (l *locale) MarshalJSON() ([]byte, error) {
|
|
||||||
// if l != nil && !l.IsRoot() {
|
|
||||||
// return l.MarshalText()
|
|
||||||
// }
|
|
||||||
// return []byte("null"), nil
|
|
||||||
//}
|
|
||||||
|
|
||||||
type Locales []language.Tag
|
type Locales []language.Tag
|
||||||
|
|
||||||
func (l *Locales) UnmarshalText(text []byte) error {
|
func (l *Locales) UnmarshalText(text []byte) error {
|
||||||
|
@ -92,11 +60,19 @@ type ResponseType string
|
||||||
|
|
||||||
type Scopes []string
|
type Scopes []string
|
||||||
|
|
||||||
|
func (s *Scopes) Encode() string {
|
||||||
|
return strings.Join(*s, " ")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Scopes) UnmarshalText(text []byte) error {
|
func (s *Scopes) UnmarshalText(text []byte) error {
|
||||||
*s = strings.Split(string(text), " ")
|
*s = strings.Split(string(text), " ")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Scopes) MarshalText() ([]byte, error) {
|
||||||
|
return []byte(s.Encode()), nil
|
||||||
|
}
|
||||||
|
|
||||||
type Time time.Time
|
type Time time.Time
|
||||||
|
|
||||||
func (t *Time) UnmarshalJSON(data []byte) error {
|
func (t *Time) UnmarshalJSON(data []byte) error {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package oidc
|
package oidc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -22,7 +23,15 @@ func TestAudience_UnmarshalText(t *testing.T) {
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"unknown value",
|
"invalid value",
|
||||||
|
args{
|
||||||
|
[]byte(`{"aud": {"a": }}}`),
|
||||||
|
},
|
||||||
|
res{},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"single audience",
|
||||||
args{
|
args{
|
||||||
[]byte(`{"aud": "single audience"}`),
|
[]byte(`{"aud": "single audience"}`),
|
||||||
},
|
},
|
||||||
|
@ -32,7 +41,7 @@ func TestAudience_UnmarshalText(t *testing.T) {
|
||||||
false,
|
false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"page",
|
"multiple audience",
|
||||||
args{
|
args{
|
||||||
[]byte(`{"aud": ["multiple", "audience"]}`),
|
[]byte(`{"aud": ["multiple", "audience"]}`),
|
||||||
},
|
},
|
||||||
|
@ -219,13 +228,12 @@ func TestScopes_UnmarshalText(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func TestScopes_MarshalText(t *testing.T) {
|
||||||
func TestTime_UnmarshalJSON(t *testing.T) {
|
|
||||||
type args struct {
|
type args struct {
|
||||||
text []byte
|
scopes Scopes
|
||||||
}
|
}
|
||||||
type res struct {
|
type res struct {
|
||||||
scopes []string
|
scopes []byte
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -236,41 +244,53 @@ func TestTime_UnmarshalJSON(t *testing.T) {
|
||||||
{
|
{
|
||||||
"unknown value",
|
"unknown value",
|
||||||
args{
|
args{
|
||||||
[]byte("unknown"),
|
Scopes{"unknown"},
|
||||||
},
|
},
|
||||||
res{
|
res{
|
||||||
[]string{"unknown"},
|
[]byte("unknown"),
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"struct",
|
||||||
|
args{
|
||||||
|
Scopes{`{"unknown":"value"}`},
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
[]byte(`{"unknown":"value"}`),
|
||||||
},
|
},
|
||||||
false,
|
false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"openid",
|
"openid",
|
||||||
args{
|
args{
|
||||||
[]byte("openid"),
|
Scopes{"openid"},
|
||||||
},
|
},
|
||||||
res{
|
res{
|
||||||
[]string{"openid"},
|
[]byte("openid"),
|
||||||
},
|
},
|
||||||
false,
|
false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"multiple scopes",
|
"multiple scopes",
|
||||||
args{
|
args{
|
||||||
[]byte("openid email custom:scope"),
|
Scopes{"openid", "email", "custom:scope"},
|
||||||
},
|
},
|
||||||
res{
|
res{
|
||||||
[]string{"openid", "email", "custom:scope"},
|
[]byte("openid email custom:scope"),
|
||||||
},
|
},
|
||||||
false,
|
false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
var scopes Scopes
|
text, err := tt.args.scopes.MarshalText()
|
||||||
if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("MarshalText() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(text, tt.res.scopes) {
|
||||||
|
t.Errorf("MarshalText() is = %q, want %q", text, tt.res.scopes)
|
||||||
}
|
}
|
||||||
assert.ElementsMatch(t, scopes, tt.res.scopes)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -138,18 +138,18 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque
|
||||||
}
|
}
|
||||||
|
|
||||||
func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
assertion, err := ParseJWTProfileRequest(r, exchanger.Decoder())
|
profileRequest, err := ParseJWTProfileRequest(r, exchanger.Decoder())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, err := VerifyJWTAssertion(r.Context(), assertion, exchanger.JWTProfileVerifier())
|
tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest, exchanger.JWTProfileVerifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := CreateJWTTokenResponse(r.Context(), claims, exchanger)
|
resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, exchanger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
return
|
return
|
||||||
|
@ -157,17 +157,17 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
utils.MarshalJSON(w, resp)
|
utils.MarshalJSON(w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (string, error) {
|
func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (*tokenexchange.JWTProfileRequest, error) {
|
||||||
err := r.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", ErrInvalidRequest("error parsing form")
|
return nil, ErrInvalidRequest("error parsing form")
|
||||||
}
|
}
|
||||||
tokenReq := new(tokenexchange.JWTProfileRequest)
|
tokenReq := new(tokenexchange.JWTProfileRequest)
|
||||||
err = decoder.Decode(tokenReq, r.Form)
|
err = decoder.Decode(tokenReq, r.Form)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", ErrInvalidRequest("error decoding form")
|
return nil, ErrInvalidRequest("error decoding form")
|
||||||
}
|
}
|
||||||
return tokenReq.Assertion, nil
|
return tokenReq, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
|
"github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JWTProfileVerifier interface {
|
type JWTProfileVerifier interface {
|
||||||
|
@ -47,9 +48,9 @@ func (v *jwtProfileVerifier) Offset() time.Duration {
|
||||||
return v.offset
|
return v.offset
|
||||||
}
|
}
|
||||||
|
|
||||||
func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) {
|
func VerifyJWTAssertion(ctx context.Context, profileRequest *tokenexchange.JWTProfileRequest, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) {
|
||||||
request := new(oidc.JWTTokenRequest)
|
request := new(oidc.JWTTokenRequest)
|
||||||
payload, err := oidc.ParseToken(assertion, request)
|
payload, err := oidc.ParseToken(profileRequest.Assertion, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -72,9 +73,10 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif
|
||||||
|
|
||||||
keySet := &jwtProfileKeySet{v.Storage(), request.Subject}
|
keySet := &jwtProfileKeySet{v.Storage(), request.Subject}
|
||||||
|
|
||||||
if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil {
|
if err = oidc.CheckSignature(ctx, profileRequest.Assertion, payload, request, nil, keySet); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
request.Scopes = profileRequest.Scope
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
@ -313,38 +314,45 @@ func CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc
|
||||||
//ClientCredentials is the `RelayingParty` interface implementation
|
//ClientCredentials is the `RelayingParty` interface implementation
|
||||||
//handling the oauth2 client credentials grant
|
//handling the oauth2 client credentials grant
|
||||||
func ClientCredentials(ctx context.Context, rp RelayingParty, scopes ...string) (newToken *oauth2.Token, err error) {
|
func ClientCredentials(ctx context.Context, rp RelayingParty, scopes ...string) (newToken *oauth2.Token, err error) {
|
||||||
return CallTokenEndpoint(grants.ClientCredentialsGrantBasic(scopes...), rp)
|
return CallTokenEndpointAuthorized(grants.ClientCredentialsGrantBasic(scopes...), rp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CallTokenEndpointAuthorized(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
|
||||||
|
config := rp.OAuthConfig()
|
||||||
|
var fn interface{} = utils.AuthorizeBasic(config.ClientID, config.ClientSecret)
|
||||||
|
if config.Endpoint.AuthStyle == oauth2.AuthStyleInParams {
|
||||||
|
fn = func(form url.Values) {
|
||||||
|
form.Set("client_id", config.ClientID)
|
||||||
|
form.Set("client_secret", config.ClientSecret)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return callTokenEndpoint(request, fn, rp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CallTokenEndpoint(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
|
func CallTokenEndpoint(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
|
||||||
config := rp.OAuthConfig()
|
return callTokenEndpoint(request, nil, rp)
|
||||||
req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, config.ClientID, config.ClientSecret, config.Endpoint.AuthStyle != oauth2.AuthStyleInParams)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
token := new(oauth2.Token)
|
|
||||||
if err := utils.HttpRequest(rp.HttpClient(), req, token); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return token, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func CallJWTProfileEndpoint(assertion string, rp RelayingParty) (*oauth2.Token, error) {
|
func callTokenEndpoint(request interface{}, authFn interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) {
|
||||||
form := url.Values{}
|
req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, authFn)
|
||||||
form.Add("assertion", assertion)
|
|
||||||
form.Add("grant_type", jwtProfileKey)
|
|
||||||
req, err := http.NewRequest("POST", rp.OAuthConfig().Endpoint.TokenURL, strings.NewReader(form.Encode()))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
var tokenRes struct {
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
token := new(oauth2.Token)
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
if err := utils.HttpRequest(rp.HttpClient(), req, token); err != nil {
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
if err := utils.HttpRequest(rp.HttpClient(), req, &tokenRes); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return token, nil
|
return &oauth2.Token{
|
||||||
|
AccessToken: tokenRes.AccessToken,
|
||||||
|
TokenType: tokenRes.TokenType,
|
||||||
|
RefreshToken: tokenRes.RefreshToken,
|
||||||
|
Expiry: time.Now().UTC().Add(time.Duration(tokenRes.ExpiresIn) * time.Second),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func trySetStateCookie(w http.ResponseWriter, state string, rp RelayingParty) error {
|
func trySetStateCookie(w http.ResponseWriter, state string, rp RelayingParty) error {
|
||||||
|
|
|
@ -43,12 +43,17 @@ func DelegationTokenExchange(ctx context.Context, subjectToken string, rp Relayi
|
||||||
}
|
}
|
||||||
|
|
||||||
//JWTProfileExchange handles the oauth2 jwt profile exchange
|
//JWTProfileExchange handles the oauth2 jwt profile exchange
|
||||||
func JWTProfileExchange(ctx context.Context, assertion *oidc.JWTProfileAssertion, rp RelayingParty) (*oauth2.Token, error) {
|
func JWTProfileExchange(ctx context.Context, jwtProfileRequest *tokenexchange.JWTProfileRequest, rp RelayingParty) (*oauth2.Token, error) {
|
||||||
|
return CallTokenEndpoint(jwtProfileRequest, rp)
|
||||||
|
}
|
||||||
|
|
||||||
|
//JWTProfileExchange handles the oauth2 jwt profile exchange
|
||||||
|
func JWTProfileAssertionExchange(ctx context.Context, assertion *oidc.JWTProfileAssertion, scopes oidc.Scopes, rp RelayingParty) (*oauth2.Token, error) {
|
||||||
token, err := generateJWTProfileToken(assertion)
|
token, err := generateJWTProfileToken(assertion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return CallJWTProfileEndpoint(token, rp)
|
return JWTProfileExchange(ctx, tokenexchange.NewJWTProfileRequest(token, scopes...), rp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateJWTProfileToken(assertion *oidc.JWTProfileAssertion) (string, error) {
|
func generateJWTProfileToken(assertion *oidc.JWTProfileAssertion) (string, error) {
|
||||||
|
|
|
@ -27,23 +27,31 @@ type Encoder interface {
|
||||||
Encode(src interface{}, dst map[string][]string) error
|
Encode(src interface{}, dst map[string][]string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func FormRequest(endpoint string, request interface{}, clientID, clientSecret string, header bool) (*http.Request, error) {
|
type FormAuthorization func(url.Values)
|
||||||
form := make(map[string][]string)
|
type RequestAuthorization func(*http.Request)
|
||||||
|
|
||||||
|
func AuthorizeBasic(user, password string) RequestAuthorization {
|
||||||
|
return func(req *http.Request) {
|
||||||
|
req.SetBasicAuth(user, password)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func FormRequest(endpoint string, request interface{}, authFn interface{}) (*http.Request, error) {
|
||||||
|
form := url.Values{}
|
||||||
encoder := schema.NewEncoder()
|
encoder := schema.NewEncoder()
|
||||||
if err := encoder.Encode(request, form); err != nil {
|
if err := encoder.Encode(request, form); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !header {
|
if fn, ok := authFn.(FormAuthorization); ok {
|
||||||
form["client_id"] = []string{clientID}
|
fn(form)
|
||||||
form["client_secret"] = []string{clientSecret}
|
|
||||||
}
|
}
|
||||||
body := strings.NewReader(url.Values(form).Encode())
|
body := strings.NewReader(form.Encode())
|
||||||
req, err := http.NewRequest("POST", endpoint, body)
|
req, err := http.NewRequest("POST", endpoint, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if header {
|
if fn, ok := authFn.(RequestAuthorization); ok {
|
||||||
req.SetBasicAuth(clientID, clientSecret)
|
fn(req)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
return req, nil
|
return req, nil
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue