[CPU] Add FP32 GEMV decode kernel for GroupQueryAttention#29216
Open
tianleiwu wants to merge 3 commits into
Open
[CPU] Add FP32 GEMV decode kernel for GroupQueryAttention#29216tianleiwu wants to merge 3 commits into
tianleiwu wants to merge 3 commits into
Conversation
Single-token decode (sequence_length == 1) falls back to the naive path. A dedicated FP32 decode kernel will be added in a follow-up PR. The quantized path is unchanged.
49ffc40 to
4042bd2
Compare
Contributor
There was a problem hiding this comment.
Pull request overview
Adds an optimized CPU FP32 single-token decode path for com.microsoft.GroupQueryAttention, aiming to eliminate the decode regression from routing M=1 work through per-block SGEMM by introducing GEMV-based decode (and GEMV-based flash-decoding partials).
Changes:
- Add GEMV-based decode helpers/kernels for
sequence_length == 1, including optional two-phase “flash decoding” KV-chunk reduction. - Update FP32 flash gating to activate when
total_sequence_length > 1, enabling prefill via tiled flash attention and decode via the new GEMV path. - Update CPU GQA documentation to describe the new decode behavior and performance characteristics.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
onnxruntime/core/mlas/lib/flashattn_gqa.cpp |
Adds GEMV decode helpers plus new decode/flash-decoding threaded kernels and dispatch logic. |
onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc |
Adjusts FP32 flash routing gate to include decode when total_sequence_length > 1. |
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h |
Allocates/partitions scratch buffers for decode vs flash-decoding; wires new args fields into MLAS call. |
docs/contrib_ops/cpu/gqa.md |
Updates docs to describe FP32 decode GEMV kernel and flash-decoding behavior. |
Adds a dedicated GEMV kernel (MlasGQADecodeGQAThreaded) for single-token decode (sequence_length == 1), and converts the flash-decoding inner M=1 GEMMs to GEMV. Re-enables the FP32 flash gate for decode (total_sequence_length > 1). Verified correctness vs naive (~1e-8); long-context decode ~1.0-1.2x, fixing the prior per-block SGEMM decode regression.
4042bd2 to
284d4a5
Compare
…test - Gate use_flash_decoding on common_past_seqlen >= 0 so the small per-thread flash-decoding scratch buffer is only selected when the unified KV-split kernel runs. Ragged/per-batch decode falls back to MlasGQADecodeGQAThreaded which needs a larger scratch (scores[total_seqlen] + temp_output[head_size]); previously it reused the small buffer and threads overran each other's scratch, producing non-deterministic output for batch>1 ragged decode. - Add test_gqa_decode_flash_vs_naive_parity comparing both the flash and naive (ORT_GQA_DISABLE_FLASH_ATTENTION=1) decode paths to the reference (addresses review thread 4). - Correct flashattn_gqa.cpp file header to describe the decode GEMV helpers (addresses review thread 1).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
PR1 #28962 adds flash attention for prefill, and removed flash decoding. This PR will add optimized kernel for single-token decode, which will be faster than other kernels including flash decoding.
This PR builds on the prefill-only flash attention change and additionally introduces a dedicated decode kernel.
What's included
MlasGQADecodeGQAThreaded) forsequence_length == 1, parallelized over (batch, head) with a two-pass softmax, using GEMV (acc[8]-lane dot product / AXPY) helpers instead of per-block M=1 SGEMM calls. This fixes the per-block SGEMM decode regression.group_query_attention.cc) is enabled fortotal_sequence_length > 1, routing prefill to the tiled kernel and decode to the GEMV kernel.Results (AMD EPYC 7763, AVX2, 8 threads)
Motivation and Context
The naive GQA path materializes the full score matrix, which is memory-bound for long sequences. Flash attention reduces memory traffic for prefill, and the GEMV decode kernel avoids SGEMM overhead for the M=1 decode case.
Testing
--compile_no_warning_as_error.benchmark_gqa_cpu_flash.py.