Skip to content

Fix packed-QKV and broadcast-head bias strides in quantized GQA flash attention#28963

Open
tianleiwu wants to merge 4 commits into
mainfrom
tlwu/fix_gqa_quantized_kv
Open

Fix packed-QKV and broadcast-head bias strides in quantized GQA flash attention#28963
tianleiwu wants to merge 4 commits into
mainfrom
tlwu/fix_gqa_quantized_kv

Conversation

@tianleiwu

@tianleiwu tianleiwu commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Description

The quantized KV-cache flash-attention path for the CPU GroupQueryAttention contrib op carried two latent batch-stride bugs that produced incorrect results for batch_size > 1. This PR ports the fixes already landed for the FP32 path (PR #28962) to the quantized path, and adds a regression test that exercises the previously-uncovered scenario.

Summary of Changes

Bug fixes

File Change
onnxruntime/core/mlas/inc/mlas_qkv_quant.h Add q_batch_stride field to MlasFlashAttentionQuantizedKVArgs so the kernel uses a caller-supplied Q batch stride instead of assuming the unpacked num_heads*S*H layout.
onnxruntime/core/mlas/lib/flashattn_qkv.cpp Use args->q_batch_stride for the Q pointer in both the main tiled kernel and the flash-decoding kernel; compute the attention-bias batch stride from bias_head_extent = broadcast_head ? 1 : num_heads (two sites) instead of always using num_heads.
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h Set q_batch_stride to (num_heads + 2*kv_num_heads)*S*H for packed QKV (else num_heads*S*H) in the unified dispatch; correct the per-batch Q offset and bias slice to use the packed stride and head extent.

Tests

  • Add test_int8_bias_broadcast_head_multi_batch to test_gqa_cpu_quantized.py covering [B, 1, S, T] bias with batch_size > 1. The existing test_int8_bias_broadcast_head used batch_size == 1, which masked the head-broadcast batch-stride bug.

Testing

  • cd onnxruntime/test/python/transformers && python -m pytest test_gqa_cpu_quantized.py -q → 21 passed, 2 skipped.
  • Verified the new test catches the bug: temporarily reverting the bias-stride fix makes test_int8_bias_broadcast_head_multi_batch fail (mismatches starting at batch index 1) while the batch_size == 1 variant still passes; restoring the fix makes all tests pass.
  • lintrunner -a → no lint issues.

Motivation and Context

Follow-up to PR #28962, which fixed the identical two bugs (packed-QKV Q batch stride and attention-bias head-broadcast batch stride) in the non-quantized FP32 CPU GQA path. The quantized path was left untouched there as out of scope; existing quantized parity tests did not cover batch > 1 with [B, 1, S, T] bias, so the bugs went undetected.

…ized GQA

The quantized KV-cache flash-attention path used a hardcoded Q batch stride
(num_heads*S*H) and a num_heads-based attention-bias batch stride, producing
incorrect results for batch_size > 1 with packed QKV or [B, 1, S, T] bias.

Add a caller-supplied q_batch_stride to MlasFlashAttentionQuantizedKVArgs and
compute the bias batch stride from the actual head extent (1 when the head dim
is broadcast). Mirrors the FP32 fixes from #28962. Adds a regression test for
batch > 1 with head-broadcast bias.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes two correctness issues in the CPU quantized KV-cache flash-attention path for the com.microsoft.GroupQueryAttention contrib op that can produce wrong results when batch_size > 1: (1) query (Q) batch stride handling for packed QKV layouts, and (2) attention-bias batch stride computation when the head dimension is broadcast (bias shape [B, 1, S, T]). It also adds a Python regression test to cover the previously-missed multi-batch broadcast-head bias case.

Changes:

  • Add q_batch_stride to MlasFlashAttentionQuantizedKVArgs and use it in the quantized flash-attention kernels to support both standard and packed-QKV layouts.
  • Fix attention-bias batch-stride calculation in quantized flash attention to account for head broadcasting (bias_head_extent = 1 when broadcast).
  • Add a new quantized Python regression test for [B, 1, S, T] bias with batch_size > 1.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
onnxruntime/core/mlas/inc/mlas_qkv_quant.h Extends quantized flash-attention args with q_batch_stride for correct Q addressing across batches/layouts.
onnxruntime/core/mlas/lib/flashattn_qkv.cpp Uses caller-provided Q batch stride and fixes bias batch stride when head dim is broadcast (tiled + decoding paths).
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h Wires q_batch_stride from the op dispatch and fixes per-batch bias slicing stride for broadcast-head bias.
onnxruntime/test/python/transformers/test_gqa_cpu_quantized.py Adds a multi-batch regression test for broadcast-head attention bias.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/mlas/inc/mlas_qkv_quant.h Outdated
Comment thread onnxruntime/test/python/transformers/test_gqa_cpu_quantized.py

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.

Comment thread onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h Outdated
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants