Add flash attention for non-quantized CPU GroupQueryAttention#28962
Conversation
Add an FP32 tiled online-softmax flash attention kernel for the CPU GroupQueryAttention contrib op, mirroring the existing quantized-KV flash path. Avoids materializing the full attention score matrix and adds a two-phase flash-decoding path for single-token decode. - New MLAS kernel core/mlas/lib/flashattn_gqa.cpp (MlasFlashAttentionGQA) supporting GQA head grouping, causal masking, local window, attention bias, ragged/per-batch seqlens, packed QKV, and flash-decoding. - New ApplyAttentionFlash dispatch in gqa_attention_base.h; wired into group_query_attention.cc (float only, gated like the quantized flash path: no softcap/smooth softmax/head sink/QK output). - Reuses ORT_GQA_DISABLE_FLASH_ATTENTION to fall back to the naive path.
There was a problem hiding this comment.
Pull request overview
Adds a new FP32 “flash attention” execution path for the CPU com.microsoft.GroupQueryAttention contrib op by introducing an MLAS tiled/online-softmax kernel for FP32 KV caches, and wiring it into the existing GQA CPU implementation under the same gating conditions as the quantized flash path.
Changes:
- Added new MLAS FP32 KV-cache flash-attention kernel (
MlasFlashAttentionGQA) including a two-phase flash-decoding variant forsequence_length == 1. - Extended the MLAS public header with
MlasFlashAttentionGQAArgsand theMlasFlashAttentionGQAAPI. - Added CPU GQA dispatch/wiring so float kernels can use the new flash path when supported, plus build-system and documentation updates.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/mlas/lib/flashattn_gqa.cpp | Implements FP32 KV-cache flash attention kernel + flash-decoding reduce path. |
| onnxruntime/core/mlas/inc/mlas.h | Adds MLAS args struct + declaration for MlasFlashAttentionGQA. |
| onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc | Routes float CPU GQA to the new flash path when eligible. |
| onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h | Implements ApplyAttentionFlash to build FP32 present cache and invoke MLAS kernel. |
| docs/contrib_ops/cpu/gqa.md | Documents the non-quantized (FP32) flash attention path and activation conditions. |
| cmake/onnxruntime_mlas.cmake | Registers the new MLAS source file in the build. |
Guard the per-thread scratch and flash-decoding partials buffer size computations against size_t overflow for large or malformed shapes, matching the SafeInt usage elsewhere in this file.
The FP32 flash path is now restricted to prefill (sequence_length > 1), so the single-token flash-decoding kernels are unreachable. Remove MlasFlashDecodingGQAThreaded/MlasFlashDecodingGQAReduceThreaded, the flash_decoding_partials/kv_chunk_count args fields, and the partials buffer wiring in ApplyAttentionFlash. Simplify the MlasFlashAttentionGQA dispatch to call the prefill kernel directly, and update the GQA doc to describe the prefill-only behavior.
Review: PR #28962 — Add flash attention for non-quantized CPU GroupQueryAttentionAuthor: SummaryAdds an FP32 flash-attention path for CPU
What I'd approve as-is
Minor things to leave as comments
Things I would NOT ask for
VerdictApprove with the two minor comments (lint include + per-batch stride comment). |
|
I could address these feedbacks in next PR 29216. |
Summary
Adds an FP32 flash attention path for the CPU
com.microsoft.GroupQueryAttention(GQA) contrib op, mirroring the existing quantized-KV flash attention path. The new tiled, online-softmax kernel avoids materializing the full[S, T]attention score matrix. It is restricted to prefill / chunked-prefill (sequence_length > 1); single-token decode falls back to the naive path. With causal early-termination it is faster than the naive path across all measured prefill lengths while using a fraction of the memory.Key changes
onnxruntime/core/mlas/lib/flashattn_gqa.cpp(MlasFlashAttentionGQA):num_heads % kv_num_heads == 0), causal masking, local window, additive attention bias, and packed-QKV input.breakonceir >= past_seqlen + q_idx + row_size_q), avoiding the wasted QK/SV GEMMs over roughly half of the square prefill attention matrix.seqlens_k.onnxruntime/core/mlas/inc/mlas.h: newMlasFlashAttentionGQAArgsstruct andMlasFlashAttentionGQAdeclaration.onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: newApplyAttentionFlashthat concatenates new K/V into the FP32 present cache and invokes the kernel. The per-thread scratch buffer size is computed withSafeInt<size_t>to guard againstsize_toverflow on large/malformed shapes before allocation.onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc: float-only flash dispatch, active only for prefill (sequence_length > 1) and whensoftcap == 0, no smooth softmax, no head sink, no QK output; falls back to the naive path otherwise. The existingORT_GQA_DISABLE_FLASH_ATTENTIONenv var disables it.cmake/onnxruntime_mlas.cmake: register the new source file.docs/contrib_ops/cpu/gqa.md: document the non-quantized flash attention path, activation conditions, causal early-termination, file list, and FP32 flash-vs-naive benchmark results.onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py: add an FP32 (non-quantized) mode (--fp32) for operator-level flash-vs-naive comparison.Why prefill-only (
sequence_length > 1)Single-token decode (
sequence_length == 1) produces only a[1, total_sequence_length]score row per head, so there is nothing to tile away and the extra online-softmax bookkeeping makes the flash kernel slower and noisier than naive in practice. Restricting the flash path to prefill keeps the consistent prefill win without regressing decode. Because decode is excluded, the two-phase flash-decoding kernels are unreachable and have been removed for a smaller, simpler implementation.float16continues to use the naive path (the kernel is float-only, matching the quantized flash constraint).Performance
Operator-level, AMD EPYC 7763 (16 physical cores), threads=8, FP32 KV cache,
B=1, num_heads=16, kv_num_heads=8, head_size=128. Flash is faster than naive across all measured prefill lengths (and single-threaded as well, 1.4-1.8x), confirming the gain is algorithmic - the causal early-termination removes the wasted upper-triangle work that previously made flash slower than naive at short sequences.The flash path's primary structural benefit is memory: it never allocates the full O(N x S x T) attention matrix (~1 GB at S=4096, N=16) and instead uses an O(S x Bc) per-thread tile.
Testing
onnxruntime_provider_test --gtest_filter="GroupQueryAttentionTest.*"- 38 passed (12 GPU/WebGPU skipped) with flash on (default) and withORT_GQA_DISABLE_FLASH_ATTENTION=1.S), MHA and GQA head ratios, and local window. Decode now uses the naive path on both sides (diff 0).test_gqa_cpu.py, flash vs. naive reference): focused FP32 sweep of 600 prompt configurations covering all head sizes (32-256), GQA ratios(6,6)/(6,3)/(9,9)/(9,3), batches1/3/5, causal/local window, attention bias, position ids, packed QKV, and with/without KV buffer - all passed. The officialtest_gqa_cpu.pysuite passes.Two correctness bugs were found and fixed via the parity sweep while developing this path:
[batch, 1, S, T]bias.num_heads * S * H, which is incorrect for packed-QKV input (correct stride is(num_heads + 2 * kv_num_heads) * S * H).