Skip to content

Add flash attention for non-quantized CPU GroupQueryAttention#28962

Merged
tianleiwu merged 6 commits into
mainfrom
tlwu/20260608/gqa_cpu_flash_att
Jun 24, 2026
Merged

Add flash attention for non-quantized CPU GroupQueryAttention#28962
tianleiwu merged 6 commits into
mainfrom
tlwu/20260608/gqa_cpu_flash_att

Conversation

@tianleiwu

@tianleiwu tianleiwu commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds an FP32 flash attention path for the CPU com.microsoft.GroupQueryAttention (GQA) contrib op, mirroring the existing quantized-KV flash attention path. The new tiled, online-softmax kernel avoids materializing the full [S, T] attention score matrix. It is restricted to prefill / chunked-prefill (sequence_length > 1); single-token decode falls back to the naive path. With causal early-termination it is faster than the naive path across all measured prefill lengths while using a fraction of the memory.

Key changes

  • New MLAS kernel onnxruntime/core/mlas/lib/flashattn_gqa.cpp (MlasFlashAttentionGQA):
    • Tiled QK / softmax / SV with online-softmax (running max/sum rescaling).
    • GQA head grouping (num_heads % kv_num_heads == 0), causal masking, local window, additive attention bias, and packed-QKV input.
    • Causal early-termination: during prefill, KV blocks that fall entirely in the causally masked upper triangle are skipped (break once ir >= past_seqlen + q_idx + row_size_q), avoiding the wasted QK/SV GEMMs over roughly half of the square prefill attention matrix.
    • Per-batch invocation for ragged / shared-buffer seqlens_k.
  • MLAS API onnxruntime/core/mlas/inc/mlas.h: new MlasFlashAttentionGQAArgs struct and MlasFlashAttentionGQA declaration.
  • Dispatch onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: new ApplyAttentionFlash that concatenates new K/V into the FP32 present cache and invokes the kernel. The per-thread scratch buffer size is computed with SafeInt<size_t> to guard against size_t overflow on large/malformed shapes before allocation.
  • Wiring onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc: float-only flash dispatch, active only for prefill (sequence_length > 1) and when softcap == 0, no smooth softmax, no head sink, no QK output; falls back to the naive path otherwise. The existing ORT_GQA_DISABLE_FLASH_ATTENTION env var disables it.
  • CMake cmake/onnxruntime_mlas.cmake: register the new source file.
  • Docs docs/contrib_ops/cpu/gqa.md: document the non-quantized flash attention path, activation conditions, causal early-termination, file list, and FP32 flash-vs-naive benchmark results.
  • Benchmark onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py: add an FP32 (non-quantized) mode (--fp32) for operator-level flash-vs-naive comparison.

Why prefill-only (sequence_length > 1)

Single-token decode (sequence_length == 1) produces only a [1, total_sequence_length] score row per head, so there is nothing to tile away and the extra online-softmax bookkeeping makes the flash kernel slower and noisier than naive in practice. Restricting the flash path to prefill keeps the consistent prefill win without regressing decode. Because decode is excluded, the two-phase flash-decoding kernels are unreachable and have been removed for a smaller, simpler implementation.

float16 continues to use the naive path (the kernel is float-only, matching the quantized flash constraint).

Performance

Operator-level, AMD EPYC 7763 (16 physical cores), threads=8, FP32 KV cache, B=1, num_heads=16, kv_num_heads=8, head_size=128. Flash is faster than naive across all measured prefill lengths (and single-threaded as well, 1.4-1.8x), confirming the gain is algorithmic - the causal early-termination removes the wasted upper-triangle work that previously made flash slower than naive at short sequences.

Prefill Seq Length Naive (ms) Flash (ms) Speedup
512 5.8-8.4 4.2-5.3 1.4-1.6x
1024 25-29 13-18 1.6-2.0x
2048 87-118 52-65 1.5-2.0x
4096 365-380 213-234 1.6-1.7x

The flash path's primary structural benefit is memory: it never allocates the full O(N x S x T) attention matrix (~1 GB at S=4096, N=16) and instead uses an O(S x Bc) per-thread tile.

Testing

  • C++ op tests: onnxruntime_provider_test --gtest_filter="GroupQueryAttentionTest.*" - 38 passed (12 GPU/WebGPU skipped) with flash on (default) and with ORT_GQA_DISABLE_FLASH_ATTENTION=1.
  • Flash vs. naive parity (FP32): output of the flash path matches the naive path (max abs diff ~1e-7) across prefill (block-aligned and non-aligned S), MHA and GQA head ratios, and local window. Decode now uses the naive path on both sides (diff 0).
  • Python parity (test_gqa_cpu.py, flash vs. naive reference): focused FP32 sweep of 600 prompt configurations covering all head sizes (32-256), GQA ratios (6,6)/(6,3)/(9,9)/(9,3), batches 1/3/5, causal/local window, attention bias, position ids, packed QKV, and with/without KV buffer - all passed. The official test_gqa_cpu.py suite passes.

