[KDA] KDA MTP decode: recurrent + KVBuffer chunkwise verify + flush#96
Open
Longxmas wants to merge 3 commits into
Open
[KDA] KDA MTP decode: recurrent + KVBuffer chunkwise verify + flush#96Longxmas wants to merge 3 commits into
Longxmas wants to merge 3 commits into
Conversation
Single-kernel recurrent gated-delta-rule multi-token-prediction decode with register-resident state. vk (lane=K butterfly-reduce) and ws (4-warp warp-spec) CuTe ops behind a unified dispatch; single-token T=1 routes to vk regardless of batch.
Chunkwise parallel-verification KVBuffer ops for KDA MTP speculative decoding: tp (token-parallel SIMT) and cute-gemm (sm90 tensor-core) verify emit a compact u-buffer instead of T per-token states; rank-m flush rebuilds the accepted state. Adds unit + determinism tests and the unified decode-mtp benchmark.
…added CTAs, empty dummies flush kvbuffer: accept_len is now per-request and read at runtime from an [N] int32 buffer (m_buf[i_n]) instead of a compile-time constant. The kernel statically unrolls T and masks i_i < m_n, so it compiles exactly once per (shape, BV) regardless of accept length; b_m uses the per-request token m_n-1. Host API accepts an int (broadcast to all N) or a per-request [N] tensor. small-batch decode (vk + kv): wrap the compute body in `if cache_idx >= 0:` so padded slots (cache_idx < 0) skip the whole T-loop, matching the ws kernel (~1.3x on a half-padded batch). kv hoists its k_split constexpr decisions to top level so they stay python constants inside the guarded block. kvbuffer verify: torch.empty instead of torch.zeros for the write_ubuf=False dummy buffers (only ever written, never read) — drops a per-call memset.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
📌 Description
Adds KDA (Kimi Delta Attention) multi-token-prediction (MTP) decode — the target-side gated-delta-rule recurrence for speculative decoding — in two complementary forms:
Recurrent (
kda_decode_mtp): a single register-resident CuTe kernel threading the recurrence over the T draft tokens per (batch, head). Two 1-CTA ops —vk(lane=K butterfly-reduce) /ws(4-warp warp-spec) — behind a unified dispatch.KVBuffer chunkwise verify (
kda_decode_mtp_kvbuffer): the parallel-verification path — verify emits a compact u-buffer (~2·T·d) instead of the T per-token states (T·d²), and a rank-m flush rebuilds the accepted state. Two variants —tp(token-parallel SIMT) /cute-gemm(sm90 tensor-core, flat-in-T).## What changed
cula/ops/kda_decode_mtp.py— recurrent vk/ws ops + unified dispatch.cula/ops/kda_decode_mtp_kvbuffer.py— tp / cute-gemm chunkwise verify + rank-m flush.tests/test_kda_decode_mtp.py— unit (vs fp32 oracle) + bit-exact determinism.benchmarks/bench_kda_decode_mtp.py— unified verify-chain benchmark.🔍 Related Issues
Closes #17
🧪 Tests
pytest tests/test_kda_decode_mtp.py— recurrent (vk/ws/kv) + kvbuffer (tp/cg) verify output & rank-m flush vs the fp32 single-token recurrence oracle, plus bit-exact determinism (torch.equal).⚡ Performance
H200 (HBM3e), K=V=128, bf16, accept m=full, official sglang scatter commit, CUDA-graph kernel-only. Each cell = the best-dispatch verify + state-update chain speedup vs official Triton recurrent (
fused_sigmoid_gating_delta_rule_update, SGLang). The unified dispatch picks the fastest of {vk, ws, tp, cg} per shape — recurrent vk/ws write T·d² states + scatter commit; KVBuffer tp/cg write a u-buffer + rank-m flush. >1 means faster than Triton.Best method / Triton (HV=H=32)
Best method / Triton (HV=H=64)
Takeaways:
Best-dispatch chain vs Triton: 1.41× – 2.50× across all shapes (both HV), strongest at T≥4. The dispatch routes T=2 → tp (token-parallel SIMT, best at small T), T≥3 → cute-gemm (flat-in-T), and tiny B≤2 at small T → recurrent vk.
flat-in-T: the cute-gemm verify kernel grows only +14–20% from T=2→6 while Triton recurrent grows +104–124%, reproducing the KVBuffer paper's Fig.4; the verify kernel alone (accept-independent) reaches up to 3.51× (B=128, T=6).
Memory: the u-buffer (~2·T·d) replaces the T·d² intermediate states → ~43× less rollback storage (d=128), independent of latency.
Correctness: tp ≤ 6.1e-5, cg ≤ 2.44e-4 max|Δ| vs Triton (bf16), well within bf16 noise.