webgpu: Enable FlashAttention for batched GQA with right-padded prompts#29247
Merged
Conversation
When GenAI runs a batched prefill with prompts of unequal lengths, short prompts are right-padded up to the batch max sequence_length and each batch's real length is reported via seqlens_k[b] = real_len[b] - 1. The WebGPU rotary embedding shaders computed past_seqlen = (seqlens_k[b]+1) - sequence_length per batch, which underflowed u32 for any batch shorter than sequence_length. The resulting astronomically large position_id indexed past the cos/sin caches and produced garbage rotated Q/K, which manifested as gibberish output text for the shorter batches in the batch. Clamp past_seqlen to 0 in all three rotary embedding shaders: RotaryEmbeddingProgram (seqlens variant), FusedQKRotaryEmbeddingProgram, and the split_packed_qkv_with_rotary_embedding template. Also extend CanApplyFlashAttention to bypass FlashAttention for batched cases with per-batch seqlens (which exercise the unpatched and-copykv variant), while still allowing it for shared-KV layers where it is mandatory. Adds a regression test exercising the packed-QKV do_rotary path with three batches of unequal real lengths.
The FlashAttention path on WebGPU previously gated batched GQA out via
(batch_size_ == 1 || seqlen_k == nullptr || kv_empty) in
CanApplyFlashAttention because three shaders hardcoded seqlens_k[0] and
several KV-cache write offsets / causal-mask / rotary-position derivations
underflowed u32 when the per-batch prompt was shorter than the
max-across-batches sequence_length (right-padded batches).
This change reads seqlens_k per batch in all FlashAttention shaders
(prefill + decode split-reduce + CopyKVCache + the rotary-and-copyKV
template), clamps every past_X = total_X - new_X subtraction to avoid
u32 underflow, and decouples attention_bias stride (still allocated to
the global max total_sequence_length) from the per-batch OOB check.
The decode_qkv shader retains a workgroup-grid sized to the global max
total_sequence_length tile count (so workgroup_idx slicing remains
self-consistent across batches), and early-exits with neutral metadata
(-inf, 0) for tiles beyond a short batch's per-batch total so the
VxReduce online softmax rescaling is not skewed by garbage tiles.
A new use_seqlen_k template parameter (separate from
use_indirect_dispatch, which still requires graph capture) drives the
per-batch path; it is enabled whenever seqlen_k is provided and
(graph_capture || batch_size_ > 1).
Verified:
- All 7 GroupQueryAttentionTest.WebGPU_* op tests pass, including
BatchedRightPaddedRotaryPrefill which now exercises FlashAttention
instead of ApplyAttention.
- phi4-prune three-prompt batched generation produces coherent,
correct outputs on WebGPU matching the CPU reference.
- phi4-prune single-prompt generation regression: coherent output.
- whisper-tiny-int4 transcription regression: 2/2 byte-exact with CPU.
# Conflicts: # onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc # onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
- Drop the now-unused seqlen_k parameter from CanApplyFlashAttention and update the GQA caller. The argument was already ORT_UNUSED_PARAMETER. - Simplify use_seqlen_k to (seqlen_k != nullptr) so batch=1 and batch>1 share one path. The previous (graph_capture || batch_size > 1) qualifier was redundant: both conditions imply seqlen_k is supplied, and reading seqlens_k[batch_idx] in the shader is a no-op for batch=1. - Read attention_bias's actual last-dim stride from a new attn_bias_dim3 uniform instead of uniforms.total_sequence_length. The shader stride must match the tensor's storage shape, which can differ from the per-step total (e.g. graph capture sets total_sequence_length=0 on the host, which would have produced a zero-stride bias offset). - Strip the unreachable indirect-dispatch branch in flash_attention_decode_ vx_reduce.wgsl.template. After the use_seqlen_k simplification, use_indirect_dispatch=true implies use_seqlen_k=true (the host gate requires seqlen_k != nullptr), so the inner branch could never execute. Keep the use_indirect_dispatch flag elsewhere for forward compatibility. - Add BatchedRightPaddedRotaryPrefillFlashAttention_WebGPU test that exercises the FlashAttentionProgram prefill path (sequence_length=33 crosses the split-reduce threshold of 32) with right-padded batches and do_rotary, matching the Phi-style batched prefill scenario.
… prefill tests Thread a smooth_softmax flag through RunGQAPackedQKVRotaryPrefill / RunBatchedRightPaddedRotaryPrefillForEP and add a new WebGPU test: BatchedRightPaddedRotaryPrefillNonFlashAttention_WebGPU With smooth_softmax=1 the WebGPU EP skips CanApplyFlashAttention and routes through ApplyAttention, so the three WebGPU prefill tests now cover all three batched right-padded code paths: - split-reduce decode shader (FlashAttentionDecodeQKV + VxReduce) - fused prefill shader (FlashAttentionProgram) - non-flash attention path (ApplyAttention) Verified by temporarily printing the chosen path during dispatch: each test hits its intended path with batch_size=3 and right-padded seqlens_k.
use_indirect_dispatch implies seqlen_k != nullptr (graph-capture path), and use_seqlen_k_ = (seqlen_k != nullptr), so use_indirect_dispatch_ ⇒ use_seqlen_k_. The needs_seqlens_k = use_indirect_dispatch_ || use_seqlen_k_ disjunction is therefore equivalent to use_seqlen_k_ at every call site. Drop the local.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR updates the WebGPU GroupQueryAttention FlashAttention implementation to support common GenAI batched-prefill shapes with right-padded prompts by consuming per-batch seqlens_k in the WGSL shaders, relaxing the previous batch_size == 1 gating, and extending the test suite to cover the newly-enabled fused paths.
Changes:
- Plumb per-batch
seqlens_k[batch_idx]through FlashAttention-related WGSL (prefill + split-reduce decode) and clamp “past = total - new” style subtractions to avoidu32underflow. - Decouple
attention_biaslast-dimension stride from per-batchtotal_sequence_lengthby passing a newattn_bias_dim3uniform. - Add/extend WebGPU tests to cover (a) FlashAttention prefill for right-padded batches and (b) the non-flash route when
smooth_softmax=1.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
onnxruntime/test/contrib_ops/group_query_attention_op_test.cc |
Adds WebGPU regression tests that force FlashAttention vs non-flash paths under right-padded batched rotary prefill. |
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc |
Updates FlashAttention routing to use the new CanApplyFlashAttention signature (and removes the former seqlen-based restriction). |
onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template |
Switches to per-batch total sequence length and introduces attn_bias_dim3-based stride handling. |
onnxruntime/contrib_ops/webgpu/bert/flash_attention.h |
Extends uniform layouts (attn_bias_dim3) and adds use_seqlen_k plumbing for decode programs; changes CanApplyFlashAttention signature. |
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc |
Implements per-batch seqlen consumption across kernels, propagates attn_bias_dim3, and relaxes flash applicability gating. |
onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template |
Uses per-batch seqlen when enabled to avoid skewing online softmax rescaling for “extra” tiles. |
onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template |
Adds early-exit neutral metadata for tiles beyond a batch’s real K/V length and uses per-batch seqlen for bounds/masking when enabled. |
Comments suppressed due to low confidence (1)
onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template:184
- Same one-past-the-end issue in the vec4 attention_bias path:
offset_maxshould be the last valid element index, notoffset_base + stride_total_seq(which is exclusive). As written, themin(offset + N, offset_max)clamp can still select an out-of-bounds element whenstride_total_seqis smaller than the accessed range.
let offset_max = offset_base + stride_total_seq;
let c1 = q_element_t(attention_bias[min(offset, offset_max)]);
let c2 = q_element_t(attention_bias[min(offset + 1, offset_max)]);
let c3 = q_element_t(attention_bias[min(offset + 2, offset_max)]);
let c4 = q_element_t(attention_bias[min(offset + 3, offset_max)]);
In graph-capture mode the host total_sequence_length scalar is 0 and the dispatch grid for the flash-attention pipeline is computed on the GPU. The three shaders that prepare or consume the indirect-dispatch buffer (CopyKVCache, SplitPackedQKVWithRotaryEmbeddingAndCopyKV, FlashAttentionDecodeQKV) previously sized the grid from seqlens_k[batch=0] + 1. For batched right-padded prefill, batch 0 is not guaranteed to hold the maximum KV span, so when the spread across the batch crosses a tile boundary other batches lose tiles and produce wrong output. Thread GQA's input microsoft#6 (total_sequence_length, GPU-resident exactly when graph capture is enabled) through ApplyFlashAttention into the three shaders and use it for the indirect-dispatch sizing only. Per-batch seqlens_k[batch] + 1 still drives causal masking and per-batch bounds inside the kernels. Also enforce in GroupQueryAttention that graph capture implies past_present_share_buffer_, so the use_indirect_dispatch predicate only needs to check seqlen_k, total_seqlen, and IsGraphCaptureEnabled. Address PR review: - Clamp attention_bias load to offset_base + stride_total_seq - 1u in both scalar and vec4 paths so the one-past-end fallback stays within the same row. - Reword the smooth_softmax test comment to reference the outer gating in GroupQueryAttention::ComputeInternal that routes through ApplyAttention. - Extend the indirect-dispatch fix to FlashAttentionDecodeQKV; the new use of use_indirect_dispatch_ also resolves the -Wunused-private-field Clang error on the wasm and arm64 builds. Add BatchedRightPaddedRotaryPrefillFlashAttentionLargeSpread_WebGPU with real_lens spread > tile_size so a future regression in the dispatch sizing surfaces in the WebGPU test suite (graph capture itself cannot be toggled from OpTester).
FlashAttentionDecodeVxReduceProgram no longer branches on use_indirect_dispatch in its shader template (the per-batch iteration is gated by use_seqlen_k instead), so the field is dead and Clang rejects it as -Wunused-private-field on wasm and arm64 builds. Remove the parameter from the program ctor, the member field, the ComputeFlashAttentionDecodeVxReduce signature, and the CacheHint so identical shaders share one cached pipeline.
guschmue
approved these changes
Jun 25, 2026
Contributor
Author
|
FYI @sushraja-msft Once this PR is merged, you may need to rebase #28059 |
hariharans29
approved these changes
Jun 26, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Lift WebGPU FlashAttention's
batch_size == 1restriction so batched GQA with right-padded prompts (the common GenAI batched-prefill shape) takes the fused FlashAttention path instead of falling back toApplyAttention.seqlens_k[batch_idx]instead of hardcodingseqlens_k[0]. Allpast_X = total_X - new_Xsubtractions are clamped to avoid u32 underflow when a short batch's per-batch total is less than the batch-widesequence_length.total_sequence_lengthinput.CopyKVCache,SplitPackedQKVWithRotaryEmbeddingAndCopyKV, andFlashAttentionDecodeQKVnow take a newtotal_sequence_length_inputbinding (GQA input Add doxygen generated website for the project #6, GPU-resident under graph capture) for the indirect-dispatch grid sizing. This is the global max KV span across the batch by construction, replacing the previousseqlens_k[0] + 1uthat under-dispatched whenever batch 0 wasn't the longest. Per-batchseqlens_k[batch] + 1still drives causal masking and K/V bounds inside the kernels. GQA now enforcesgraph_capture_enabled -> past_present_share_buffer_so the host-sideuse_indirect_dispatchpredicate stays simple.attention_biasis still allocated to the global maxtotal_sequence_length; only the causal-mask / softmax tile loops are gated by the per-batch total. The one-past-end fallback was tightened to clamp inside the same row (offset_base + stride_total_seq - 1u).decode_qkvkeeps a workgroup grid sized to the global max tile count to keepworkgroup_idxslicing consistent across batches, with neutral(-inf, 0)early-exit for tiles beyond a short batch's per-batch total so theVxReduceonline softmax rescaling is not skewed.use_seqlen_ktemplate parameter (separate fromuse_indirect_dispatchwhich still requires graph capture). It is enabled wheneverseqlen_kis provided and (graph_capture || batch_size_ > 1).webgpu: fix GQA batched right-padded prefill with do_rotary, 591df5b): clampspast_seqlento 0 inRotaryEmbeddingProgram,FusedQKRotaryEmbeddingProgram, andsplit_packed_qkv_with_rotary_embedding, which previously produced gibberish for the shorter batches.Motivation
GenAI's batched prefill right-pads short prompts to the batch max and reports each batch's real length via
seqlens_k[b] = real_len[b] - 1. The previous FlashAttention gate forced every batched call onto the slowerApplyAttentionpath, and the rotary shaders underflowedu32for any batch shorter than the batch-widesequence_length, producing garbage Q/K positions and gibberish output text for the shorter batches.Test plan
GroupQueryAttentionTest.WebGPU_*op tests pass, includingBatchedRightPaddedRotaryPrefill(FlashAttention path) and the newBatchedRightPaddedRotaryPrefillFlashAttentionLargeSpread_WebGPUcovering areal_lensspread > tile_sizeverify_model_correctness.py4/4 PASS;verify_multi_gen.pysequential + overlapping both PASS