Skip to content

KV cache quantization for ORT WebGPU#28059

Open
sushraja-msft wants to merge 35 commits into
mainfrom
user/sushraja/turbo_quant
Open

KV cache quantization for ORT WebGPU#28059
sushraja-msft wants to merge 35 commits into
mainfrom
user/sushraja/turbo_quant

Conversation

@sushraja-msft

@sushraja-msft sushraja-msft commented Apr 14, 2026

Copy link
Copy Markdown
Contributor

Description

Turbo quant implementation for ORT WebGPU, using a Hadamard matrix for rotation instead of a regular matrix, which deviates from the paper.

seq_length turbo_quant prefill_tps generation_tps working_set_gb gpu_memory_gb % saving
1024 ❌ Off 2007.21 108.48 1.84 4.14
✅ On 2053.70 113.06 1.64 3.37 18.6%
2048 ❌ Off 1778.32 111.26 1.95 5.49
✅ On 1763.69 111.83 1.90 3.75 31.7%
4096 ❌ Off 1373.88 29.78 1.61 7.21
✅ On 1367.89 29.51 2.30 4.96 31.2%
8192 ❌ Off 948.90 84.17 2.44 10.14
✅ On 943.96 82.36 2.31 6.50 35.9%
  • Hadamard transform is kept as its own class and standalone shader -hadamard_transform.h‎ - used to rotate/unrotate Q. Can be used by other feature in the future like activation quantization.

  • TurboQuantHadamard applies Hadamard transform and then quantizes using the centroid look up for q4.

  • Dequantization is all fused into the various flash attention kernels.

Caller for LLMs like gen-ai have to set kvCacheQuantizationBits:4 in the EP provider options and pass in present,past kv cache input, output tensors that have a reduced headsize.

With turboquant the headsize reduces from say using 16bits per value to 4bits and in addition there is a 32bit scale in the front of each token per head.


Note on impact on quality. Evaluting KV quantization 4 bits with Phi4 mini

Metric kv0 (no quant) kv4 (4-bit) Δ (kv4 − kv0)
Mean quality score (0–5) 3.64 3.36 −0.28
Head-to-head wins 71 37
Ties 92
Broken/failed (score ≤ 1) 13 17 +4
Hard-broken (score = 0) 2 2 0

Verdict: Under graded rubric scoring, 4-bit KV quantization shows a small but consistent quality penalty (≈ −0.28 on a 5-point scale). kv0 wins roughly 2× as many head-to-head matchups as kv4 (71 vs 37), though nearly half of all prompts (92/200) tie. The degradation is mild and uneven, not catastrophic — it concentrates in specific use cases rather than degrading everything.

Quality-score distribution

Score kv0 kv4
5 (perfect) 62 42
4 (good) 56 55
3 (bearable) 44 54
2 (significant issues) 25 32
1 (serious issues) 11 15
0 (broken) 2 2

The main shift is at the top end: kv0 earns 62 perfect scores vs kv4's 42 (−20). Those lost 5s mostly slide down to 3s (+10) and 2s (+7). kv4 doesn't produce dramatically more total failures — it produces fewer flawless answers.

4-bit KV quant is acceptable for latency/memory-sensitive deployments where a ~0.3-point average quality dip is tolerable, except for tag-generation and content-detection workloads, where kv0 (no quant) is meaningfully better. If those use cases matter, keep KV quant off or pair it with repetition_penalty > 1.0 to suppress the tag-loop failures that dominate kv4's losses.

@sushraja-msft sushraja-msft changed the title User/sushraja/turbo quant WIP: Turbo quant for ORT WebGPU Apr 14, 2026
@sushraja-msft sushraja-msft marked this pull request as draft April 14, 2026 02:59

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.cc Outdated
Comment thread onnxruntime/core/providers/webgpu/webgpu_execution_provider.h Outdated
Comment thread samples/cxx/GqaTest.cc Outdated
Comment thread samples/cxx/GqaTest.cc Outdated
Comment thread samples/cxx/GqaTest.cc Outdated
Comment thread samples/cxx/generate_gqa_model.py Outdated
Comment thread samples/cxx/generate_gqa_model.py Outdated
Comment thread samples/cxx/generate_gqa_model.py Outdated
Comment thread samples/cxx/generate_gqa_model.py Outdated
Comment thread samples/cxx/generate_gqa_model.py Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.cc Fixed
Comment thread samples/cxx/GqaTest.cc Fixed
Comment thread samples/cxx/generate_gqa_model.py Fixed
@sushraja-msft sushraja-msft changed the title WIP: Turbo quant for ORT WebGPU WIP: TurboQuant for ORT WebGPU Apr 14, 2026
@sushraja-msft sushraja-msft force-pushed the user/sushraja/turbo_quant branch from 6bd2002 to 74ba45a Compare June 4, 2026 06:04
Comment thread onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.h Fixed

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.h Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.h Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.h Outdated
Comment thread onnxruntime/core/providers/webgpu/webgpu_execution_provider.h Outdated

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.h Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.h Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.h Outdated
Comment thread onnxruntime/core/providers/webgpu/program_manager.cc Outdated
@sushraja-msft sushraja-msft force-pushed the user/sushraja/turbo_quant branch from ae56ae6 to fb35f72 Compare June 12, 2026 01:45

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/core/providers/webgpu/webgpu_execution_provider.h Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces an experimental “TurboQuant” path for the WebGPU EP to reduce KV-cache memory by applying a Walsh–Hadamard rotation plus 4-bit quantization for K/V, with corresponding dequantization support in FlashAttention. It also adds an EP config option to enable the feature and improves WGSL shader error surfacing during shader-module creation.

