KV cache quantization for ORT WebGPU#28059
Conversation
6bd2002 to
74ba45a
Compare
ae56ae6 to
fb35f72
Compare
There was a problem hiding this comment.
Pull request overview
This PR introduces an experimental “TurboQuant” path for the WebGPU EP to reduce KV-cache memory by applying a Walsh–Hadamard rotation plus 4-bit quantization for K/V, with corresponding dequantization support in FlashAttention. It also adds an EP config option to enable the feature and improves WGSL shader error surfacing during shader-module creation.
Changes:
- Add a new WebGPU EP option (
ep.webgpuexecutionprovider.turboQuant) and plumb it through WebGPU EP config/context. - Implement Hadamard transform + TurboQuant quantize/copy kernels (WGSL + C++ program wrappers) and integrate them into WebGPU GQA/FlashAttention routing (prefill + a dedicated decode path).
- Add shader compilation-info logging in
ProgramManager::Buildto surface WGSL compilation errors earlier.
Reviewed changes
Copilot reviewed 22 out of 22 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/providers/webgpu/webgpu_provider_options.h | Adds a new provider option key/constants for TurboQuant. |
| onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc | Parses turboQuant config and stores quantization bits into EP config. |
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.h | Stores TurboQuant config in EP state and exposes accessors. |
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc | Plumbs TurboQuant config into the EP constructor. |
| onnxruntime/core/providers/webgpu/program_manager.cc | Adds shader compilation-info fetching/logging after shader-module creation. |
| onnxruntime/core/providers/webgpu/compute_context.h | Exposes TurboQuant settings on the compute context. |
| onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | Adjusts KV-cache shape expectations and FlashAttention routing for TurboQuant. |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.h | Adds TurboQuant parameters/uniforms to FlashAttention programs and adds new decode program classes. |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | Integrates TurboQuant fused paths, Hadamard pre/post transforms, and a 3-pass decode route. |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template | Adds TurboQuant dequantization path for KV loads (prefill). |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template | Adds TurboQuant dequantize-on-load implementation for decode QK^T. |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template | Adds TurboQuant dequantize-on-load implementation for decode SplitVx. |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce_tq.wgsl.template | Adds a TurboQuant-specific reduction shader for decode. |
| onnxruntime/contrib_ops/webgpu/bert/turbo_quant_common.wgsl.template | Introduces shared TurboQuant codebook + dequant helpers. |
| onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.wgsl.template | Adds fused Hadamard + quantize + copy-to-KV-cache WGSL kernel. |
| onnxruntime/contrib_ops/webgpu/bert/turbo_quant_fused_rotary.wgsl.template | Adds fused split+rotary+hadamard+quantize WGSL kernel for packed QKV. |
| onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.h | Declares TurboQuant program wrappers and entrypoints. |
| onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.cc | Implements TurboQuant program wrappers and dispatch setup. |
| onnxruntime/contrib_ops/webgpu/bert/hadamard_transform_common.wgsl.template | Adds shared WGSL butterfly routine used by multiple kernels. |
| onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.wgsl.template | Adds a standalone Hadamard transform shader. |
| onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.h | Declares the Hadamard transform program wrapper and entrypoint. |
| onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.cc | Implements the Hadamard transform program wrapper and dispatch setup. |
| onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h | Extends shape-check helpers to support a compressed KV head dimension override (used by WebGPU TurboQuant). |
Comments suppressed due to low confidence (1)
onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h:152
- The error text says past_key dim[3] should match
head_size, but whenkv_compressed_head_sizeoverrides the packed KV head dimension (e.g., TurboQuant), the expected value is no longerhead_size. Updating the message avoids confusing users when the override path is active.
if (past_key_dims[3] != packed_head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' dimension 3 should be same as head_size, got ",
past_key_dims[3], " expected ", packed_head_size);
1ba7ddf to
e594522
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 23 out of 23 changed files in this pull request and generated 4 comments.
Comments suppressed due to low confidence (1)
onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h:163
- The updated CheckPast() now supports quantized/packed KV caches (packed_head_size may differ from head_size), but the error message still says the dimension "should be same as head_size". This is misleading for TurboQuant/kv_cache_bit_width!=0 cases; it should report the expected packed dimension and/or mention quantization parameters.
if (past_key_dims[3] != packed_head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' dimension 3 should be same as head_size, got ",
past_key_dims[3], " expected ", packed_head_size);
}
if (past_value_dims[3] != packed_head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_value' dimension 3 should be same as head_size, got ",
past_value_dims[3], " expected ", packed_head_size);
}
1e837ac to
24089cb
Compare
qjia7
left a comment
There was a problem hiding this comment.
PR #28059 — TurboQuant KV Cache Compression for ORT WebGPU
Summary
Walsh-Hadamard rotation + 4-bit quantization on the WebGPU flash-attention path. Math is
sound (orthogonal WHT preserves QKᵀ; H is self-inverse on the V output). 18–36% memory
reduction with near-neutral throughput. Comments below.
Correctness
C1: TQ kernels read seqlen_k[0] regardless of batch index
turbo_quant_hadamard.wgsl.template and turbo_quant_fused_rotary_hadamard.wgsl.template
read u32(seqlen_k[0u]), while the sibling
split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template reads seqlen_k[batch].
seqlen_k is a [batch_size] storage buffer, so for batch_size > 1 the TQ kernels would
use batch 0's past length for every batch. If the TQ path is meant to assume identical past
lengths across batches, please add a comment at the read site (and ideally an assertion on
the C++ side) so the constraint is visible; otherwise this should be seqlen_k[batch].
C2: CheckPast refactor may break the CUDA INT4 KV-cache path (latent)
group_query_attention_helper.h::CheckPast switched from packed_head_size = head_size/2 to
(head_size * kv_cache_bit_width + kv_cache_extra_bits) / bits_per_element. CUDA GQA shares
this helper and passes kv_cache_bit_width=4. With fp16 past_key the new formula yields
head_size/4, but CUDA's INT4 cache dim is (head_size+1)/2. Gated today by
#ifdef USE_INT4_KV_CACHE, but it will fire the moment that flag is enabled. Either keep the
old formula on the kv_cache_extra_bits == 0 branch, or pass compressed_head_size_u32
directly (see Q2).
Input Validation
V1: ORT_ENFORCE for runtime shape validation in ApplyFlashAttention
ORT_ENFORCE aborts on failure; these checks are reachable from bad inputs. Use
ORT_RETURN_IF_ERROR(ORT_MAKE_STATUS(..., INVALID_ARGUMENT, ...)).
V2: Use == instead of >= for the TQ shape check
ORT_ENFORCE(present_key->Shape()[3] * bytes_per_elem >= compressed_head_size_u32 * 4, ...);== more precisely states the contract; >= quietly accepts mis-sized allocations.
V3: const_cast<void*>(past_key->DataRaw()) for the u32 view
The Tensor ctor takes a non-const data pointer, implying write ownership the alias does not
have. Add a comment that past_key_u32 / past_value_u32 are read-only, or use a const-data
construction path if available.
Compatibility / Scope
S1: ORT_ENFORCE for unsupported QKV format in TurboQuantCopyToQuantizedKVCache
ORT_ENFORCE(parameters.qkv_format_ == Q_K_V_BSNH, "...");Use ORT_RETURN_IF_ERROR(ORT_MAKE_STATUS(..., INVALID_ARGUMENT, ...)).
S2: Non-power-of-2 head sizes hard-error the whole model
TQ requires power-of-2 head_size (Hadamard constraint). The turboQuant flag is global, so
real-world models with head_size 96 / 80 / 48 / 160 fail end-to-end the moment it's set.
Please document supported configurations prominently, and consider per-node opt-in for a
follow-up PR.
S3: Compressed present-KV shape is a hidden contract with external allocators
The output last dim becomes (head_size*4 + 32) / bits_per_element u32 words, but ONNX shape
inference still reports head_size. Pre-allocated outputs (IO-binding, graph-capture, ORT
GenAI) get no signal about the compressed layout. Please document the allocator contract or
surface the compressed shape via metadata.
Design / Quality
Q1: head_size_log2 computation duplicated 3×
int head_size_log2 = 0;
for (int tmp = head_size; tmp > 1; tmp >>= 1) head_size_log2++;Extract to a file-static helper or use std::bit_width(unsigned(head_size)) - 1 (C++20).
Q2: CheckPast should take compressed_head_size_u32 directly
The current (head_size * bit_width + extra_bits) / bits_per_element encoding is
dtype-dependent and hides the "+32 bits for norm" convention. Passing
compressed_head_size_u32 directly makes the contract explicit and avoids the C2 latent bug.
Q3: Pre-multiply uniforms.alpha into all_q in flash_attention_decode_qkv.wgsl.template
let mq_lo = all_q[m][word_idx * 2u] * q_element_t(uniforms.alpha);alpha is constant; folding it into the Q-load phase eliminates the redundant multiply
inside the inner loop.
Q4: tq_unpack_nibbles duplicated across two templates
Same body in flash_attention.wgsl.template and flash_attention_decode_qkv.wgsl.template.
Move to turbo_quant_common.wgsl.template, or pass tq_lut as a parameter.
Q5: TurboQuantHadamardProgram initializer list on one line
200+ chars — format one member per line, as TurboQuantFusedRotaryProgram already does.
Q6: New TQ shaders bypass get/setByOffset for some mandatory bindings
turbo_quant_hadamard.wgsl.template:key/value/present_key/present_value/past_key/
past_valueare registeredUseUniform, but the template only opts into
#use .indicesToOffset. Extend to#use .indicesToOffset .getByOffset .setByOffsetand
switch tokey.getByOffset(...)/present_key.setByOffset(...).turbo_quant_fused_rotary_hadamard.wgsl.templatealready declares
#use .getByOffset .setByOffsetbut still writespresent_key[...]/present_value[...]
raw (lines 212, 214, 232, 234). Switch those tosetByOffset.
The bindings are u32-typed views, so setByOffset(idx, packed_u32) works without any extra
cast. hadamard_transform.wgsl.template is the model.
Nits
- Centroid values in
turbo_quant_common.wgsl.templateshould cite their source (QuaRot /
KVQuant / custom). use_smooth_softmaxandsliding_windowdisable flash attention, so TQ is implicitly
excluded via that path — worth a one-line comment inCanApplyFlashAttention.select(0.0f, 1.0f/l2_norm, l2_norm > 0.0f): WGSLselectis eager, so1.0/0.0is
computed and discarded. Harmless but please confirm NaN inputs land at centroid 0 with
scale 0.kTurboQuantparses only "0" / "4", but the plumbing (turbo_quantization_bits, "8-bit
future" comment) suggests more. Make sure unsupported values fail at every layer.
Verdict
Approve with changes. Pre-merge: C2 (CUDA INT4 latent regression) is the main item that
needs attention; S3 (allocator contract) needs at least documentation; everything else is
cleanup.
24089cb to
745ddbe
Compare
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Addressed Comment 1 - https://github.com/microsoft/onnxruntime-genai/pull/2084/changes#diff-8a33668cd8981709ba426dab1b3d3145bd0756d1ebf1203657ba8b8aae860d45 in genai repo. |
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
| const bool turbo_quant_enabled = context.KvCacheQuantizationEnabled(); | ||
|
|
||
| // Compressed head dimension, expressed in two units: | ||
| // compressed_head_size_u32 — u32 words per head (1 scale + head_size/8 packed 4-bit indices), | ||
| // passed to the shaders as the packed KV dimension. | ||
| // present_last_dim — the same span counted in Q elements (fp16/fp32), used to size an | ||
| // internally-allocated present buffer so its u32 view lines up | ||
| // (compressed_head_size_u32 * 4 bytes == present_last_dim * sizeof(Q elem)). | ||
| const int compressed_head_size_u32 = turbo_quant_enabled ? (parameters.head_size_ / 8 + 1) : 0; | ||
| const int64_t present_last_dim = | ||
| turbo_quant_enabled | ||
| ? static_cast<int64_t>(compressed_head_size_u32) * 4 / static_cast<int64_t>(Q->DataType()->Size()) | ||
| : parameters.head_size_; | ||
|
|
| std::vector<std::unique_ptr<IExecutionProvider>> execution_providers; | ||
| execution_providers.push_back(WebGpuEPWithTurboQuant4()); |
| std::vector<std::unique_ptr<IExecutionProvider>> execution_providers; | ||
| execution_providers.push_back(WebGpuEPWithTurboQuant4()); |
| tester.AddInput<float>("value", {batch_size, sequence_length, kv_hidden_size}, value_data); | ||
| } | ||
|
|
||
| // Past KV in compressed TQ4 format (random u32 data reinterpreted as float) |
Description
Turbo quant implementation for ORT WebGPU, using a Hadamard matrix for rotation instead of a regular matrix, which deviates from the paper.
Hadamard transform is kept as its own class and standalone shader -hadamard_transform.h - used to rotate/unrotate Q. Can be used by other feature in the future like activation quantization.
TurboQuantHadamard applies Hadamard transform and then quantizes using the centroid look up for q4.
Dequantization is all fused into the various flash attention kernels.
Caller for LLMs like gen-ai have to set kvCacheQuantizationBits:4 in the EP provider options and pass in present,past kv cache input, output tensors that have a reduced headsize.
With turboquant the headsize reduces from say using 16bits per value to 4bits and in addition there is a 32bit scale in the front of each token per head.
Note on impact on quality. Evaluting KV quantization 4 bits with Phi4 mini
Verdict: Under graded rubric scoring, 4-bit KV quantization shows a small but consistent quality penalty (≈ −0.28 on a 5-point scale). kv0 wins roughly 2× as many head-to-head matchups as kv4 (71 vs 37), though nearly half of all prompts (92/200) tie. The degradation is mild and uneven, not catastrophic — it concentrates in specific use cases rather than degrading everything.
Quality-score distribution
The main shift is at the top end: kv0 earns 62 perfect scores vs kv4's 42 (−20). Those lost 5s mostly slide down to 3s (+10) and 2s (+7). kv4 doesn't produce dramatically more total failures — it produces fewer flawless answers.
4-bit KV quant is acceptable for latency/memory-sensitive deployments where a ~0.3-point average quality dip is tolerable, except for tag-generation and content-detection workloads, where kv0 (no quant) is meaningfully better. If those use cases matter, keep KV quant off or pair it with
repetition_penalty > 1.0to suppress the tag-loop failures that dominate kv4's losses.