From b06e76c54e81453a32b2567c985f7e1a45fdf930 Mon Sep 17 00:00:00 2001 From: Ayato Date: Thu, 13 Mar 2025 23:32:26 +0900 Subject: [PATCH] fix: simplify verifying PKCE --- example/server/storage/oidc.go | 13 ++++++---- pkg/op/op_test.go | 1 + pkg/op/token_code.go | 44 ++++++++++++++-------------------- 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/example/server/storage/oidc.go b/example/server/storage/oidc.go index c04877f..03d91ed 100644 --- a/example/server/storage/oidc.go +++ b/example/server/storage/oidc.go @@ -143,6 +143,14 @@ func MaxAgeToInternal(maxAge *uint) *time.Duration { } func authRequestToInternal(authReq *oidc.AuthRequest, userID string) *AuthRequest { + var codeChallenge *OIDCCodeChallenge + if authReq.CodeChallenge != "" { + codeChallenge = &OIDCCodeChallenge{ + Challenge: authReq.CodeChallenge, + Method: string(authReq.CodeChallengeMethod), + } + } + return &AuthRequest{ CreationDate: time.Now(), ApplicationID: authReq.ClientID, @@ -157,10 +165,7 @@ func authRequestToInternal(authReq *oidc.AuthRequest, userID string) *AuthReques ResponseType: authReq.ResponseType, ResponseMode: authReq.ResponseMode, Nonce: authReq.Nonce, - CodeChallenge: &OIDCCodeChallenge{ - Challenge: authReq.CodeChallenge, - Method: string(authReq.CodeChallengeMethod), - }, + CodeChallenge: codeChallenge, } } diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index 9a4a624..c1520e2 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -102,6 +102,7 @@ func TestRoutes(t *testing.T) { authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1") require.NoError(t, err) storage.AuthRequestDone(authReq.GetID()) + storage.SaveAuthCode(ctx, authReq.GetID(), "123") accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "") require.NoError(t, err) diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index 50dff04..019aa63 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -74,6 +74,20 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, ctx, span := tracer.Start(ctx, "AuthorizeCodeClient") defer span.End() + request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code) + if err != nil { + return nil, nil, err + } + + codeChallenge := request.GetCodeChallenge() + if codeChallenge != nil { + err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, codeChallenge) + + if err != nil { + return nil, nil, err + } + } + if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { jwtExchanger, ok := exchanger.(JWTAuthorizationGrantExchanger) if !ok || !exchanger.AuthMethodPrivateKeyJWTSupported() { @@ -83,19 +97,9 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, if err != nil { return nil, nil, err } - request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code) - - codeChallenge := request.GetCodeChallenge() - if codeChallenge != nil && codeChallenge.Challenge != "" { - err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge()) - - if err != nil { - return nil, nil, err - } - } - return request, client, err } + client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID) if err != nil { return nil, nil, oidc.ErrInvalidClient().WithParent(err) @@ -104,12 +108,10 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, return nil, nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client") } if client.AuthMethod() == oidc.AuthMethodNone { - request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code) - if err != nil { - return nil, nil, err + if codeChallenge == nil { + return nil, nil, oidc.ErrInvalidRequest().WithDescription("PKCE required") } - err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge()) - return request, client, err + return request, client, nil } if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() { return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported") @@ -118,16 +120,6 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, if err != nil { return nil, nil, err } - request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code) - - codeChallenge := request.GetCodeChallenge() - if codeChallenge != nil && codeChallenge.Challenge != "" { - err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge()) - - if err != nil { - return nil, nil, err - } - } return request, client, err }