feat: pkce

This commit is contained in:
Livio Amstutz 2020-01-28 08:51:34 +01:00
parent c1f4d01965
commit be6737328c
6 changed files with 100 additions and 15 deletions

View file

@ -18,6 +18,7 @@ import (
const (
idTokenKey = "id_token"
stateParam = "state"
pkceCode = "pkce"
)
var (
@ -32,6 +33,7 @@ type DefaultRP struct {
oauthConfig oauth2.Config
config *Config
pkce bool
httpClient *http.Client
cookieHandler *utils.CookieHandler
@ -80,6 +82,16 @@ func WithCookieHandler(cookieHandler *utils.CookieHandler) DefaultRPOpts {
}
}
//WithPKCE sets the RP to use PKCE (oauth2 code challenge)
//it also sets a `CookieHandler` for securing the various redirects
//and exchanging the code challenge
func WithPKCE(cookieHandler *utils.CookieHandler) DefaultRPOpts {
return func(p *DefaultRP) {
p.pkce = true
p.cookieHandler = cookieHandler
}
}
//WithHTTPClient provides the ability to set an http client to be used for the relaying party and verifier
func WithHTTPClient(client *http.Client) DefaultRPOpts {
return func(p *DefaultRP) {
@ -90,28 +102,55 @@ func WithHTTPClient(client *http.Client) DefaultRPOpts {
//AuthURL is the `RelayingParty` interface implementation
//wrapping the oauth2 `AuthCodeURL`
//returning the url of the auth request
func (p *DefaultRP) AuthURL(state string) string {
return p.oauthConfig.AuthCodeURL(state)
func (p *DefaultRP) AuthURL(state string, opts ...AuthURLOpt) string {
authOpts := make([]oauth2.AuthCodeOption, 0)
for _, opt := range opts {
authOpts = append(authOpts, opt()...)
}
return p.oauthConfig.AuthCodeURL(state, authOpts...)
}
//AuthURL is the `RelayingParty` interface implementation
//extending the `AuthURL` method with a http redirect handler
func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
opts := make([]AuthURLOpt, 0)
if err := p.trySetStateCookie(w, state); err != nil {
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
return
}
http.Redirect(w, r, p.AuthURL(state), http.StatusFound)
if p.pkce {
codeChallenge, err := p.generateAndStoreCodeChallenge(w)
if err != nil {
http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized)
return
}
opts = append(opts, WithCodeChallenge(codeChallenge))
}
http.Redirect(w, r, p.AuthURL(state, opts...), http.StatusFound)
}
}
func (p *DefaultRP) generateAndStoreCodeChallenge(w http.ResponseWriter) (string, error) {
var codeVerifier string
codeVerifier = "s"
if err := p.cookieHandler.SetCookie(w, pkceCode, codeVerifier); err != nil {
return "", err
}
return oidc.NewSHACodeChallenge(codeVerifier), nil
}
//AuthURL is the `RelayingParty` interface implementation
//handling the oauth2 code exchange, extracting and validating the id_token
//returning it paresed together with the oauth2 tokens (access, refresh)
func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc.Tokens, err error) {
func (p *DefaultRP) CodeExchange(ctx context.Context, code string, opts ...CodeExchangeOpt) (tokens *oidc.Tokens, err error) {
ctx = context.WithValue(ctx, oauth2.HTTPClient, p.httpClient)
token, err := p.oauthConfig.Exchange(ctx, code)
codeOpts := make([]oauth2.AuthCodeOption, 0)
for _, opt := range opts {
codeOpts = append(codeOpts, opt()...)
}
token, err := p.oauthConfig.Exchange(ctx, code, codeOpts...)
if err != nil {
return nil, err //TODO: our error
}
@ -142,7 +181,16 @@ func (p *DefaultRP) CodeExchangeHandler(callback func(http.ResponseWriter, *http
p.errorHandler(w, r, params.Get("error"), params.Get("error_description"), state)
return
}
tokens, err := p.CodeExchange(r.Context(), params.Get("code"))
var codeOpts CodeExchangeOpt
if p.pkce {
codeVerifier, err := p.cookieHandler.CheckCookie(r, pkceCode)
if err != nil {
http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized)
return
}
codeOpts = WithCodeVerifier(codeVerifier)
}
tokens, err := p.CodeExchange(r.Context(), params.Get("code"), codeOpts)
if err != nil {
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
return
@ -219,7 +267,7 @@ func (p *DefaultRP) callTokenEndpoint(request interface{}) (newToken *oauth2.Tok
func (p *DefaultRP) trySetStateCookie(w http.ResponseWriter, state string) error {
if p.cookieHandler != nil {
if err := p.cookieHandler.SetQueryCookie(w, stateParam, state); err != nil {
if err := p.cookieHandler.SetCookie(w, stateParam, state); err != nil {
return err
}
}