Changes:

  • Add a new WebGPU EP option (ep.webgpuexecutionprovider.turboQuant) and plumb it through WebGPU EP config/context.
  • Implement Hadamard transform + TurboQuant quantize/copy kernels (WGSL + C++ program wrappers) and integrate them into WebGPU GQA/FlashAttention routing (prefill + a dedicated decode path).
  • Add shader compilation-info logging in ProgramManager::Build to surface WGSL compilation errors earlier.

Reviewed changes

Copilot reviewed 22 out of 22 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
onnxruntime/core/providers/webgpu/webgpu_provider_options.h Adds a new provider option key/constants for TurboQuant.
onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc Parses turboQuant config and stores quantization bits into EP config.
onnxruntime/core/providers/webgpu/webgpu_execution_provider.h Stores TurboQuant config in EP state and exposes accessors.
onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc Plumbs TurboQuant config into the EP constructor.
onnxruntime/core/providers/webgpu/program_manager.cc Adds shader compilation-info fetching/logging after shader-module creation.
onnxruntime/core/providers/webgpu/compute_context.h Exposes TurboQuant settings on the compute context.
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Adjusts KV-cache shape expectations and FlashAttention routing for TurboQuant.
onnxruntime/contrib_ops/webgpu/bert/flash_attention.h Adds TurboQuant parameters/uniforms to FlashAttention programs and adds new decode program classes.
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Integrates TurboQuant fused paths, Hadamard pre/post transforms, and a 3-pass decode route.
onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template Adds TurboQuant dequantization path for KV loads (prefill).
onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template Adds TurboQuant dequantize-on-load implementation for decode QK^T.
onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template Adds TurboQuant dequantize-on-load implementation for decode SplitVx.
onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce_tq.wgsl.template Adds a TurboQuant-specific reduction shader for decode.
onnxruntime/contrib_ops/webgpu/bert/turbo_quant_common.wgsl.template Introduces shared TurboQuant codebook + dequant helpers.
onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.wgsl.template Adds fused Hadamard + quantize + copy-to-KV-cache WGSL kernel.
onnxruntime/contrib_ops/webgpu/bert/turbo_quant_fused_rotary.wgsl.template Adds fused split+rotary+hadamard+quantize WGSL kernel for packed QKV.
onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.h Declares TurboQuant program wrappers and entrypoints.
onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.cc Implements TurboQuant program wrappers and dispatch setup.
onnxruntime/contrib_ops/webgpu/bert/hadamard_transform_common.wgsl.template Adds shared WGSL butterfly routine used by multiple kernels.
onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.wgsl.template Adds a standalone Hadamard transform shader.
onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.h Declares the Hadamard transform program wrapper and entrypoint.
onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.cc Implements the Hadamard transform program wrapper and dispatch setup.
onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h Extends shape-check helpers to support a compressed KV head dimension override (used by WebGPU TurboQuant).
Comments suppressed due to low confidence (1)

onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h:152

  • The error text says past_key dim[3] should match head_size, but when kv_compressed_head_size overrides the packed KV head dimension (e.g., TurboQuant), the expected value is no longer head_size. Updating the message avoids confusing users when the override path is active.
  if (past_key_dims[3] != packed_head_size) {
    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
                           "Input 'past_key' dimension 3 should be same as head_size, got ",
                           past_key_dims[3], " expected ", packed_head_size);

Comment thread onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template Outdated
@sushraja-msft sushraja-msft force-pushed the user/sushraja/turbo_quant branch from 1ba7ddf to e594522 Compare June 13, 2026 16:31
Comment thread onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.h Fixed

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.h Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.h Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.h Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 23 out of 23 changed files in this pull request and generated 4 comments.

Comments suppressed due to low confidence (1)

onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h:163

  • The updated CheckPast() now supports quantized/packed KV caches (packed_head_size may differ from head_size), but the error message still says the dimension "should be same as head_size". This is misleading for TurboQuant/kv_cache_bit_width!=0 cases; it should report the expected packed dimension and/or mention quantization parameters.
  if (past_key_dims[3] != packed_head_size) {
    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
                           "Input 'past_key' dimension 3 should be same as head_size, got ",
                           past_key_dims[3], " expected ", packed_head_size);
  }
  if (past_value_dims[3] != packed_head_size) {
    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
                           "Input 'past_value' dimension 3 should be same as head_size, got ",
                           past_value_dims[3], " expected ", packed_head_size);
  }

Comment thread onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
Comment thread onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.cc
Comment thread onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.cc
@sushraja-msft sushraja-msft force-pushed the user/sushraja/turbo_quant branch from 1e837ac to 24089cb Compare June 15, 2026 19:17
@sushraja-msft sushraja-msft marked this pull request as ready for review June 15, 2026 22:48
@sushraja-msft sushraja-msft changed the title WIP: TurboQuant for ORT WebGPU TurboQuant style KV cache compression for ORT WebGPU Jun 15, 2026

@qjia7 qjia7 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR #28059 — TurboQuant KV Cache Compression for ORT WebGPU

Summary

Walsh-Hadamard rotation + 4-bit quantization on the WebGPU flash-attention path. Math is
sound (orthogonal WHT preserves QKᵀ; H is self-inverse on the V output). 18–36% memory
reduction with near-neutral throughput. Comments below.


Correctness

C1: TQ kernels read seqlen_k[0] regardless of batch index

turbo_quant_hadamard.wgsl.template and turbo_quant_fused_rotary_hadamard.wgsl.template
read u32(seqlen_k[0u]), while the sibling
split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template reads seqlen_k[batch].
seqlen_k is a [batch_size] storage buffer, so for batch_size > 1 the TQ kernels would
use batch 0's past length for every batch. If the TQ path is meant to assume identical past
lengths across batches, please add a comment at the read site (and ideally an assertion on
the C++ side) so the constraint is visible; otherwise this should be seqlen_k[batch].

C2: CheckPast refactor may break the CUDA INT4 KV-cache path (latent)

group_query_attention_helper.h::CheckPast switched from packed_head_size = head_size/2 to
(head_size * kv_cache_bit_width + kv_cache_extra_bits) / bits_per_element. CUDA GQA shares
this helper and passes kv_cache_bit_width=4. With fp16 past_key the new formula yields
head_size/4, but CUDA's INT4 cache dim is (head_size+1)/2. Gated today by
#ifdef USE_INT4_KV_CACHE, but it will fire the moment that flag is enabled. Either keep the
old formula on the kv_cache_extra_bits == 0 branch, or pass compressed_head_size_u32
directly (see Q2).


Input Validation

V1: ORT_ENFORCE for runtime shape validation in ApplyFlashAttention

ORT_ENFORCE aborts on failure; these checks are reachable from bad inputs. Use
ORT_RETURN_IF_ERROR(ORT_MAKE_STATUS(..., INVALID_ARGUMENT, ...)).

V2: Use == instead of >= for the TQ shape check

ORT_ENFORCE(present_key->Shape()[3] * bytes_per_elem >= compressed_head_size_u32 * 4, ...);

== more precisely states the contract; >= quietly accepts mis-sized allocations.

V3: const_cast<void*>(past_key->DataRaw()) for the u32 view

The Tensor ctor takes a non-const data pointer, implying write ownership the alias does not
have. Add a comment that past_key_u32 / past_value_u32 are read-only, or use a const-data
construction path if available.


Compatibility / Scope

S1: ORT_ENFORCE for unsupported QKV format in TurboQuantCopyToQuantizedKVCache

ORT_ENFORCE(parameters.qkv_format_ == Q_K_V_BSNH, "...");

Use ORT_RETURN_IF_ERROR(ORT_MAKE_STATUS(..., INVALID_ARGUMENT, ...)).

S2: Non-power-of-2 head sizes hard-error the whole model

TQ requires power-of-2 head_size (Hadamard constraint). The turboQuant flag is global, so
real-world models with head_size 96 / 80 / 48 / 160 fail end-to-end the moment it's set.
Please document supported configurations prominently, and consider per-node opt-in for a
follow-up PR.

S3: Compressed present-KV shape is a hidden contract with external allocators

