Skip to content

[CPU] Add FP32 GEMV decode kernel for GroupQueryAttention#29216

Open
tianleiwu wants to merge 3 commits into
mainfrom
tlwu/20260608/gqa_cpu_decode_gemv
Open

[CPU] Add FP32 GEMV decode kernel for GroupQueryAttention#29216
tianleiwu wants to merge 3 commits into
mainfrom
tlwu/20260608/gqa_cpu_decode_gemv

Conversation

@tianleiwu

@tianleiwu tianleiwu commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

Description

PR1 #28962 adds flash attention for prefill, and removed flash decoding. This PR will add optimized kernel for single-token decode, which will be faster than other kernels including flash decoding.

This PR builds on the prefill-only flash attention change and additionally introduces a dedicated decode kernel.

What's included

  • Decode (GEMV) kernel — A dedicated single-token decode kernel (MlasGQADecodeGQAThreaded) for sequence_length == 1, parallelized over (batch, head) with a two-pass softmax, using GEMV (acc[8]-lane dot product / AXPY) helpers instead of per-block M=1 SGEMM calls. This fixes the per-block SGEMM decode regression.
  • The FP32 flash gate (group_query_attention.cc) is enabled for total_sequence_length > 1, routing prefill to the tiled kernel and decode to the GEMV kernel.
  • The quantized KV-cache path is unchanged (FP32-only scope).

Results (AMD EPYC 7763, AVX2, 8 threads)

  • Decode: correctness ~1e-8 vs naive; long-context decode ~1.0–1.5x (T = 4097 ~1.3–1.5x).

Motivation and Context

The naive GQA path materializes the full score matrix, which is memory-bound for long sequences. Flash attention reduces memory traffic for prefill, and the GEMV decode kernel avoids SGEMM overhead for the M=1 decode case.

Testing

  • Built with --compile_no_warning_as_error.
  • Correctness verified against the naive path for both prefill and decode (max abs diff ~1e-8).
  • Benchmarked via benchmark_gqa_cpu_flash.py.

@tianleiwu tianleiwu changed the title [CPU] Add FP32 flash attention (prefill) and GEMV decode kernel for GroupQueryAttention [CPU] Add FP32 GEMV decode kernel for GroupQueryAttention Jun 23, 2026
Single-token decode (sequence_length == 1) falls back to the naive path. A dedicated FP32 decode kernel will be added in a follow-up PR. The quantized path is unchanged.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an optimized CPU FP32 single-token decode path for com.microsoft.GroupQueryAttention, aiming to eliminate the decode regression from routing M=1 work through per-block SGEMM by introducing GEMV-based decode (and GEMV-based flash-decoding partials).

Changes:

  • Add GEMV-based decode helpers/kernels for sequence_length == 1, including optional two-phase “flash decoding” KV-chunk reduction.
  • Update FP32 flash gating to activate when total_sequence_length > 1, enabling prefill via tiled flash attention and decode via the new GEMV path.
  • Update CPU GQA documentation to describe the new decode behavior and performance characteristics.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
onnxruntime/core/mlas/lib/flashattn_gqa.cpp Adds GEMV decode helpers plus new decode/flash-decoding threaded kernels and dispatch logic.
onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Adjusts FP32 flash routing gate to include decode when total_sequence_length > 1.
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h Allocates/partitions scratch buffers for decode vs flash-decoding; wires new args fields into MLAS call.
docs/contrib_ops/cpu/gqa.md Updates docs to describe FP32 decode GEMV kernel and flash-decoding behavior.

Comment thread onnxruntime/core/mlas/lib/flashattn_gqa.cpp Outdated
Comment thread onnxruntime/core/mlas/lib/flashattn_gqa.cpp
Comment thread onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Adds a dedicated GEMV kernel (MlasGQADecodeGQAThreaded) for single-token decode (sequence_length == 1), and converts the flash-decoding inner M=1 GEMMs to GEMV. Re-enables the FP32 flash gate for decode (total_sequence_length > 1). Verified correctness vs naive (~1e-8); long-context decode ~1.0-1.2x, fixing the prior per-block SGEMM decode regression.
@tianleiwu tianleiwu force-pushed the tlwu/20260608/gqa_cpu_decode_gemv branch from 4042bd2 to 284d4a5 Compare June 25, 2026 02:36
…test

- Gate use_flash_decoding on common_past_seqlen >= 0 so the small per-thread
  flash-decoding scratch buffer is only selected when the unified KV-split
  kernel runs. Ragged/per-batch decode falls back to MlasGQADecodeGQAThreaded
  which needs a larger scratch (scores[total_seqlen] + temp_output[head_size]);
  previously it reused the small buffer and threads overran each other's
  scratch, producing non-deterministic output for batch>1 ragged decode.
- Add test_gqa_decode_flash_vs_naive_parity comparing both the flash and naive
  (ORT_GQA_DISABLE_FLASH_ATTENTION=1) decode paths to the reference (addresses
  review thread 4).
- Correct flashattn_gqa.cpp file header to describe the decode GEMV helpers
  (addresses review thread 1).
@tianleiwu tianleiwu marked this pull request as ready for review June 25, 2026 04:50
@tianleiwu tianleiwu requested a review from Copilot June 25, 2026 04:50
@tianleiwu tianleiwu requested a review from hariharans29 June 25, 2026 04:51

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated no new comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants