From 4f0ed79c0a49c9de7341300c0d1e45e5c1e38796 Mon Sep 17 00:00:00 2001 From: Ayato Date: Tue, 29 Apr 2025 23:33:31 +0900 Subject: [PATCH] fix(op): Add mitigation for PKCE Downgrade Attack (#741) * fix(op): Add mitigation for PKCE downgrade attack * chore(op): add test for PKCE verification --- pkg/op/token_code.go | 9 ++--- pkg/op/token_request.go | 12 +++++- pkg/op/token_request_test.go | 75 ++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 8 deletions(-) create mode 100644 pkg/op/token_request_test.go diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index 019aa63..fb636b4 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -80,12 +80,9 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, } codeChallenge := request.GetCodeChallenge() - if codeChallenge != nil { - err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, codeChallenge) - - if err != nil { - return nil, nil, err - } + err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, codeChallenge) + if err != nil { + return nil, nil, err } if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index 85e2270..66e4c83 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -132,11 +132,19 @@ func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, // AuthorizeCodeChallenge authorizes a client by validating the code_verifier against the previously sent // code_challenge of the auth request (PKCE) func AuthorizeCodeChallenge(codeVerifier string, challenge *oidc.CodeChallenge) error { + if challenge == nil { + if codeVerifier != "" { + return oidc.ErrInvalidRequest().WithDescription("code_verifier unexpectedly provided") + } + + return nil + } + if codeVerifier == "" { - return oidc.ErrInvalidRequest().WithDescription("code_challenge required") + return oidc.ErrInvalidRequest().WithDescription("code_verifier required") } if !oidc.VerifyCodeChallenge(challenge, codeVerifier) { - return oidc.ErrInvalidGrant().WithDescription("invalid code challenge") + return oidc.ErrInvalidGrant().WithDescription("invalid code_verifier") } return nil } diff --git a/pkg/op/token_request_test.go b/pkg/op/token_request_test.go new file mode 100644 index 0000000..21cf20b --- /dev/null +++ b/pkg/op/token_request_test.go @@ -0,0 +1,75 @@ +package op_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" +) + +func TestAuthorizeCodeChallenge(t *testing.T) { + tests := []struct { + name string + codeVerifier string + codeChallenge *oidc.CodeChallenge + want func(t *testing.T, err error) + }{ + { + name: "missing both code_verifier and code_challenge", + codeVerifier: "", + codeChallenge: nil, + want: func(t *testing.T, err error) { + assert.Nil(t, err) + }, + }, + { + name: "valid code_verifier", + codeVerifier: "Hello World!", + codeChallenge: &oidc.CodeChallenge{ + Challenge: "f4OxZX_x_FO5LcGBSKHWXfwtSx-j1ncoSt3SABJtkGk", + Method: oidc.CodeChallengeMethodS256, + }, + want: func(t *testing.T, err error) { + assert.Nil(t, err) + }, + }, + { + name: "invalid code_verifier", + codeVerifier: "Hi World!", + codeChallenge: &oidc.CodeChallenge{ + Challenge: "f4OxZX_x_FO5LcGBSKHWXfwtSx-j1ncoSt3SABJtkGk", + Method: oidc.CodeChallengeMethodS256, + }, + want: func(t *testing.T, err error) { + assert.ErrorContains(t, err, "invalid code_verifier") + }, + }, + { + name: "code_verifier provided without code_challenge", + codeVerifier: "code_verifier", + codeChallenge: nil, + want: func(t *testing.T, err error) { + assert.ErrorContains(t, err, "code_verifier unexpectedly provided") + }, + }, + { + name: "empty code_verifier", + codeVerifier: "", + codeChallenge: &oidc.CodeChallenge{ + Challenge: "f4OxZX_x_FO5LcGBSKHWXfwtSx-j1ncoSt3SABJtkGk", + Method: oidc.CodeChallengeMethodS256, + }, + want: func(t *testing.T, err error) { + assert.ErrorContains(t, err, "code_verifier required") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := op.AuthorizeCodeChallenge(tt.codeVerifier, tt.codeChallenge) + + tt.want(t, err) + }) + } +}