The output last dim becomes (head_size*4 + 32) / bits_per_element u32 words, but ONNX shape
inference still reports head_size. Pre-allocated outputs (IO-binding, graph-capture, ORT
GenAI) get no signal about the compressed layout. Please document the allocator contract or
surface the compressed shape via metadata.


Design / Quality

Q1: head_size_log2 computation duplicated 3×

int head_size_log2 = 0;
for (int tmp = head_size; tmp > 1; tmp >>= 1) head_size_log2++;

Extract to a file-static helper or use std::bit_width(unsigned(head_size)) - 1 (C++20).

Q2: CheckPast should take compressed_head_size_u32 directly

The current (head_size * bit_width + extra_bits) / bits_per_element encoding is
dtype-dependent and hides the "+32 bits for norm" convention. Passing
compressed_head_size_u32 directly makes the contract explicit and avoids the C2 latent bug.

Q3: Pre-multiply uniforms.alpha into all_q in flash_attention_decode_qkv.wgsl.template

let mq_lo = all_q[m][word_idx * 2u] * q_element_t(uniforms.alpha);

alpha is constant; folding it into the Q-load phase eliminates the redundant multiply
inside the inner loop.

Q4: tq_unpack_nibbles duplicated across two templates

Same body in flash_attention.wgsl.template and flash_attention_decode_qkv.wgsl.template.
Move to turbo_quant_common.wgsl.template, or pass tq_lut as a parameter.

Q5: TurboQuantHadamardProgram initializer list on one line

200+ chars — format one member per line, as TurboQuantFusedRotaryProgram already does.

Q6: New TQ shaders bypass get/setByOffset for some mandatory bindings

  • turbo_quant_hadamard.wgsl.template: key/value/present_key/present_value/past_key/
    past_value are registered UseUniform, but the template only opts into
    #use .indicesToOffset. Extend to #use .indicesToOffset .getByOffset .setByOffset and
    switch to key.getByOffset(...) / present_key.setByOffset(...).
  • turbo_quant_fused_rotary_hadamard.wgsl.template already declares
    #use .getByOffset .setByOffset but still writes present_key[...] / present_value[...]
    raw (lines 212, 214, 232, 234). Switch those to setByOffset.

The bindings are u32-typed views, so setByOffset(idx, packed_u32) works without any extra
cast. hadamard_transform.wgsl.template is the model.


