Skip to content

[KDA] KDA MTP decode: recurrent + KVBuffer chunkwise verify + flush#96

Open
Longxmas wants to merge 3 commits into
inclusionAI:mainfrom
Longxmas:feat/kda-mtp-verify-4ops
Open

[KDA] KDA MTP decode: recurrent + KVBuffer chunkwise verify + flush#96
Longxmas wants to merge 3 commits into
inclusionAI:mainfrom
Longxmas:feat/kda-mtp-verify-4ops

Conversation

@Longxmas

Copy link
Copy Markdown
Collaborator


📌 Description

Adds KDA (Kimi Delta Attention) multi-token-prediction (MTP) decode — the target-side gated-delta-rule recurrence for speculative decoding — in two complementary forms:

  • Recurrent (kda_decode_mtp): a single register-resident CuTe kernel threading the recurrence over the T draft tokens per (batch, head). Two 1-CTA ops — vk (lane=K butterfly-reduce) / ws (4-warp warp-spec) — behind a unified dispatch.

  • KVBuffer chunkwise verify (kda_decode_mtp_kvbuffer): the parallel-verification path — verify emits a compact u-buffer (~2·T·d) instead of the T per-token states (T·d²), and a rank-m flush rebuilds the accepted state. Two variants — tp (token-parallel SIMT) / cute-gemm (sm90 tensor-core, flat-in-T).

## What changed

  • cula/ops/kda_decode_mtp.py — recurrent vk/ws ops + unified dispatch.

  • cula/ops/kda_decode_mtp_kvbuffer.py — tp / cute-gemm chunkwise verify + rank-m flush.

  • tests/test_kda_decode_mtp.py — unit (vs fp32 oracle) + bit-exact determinism.

  • benchmarks/bench_kda_decode_mtp.py — unified verify-chain benchmark.

🔍 Related Issues

Closes #17

🧪 Tests

pytest tests/test_kda_decode_mtp.py — recurrent (vk/ws/kv) + kvbuffer (tp/cg) verify output & rank-m flush vs the fp32 single-token recurrence oracle, plus bit-exact determinism (torch.equal).

tests/test_kda_decode_mtp.py::test_oracle_vs_loop[N16-T8-H8-HV16-randstate] PASSED
tests/test_kda_decode_mtp.py::test_ws_decode[N64-T8-H8-HV16-tvNone-ilpNone-smemNone] PASSED
tests/test_kda_decode_mtp.py::test_small_batch_decode[N16-T4-H16-HV32-vk-bv-1-ks1] PASSED
tests/test_kda_decode_mtp.py::test_small_batch_decode[N16-T4-H16-HV32-kv-bv32-ks-1] PASSED
tests/test_kda_decode_mtp.py::test_determinism[sb_vk] PASSED
tests/test_kda_decode_mtp.py::test_intermediate_vs_oracle_and_final[64-4-True] PASSED
tests/test_kda_decode_mtp.py::test_tp_kvbuffer_verify_and_flush[4-4-16-16] PASSED
tests/test_kda_decode_mtp.py::test_cg_kvbuffer_verify_and_flush[4-6-16-16] PASSED
tests/test_kda_decode_mtp.py::test_kvbuffer_dispatch_routes_by_T[4-cg] PASSED
tests/test_kda_decode_mtp.py::test_kvbuffer_verify_determinism[cg-4-6-16-16] PASSED
tests/test_kda_decode_mtp.py::test_kvbuffer_flush_determinism[tp-4-4-16-16] PASSED
============== 106 passed, 27 warnings in 287.71s ==============

⚡ Performance

H200 (HBM3e), K=V=128, bf16, accept m=full, official sglang scatter commit, CUDA-graph kernel-only. Each cell = the best-dispatch verify + state-update chain speedup vs official Triton recurrent (fused_sigmoid_gating_delta_rule_update, SGLang). The unified dispatch picks the fastest of {vk, ws, tp, cg} per shape — recurrent vk/ws write T·d² states + scatter commit; KVBuffer tp/cg write a u-buffer + rank-m flush. >1 means faster than Triton.

Best method / Triton (HV=H=32)

B T=2 T=3 T=4 T=6
1 1.80 1.73 1.75 1.91
2 1.48 1.52 1.63 1.86
4 1.41 1.41 1.63 2.10
8 1.70 1.72 1.82 2.18
16 1.97 1.98 2.20 2.46
32 1.61 1.61 1.78 2.05
64 1.61 1.59 1.80 1.96
128 1.59 1.61 1.84 2.05

Best method / Triton (HV=H=64)

B T=2 T=3 T=4 T=6
1 1.46 1.49 1.68 1.91
2 1.41 1.41 1.57 2.09
4 1.65 1.71 1.82 2.13
8 2.03 2.03 2.23 2.50
16 1.61 1.62 1.78 2.06
32 1.63 1.60 1.80 1.96
64 1.58 1.61 1.84 2.06
128 1.59 1.64 1.85 2.09

Takeaways:

  • Best-dispatch chain vs Triton: 1.41× – 2.50× across all shapes (both HV), strongest at T≥4. The dispatch routes T=2 → tp (token-parallel SIMT, best at small T), T≥3 → cute-gemm (flat-in-T), and tiny B≤2 at small T → recurrent vk.

  • flat-in-T: the cute-gemm verify kernel grows only +14–20% from T=2→6 while Triton recurrent grows +104–124%, reproducing the KVBuffer paper's Fig.4; the verify kernel alone (accept-independent) reaches up to 3.51× (B=128, T=6).

  • Memory: the u-buffer (~2·T·d) replaces the T·d² intermediate states → ~43× less rollback storage (d=128), independent of latency.

  • Correctness: tp ≤ 6.1e-5, cg ≤ 2.44e-4 max|Δ| vs Triton (bf16), well within bf16 noise.

Longxmas added 2 commits June 16, 2026 16:37
Single-kernel recurrent gated-delta-rule multi-token-prediction decode with register-resident state. vk (lane=K butterfly-reduce) and ws (4-warp warp-spec) CuTe ops behind a unified dispatch; single-token T=1 routes to vk regardless of batch.
Chunkwise parallel-verification KVBuffer ops for KDA MTP speculative decoding: tp (token-parallel SIMT) and cute-gemm (sm90 tensor-core) verify emit a compact u-buffer instead of T per-token states; rank-m flush rebuilds the accepted state. Adds unit + determinism tests and the unified decode-mtp benchmark.
@Longxmas Longxmas requested a review from icavan June 16, 2026 09:12
gemini-code-assist[bot]

This comment was marked as outdated.

@Longxmas Longxmas requested a review from zheyang0825 June 16, 2026 09:30
…added CTAs, empty dummies

flush kvbuffer: accept_len is now per-request and read at runtime from an
[N] int32 buffer (m_buf[i_n]) instead of a compile-time constant. The kernel
statically unrolls T and masks i_i < m_n, so it compiles exactly once per
(shape, BV) regardless of accept length; b_m uses the per-request token m_n-1.
Host API accepts an int (broadcast to all N) or a per-request [N] tensor.

small-batch decode (vk + kv): wrap the compute body in `if cache_idx >= 0:` so
padded slots (cache_idx < 0) skip the whole T-loop, matching the ws kernel
(~1.3x on a half-padded batch). kv hoists its k_split constexpr decisions to
top level so they stay python constants inside the guarded block.

kvbuffer verify: torch.empty instead of torch.zeros for the write_ubuf=False
dummy buffers (only ever written, never read) — drops a per-call memset.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

KDA MTP (Multi-Token Prediction) support

1 participant