feat(op): always verify code challenge when available (#721)

Finally the RFC Best Current Practice for OAuth 2.0 Security has been approved.

According to the RFC:

> Authorization servers MUST support PKCE [RFC7636].
> 
> If a client sends a valid PKCE code_challenge parameter in the authorization request, the authorization server MUST enforce the correct usage of code_verifier at the token endpoint.

Isn’t it time we strengthen PKCE support a bit more?

This PR updates the logic so that PKCE is always verified, even when the Auth Method is not "none".
This commit is contained in:
Ayato 2025-03-25 01:00:04 +09:00 committed by GitHub
parent 7096406e71
commit c51628ea27
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 45 additions and 15 deletions

View file

@ -7,6 +7,7 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"os" "os"
"strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
@ -34,6 +35,14 @@ func main() {
scopes := strings.Split(os.Getenv("SCOPES"), " ") scopes := strings.Split(os.Getenv("SCOPES"), " ")
responseMode := os.Getenv("RESPONSE_MODE") responseMode := os.Getenv("RESPONSE_MODE")
var pkce bool
if pkceEnv, ok := os.LookupEnv("PKCE"); ok {
var err error
pkce, err = strconv.ParseBool(pkceEnv)
if err != nil {
logrus.Fatalf("error parsing PKCE %s", err.Error())
}
}
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
@ -64,6 +73,9 @@ func main() {
if keyPath != "" { if keyPath != "" {
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
} }
if pkce {
options = append(options, rp.WithPKCE(cookieHandler))
}
// One can add a logger to the context, // One can add a logger to the context,
// pre-defining log attributes as required. // pre-defining log attributes as required.

View file

@ -25,5 +25,5 @@
<button type="submit">Login</button> <button type="submit">Login</button>
</form> </form>
</body> </body>
</html>` </html>
{{- end }} {{- end }}

View file

@ -18,7 +18,7 @@ const (
// CustomClaim is an example for how to return custom claims with this library // CustomClaim is an example for how to return custom claims with this library
CustomClaim = "custom_claim" CustomClaim = "custom_claim"
// CustomScopeImpersonatePrefix is an example scope prefix for passing user id to impersonate using token exchage // CustomScopeImpersonatePrefix is an example scope prefix for passing user id to impersonate using token exchange
CustomScopeImpersonatePrefix = "custom_scope:impersonate:" CustomScopeImpersonatePrefix = "custom_scope:impersonate:"
) )
@ -143,6 +143,14 @@ func MaxAgeToInternal(maxAge *uint) *time.Duration {
} }
func authRequestToInternal(authReq *oidc.AuthRequest, userID string) *AuthRequest { func authRequestToInternal(authReq *oidc.AuthRequest, userID string) *AuthRequest {
var codeChallenge *OIDCCodeChallenge
if authReq.CodeChallenge != "" {
codeChallenge = &OIDCCodeChallenge{
Challenge: authReq.CodeChallenge,
Method: string(authReq.CodeChallengeMethod),
}
}
return &AuthRequest{ return &AuthRequest{
CreationDate: time.Now(), CreationDate: time.Now(),
ApplicationID: authReq.ClientID, ApplicationID: authReq.ClientID,
@ -157,10 +165,7 @@ func authRequestToInternal(authReq *oidc.AuthRequest, userID string) *AuthReques
ResponseType: authReq.ResponseType, ResponseType: authReq.ResponseType,
ResponseMode: authReq.ResponseMode, ResponseMode: authReq.ResponseMode,
Nonce: authReq.Nonce, Nonce: authReq.Nonce,
CodeChallenge: &OIDCCodeChallenge{ CodeChallenge: codeChallenge,
Challenge: authReq.CodeChallenge,
Method: string(authReq.CodeChallengeMethod),
},
} }
} }

View file

@ -102,6 +102,7 @@ func TestRoutes(t *testing.T) {
authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1") authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1")
require.NoError(t, err) require.NoError(t, err)
storage.AuthRequestDone(authReq.GetID()) storage.AuthRequestDone(authReq.GetID())
storage.SaveAuthCode(ctx, authReq.GetID(), "123")
accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "") accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
require.NoError(t, err) require.NoError(t, err)

View file

@ -130,7 +130,7 @@ func TestServerRoutes(t *testing.T) {
"client_id": client.GetID(), "client_id": client.GetID(),
"client_secret": "secret", "client_secret": "secret",
"redirect_uri": "https://example.com", "redirect_uri": "https://example.com",
"code": "123", "code": "abc",
}, },
wantCode: http.StatusBadRequest, wantCode: http.StatusBadRequest,
json: `{"error":"invalid_grant", "error_description":"invalid code"}`, json: `{"error":"invalid_grant", "error_description":"invalid code"}`,

View file

@ -74,6 +74,20 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
ctx, span := tracer.Start(ctx, "AuthorizeCodeClient") ctx, span := tracer.Start(ctx, "AuthorizeCodeClient")
defer span.End() defer span.End()
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
if err != nil {
return nil, nil, err
}
codeChallenge := request.GetCodeChallenge()
if codeChallenge != nil {
err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, codeChallenge)
if err != nil {
return nil, nil, err
}
}
if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
jwtExchanger, ok := exchanger.(JWTAuthorizationGrantExchanger) jwtExchanger, ok := exchanger.(JWTAuthorizationGrantExchanger)
if !ok || !exchanger.AuthMethodPrivateKeyJWTSupported() { if !ok || !exchanger.AuthMethodPrivateKeyJWTSupported() {
@ -83,9 +97,9 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
return request, client, err return request, client, err
} }
client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID) client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
if err != nil { if err != nil {
return nil, nil, oidc.ErrInvalidClient().WithParent(err) return nil, nil, oidc.ErrInvalidClient().WithParent(err)
@ -94,12 +108,10 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
return nil, nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client") return nil, nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client")
} }
if client.AuthMethod() == oidc.AuthMethodNone { if client.AuthMethod() == oidc.AuthMethodNone {
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code) if codeChallenge == nil {
if err != nil { return nil, nil, oidc.ErrInvalidRequest().WithDescription("PKCE required")
return nil, nil, err
} }
err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge()) return request, client, nil
return request, client, err
} }
if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() { if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() {
return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported") return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
@ -108,7 +120,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
return request, client, err return request, client, err
} }