From b2613d9a81ae3404fc55f437ddc8f3deef83d91e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 9 Jun 2026 19:35:40 +0000 Subject: [PATCH 1/5] feat(cpu): add flash attention for non-quantized GQA Add an FP32 tiled online-softmax flash attention kernel for the CPU GroupQueryAttention contrib op, mirroring the existing quantized-KV flash path. Avoids materializing the full attention score matrix and adds a two-phase flash-decoding path for single-token decode. - New MLAS kernel core/mlas/lib/flashattn_gqa.cpp (MlasFlashAttentionGQA) supporting GQA head grouping, causal masking, local window, attention bias, ragged/per-batch seqlens, packed QKV, and flash-decoding. - New ApplyAttentionFlash dispatch in gqa_attention_base.h; wired into group_query_attention.cc (float only, gated like the quantized flash path: no softcap/smooth softmax/head sink/QK output). - Reuses ORT_GQA_DISABLE_FLASH_ATTENTION to fall back to the naive path. --- cmake/onnxruntime_mlas.cmake | 1 + docs/contrib_ops/cpu/gqa.md | 58 +- .../contrib_ops/cpu/bert/gqa_attention_base.h | 310 +++++++++ .../cpu/bert/group_query_attention.cc | 21 + onnxruntime/core/mlas/inc/mlas.h | 57 ++ onnxruntime/core/mlas/lib/flashattn_gqa.cpp | 610 ++++++++++++++++++ 6 files changed, 1052 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/flashattn_gqa.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index b254b40f88e76..1c47ac4ef4569 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -56,6 +56,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp ${MLAS_SRC_DIR}/flashattn_qkv.cpp + ${MLAS_SRC_DIR}/flashattn_gqa.cpp ${MLAS_SRC_DIR}/qkv_quant.cpp ${MLAS_SRC_DIR}/cast.cpp ${MLAS_SRC_DIR}/layernorm.cpp diff --git a/docs/contrib_ops/cpu/gqa.md b/docs/contrib_ops/cpu/gqa.md index e5a211c9fd11a..d3b7f25c6fdba 100644 --- a/docs/contrib_ops/cpu/gqa.md +++ b/docs/contrib_ops/cpu/gqa.md @@ -17,7 +17,12 @@ Quantized KV-cache GEMM helpers are implemented in MLAS: - `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp` - `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp` - `onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp` -- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (flash attention tiled kernel) +- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (quantized-KV flash attention tiled kernel) + +The non-quantized flash attention tiled kernel is implemented in MLAS: + +- `onnxruntime/core/mlas/lib/flashattn_gqa.cpp` (FP32-KV flash attention tiled kernel) +- `onnxruntime/core/mlas/inc/mlas.h` (`MlasFlashAttentionGQA` declaration and `MlasFlashAttentionGQAArgs`) The operator schema itself is defined in: @@ -48,12 +53,14 @@ At a high level, the CPU kernel executes GroupQueryAttention in these stages: The non-quantized and quantized paths share the surrounding validation, masking, softmax, and output flow. Their main difference is how the K/V cache is stored and read during QK and SV GEMMs. -The quantized path has two execution strategies: +Both the non-quantized and quantized paths have two execution strategies: - **Naive (full materialization)**: Computes the full `[S, T]` attention score matrix, applies masking and softmax, then computes the SV product. Simple but memory-intensive for long sequences. - **Flash Attention (tiled, online softmax)**: Processes K/V in L2-cache-sized blocks using the online softmax algorithm (Milakov & Gimelshein, 2018). Avoids materializing the full attention matrix, reducing peak memory from O(S×T) to O(S×Bc) per head. Multi-threaded via the MLAS thread pool. -The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path. +The quantized path uses `MlasFlashAttentionQuantizedKV` (`flashattn_qkv.cpp`); the non-quantized FP32 path uses `MlasFlashAttentionGQA` (`flashattn_gqa.cpp`). Both share the same tiling, masking, online-softmax, and flash-decoding structure. + +The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path (applies to both the quantized and non-quantized paths). ## Supported Cache Modes @@ -144,9 +151,9 @@ For quantized V cache, the CPU path calls `MlasSVGemm` with: As with QK GEMM, the default MLAS contract preserves the FP32 left-hand operand and dequantizes only the cached V values on the fly. -## Flash Attention Path +## Quantized Flash Attention Path -The flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix. +The quantized flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix. ### Algorithm @@ -204,6 +211,47 @@ The partials buffer is allocated alongside the per-thread scratch in a single al - Per-thread scratch: `scores[Bc]` (one float per KV block element) - Partials: `batch × num_heads × kv_chunks × (2 + H)` floats (m, l, and partial output per chunk) +## Non-Quantized Flash Attention Path + +The non-quantized flash attention path (`MlasFlashAttentionGQA`, in `flashattn_gqa.cpp`) is the FP32-KV-cache counterpart of the quantized path. It is selected for the `float` kernel specialization and reuses the same tiling, online-softmax, masking, and flash-decoding structure. + +### Differences from the Quantized Path + +- **Cache element type**: The present K/V cache is FP32, laid out as BNSH (`[batch, kv_num_heads, seqlen_present, head_size]`). There is no quantize-on-write or dequantize-on-read step. +- **QK GEMM**: Uses the single-threaded SGEMM primitive `MlasSgemmOperation(CblasNoTrans, CblasTrans, ...)` on an FP32 K block instead of `MlasQKGemm`. +- **SV accumulate**: Uses `MlasSgemmOperation(CblasNoTrans, CblasNoTrans, ..., beta)` with `beta = 0` for the first KV block and `beta = 1` afterwards (accumulate) instead of `MlasSVGemm`. +- **Cache concat**: New K/V tokens are appended into the FP32 present cache with `ConcatStateChunkGQA` before the tiled loop runs. + +### Algorithm + +For each (batch, head, q_block) tile: + +1. **QK GEMM** — `MlasSgemmOperation` of the query tile against a block slice of the FP32 K cache (Bc rows at a time) +1b. **Attention bias** — Add the corresponding tile of the bias tensor (if present) to QK scores +2. **Causal + local window masking** — Set masked positions to −∞ before softmax +3. **Online softmax** — Track running max `m` and sum `l`, rescale accumulated output with `exp(m_old − m_new)` +4. **SV accumulate** — `MlasSgemmOperation(..., beta)` accumulates `softmax(QK_block) × V_block` into the output tile +5. **Finalize** — Normalize accumulated output by `1/l` after all KV blocks are processed + +### Activation Conditions + +The non-quantized flash path is selected when ALL of the following hold: + +- The kernel specialization is `float` (FP16 uses the naive path) +- `ORT_GQA_DISABLE_FLASH_ATTENTION` environment variable is not set (or set to `0`) +- `total_sequence_length > 1` +- No softcap +- No smooth softmax +- No head sink +- No output QK capture +- `present_key` and `present_value` are provided + +Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, shared past/present buffers, and flash decoding are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path. + +### Block Sizes, Threading, and Flash Decoding + +Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`), and the two-phase flash-decoding strategy for single-token decode are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. + ## MLAS Dispatch Paths MLAS selects the best available quantized KV-cache GEMM implementation through the platform dispatch table. diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 12f61cddea18c..d66ed2cb0fb7d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -903,6 +903,316 @@ class GQAAttentionBase { return Status::OK(); } + // Non-quantized flash attention path. Only supports T = float. + // Concatenates new K/V into the FP32 present cache, then runs the tiled + // online-softmax kernel MlasFlashAttentionGQA (QK^T + softmax + S*V fused). + Status ApplyAttentionFlash( + const float* Q, // Q data [B, N, S, H] BNSH + const float* K, // K data [B, N_kv, L, H] or nullptr for packed_qkv + const float* V, // V data [B, N_kv, L, H] or nullptr for packed_qkv + const Tensor* attention_bias, // additive bias [B|1, N|1, S, T] or nullptr + const Tensor* past_key, // past K (float) + const Tensor* past_value, // past V (float) + Tensor* output, // output [B, S, N*H] float + Tensor* present_key, // present K (float) + Tensor* present_value, // present V (float) + const Tensor* seqlens_k, + GroupQueryAttentionParameters& parameters, + AllocatorPtr allocator, + OpKernelContext* context) const { + const bool is_prompt = parameters.is_first_prompt; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int head_size = parameters.head_size; + const int hidden_size = parameters.hidden_size; + const bool packed_qkv = parameters.is_packed_qkv; + + auto* tp = context->GetOperatorThreadPool(); + + int seqlen_past_kv_cache = 0; + if (past_key != nullptr && past_value != nullptr) { + seqlen_past_kv_cache = static_cast(past_key->Shape().GetDims()[2]); + } + int seqlen_present_kv_cache = present_key != nullptr + ? static_cast(present_key->Shape().GetDims()[2]) + : parameters.total_sequence_length; + + if (kv_sequence_length == 0) { + ORT_ENFORCE(parameters.total_sequence_length <= seqlen_past_kv_cache, + "total_seqlen (", parameters.total_sequence_length, ") exceeds past buffer size (", + seqlen_past_kv_cache, ") in shared KV mode"); + } + + ORT_RETURN_IF(present_key == nullptr || present_value == nullptr, + "present_key and present_value must be provided for flash attention"); + + const float* past_key_data = past_key != nullptr ? past_key->Data() : nullptr; + float* present_key_data = present_key->MutableData(); + const float* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; + float* present_value_data = present_value->MutableData(); + + bool past_present_share_buffer = (past_key_data == present_key_data) && + (past_value_data == present_value_data); + + const int32_t* seqlens_k_data = seqlens_k->Data(); + + // Attention bias setup + const float* attention_bias_data = nullptr; + int attention_bias_seqlen_stride = 0; + bool attention_bias_broadcast_batch = true; + bool attention_bias_broadcast_head = true; + if (attention_bias != nullptr) { + attention_bias_data = attention_bias->Data(); + auto bias_shape = attention_bias->Shape().GetDims(); + attention_bias_seqlen_stride = static_cast(bias_shape[3]); + attention_bias_broadcast_batch = (bias_shape[0] == 1); + attention_bias_broadcast_head = (bias_shape[1] == 1); + } + + // K/V base pointers (FP32, new tokens) + const float* k_base = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; + const float* v_base = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; + + const ptrdiff_t packed_batch_stride = + packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size + : SafeInt(0); + const size_t kv_input_chunk_length = SafeInt(kv_sequence_length) * head_size; + const size_t past_buff_chunk_length = SafeInt(seqlen_past_kv_cache) * head_size; + const size_t present_buff_chunk_length = SafeInt(seqlen_present_kv_cache) * head_size; + + // ---- Phase 1: Concat new K/V into present cache ---- + // We must do this first so the flash attention kernel can read the full present cache. + if (present_key_data && !past_present_share_buffer) { + memset(present_key_data, 0, + SafeInt(batch_size) * kv_num_heads_ * present_buff_chunk_length * sizeof(float)); + memset(present_value_data, 0, + SafeInt(batch_size) * kv_num_heads_ * present_buff_chunk_length * sizeof(float)); + } + + // Concat K and V caches (parallelize over batch * kv_num_heads) + { + const size_t concat_loop_len = batch_size * kv_num_heads_; + TensorOpCost concat_cost; + concat_cost.compute_cycles = static_cast(kv_sequence_length * head_size); + concat_cost.bytes_loaded = static_cast((past_buff_chunk_length + kv_input_chunk_length) * sizeof(float)); + concat_cost.bytes_stored = static_cast(present_buff_chunk_length * sizeof(float)); + + ThreadPool::TryParallelFor(tp, concat_loop_len, concat_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t kv_idx = begin; kv_idx != end; ++kv_idx) { + const size_t batch_index = kv_idx / kv_num_heads_; + const size_t kv_head_index = kv_idx % kv_num_heads_; + const size_t total_seqlen = SafeInt(seqlens_k_data[batch_index]) + 1; + + size_t past_seqlen; + if (past_key == nullptr) { + past_seqlen = 0; + } else if (kv_sequence_length == 0) { + past_seqlen = total_seqlen; + } else if (is_prompt) { + past_seqlen = 0; + } else { + past_seqlen = total_seqlen - sequence_length; + } + const size_t past_chunk_length = past_seqlen * head_size; + + // Concat K + const float* k_new; + if (packed_qkv) { + k_new = k_base + packed_batch_stride * batch_index + + kv_input_chunk_length * kv_head_index; + } else { + k_new = k_base + kv_input_chunk_length * kv_idx; + } + ConcatStateChunkGQA(past_key_data, k_new, present_key_data, + present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, + past_present_share_buffer, kv_idx); + + // Concat V + const float* v_new; + if (packed_qkv) { + v_new = v_base + packed_batch_stride * batch_index + + kv_input_chunk_length * kv_head_index; + } else { + v_new = v_base + kv_input_chunk_length * kv_idx; + } + ConcatStateChunkGQA(past_value_data, v_new, present_value_data, + present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, + past_present_share_buffer, kv_idx); + } + }); + } + + // ---- Phase 2: Flash Attention with FP32 KV cache ---- + // Compute L2-aware block sizes (same formula as MHA flash attention). + const auto& env = Env::Default(); + int l2_cache_size = env.GetL2CacheSize(); + + int kv_block_size = l2_cache_size / (static_cast(sizeof(float)) * 4 * (head_size + head_size)); + kv_block_size = std::max(kv_block_size, 1); + int q_block_size = std::min(kv_block_size, 2 * head_size); + + // The flash kernel uses a single (past_seqlen, total_seqlen) pair for all batch items. + // When batch items have different seqlens_k (ragged), fall back to per-batch invocation + // so each batch item gets its own correct causal offset. + int max_total_seqlen = 0; + int min_total_seqlen = std::numeric_limits::max(); + int common_past_seqlen = 0; + for (int b = 0; b < batch_size; ++b) { + int total_sl = seqlens_k_data[b] + 1; + max_total_seqlen = std::max(max_total_seqlen, total_sl); + min_total_seqlen = std::min(min_total_seqlen, total_sl); + } + const bool ragged_seqlens = (max_total_seqlen != min_total_seqlen); + + if (ragged_seqlens) { + common_past_seqlen = -1; // sentinel: per-batch + } else if (past_key == nullptr || is_prompt) { + common_past_seqlen = 0; + } else if (kv_sequence_length == 0) { + // Shared buffer mode: each batch item has its own past_seqlen. + common_past_seqlen = -1; // sentinel: per-batch + } else { + common_past_seqlen = max_total_seqlen - sequence_length; + } + + // Cap block sizes + kv_block_size = std::min(kv_block_size, max_total_seqlen); + q_block_size = std::min(q_block_size, sequence_length); + + int thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp); + thread_count = std::max(thread_count, 1); + + // Flash decoding: for decode (sequence_length==1), partition KV across threads + // to improve parallelism when batch*heads < thread_count. + const int kv_chunk_count = (max_total_seqlen + kv_block_size - 1) / kv_block_size; + const bool use_flash_decoding = (sequence_length == 1 && + batch_size * num_heads_ < thread_count && + kv_chunk_count > 1); + + size_t buffer_size_per_thread; + size_t partials_buffer_bytes = 0; + if (use_flash_decoding) { + // Flash decoding: per-thread scratch only needs scores[kv_block_size] + buffer_size_per_thread = static_cast(kv_block_size) * sizeof(float); + // Partials: [batch * num_heads * kv_chunk_count * (2 + head_size)] floats + partials_buffer_bytes = static_cast(batch_size) * num_heads_ * + kv_chunk_count * (2 + head_size) * sizeof(float); + } else { + buffer_size_per_thread = + (static_cast(q_block_size) * 2 + // l + m + static_cast(q_block_size) * static_cast(kv_block_size) + // scores + static_cast(q_block_size) * static_cast(head_size)) * // temp_output + sizeof(float); + } + size_t total_buffer_bytes = buffer_size_per_thread * thread_count + partials_buffer_bytes; + auto flash_buffer_alloc = allocator->Alloc(total_buffer_bytes); + BufferUniquePtr flash_buffer(flash_buffer_alloc, BufferDeleter(allocator)); + + // Partials buffer is placed after per-thread scratch + float* partials_ptr = use_flash_decoding + ? reinterpret_cast(reinterpret_cast(flash_buffer_alloc) + + buffer_size_per_thread * thread_count) + : nullptr; + + const float scale = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; + + // If all batch items share the same past_seqlen, use the unified flash kernel. + // Otherwise, fall back to per-batch invocation. + if (common_past_seqlen >= 0) { + MlasFlashAttentionGQAArgs args; + args.batch_size = batch_size; + args.num_heads = num_heads_; + args.kv_num_heads = kv_num_heads_; + args.sequence_length = sequence_length; + args.total_seqlen = max_total_seqlen; + args.head_size = head_size; + args.past_seqlen = common_past_seqlen; + args.local_window_size = local_window_size_; + args.seqlen_present_kv = seqlen_present_kv_cache; + args.q_block_size = q_block_size; + args.kv_block_size = kv_block_size; + args.scale = scale; + args.thread_count = thread_count; + args.buffer = reinterpret_cast(flash_buffer_alloc); + args.buffer_size_per_thread = buffer_size_per_thread; + args.query = Q; + args.q_batch_stride = packed_qkv + ? static_cast(packed_batch_stride) + : static_cast(SafeInt(num_heads_) * sequence_length * head_size); + args.k_cache = present_key_data; + args.v_cache = present_value_data; + args.output = output->MutableData(); + args.attention_bias = attention_bias_data; + args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; + args.attention_bias_broadcast_batch = attention_bias_broadcast_batch; + args.attention_bias_broadcast_head = attention_bias_broadcast_head; + args.flash_decoding_partials = partials_ptr; + args.kv_chunk_count = kv_chunk_count; + + MlasFlashAttentionGQA(&args, tp); + } else { + // Per-batch handling for variable past_seqlen (shared KV buffer mode or ragged seqlens) + for (int b = 0; b < batch_size; ++b) { + int total_sl = seqlens_k_data[b] + 1; + int batch_past_seqlen = (past_key == nullptr || is_prompt) + ? 0 + : std::max(0, total_sl - sequence_length); + + MlasFlashAttentionGQAArgs args; + args.batch_size = 1; + args.num_heads = num_heads_; + args.kv_num_heads = kv_num_heads_; + args.sequence_length = sequence_length; + args.total_seqlen = total_sl; + args.head_size = head_size; + args.past_seqlen = batch_past_seqlen; + args.local_window_size = local_window_size_; + args.seqlen_present_kv = seqlen_present_kv_cache; + args.q_block_size = q_block_size; + args.kv_block_size = std::min(kv_block_size, total_sl); + args.scale = scale; + args.thread_count = thread_count; + args.buffer = reinterpret_cast(flash_buffer_alloc); + args.buffer_size_per_thread = buffer_size_per_thread; + + // Offset Q and output for this batch + const ptrdiff_t q_batch_stride_elems = packed_batch_stride > 0 + ? packed_batch_stride + : static_cast(SafeInt(num_heads_) * sequence_length * head_size); + args.query = Q + static_cast(b) * static_cast(q_batch_stride_elems); + args.q_batch_stride = SafeInt(num_heads_) * sequence_length * head_size; + args.k_cache = present_key_data + + static_cast(b) * kv_num_heads_ * present_buff_chunk_length; + args.v_cache = present_value_data + + static_cast(b) * kv_num_heads_ * present_buff_chunk_length; + args.output = output->MutableData() + + static_cast(b) * sequence_length * hidden_size; + + // Slice attention bias for this batch (the kernel sees batch_size=1, so batch_idx=0 inside). + // Bias shape is [batch|1, num_heads|1, S, T]; the batch stride uses the actual head + // extent (1 when the head dim is broadcast). + const float* batch_bias = attention_bias_data; + if (attention_bias_data != nullptr && !attention_bias_broadcast_batch) { + const size_t bias_head_extent = attention_bias_broadcast_head ? 1 : static_cast(num_heads_); + batch_bias += static_cast(b) * bias_head_extent * sequence_length * attention_bias_seqlen_stride; + } + args.attention_bias = batch_bias; + args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; + args.attention_bias_broadcast_batch = true; // batch offset handled above + args.attention_bias_broadcast_head = attention_bias_broadcast_head; + args.flash_decoding_partials = nullptr; // per-batch doesn't use flash decoding + args.kv_chunk_count = 0; + + MlasFlashAttentionGQA(&args, tp); + } + } + + return Status::OK(); + } + private: // Helper function to compute the attention probs. It does 2 things: // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 61ae474703213..29d372eb7a4bb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -343,6 +343,27 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // Compute the attention score and apply the score to V const T* k_data = packed_qkv ? nullptr : k_rotary; const T* v_data = packed_qkv ? nullptr : V.Get().Data(); + + // Non-quantized flash attention path (float only). Uses the tiled online-softmax + // kernel to avoid materializing the full attention score matrix. Falls back to the + // naive path when an unsupported feature is requested (softcap, smooth softmax, + // head sink, or QK output). + if constexpr (std::is_same_v) { + const bool use_flash = !disable_gqa_flash_ && + parameters.total_sequence_length > 1 && + softcap_ == 0.0f && + !use_smooth_softmax_ && + head_sink_data == nullptr && + output_qk == nullptr && + present_k != nullptr && present_v != nullptr; + if (use_flash) { + return ApplyAttentionFlash(q_rotary, k_data, v_data, + attention_bias, past_key, past_value, + output, present_k, present_v, seqlens_k, + parameters, allocator, context); + } + } + return ApplyAttention(q_rotary, k_data, v_data, head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, output_qk, seqlens_k, parameters, allocator, context); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 99b72dc756663..ec2398dd1ee0f 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -2281,6 +2281,63 @@ MlasFlashAttention( MLAS_THREADPOOL* ThreadPool ); +// +// Flash Attention for non-quantized (FP32) GroupQueryAttention KV cache. +// +// Adapts the online-softmax tiled algorithm to operate on an FP32 present +// K/V cache laid out as BNSH ([batch, kv_num_heads, seqlen_present, head_size]). +// Supports GQA head grouping, causal masking, local window attention, +// additive attention bias, and an optional flash-decoding split over the KV +// sequence dimension for the single-token decode case. +// +struct MlasFlashAttentionGQAArgs { + int batch_size; + int num_heads; // number of query heads + int kv_num_heads; // number of key/value heads (num_heads % kv_num_heads == 0) + int sequence_length; // number of new query tokens (S) + int total_seqlen; // total tokens (past + new) for this invocation (T) + int head_size; // per-head size (H) + int past_seqlen; // causal offset (number of cached tokens before the new ones) + int local_window_size; // -1 disables local window masking + int seqlen_present_kv; // sequence dimension of the present K/V buffer + int q_block_size; // query tile size (Br) + int kv_block_size; // key/value tile size (Bc) + float scale; // QK scaling factor + int thread_count; // number of partitions / threads + float* buffer; // per-thread scratch (+ optional flash-decoding partials) + size_t buffer_size_per_thread; + + const float* query; // [batch, num_heads, sequence_length, head_size] BNSH + size_t q_batch_stride; // element stride between consecutive batches in `query` + // (num_heads*S*H for unpacked, (num_heads+2*kv_num_heads)*S*H for packed QKV) + const float* k_cache; // [batch, kv_num_heads, seqlen_present, head_size] FP32 + const float* v_cache; // [batch, kv_num_heads, seqlen_present, head_size] FP32 + float* output; // [batch, sequence_length, num_heads, head_size] BSNH + + const float* attention_bias; // [batch|1, num_heads|1, S, T] additive bias, or nullptr + int attention_bias_seqlen_stride; + bool attention_bias_broadcast_batch; + bool attention_bias_broadcast_head; + + // Flash decoding (sequence_length == 1): partition KV across threads. + // Set flash_decoding_partials != nullptr to enable; otherwise the standard + // per-(batch, head, q_block) partitioning is used. + float* flash_decoding_partials; + int kv_chunk_count; +}; + +/** + * @brief FP32 Flash Attention for GroupQueryAttention with an FP32 KV cache. + * @param args Arguments + * @param ThreadPool Thread pool + */ +void +MLASCALL +MlasFlashAttentionGQA( + MlasFlashAttentionGQAArgs* args, + MLAS_THREADPOOL* ThreadPool +); + /** * @brief Enumeration of supported GELU algorithm variants. * diff --git a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp new file mode 100644 index 0000000000000..2f62b9a8b0735 --- /dev/null +++ b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp @@ -0,0 +1,610 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + flashattn_gqa.cpp + +Abstract: + + Flash Attention kernel for the non-quantized (FP32) GroupQueryAttention + KV cache. + + Adapts the online-softmax tiled algorithm from flashattn.cpp to operate on + an FP32 present K/V cache laid out as BNSH + ([batch, kv_num_heads, seqlen_present, head_size]) and to support GQA head + grouping (num_heads % kv_num_heads == 0), causal masking, local window + attention, additive attention bias, and an optional flash-decoding split + over the KV sequence dimension for single-token decode. + + QK^T and S*V use the single-threaded SGEMM primitive MlasSgemmOperation; + the outer parallelism is provided by MlasExecuteThreaded. + +--*/ + +#include +#include +#include +#include + +#include "mlasi.h" + +void +MlasFlashAttentionGQAThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t q_block_size = static_cast(args->q_block_size); + const ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t sequence_length = static_cast(args->sequence_length); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + const size_t kv_head_stride = + static_cast(args->seqlen_present_kv) * static_cast(head_size); + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // Total tasks: one per (batch, head, q_block) + const ptrdiff_t q_chunk_count = (sequence_length + q_block_size - 1) / q_block_size; + const ptrdiff_t total_task_count = batch_size * num_heads * q_chunk_count; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t batch_idx = task_index; + ptrdiff_t q_idx = (batch_idx % q_chunk_count) * q_block_size; + batch_idx /= q_chunk_count; + ptrdiff_t head_idx = batch_idx % num_heads; + batch_idx /= num_heads; + + // Per-thread buffer layout: + // l[q_block_size] - running sum for online softmax + // m[q_block_size] - running max for online softmax + // scores[q_block_size * kv_block_size] - QK scores (S) + // temp_output[q_block_size * head_size] - accumulated output + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* l = reinterpret_cast(buffer_ptr); + float* m = l + q_block_size; + float* scores = m + q_block_size; + float* temp_output = scores + q_block_size * kv_block_size; + + // Initialize running state + for (ptrdiff_t t = 0; t < q_block_size; ++t) { + m[t] = std::numeric_limits::lowest(); + l[t] = 0.0f; + } + memset(temp_output, 0, static_cast(q_block_size * head_size) * sizeof(float)); + + const size_t row_size_q = static_cast(std::min(q_block_size, sequence_length - q_idx)); + + // Determine KV head index for GQA head sharing + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + + // K/V cache pointers. Layout: [batch, kv_num_heads, seqlen_present, head_size] + const size_t kv_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + kv_head_stride; + const float* k_cache_head = args->k_cache + kv_batch_head_offset; + const float* v_cache_head = args->v_cache + kv_batch_head_offset; + + // Q pointer: layout [batch, num_heads, seq, head_size]. The batch stride is + // supplied separately (args->q_batch_stride) so the kernel works with both the + // standard BNSH layout and packed-QKV input where Q/K/V are interleaved per batch. + const float* q_ptr = args->query + + static_cast(batch_idx) * args->q_batch_stride + + static_cast(head_idx) * static_cast(sequence_length) * static_cast(head_size) + + static_cast(q_idx) * static_cast(head_size); + + // Iterate over KV blocks + for (ptrdiff_t ir = 0; ir < total_seqlen; ir += kv_block_size) { + const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); + + // Step 1: QK^T GEMM with FP32 K block + const float* k_block = k_cache_head + static_cast(ir) * static_cast(head_size); + MlasSgemmOperation( + CblasNoTrans, + CblasTrans, + row_size_q, // M + row_size_kv, // N + static_cast(head_size), // K + scale, // alpha + q_ptr, // A (FP32 query) + static_cast(head_size), // lda + k_block, // B (FP32 K block) + static_cast(head_size), // ldb + 0.0f, // beta + scores, // C (output scores) + row_size_kv // ldc + ); + + // Step 1b: Apply attention bias (additive) if present + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_seqlen_stride = + static_cast(args->attention_bias_seqlen_stride); + const ptrdiff_t bias_matrix_size = + static_cast(sequence_length) * bias_seqlen_stride; + // The bias tensor has shape [batch|1, num_heads|1, S, T]; the batch + // stride uses the actual head extent (1 when the head dim is broadcast). + const ptrdiff_t bias_head_extent = + args->attention_bias_broadcast_head ? 1 : static_cast(num_heads); + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * + bias_head_extent * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + // Add bias tile: bias[q_idx + irow, ir + jcol] + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + const float* bias_row = args->attention_bias + bias_offset + + (q_idx + irow) * bias_seqlen_stride + ir; + float* s_row = scores + irow * static_cast(row_size_kv); + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + s_row[jcol] += bias_row[jcol]; + } + } + } + + // Step 2: Apply causal mask and Step 3: Online softmax update + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + float* p = scores + irow * static_cast(row_size_kv); + const ptrdiff_t global_q_pos = past_seqlen + q_idx + irow; + const ptrdiff_t causal_limit = global_q_pos + 1; // can attend to positions [0, causal_limit) + + // Apply causal masking + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos >= causal_limit) { + p[jcol] = std::numeric_limits::lowest(); + } + } + + // Apply local window masking if enabled + if (local_window_size >= 0) { + const ptrdiff_t window_start = + (causal_limit > local_window_size) ? (causal_limit - local_window_size) : 0; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos < window_start) { + p[jcol] = std::numeric_limits::lowest(); + } + } + } + + // Online softmax: find row max, update running max +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, row_size_kv); +#else + float rowmax = MlasReduceMaximumF32Kernel(p, row_size_kv); +#endif + + // If the entire row is masked (all scores are -inf), zero the scores + // so the S*V GEMM contributes nothing and skip the softmax state update. + if (rowmax == std::numeric_limits::lowest()) { + memset(p, 0, row_size_kv * sizeof(float)); + continue; + } + + float m_old = m[irow]; + m[irow] = std::max(m[irow], rowmax); + float m_diff = m_old - m[irow]; // <= 0 + + // Compute exp(score - m_new) for each element + float negmax = -m[irow]; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, row_size_kv, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(p, p, row_size_kv, &negmax); +#endif + + // Rescale previous state + if (ir != 0) { + float exp_diff = std::exp(m_diff); + l[irow] = exp_diff * l[irow] + rowsum; + + // Rescale accumulated output + float* out_row = temp_output + irow * head_size; + for (ptrdiff_t icol = 0; icol < head_size; ++icol) { + out_row[icol] *= exp_diff; + } + } else { + l[irow] = rowsum; + } + } + + // Step 4: Accumulate O += S_exp * V_block + const float* v_block = v_cache_head + static_cast(ir) * static_cast(head_size); + MlasSgemmOperation( + CblasNoTrans, + CblasNoTrans, + row_size_q, // M + static_cast(head_size), // N + row_size_kv, // K + 1.0f, // alpha + scores, // A (exp softmax scores) + row_size_kv, // lda + v_block, // B (FP32 V block) + static_cast(head_size), // ldb + ir == 0 ? 0.0f : 1.0f, // beta (accumulate after first block) + temp_output, // C (accumulated output) + static_cast(head_size) // ldc + ); + } + + // Final: normalize output by l (softmax denominator) + // Output layout: [batch, sequence_length, num_heads, head_size] + float* output_row = args->output + + (static_cast(batch_idx) * static_cast(sequence_length) + + static_cast(q_idx)) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + const ptrdiff_t output_row_stride = num_heads * head_size; + + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + float inv_l = (l[irow] > 0.0f) ? (1.0f / l[irow]) : 0.0f; + float* src = temp_output + irow * head_size; + for (ptrdiff_t icol = 0; icol < head_size; ++icol) { + output_row[icol] = src[icol] * inv_l; + } + output_row += output_row_stride; + } + } +} + +// +// Flash Decoding: Phase 1 - parallel partial attention over (batch, head, kv_chunk). +// Each task computes attention for one KV chunk and stores (m, l, partial_output) +// into the partials buffer. +// +void +MlasFlashDecodingGQAThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + const size_t kv_head_stride = + static_cast(args->seqlen_present_kv) * static_cast(head_size); + + const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); + // Partials layout per entry: [m, l, output[head_size]] + const ptrdiff_t partial_stride = 2 + head_size; + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // Total tasks: (batch, head, kv_chunk) + const ptrdiff_t total_task_count = batch_size * num_heads * kv_chunk_count; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + // Decompose task_index into (batch_idx, head_idx, kv_chunk_idx) + ptrdiff_t tmp = task_index; + ptrdiff_t kv_chunk_idx = tmp % kv_chunk_count; + tmp /= kv_chunk_count; + ptrdiff_t head_idx = tmp % num_heads; + ptrdiff_t batch_idx = tmp / num_heads; + + // Per-thread scratch buffer: just scores[kv_block_size] + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* scores = reinterpret_cast(buffer_ptr); + + // KV block range for this chunk + const ptrdiff_t ir = kv_chunk_idx * kv_block_size; + const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); + + // Determine KV head index for GQA head sharing + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + + // K/V cache pointers + const size_t kv_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + kv_head_stride; + const float* k_cache_head = args->k_cache + kv_batch_head_offset; + const float* v_cache_head = args->v_cache + kv_batch_head_offset; + + // Q pointer: layout [batch, num_heads, 1, head_size] (sequence_length=1). + // The batch stride is supplied separately to support packed-QKV input. + const float* q_ptr = args->query + + static_cast(batch_idx) * args->q_batch_stride + + static_cast(head_idx) * static_cast(head_size); + + // Step 1: QK^T GEMM for this KV chunk + const float* k_block = k_cache_head + static_cast(ir) * static_cast(head_size); + MlasSgemmOperation( + CblasNoTrans, + CblasTrans, + 1, // M (single query row) + row_size_kv, // N + static_cast(head_size), // K + scale, // alpha + q_ptr, // A (FP32 query) + static_cast(head_size), // lda + k_block, // B (FP32 K block) + static_cast(head_size), // ldb + 0.0f, // beta + scores, // C (output scores) + row_size_kv // ldc + ); + + // Step 1b: Apply attention bias if present + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_seqlen_stride = + static_cast(args->attention_bias_seqlen_stride); + const ptrdiff_t bias_matrix_size = bias_seqlen_stride; // S=1 + // The bias tensor has shape [batch|1, num_heads|1, S, T]; the batch stride + // uses the actual head extent (1 when the head dim is broadcast). + const ptrdiff_t bias_head_extent = + args->attention_bias_broadcast_head ? 1 : static_cast(num_heads); + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * + bias_head_extent * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + const float* bias_row = args->attention_bias + bias_offset + ir; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + scores[jcol] += bias_row[jcol]; + } + } + + // Step 2: Apply causal mask + const ptrdiff_t global_q_pos = past_seqlen; // sequence_length=1, q_idx=0 + const ptrdiff_t causal_limit = global_q_pos + 1; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos >= causal_limit) { + scores[jcol] = std::numeric_limits::lowest(); + } + } + + // Apply local window masking if enabled + if (local_window_size >= 0) { + const ptrdiff_t window_start = + (causal_limit > local_window_size) ? (causal_limit - local_window_size) : 0; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos < window_start) { + scores[jcol] = std::numeric_limits::lowest(); + } + } + } + + // Step 3: Compute local softmax statistics (m, l) and exp scores +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(scores, row_size_kv); +#else + float rowmax = MlasReduceMaximumF32Kernel(scores, row_size_kv); +#endif + + // Pointer to this task's partial in the partials buffer + const ptrdiff_t partial_index = + (batch_idx * num_heads + head_idx) * kv_chunk_count + kv_chunk_idx; + float* partial = args->flash_decoding_partials + partial_index * partial_stride; + float* partial_m = partial; + float* partial_l = partial + 1; + float* partial_output = partial + 2; + + if (rowmax == std::numeric_limits::lowest()) { + // Entire chunk is masked: store sentinel + *partial_m = std::numeric_limits::lowest(); + *partial_l = 0.0f; + memset(partial_output, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + *partial_m = rowmax; + float negmax = -rowmax; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); +#endif + *partial_l = rowsum; + + // Step 4: S_exp * V_block -> partial_output + const float* v_block = v_cache_head + static_cast(ir) * static_cast(head_size); + MlasSgemmOperation( + CblasNoTrans, + CblasNoTrans, + 1, // M + static_cast(head_size), // N + row_size_kv, // K + 1.0f, // alpha + scores, // A (exp softmax scores) + row_size_kv, // lda + v_block, // B (FP32 V block) + static_cast(head_size), // ldb + 0.0f, // beta (overwrite) + partial_output, // C (output for this chunk) + static_cast(head_size) // ldc + ); + } +} + +// +// Flash Decoding: Phase 2 - reduce partials for each (batch, head) into final output. +// +void +MlasFlashDecodingGQAReduceThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); + const ptrdiff_t thread_count = static_cast(args->thread_count); + const ptrdiff_t partial_stride = 2 + head_size; + + // Total reduction tasks: one per (batch, head) + const ptrdiff_t total_task_count = batch_size * num_heads; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t head_idx = task_index % num_heads; + ptrdiff_t batch_idx = task_index / num_heads; + + // Pointer to this (batch, head)'s partials: kv_chunk_count entries + const float* partials_base = args->flash_decoding_partials + + task_index * kv_chunk_count * partial_stride; + + // Find global max across all chunks + float global_m = std::numeric_limits::lowest(); + for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { + float chunk_m = partials_base[c * partial_stride]; + global_m = std::max(global_m, chunk_m); + } + + // Output layout: [batch, sequence_length=1, num_heads, head_size] + float* output_ptr = args->output + + static_cast(batch_idx) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + + // If all chunks are masked, output zeros + if (global_m == std::numeric_limits::lowest()) { + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + // Accumulate rescaled outputs and l values + float global_l = 0.0f; + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + + for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { + const float* partial = partials_base + c * partial_stride; + float chunk_m = partial[0]; + float chunk_l = partial[1]; + const float* chunk_output = partial + 2; + + if (chunk_l <= 0.0f) { + continue; // masked chunk contributes nothing + } + + float rescale = std::exp(chunk_m - global_m); + global_l += rescale * chunk_l; + + // partial_output = S_exp * V where sum(S_exp) = l_c (unnormalized). + // Rescale by exp(m_c - global_m) to align all chunks to the same max. + for (ptrdiff_t i = 0; i < head_size; ++i) { + output_ptr[i] += rescale * chunk_output[i]; + } + } + + // output = sum_c(rescale_c * partial_output_c) / global_l + float inv_l = (global_l > 0.0f) ? (1.0f / global_l) : 0.0f; + for (ptrdiff_t i = 0; i < head_size; ++i) { + output_ptr[i] *= inv_l; + } + } +} + +void +MLASCALL +MlasFlashAttentionGQA( + MlasFlashAttentionGQAArgs* args, + MLAS_THREADPOOL* ThreadPool +) +{ + if (args->flash_decoding_partials != nullptr && args->sequence_length == 1) { + // Flash decoding: two-phase approach. + // Phase 1: parallel partial computation over (batch, head, kv_chunk). + MlasExecuteThreaded( + MlasFlashDecodingGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + // Phase 2: reduce partials into final output (parallel over batch*heads). + MlasExecuteThreaded( + MlasFlashDecodingGQAReduceThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } else { + MlasExecuteThreaded( + MlasFlashAttentionGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } +} From 980b497e418bfc7af4d3cbdfeb150197bf415c1b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 22 Jun 2026 02:51:24 +0000 Subject: [PATCH 2/5] benchmark and doc --- docs/contrib_ops/cpu/gqa.md | 45 ++++ onnxruntime/core/mlas/lib/flashattn_gqa.cpp | 12 ++ .../transformers/benchmark_gqa_cpu_flash.py | 197 +++++++++++++----- 3 files changed, 201 insertions(+), 53 deletions(-) diff --git a/docs/contrib_ops/cpu/gqa.md b/docs/contrib_ops/cpu/gqa.md index d3b7f25c6fdba..840dcea5b0cfd 100644 --- a/docs/contrib_ops/cpu/gqa.md +++ b/docs/contrib_ops/cpu/gqa.md @@ -233,6 +233,18 @@ For each (batch, head, q_block) tile: 4. **SV accumulate** — `MlasSgemmOperation(..., beta)` accumulates `softmax(QK_block) × V_block` into the output tile 5. **Finalize** — Normalize accumulated output by `1/l` after all KV blocks are processed +#### Causal early-termination + +During prefill, every KV block whose start index is at or beyond the largest global query +position in the current q_block is fully causally masked and contributes nothing. The kernel +computes a per-q_block bound +`kv_causal_limit = past_seqlen + q_idx + row_size_q` and breaks out of the KV loop once +`ir >= kv_causal_limit`, instead of computing and then discarding the masked upper-triangle +QK/SV GEMMs. This skips roughly half of the QK/SV work for square prefill (S = T) and is the +main reason the FP32 flash path is faster than naive even at short sequence lengths +(see the benchmark results below). Decode (q_block of size 1 at the cache tail) attends to all +KV positions, so the bound equals `total_seqlen` and nothing is skipped. + ### Activation Conditions The non-quantized flash path is selected when ALL of the following hold: @@ -476,7 +488,40 @@ Flash decoding IS active (batch×heads=4 < threads=8, KV partitioned across idle | 4096 (N=32) | +2131 | +87 | 24.5x | **Summary**: The flash path's primary benefit for prefill is **memory reduction** — avoiding the full O(N×S×T) attention matrix. For S=4096 with 16 heads, the naive path allocates ~1 GB for attention scores while the flash path uses ~80 MB regardless of sequence length. The prefill latency speedup (1.2–2.7x at kernel level, 1.2–1.9x at operator level) comes from improved cache locality. For decode, the tiled kernel provides 1.2–1.8x kernel-level speedup from fused single-pass KV access; at operator level the gain is visible for T≥1024 but masked by KV concat overhead at shorter sequences. When flash decoding is active (batch×heads < threads), KV partitioning across idle threads yields an additional 2–5x speedup for long sequences. +### Non-Quantized (FP32) Flash Attention vs Naive benchmark results + +Measured on an AMD EPYC 7763 (32 logical / 16 physical cores), threads=8, FP32 KV cache, +`B=1, num_heads=16, kv_num_heads=8, head_size=128`. Operator-level, measured with: + +```bash +python onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py \ + --fp32 --prompt_only --warmup 10 --repeats 30 +``` + +#### Latency — Prefill (S = T, prompt phase) + +| Seq Length | Naive (ms) | Flash (ms) | Speedup | +|---:|---:|---:|---:| +| 512 | 5.8\u20138.4 | 4.2\u20135.3 | 1.4\u20131.6x | +| 1024 | 25\u201329 | 13\u201318 | 1.6\u20132.0x | +| 2048 | 87\u2013118 | 52\u201365 | 1.5\u20132.0x | +| 4096 | 365\u2013380 | 213\u2013234 | 1.6\u20131.7x | + +The FP32 flash path is faster than naive across all measured prefill lengths. With the causal +early-termination described above, roughly half of the QK/SV work (the causally masked +upper triangle of the square prefill attention matrix) is skipped entirely, which more than +offsets the intrinsic per-KV-block online-softmax overhead (running max/exp/output rescale). +The same advantage holds single-threaded (1.4\u20131.8x at threads=1), confirming the gain is +algorithmic rather than purely from threading. + +#### Latency — Decode (S = 1, token generation) +For single-token decode at this head configuration (`batch\u00d7heads = 16 > threads = 8`, so +flash decoding KV-partitioning is not active), the workload per `Run` is tiny and dominated +by KV-cache concatenation overhead. Operator-level decode latency is therefore noisy and +roughly at parity between the two paths, with longer total sequence lengths (T\u22652049) +tending to favor flash. The FP32 decode path is not the target of the prefill-oriented +causal early-termination optimization. ## Current CPU Limitations The current CPU GroupQueryAttention implementation has a few important limitations: diff --git a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp index 2f62b9a8b0735..25f3733f59cca 100644 --- a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp +++ b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp @@ -126,8 +126,20 @@ MlasFlashAttentionGQAThreaded( static_cast(head_idx) * static_cast(sequence_length) * static_cast(head_size) + static_cast(q_idx) * static_cast(head_size); + // Causal early-termination bound: the largest global query position in this + // q_block is (past_seqlen + q_idx + row_size_q - 1), so it can attend to KV + // positions up to that index inclusive. Any KV block starting at or beyond + // (past_seqlen + q_idx + row_size_q) is fully causally masked for every row in + // the block, so it contributes nothing and can be skipped. This avoids the + // wasted QK/SV GEMMs over the causal upper triangle during prefill. + const ptrdiff_t kv_causal_limit = + past_seqlen + q_idx + static_cast(row_size_q); + // Iterate over KV blocks for (ptrdiff_t ir = 0; ir < total_seqlen; ir += kv_block_size) { + if (ir >= kv_causal_limit) { + break; + } const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); // Step 1: QK^T GEMM with FP32 K block diff --git a/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py b/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py index 77ac08cf50d6c..7dbcb16a75973 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py @@ -106,6 +106,70 @@ def create_quantized_gqa_graph( return model.SerializeToString() +def create_fp32_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + buffer_seq_len=None, +): + """Create an ONNX graph for GroupQueryAttention with a non-quantized FP32 KV cache.""" + if buffer_seq_len is None: + buffer_seq_len = seq_len + + hidden_size = num_heads * head_size + kv_hidden_size = kv_num_heads * head_size + + inputs = [ + "query", + "key", + "value", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length", + ] + + node = helper.make_node( + op_type="GroupQueryAttention", + inputs=inputs, + outputs=["output", "present_key", "present_value"], + name="GroupQueryAttention_0", + num_heads=num_heads, + kv_num_heads=kv_num_heads, + domain="com.microsoft", + ) + + graph_input = [ + helper.make_tensor_value_info("query", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info("key", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info("value", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info( + "past_key", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + helper.make_tensor_value_info( + "past_value", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [batch_size]), + helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), + ] + + graph_output = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info( + "present_key", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + helper.make_tensor_value_info( + "present_value", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + ] + + graph = helper.make_graph([node], "BenchGQA", graph_input, graph_output) + model = helper.make_model(graph) + return model.SerializeToString() + + def benchmark_gqa( batch_size, seq_len, @@ -117,6 +181,7 @@ def benchmark_gqa( past_seq_len=0, warmup=5, repeats=20, + non_quantized=False, ): """Benchmark a single GQA configuration. Returns elapsed time in ms.""" hidden_size = num_heads * head_size @@ -126,54 +191,76 @@ def benchmark_gqa( total_seqlen = past_seq_len + seq_len buffer_seq_len = total_seqlen - onnx_model_str = create_quantized_gqa_graph( - batch_size, - seq_len, - num_heads, - kv_num_heads, - head_size, - quant_type, - bit_width, - buffer_seq_len=buffer_seq_len, - ) - sess_options = SessionOptions() sess_options.intra_op_num_threads = 8 - sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - # Generate inputs np.random.seed(42) query = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, hidden_size)).astype(np.float32) key = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) value = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) - - cache_dtype = np.uint8 if bit_width == 4 else np.int8 - past_k = np.random.randint( - 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 - ).view(cache_dtype) - past_v = np.random.randint( - 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 - ).view(cache_dtype) - seqlens_k = np.array([total_seqlen - 1] * batch_size, dtype=np.int32) total_seq = np.array([total_seqlen], dtype=np.int32) - per_channel = quant_type == "PER_CHANNEL" - scale_size = kv_num_heads * head_size if per_channel else 1 - k_scale = np.full(scale_size, 0.01, dtype=np.float32) - v_scale = np.full(scale_size, 0.01, dtype=np.float32) - - feeds = { - "query": query, - "key": key, - "value": value, - "past_key": past_k, - "past_value": past_v, - "seqlens_k": seqlens_k, - "total_sequence_length": total_seq, - "k_scale": k_scale, - "v_scale": v_scale, - } + if non_quantized: + onnx_model_str = create_fp32_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + buffer_seq_len=buffer_seq_len, + ) + sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + + past_k = np.random.uniform(-0.5, 0.5, (batch_size, kv_num_heads, buffer_seq_len, head_size)).astype(np.float32) + past_v = np.random.uniform(-0.5, 0.5, (batch_size, kv_num_heads, buffer_seq_len, head_size)).astype(np.float32) + + feeds = { + "query": query, + "key": key, + "value": value, + "past_key": past_k, + "past_value": past_v, + "seqlens_k": seqlens_k, + "total_sequence_length": total_seq, + } + else: + onnx_model_str = create_quantized_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + buffer_seq_len=buffer_seq_len, + ) + sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + + cache_dtype = np.uint8 if bit_width == 4 else np.int8 + past_k = np.random.randint( + 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 + ).view(cache_dtype) + past_v = np.random.randint( + 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 + ).view(cache_dtype) + + per_channel = quant_type == "PER_CHANNEL" + scale_size = kv_num_heads * head_size if per_channel else 1 + k_scale = np.full(scale_size, 0.01, dtype=np.float32) + v_scale = np.full(scale_size, 0.01, dtype=np.float32) + + feeds = { + "query": query, + "key": key, + "value": value, + "past_key": past_k, + "past_value": past_v, + "seqlens_k": seqlens_k, + "total_sequence_length": total_seq, + "k_scale": k_scale, + "v_scale": v_scale, + } # Warmup for _ in range(warmup): @@ -242,20 +329,21 @@ def run_benchmarks(args): "past_seq_len": 2048, } ) - # INT4 prefill - configs.append( - { - "label": "Prefill S=2048 INT4", - "batch_size": 1, - "seq_len": 2048, - "num_heads": 16, - "kv_num_heads": 8, - "head_size": 128, - "quant_type": "PER_TENSOR", - "bit_width": 4, - "past_seq_len": 0, - } - ) + # INT4 prefill (quantized mode only) + if not args.fp32: + configs.append( + { + "label": "Prefill S=2048 INT4", + "batch_size": 1, + "seq_len": 2048, + "num_heads": 16, + "kv_num_heads": 8, + "head_size": 128, + "quant_type": "PER_TENSOR", + "bit_width": 4, + "past_seq_len": 0, + } + ) warmup = args.warmup repeats = args.repeats @@ -263,13 +351,15 @@ def run_benchmarks(args): # Save and restore env var to avoid side effects on callers saved_env = os.environ.get("ORT_GQA_DISABLE_FLASH_ATTENTION") + kv_mode = "FP32 (non-quantized)" if args.fp32 else "INT8/INT4 quantized" print("\nBenchmark: CPU GroupQueryAttention — Flash vs Naive") - print(f"Threads: {8}, Warmup: {warmup}, Repeats: {repeats}") + print(f"KV cache: {kv_mode}, Threads: {8}, Warmup: {warmup}, Repeats: {repeats}") print(f"{'Config':<25} {'Naive (ms)':>12} {'Flash (ms)':>12} {'Speedup':>10}") print("-" * 62) for cfg in configs: label = cfg.pop("label") + cfg["non_quantized"] = args.fp32 # Flash path (default) os.environ.pop("ORT_GQA_DISABLE_FLASH_ATTENTION", None) @@ -296,5 +386,6 @@ def run_benchmarks(args): parser.add_argument("--repeats", type=int, default=20, help="Measurement iterations") parser.add_argument("--decode_only", action="store_true", help="Only run decode benchmarks") parser.add_argument("--prompt_only", action="store_true", help="Only run prompt benchmarks") + parser.add_argument("--fp32", action="store_true", help="Use non-quantized FP32 KV cache instead of quantized") args = parser.parse_args() run_benchmarks(args) From 125172e29a843b49f6b037f3aef20c42d0998112 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 22 Jun 2026 02:56:59 +0000 Subject: [PATCH 3/5] Use SafeInt for FP32 flash attention scratch buffer sizing Guard the per-thread scratch and flash-decoding partials buffer size computations against size_t overflow for large or malformed shapes, matching the SafeInt usage elsewhere in this file. --- .../contrib_ops/cpu/bert/gqa_attention_base.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index d66ed2cb0fb7d..caa6a89f02fd7 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -1096,18 +1096,18 @@ class GQAAttentionBase { size_t partials_buffer_bytes = 0; if (use_flash_decoding) { // Flash decoding: per-thread scratch only needs scores[kv_block_size] - buffer_size_per_thread = static_cast(kv_block_size) * sizeof(float); + buffer_size_per_thread = SafeInt(kv_block_size) * sizeof(float); // Partials: [batch * num_heads * kv_chunk_count * (2 + head_size)] floats - partials_buffer_bytes = static_cast(batch_size) * num_heads_ * + partials_buffer_bytes = SafeInt(batch_size) * num_heads_ * kv_chunk_count * (2 + head_size) * sizeof(float); } else { buffer_size_per_thread = - (static_cast(q_block_size) * 2 + // l + m - static_cast(q_block_size) * static_cast(kv_block_size) + // scores - static_cast(q_block_size) * static_cast(head_size)) * // temp_output + (SafeInt(q_block_size) * 2 + // l + m + SafeInt(q_block_size) * kv_block_size + // scores + SafeInt(q_block_size) * head_size) * // temp_output sizeof(float); } - size_t total_buffer_bytes = buffer_size_per_thread * thread_count + partials_buffer_bytes; + size_t total_buffer_bytes = SafeInt(buffer_size_per_thread) * thread_count + partials_buffer_bytes; auto flash_buffer_alloc = allocator->Alloc(total_buffer_bytes); BufferUniquePtr flash_buffer(flash_buffer_alloc, BufferDeleter(allocator)); From 2f945ebab46cbdf01ca24705323e4e046a65a573 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 22 Jun 2026 18:39:45 +0000 Subject: [PATCH 4/5] limit query length > 1 for flash --- onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 29d372eb7a4bb..ddbbc28b91700 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -349,8 +349,12 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // naive path when an unsupported feature is requested (softcap, smooth softmax, // head sink, or QK output). if constexpr (std::is_same_v) { + // Restrict the flash path to prefill / chunked-prefill (query length > 1). Single-token + // decode (sequence_length == 1) has no flash benefit: the naive score matrix is only + // [1, total_sequence_length] per head, so there is nothing to tile away, and the extra + // online-softmax bookkeeping makes it slower in practice. const bool use_flash = !disable_gqa_flash_ && - parameters.total_sequence_length > 1 && + parameters.sequence_length > 1 && softcap_ == 0.0f && !use_smooth_softmax_ && head_sink_data == nullptr && From 11195e7df3837e256f2f12b8290c939d99c6bd6e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 22 Jun 2026 18:57:31 +0000 Subject: [PATCH 5/5] Remove dead flash-decoding code from FP32 GQA path The FP32 flash path is now restricted to prefill (sequence_length > 1), so the single-token flash-decoding kernels are unreachable. Remove MlasFlashDecodingGQAThreaded/MlasFlashDecodingGQAReduceThreaded, the flash_decoding_partials/kv_chunk_count args fields, and the partials buffer wiring in ApplyAttentionFlash. Simplify the MlasFlashAttentionGQA dispatch to call the prefill kernel directly, and update the GQA doc to describe the prefill-only behavior. --- docs/contrib_ops/cpu/gqa.md | 28 +- .../contrib_ops/cpu/bert/gqa_attention_base.h | 40 +-- onnxruntime/core/mlas/inc/mlas.h | 14 +- onnxruntime/core/mlas/lib/flashattn_gqa.cpp | 328 +----------------- 4 files changed, 32 insertions(+), 378 deletions(-) diff --git a/docs/contrib_ops/cpu/gqa.md b/docs/contrib_ops/cpu/gqa.md index 840dcea5b0cfd..8b81fdba8f1a6 100644 --- a/docs/contrib_ops/cpu/gqa.md +++ b/docs/contrib_ops/cpu/gqa.md @@ -58,7 +58,7 @@ Both the non-quantized and quantized paths have two execution strategies: - **Naive (full materialization)**: Computes the full `[S, T]` attention score matrix, applies masking and softmax, then computes the SV product. Simple but memory-intensive for long sequences. - **Flash Attention (tiled, online softmax)**: Processes K/V in L2-cache-sized blocks using the online softmax algorithm (Milakov & Gimelshein, 2018). Avoids materializing the full attention matrix, reducing peak memory from O(S×T) to O(S×Bc) per head. Multi-threaded via the MLAS thread pool. -The quantized path uses `MlasFlashAttentionQuantizedKV` (`flashattn_qkv.cpp`); the non-quantized FP32 path uses `MlasFlashAttentionGQA` (`flashattn_gqa.cpp`). Both share the same tiling, masking, online-softmax, and flash-decoding structure. +The quantized path uses `MlasFlashAttentionQuantizedKV` (`flashattn_qkv.cpp`); the non-quantized FP32 path uses `MlasFlashAttentionGQA` (`flashattn_gqa.cpp`). Both share the same tiling, masking, and online-softmax structure. The quantized path additionally provides a two-phase flash-decoding strategy for single-token decode; the non-quantized FP32 path is limited to prefill (`sequence_length > 1`) and uses the naive path for decode. The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path (applies to both the quantized and non-quantized paths). @@ -213,7 +213,7 @@ The partials buffer is allocated alongside the per-thread scratch in a single al ## Non-Quantized Flash Attention Path -The non-quantized flash attention path (`MlasFlashAttentionGQA`, in `flashattn_gqa.cpp`) is the FP32-KV-cache counterpart of the quantized path. It is selected for the `float` kernel specialization and reuses the same tiling, online-softmax, masking, and flash-decoding structure. +The non-quantized flash attention path (`MlasFlashAttentionGQA`, in `flashattn_gqa.cpp`) is the FP32-KV-cache counterpart of the quantized path. It is selected for the `float` kernel specialization and reuses the same tiling, online-softmax, and masking structure. Unlike the quantized path, it is limited to prefill / chunked-prefill (`sequence_length > 1`); single-token decode (`sequence_length == 1`) uses the naive path, which is why there is no flash-decoding variant here. ### Differences from the Quantized Path @@ -242,8 +242,7 @@ computes a per-q_block bound `ir >= kv_causal_limit`, instead of computing and then discarding the masked upper-triangle QK/SV GEMMs. This skips roughly half of the QK/SV work for square prefill (S = T) and is the main reason the FP32 flash path is faster than naive even at short sequence lengths -(see the benchmark results below). Decode (q_block of size 1 at the cache tail) attends to all -KV positions, so the bound equals `total_seqlen` and nothing is skipped. +(see the benchmark results below). ### Activation Conditions @@ -251,18 +250,18 @@ The non-quantized flash path is selected when ALL of the following hold: - The kernel specialization is `float` (FP16 uses the naive path) - `ORT_GQA_DISABLE_FLASH_ATTENTION` environment variable is not set (or set to `0`) -- `total_sequence_length > 1` +- `sequence_length > 1` (prefill / chunked-prefill; single-token decode uses the naive path) - No softcap - No smooth softmax - No head sink - No output QK capture - `present_key` and `present_value` are provided -Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, shared past/present buffers, and flash decoding are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path. +Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, and shared past/present buffers are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path. -### Block Sizes, Threading, and Flash Decoding +### Block Sizes and Threading -Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`), and the two-phase flash-decoding strategy for single-token decode are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. +Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, and the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`) are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. Because this path is prefill-only, it does not include the quantized path's two-phase flash-decoding strategy for single-token decode. ## MLAS Dispatch Paths @@ -514,14 +513,13 @@ offsets the intrinsic per-KV-block online-softmax overhead (running max/exp/outp The same advantage holds single-threaded (1.4\u20131.8x at threads=1), confirming the gain is algorithmic rather than purely from threading. -#### Latency — Decode (S = 1, token generation) +#### Decode (S = 1, token generation) -For single-token decode at this head configuration (`batch\u00d7heads = 16 > threads = 8`, so -flash decoding KV-partitioning is not active), the workload per `Run` is tiny and dominated -by KV-cache concatenation overhead. Operator-level decode latency is therefore noisy and -roughly at parity between the two paths, with longer total sequence lengths (T\u22652049) -tending to favor flash. The FP32 decode path is not the target of the prefill-oriented -causal early-termination optimization. +Single-token decode (`sequence_length == 1`) is **not** handled by the FP32 flash path; it falls +back to the naive path. Decode produces only a `[1, total_sequence_length]` score row per head, +so there is nothing to tile away, and the extra online-softmax bookkeeping made the flash kernel +slower and noisier in practice. Restricting the flash path to prefill (`sequence_length > 1`) keeps +the consistent prefill win without regressing decode. ## Current CPU Limitations The current CPU GroupQueryAttention implementation has a few important limitations: diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index caa6a89f02fd7..dbafcb38acc91 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -1085,38 +1085,16 @@ class GQAAttentionBase { int thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp); thread_count = std::max(thread_count, 1); - // Flash decoding: for decode (sequence_length==1), partition KV across threads - // to improve parallelism when batch*heads < thread_count. - const int kv_chunk_count = (max_total_seqlen + kv_block_size - 1) / kv_block_size; - const bool use_flash_decoding = (sequence_length == 1 && - batch_size * num_heads_ < thread_count && - kv_chunk_count > 1); - - size_t buffer_size_per_thread; - size_t partials_buffer_bytes = 0; - if (use_flash_decoding) { - // Flash decoding: per-thread scratch only needs scores[kv_block_size] - buffer_size_per_thread = SafeInt(kv_block_size) * sizeof(float); - // Partials: [batch * num_heads * kv_chunk_count * (2 + head_size)] floats - partials_buffer_bytes = SafeInt(batch_size) * num_heads_ * - kv_chunk_count * (2 + head_size) * sizeof(float); - } else { - buffer_size_per_thread = - (SafeInt(q_block_size) * 2 + // l + m - SafeInt(q_block_size) * kv_block_size + // scores - SafeInt(q_block_size) * head_size) * // temp_output - sizeof(float); - } - size_t total_buffer_bytes = SafeInt(buffer_size_per_thread) * thread_count + partials_buffer_bytes; + // Per-thread scratch: l + m + scores[q_block_size * kv_block_size] + temp_output[q_block_size * head_size] + const size_t buffer_size_per_thread = + (SafeInt(q_block_size) * 2 + // l + m + SafeInt(q_block_size) * kv_block_size + // scores + SafeInt(q_block_size) * head_size) * // temp_output + sizeof(float); + size_t total_buffer_bytes = SafeInt(buffer_size_per_thread) * thread_count; auto flash_buffer_alloc = allocator->Alloc(total_buffer_bytes); BufferUniquePtr flash_buffer(flash_buffer_alloc, BufferDeleter(allocator)); - // Partials buffer is placed after per-thread scratch - float* partials_ptr = use_flash_decoding - ? reinterpret_cast(reinterpret_cast(flash_buffer_alloc) + - buffer_size_per_thread * thread_count) - : nullptr; - const float scale = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; // If all batch items share the same past_seqlen, use the unified flash kernel. @@ -1149,8 +1127,6 @@ class GQAAttentionBase { args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; args.attention_bias_broadcast_batch = attention_bias_broadcast_batch; args.attention_bias_broadcast_head = attention_bias_broadcast_head; - args.flash_decoding_partials = partials_ptr; - args.kv_chunk_count = kv_chunk_count; MlasFlashAttentionGQA(&args, tp); } else { @@ -1203,8 +1179,6 @@ class GQAAttentionBase { args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; args.attention_bias_broadcast_batch = true; // batch offset handled above args.attention_bias_broadcast_head = attention_bias_broadcast_head; - args.flash_decoding_partials = nullptr; // per-batch doesn't use flash decoding - args.kv_chunk_count = 0; MlasFlashAttentionGQA(&args, tp); } diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index cbca2d85a97a4..2410fcc83e7cd 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -2302,9 +2302,9 @@ MlasFlashAttention( // // Adapts the online-softmax tiled algorithm to operate on an FP32 present // K/V cache laid out as BNSH ([batch, kv_num_heads, seqlen_present, head_size]). -// Supports GQA head grouping, causal masking, local window attention, -// additive attention bias, and an optional flash-decoding split over the KV -// sequence dimension for the single-token decode case. +// Supports GQA head grouping, causal masking, local window attention, and +// additive attention bias. Intended for prefill / chunked-prefill +// (sequence_length > 1). // struct MlasFlashAttentionGQAArgs { int batch_size; @@ -2320,7 +2320,7 @@ struct MlasFlashAttentionGQAArgs { int kv_block_size; // key/value tile size (Bc) float scale; // QK scaling factor int thread_count; // number of partitions / threads - float* buffer; // per-thread scratch (+ optional flash-decoding partials) + float* buffer; // per-thread scratch size_t buffer_size_per_thread; const float* query; // [batch, num_heads, sequence_length, head_size] BNSH @@ -2334,12 +2334,6 @@ struct MlasFlashAttentionGQAArgs { int attention_bias_seqlen_stride; bool attention_bias_broadcast_batch; bool attention_bias_broadcast_head; - - // Flash decoding (sequence_length == 1): partition KV across threads. - // Set flash_decoding_partials != nullptr to enable; otherwise the standard - // per-(batch, head, q_block) partitioning is used. - float* flash_decoding_partials; - int kv_chunk_count; }; /** diff --git a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp index 25f3733f59cca..4d0ff65733a44 100644 --- a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp +++ b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp @@ -17,8 +17,8 @@ Module Name: an FP32 present K/V cache laid out as BNSH ([batch, kv_num_heads, seqlen_present, head_size]) and to support GQA head grouping (num_heads % kv_num_heads == 0), causal masking, local window - attention, additive attention bias, and an optional flash-decoding split - over the KV sequence dimension for single-token decode. + attention, and additive attention bias. Intended for prefill / + chunked-prefill (sequence_length > 1). QK^T and S*V use the single-threaded SGEMM primitive MlasSgemmOperation; the outer parallelism is provided by MlasExecuteThreaded. @@ -294,300 +294,6 @@ MlasFlashAttentionGQAThreaded( } } -// -// Flash Decoding: Phase 1 - parallel partial attention over (batch, head, kv_chunk). -// Each task computes attention for one KV chunk and stores (m, l, partial_output) -// into the partials buffer. -// -void -MlasFlashDecodingGQAThreaded( - void* argptr, - std::ptrdiff_t thread_id -) -{ - const MlasFlashAttentionGQAArgs* args = - reinterpret_cast(argptr); - - const ptrdiff_t kv_block_size = static_cast(args->kv_block_size); - const ptrdiff_t batch_size = static_cast(args->batch_size); - const ptrdiff_t num_heads = static_cast(args->num_heads); - const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); - const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); - const ptrdiff_t head_size = static_cast(args->head_size); - const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); - const ptrdiff_t local_window_size = static_cast(args->local_window_size); - const float scale = args->scale; - - float* buffer = args->buffer; - const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); - const ptrdiff_t thread_count = static_cast(args->thread_count); - - const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); - const size_t kv_head_stride = - static_cast(args->seqlen_present_kv) * static_cast(head_size); - - const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); - // Partials layout per entry: [m, l, output[head_size]] - const ptrdiff_t partial_stride = 2 + head_size; - -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - auto&& mlas_platform = GetMlasPlatform(); -#endif - - // Total tasks: (batch, head, kv_chunk) - const ptrdiff_t total_task_count = batch_size * num_heads * kv_chunk_count; - - ptrdiff_t task_start = 0; - ptrdiff_t task_end = 0; - ptrdiff_t quotient = total_task_count / thread_count; - ptrdiff_t remainder = total_task_count % thread_count; - if (thread_id < remainder) { - task_start = (quotient + 1) * thread_id; - task_end = task_start + quotient + 1; - } else { - task_start = quotient * thread_id + remainder; - task_end = task_start + quotient; - } - - for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { - // Decompose task_index into (batch_idx, head_idx, kv_chunk_idx) - ptrdiff_t tmp = task_index; - ptrdiff_t kv_chunk_idx = tmp % kv_chunk_count; - tmp /= kv_chunk_count; - ptrdiff_t head_idx = tmp % num_heads; - ptrdiff_t batch_idx = tmp / num_heads; - - // Per-thread scratch buffer: just scores[kv_block_size] - char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; - float* scores = reinterpret_cast(buffer_ptr); - - // KV block range for this chunk - const ptrdiff_t ir = kv_chunk_idx * kv_block_size; - const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); - - // Determine KV head index for GQA head sharing - const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; - - // K/V cache pointers - const size_t kv_batch_head_offset = - (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * - kv_head_stride; - const float* k_cache_head = args->k_cache + kv_batch_head_offset; - const float* v_cache_head = args->v_cache + kv_batch_head_offset; - - // Q pointer: layout [batch, num_heads, 1, head_size] (sequence_length=1). - // The batch stride is supplied separately to support packed-QKV input. - const float* q_ptr = args->query + - static_cast(batch_idx) * args->q_batch_stride + - static_cast(head_idx) * static_cast(head_size); - - // Step 1: QK^T GEMM for this KV chunk - const float* k_block = k_cache_head + static_cast(ir) * static_cast(head_size); - MlasSgemmOperation( - CblasNoTrans, - CblasTrans, - 1, // M (single query row) - row_size_kv, // N - static_cast(head_size), // K - scale, // alpha - q_ptr, // A (FP32 query) - static_cast(head_size), // lda - k_block, // B (FP32 K block) - static_cast(head_size), // ldb - 0.0f, // beta - scores, // C (output scores) - row_size_kv // ldc - ); - - // Step 1b: Apply attention bias if present - if (args->attention_bias != nullptr) { - const ptrdiff_t bias_seqlen_stride = - static_cast(args->attention_bias_seqlen_stride); - const ptrdiff_t bias_matrix_size = bias_seqlen_stride; // S=1 - // The bias tensor has shape [batch|1, num_heads|1, S, T]; the batch stride - // uses the actual head extent (1 when the head dim is broadcast). - const ptrdiff_t bias_head_extent = - args->attention_bias_broadcast_head ? 1 : static_cast(num_heads); - ptrdiff_t bias_offset = 0; - if (!args->attention_bias_broadcast_batch) { - bias_offset += static_cast(batch_idx) * - bias_head_extent * bias_matrix_size; - } - if (!args->attention_bias_broadcast_head) { - bias_offset += static_cast(head_idx) * bias_matrix_size; - } - const float* bias_row = args->attention_bias + bias_offset + ir; - for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { - scores[jcol] += bias_row[jcol]; - } - } - - // Step 2: Apply causal mask - const ptrdiff_t global_q_pos = past_seqlen; // sequence_length=1, q_idx=0 - const ptrdiff_t causal_limit = global_q_pos + 1; - for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { - ptrdiff_t kv_pos = ir + jcol; - if (kv_pos >= causal_limit) { - scores[jcol] = std::numeric_limits::lowest(); - } - } - - // Apply local window masking if enabled - if (local_window_size >= 0) { - const ptrdiff_t window_start = - (causal_limit > local_window_size) ? (causal_limit - local_window_size) : 0; - for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { - ptrdiff_t kv_pos = ir + jcol; - if (kv_pos < window_start) { - scores[jcol] = std::numeric_limits::lowest(); - } - } - } - - // Step 3: Compute local softmax statistics (m, l) and exp scores -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - float rowmax = mlas_platform.ReduceMaximumF32Kernel(scores, row_size_kv); -#else - float rowmax = MlasReduceMaximumF32Kernel(scores, row_size_kv); -#endif - - // Pointer to this task's partial in the partials buffer - const ptrdiff_t partial_index = - (batch_idx * num_heads + head_idx) * kv_chunk_count + kv_chunk_idx; - float* partial = args->flash_decoding_partials + partial_index * partial_stride; - float* partial_m = partial; - float* partial_l = partial + 1; - float* partial_output = partial + 2; - - if (rowmax == std::numeric_limits::lowest()) { - // Entire chunk is masked: store sentinel - *partial_m = std::numeric_limits::lowest(); - *partial_l = 0.0f; - memset(partial_output, 0, static_cast(head_size) * sizeof(float)); - continue; - } - - *partial_m = rowmax; - float negmax = -rowmax; -#if defined(MLAS_TARGET_AMD64) - float rowsum = mlas_platform.ComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); -#else - float rowsum = MlasComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); -#endif - *partial_l = rowsum; - - // Step 4: S_exp * V_block -> partial_output - const float* v_block = v_cache_head + static_cast(ir) * static_cast(head_size); - MlasSgemmOperation( - CblasNoTrans, - CblasNoTrans, - 1, // M - static_cast(head_size), // N - row_size_kv, // K - 1.0f, // alpha - scores, // A (exp softmax scores) - row_size_kv, // lda - v_block, // B (FP32 V block) - static_cast(head_size), // ldb - 0.0f, // beta (overwrite) - partial_output, // C (output for this chunk) - static_cast(head_size) // ldc - ); - } -} - -// -// Flash Decoding: Phase 2 - reduce partials for each (batch, head) into final output. -// -void -MlasFlashDecodingGQAReduceThreaded( - void* argptr, - std::ptrdiff_t thread_id -) -{ - const MlasFlashAttentionGQAArgs* args = - reinterpret_cast(argptr); - - const ptrdiff_t batch_size = static_cast(args->batch_size); - const ptrdiff_t num_heads = static_cast(args->num_heads); - const ptrdiff_t head_size = static_cast(args->head_size); - const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); - const ptrdiff_t thread_count = static_cast(args->thread_count); - const ptrdiff_t partial_stride = 2 + head_size; - - // Total reduction tasks: one per (batch, head) - const ptrdiff_t total_task_count = batch_size * num_heads; - - ptrdiff_t task_start = 0; - ptrdiff_t task_end = 0; - ptrdiff_t quotient = total_task_count / thread_count; - ptrdiff_t remainder = total_task_count % thread_count; - if (thread_id < remainder) { - task_start = (quotient + 1) * thread_id; - task_end = task_start + quotient + 1; - } else { - task_start = quotient * thread_id + remainder; - task_end = task_start + quotient; - } - - for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { - ptrdiff_t head_idx = task_index % num_heads; - ptrdiff_t batch_idx = task_index / num_heads; - - // Pointer to this (batch, head)'s partials: kv_chunk_count entries - const float* partials_base = args->flash_decoding_partials + - task_index * kv_chunk_count * partial_stride; - - // Find global max across all chunks - float global_m = std::numeric_limits::lowest(); - for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { - float chunk_m = partials_base[c * partial_stride]; - global_m = std::max(global_m, chunk_m); - } - - // Output layout: [batch, sequence_length=1, num_heads, head_size] - float* output_ptr = args->output + - static_cast(batch_idx) * static_cast(num_heads) * static_cast(head_size) + - static_cast(head_idx) * static_cast(head_size); - - // If all chunks are masked, output zeros - if (global_m == std::numeric_limits::lowest()) { - memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); - continue; - } - - // Accumulate rescaled outputs and l values - float global_l = 0.0f; - memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); - - for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { - const float* partial = partials_base + c * partial_stride; - float chunk_m = partial[0]; - float chunk_l = partial[1]; - const float* chunk_output = partial + 2; - - if (chunk_l <= 0.0f) { - continue; // masked chunk contributes nothing - } - - float rescale = std::exp(chunk_m - global_m); - global_l += rescale * chunk_l; - - // partial_output = S_exp * V where sum(S_exp) = l_c (unnormalized). - // Rescale by exp(m_c - global_m) to align all chunks to the same max. - for (ptrdiff_t i = 0; i < head_size; ++i) { - output_ptr[i] += rescale * chunk_output[i]; - } - } - - // output = sum_c(rescale_c * partial_output_c) / global_l - float inv_l = (global_l > 0.0f) ? (1.0f / global_l) : 0.0f; - for (ptrdiff_t i = 0; i < head_size; ++i) { - output_ptr[i] *= inv_l; - } - } -} - void MLASCALL MlasFlashAttentionGQA( @@ -595,28 +301,10 @@ MlasFlashAttentionGQA( MLAS_THREADPOOL* ThreadPool ) { - if (args->flash_decoding_partials != nullptr && args->sequence_length == 1) { - // Flash decoding: two-phase approach. - // Phase 1: parallel partial computation over (batch, head, kv_chunk). - MlasExecuteThreaded( - MlasFlashDecodingGQAThreaded, - static_cast(args), - static_cast(args->thread_count), - ThreadPool - ); - // Phase 2: reduce partials into final output (parallel over batch*heads). - MlasExecuteThreaded( - MlasFlashDecodingGQAReduceThreaded, - static_cast(args), - static_cast(args->thread_count), - ThreadPool - ); - } else { - MlasExecuteThreaded( - MlasFlashAttentionGQAThreaded, - static_cast(args), - static_cast(args->thread_count), - ThreadPool - ); - } + MlasExecuteThreaded( + MlasFlashAttentionGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); }