Nits

  • Centroid values in turbo_quant_common.wgsl.template should cite their source (QuaRot /
    KVQuant / custom).
  • use_smooth_softmax and sliding_window disable flash attention, so TQ is implicitly
    excluded via that path — worth a one-line comment in CanApplyFlashAttention.
  • select(0.0f, 1.0f/l2_norm, l2_norm > 0.0f): WGSL select is eager, so 1.0/0.0 is
    computed and discarded. Harmless but please confirm NaN inputs land at centroid 0 with
    scale 0.
  • kTurboQuant parses only "0" / "4", but the plumbing (turbo_quantization_bits, "8-bit
    future" comment) suggests more. Make sure unsupported values fail at every layer.

Verdict

Approve with changes. Pre-merge: C2 (CUDA INT4 latent regression) is the main item that
needs attention; S3 (allocator contract) needs at least documentation; everything else is
cleanup.

@sushraja-msft sushraja-msft force-pushed the user/sushraja/turbo_quant branch from 24089cb to 745ddbe Compare June 23, 2026 00:04
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
sushraja-msft and others added 4 commits June 25, 2026 08:31
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
@sushraja-msft

Copy link
Copy Markdown
Contributor Author

Review: TurboQuant / KVCacheQuantization

Overall: APPROVE, conditional on (1) resolving the WebGPU CI test failures and (2) addressing Comment 1 below. Comment 2 is a follow-up that can land in a separate PR — it is not a blocker for this one.

Comment 1 — 🔴 BLOCKING: quantized KV cache cannot load on any model this stack produces ("existing models" claim unmet)

Where: GenAI src/python/py/models/builders/base.py:165-166 (past inputs) and 186-187 (present outputs); interacts with ORT core/session/inference_session.cc:2885-2904 (CheckShapes).

The PR enables KV-cache quantization via kvCacheQuantizationBits, but turning it on makes the model fail to load at session creation:

Got invalid dimensions for input: past_key_values.0.key for the following indices
 index: 3 Got: 34 Expected: 128
 Please fix either the inputs/outputs or the model.

Root cause: with quantization on, GenAI's kv_cache.cpp correctly allocates a compressed KV buffer (head_dim 128 → 34 for 4-bit fp16: 128/8 + 1 = 17 u32 words × 2 = 34), but builder.py declares the KV last dim as a static value (self.head_size). ORT's InferenceSession::CheckShapes enforces every non-symbolic declared dim (expected_shape[i] < 0 is the only skip — I verified there is no quantization-aware relaxation anywhere in core/framework or the WebGPU EP), so it rejects the 34-wide feed before the (quantization-aware) GQA CheckPast ever runs.

I verified this end-to-end on phi4-prune: with the KV last dim static, load fails with the error above; after making past_key_values.*/present.* dim 3 symbolic, the same quantized run produces correct output ("1 + 1 = 2"). As written, the feature works on neither existing exported models nor newly built ones without a manual graph edit. Please address one of:

  1. (GenAI) Make builder.py emit a symbolic KV last dim for past inputs/present outputs, so one export serves both quantized and non-quantized runs; and/or
  2. (ORT) Relax session-level KV input-shape validation when kvCacheQuantizationBits != 0 (the kernel's CheckPast already validates the compressed buffer at runtime).

Comment 2 — 🟡 Non-blocking (follow-up PR): 8-bit support is asymmetric between GenAI and ORT

Where: GenAI src/config.cpp:1459-1461 + src/models/kv_cache.cpp:165-180; ORT core/providers/webgpu/webgpu_provider_factory.cc:103-108 + contrib_ops/webgpu/bert/group_query_attention.cc:248-249.

The GenAI side already has a full 8-bit path (GetKvCacheQuantizationBits returns 8; ComputeQuantizedKvCacheHeadSize handles 8-bit with indices_per_word = 4), but ORT does not: webgpu_provider_factory.cc accepts only "0"/"4" and ORT_THROWs on "8", and group_query_attention.cc hardcodes kv_cache_bit_width = quant ? 4 : 0. Setting kvCacheQuantizationBits: "8" therefore makes GenAI allocate an 8-bit cache that ORT can't honor (I observed the run hang rather than error cleanly when the model is symbolic).

This doesn't need to block this PR — the 4-bit path is the intended scope. Suggestion: for now, either gate "8" out of GenAI or have it reject values ORT doesn't support, and implement the full 8-bit kernel in a separate follow-up PR once this one lands. Just keep the two sides' accepted value-sets in sync at all times so "8" never silently mis-allocates.

On the CI failures: they are related to this PR. The two failing tests — WebGPU_TurboQuant_Decode_MultiBatch_K24 and WebGPU_TurboQuant_CrossValidate_MultiBatch — are TurboQuant tests added by this PR itself, and they're the only failures in the run (all CUDA/TensorRT/CPU/Android/minimal jobs pass). They must be green before merge.

Addressed Comment 1 - https://github.com/microsoft/onnxruntime-genai/pull/2084/changes#diff-8a33668cd8981709ba426dab1b3d3145bd0756d1ebf1203657ba8b8aae860d45 in genai repo.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 24 out of 24 changed files in this pull request and generated 1 comment.

Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 24 out of 24 changed files in this pull request and generated 3 comments.

Comment thread onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
Comment thread onnxruntime/core/providers/webgpu/webgpu_provider_options.h
Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 24 out of 24 changed files in this pull request and generated 3 comments.

Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Outdated
Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.h Outdated
sushraja-msft and others added 2 commits June 25, 2026 17:40
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Outdated
sushraja-msft and others added 2 commits June 25, 2026 18:29
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 24 out of 24 changed files in this pull request and generated 4 comments.

Comment on lines +430 to +443
const bool turbo_quant_enabled = context.KvCacheQuantizationEnabled();

// Compressed head dimension, expressed in two units:
// compressed_head_size_u32 — u32 words per head (1 scale + head_size/8 packed 4-bit indices),
// passed to the shaders as the packed KV dimension.
// present_last_dim — the same span counted in Q elements (fp16/fp32), used to size an
// internally-allocated present buffer so its u32 view lines up
// (compressed_head_size_u32 * 4 bytes == present_last_dim * sizeof(Q elem)).
const int compressed_head_size_u32 = turbo_quant_enabled ? (parameters.head_size_ / 8 + 1) : 0;
const int64_t present_last_dim =
turbo_quant_enabled
? static_cast<int64_t>(compressed_head_size_u32) * 4 / static_cast<int64_t>(Q->DataType()->Size())
: parameters.head_size_;

Comment on lines +3039 to +3040
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(WebGpuEPWithTurboQuant4());
Comment on lines +3099 to +3100
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(WebGpuEPWithTurboQuant4());
tester.AddInput<float>("value", {batch_size, sequence_length, kv_hidden_size}, value_data);
}

// Past KV in compressed TQ4 format (random u32 data reinterpreted as float)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants