diff --git a/example/client/app/app.go b/example/client/app/app.go index 0b9b19d..5ba3a7a 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -7,6 +7,7 @@ import ( "log/slog" "net/http" "os" + "strconv" "strings" "sync/atomic" "time" @@ -33,6 +34,10 @@ func main() { port := os.Getenv("PORT") scopes := strings.Split(os.Getenv("SCOPES"), " ") responseMode := os.Getenv("RESPONSE_MODE") + pkce, err := strconv.ParseBool(os.Getenv("PKCE")) + if err != nil { + logrus.Fatalf("error parsing PKCE %s", err.Error()) + } redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) @@ -64,6 +69,9 @@ func main() { if keyPath != "" { options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) } + if pkce { + options = append(options, rp.WithPKCE(cookieHandler)) + } // One can add a logger to the context, // pre-defining log attributes as required. diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index 3612240..50dff04 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -84,6 +84,16 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, return nil, nil, err } request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code) + + codeChallenge := request.GetCodeChallenge() + if codeChallenge != nil && codeChallenge.Challenge != "" { + err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge()) + + if err != nil { + return nil, nil, err + } + } + return request, client, err } client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID) @@ -109,6 +119,16 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, return nil, nil, err } request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code) + + codeChallenge := request.GetCodeChallenge() + if codeChallenge != nil && codeChallenge.Challenge != "" { + err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge()) + + if err != nil { + return nil, nil, err + } + } + return request, client, err }