feat: pkce

This commit is contained in:
Livio Amstutz 2020-01-28 08:51:34 +01:00
parent c1f4d01965
commit be6737328c
6 changed files with 100 additions and 15 deletions

View file

@ -13,11 +13,12 @@ import (
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/rp" "github.com/caos/oidc/pkg/rp"
"github.com/caos/oidc/pkg/utils"
) )
var ( var (
callbackPath string = "/auth/callback" callbackPath string = "/auth/callback"
hashKey []byte = []byte("test") key []byte = []byte("test1234test1234")
) )
func main() { func main() {
@ -35,10 +36,10 @@ func main() {
CallbackURL: fmt.Sprintf("http://localhost:%v%v", port, callbackPath), CallbackURL: fmt.Sprintf("http://localhost:%v%v", port, callbackPath),
Scopes: []string{"openid", "profile", "email"}, Scopes: []string{"openid", "profile", "email"},
} }
// cookieHandler := utils.NewCookieHandler(hashKey, nil, utils.WithUnsecure()) cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
provider, err := rp.NewDefaultRP(rpConfig) //, rp.WithCookieHandler(cookieHandler)) provider, err := rp.NewDefaultRP(rpConfig, rp.WithPKCE(cookieHandler)) //, rp.WithCookieHandler(cookieHandler))
if err != nil { if err != nil {
logrus.Panicf("error creating provider %s", err.Error()) logrus.Fatalf("error creating provider %s", err.Error())
} }
// state := "foobar" // state := "foobar"

View file

@ -18,6 +18,10 @@ type CodeChallenge struct {
Method CodeChallengeMethod Method CodeChallengeMethod
} }
func NewSHACodeChallenge(code string) string {
return utils.HashString(sha256.New(), code)
}
func VerifyCodeChallenge(c *CodeChallenge, codeVerifier string) bool { func VerifyCodeChallenge(c *CodeChallenge, codeVerifier string) bool {
if c == nil { if c == nil {
return false //TODO: ? return false //TODO: ?

View file

@ -165,6 +165,9 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
return return
} }
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code) callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
if authReq.GetState() != "" {
callback = callback + "&state=" + authReq.GetState()
}
} else { } else {
var accessToken string var accessToken string
var err error var err error

View file

@ -18,6 +18,7 @@ import (
const ( const (
idTokenKey = "id_token" idTokenKey = "id_token"
stateParam = "state" stateParam = "state"
pkceCode = "pkce"
) )
var ( var (
@ -32,6 +33,7 @@ type DefaultRP struct {
oauthConfig oauth2.Config oauthConfig oauth2.Config
config *Config config *Config
pkce bool
httpClient *http.Client httpClient *http.Client
cookieHandler *utils.CookieHandler cookieHandler *utils.CookieHandler
@ -80,6 +82,16 @@ func WithCookieHandler(cookieHandler *utils.CookieHandler) DefaultRPOpts {
} }
} }
//WithPKCE sets the RP to use PKCE (oauth2 code challenge)
//it also sets a `CookieHandler` for securing the various redirects
//and exchanging the code challenge
func WithPKCE(cookieHandler *utils.CookieHandler) DefaultRPOpts {
return func(p *DefaultRP) {
p.pkce = true
p.cookieHandler = cookieHandler
}
}
//WithHTTPClient provides the ability to set an http client to be used for the relaying party and verifier //WithHTTPClient provides the ability to set an http client to be used for the relaying party and verifier
func WithHTTPClient(client *http.Client) DefaultRPOpts { func WithHTTPClient(client *http.Client) DefaultRPOpts {
return func(p *DefaultRP) { return func(p *DefaultRP) {
@ -90,28 +102,55 @@ func WithHTTPClient(client *http.Client) DefaultRPOpts {
//AuthURL is the `RelayingParty` interface implementation //AuthURL is the `RelayingParty` interface implementation
//wrapping the oauth2 `AuthCodeURL` //wrapping the oauth2 `AuthCodeURL`
//returning the url of the auth request //returning the url of the auth request
func (p *DefaultRP) AuthURL(state string) string { func (p *DefaultRP) AuthURL(state string, opts ...AuthURLOpt) string {
return p.oauthConfig.AuthCodeURL(state) authOpts := make([]oauth2.AuthCodeOption, 0)
for _, opt := range opts {
authOpts = append(authOpts, opt()...)
}
return p.oauthConfig.AuthCodeURL(state, authOpts...)
} }
//AuthURL is the `RelayingParty` interface implementation //AuthURL is the `RelayingParty` interface implementation
//extending the `AuthURL` method with a http redirect handler //extending the `AuthURL` method with a http redirect handler
func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc { func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
opts := make([]AuthURLOpt, 0)
if err := p.trySetStateCookie(w, state); err != nil { if err := p.trySetStateCookie(w, state); err != nil {
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized) http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
return return
} }
http.Redirect(w, r, p.AuthURL(state), http.StatusFound) if p.pkce {
codeChallenge, err := p.generateAndStoreCodeChallenge(w)
if err != nil {
http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized)
return
}
opts = append(opts, WithCodeChallenge(codeChallenge))
}
http.Redirect(w, r, p.AuthURL(state, opts...), http.StatusFound)
} }
} }
func (p *DefaultRP) generateAndStoreCodeChallenge(w http.ResponseWriter) (string, error) {
var codeVerifier string
codeVerifier = "s"
if err := p.cookieHandler.SetCookie(w, pkceCode, codeVerifier); err != nil {
return "", err
}
return oidc.NewSHACodeChallenge(codeVerifier), nil
}
//AuthURL is the `RelayingParty` interface implementation //AuthURL is the `RelayingParty` interface implementation
//handling the oauth2 code exchange, extracting and validating the id_token //handling the oauth2 code exchange, extracting and validating the id_token
//returning it paresed together with the oauth2 tokens (access, refresh) //returning it paresed together with the oauth2 tokens (access, refresh)
func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc.Tokens, err error) { func (p *DefaultRP) CodeExchange(ctx context.Context, code string, opts ...CodeExchangeOpt) (tokens *oidc.Tokens, err error) {
ctx = context.WithValue(ctx, oauth2.HTTPClient, p.httpClient) ctx = context.WithValue(ctx, oauth2.HTTPClient, p.httpClient)
token, err := p.oauthConfig.Exchange(ctx, code) codeOpts := make([]oauth2.AuthCodeOption, 0)
for _, opt := range opts {
codeOpts = append(codeOpts, opt()...)
}
token, err := p.oauthConfig.Exchange(ctx, code, codeOpts...)
if err != nil { if err != nil {
return nil, err //TODO: our error return nil, err //TODO: our error
} }
@ -142,7 +181,16 @@ func (p *DefaultRP) CodeExchangeHandler(callback func(http.ResponseWriter, *http
p.errorHandler(w, r, params.Get("error"), params.Get("error_description"), state) p.errorHandler(w, r, params.Get("error"), params.Get("error_description"), state)
return return
} }
tokens, err := p.CodeExchange(r.Context(), params.Get("code")) var codeOpts CodeExchangeOpt
if p.pkce {
codeVerifier, err := p.cookieHandler.CheckCookie(r, pkceCode)
if err != nil {
http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized)
return
}
codeOpts = WithCodeVerifier(codeVerifier)
}
tokens, err := p.CodeExchange(r.Context(), params.Get("code"), codeOpts)
if err != nil { if err != nil {
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized) http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
return return
@ -219,7 +267,7 @@ func (p *DefaultRP) callTokenEndpoint(request interface{}) (newToken *oauth2.Tok
func (p *DefaultRP) trySetStateCookie(w http.ResponseWriter, state string) error { func (p *DefaultRP) trySetStateCookie(w http.ResponseWriter, state string) error {
if p.cookieHandler != nil { if p.cookieHandler != nil {
if err := p.cookieHandler.SetQueryCookie(w, stateParam, state); err != nil { if err := p.cookieHandler.SetCookie(w, stateParam, state); err != nil {
return err return err
} }
} }

View file

@ -13,7 +13,7 @@ import (
type RelayingParty interface { type RelayingParty interface {
//AuthURL returns the authorization endpoint with a given state //AuthURL returns the authorization endpoint with a given state
AuthURL(state string) string AuthURL(state string, opts ...AuthURLOpt) string
//AuthURLHandler should implement the AuthURL func as http.HandlerFunc //AuthURLHandler should implement the AuthURL func as http.HandlerFunc
//(redirecting to the auth endpoint) //(redirecting to the auth endpoint)
@ -21,7 +21,7 @@ type RelayingParty interface {
//CodeExchange implements the OIDC Token Request (oauth2 Authorization Code Grant) //CodeExchange implements the OIDC Token Request (oauth2 Authorization Code Grant)
//returning an `Access Token` and `ID Token Claims` //returning an `Access Token` and `ID Token Claims`
CodeExchange(ctx context.Context, code string) (*oidc.Tokens, error) CodeExchange(ctx context.Context, code string, opts ...CodeExchangeOpt) (*oidc.Tokens, error)
//CodeExchangeHandler extends the CodeExchange func, //CodeExchangeHandler extends the CodeExchange func,
//calling the provided callback func on success with additional returned `state` //calling the provided callback func on success with additional returned `state`
@ -82,3 +82,24 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
JKWsURL: discoveryConfig.JwksURI, JKWsURL: discoveryConfig.JwksURI,
} }
} }
type AuthURLOpt func() []oauth2.AuthCodeOption
//WithCodeChallenge sets the `code_challenge` params in the auth request
func WithCodeChallenge(codeChallenge string) AuthURLOpt {
return func() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
}
}
}
type CodeExchangeOpt func() []oauth2.AuthCodeOption
//WithCodeVerifier sets the `code_verifier` param in the token request
func WithCodeVerifier(codeVerifier string) CodeExchangeOpt {
return func() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("code_verifier", codeVerifier)}
}
}

View file

@ -55,7 +55,7 @@ func WithDomain(domain string) CookieHandlerOpt {
} }
} }
func (c *CookieHandler) CheckQueryCookie(r *http.Request, name string) (string, error) { func (c *CookieHandler) CheckCookie(r *http.Request, name string) (string, error) {
cookie, err := r.Cookie(name) cookie, err := r.Cookie(name)
if err != nil { if err != nil {
return "", err return "", err
@ -64,13 +64,21 @@ func (c *CookieHandler) CheckQueryCookie(r *http.Request, name string) (string,
if err := c.securecookie.Decode(name, cookie.Value, &value); err != nil { if err := c.securecookie.Decode(name, cookie.Value, &value); err != nil {
return "", err return "", err
} }
return value, nil
}
func (c *CookieHandler) CheckQueryCookie(r *http.Request, name string) (string, error) {
value, err := c.CheckCookie(r, name)
if err != nil {
return "", err
}
if value != r.FormValue(name) { if value != r.FormValue(name) {
return "", errors.New(name + " does not compare") return "", errors.New(name + " does not compare")
} }
return value, nil return value, nil
} }
func (c *CookieHandler) SetQueryCookie(w http.ResponseWriter, name, value string) error { func (c *CookieHandler) SetCookie(w http.ResponseWriter, name, value string) error {
encoded, err := c.securecookie.Encode(name, value) encoded, err := c.securecookie.Encode(name, value)
if err != nil { if err != nil {
return err return err