Fix packed-QKV and broadcast-head bias strides in quantized GQA flash attention#28963
Open
tianleiwu wants to merge 4 commits into
Open
Fix packed-QKV and broadcast-head bias strides in quantized GQA flash attention#28963tianleiwu wants to merge 4 commits into
tianleiwu wants to merge 4 commits into
Conversation
…ized GQA The quantized KV-cache flash-attention path used a hardcoded Q batch stride (num_heads*S*H) and a num_heads-based attention-bias batch stride, producing incorrect results for batch_size > 1 with packed QKV or [B, 1, S, T] bias. Add a caller-supplied q_batch_stride to MlasFlashAttentionQuantizedKVArgs and compute the bias batch stride from the actual head extent (1 when the head dim is broadcast). Mirrors the FP32 fixes from #28962. Adds a regression test for batch > 1 with head-broadcast bias.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR fixes two correctness issues in the CPU quantized KV-cache flash-attention path for the com.microsoft.GroupQueryAttention contrib op that can produce wrong results when batch_size > 1: (1) query (Q) batch stride handling for packed QKV layouts, and (2) attention-bias batch stride computation when the head dimension is broadcast (bias shape [B, 1, S, T]). It also adds a Python regression test to cover the previously-missed multi-batch broadcast-head bias case.
Changes:
- Add
q_batch_stridetoMlasFlashAttentionQuantizedKVArgsand use it in the quantized flash-attention kernels to support both standard and packed-QKV layouts. - Fix attention-bias batch-stride calculation in quantized flash attention to account for head broadcasting (
bias_head_extent = 1when broadcast). - Add a new quantized Python regression test for
[B, 1, S, T]bias withbatch_size > 1.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
onnxruntime/core/mlas/inc/mlas_qkv_quant.h |
Extends quantized flash-attention args with q_batch_stride for correct Q addressing across batches/layouts. |
onnxruntime/core/mlas/lib/flashattn_qkv.cpp |
Uses caller-provided Q batch stride and fixes bias batch stride when head dim is broadcast (tiled + decoding paths). |
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h |
Wires q_batch_stride from the op dispatch and fixes per-batch bias slicing stride for broadcast-head bias. |
onnxruntime/test/python/transformers/test_gqa_cpu_quantized.py |
Adds a multi-batch regression test for broadcast-head attention bias. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
The quantized KV-cache flash-attention path for the CPU
GroupQueryAttentioncontrib op carried two latent batch-stride bugs that produced incorrect results forbatch_size > 1. This PR ports the fixes already landed for the FP32 path (PR #28962) to the quantized path, and adds a regression test that exercises the previously-uncovered scenario.Summary of Changes
Bug fixes
onnxruntime/core/mlas/inc/mlas_qkv_quant.hq_batch_stridefield toMlasFlashAttentionQuantizedKVArgsso the kernel uses a caller-supplied Q batch stride instead of assuming the unpackednum_heads*S*Hlayout.onnxruntime/core/mlas/lib/flashattn_qkv.cppargs->q_batch_stridefor the Q pointer in both the main tiled kernel and the flash-decoding kernel; compute the attention-bias batch stride frombias_head_extent = broadcast_head ? 1 : num_heads(two sites) instead of always usingnum_heads.onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.hq_batch_strideto(num_heads + 2*kv_num_heads)*S*Hfor packed QKV (elsenum_heads*S*H) in the unified dispatch; correct the per-batch Q offset and bias slice to use the packed stride and head extent.Tests
test_int8_bias_broadcast_head_multi_batchtotest_gqa_cpu_quantized.pycovering[B, 1, S, T]bias withbatch_size > 1. The existingtest_int8_bias_broadcast_headusedbatch_size == 1, which masked the head-broadcast batch-stride bug.Testing
cd onnxruntime/test/python/transformers && python -m pytest test_gqa_cpu_quantized.py -q→ 21 passed, 2 skipped.test_int8_bias_broadcast_head_multi_batchfail (mismatches starting at batch index 1) while thebatch_size == 1variant still passes; restoring the fix makes all tests pass.lintrunner -a→ no lint issues.Motivation and Context
Follow-up to PR #28962, which fixed the identical two bugs (packed-QKV Q batch stride and attention-bias head-broadcast batch stride) in the non-quantized FP32 CPU GQA path. The quantized path was left untouched there as out of scope; existing quantized parity tests did not cover
batch > 1with[B, 1, S, T]bias, so the bugs went undetected.