feat: pkce
This commit is contained in:
parent
c1f4d01965
commit
be6737328c
6 changed files with 100 additions and 15 deletions
|
@ -13,11 +13,12 @@ import (
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
"github.com/caos/oidc/pkg/rp"
|
"github.com/caos/oidc/pkg/rp"
|
||||||
|
"github.com/caos/oidc/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
callbackPath string = "/auth/callback"
|
callbackPath string = "/auth/callback"
|
||||||
hashKey []byte = []byte("test")
|
key []byte = []byte("test1234test1234")
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -35,10 +36,10 @@ func main() {
|
||||||
CallbackURL: fmt.Sprintf("http://localhost:%v%v", port, callbackPath),
|
CallbackURL: fmt.Sprintf("http://localhost:%v%v", port, callbackPath),
|
||||||
Scopes: []string{"openid", "profile", "email"},
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
}
|
}
|
||||||
// cookieHandler := utils.NewCookieHandler(hashKey, nil, utils.WithUnsecure())
|
cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
|
||||||
provider, err := rp.NewDefaultRP(rpConfig) //, rp.WithCookieHandler(cookieHandler))
|
provider, err := rp.NewDefaultRP(rpConfig, rp.WithPKCE(cookieHandler)) //, rp.WithCookieHandler(cookieHandler))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Panicf("error creating provider %s", err.Error())
|
logrus.Fatalf("error creating provider %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// state := "foobar"
|
// state := "foobar"
|
||||||
|
|
|
@ -18,6 +18,10 @@ type CodeChallenge struct {
|
||||||
Method CodeChallengeMethod
|
Method CodeChallengeMethod
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewSHACodeChallenge(code string) string {
|
||||||
|
return utils.HashString(sha256.New(), code)
|
||||||
|
}
|
||||||
|
|
||||||
func VerifyCodeChallenge(c *CodeChallenge, codeVerifier string) bool {
|
func VerifyCodeChallenge(c *CodeChallenge, codeVerifier string) bool {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return false //TODO: ?
|
return false //TODO: ?
|
||||||
|
|
|
@ -165,6 +165,9 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
|
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
|
||||||
|
if authReq.GetState() != "" {
|
||||||
|
callback = callback + "&state=" + authReq.GetState()
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
var accessToken string
|
var accessToken string
|
||||||
var err error
|
var err error
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
const (
|
const (
|
||||||
idTokenKey = "id_token"
|
idTokenKey = "id_token"
|
||||||
stateParam = "state"
|
stateParam = "state"
|
||||||
|
pkceCode = "pkce"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -32,6 +33,7 @@ type DefaultRP struct {
|
||||||
|
|
||||||
oauthConfig oauth2.Config
|
oauthConfig oauth2.Config
|
||||||
config *Config
|
config *Config
|
||||||
|
pkce bool
|
||||||
|
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
cookieHandler *utils.CookieHandler
|
cookieHandler *utils.CookieHandler
|
||||||
|
@ -80,6 +82,16 @@ func WithCookieHandler(cookieHandler *utils.CookieHandler) DefaultRPOpts {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//WithPKCE sets the RP to use PKCE (oauth2 code challenge)
|
||||||
|
//it also sets a `CookieHandler` for securing the various redirects
|
||||||
|
//and exchanging the code challenge
|
||||||
|
func WithPKCE(cookieHandler *utils.CookieHandler) DefaultRPOpts {
|
||||||
|
return func(p *DefaultRP) {
|
||||||
|
p.pkce = true
|
||||||
|
p.cookieHandler = cookieHandler
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//WithHTTPClient provides the ability to set an http client to be used for the relaying party and verifier
|
//WithHTTPClient provides the ability to set an http client to be used for the relaying party and verifier
|
||||||
func WithHTTPClient(client *http.Client) DefaultRPOpts {
|
func WithHTTPClient(client *http.Client) DefaultRPOpts {
|
||||||
return func(p *DefaultRP) {
|
return func(p *DefaultRP) {
|
||||||
|
@ -90,28 +102,55 @@ func WithHTTPClient(client *http.Client) DefaultRPOpts {
|
||||||
//AuthURL is the `RelayingParty` interface implementation
|
//AuthURL is the `RelayingParty` interface implementation
|
||||||
//wrapping the oauth2 `AuthCodeURL`
|
//wrapping the oauth2 `AuthCodeURL`
|
||||||
//returning the url of the auth request
|
//returning the url of the auth request
|
||||||
func (p *DefaultRP) AuthURL(state string) string {
|
func (p *DefaultRP) AuthURL(state string, opts ...AuthURLOpt) string {
|
||||||
return p.oauthConfig.AuthCodeURL(state)
|
authOpts := make([]oauth2.AuthCodeOption, 0)
|
||||||
|
for _, opt := range opts {
|
||||||
|
authOpts = append(authOpts, opt()...)
|
||||||
|
}
|
||||||
|
return p.oauthConfig.AuthCodeURL(state, authOpts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
//AuthURL is the `RelayingParty` interface implementation
|
//AuthURL is the `RelayingParty` interface implementation
|
||||||
//extending the `AuthURL` method with a http redirect handler
|
//extending the `AuthURL` method with a http redirect handler
|
||||||
func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc {
|
func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
opts := make([]AuthURLOpt, 0)
|
||||||
if err := p.trySetStateCookie(w, state); err != nil {
|
if err := p.trySetStateCookie(w, state); err != nil {
|
||||||
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
|
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
http.Redirect(w, r, p.AuthURL(state), http.StatusFound)
|
if p.pkce {
|
||||||
|
codeChallenge, err := p.generateAndStoreCodeChallenge(w)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
opts = append(opts, WithCodeChallenge(codeChallenge))
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, p.AuthURL(state, opts...), http.StatusFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DefaultRP) generateAndStoreCodeChallenge(w http.ResponseWriter) (string, error) {
|
||||||
|
var codeVerifier string
|
||||||
|
codeVerifier = "s"
|
||||||
|
if err := p.cookieHandler.SetCookie(w, pkceCode, codeVerifier); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return oidc.NewSHACodeChallenge(codeVerifier), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//AuthURL is the `RelayingParty` interface implementation
|
//AuthURL is the `RelayingParty` interface implementation
|
||||||
//handling the oauth2 code exchange, extracting and validating the id_token
|
//handling the oauth2 code exchange, extracting and validating the id_token
|
||||||
//returning it paresed together with the oauth2 tokens (access, refresh)
|
//returning it paresed together with the oauth2 tokens (access, refresh)
|
||||||
func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc.Tokens, err error) {
|
func (p *DefaultRP) CodeExchange(ctx context.Context, code string, opts ...CodeExchangeOpt) (tokens *oidc.Tokens, err error) {
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, p.httpClient)
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, p.httpClient)
|
||||||
token, err := p.oauthConfig.Exchange(ctx, code)
|
codeOpts := make([]oauth2.AuthCodeOption, 0)
|
||||||
|
for _, opt := range opts {
|
||||||
|
codeOpts = append(codeOpts, opt()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := p.oauthConfig.Exchange(ctx, code, codeOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err //TODO: our error
|
return nil, err //TODO: our error
|
||||||
}
|
}
|
||||||
|
@ -142,7 +181,16 @@ func (p *DefaultRP) CodeExchangeHandler(callback func(http.ResponseWriter, *http
|
||||||
p.errorHandler(w, r, params.Get("error"), params.Get("error_description"), state)
|
p.errorHandler(w, r, params.Get("error"), params.Get("error_description"), state)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tokens, err := p.CodeExchange(r.Context(), params.Get("code"))
|
var codeOpts CodeExchangeOpt
|
||||||
|
if p.pkce {
|
||||||
|
codeVerifier, err := p.cookieHandler.CheckCookie(r, pkceCode)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
codeOpts = WithCodeVerifier(codeVerifier)
|
||||||
|
}
|
||||||
|
tokens, err := p.CodeExchange(r.Context(), params.Get("code"), codeOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
|
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
|
@ -219,7 +267,7 @@ func (p *DefaultRP) callTokenEndpoint(request interface{}) (newToken *oauth2.Tok
|
||||||
|
|
||||||
func (p *DefaultRP) trySetStateCookie(w http.ResponseWriter, state string) error {
|
func (p *DefaultRP) trySetStateCookie(w http.ResponseWriter, state string) error {
|
||||||
if p.cookieHandler != nil {
|
if p.cookieHandler != nil {
|
||||||
if err := p.cookieHandler.SetQueryCookie(w, stateParam, state); err != nil {
|
if err := p.cookieHandler.SetCookie(w, stateParam, state); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
type RelayingParty interface {
|
type RelayingParty interface {
|
||||||
|
|
||||||
//AuthURL returns the authorization endpoint with a given state
|
//AuthURL returns the authorization endpoint with a given state
|
||||||
AuthURL(state string) string
|
AuthURL(state string, opts ...AuthURLOpt) string
|
||||||
|
|
||||||
//AuthURLHandler should implement the AuthURL func as http.HandlerFunc
|
//AuthURLHandler should implement the AuthURL func as http.HandlerFunc
|
||||||
//(redirecting to the auth endpoint)
|
//(redirecting to the auth endpoint)
|
||||||
|
@ -21,7 +21,7 @@ type RelayingParty interface {
|
||||||
|
|
||||||
//CodeExchange implements the OIDC Token Request (oauth2 Authorization Code Grant)
|
//CodeExchange implements the OIDC Token Request (oauth2 Authorization Code Grant)
|
||||||
//returning an `Access Token` and `ID Token Claims`
|
//returning an `Access Token` and `ID Token Claims`
|
||||||
CodeExchange(ctx context.Context, code string) (*oidc.Tokens, error)
|
CodeExchange(ctx context.Context, code string, opts ...CodeExchangeOpt) (*oidc.Tokens, error)
|
||||||
|
|
||||||
//CodeExchangeHandler extends the CodeExchange func,
|
//CodeExchangeHandler extends the CodeExchange func,
|
||||||
//calling the provided callback func on success with additional returned `state`
|
//calling the provided callback func on success with additional returned `state`
|
||||||
|
@ -82,3 +82,24 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
|
||||||
JKWsURL: discoveryConfig.JwksURI,
|
JKWsURL: discoveryConfig.JwksURI,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AuthURLOpt func() []oauth2.AuthCodeOption
|
||||||
|
|
||||||
|
//WithCodeChallenge sets the `code_challenge` params in the auth request
|
||||||
|
func WithCodeChallenge(codeChallenge string) AuthURLOpt {
|
||||||
|
return func() []oauth2.AuthCodeOption {
|
||||||
|
return []oauth2.AuthCodeOption{
|
||||||
|
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||||
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CodeExchangeOpt func() []oauth2.AuthCodeOption
|
||||||
|
|
||||||
|
//WithCodeVerifier sets the `code_verifier` param in the token request
|
||||||
|
func WithCodeVerifier(codeVerifier string) CodeExchangeOpt {
|
||||||
|
return func() []oauth2.AuthCodeOption {
|
||||||
|
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("code_verifier", codeVerifier)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -55,7 +55,7 @@ func WithDomain(domain string) CookieHandlerOpt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CookieHandler) CheckQueryCookie(r *http.Request, name string) (string, error) {
|
func (c *CookieHandler) CheckCookie(r *http.Request, name string) (string, error) {
|
||||||
cookie, err := r.Cookie(name)
|
cookie, err := r.Cookie(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -64,13 +64,21 @@ func (c *CookieHandler) CheckQueryCookie(r *http.Request, name string) (string,
|
||||||
if err := c.securecookie.Decode(name, cookie.Value, &value); err != nil {
|
if err := c.securecookie.Decode(name, cookie.Value, &value); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CookieHandler) CheckQueryCookie(r *http.Request, name string) (string, error) {
|
||||||
|
value, err := c.CheckCookie(r, name)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
if value != r.FormValue(name) {
|
if value != r.FormValue(name) {
|
||||||
return "", errors.New(name + " does not compare")
|
return "", errors.New(name + " does not compare")
|
||||||
}
|
}
|
||||||
return value, nil
|
return value, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CookieHandler) SetQueryCookie(w http.ResponseWriter, name, value string) error {
|
func (c *CookieHandler) SetCookie(w http.ResponseWriter, name, value string) error {
|
||||||
encoded, err := c.securecookie.Encode(name, value)
|
encoded, err := c.securecookie.Encode(name, value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue