From 0cad2e4652b30ac4d31cf35c989b372ad23558f1 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Mon, 28 Sep 2020 13:55:22 +0200 Subject: [PATCH] jwt profile and authorization handling --- example/client/app/app.go | 2 +- example/client/jwt_profile.go | 39 ---- .../grants/tokenexchange/tokenexchange.go | 18 +- pkg/oidc/token.go | 177 ------------------ pkg/oidc/token_request.go | 2 +- pkg/oidc/types.go | 40 +--- pkg/oidc/types_test.go | 52 +++-- pkg/op/tokenrequest.go | 14 +- pkg/op/verifier_jwt_profile.go | 8 +- pkg/rp/relaying_party.go | 52 ++--- pkg/rp/tockenexchange.go | 9 +- pkg/utils/http.go | 24 ++- 12 files changed, 128 insertions(+), 309 deletions(-) delete mode 100644 example/client/jwt_profile.go diff --git a/example/client/app/app.go b/example/client/app/app.go index 3a96830..a2fff44 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -131,7 +131,7 @@ func main() { http.Error(w, err.Error(), http.StatusInternalServerError) return } - token, err := rp.JWTProfileExchange(ctx, assertion, provider) + token, err := rp.JWTProfileAssertionExchange(ctx, assertion, oidc.Scopes{oidc.ScopeOpenID, oidc.ScopeProfile}, provider) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/example/client/jwt_profile.go b/example/client/jwt_profile.go deleted file mode 100644 index 6dcd11b..0000000 --- a/example/client/jwt_profile.go +++ /dev/null @@ -1,39 +0,0 @@ -package client - -import ( - "context" - "fmt" - "os" - "time" - - "github.com/sirupsen/logrus" - - "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/rp" - "github.com/caos/oidc/pkg/utils" -) - -var ( - callbackPath string = "/auth/callback" - key []byte = []byte("test1234test1234") -) - -func main() { - clientID := os.Getenv("CLIENT_ID") - clientSecret := os.Getenv("CLIENT_SECRET") - issuer := os.Getenv("ISSUER") - port := os.Getenv("PORT") - - ctx := context.Background() - - redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) - scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeAddress, "hodor"} - cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) - provider, err := rp.NewRelayingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, - rp.WithPKCE(cookieHandler), - rp.WithVerifierOpts(rp.WithIssuedAtOffset(5*time.Second)), - ) - if err != nil { - logrus.Fatalf("error creating provider %s", err.Error()) - } -} diff --git a/pkg/oidc/grants/tokenexchange/tokenexchange.go b/pkg/oidc/grants/tokenexchange/tokenexchange.go index 9464605..5cb6e79 100644 --- a/pkg/oidc/grants/tokenexchange/tokenexchange.go +++ b/pkg/oidc/grants/tokenexchange/tokenexchange.go @@ -1,5 +1,9 @@ package tokenexchange +import ( + "github.com/caos/oidc/pkg/oidc" +) + const ( AccessTokenType = "urn:ietf:params:oauth:token-type:access_token" RefreshTokenType = "urn:ietf:params:oauth:token-type:refresh_token" @@ -23,7 +27,19 @@ type TokenExchangeRequest struct { } type JWTProfileRequest struct { - Assertion string `schema:"assertion"` + Assertion string `schema:"assertion"` + Scope oidc.Scopes `schema:"scope"` + GrantType oidc.GrantType `schema:"grant_type"` +} + +//ClientCredentialsGrantBasic creates an oauth2 `Client Credentials` Grant +//sneding client_id and client_secret as basic auth header +func NewJWTProfileRequest(assertion string, scopes ...string) *JWTProfileRequest { + return &JWTProfileRequest{ + GrantType: oidc.GrantTypeBearer, + Assertion: assertion, + Scope: scopes, + } } func NewTokenExchangeRequest(subjectToken, subjectTokenType string, opts ...TokenExchangeOption) *TokenExchangeRequest { diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index e20dd4a..e445e7e 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -202,7 +202,6 @@ type AccessTokenResponse struct { type JWTProfileAssertion struct { PrivateKeyID string `json:"-"` PrivateKey []byte `json:"-"` - Scopes []string `json:"scopes"` Issuer string `json:"issuer"` Subject string `json:"sub"` Audience Audience `json:"aud"` @@ -236,7 +235,6 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte) PrivateKey: key, PrivateKeyID: keyID, Issuer: userID, - Scopes: []string{ScopeOpenID}, Subject: userID, IssuedAt: Time(time.Now().UTC()), Expiration: Time(time.Now().Add(1 * time.Hour).UTC()), @@ -244,80 +242,6 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte) } } -// -//type jsonToken struct { -// Issuer string `json:"iss,omitempty"` -// Subject string `json:"sub,omitempty"` -// Audiences interface{} `json:"aud,omitempty"` -// Expiration int64 `json:"exp,omitempty"` -// NotBefore int64 `json:"nbf,omitempty"` -// IssuedAt int64 `json:"iat,omitempty"` -// JWTID string `json:"jti,omitempty"` -// AuthorizedParty string `json:"azp,omitempty"` -// Nonce string `json:"nonce,omitempty"` -// AuthTime int64 `json:"auth_time,omitempty"` -// AccessTokenHash string `json:"at_hash,omitempty"` -// CodeHash string `json:"c_hash,omitempty"` -// AuthenticationContextClassReference string `json:"acr,omitempty"` -// AuthenticationMethodsReferences []string `json:"amr,omitempty"` -// SessionID string `json:"sid,omitempty"` -// Actor interface{} `json:"act,omitempty"` //TODO: impl -// Scopes string `json:"scope,omitempty"` -// ClientID string `json:"client_id,omitempty"` -// AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl -// AccessTokenUseNumber int `json:"at_use_nbr,omitempty"` -// jsonUserinfo -//} - -// -//func (t *accessTokenClaims) MarshalJSON() ([]byte, error) { -// j := jsonToken{ -// Issuer: t.Issuer, -// Subject: t.Subject, -// Audiences: t.Audiences, -// Expiration: timeToJSON(t.Expiration), -// NotBefore: timeToJSON(t.NotBefore), -// IssuedAt: timeToJSON(t.IssuedAt), -// JWTID: t.JWTID, -// AuthorizedParty: t.AuthorizedParty, -// Nonce: t.Nonce, -// AuthTime: timeToJSON(t.AuthTime), -// CodeHash: t.CodeHash, -// AuthenticationContextClassReference: t.AuthenticationContextClassReference, -// AuthenticationMethodsReferences: t.AuthenticationMethodsReferences, -// SessionID: t.SessionID, -// Scopes: strings.Join(t.Scopes, " "), -// ClientID: t.ClientID, -// AccessTokenUseNumber: t.AccessTokenUseNumber, -// } -// return json.Marshal(j) -//} -// -//func (t *accessTokenClaims) UnmarshalJSON(b []byte) error { -// var j jsonToken -// if err := json.Unmarshal(b, &j); err != nil { -// return err -// } -// t.Issuer = j.Issuer -// t.Subject = j.Subject -// t.Audiences = audienceFromJSON(j.Audiences) -// t.Expiration = time.Unix(j.Expiration, 0).UTC() -// t.NotBefore = time.Unix(j.NotBefore, 0).UTC() -// t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC() -// t.JWTID = j.JWTID -// t.AuthorizedParty = j.AuthorizedParty -// t.Nonce = j.Nonce -// t.AuthTime = time.Unix(j.AuthTime, 0).UTC() -// t.CodeHash = j.CodeHash -// t.AuthenticationContextClassReference = j.AuthenticationContextClassReference -// t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences -// t.SessionID = j.SessionID -// t.Scopes = strings.Split(j.Scopes, " ") -// t.ClientID = j.ClientID -// t.AccessTokenUseNumber = j.AccessTokenUseNumber -// return nil -//} -// func (t *idTokenClaims) MarshalJSON() ([]byte, error) { type Alias idTokenClaims a := &struct { @@ -406,84 +330,6 @@ func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) { t.signatureAlg = alg } -// -//func (t *JWTProfileAssertion) MarshalJSON() ([]byte, error) { -// j := jsonToken{ -// Issuer: t.Issuer, -// Subject: t.Subject, -// Audiences: t.Audience, -// Expiration: timeToJSON(t.Expiration), -// IssuedAt: timeToJSON(t.IssuedAt), -// Scopes: strings.Join(t.Scopes, " "), -// } -// return json.Marshal(j) -//} - -//func (t *JWTProfileAssertion) UnmarshalJSON(b []byte) error { -// var j jsonToken -// if err := json.Unmarshal(b, &j); err != nil { -// return err -// } -// -// t.Issuer = j.Issuer -// t.Subject = j.Subject -// t.Audience = audienceFromJSON(j.Audiences) -// t.Expiration = time.Unix(j.Expiration, 0).UTC() -// t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC() -// t.Scopes = strings.Split(j.Scopes, " ") -// -// return nil -//} - -// -//func (j *jsonToken) UnmarshalUserinfoProfile() userInfoProfile { -// locale, _ := language.Parse(j.Locale) -// return userInfoProfile{ -// Name: j.Name, -// GivenName: j.GivenName, -// FamilyName: j.FamilyName, -// MiddleName: j.MiddleName, -// Nickname: j.Nickname, -// Profile: j.Profile, -// Picture: j.Picture, -// Website: j.Website, -// Gender: Gender(j.Gender), -// Birthdate: j.Birthdate, -// Zoneinfo: j.Zoneinfo, -// Locale: locale, -// UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(), -// PreferredUsername: j.PreferredUsername, -// } -//} -// -//func (j *jsonToken) UnmarshalUserinfoEmail() userInfoEmail { -// return userInfoEmail{ -// Email: j.Email, -// EmailVerified: j.EmailVerified, -// } -//} -// -//func (j *jsonToken) UnmarshalUserinfoPhone() userInfoPhone { -// return userInfoPhone{ -// PhoneNumber: j.Phone, -// PhoneNumberVerified: j.PhoneVerified, -// } -//} -// -//func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress { -// if j.JsonUserinfoAddress == nil { -// return nil -// } -// return &UserinfoAddress{ -// Country: j.JsonUserinfoAddress.Country, -// Formatted: j.JsonUserinfoAddress.Formatted, -// Locality: j.JsonUserinfoAddress.Locality, -// PostalCode: j.JsonUserinfoAddress.PostalCode, -// Region: j.JsonUserinfoAddress.Region, -// StreetAddress: j.JsonUserinfoAddress.StreetAddress, -// } -//} - func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { hash, err := utils.GetHashAlgorithm(sigAlgorithm) if err != nil { @@ -492,26 +338,3 @@ func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, erro return utils.HashString(hash, claim, true), nil } - -func timeToJSON(t time.Time) int64 { - if t.IsZero() { - return 0 - } - return t.Unix() -} - -func audienceFromJSON(i interface{}) []string { - switch aud := i.(type) { - case []string: - return aud - case []interface{}: - audience := make([]string, len(aud)) - for i, a := range aud { - audience[i] = a.(string) - } - return audience - case string: - return []string{aud} - } - return nil -} diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go index c04dfb4..1d958df 100644 --- a/pkg/oidc/token_request.go +++ b/pkg/oidc/token_request.go @@ -37,7 +37,7 @@ func (a *AccessTokenRequest) GrantType() GrantType { type JWTTokenRequest struct { Issuer string `json:"iss"` Subject string `json:"sub"` - Scopes Scopes `json:"scope"` + Scopes Scopes `json:"-"` Audience Audience `json:"aud"` IssuedAt Time `json:"iat"` ExpiresAt Time `json:"exp"` diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index ad19684..8423cff 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -41,38 +41,6 @@ func (d *Display) UnmarshalText(text []byte) error { type Gender string -type Locale language.Tag - -//{ -// SetLocale(language.Tag) -// Get() language.Tag -//} -// -//func NewLocale(tag language.Tag) Locale { -// if tag.IsRoot() { -// return nil -// } -// return &locale{Tag: tag} -//} -// -//type locale struct { -// language.Tag -//} -// -//func (l *locale) SetLocale(tag language.Tag) { -// l.Tag = tag -//} -//func (l *locale) Get() language.Tag { -// return l.Tag -//} - -//func (l *locale) MarshalJSON() ([]byte, error) { -// if l != nil && !l.IsRoot() { -// return l.MarshalText() -// } -// return []byte("null"), nil -//} - type Locales []language.Tag func (l *Locales) UnmarshalText(text []byte) error { @@ -92,11 +60,19 @@ type ResponseType string type Scopes []string +func (s *Scopes) Encode() string { + return strings.Join(*s, " ") +} + func (s *Scopes) UnmarshalText(text []byte) error { *s = strings.Split(string(text), " ") return nil } +func (s *Scopes) MarshalText() ([]byte, error) { + return []byte(s.Encode()), nil +} + type Time time.Time func (t *Time) UnmarshalJSON(data []byte) error { diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go index c451f8c..830fb02 100644 --- a/pkg/oidc/types_test.go +++ b/pkg/oidc/types_test.go @@ -1,6 +1,7 @@ package oidc import ( + "bytes" "encoding/json" "testing" @@ -22,7 +23,15 @@ func TestAudience_UnmarshalText(t *testing.T) { wantErr bool }{ { - "unknown value", + "invalid value", + args{ + []byte(`{"aud": {"a": }}}`), + }, + res{}, + false, + }, + { + "single audience", args{ []byte(`{"aud": "single audience"}`), }, @@ -32,7 +41,7 @@ func TestAudience_UnmarshalText(t *testing.T) { false, }, { - "page", + "multiple audience", args{ []byte(`{"aud": ["multiple", "audience"]}`), }, @@ -219,13 +228,12 @@ func TestScopes_UnmarshalText(t *testing.T) { }) } } - -func TestTime_UnmarshalJSON(t *testing.T) { +func TestScopes_MarshalText(t *testing.T) { type args struct { - text []byte + scopes Scopes } type res struct { - scopes []string + scopes []byte } tests := []struct { name string @@ -236,41 +244,53 @@ func TestTime_UnmarshalJSON(t *testing.T) { { "unknown value", args{ - []byte("unknown"), + Scopes{"unknown"}, }, res{ - []string{"unknown"}, + []byte("unknown"), + }, + false, + }, + { + "struct", + args{ + Scopes{`{"unknown":"value"}`}, + }, + res{ + []byte(`{"unknown":"value"}`), }, false, }, { "openid", args{ - []byte("openid"), + Scopes{"openid"}, }, res{ - []string{"openid"}, + []byte("openid"), }, false, }, { "multiple scopes", args{ - []byte("openid email custom:scope"), + Scopes{"openid", "email", "custom:scope"}, }, res{ - []string{"openid", "email", "custom:scope"}, + []byte("openid email custom:scope"), }, false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var scopes Scopes - if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr { - t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr) + text, err := tt.args.scopes.MarshalText() + if (err != nil) != tt.wantErr { + t.Errorf("MarshalText() error = %v, wantErr %v", err, tt.wantErr) + } + if !bytes.Equal(text, tt.res.scopes) { + t.Errorf("MarshalText() is = %q, want %q", text, tt.res.scopes) } - assert.ElementsMatch(t, scopes, tt.res.scopes) }) } } diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index b3613ce..cba70f3 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -138,18 +138,18 @@ func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenReque } func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { - assertion, err := ParseJWTProfileRequest(r, exchanger.Decoder()) + profileRequest, err := ParseJWTProfileRequest(r, exchanger.Decoder()) if err != nil { RequestError(w, r, err) } - claims, err := VerifyJWTAssertion(r.Context(), assertion, exchanger.JWTProfileVerifier()) + tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest, exchanger.JWTProfileVerifier()) if err != nil { RequestError(w, r, err) return } - resp, err := CreateJWTTokenResponse(r.Context(), claims, exchanger) + resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, exchanger) if err != nil { RequestError(w, r, err) return @@ -157,17 +157,17 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { utils.MarshalJSON(w, resp) } -func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (string, error) { +func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (*tokenexchange.JWTProfileRequest, error) { err := r.ParseForm() if err != nil { - return "", ErrInvalidRequest("error parsing form") + return nil, ErrInvalidRequest("error parsing form") } tokenReq := new(tokenexchange.JWTProfileRequest) err = decoder.Decode(tokenReq, r.Form) if err != nil { - return "", ErrInvalidRequest("error decoding form") + return nil, ErrInvalidRequest("error decoding form") } - return tokenReq.Assertion, nil + return tokenReq, nil } func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { diff --git a/pkg/op/verifier_jwt_profile.go b/pkg/op/verifier_jwt_profile.go index b30bdc5..8a31253 100644 --- a/pkg/op/verifier_jwt_profile.go +++ b/pkg/op/verifier_jwt_profile.go @@ -8,6 +8,7 @@ import ( "gopkg.in/square/go-jose.v2" "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/oidc/grants/tokenexchange" ) type JWTProfileVerifier interface { @@ -47,9 +48,9 @@ func (v *jwtProfileVerifier) Offset() time.Duration { return v.offset } -func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) { +func VerifyJWTAssertion(ctx context.Context, profileRequest *tokenexchange.JWTProfileRequest, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) { request := new(oidc.JWTTokenRequest) - payload, err := oidc.ParseToken(assertion, request) + payload, err := oidc.ParseToken(profileRequest.Assertion, request) if err != nil { return nil, err } @@ -72,9 +73,10 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif keySet := &jwtProfileKeySet{v.Storage(), request.Subject} - if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil { + if err = oidc.CheckSignature(ctx, profileRequest.Assertion, payload, request, nil, keySet); err != nil { return nil, err } + request.Scopes = profileRequest.Scope return request, nil } diff --git a/pkg/rp/relaying_party.go b/pkg/rp/relaying_party.go index fd9ee95..a8bb9bb 100644 --- a/pkg/rp/relaying_party.go +++ b/pkg/rp/relaying_party.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "strings" + "time" "github.com/google/uuid" @@ -313,38 +314,45 @@ func CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc //ClientCredentials is the `RelayingParty` interface implementation //handling the oauth2 client credentials grant func ClientCredentials(ctx context.Context, rp RelayingParty, scopes ...string) (newToken *oauth2.Token, err error) { - return CallTokenEndpoint(grants.ClientCredentialsGrantBasic(scopes...), rp) + return CallTokenEndpointAuthorized(grants.ClientCredentialsGrantBasic(scopes...), rp) +} + +func CallTokenEndpointAuthorized(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) { + config := rp.OAuthConfig() + var fn interface{} = utils.AuthorizeBasic(config.ClientID, config.ClientSecret) + if config.Endpoint.AuthStyle == oauth2.AuthStyleInParams { + fn = func(form url.Values) { + form.Set("client_id", config.ClientID) + form.Set("client_secret", config.ClientSecret) + } + } + return callTokenEndpoint(request, fn, rp) } func CallTokenEndpoint(request interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) { - config := rp.OAuthConfig() - req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, config.ClientID, config.ClientSecret, config.Endpoint.AuthStyle != oauth2.AuthStyleInParams) - if err != nil { - return nil, err - } - token := new(oauth2.Token) - if err := utils.HttpRequest(rp.HttpClient(), req, token); err != nil { - return nil, err - } - return token, nil + return callTokenEndpoint(request, nil, rp) } -func CallJWTProfileEndpoint(assertion string, rp RelayingParty) (*oauth2.Token, error) { - form := url.Values{} - form.Add("assertion", assertion) - form.Add("grant_type", jwtProfileKey) - req, err := http.NewRequest("POST", rp.OAuthConfig().Endpoint.TokenURL, strings.NewReader(form.Encode())) +func callTokenEndpoint(request interface{}, authFn interface{}, rp RelayingParty) (newToken *oauth2.Token, err error) { + req, err := utils.FormRequest(rp.OAuthConfig().Endpoint.TokenURL, request, authFn) if err != nil { return nil, err } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - token := new(oauth2.Token) - if err := utils.HttpRequest(rp.HttpClient(), req, token); err != nil { + var tokenRes struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + } + if err := utils.HttpRequest(rp.HttpClient(), req, &tokenRes); err != nil { return nil, err } - return token, nil + return &oauth2.Token{ + AccessToken: tokenRes.AccessToken, + TokenType: tokenRes.TokenType, + RefreshToken: tokenRes.RefreshToken, + Expiry: time.Now().UTC().Add(time.Duration(tokenRes.ExpiresIn) * time.Second), + }, nil } func trySetStateCookie(w http.ResponseWriter, state string, rp RelayingParty) error { diff --git a/pkg/rp/tockenexchange.go b/pkg/rp/tockenexchange.go index 24b588a..4396dc4 100644 --- a/pkg/rp/tockenexchange.go +++ b/pkg/rp/tockenexchange.go @@ -43,12 +43,17 @@ func DelegationTokenExchange(ctx context.Context, subjectToken string, rp Relayi } //JWTProfileExchange handles the oauth2 jwt profile exchange -func JWTProfileExchange(ctx context.Context, assertion *oidc.JWTProfileAssertion, rp RelayingParty) (*oauth2.Token, error) { +func JWTProfileExchange(ctx context.Context, jwtProfileRequest *tokenexchange.JWTProfileRequest, rp RelayingParty) (*oauth2.Token, error) { + return CallTokenEndpoint(jwtProfileRequest, rp) +} + +//JWTProfileExchange handles the oauth2 jwt profile exchange +func JWTProfileAssertionExchange(ctx context.Context, assertion *oidc.JWTProfileAssertion, scopes oidc.Scopes, rp RelayingParty) (*oauth2.Token, error) { token, err := generateJWTProfileToken(assertion) if err != nil { return nil, err } - return CallJWTProfileEndpoint(token, rp) + return JWTProfileExchange(ctx, tokenexchange.NewJWTProfileRequest(token, scopes...), rp) } func generateJWTProfileToken(assertion *oidc.JWTProfileAssertion) (string, error) { diff --git a/pkg/utils/http.go b/pkg/utils/http.go index 993febb..e785472 100644 --- a/pkg/utils/http.go +++ b/pkg/utils/http.go @@ -27,23 +27,31 @@ type Encoder interface { Encode(src interface{}, dst map[string][]string) error } -func FormRequest(endpoint string, request interface{}, clientID, clientSecret string, header bool) (*http.Request, error) { - form := make(map[string][]string) +type FormAuthorization func(url.Values) +type RequestAuthorization func(*http.Request) + +func AuthorizeBasic(user, password string) RequestAuthorization { + return func(req *http.Request) { + req.SetBasicAuth(user, password) + } +} + +func FormRequest(endpoint string, request interface{}, authFn interface{}) (*http.Request, error) { + form := url.Values{} encoder := schema.NewEncoder() if err := encoder.Encode(request, form); err != nil { return nil, err } - if !header { - form["client_id"] = []string{clientID} - form["client_secret"] = []string{clientSecret} + if fn, ok := authFn.(FormAuthorization); ok { + fn(form) } - body := strings.NewReader(url.Values(form).Encode()) + body := strings.NewReader(form.Encode()) req, err := http.NewRequest("POST", endpoint, body) if err != nil { return nil, err } - if header { - req.SetBasicAuth(clientID, clientSecret) + if fn, ok := authFn.(RequestAuthorization); ok { + fn(req) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") return req, nil