Two correctness bugs were found and fixed via the parity sweep while developing this path:

  1. Attention-bias batch stride ignored head broadcasting for [batch, 1, S, T] bias.
  2. Query batch stride was hardcoded to num_heads * S * H, which is incorrect for packed-QKV input (correct stride is (num_heads + 2 * kv_num_heads) * S * H).

Add an FP32 tiled online-softmax flash attention kernel for the CPU
GroupQueryAttention contrib op, mirroring the existing quantized-KV flash
path. Avoids materializing the full attention score matrix and adds a
two-phase flash-decoding path for single-token decode.

- New MLAS kernel core/mlas/lib/flashattn_gqa.cpp (MlasFlashAttentionGQA)
  supporting GQA head grouping, causal masking, local window, attention
  bias, ragged/per-batch seqlens, packed QKV, and flash-decoding.
- New ApplyAttentionFlash dispatch in gqa_attention_base.h; wired into
  group_query_attention.cc (float only, gated like the quantized flash
  path: no softcap/smooth softmax/head sink/QK output).
- Reuses ORT_GQA_DISABLE_FLASH_ATTENTION to fall back to the naive path.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new FP32 “flash attention” execution path for the CPU com.microsoft.GroupQueryAttention contrib op by introducing an MLAS tiled/online-softmax kernel for FP32 KV caches, and wiring it into the existing GQA CPU implementation under the same gating conditions as the quantized flash path.

Changes:

  • Added new MLAS FP32 KV-cache flash-attention kernel (MlasFlashAttentionGQA) including a two-phase flash-decoding variant for sequence_length == 1.
  • Extended the MLAS public header with MlasFlashAttentionGQAArgs and the MlasFlashAttentionGQA API.
  • Added CPU GQA dispatch/wiring so float kernels can use the new flash path when supported, plus build-system and documentation updates.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
onnxruntime/core/mlas/lib/flashattn_gqa.cpp Implements FP32 KV-cache flash attention kernel + flash-decoding reduce path.
onnxruntime/core/mlas/inc/mlas.h Adds MLAS args struct + declaration for MlasFlashAttentionGQA.
onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Routes float CPU GQA to the new flash path when eligible.
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h Implements ApplyAttentionFlash to build FP32 present cache and invoke MLAS kernel.
docs/contrib_ops/cpu/gqa.md Documents the non-quantized (FP32) flash attention path and activation conditions.
cmake/onnxruntime_mlas.cmake Registers the new MLAS source file in the build.

Comment thread onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h Outdated
@tianleiwu tianleiwu marked this pull request as draft June 21, 2026 07:38
Guard the per-thread scratch and flash-decoding partials buffer size
computations against size_t overflow for large or malformed shapes,
matching the SafeInt usage elsewhere in this file.
@tianleiwu tianleiwu marked this pull request as ready for review June 22, 2026 04:45
@tianleiwu tianleiwu marked this pull request as draft June 22, 2026 07:51
The FP32 flash path is now restricted to prefill (sequence_length > 1),
so the single-token flash-decoding kernels are unreachable. Remove
MlasFlashDecodingGQAThreaded/MlasFlashDecodingGQAReduceThreaded, the
flash_decoding_partials/kv_chunk_count args fields, and the partials
buffer wiring in ApplyAttentionFlash. Simplify the MlasFlashAttentionGQA
dispatch to call the prefill kernel directly, and update the GQA doc to
describe the prefill-only behavior.
@tianleiwu tianleiwu marked this pull request as ready for review June 23, 2026 01:05
@hariharans29

Copy link
Copy Markdown
Member

Review: PR #28962 — Add flash attention for non-quantized CPU GroupQueryAttention

Author: tianleiwu · Branch: tlwu/20260608/gqa_cpu_flash_attmain
Size: +911 / −58 across 7 files · CI: 86 / 86 green · Status: awaiting approval (Copilot reviewed; apsonawane, hariharans29 requested)


Summary

Adds an FP32 flash-attention path for CPU com.microsoft.GroupQueryAttention, mirroring the existing quantized flash path. Tiled online-softmax kernel MlasFlashAttentionGQA in MLAS, dispatched only when the float specialization is used AND sequence_length > 1 AND none of {softcap, smooth-softmax, head-sink, QK-output} are requested. Single-token decode and fp16 continue to use the naive path.

  • Memory: avoids O(N·S·T) attention matrix; ~1 GB → ~tens of MB at S=4096, N=16.
  • Perf: 1.4–2.0× on prefill (and 1.4–1.8× single-threaded, so the win is algorithmic — mostly causal early-termination skipping the upper triangle).

