feat: check PKCE even when the Auth Method is not “none”.

This commit is contained in:
Ayato 2025-03-01 22:01:42 +09:00
parent 6a80712fbe
commit acfc8ad99b
No known key found for this signature in database
GPG key ID: 56E05AE09DBA012D
2 changed files with 28 additions and 0 deletions

View file

@ -7,6 +7,7 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"os" "os"
"strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
@ -33,6 +34,10 @@ func main() {
port := os.Getenv("PORT") port := os.Getenv("PORT")
scopes := strings.Split(os.Getenv("SCOPES"), " ") scopes := strings.Split(os.Getenv("SCOPES"), " ")
responseMode := os.Getenv("RESPONSE_MODE") responseMode := os.Getenv("RESPONSE_MODE")
pkce, err := strconv.ParseBool(os.Getenv("PKCE"))
if err != nil {
logrus.Fatalf("error parsing PKCE %s", err.Error())
}
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
@ -64,6 +69,9 @@ func main() {
if keyPath != "" { if keyPath != "" {
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
} }
if pkce {
options = append(options, rp.WithPKCE(cookieHandler))
}
// One can add a logger to the context, // One can add a logger to the context,
// pre-defining log attributes as required. // pre-defining log attributes as required.

View file

@ -84,6 +84,16 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
return nil, nil, err return nil, nil, err
} }
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code) request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
codeChallenge := request.GetCodeChallenge()
if codeChallenge != nil && codeChallenge.Challenge != "" {
err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge())
if err != nil {
return nil, nil, err
}
}
return request, client, err return request, client, err
} }
client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID) client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
@ -109,6 +119,16 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
return nil, nil, err return nil, nil, err
} }
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code) request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
codeChallenge := request.GetCodeChallenge()
if codeChallenge != nil && codeChallenge.Challenge != "" {
err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge())
if err != nil {
return nil, nil, err
}
}
return request, client, err return request, client, err
} }