feat: check PKCE even when the Auth Method is not “none”.
This commit is contained in:
parent
6a80712fbe
commit
acfc8ad99b
2 changed files with 28 additions and 0 deletions
|
@ -7,6 +7,7 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
@ -33,6 +34,10 @@ func main() {
|
||||||
port := os.Getenv("PORT")
|
port := os.Getenv("PORT")
|
||||||
scopes := strings.Split(os.Getenv("SCOPES"), " ")
|
scopes := strings.Split(os.Getenv("SCOPES"), " ")
|
||||||
responseMode := os.Getenv("RESPONSE_MODE")
|
responseMode := os.Getenv("RESPONSE_MODE")
|
||||||
|
pkce, err := strconv.ParseBool(os.Getenv("PKCE"))
|
||||||
|
if err != nil {
|
||||||
|
logrus.Fatalf("error parsing PKCE %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
|
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
|
||||||
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
|
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
|
||||||
|
@ -64,6 +69,9 @@ func main() {
|
||||||
if keyPath != "" {
|
if keyPath != "" {
|
||||||
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
|
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
|
||||||
}
|
}
|
||||||
|
if pkce {
|
||||||
|
options = append(options, rp.WithPKCE(cookieHandler))
|
||||||
|
}
|
||||||
|
|
||||||
// One can add a logger to the context,
|
// One can add a logger to the context,
|
||||||
// pre-defining log attributes as required.
|
// pre-defining log attributes as required.
|
||||||
|
|
|
@ -84,6 +84,16 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
|
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
|
||||||
|
|
||||||
|
codeChallenge := request.GetCodeChallenge()
|
||||||
|
if codeChallenge != nil && codeChallenge.Challenge != "" {
|
||||||
|
err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge())
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return request, client, err
|
return request, client, err
|
||||||
}
|
}
|
||||||
client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
|
client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
|
||||||
|
@ -109,6 +119,16 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
|
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
|
||||||
|
|
||||||
|
codeChallenge := request.GetCodeChallenge()
|
||||||
|
if codeChallenge != nil && codeChallenge.Challenge != "" {
|
||||||
|
err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge())
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return request, client, err
|
return request, client, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue