Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
${MLAS_SRC_DIR}/flashattn_qkv.cpp
${MLAS_SRC_DIR}/flashattn_gqa.cpp
${MLAS_SRC_DIR}/qkv_quant.cpp
${MLAS_SRC_DIR}/cast.cpp
${MLAS_SRC_DIR}/layernorm.cpp
Expand Down
101 changes: 96 additions & 5 deletions docs/contrib_ops/cpu/gqa.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ Quantized KV-cache GEMM helpers are implemented in MLAS:
- `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp`
- `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp`
- `onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp`
- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (flash attention tiled kernel)
- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (quantized-KV flash attention tiled kernel)

The non-quantized flash attention tiled kernel is implemented in MLAS:

- `onnxruntime/core/mlas/lib/flashattn_gqa.cpp` (FP32-KV flash attention tiled kernel)
- `onnxruntime/core/mlas/inc/mlas.h` (`MlasFlashAttentionGQA` declaration and `MlasFlashAttentionGQAArgs`)

The operator schema itself is defined in:

Expand Down Expand Up @@ -48,12 +53,14 @@ At a high level, the CPU kernel executes GroupQueryAttention in these stages:

The non-quantized and quantized paths share the surrounding validation, masking, softmax, and output flow. Their main difference is how the K/V cache is stored and read during QK and SV GEMMs.

The quantized path has two execution strategies:
Both the non-quantized and quantized paths have two execution strategies:

- **Naive (full materialization)**: Computes the full `[S, T]` attention score matrix, applies masking and softmax, then computes the SV product. Simple but memory-intensive for long sequences.
- **Flash Attention (tiled, online softmax)**: Processes K/V in L2-cache-sized blocks using the online softmax algorithm (Milakov & Gimelshein, 2018). Avoids materializing the full attention matrix, reducing peak memory from O(S×T) to O(S×Bc) per head. Multi-threaded via the MLAS thread pool.

The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path.
The quantized path uses `MlasFlashAttentionQuantizedKV` (`flashattn_qkv.cpp`); the non-quantized FP32 path uses `MlasFlashAttentionGQA` (`flashattn_gqa.cpp`). Both share the same tiling, masking, and online-softmax structure. The quantized path additionally provides a two-phase flash-decoding strategy for single-token decode; the non-quantized FP32 path is limited to prefill (`sequence_length > 1`) and uses the naive path for decode.

The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path (applies to both the quantized and non-quantized paths).

## Supported Cache Modes

Expand Down Expand Up @@ -144,9 +151,9 @@ For quantized V cache, the CPU path calls `MlasSVGemm` with:

As with QK GEMM, the default MLAS contract preserves the FP32 left-hand operand and dequantizes only the cached V values on the fly.

## Flash Attention Path
## Quantized Flash Attention Path

The flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix.
The quantized flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix.

### Algorithm

Expand Down Expand Up @@ -204,6 +211,58 @@ The partials buffer is allocated alongside the per-thread scratch in a single al
- Per-thread scratch: `scores[Bc]` (one float per KV block element)
- Partials: `batch × num_heads × kv_chunks × (2 + H)` floats (m, l, and partial output per chunk)

## Non-Quantized Flash Attention Path

The non-quantized flash attention path (`MlasFlashAttentionGQA`, in `flashattn_gqa.cpp`) is the FP32-KV-cache counterpart of the quantized path. It is selected for the `float` kernel specialization and reuses the same tiling, online-softmax, and masking structure. Unlike the quantized path, it is limited to prefill / chunked-prefill (`sequence_length > 1`); single-token decode (`sequence_length == 1`) uses the naive path, which is why there is no flash-decoding variant here.

### Differences from the Quantized Path

- **Cache element type**: The present K/V cache is FP32, laid out as BNSH (`[batch, kv_num_heads, seqlen_present, head_size]`). There is no quantize-on-write or dequantize-on-read step.
- **QK GEMM**: Uses the single-threaded SGEMM primitive `MlasSgemmOperation(CblasNoTrans, CblasTrans, ...)` on an FP32 K block instead of `MlasQKGemm`.
- **SV accumulate**: Uses `MlasSgemmOperation(CblasNoTrans, CblasNoTrans, ..., beta)` with `beta = 0` for the first KV block and `beta = 1` afterwards (accumulate) instead of `MlasSVGemm`.
- **Cache concat**: New K/V tokens are appended into the FP32 present cache with `ConcatStateChunkGQA<float>` before the tiled loop runs.

### Algorithm

For each (batch, head, q_block) tile:

1. **QK GEMM** — `MlasSgemmOperation` of the query tile against a block slice of the FP32 K cache (Bc rows at a time)
1b. **Attention bias** — Add the corresponding tile of the bias tensor (if present) to QK scores
2. **Causal + local window masking** — Set masked positions to −∞ before softmax
3. **Online softmax** — Track running max `m` and sum `l`, rescale accumulated output with `exp(m_old − m_new)`
4. **SV accumulate** — `MlasSgemmOperation(..., beta)` accumulates `softmax(QK_block) × V_block` into the output tile
5. **Finalize** — Normalize accumulated output by `1/l` after all KV blocks are processed

#### Causal early-termination

During prefill, every KV block whose start index is at or beyond the largest global query
position in the current q_block is fully causally masked and contributes nothing. The kernel
computes a per-q_block bound
`kv_causal_limit = past_seqlen + q_idx + row_size_q` and breaks out of the KV loop once
`ir >= kv_causal_limit`, instead of computing and then discarding the masked upper-triangle
QK/SV GEMMs. This skips roughly half of the QK/SV work for square prefill (S = T) and is the
main reason the FP32 flash path is faster than naive even at short sequence lengths
(see the benchmark results below).

### Activation Conditions

The non-quantized flash path is selected when ALL of the following hold:

- The kernel specialization is `float` (FP16 uses the naive path)
- `ORT_GQA_DISABLE_FLASH_ATTENTION` environment variable is not set (or set to `0`)
- `sequence_length > 1` (prefill / chunked-prefill; single-token decode uses the naive path)
- No softcap
- No smooth softmax
- No head sink
- No output QK capture
- `present_key` and `present_value` are provided

Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, and shared past/present buffers are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path.

### Block Sizes and Threading

Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, and the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`) are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. Because this path is prefill-only, it does not include the quantized path's two-phase flash-decoding strategy for single-token decode.

## MLAS Dispatch Paths

MLAS selects the best available quantized KV-cache GEMM implementation through the platform dispatch table.
Expand Down Expand Up @@ -428,7 +487,39 @@ Flash decoding IS active (batch×heads=4 < threads=8, KV partitioned across idle
| 4096 (N=32) | +2131 | +87 | 24.5x |

**Summary**: The flash path's primary benefit for prefill is **memory reduction** — avoiding the full O(N×S×T) attention matrix. For S=4096 with 16 heads, the naive path allocates ~1 GB for attention scores while the flash path uses ~80 MB regardless of sequence length. The prefill latency speedup (1.2–2.7x at kernel level, 1.2–1.9x at operator level) comes from improved cache locality. For decode, the tiled kernel provides 1.2–1.8x kernel-level speedup from fused single-pass KV access; at operator level the gain is visible for T≥1024 but masked by KV concat overhead at shorter sequences. When flash decoding is active (batch×heads < threads), KV partitioning across idle threads yields an additional 2–5x speedup for long sequences.
### Non-Quantized (FP32) Flash Attention vs Naive benchmark results

Measured on an AMD EPYC 7763 (32 logical / 16 physical cores), threads=8, FP32 KV cache,
`B=1, num_heads=16, kv_num_heads=8, head_size=128`. Operator-level, measured with:

```bash
python onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py \
--fp32 --prompt_only --warmup 10 --repeats 30
```

#### Latency — Prefill (S = T, prompt phase)

| Seq Length | Naive (ms) | Flash (ms) | Speedup |
|---:|---:|---:|---:|
| 512 | 5.8\u20138.4 | 4.2\u20135.3 | 1.4\u20131.6x |
| 1024 | 25\u201329 | 13\u201318 | 1.6\u20132.0x |
| 2048 | 87\u2013118 | 52\u201365 | 1.5\u20132.0x |
| 4096 | 365\u2013380 | 213\u2013234 | 1.6\u20131.7x |

The FP32 flash path is faster than naive across all measured prefill lengths. With the causal
early-termination described above, roughly half of the QK/SV work (the causally masked
upper triangle of the square prefill attention matrix) is skipped entirely, which more than
offsets the intrinsic per-KV-block online-softmax overhead (running max/exp/output rescale).
The same advantage holds single-threaded (1.4\u20131.8x at threads=1), confirming the gain is
algorithmic rather than purely from threading.

#### Decode (S = 1, token generation)

Single-token decode (`sequence_length == 1`) is **not** handled by the FP32 flash path; it falls
back to the naive path. Decode produces only a `[1, total_sequence_length]` score row per head,
so there is nothing to tile away, and the extra online-softmax bookkeeping made the flash kernel
slower and noisier in practice. Restricting the flash path to prefill (`sequence_length > 1`) keeps
the consistent prefill win without regressing decode.
## Current CPU Limitations

The current CPU GroupQueryAttention implementation has a few important limitations:
Expand Down
Loading
Loading