What I'd approve as-is

  • Math is sound. Online softmax with (m, l, temp_output) state per row; rescale guarded by ir != 0; all--inf rows correctly handled by zeroing scores, continue-ing to skip softmax update, and relying on inv_l = (l > 0) ? 1/l : 0 at finalize. The m_diff = m_old - m_new path is correct when m_old = lowest() because exp(very-negative) ≈ 0 zeroes the old accumulator cleanly.
  • Causal early-termination bound is right. kv_causal_limit = past_seqlen + q_idx + row_size_q reflects the max global query position in the tile; the break is the legitimate optimization for the upper triangle.
  • Threading is race-free. (batch, head, q_block) task partitioning across MlasExecuteThreaded — each task writes to disjoint output tiles and a disjoint per-thread scratch slot. K/V concat uses TryParallelFor, which joins before MlasFlashAttentionGQA reads from the cache. No barriers needed.
  • Scratch sizing uses SafeInt<size_t> for the multiplications (buffer_size_per_thread, total_buffer_bytes) — guards against shape-driven overflow before the allocator call.
  • Fallback gates are correct. The if constexpr (std::is_same_v<T, float>) plus the conditional list mirrors the activation conditions documented in gqa.md. fp16 and decode paths untouched.
  • Layouts are consistent with the naive path. Q is BNSH, output is BSNH (flattens to [B, S, num_heads*head_size]). q_batch_stride is parameterized so packed-QKV works.
  • Bug fixes called out in the PR description (attention-bias batch stride ignoring head broadcasting; packed-QKV query stride hardcoded as num_heads * S * H) — these are real bugs that exist today in the quantized flash path and are tracked separately in Fix packed-QKV and broadcast-head bias strides in quantized GQA flash attention #28963. Good split.

Minor things to leave as comments

  1. Lint warning unaddressed: [cpplint] Add #include <algorithm> for min at gqa_attention_base.h:1151. One-line fix — add #include <algorithm> near the other <...> includes in that header. (The new flashattn_gqa.cpp already does.)

  2. Per-batch fallback foot-gun: in the else branch around L1133, args.q_batch_stride = SafeInt<size_t>(num_heads_) * sequence_length * head_size; is set but unused (kernel runs with batch_size=1, so batch_idx * stride = 0). For packed-QKV this stride would even be wrong, but the per-batch offset is applied to args.query = Q + b * q_batch_stride_elems before the kernel sees it. Worth a one-line comment so a future reader doesn't "fix" it incorrectly:

    // batch_size=1 means the kernel never multiplies by q_batch_stride; left set
    // only for completeness. Inter-batch offset is applied to args.query above.
  3. No new C++ unit tests in the diff. PR claims the existing GroupQueryAttentionTest.* suite covers the flash path because flash is on by default for float — that's true and reasonable, but it does mean we get only one execution flavor on CI. Worth asking if a focused _FlashPath test (or running the suite twice with ORT_GQA_DISABLE_FLASH_ATTENTION=0 and =1) would be tractable, so future regressions in flash-only logic surface immediately rather than via "did I forget to unset the env var?".

  4. The "600-config Python sweep" is great but not reproducible from this PR. PR description references an ad-hoc parity sweep against the naive path. If a script for it exists, committing it under test/python/transformers/ as a slow-marked test (or even a manual script) would make future flash-touching PRs cheap to re-validate. Nit, not a blocker.

  5. Local-window semantics: local_window_size == 0 masks self too (because window_start = causal_limit and kv_pos < window_start includes the current position). Consistent with the existing paths, so don't change here, but worth noting if anyone's ever tempted to "fix" it.


Things I would NOT ask for

  • Refactoring to share more with flashattn_qkv.cpp. Tightly-coupled hot kernels typically benefit from being separate, and the two diverge in non-trivial ways (dequant-on-read vs FP32 read, different SGEMM primitives, no flash-decoding here). Leave as parallel implementations.
  • Adding flash-decoding (sequence_length == 1) for FP32. PR explicitly justifies skipping it — bookkeeping overhead exceeded the tiling benefit, and the naive path is already optimal for a single row of scores. Author has follow-up [CPU] Add FP32 GEMV decode kernel for GroupQueryAttention #29216 for an FP32 GEMV decode kernel that takes a different angle. Right call to keep this PR focused.

Verdict

Approve with the two minor comments (lint include + per-batch stride comment).
The big risks — correctness of online softmax + causal early-termination, race-freedom across batches/heads/tiles, dispatch fallback completeness — all look handled. The perf wins are credible (single-threaded result rules out "just from parallelism"). Tests are a thin spot but not blocking given the integration coverage and the parity-sweep claim.

@tianleiwu

Copy link
Copy Markdown
Contributor Author

I could address these feedbacks in next PR 29216.

@tianleiwu tianleiwu merged commit 996cea1 into main Jun 24, 2026
89 of 90 checks passed
@tianleiwu tianleiwu deleted the tlwu/20260608/gqa_cpu_flash_att branch June 24, 2026 22:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants