Skip to content

webgpu: Enable FlashAttention for batched GQA with right-padded prompts#29247

Merged
qjia7 merged 8 commits into
microsoft:mainfrom
qjia7:feat/webgpu-flash-attention-batched
Jun 26, 2026
Merged

webgpu: Enable FlashAttention for batched GQA with right-padded prompts#29247
qjia7 merged 8 commits into
microsoft:mainfrom
qjia7:feat/webgpu-flash-attention-batched

Conversation

@qjia7

@qjia7 qjia7 commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Summary

Lift WebGPU FlashAttention's batch_size == 1 restriction so batched GQA with right-padded prompts (the common GenAI batched-prefill shape) takes the fused FlashAttention path instead of falling back to ApplyAttention.

  • Per-batch seqlens in FlashAttention shaders. Prefill, decode split-reduce, CopyKVCache, and the fused rotary-and-copyKV template now read seqlens_k[batch_idx] instead of hardcoding seqlens_k[0]. All past_X = total_X - new_X subtractions are clamped to avoid u32 underflow when a short batch's per-batch total is less than the batch-wide sequence_length.
  • Indirect-dispatch sizing uses GQA's total_sequence_length input. CopyKVCache, SplitPackedQKVWithRotaryEmbeddingAndCopyKV, and FlashAttentionDecodeQKV now take a new total_sequence_length_input binding (GQA input Add doxygen generated website for the project #6, GPU-resident under graph capture) for the indirect-dispatch grid sizing. This is the global max KV span across the batch by construction, replacing the previous seqlens_k[0] + 1u that under-dispatched whenever batch 0 wasn't the longest. Per-batch seqlens_k[batch] + 1 still drives causal masking and K/V bounds inside the kernels. GQA now enforces graph_capture_enabled -> past_present_share_buffer_ so the host-side use_indirect_dispatch predicate stays simple.
  • Decoupled attention_bias stride from per-batch OOB. attention_bias is still allocated to the global max total_sequence_length; only the causal-mask / softmax tile loops are gated by the per-batch total. The one-past-end fallback was tightened to clamp inside the same row (offset_base + stride_total_seq - 1u).
  • Decode workgroup grid stays at global max. decode_qkv keeps a workgroup grid sized to the global max tile count to keep workgroup_idx slicing consistent across batches, with neutral (-inf, 0) early-exit for tiles beyond a short batch's per-batch total so the VxReduce online softmax rescaling is not skewed.
  • New use_seqlen_k template parameter (separate from use_indirect_dispatch which still requires graph capture). It is enabled whenever seqlen_k is provided and (graph_capture || batch_size_ > 1).
  • Rotary fix prerequisite (webgpu: fix GQA batched right-padded prefill with do_rotary, 591df5b): clamps past_seqlen to 0 in RotaryEmbeddingProgram, FusedQKRotaryEmbeddingProgram, and split_packed_qkv_with_rotary_embedding, which previously produced gibberish for the shorter batches.

Motivation

GenAI's batched prefill right-pads short prompts to the batch max and reports each batch's real length via seqlens_k[b] = real_len[b] - 1. The previous FlashAttention gate forced every batched call onto the slower ApplyAttention path, and the rotary shaders underflowed u32 for any batch shorter than the batch-wide sequence_length, producing garbage Q/K positions and gibberish output text for the shorter batches.

Test plan

  • All GroupQueryAttentionTest.WebGPU_* op tests pass, including BatchedRightPaddedRotaryPrefill (FlashAttention path) and the new BatchedRightPaddedRotaryPrefillFlashAttentionLargeSpread_WebGPU covering a real_lens spread > tile_size
  • phi4-prune three-prompt batched generation: coherent outputs on WebGPU matching CPU reference (3 prompts, 384 tokens, 173 tps)
  • phi4-prune single-prompt generation regression: coherent
  • phi4-graph-prune (graph capture enabled): verify_model_correctness.py 4/4 PASS; verify_multi_gen.py sequential + overlapping both PASS
  • whisper-tiny-int4 transcription regression: 2/2 byte-exact with CPU
  • Lintrunner clean on all changed files

qjia7 added 6 commits June 11, 2026 16:43
When GenAI runs a batched prefill with prompts of unequal lengths, short
prompts are right-padded up to the batch max sequence_length and each
batch's real length is reported via seqlens_k[b] = real_len[b] - 1. The
WebGPU rotary embedding shaders computed past_seqlen = (seqlens_k[b]+1)
- sequence_length per batch, which underflowed u32 for any batch shorter
than sequence_length. The resulting astronomically large position_id
indexed past the cos/sin caches and produced garbage rotated Q/K, which
manifested as gibberish output text for the shorter batches in the
batch.

Clamp past_seqlen to 0 in all three rotary embedding shaders:
RotaryEmbeddingProgram (seqlens variant), FusedQKRotaryEmbeddingProgram,
and the split_packed_qkv_with_rotary_embedding template. Also extend
CanApplyFlashAttention to bypass FlashAttention for batched cases with
per-batch seqlens (which exercise the unpatched and-copykv variant),
while still allowing it for shared-KV layers where it is mandatory.

Adds a regression test exercising the packed-QKV do_rotary path with
three batches of unequal real lengths.
The FlashAttention path on WebGPU previously gated batched GQA out via
(batch_size_ == 1 || seqlen_k == nullptr || kv_empty) in
CanApplyFlashAttention because three shaders hardcoded seqlens_k[0] and
several KV-cache write offsets / causal-mask / rotary-position derivations
underflowed u32 when the per-batch prompt was shorter than the
max-across-batches sequence_length (right-padded batches).

This change reads seqlens_k per batch in all FlashAttention shaders
(prefill + decode split-reduce + CopyKVCache + the rotary-and-copyKV
template), clamps every past_X = total_X - new_X subtraction to avoid
u32 underflow, and decouples attention_bias stride (still allocated to
the global max total_sequence_length) from the per-batch OOB check.

The decode_qkv shader retains a workgroup-grid sized to the global max
total_sequence_length tile count (so workgroup_idx slicing remains
self-consistent across batches), and early-exits with neutral metadata
(-inf, 0) for tiles beyond a short batch's per-batch total so the
VxReduce online softmax rescaling is not skewed by garbage tiles.

A new use_seqlen_k template parameter (separate from
use_indirect_dispatch, which still requires graph capture) drives the
per-batch path; it is enabled whenever seqlen_k is provided and
(graph_capture || batch_size_ > 1).

Verified:
  - All 7 GroupQueryAttentionTest.WebGPU_* op tests pass, including
    BatchedRightPaddedRotaryPrefill which now exercises FlashAttention
    instead of ApplyAttention.
  - phi4-prune three-prompt batched generation produces coherent,
    correct outputs on WebGPU matching the CPU reference.
  - phi4-prune single-prompt generation regression: coherent output.
  - whisper-tiny-int4 transcription regression: 2/2 byte-exact with CPU.
# Conflicts:
#	onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
#	onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
- Drop the now-unused seqlen_k parameter from CanApplyFlashAttention and
  update the GQA caller. The argument was already ORT_UNUSED_PARAMETER.

- Simplify use_seqlen_k to (seqlen_k != nullptr) so batch=1 and batch>1
  share one path. The previous (graph_capture || batch_size > 1) qualifier
  was redundant: both conditions imply seqlen_k is supplied, and reading
  seqlens_k[batch_idx] in the shader is a no-op for batch=1.

- Read attention_bias's actual last-dim stride from a new attn_bias_dim3
  uniform instead of uniforms.total_sequence_length. The shader stride
  must match the tensor's storage shape, which can differ from the
  per-step total (e.g. graph capture sets total_sequence_length=0 on the
  host, which would have produced a zero-stride bias offset).

- Strip the unreachable indirect-dispatch branch in flash_attention_decode_
  vx_reduce.wgsl.template. After the use_seqlen_k simplification,
  use_indirect_dispatch=true implies use_seqlen_k=true (the host gate
  requires seqlen_k != nullptr), so the inner branch could never execute.
  Keep the use_indirect_dispatch flag elsewhere for forward compatibility.

- Add BatchedRightPaddedRotaryPrefillFlashAttention_WebGPU test that
  exercises the FlashAttentionProgram prefill path (sequence_length=33
  crosses the split-reduce threshold of 32) with right-padded batches
  and do_rotary, matching the Phi-style batched prefill scenario.
… prefill tests

Thread a smooth_softmax flag through RunGQAPackedQKVRotaryPrefill /
RunBatchedRightPaddedRotaryPrefillForEP and add a new WebGPU test:

  BatchedRightPaddedRotaryPrefillNonFlashAttention_WebGPU

With smooth_softmax=1 the WebGPU EP skips CanApplyFlashAttention and routes
through ApplyAttention, so the three WebGPU prefill tests now cover all three
batched right-padded code paths:

  - split-reduce decode shader  (FlashAttentionDecodeQKV + VxReduce)
  - fused prefill shader        (FlashAttentionProgram)
  - non-flash attention path    (ApplyAttention)

Verified by temporarily printing the chosen path during dispatch: each test
hits its intended path with batch_size=3 and right-padded seqlens_k.
use_indirect_dispatch implies seqlen_k != nullptr (graph-capture path), and
use_seqlen_k_ = (seqlen_k != nullptr), so use_indirect_dispatch_ ⇒ use_seqlen_k_.
The needs_seqlens_k = use_indirect_dispatch_ || use_seqlen_k_ disjunction is
therefore equivalent to use_seqlen_k_ at every call site. Drop the local.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the WebGPU GroupQueryAttention FlashAttention implementation to support common GenAI batched-prefill shapes with right-padded prompts by consuming per-batch seqlens_k in the WGSL shaders, relaxing the previous batch_size == 1 gating, and extending the test suite to cover the newly-enabled fused paths.

Changes:

  • Plumb per-batch seqlens_k[batch_idx] through FlashAttention-related WGSL (prefill + split-reduce decode) and clamp “past = total - new” style subtractions to avoid u32 underflow.
  • Decouple attention_bias last-dimension stride from per-batch total_sequence_length by passing a new attn_bias_dim3 uniform.
  • Add/extend WebGPU tests to cover (a) FlashAttention prefill for right-padded batches and (b) the non-flash route when smooth_softmax=1.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Adds WebGPU regression tests that force FlashAttention vs non-flash paths under right-padded batched rotary prefill.
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Updates FlashAttention routing to use the new CanApplyFlashAttention signature (and removes the former seqlen-based restriction).
onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template Switches to per-batch total sequence length and introduces attn_bias_dim3-based stride handling.
onnxruntime/contrib_ops/webgpu/bert/flash_attention.h Extends uniform layouts (attn_bias_dim3) and adds use_seqlen_k plumbing for decode programs; changes CanApplyFlashAttention signature.
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Implements per-batch seqlen consumption across kernels, propagates attn_bias_dim3, and relaxes flash applicability gating.
onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template Uses per-batch seqlen when enabled to avoid skewing online softmax rescaling for “extra” tiles.
onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template Adds early-exit neutral metadata for tiles beyond a batch’s real K/V length and uses per-batch seqlen for bounds/masking when enabled.
Comments suppressed due to low confidence (1)

onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template:184

  • Same one-past-the-end issue in the vec4 attention_bias path: offset_max should be the last valid element index, not offset_base + stride_total_seq (which is exclusive). As written, the min(offset + N, offset_max) clamp can still select an out-of-bounds element when stride_total_seq is smaller than the accessed range.
  let offset_max = offset_base + stride_total_seq;
  let c1 = q_element_t(attention_bias[min(offset, offset_max)]);
  let c2 = q_element_t(attention_bias[min(offset + 1, offset_max)]);
  let c3 = q_element_t(attention_bias[min(offset + 2, offset_max)]);
  let c4 = q_element_t(attention_bias[min(offset + 3, offset_max)]);

Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Outdated
In graph-capture mode the host total_sequence_length scalar is 0 and the
dispatch grid for the flash-attention pipeline is computed on the GPU.
The three shaders that prepare or consume the indirect-dispatch buffer
(CopyKVCache, SplitPackedQKVWithRotaryEmbeddingAndCopyKV,
FlashAttentionDecodeQKV) previously sized the grid from
seqlens_k[batch=0] + 1. For batched right-padded prefill, batch 0 is
not guaranteed to hold the maximum KV span, so when the spread across
the batch crosses a tile boundary other batches lose tiles and produce
wrong output.

Thread GQA's input microsoft#6 (total_sequence_length, GPU-resident exactly when
graph capture is enabled) through ApplyFlashAttention into the three
shaders and use it for the indirect-dispatch sizing only. Per-batch
seqlens_k[batch] + 1 still drives causal masking and per-batch bounds
inside the kernels.

Also enforce in GroupQueryAttention that graph capture implies
past_present_share_buffer_, so the use_indirect_dispatch predicate only
needs to check seqlen_k, total_seqlen, and IsGraphCaptureEnabled.

Address PR review:
- Clamp attention_bias load to offset_base + stride_total_seq - 1u in
  both scalar and vec4 paths so the one-past-end fallback stays within
  the same row.
- Reword the smooth_softmax test comment to reference the outer gating
  in GroupQueryAttention::ComputeInternal that routes through
  ApplyAttention.
- Extend the indirect-dispatch fix to FlashAttentionDecodeQKV; the new
  use of use_indirect_dispatch_ also resolves the -Wunused-private-field
  Clang error on the wasm and arm64 builds.

Add BatchedRightPaddedRotaryPrefillFlashAttentionLargeSpread_WebGPU with
real_lens spread > tile_size so a future regression in the dispatch
sizing surfaces in the WebGPU test suite (graph capture itself cannot
be toggled from OpTester).

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.

Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
FlashAttentionDecodeVxReduceProgram no longer branches on
use_indirect_dispatch in its shader template (the per-batch iteration
is gated by use_seqlen_k instead), so the field is dead and Clang
rejects it as -Wunused-private-field on wasm and arm64 builds.

Remove the parameter from the program ctor, the member field, the
ComputeFlashAttentionDecodeVxReduce signature, and the CacheHint so
identical shaders share one cached pipeline.
@qjia7 qjia7 marked this pull request as ready for review June 25, 2026 10:09
@qjia7 qjia7 requested review from guschmue and hariharans29 June 25, 2026 10:10
@qjia7

qjia7 commented Jun 26, 2026

Copy link
Copy Markdown
Contributor Author

FYI @sushraja-msft Once this PR is merged, you may need to rebase #28059

@qjia7 qjia7 merged commit 92b4c66 into microsoft:main Jun 26, 2026
91 of 92 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants