feat: pkce
This commit is contained in:
parent
c1f4d01965
commit
be6737328c
6 changed files with 100 additions and 15 deletions
|
@ -18,6 +18,7 @@ import (
|
|||
const (
|
||||
idTokenKey = "id_token"
|
||||
stateParam = "state"
|
||||
pkceCode = "pkce"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -32,6 +33,7 @@ type DefaultRP struct {
|
|||
|
||||
oauthConfig oauth2.Config
|
||||
config *Config
|
||||
pkce bool
|
||||
|
||||
httpClient *http.Client
|
||||
cookieHandler *utils.CookieHandler
|
||||
|
@ -80,6 +82,16 @@ func WithCookieHandler(cookieHandler *utils.CookieHandler) DefaultRPOpts {
|
|||
}
|
||||
}
|
||||
|
||||
//WithPKCE sets the RP to use PKCE (oauth2 code challenge)
|
||||
//it also sets a `CookieHandler` for securing the various redirects
|
||||
//and exchanging the code challenge
|
||||
func WithPKCE(cookieHandler *utils.CookieHandler) DefaultRPOpts {
|
||||
return func(p *DefaultRP) {
|
||||
p.pkce = true
|
||||
p.cookieHandler = cookieHandler
|
||||
}
|
||||
}
|
||||
|
||||
//WithHTTPClient provides the ability to set an http client to be used for the relaying party and verifier
|
||||
func WithHTTPClient(client *http.Client) DefaultRPOpts {
|
||||
return func(p *DefaultRP) {
|
||||
|
@ -90,28 +102,55 @@ func WithHTTPClient(client *http.Client) DefaultRPOpts {
|
|||
//AuthURL is the `RelayingParty` interface implementation
|
||||
//wrapping the oauth2 `AuthCodeURL`
|
||||
//returning the url of the auth request
|
||||
func (p *DefaultRP) AuthURL(state string) string {
|
||||
return p.oauthConfig.AuthCodeURL(state)
|
||||
func (p *DefaultRP) AuthURL(state string, opts ...AuthURLOpt) string {
|
||||
authOpts := make([]oauth2.AuthCodeOption, 0)
|
||||
for _, opt := range opts {
|
||||
authOpts = append(authOpts, opt()...)
|
||||
}
|
||||
return p.oauthConfig.AuthCodeURL(state, authOpts...)
|
||||
}
|
||||
|
||||
//AuthURL is the `RelayingParty` interface implementation
|
||||
//extending the `AuthURL` method with a http redirect handler
|
||||
func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
opts := make([]AuthURLOpt, 0)
|
||||
if err := p.trySetStateCookie(w, state); err != nil {
|
||||
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, p.AuthURL(state), http.StatusFound)
|
||||
if p.pkce {
|
||||
codeChallenge, err := p.generateAndStoreCodeChallenge(w)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
opts = append(opts, WithCodeChallenge(codeChallenge))
|
||||
}
|
||||
http.Redirect(w, r, p.AuthURL(state, opts...), http.StatusFound)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DefaultRP) generateAndStoreCodeChallenge(w http.ResponseWriter) (string, error) {
|
||||
var codeVerifier string
|
||||
codeVerifier = "s"
|
||||
if err := p.cookieHandler.SetCookie(w, pkceCode, codeVerifier); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return oidc.NewSHACodeChallenge(codeVerifier), nil
|
||||
}
|
||||
|
||||
//AuthURL is the `RelayingParty` interface implementation
|
||||
//handling the oauth2 code exchange, extracting and validating the id_token
|
||||
//returning it paresed together with the oauth2 tokens (access, refresh)
|
||||
func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc.Tokens, err error) {
|
||||
func (p *DefaultRP) CodeExchange(ctx context.Context, code string, opts ...CodeExchangeOpt) (tokens *oidc.Tokens, err error) {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, p.httpClient)
|
||||
token, err := p.oauthConfig.Exchange(ctx, code)
|
||||
codeOpts := make([]oauth2.AuthCodeOption, 0)
|
||||
for _, opt := range opts {
|
||||
codeOpts = append(codeOpts, opt()...)
|
||||
}
|
||||
|
||||
token, err := p.oauthConfig.Exchange(ctx, code, codeOpts...)
|
||||
if err != nil {
|
||||
return nil, err //TODO: our error
|
||||
}
|
||||
|
@ -142,7 +181,16 @@ func (p *DefaultRP) CodeExchangeHandler(callback func(http.ResponseWriter, *http
|
|||
p.errorHandler(w, r, params.Get("error"), params.Get("error_description"), state)
|
||||
return
|
||||
}
|
||||
tokens, err := p.CodeExchange(r.Context(), params.Get("code"))
|
||||
var codeOpts CodeExchangeOpt
|
||||
if p.pkce {
|
||||
codeVerifier, err := p.cookieHandler.CheckCookie(r, pkceCode)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
codeOpts = WithCodeVerifier(codeVerifier)
|
||||
}
|
||||
tokens, err := p.CodeExchange(r.Context(), params.Get("code"), codeOpts)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
|
@ -219,7 +267,7 @@ func (p *DefaultRP) callTokenEndpoint(request interface{}) (newToken *oauth2.Tok
|
|||
|
||||
func (p *DefaultRP) trySetStateCookie(w http.ResponseWriter, state string) error {
|
||||
if p.cookieHandler != nil {
|
||||
if err := p.cookieHandler.SetQueryCookie(w, stateParam, state); err != nil {
|
||||
if err := p.cookieHandler.SetCookie(w, stateParam, state); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue