diff --git a/benchmarks/bench_kda_decode_mtp.py b/benchmarks/bench_kda_decode_mtp.py new file mode 100644 index 00000000..ed730eab --- /dev/null +++ b/benchmarks/bench_kda_decode_mtp.py @@ -0,0 +1,578 @@ +"""KDA MTP decode benchmark — recurrent vs KVBuffer (chunkwise) verify CHAIN. + +Unified bench (supersedes the old forward-only bench_kda_decode_mtp and +bench_kda_kvbuffer). Variants, selectable via --only / --profile: + recurrent verify: vk / ws / tri (official Triton), all writing T*d^2 states; + kvbuffer verify: tpkvb (token-parallel) / cgkvb (CuTe sm_90 tensor-core GEMM + form, flat-in-T), both writing the compact u-buffer; + forward-only baselines (no rollback cost, breakdown table only): kv / auto / loop. + +Chain: REC = recurrent verify (writes T·d² intermediate states) + commit; KVB = +kvbuffer verify (emit output + write a compact u-buffer) + flush (rank-m rebuild of +S_m). spd = REC / KVB. The commit uses the REAL sglang fused_mamba_state_scatter_with_mask +(from KDA_SCATTER_FILE) so the recurrent rollback cost is official code, not a model. + +Self-contained (inlines input/timing helpers). Triton recurrent baseline (numerical +check only) from KDA_TRITON_FILE; scatter commit from KDA_SCATTER_FILE. +""" + +import argparse +import importlib.util +import os + +import torch + +from cula.ops.kda_decode import kda_decode +from cula.ops.kda_decode_mtp import ( + kda_decode_mtp, + kda_decode_mtp_small_batch, + kda_decode_mtp_ws, +) +from cula.ops.kda_decode_mtp_kvbuffer import kda_flush_kvbuffer + +# tp-kvbuffer (token-parallel, structure B) is optional too. +try: + from cula.ops.kda_decode_mtp_kvbuffer import kda_decode_mtp_tp_kvbuffer + _HAVE_TPKVB = True +except Exception: + _HAVE_TPKVB = False + +# gemm-kvbuffer (CuTe sm_90 tensor-core, flat-in-T verify). +try: + from cula.ops.kda_decode_mtp_kvbuffer import kda_decode_mtp_gemm_kvbuffer_cute + _HAVE_CGKVB = True +except Exception: + _HAVE_CGKVB = False + + +def _load_from_file(path, attr): + """Load a single attribute from a standalone .py file via importlib.""" + spec = importlib.util.spec_from_file_location(f"_standalone_{attr}", path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return getattr(mod, attr) + + +# Triton recurrent baseline (numerical check only). +_HAVE_TRITON, _TRITON_ERR = True, "" +fused_sigmoid_gating_delta_rule_update = None +try: + _f = os.environ.get("KDA_TRITON_FILE", "") + if _f and os.path.exists(_f): + fused_sigmoid_gating_delta_rule_update = _load_from_file( + _f, "fused_sigmoid_gating_delta_rule_update") + else: + from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import ( + fused_sigmoid_gating_delta_rule_update, + ) +except Exception as e: + _HAVE_TRITON, _TRITON_ERR = False, repr(e) + +# Official sglang scatter commit (update_mamba_state_after_mtp_verify). +_HAVE_SCATTER, _SCATTER_ERR = True, "" +fused_mamba_state_scatter_with_mask = None +try: + _f = os.environ.get("KDA_SCATTER_FILE", "") + if _f and os.path.exists(_f): + fused_mamba_state_scatter_with_mask = _load_from_file( + _f, "fused_mamba_state_scatter_with_mask") + else: + from sglang.srt.layers.attention.mamba.mamba_state_scatter_triton import ( + fused_mamba_state_scatter_with_mask, + ) +except Exception as e: + _HAVE_SCATTER, _SCATTER_ERR = False, repr(e) + + +def make_dense_inputs(N, T, H, HV, K, V, device, seed=42): + g = torch.Generator(device=device).manual_seed(seed) + bf16 = torch.bfloat16 + q = torch.randn(N, T, H, K, device=device, dtype=bf16, generator=g) + k = torch.randn(N, T, H, K, device=device, dtype=bf16, generator=g) + v = torch.randn(N, T, HV, V, device=device, dtype=bf16, generator=g) + a = (torch.randn(N, T, HV, K, device=device, dtype=torch.float32, generator=g) * 0.1).to(bf16) + b = torch.randn(N, T, HV, device=device, dtype=bf16, generator=g) + A_log = -torch.rand(HV, device=device, dtype=torch.float32, generator=g) * 2 + dt_bias = torch.randn(HV, K, device=device, dtype=torch.float32, generator=g) * 0.1 + state = torch.randn(N, HV, V, K, device=device, dtype=torch.float32, generator=g) * 0.01 + indices = torch.arange(N, device=device, dtype=torch.int32) + return q, k, v, a, b, A_log, dt_bias, state, indices + + +def to_triton_varlen(q, k, v, a, b): + N, T, H, K = q.shape + HV, V = v.shape[2], v.shape[3] + NT = N * T + q_t = q.reshape(1, NT, H, K).contiguous() + k_t = k.reshape(1, NT, H, K).contiguous() + v_t = v.reshape(1, NT, HV, V).contiguous() + a_t = a.reshape(1, NT, HV * K).contiguous() + b_t = b.reshape(1, NT, HV).contiguous() + cu_seqlens = torch.arange(0, (N + 1) * T, T, device=q.device, dtype=torch.int32) + return q_t, k_t, v_t, a_t, b_t, cu_seqlens + + +def make_triton_call(qt, kt, vt, at, bt, cu_seqlens, A_log, dt_bias, state, indices, scale, dsu, + inter_buf=None, inter_idx=None, cache_steps=None): + """Official sglang recurrent verify. In verify mode (inter_buf set) it writes the T·d² + intermediate_states_buffer, same rollback cost as our production vk_v/ws_v.""" + def call(): + return fused_sigmoid_gating_delta_rule_update( + A_log=A_log, a=at, dt_bias=dt_bias, softplus_beta=1.0, softplus_threshold=20.0, + q=qt, k=kt, v=vt, b=bt, initial_state_source=state, initial_state_indices=indices, + scale=scale, use_qk_l2norm_in_kernel=True, cu_seqlens=cu_seqlens, is_kda=True, + disable_state_update=dsu, intermediate_states_buffer=inter_buf, + intermediate_state_indices=inter_idx, cache_steps=cache_steps, + retrieve_parent_token=None, lower_bound=None, + ) + return call + + +def warmup(fn, n): + for _ in range(n): + fn() + torch.cuda.synchronize() + + +def t_graph_ms(fn, warmup_iters, rep, graph_calls=1): + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(warmup_iters): + fn() + torch.cuda.current_stream().wait_stream(s) + torch.cuda.synchronize() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(graph_calls): + fn() + for _ in range(10): + g.replay() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(rep): + g.replay() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / rep / graph_calls + + +_VK_BV = -1 +_ONLY = set() # empty = all variants + + +def _want(name): + return not _ONLY or name in _ONLY + + +def make_vk_call(q, k, v, a, b, A_log, dt_bias, state, indices, scale, dsu, inter_buf=None): + """Production recurrent vk. In verify mode (inter_buf set) it writes the T·d² + intermediate_states_buffer — the rollback cost kvbuffer replaces with a u-buffer.""" + def call(): + return kda_decode_mtp_small_batch( + A_log=A_log, dt_bias=dt_bias, q=q, k=k, v=v, a=a, b=b, + initial_state_source=state, initial_state_indices=indices, scale=scale, + use_qk_l2norm_in_kernel=True, softplus_beta=1.0, softplus_threshold=20.0, + disable_state_update=dsu, variant="vk", bv=_VK_BV, + intermediate_states_buffer=inter_buf, + ) + return call + + +def make_ws_call(q, k, v, a, b, A_log, dt_bias, state, indices, scale, dsu, inter_buf=None): + """Production recurrent ws. In verify mode (inter_buf set) it also writes T·d² states.""" + def call(): + return kda_decode_mtp_ws( + A_log=A_log, dt_bias=dt_bias, q=q, k=k, v=v, a=a, b=b, + initial_state_source=state, initial_state_indices=indices, scale=scale, + use_qk_l2norm_in_kernel=True, softplus_beta=1.0, softplus_threshold=20.0, + disable_state_update=dsu, + intermediate_states_buffer=inter_buf, + ) + return call + + +def make_tpkvb_call(q, k, v, a, b, A_log, dt_bias, state, indices, scale, dsu, ubufs=None): + """tp-kvbuffer (token-parallel chunkwise, structure B) — target: verify latency ~flat in T. + tile_v / ilp_rows overridable via env KDA_TPKVB_TILE_V / KDA_TPKVB_ILP_ROWS (-1 = auto).""" + u_buf, kinv_buf, b_buf = (ubufs if ubufs is not None else (None, None, None)) + _tv = int(os.environ.get("KDA_TPKVB_TILE_V", "-1")) + _ilp = int(os.environ.get("KDA_TPKVB_ILP_ROWS", "-1")) + def call(): + return kda_decode_mtp_tp_kvbuffer( + A_log=A_log, dt_bias=dt_bias, q=q, k=k, v=v, a=a, b=b, + initial_state_source=state, initial_state_indices=indices, scale=scale, + use_qk_l2norm_in_kernel=True, softplus_beta=1.0, softplus_threshold=20.0, + disable_state_update=dsu, emit_output=True, + u_buffer=u_buf, kinv_buffer=kinv_buf, b_buffer=b_buf, + tile_v=_tv, ilp_rows=_ilp, + ) + return call + + +def make_cgkvb_call(q, k, v, a, b, A_log, dt_bias, state, indices, scale, dsu, ubufs=None): + """CuTe sm_90 tensor-core gemm-kvbuffer. env KDA_CGKVB_BV / KDA_CGKVB_NUM_V_TILES (-1 = auto).""" + u_buf, kinv_buf, b_buf = (ubufs if ubufs is not None else (None, None, None)) + _bv = int(os.environ.get("KDA_CGKVB_BV", "32")) + _num_v_tiles = int(os.environ.get("KDA_CGKVB_NUM_V_TILES", "-1")) + def call(): + return kda_decode_mtp_gemm_kvbuffer_cute( + A_log=A_log, dt_bias=dt_bias, q=q, k=k, v=v, a=a, b=b, + initial_state_source=state, initial_state_indices=indices, scale=scale, + use_qk_l2norm_in_kernel=True, softplus_beta=1.0, softplus_threshold=20.0, + disable_state_update=dsu, emit_output=True, + u_buffer=u_buf, kinv_buffer=kinv_buf, b_buffer=b_buf, + bv=_bv, num_v_tiles=_num_v_tiles, + ) + return call + + +def make_kv_call(q, k, v, a, b, A_log, dt_bias, state, indices, scale, dsu): + """Forward-only production kv (lane=V small_batch; no intermediate-state support).""" + state_kv = state.transpose(-2, -1).contiguous() # vk->kv once, outside timing + def call(): + return kda_decode_mtp_small_batch( + A_log=A_log, dt_bias=dt_bias, q=q, k=k, v=v, a=a, b=b, + initial_state_source=state_kv, initial_state_indices=indices, scale=scale, + use_qk_l2norm_in_kernel=True, softplus_beta=1.0, softplus_threshold=20.0, + disable_state_update=dsu, variant="kv", + ) + return call + + +def make_auto_call(q, k, v, a, b, A_log, dt_bias, state, indices, scale, dsu, inter_buf=None): + """kda_decode_mtp dispatch (small_batch vk for N*HV<=512, else ws).""" + def call(): + return kda_decode_mtp( + A_log=A_log, dt_bias=dt_bias, q=q, k=k, v=v, a=a, b=b, + initial_state_source=state, initial_state_indices=indices, scale=scale, + use_qk_l2norm_in_kernel=True, softplus_beta=1.0, softplus_threshold=20.0, + disable_state_update=dsu, state_layout="vk", intermediate_states_buffer=inter_buf, + ) + return call + + +def make_loop_call(q, k, v, a, b, A_log, dt_bias, state, indices, scale, dsu): + """Per-token kda_decode loop baseline (slices pre-cut; kda_decode always writes state).""" + N, T = q.shape[0], q.shape[1] + HV, V = v.shape[2], v.shape[3] + qs = [q[:, t].unsqueeze(1).contiguous() for t in range(T)] + ks = [k[:, t].unsqueeze(1).contiguous() for t in range(T)] + vs = [v[:, t].unsqueeze(1).contiguous() for t in range(T)] + as_ = [a[:, t].unsqueeze(1).contiguous() for t in range(T)] + bs = [b[:, t].unsqueeze(1).contiguous() for t in range(T)] + st = state.clone().contiguous() + o = torch.empty(N, T, HV, V, device=q.device, dtype=torch.bfloat16) + def call(): + for t in range(T): + o_t = kda_decode( + A_log=A_log, dt_bias=dt_bias, q=qs[t], k=ks[t], v=vs[t], a=as_[t], b=bs[t], + initial_state_source=st, initial_state_indices=indices, scale=scale, + use_qk_l2norm_in_kernel=True, + ) + o[:, t] = o_t.squeeze(1) + return o + return call + + +# ---- verify-chain components: commit (recurrent rollback) & flush (kvbuffer) ---- +def make_scatter_commit_call(state_pool, inter_buf, m, N, T, HV, V, K): + """Recurrent rollback via the OFFICIAL sglang fused_mamba_state_scatter_with_mask: + gather each request's accepted-step state from the intermediate cache into the pool + (num_layers=1; step = m-1 for all requests).""" + dst = state_pool.view(1, N, HV, V, K) # [layers, cache, *state] + src = inter_buf.view(1, N, T, HV, V, K) # [layers, req, step, *state] + dst_idx = torch.arange(N, device=state_pool.device, dtype=torch.int32) + step_idx = torch.full((N,), m - 1, device=state_pool.device, dtype=torch.int32) + def call(): + fused_mamba_state_scatter_with_mask(dst, src, dst_idx, step_idx) + return state_pool + return call + + +def make_gather_commit_call(state_pool, inter_buf, m): + """Recurrent rollback, strided gather model: copy inter_buf[:,m-1] (a T-strided view) + into the pool. Less coalesced than the official kernel — kept for sensitivity only.""" + midx = m - 1 + def call(): + state_pool.copy_(inter_buf[:, midx]) + return state_pool + return call + + +def make_flush_call(state_pool, indices, ubufs, m): + """KVBuffer flush: read the compact u-buffer, rank-m rebuild S_m (no recompute).""" + u_b, kinv_b, b_b = ubufs + def call(): + return kda_flush_kvbuffer(state_pool, indices, u_b, kinv_b, b_b, m) + return call + + +def _accept_len(T, accept, N=0): + if accept == "full": + return T + if accept == "half": + return max(1, (T + 1) // 2) + if accept == "one": + return 1 + if accept == "random": + # Deterministic per-(N,T) accept length in [1,T] (real serving is per-req variable). + g = torch.Generator().manual_seed(1000 * N + T) + return int(torch.randint(1, T + 1, (1,), generator=g).item()) + return max(1, min(int(accept), T)) + + +def _profile_one(args, DSU, device): + """Run ONE method's kernel in a loop so ncu can wrap it. Shape = (batch_sizes[0], Ts[0]).""" + N, T = args.batch_sizes[0], args.Ts[0] + q, k, v, a, b, A_log, dt_bias, state0, indices = make_dense_inputs( + N, T, args.H, args.HV, args.K, args.V, device) + scale = args.K ** -0.5 + m = _accept_len(T, args.accept, N) + inter_buf = torch.empty(N, T, args.HV, args.V, args.K, dtype=torch.float32, device=device) + ubufs = ( + torch.empty(N, T, args.HV, args.V, dtype=torch.float32, device=device), + torch.empty(N, T, args.HV, args.K, dtype=torch.float32, device=device), + torch.empty(N, T, args.HV, args.K, dtype=torch.float32, device=device), + ) + p = args.profile + if p == "vk": + fn = make_vk_call(q, k, v, a, b, A_log, dt_bias, state0.clone(), indices, scale, DSU, inter_buf) + elif p == "ws": + fn = make_ws_call(q, k, v, a, b, A_log, dt_bias, state0.clone(), indices, scale, DSU, inter_buf) + elif p == "tpkvb": + fn = make_tpkvb_call(q, k, v, a, b, A_log, dt_bias, state0.clone(), indices, scale, DSU, ubufs) + elif p == "cgkvb": + fn = make_cgkvb_call(q, k, v, a, b, A_log, dt_bias, state0.clone(), indices, scale, DSU, ubufs) + elif p == "triton": + qt, kt, vt, at, bt, cu = to_triton_varlen(q, k, v, a, b) + tri_idx = torch.arange(N, device=device, dtype=torch.int32) + fn = make_triton_call(qt, kt, vt, at, bt, cu, A_log, dt_bias, state0.clone(), + indices, scale, DSU, inter_buf, tri_idx, T) + elif p == "commit": + make_vk_call(q, k, v, a, b, A_log, dt_bias, state0.clone(), indices, scale, DSU, inter_buf)() + fn = make_scatter_commit_call(state0.clone(), inter_buf, m, N, T, args.HV, args.V, args.K) + elif p == "kv": + fn = make_kv_call(q, k, v, a, b, A_log, dt_bias, state0.clone(), indices, scale, DSU) + elif p == "auto": + fn = make_auto_call(q, k, v, a, b, A_log, dt_bias, state0.clone(), indices, scale, DSU) + elif p == "loop": + fn = make_loop_call(q, k, v, a, b, A_log, dt_bias, state0.clone(), indices, scale, DSU) + elif p == "flush": + make_tpkvb_call(q, k, v, a, b, A_log, dt_bias, state0.clone(), indices, scale, DSU, ubufs)() + fn = make_flush_call(state0.clone(), indices, ubufs, m) + for _ in range(5): + fn() + torch.cuda.synchronize() + for _ in range(args.profile_iters): + fn() + torch.cuda.synchronize() + print(f"profiled {p} N={N} T={T} HV={args.HV} m={m} iters={args.profile_iters}") + + +def main(): + ap = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--batch-sizes", type=int, nargs="+", default=[1, 2, 4, 8]) + ap.add_argument("--Ts", type=int, nargs="+", default=[2, 3, 4, 6, 8]) + ap.add_argument("--H", type=int, default=16) + ap.add_argument("--HV", type=int, default=64) + ap.add_argument("--K", type=int, default=128) + ap.add_argument("--V", type=int, default=128) + ap.add_argument("--rep", type=int, default=300) + ap.add_argument("--warmup", type=int, default=5, help="warmup iters before each timed segment") + ap.add_argument("--graph-calls", type=int, default=4, + help="ops per CUDA graph to amortize fixed launch overhead at small batch " + "(N<16; N>=16 uses 1). needs idempotent dsu=1.") + ap.add_argument("--dsu", type=int, default=1, choices=[0, 1], + help="disable_state_update; 1=forward-only (idempotent, default), 0=write state") + ap.add_argument("--vk-bv", type=int, default=-1, choices=[-1, 8, 16, 32]) + ap.add_argument("--accept", default="random", + help="chain accept length m: full(=T)/half/one/random/; drives commit/flush.") + ap.add_argument("--commit", default="scatter", choices=["scatter", "gather"], + help="recurrent commit model: scatter=official sglang " + "fused_mamba_state_scatter_with_mask (coalesced N·d², default); " + "gather=strided copy (sensitivity). kvbuffer flush always counted.") + ap.add_argument("--only", nargs="+", default=[], + choices=["vk", "ws", "tri", "tpkvb", "cgkvb", "kv", "auto", "loop"], + help="restrict check/timing to these verify variants (default: all). " + "REC/spd columns show n/a for skipped baselines.") + ap.add_argument("--check", action="store_true", help="numerical check only, no timing") + ap.add_argument("--atol", type=float, default=5e-2) + ap.add_argument("--profile", default="", + choices=["", "vk", "ws", "tpkvb", "cgkvb", "triton", "commit", "flush", "kv", "auto", "loop"], + help="ncu profile mode: run one method's kernel in a loop (uses batch-sizes[0], Ts[0])") + ap.add_argument("--profile-iters", type=int, default=20, help="kernel launches in the profiled loop") + args = ap.parse_args() + + global _VK_BV + _VK_BV = args.vk_bv + global _ONLY + _ONLY = set(args.only) + DSU = bool(args.dsu) + device = "cuda" + if args.profile: + _profile_one(args, DSU, device) + return + print(f"GPU: {torch.cuda.get_device_name()}") + print(f"shape H={args.H} HV={args.HV} K={args.K} V={args.V} dsu={DSU} " + f"tpkvb_impl={_HAVE_TPKVB} cgkvb_impl={_HAVE_CGKVB}") + + # ---------------- numerical check (vs Triton recurrent) ---------------- + if not _HAVE_TRITON: + print(f"[warn] Triton baseline unavailable ({_TRITON_ERR}); skipping numerical check.") + else: + print("\n=== numerical check (max|Δ| vs Triton recurrent, threshold " + f"{args.atol}) ===") + print(f"{'N':>4} {'T':>3} | {'Δ vk':>10} | {'Δ ws':>10} | " + f"{'Δ tpkvb':>10} | {'Δ cgkvb':>10} | flag") + for N in args.batch_sizes: + for T in args.Ts: + q, k, v, a, b, A_log, dt_bias, state0, indices = make_dense_inputs( + N, T, args.H, args.HV, args.K, args.V, device) + scale = args.K ** -0.5 + qt, kt, vt, at, bt, cu = to_triton_varlen(q, k, v, a, b) + o_tri = make_triton_call(qt, kt, vt, at, bt, cu, A_log, dt_bias, + state0.clone(), indices, scale, True)() + o_tri = o_tri.reshape(N, T, args.HV, args.V) + d_vk = d_ws = float("nan") + if _want("vk"): + o_vk = make_vk_call(q, k, v, a, b, A_log, dt_bias, + state0.clone(), indices, scale, True)() + d_vk = (o_vk - o_tri).abs().max().item() + if _want("ws"): + o_ws = make_ws_call(q, k, v, a, b, A_log, dt_bias, + state0.clone(), indices, scale, True)() + d_ws = (o_ws - o_tri).abs().max().item() + d_tpkvb = float("nan") + if _HAVE_TPKVB and _want("tpkvb"): + o_tpkvb = make_tpkvb_call(q, k, v, a, b, A_log, dt_bias, + state0.clone(), indices, scale, True)() + d_tpkvb = (o_tpkvb - o_tri).abs().max().item() + d_cgkvb = float("nan") + if _HAVE_CGKVB and _want("cgkvb"): + o_cgkvb = make_cgkvb_call(q, k, v, a, b, A_log, dt_bias, + state0.clone(), indices, scale, True)() + d_cgkvb = (o_cgkvb - o_tri).abs().max().item() + cand = [x for x in (d_vk, d_ws, d_tpkvb, d_cgkvb) if x == x] + flag = ("OK" if max(cand) < args.atol else "DIFF!") if cand else "n/a" + print(f"{N:>4} {T:>3} | {d_vk:>10.2e} | {d_ws:>10.2e} | " + f"{d_tpkvb:>10.2e} | {d_cgkvb:>10.2e} | {flag}") + + if args.check: + return + + _timing_verify_chain(args, DSU, device) + + +def _timing_verify_chain(args, DSU, device): + """Fair spec-decode verify CHAIN (each segment timed in its own CUDA graph, summed). All verify + kernels run dsu=1 + verify-mode: recurrent vk/ws/triton write the T·d² intermediate states, + kvbuffer writes its compact u-buffer. REC = recurrent verify + commit; KVB = kvbuffer verify + + flush. spd_vk/spd_ws = REC/KVB vs production vk/ws; spd_vkbf/spd_wsbf = official triton REC chain + / kvbuffer KVB chain. Prints chain totals + speedups first, per-segment breakdown after.""" + def us(x): + return f"{x * 1e3:.1f}" if x else "n/a" + + def rat(a_, b_): + return f"{a_ / b_:.2f}x" if (a_ and b_) else "n/a" + + if args.commit == "scatter" and not _HAVE_SCATTER: + raise RuntimeError( + f"commit=scatter needs the official sglang kernel; set KDA_SCATTER_FILE to " + f"mamba_state_scatter_triton.py (load error: {_SCATTER_ERR})") + + # ---- measure every segment for every (N, T) into `results` ---- + results = [] + for N in args.batch_sizes: + for T in args.Ts: + q, k, v, a, b, A_log, dt_bias, state0, indices = make_dense_inputs( + N, T, args.H, args.HV, args.K, args.V, device) + scale = args.K ** -0.5 + m = _accept_len(T, args.accept, N) + gc = 1 if N >= 16 else args.graph_calls # amortize launch overhead at small batch + inter_buf = torch.empty(N, T, args.HV, args.V, args.K, dtype=torch.float32, device=device) + ubufs = ( + torch.empty(N, T, args.HV, args.V, dtype=torch.float32, device=device), + torch.empty(N, T, args.HV, args.K, dtype=torch.float32, device=device), + torch.empty(N, T, args.HV, args.K, dtype=torch.float32, device=device), + ) + tg = {} + + def time_seg(fn): + warmup(fn, args.warmup) + return t_graph_ms(fn, args.warmup, args.rep, gc) + + # recurrent verify (dsu=1, writes T·d² states) + commit + if _want("vk"): + tg["vk_v"] = time_seg(make_vk_call(q, k, v, a, b, A_log, dt_bias, state0.clone(), indices, scale, DSU, inter_buf)) + if _want("vk") or _want("ws") or _want("tri"): + if args.commit == "scatter": + fn_cmt = make_scatter_commit_call(state0.clone(), inter_buf, m, N, T, args.HV, args.V, args.K) + else: + fn_cmt = make_gather_commit_call(state0.clone(), inter_buf, m) + tg["cmt"] = time_seg(fn_cmt) + if _want("ws"): + tg["ws_v"] = time_seg(make_ws_call(q, k, v, a, b, A_log, dt_bias, state0.clone(), indices, scale, DSU, inter_buf)) + # kvbuffer verify (dsu=1, writes u-buffer) + flush + if _want("tpkvb") or _want("cgkvb"): + # flush needs a populated u-buffer: run one kvbuffer verify first to fill it + if _HAVE_TPKVB and _want("tpkvb"): + make_tpkvb_call(q, k, v, a, b, A_log, dt_bias, state0.clone(), indices, scale, DSU, ubufs)() + elif _HAVE_CGKVB and _want("cgkvb"): + make_cgkvb_call(q, k, v, a, b, A_log, dt_bias, state0.clone(), indices, scale, DSU, ubufs)() + tg["flush"] = time_seg(make_flush_call(state0.clone(), indices, ubufs, m)) + if _HAVE_TPKVB and _want("tpkvb"): + tg["tpkvb_v"] = time_seg(make_tpkvb_call(q, k, v, a, b, A_log, dt_bias, state0.clone(), indices, scale, DSU, ubufs)) + if _HAVE_CGKVB and _want("cgkvb"): + tg["cgkvb_v"] = time_seg(make_cgkvb_call(q, k, v, a, b, A_log, dt_bias, state0.clone(), indices, scale, DSU, ubufs)) + # official triton recurrent verify (dsu=1, writes T·d² states) + if _HAVE_TRITON and _want("tri"): + qt, kt, vt, at, bt, cu = to_triton_varlen(q, k, v, a, b) + tri_inter = torch.empty(N, T, args.HV, args.V, args.K, dtype=torch.float32, device=device) + tri_idx = torch.arange(N, device=device, dtype=torch.int32) + tg["tri_v"] = time_seg(make_triton_call(qt, kt, vt, at, bt, cu, A_log, dt_bias, + state0.clone(), indices, scale, DSU, tri_inter, tri_idx, T)) + + r = {"N": N, "T": T, "m": m, "tg": tg} + + def _sum(av, bv): + return tg[av] + tg[bv] if (av in tg and bv in tg) else None + + r["REC_vk"] = _sum("vk_v", "cmt") + r["REC_ws"] = _sum("ws_v", "cmt") + r["KVB_tp"] = _sum("tpkvb_v", "flush") + r["KVB_cg"] = _sum("cgkvb_v", "flush") + r["REC_tri"] = _sum("tri_v", "cmt") + results.append(r) + + # ---- table 1: chain totals + speedups ---- + print(f"\n=== verify-CHAIN total latency (us) + speedup — accept m={args.accept} commit={args.commit} ===") + print(" REC_* = recurrent verify (writes T·d² states) + commit; KVB_* = kvbuffer verify (u-buffer) + flush") + print(" spd_(vk/ws/tp/cg) = REC_tri (official triton) / (REC_vk/REC_ws/KVB_tp/KVB_cg) -- chain speedup over triton") + hdr = (f"{'N':>4} {'T':>3} {'m':>3} | {'REC_vk':>7} {'REC_ws':>7} {'REC_tri':>7} | {'KVB_tp':>7} {'KVB_cg':>7} | " + f"{'spd_vk':>7} {'spd_ws':>7} {'spd_tp':>7} {'spd_cg':>7}") + print(hdr) + print("-" * len(hdr)) + for r in results: + print(f"{r['N']:>4} {r['T']:>3} {r['m']:>3} | {us(r['REC_vk']):>7} {us(r['REC_ws']):>7} {us(r['REC_tri']):>7} | " + f"{us(r['KVB_tp']):>7} {us(r['KVB_cg']):>7} | " + f"{rat(r['REC_tri'], r['REC_vk']):>7} {rat(r['REC_tri'], r['REC_ws']):>7} {rat(r['REC_tri'], r['KVB_tp']):>7} {rat(r['REC_tri'], r['KVB_cg']):>7}") + + # ---- table 2: per-segment breakdown ---- + print("\n=== per-segment breakdown (us) — verify kernels + shared commit/flush ===") + hdr2 = (f"{'N':>4} {'T':>3} | {'vk_v':>6} {'ws_v':>6} {'tri_v':>6} | {'tpkvb_v':>7} {'cgkvb_v':>7} | " + f"{'cmt':>5} {'flush':>6}") + print(hdr2) + print("-" * len(hdr2)) + for r in results: + tg = r["tg"] + print(f"{r['N']:>4} {r['T']:>3} | {us(tg.get('vk_v')):>6} {us(tg.get('ws_v')):>6} {us(tg.get('tri_v')):>6} | " + f"{us(tg.get('tpkvb_v')):>7} {us(tg.get('cgkvb_v')):>7} | " + f"{us(tg.get('cmt')):>5} {us(tg.get('flush')):>6}") + + +if __name__ == "__main__": + main() diff --git a/cula/kda/__init__.py b/cula/kda/__init__.py index ee1a2bb9..dd7d22b8 100644 --- a/cula/kda/__init__.py +++ b/cula/kda/__init__.py @@ -16,11 +16,19 @@ from cula.kda.chunk import chunk_kda from cula.kda.hopper_fused_fwd import cula_kda_prefill as kda_prefill_hopper from cula.ops.kda_decode import fused_sigmoid_gating_delta_rule_update, kda_decode +from cula.ops.kda_decode_mtp import ( + kda_decode_mtp, + kda_decode_mtp_small_batch, + kda_decode_mtp_ws, +) __all__ = [ "chunk_kda", "kda_prefill_blackwell", "kda_decode", + "kda_decode_mtp", + "kda_decode_mtp_ws", + "kda_decode_mtp_small_batch", "fused_sigmoid_gating_delta_rule_update", "kda_prefill_hopper", ] diff --git a/cula/ops/__init__.py b/cula/ops/__init__.py index 6450488b..052f2edb 100644 --- a/cula/ops/__init__.py +++ b/cula/ops/__init__.py @@ -13,10 +13,18 @@ # limitations under the License. from cula.ops.kda_decode import fused_sigmoid_gating_delta_rule_update, kda_decode +from cula.ops.kda_decode_mtp import ( + kda_decode_mtp, + kda_decode_mtp_small_batch, + kda_decode_mtp_ws, +) from cula.ops.la_decode import linear_attention_decode __all__ = [ "kda_decode", + "kda_decode_mtp", + "kda_decode_mtp_ws", + "kda_decode_mtp_small_batch", "fused_sigmoid_gating_delta_rule_update", "linear_attention_decode", ] diff --git a/cula/ops/kda_decode.py b/cula/ops/kda_decode.py index d84c77bf..757133bf 100644 --- a/cula/ops/kda_decode.py +++ b/cula/ops/kda_decode.py @@ -144,6 +144,7 @@ def _try_fast_dense_decode( softplus_threshold: float, out: torch.Tensor | None, state_layout: str | None, + opt_level: int = 1, ): """Fast path for the common dense decode case used by the benchmark. @@ -267,6 +268,7 @@ def _try_fast_dense_decode( dense_small_hv_parallel=dense_small_hv_parallel, softplus_beta=softplus_beta, softplus_threshold=softplus_threshold, + opt_level=opt_level, ) compiled_kernel( cu_seqlens_to_use, @@ -1552,12 +1554,18 @@ def _get_compiled_kernel( dense_small_hv_parallel, softplus_beta, softplus_threshold, + opt_level=1, ): """Get or lazily compile one CuteDSL decode kernel variant. Compile-time specialization is still important here, so we cache the result by shape, layout, and constexpr options. The compiled function is emitted with TVM-FFI enabled so runtime calls can pass torch tensors directly. + + ``opt_level`` selects the CuTe DSL ``--opt-level`` (codegen optimization; + NOT a kernel constexpr). It is part of the cache key so the same shape can + be compiled at multiple opt-levels without colliding. Default 1 keeps the + historical behavior; 2/3 are experiments (see issue 17 compile-knob tuning). """ global _compiled_kernels @@ -1578,6 +1586,7 @@ def _get_compiled_kernel( dense_small_hv_parallel, softplus_beta, softplus_threshold, + opt_level, ) if key in _compiled_kernels: return _compiled_kernels[key] @@ -1656,7 +1665,7 @@ def _get_compiled_kernel( num_blocks_per_state_small=num_blocks_per_state_small, dense_small_hv_parallel=dense_small_hv_parallel, stream=stream, - options="--enable-tvm-ffi --opt-level 1", + options=f"--enable-tvm-ffi --opt-level {opt_level}", ) _compiled_kernels[key] = compiled_kernel @@ -1809,6 +1818,7 @@ def fused_sigmoid_gating_delta_rule_update( is_kda: bool = False, out: torch.Tensor | None = None, state_layout: str = "vk", + opt_level: int = 1, ): """Public cuLA decode API backed by CuTe DSL. @@ -1839,6 +1849,7 @@ def fused_sigmoid_gating_delta_rule_update( softplus_threshold=softplus_threshold, out=out, state_layout=state_layout, + opt_level=opt_level, ) @@ -1859,6 +1870,7 @@ def kda_decode( softplus_threshold: float = 20.0, out: torch.Tensor | None = None, state_layout: str = "vk", + opt_level: int = 1, ) -> torch.Tensor: """CuTe DSL implementation of fused sigmoid gating KDA update. @@ -1911,6 +1923,7 @@ def kda_decode( softplus_threshold, out, state_layout, + opt_level, ) if fast_dense_out is not None: return fast_dense_out @@ -2074,6 +2087,7 @@ def kda_decode( dense_small_hv_parallel=dense_small_hv_parallel, softplus_beta=softplus_beta, softplus_threshold=softplus_threshold, + opt_level=opt_level, ) # With TVM-FFI enabled at compile time, the runtime launch can pass torch diff --git a/cula/ops/kda_decode_mtp.py b/cula/ops/kda_decode_mtp.py new file mode 100644 index 00000000..a8bc112c --- /dev/null +++ b/cula/ops/kda_decode_mtp.py @@ -0,0 +1,2091 @@ +"""CuTe DSL KDA MTP decode + +Production KDA MTP decode kernel. Public entry point: ``kda_decode_mtp_ws`` +(warp-spec). The defining feature is KDA's per-K-channel decay gate ``g_t in R^K`` +(``beta`` stays a per-(head, token) scalar); the whole kernel is built around that +channel axis. + +Grid = N*HV*num_v_tiles, one CTA per (i_n, i_hv, i_v V-tile). State is +register-resident across the T tokens; the K-reduce is a full-warp shuffle. The +recurrence uses the DECAY-FIRST order (decay the whole state, then dot with raw k); +bf16 rounding differs slightly from the single-token ``kda_decode`` (accumulation +order), both validated against the fp32 torch oracle at atol 3e-2 / rtol 2e-2. + +Scope (this file): +- Warp-spec variant. ``ilp_rows in {2, 4}``: ilp=2 covers every + tile_v in {8,16,32,64}; ilp=4 fuses steps 1+2 and 4+5 with double accumulators + + packed F32x2 FMA on SM100 (scalar ``fma_pair`` fallback elsewhere) and requires + ``tile_v % 16 == 0`` (so {16,32,64}). +- ``vk`` state layout only. +- ``use_smem_v`` (Stage C): preload the v-tile into SMEM + coalesced merged output + writeback. Constexpr, off unless the heuristic / an explicit arg turns it on. +- ``cache_intermediate_states`` (Stage D): when an ``intermediate_states_buffer`` + ([N, T, HV, V, K] vk) is passed, snapshot every token's post-state to GMEM + (sequence-indexed) for speculative-decoding rollback. Produce-only. +- ``disable_state_update`` supported (default False = always write back). + +Math per token t (decay-first, per-channel g): + g_t = exp(-exp(A_log) * softplus(a_t + dt_bias)) # (K,) per-channel + S <- S * diag(g_t) # step 1 (per channel) + s = S @ k_norm # step 2 (reduce K) + v_new = sigmoid(b_t) * (v_t - s) # step 3 + S += v_new (x) k_norm # step 4 (rank-1, raw k) + o_t = S @ (l2norm(q_t) * scale) # step 5 (reduce K) +""" + +import logging + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass.cute.runtime import from_dlpack + +from cula.ops.kda_decode import ( + NUM_THREADS, + TILE_K, + _canonicalize_state_layout, + _get_cached_stream, + _normalize_A_log, + _normalize_dt_bias, + _normalize_state_indices, + _normalize_state_source, + _prepare_output_tensor, +) + +logger = logging.getLogger(__name__) + +# vec_size = 4 -> 32 threads/group = a full warp, 4 groups (warps) per block. +VEC_SIZE_MTP = 4 + +_compiled_mtp_ws_kernels: dict[tuple, object] = {} + + +def _normalize_mtp_a(a: torch.Tensor, *, N: int, T: int, HV: int, K: int) -> torch.Tensor: + """Normalize `a` to the compile-time dense MTP shape (N, T, HV, K).""" + if a.dim() == 4 and tuple(a.shape) == (N, T, HV, K): + return a + if a.dim() == 3 and tuple(a.shape) == (N, T, HV * K): + return a.view(N, T, HV, K) + raise ValueError(f"Unexpected a shape for MTP dense: {tuple(a.shape)}; expected {(N, T, HV, K)}") + + +# Valid V-tile sizes {8,16,32,64}: each a multiple of NUM_WARPS (4) so V_PER_WARP +_MTP_TILE_V_CHOICES = (8, 16, 32, 64) + + +def _select_mtp_config( + N: int, + HV: int, + V: int, + T: int, + *, + disable_state_update: bool = False, +) -> tuple[int, int, bool]: + work_units = N * HV + + if work_units <= 64: + tile_v, ilp_rows, use_smem_v = 8, 2, False + elif work_units <= 128: + tile_v, ilp_rows, use_smem_v = 16, 4, False + elif work_units <= 448: + if T <= 2: + tile_v, ilp_rows, use_smem_v = 16, 2, False + else: + tile_v, ilp_rows, use_smem_v = 32, 4, False + elif work_units <= 1024: + tile_v, ilp_rows, use_smem_v = 32, 4, False + else: + # Large batches: ilp capped at 4, so (64, 4, True) uniformly. + tile_v, ilp_rows, use_smem_v = 64, 4, True + + tile_v = min(tile_v, V) + while tile_v > _MTP_TILE_V_CHOICES[0] and V % tile_v != 0: + tile_v //= 2 + + # Legality backstop: ilp=4 requires (tile_v//4) % 4 == 0, i.e. tile_v % 16 == 0 + if ilp_rows == 4 and tile_v % 16 != 0: + ilp_rows = 2 + + return tile_v, ilp_rows, use_smem_v + + +def _select_mtp_tile_v(N: int, HV: int, V: int, T: int) -> int: + return _select_mtp_config(N, HV, V, T)[0] + + +@cute.jit +def fma_pair(a1, a2, b1, b2, c1, c2): + # FMA two pairs: (a1*b1+c1, a2*b2+c2). + result1 = a1 * b1 + c1 + result2 = a2 * b2 + c2 + return result1, result2 + + +@cute.kernel +def kda_verify_kernel_mtp_ws( + h0_source: cute.Tensor, # [pool_size * HV, V, K] fp32, K-last (VK layout) + intermediate_states: cute.Tensor, # [N*T*HV, V, K] fp32 snapshot cache (or dummy) + vec_size: cutlass.Constexpr[int], + num_v_tiles: cutlass.Constexpr[int], + tile_v: cutlass.Constexpr[int], + A_log: cute.Tensor, # [HV] fp32 (per-channel decay) + a: cute.Tensor, # [N, T, HV, K] (per-channel decay input) + dt_bias: cute.Tensor, # [HV, K] (per-channel decay bias) + q: cute.Tensor, # [N, T, H, K] + k: cute.Tensor, # [N, T, H, K] + v: cute.Tensor, # [N, T, HV, V] + b: cute.Tensor, # [N, T, HV] (update-gate logit) + o: cute.Tensor, # [N, T, HV, V] output + h0_indices: cute.Tensor, # [N] int32 (state-pool slot per sequence; <0 = pad) + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + disable_state_update: cutlass.Constexpr[bool], + ilp_rows: cutlass.Constexpr[int], + use_packed_fma: cutlass.Constexpr[bool], + use_smem_v: cutlass.Constexpr[bool], + cache_intermediate_states: cutlass.Constexpr[bool], + fast_math: cutlass.Constexpr[bool], +): + tidx, _, _ = cute.arch.thread_idx() + lane_id = tidx % 32 + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # vec_size=4 -> threads_per_group=32 (full warp), 4 groups (one per warp). + threads_per_group: cutlass.Constexpr[int] = K // vec_size # 32 + num_groups: cutlass.Constexpr[int] = 4 + lane_in_group = lane_id % threads_per_group + group_idx = warp_idx + + batch_idx, _, _ = cute.arch.block_idx() + + # Decode the flat CTA index into (i_n sequence, i_hv value-head, i_v V-tile). + i_v = batch_idx % num_v_tiles + tmp = batch_idx // num_v_tiles + i_hv = tmp % HV + i_n = tmp // HV + i_h = i_hv // (HV // H) # GVA: HV//H value-heads share one q/k head + + cache_idx = h0_indices[i_n] + + # exp(A_log) is per-head, shared across all K channels — hoist once. + r_A_log = cutlass.Float32(A_log[i_hv]) + r_exp_A = cute.exp(r_A_log, fastmath=fast_math) + + # SMEM broadcast buffers (warp 0 -> all warps). sG is [T, K] (per-channel); + smem = cutlass.utils.SmemAllocator() + sQ = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((T, K), stride=(K + 8, 1)), 16 + ) + sK = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((T, K), stride=(K + 8, 1)), 16 + ) + sG = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((T, K), stride=(K + 8, 1)), 16 + ) + sBeta = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T,)), 16) + + # use_smem_v (Stage C): preload the v-tile into SMEM + accumulate outputs for a + # coalesced merged writeback. Allocated last/conditionally so off-path offsets stay put. + if cutlass.const_expr(use_smem_v): + sVdata = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((T, tile_v), stride=(tile_v, 1)), 16 + ) + sOutput = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((T, tile_v), stride=(tile_v, 1)), 16 + ) + + # Per-lane registers: r_g = this lane's vec_size channels of g; r_h = up to 8 + # V-rows of state (only ilp_rows used), each row spanning 32 lanes over K=128. + r_q = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_k = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_g = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_h = cute.make_rmem_tensor( + cute.make_layout((8, vec_size), stride=(vec_size, 1)), cutlass.Float32 + ) + r_q_bf16 = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) + r_k_bf16 = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) + + if cache_idx >= 0: + k_start = lane_in_group * vec_size # this lane's first K channel + rows_per_group: cutlass.Constexpr[int] = tile_v // num_groups + flat_state_idx = cache_idx * HV + i_hv # row in [pool*HV, V, K] + + # ============ Phase 1: warp specialization ============ + if warp_idx == 0: + # Warp 0 computes q/k/g/beta for all T tokens, broadcasts via SMEM. + for i_t in cutlass.range_constexpr(T): + q_tile = cute.local_tile( + q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_in_group) + ) + k_tile = cute.local_tile( + k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_in_group) + ) + cute.autovec_copy(q_tile, r_q_bf16) + cute.autovec_copy(k_tile, r_k_bf16) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = cutlass.Float32(r_q_bf16[i]) + r_k[i] = cutlass.Float32(r_k_bf16[i]) + + if cutlass.const_expr(use_qk_l2norm): + sum_q = 0.0 + sum_k = 0.0 + for i in cutlass.range_constexpr(vec_size): + sum_q += r_q[i] * r_q[i] + sum_k += r_k[i] * r_k[i] + # Full-warp reduction (32 lanes x vec_size=4 = all 128 K). + for offset in [16, 8, 4, 2, 1]: + sum_q += cute.arch.shuffle_sync_bfly( + sum_q, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_k += cute.arch.shuffle_sync_bfly( + sum_k, offset=offset, mask=-1, mask_and_clamp=31 + ) + inv_norm_q_scaled = cute.rsqrt(sum_q + 1e-6, fastmath=fast_math) * scale + inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=fast_math) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * inv_norm_q_scaled + r_k[i] = r_k[i] * inv_norm_k + else: + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * scale + + # vec_size=4 -> warp 0's 32 lanes cover all 128 K channels. + for i in cutlass.range_constexpr(vec_size): + sQ[(i_t, k_start + i)] = r_q[i] + sK[(i_t, k_start + i)] = r_k[i] + + # KDA per-channel decay gate: each lane computes g for its own + # vec_size channels. g[kk] = exp(-exp(A_log) * softplus(a+dt_bias)). + for i in cutlass.range_constexpr(vec_size): + kk = k_start + i + x = cutlass.Float32(a[i_n, i_t, i_hv, kk]) + cutlass.Float32( + dt_bias[i_hv, kk] + ) + beta_x = softplus_beta * x + exp_beta_x = cute.exp(beta_x, fastmath=fast_math) + softplus_val = (cutlass.Float32(1.0) / softplus_beta) * cute.log( + cutlass.Float32(1.0) + exp_beta_x, fastmath=fast_math + ) + use_softplus = ( + cutlass.Float32(1.0) + if beta_x <= softplus_threshold + else cutlass.Float32(0.0) + ) + softplus_x = ( + use_softplus * softplus_val + + (cutlass.Float32(1.0) - use_softplus) * x + ) + sG[(i_t, kk)] = cute.exp(-r_exp_A * softplus_x, fastmath=fast_math) + + # Update gate beta is a per-(head, token) scalar (warp-uniform). + r_b = cutlass.Float32(b[i_n, i_t, i_hv]) + r_beta = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + cute.exp(-r_b, fastmath=fast_math) + ) + sBeta[i_t] = r_beta + + # Preload the v-tile into SMEM: warp 0 covers tile-local cols 0..31, + # warps 1-3 the rest (tidx each col written once). + if cutlass.const_expr(use_smem_v): + if tidx < tile_v: + v_global_idx = i_v * tile_v + tidx + if v_global_idx < V: + sVdata[(i_t, tidx)] = cutlass.Float32( + v[i_n, i_t, i_hv, v_global_idx] + ) + else: + # Warps 1-3: prefetch the first ILP set of state rows into registers, + # overlapping the h-state DRAM latency with warp 0's Phase 1 compute. + v_base_prefetch = i_v * tile_v + group_idx * rows_per_group + if cutlass.const_expr(ilp_rows == 4): + # Prefetch 4 h-state rows (4 independent load streams). + v_pf_d = v_base_prefetch + 3 + if v_pf_d < V: + pf_a = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_base_prefetch, lane_in_group), + ) + pf_b = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_base_prefetch + 1, lane_in_group), + ) + pf_c = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_base_prefetch + 2, lane_in_group), + ) + pf_d = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_base_prefetch + 3, lane_in_group), + ) + cute.autovec_copy(pf_a, cute.slice_(r_h, (0, None))) + cute.autovec_copy(pf_b, cute.slice_(r_h, (1, None))) + cute.autovec_copy(pf_c, cute.slice_(r_h, (2, None))) + cute.autovec_copy(pf_d, cute.slice_(r_h, (3, None))) + elif cutlass.const_expr(ilp_rows == 2): + v_pf_b = v_base_prefetch + 1 + if v_pf_b < V: + pf_a = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_base_prefetch, lane_in_group), + ) + pf_b = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_base_prefetch + 1, lane_in_group), + ) + cute.autovec_copy(pf_a, cute.slice_(r_h, (0, None))) + cute.autovec_copy(pf_b, cute.slice_(r_h, (1, None))) + + # Warps 1-3 cover the tile-local v columns warp 0 can't reach + # (tidx 32..127); same tidx each column written once. + if cutlass.const_expr(use_smem_v): + for i_t in cutlass.range_constexpr(T): + if tidx < tile_v: + v_global_idx = i_v * tile_v + tidx + if v_global_idx < V: + sVdata[(i_t, tidx)] = cutlass.Float32( + v[i_n, i_t, i_hv, v_global_idx] + ) + + # Publish warp 0's SMEM writes (q/k/g/beta + preloaded v) to all warps + # before the recurrence reads them. + cute.arch.barrier() + + # ============ Recurrence: ilp_rows == 2 (process 2 V-rows together) === + if cutlass.const_expr(ilp_rows == 2): + half_rows: cutlass.Constexpr[int] = rows_per_group // 2 + + for row_pair in cutlass.range_constexpr(half_rows): + v_idx_a = i_v * tile_v + group_idx * rows_per_group + row_pair * 2 + v_idx_b = v_idx_a + 1 + + if v_idx_b < V: + # Load state for both rows. Warps 1-3 reuse the Phase-1 + # prefetch on the first pair; everyone else loads in place. + if warp_idx == 0 or row_pair > 0: + h_tile_a = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_idx_a, lane_in_group), + ) + h_tile_b = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_idx_b, lane_in_group), + ) + cute.autovec_copy(h_tile_a, cute.slice_(r_h, (0, None))) + cute.autovec_copy(h_tile_b, cute.slice_(r_h, (1, None))) + + for i_t in cutlass.range_constexpr(T): + # Read warp-0-staged q/k/g for this token (shared by both rows). + sQ_tile = cute.local_tile(sQ, (1, vec_size), (i_t, lane_in_group)) + sK_tile = cute.local_tile(sK, (1, vec_size), (i_t, lane_in_group)) + sG_tile = cute.local_tile(sG, (1, vec_size), (i_t, lane_in_group)) + cute.autovec_copy(sQ_tile, r_q) + cute.autovec_copy(sK_tile, r_k) + cute.autovec_copy(sG_tile, r_g) + r_beta = sBeta[i_t] + + # Step 1: per-channel decay (KDA: r_g[i], not a scalar). + for i in cutlass.range_constexpr(vec_size): + r_h[0, i] = r_h[0, i] * r_g[i] + r_h[1, i] = r_h[1, i] * r_g[i] + + # Step 2: s = (decayed S) @ k_norm (reduce over K). + sum_hk_a = 0.0 + sum_hk_b = 0.0 + for i in cutlass.range_constexpr(vec_size): + sum_hk_a += r_h[0, i] * r_k[i] + sum_hk_b += r_h[1, i] * r_k[i] + for offset in [16, 8, 4, 2, 1]: + sum_hk_a += cute.arch.shuffle_sync_bfly( + sum_hk_a, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_hk_b += cute.arch.shuffle_sync_bfly( + sum_hk_b, offset=offset, mask=-1, mask_and_clamp=31 + ) + + # Step 3: delta rule. v from SMEM (preloaded) or GMEM. + if cutlass.const_expr(use_smem_v): + v_local_a = v_idx_a - i_v * tile_v + r_v_a = sVdata[(i_t, v_local_a)] + r_v_b = sVdata[(i_t, v_local_a + 1)] + else: + r_v_a = cutlass.Float32(v[i_n, i_t, i_hv, v_idx_a]) + r_v_b = cutlass.Float32(v[i_n, i_t, i_hv, v_idx_b]) + v_new_a = (r_v_a - sum_hk_a) * r_beta + v_new_b = (r_v_b - sum_hk_b) * r_beta + + # Step 4: rank-1 update with raw k (decay already applied). + for i in cutlass.range_constexpr(vec_size): + r_h[0, i] += r_k[i] * v_new_a + r_h[1, i] += r_k[i] * v_new_b + + # Stage D: snapshot post-token state, sequence-indexed + # (flat_idx = i_n*T*HV + i_t*HV + i_hv), race-free before step 5. + if cutlass.const_expr(cache_intermediate_states): + flat_idx = i_n * T * HV + i_t * HV + i_hv + inter_a = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v_idx_a, lane_in_group), + ) + cute.autovec_copy(cute.slice_(r_h, (0, None)), inter_a) + inter_b = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v_idx_b, lane_in_group), + ) + cute.autovec_copy(cute.slice_(r_h, (1, None)), inter_b) + + # Step 5: o = S_new @ q_scaled (reduce over K). + sum_hq_a = 0.0 + sum_hq_b = 0.0 + for i in cutlass.range_constexpr(vec_size): + sum_hq_a += r_h[0, i] * r_q[i] + sum_hq_b += r_h[1, i] * r_q[i] + for offset in [16, 8, 4, 2, 1]: + sum_hq_a += cute.arch.shuffle_sync_bfly( + sum_hq_a, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_hq_b += cute.arch.shuffle_sync_bfly( + sum_hq_b, offset=offset, mask=-1, mask_and_clamp=31 + ) + + # Reduction result is identical on all lanes -> lane 0 + # writes. To SMEM (merged flush at kernel end) or GMEM. + if lane_in_group == 0: + if cutlass.const_expr(use_smem_v): + vla = v_idx_a - i_v * tile_v + sOutput[(i_t, vla)] = cutlass.BFloat16(sum_hq_a) + sOutput[(i_t, vla + 1)] = cutlass.BFloat16(sum_hq_b) + else: + o[(i_n, i_t, i_hv, v_idx_a)] = cutlass.BFloat16(sum_hq_a) + o[(i_n, i_t, i_hv, v_idx_b)] = cutlass.BFloat16(sum_hq_b) + + # Write final state for both rows back to the pool (once). + if cutlass.const_expr(not disable_state_update): + h_tile_out_a = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_idx_a, lane_in_group), + ) + cute.autovec_copy(cute.slice_(r_h, (0, None)), h_tile_out_a) + h_tile_out_b = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_idx_b, lane_in_group), + ) + cute.autovec_copy(cute.slice_(r_h, (1, None)), h_tile_out_b) + + # ============ Recurrence: ilp_rows == 4 (process 4 V-rows together) === + # Steps 1+2 fused (decay then h@k) and 4+5 fused (rank-1 then h@q), with + # double accumulators (halve the K-reduce FFMA chain) + packed F32x2 FMA on + # SM100. Per-channel decay r_g[i]/r_g[i+1] loaded from sG. + elif cutlass.const_expr(ilp_rows == 4): + quarter_rows: cutlass.Constexpr[int] = rows_per_group // 4 + + for row_quad in cutlass.range_constexpr(quarter_rows): + v_idx_a = i_v * tile_v + group_idx * rows_per_group + row_quad * 4 + v_idx_b = v_idx_a + 1 + v_idx_c = v_idx_a + 2 + v_idx_d = v_idx_a + 3 + + if v_idx_d < V: + if warp_idx == 0 or row_quad > 0: + h_tile_a = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_idx_a, lane_in_group), + ) + h_tile_b = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_idx_b, lane_in_group), + ) + h_tile_c = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_idx_c, lane_in_group), + ) + h_tile_d = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_idx_d, lane_in_group), + ) + cute.autovec_copy(h_tile_a, cute.slice_(r_h, (0, None))) + cute.autovec_copy(h_tile_b, cute.slice_(r_h, (1, None))) + cute.autovec_copy(h_tile_c, cute.slice_(r_h, (2, None))) + cute.autovec_copy(h_tile_d, cute.slice_(r_h, (3, None))) + + for i_t in cutlass.range_constexpr(T): + # Warp-0-staged q/k/g for this token (shared by all 4 rows). + sQ_tile = cute.local_tile(sQ, (1, vec_size), (i_t, lane_in_group)) + sK_tile = cute.local_tile(sK, (1, vec_size), (i_t, lane_in_group)) + sG_tile = cute.local_tile(sG, (1, vec_size), (i_t, lane_in_group)) + cute.autovec_copy(sQ_tile, r_q) + cute.autovec_copy(sK_tile, r_k) + cute.autovec_copy(sG_tile, r_g) + r_beta = sBeta[i_t] + + # Steps 1+2 fused: per-channel decay then h@k. + sum_hk_a = cutlass.Float32(0.0) + sum_hk_a2 = cutlass.Float32(0.0) + sum_hk_b = cutlass.Float32(0.0) + sum_hk_b2 = cutlass.Float32(0.0) + sum_hk_c = cutlass.Float32(0.0) + sum_hk_c2 = cutlass.Float32(0.0) + sum_hk_d = cutlass.Float32(0.0) + sum_hk_d2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, vec_size, 2): + # Step 1: per-channel decay (KDA: r_g[i]/r_g[i+1]). + r_h[0, i] = r_h[0, i] * r_g[i] + r_h[0, i + 1] = r_h[0, i + 1] * r_g[i + 1] + r_h[1, i] = r_h[1, i] * r_g[i] + r_h[1, i + 1] = r_h[1, i + 1] * r_g[i + 1] + r_h[2, i] = r_h[2, i] * r_g[i] + r_h[2, i + 1] = r_h[2, i + 1] * r_g[i + 1] + r_h[3, i] = r_h[3, i] * r_g[i] + r_h[3, i + 1] = r_h[3, i + 1] * r_g[i + 1] + # Step 2: h@k, two channels per step (packed on SM100). + if cutlass.const_expr(use_packed_fma): + sum_hk_a, sum_hk_a2 = cute.arch.fma_packed_f32x2( + src_a=(r_h[0, i], r_h[0, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(sum_hk_a, sum_hk_a2), + ) + sum_hk_b, sum_hk_b2 = cute.arch.fma_packed_f32x2( + src_a=(r_h[1, i], r_h[1, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(sum_hk_b, sum_hk_b2), + ) + sum_hk_c, sum_hk_c2 = cute.arch.fma_packed_f32x2( + src_a=(r_h[2, i], r_h[2, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(sum_hk_c, sum_hk_c2), + ) + sum_hk_d, sum_hk_d2 = cute.arch.fma_packed_f32x2( + src_a=(r_h[3, i], r_h[3, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(sum_hk_d, sum_hk_d2), + ) + else: + sum_hk_a, sum_hk_a2 = fma_pair( + r_h[0, i], r_h[0, i + 1], r_k[i], r_k[i + 1], sum_hk_a, sum_hk_a2 + ) + sum_hk_b, sum_hk_b2 = fma_pair( + r_h[1, i], r_h[1, i + 1], r_k[i], r_k[i + 1], sum_hk_b, sum_hk_b2 + ) + sum_hk_c, sum_hk_c2 = fma_pair( + r_h[2, i], r_h[2, i + 1], r_k[i], r_k[i + 1], sum_hk_c, sum_hk_c2 + ) + sum_hk_d, sum_hk_d2 = fma_pair( + r_h[3, i], r_h[3, i + 1], r_k[i], r_k[i + 1], sum_hk_d, sum_hk_d2 + ) + sum_hk_a = sum_hk_a + sum_hk_a2 + sum_hk_b = sum_hk_b + sum_hk_b2 + sum_hk_c = sum_hk_c + sum_hk_c2 + sum_hk_d = sum_hk_d + sum_hk_d2 + + # Full-warp reduction for all 4 h@k dot products. + for offset in [16, 8, 4, 2, 1]: + sum_hk_a += cute.arch.shuffle_sync_bfly( + sum_hk_a, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_hk_b += cute.arch.shuffle_sync_bfly( + sum_hk_b, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_hk_c += cute.arch.shuffle_sync_bfly( + sum_hk_c, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_hk_d += cute.arch.shuffle_sync_bfly( + sum_hk_d, offset=offset, mask=-1, mask_and_clamp=31 + ) + + # Step 3: delta rule for all 4 rows. v from SMEM or GMEM. + if cutlass.const_expr(use_smem_v): + v_local_a = v_idx_a - i_v * tile_v + r_v_a = sVdata[(i_t, v_local_a)] + r_v_b = sVdata[(i_t, v_local_a + 1)] + r_v_c = sVdata[(i_t, v_local_a + 2)] + r_v_d = sVdata[(i_t, v_local_a + 3)] + else: + r_v_a = cutlass.Float32(v[i_n, i_t, i_hv, v_idx_a]) + r_v_b = cutlass.Float32(v[i_n, i_t, i_hv, v_idx_b]) + r_v_c = cutlass.Float32(v[i_n, i_t, i_hv, v_idx_c]) + r_v_d = cutlass.Float32(v[i_n, i_t, i_hv, v_idx_d]) + v_new_a = (r_v_a - sum_hk_a) * r_beta + v_new_b = (r_v_b - sum_hk_b) * r_beta + v_new_c = (r_v_c - sum_hk_c) * r_beta + v_new_d = (r_v_d - sum_hk_d) * r_beta + + # Steps 4+5 FUSED: rank-1 update with raw k (step 4) then + # h@q (step 5), per row. Double accumulators again. + sum_hq_a = cutlass.Float32(0.0) + sum_hq_a2 = cutlass.Float32(0.0) + sum_hq_b = cutlass.Float32(0.0) + sum_hq_b2 = cutlass.Float32(0.0) + sum_hq_c = cutlass.Float32(0.0) + sum_hq_c2 = cutlass.Float32(0.0) + sum_hq_d = cutlass.Float32(0.0) + sum_hq_d2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, vec_size, 2): + if cutlass.const_expr(use_packed_fma): + r_h[0, i], r_h[0, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(v_new_a, v_new_a), + src_c=(r_h[0, i], r_h[0, i + 1]), + ) + r_h[1, i], r_h[1, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(v_new_b, v_new_b), + src_c=(r_h[1, i], r_h[1, i + 1]), + ) + r_h[2, i], r_h[2, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(v_new_c, v_new_c), + src_c=(r_h[2, i], r_h[2, i + 1]), + ) + r_h[3, i], r_h[3, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(v_new_d, v_new_d), + src_c=(r_h[3, i], r_h[3, i + 1]), + ) + sum_hq_a, sum_hq_a2 = cute.arch.fma_packed_f32x2( + src_a=(r_h[0, i], r_h[0, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(sum_hq_a, sum_hq_a2), + ) + sum_hq_b, sum_hq_b2 = cute.arch.fma_packed_f32x2( + src_a=(r_h[1, i], r_h[1, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(sum_hq_b, sum_hq_b2), + ) + sum_hq_c, sum_hq_c2 = cute.arch.fma_packed_f32x2( + src_a=(r_h[2, i], r_h[2, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(sum_hq_c, sum_hq_c2), + ) + sum_hq_d, sum_hq_d2 = cute.arch.fma_packed_f32x2( + src_a=(r_h[3, i], r_h[3, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(sum_hq_d, sum_hq_d2), + ) + else: + r_h[0, i], r_h[0, i + 1] = fma_pair( + r_k[i], r_k[i + 1], v_new_a, v_new_a, r_h[0, i], r_h[0, i + 1] + ) + r_h[1, i], r_h[1, i + 1] = fma_pair( + r_k[i], r_k[i + 1], v_new_b, v_new_b, r_h[1, i], r_h[1, i + 1] + ) + r_h[2, i], r_h[2, i + 1] = fma_pair( + r_k[i], r_k[i + 1], v_new_c, v_new_c, r_h[2, i], r_h[2, i + 1] + ) + r_h[3, i], r_h[3, i + 1] = fma_pair( + r_k[i], r_k[i + 1], v_new_d, v_new_d, r_h[3, i], r_h[3, i + 1] + ) + sum_hq_a, sum_hq_a2 = fma_pair( + r_h[0, i], r_h[0, i + 1], r_q[i], r_q[i + 1], sum_hq_a, sum_hq_a2 + ) + sum_hq_b, sum_hq_b2 = fma_pair( + r_h[1, i], r_h[1, i + 1], r_q[i], r_q[i + 1], sum_hq_b, sum_hq_b2 + ) + sum_hq_c, sum_hq_c2 = fma_pair( + r_h[2, i], r_h[2, i + 1], r_q[i], r_q[i + 1], sum_hq_c, sum_hq_c2 + ) + sum_hq_d, sum_hq_d2 = fma_pair( + r_h[3, i], r_h[3, i + 1], r_q[i], r_q[i + 1], sum_hq_d, sum_hq_d2 + ) + sum_hq_a = sum_hq_a + sum_hq_a2 + sum_hq_b = sum_hq_b + sum_hq_b2 + sum_hq_c = sum_hq_c + sum_hq_c2 + sum_hq_d = sum_hq_d + sum_hq_d2 + + # Full-warp reduction for all 4 h@q dot products. + for offset in [16, 8, 4, 2, 1]: + sum_hq_a += cute.arch.shuffle_sync_bfly( + sum_hq_a, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_hq_b += cute.arch.shuffle_sync_bfly( + sum_hq_b, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_hq_c += cute.arch.shuffle_sync_bfly( + sum_hq_c, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_hq_d += cute.arch.shuffle_sync_bfly( + sum_hq_d, offset=offset, mask=-1, mask_and_clamp=31 + ) + + # Reduction result is identical on all lanes -> lane 0 + # writes. To SMEM (merged flush at kernel end) or GMEM. + if lane_in_group == 0: + if cutlass.const_expr(use_smem_v): + vla = v_idx_a - i_v * tile_v + sOutput[(i_t, vla)] = cutlass.BFloat16(sum_hq_a) + sOutput[(i_t, vla + 1)] = cutlass.BFloat16(sum_hq_b) + sOutput[(i_t, vla + 2)] = cutlass.BFloat16(sum_hq_c) + sOutput[(i_t, vla + 3)] = cutlass.BFloat16(sum_hq_d) + else: + o[(i_n, i_t, i_hv, v_idx_a)] = cutlass.BFloat16(sum_hq_a) + o[(i_n, i_t, i_hv, v_idx_b)] = cutlass.BFloat16(sum_hq_b) + o[(i_n, i_t, i_hv, v_idx_c)] = cutlass.BFloat16(sum_hq_c) + o[(i_n, i_t, i_hv, v_idx_d)] = cutlass.BFloat16(sum_hq_d) + + # Stage D: snapshot post-token state (sequence-indexed), + # last here since fused 4+5 means r_h is final only now. + if cutlass.const_expr(cache_intermediate_states): + flat_idx = i_n * T * HV + i_t * HV + i_hv + inter_a = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v_idx_a, lane_in_group), + ) + cute.autovec_copy(cute.slice_(r_h, (0, None)), inter_a) + inter_b = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v_idx_b, lane_in_group), + ) + cute.autovec_copy(cute.slice_(r_h, (1, None)), inter_b) + inter_c = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v_idx_c, lane_in_group), + ) + cute.autovec_copy(cute.slice_(r_h, (2, None)), inter_c) + inter_d = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v_idx_d, lane_in_group), + ) + cute.autovec_copy(cute.slice_(r_h, (3, None)), inter_d) + + # Write final state for all 4 rows back to the pool (once). + if cutlass.const_expr(not disable_state_update): + h_tile_out_a = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_idx_a, lane_in_group), + ) + cute.autovec_copy(cute.slice_(r_h, (0, None)), h_tile_out_a) + h_tile_out_b = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_idx_b, lane_in_group), + ) + cute.autovec_copy(cute.slice_(r_h, (1, None)), h_tile_out_b) + h_tile_out_c = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_idx_c, lane_in_group), + ) + cute.autovec_copy(cute.slice_(r_h, (2, None)), h_tile_out_c) + h_tile_out_d = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_idx_d, lane_in_group), + ) + cute.autovec_copy(cute.slice_(r_h, (3, None)), h_tile_out_d) + + # ============ Merged output writeback (use_smem_v only) ============ + # Barrier publishes all groups' disjoint lane-0 sOutput writes, then all 128 + # threads flush sOutput -> o (one tile-local column each, all T tokens) so the + # GMEM writes coalesce. Inside `cache_idx >= 0` so the barrier never deadlocks. + if cutlass.const_expr(use_smem_v): + cute.arch.barrier() + v_tile_base = i_v * tile_v + for t_idx in cutlass.range_constexpr(T): + if tidx < tile_v: + v_global = v_tile_base + tidx + if v_global < V: + o[(i_n, t_idx, i_hv, v_global)] = sOutput[(t_idx, tidx)] + + +@cute.jit +def run_kda_verify_kernel_mtp_ws( + h0_source: cute.Tensor, + intermediate_states: cute.Tensor, + A_log: cute.Tensor, + a: cute.Tensor, + dt_bias: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + b: cute.Tensor, + o: cute.Tensor, + h0_indices: cute.Tensor, + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + tile_v: cutlass.Constexpr[int], + vec_size: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + disable_state_update: cutlass.Constexpr[bool], + ilp_rows: cutlass.Constexpr[int], + use_packed_fma: cutlass.Constexpr[bool], + use_smem_v: cutlass.Constexpr[bool], + cache_intermediate_states: cutlass.Constexpr[bool], + fast_math: cutlass.Constexpr[bool], + stream: cuda.CUstream, +): + """Host-side launcher: grid = N * HV * num_v_tiles, block = 128 (4 warps).""" + n_indices = h0_indices.layout.shape[0] + v_dim = h0_source.layout.shape[1] + k_dim = h0_source.layout.shape[2] + + num_v_tiles = cute.ceil_div(v_dim, tile_v) + grid_size = n_indices * HV * num_v_tiles + + smem_bytes = ( + 4 * T * (k_dim + 8) # sQ + + 4 * T * (k_dim + 8) # sK + + 4 * T * (k_dim + 8) # sG (per-channel) + + 4 * T # sBeta + + 128 # alignment slack + ) + if cutlass.const_expr(use_smem_v): + smem_bytes += 4 * T * tile_v # sVdata (fp32) + smem_bytes += 2 * T * tile_v # sOutput (bf16) + + kda_verify_kernel_mtp_ws( + h0_source, + intermediate_states, + vec_size, + num_v_tiles, + tile_v, + A_log, + a, + dt_bias, + q, + k, + v, + b, + o, + h0_indices, + softplus_beta, + softplus_threshold, + scale, + HV, + T, + H, + K, + V, + use_qk_l2norm, + disable_state_update, + ilp_rows, + use_packed_fma, + use_smem_v, + cache_intermediate_states, + fast_math, + ).launch( + grid=(grid_size, 1, 1), + block=[NUM_THREADS, 1, 1], + smem=smem_bytes, + stream=stream, + ) + + +def _get_compiled_mtp_ws_kernel( + N, + T, + H, + HV, + K, + V, + pool_size, + scale, + use_qk_l2norm, + disable_state_update, + softplus_beta, + softplus_threshold, + tile_v, + ilp_rows, + use_packed_fma, + use_smem_v, + cache_intermediate_states, + opt_level=3, + fast_math=True, +): + """Get or lazily compile the warp-spec MTP kernel for one shape/config. + + ``opt_level`` (``--opt-level``) and ``fast_math`` are part of the cache key. + """ + key = ( + N, + T, + H, + HV, + K, + V, + pool_size, + scale, + use_qk_l2norm, + disable_state_update, + softplus_beta, + softplus_threshold, + tile_v, + ilp_rows, + use_packed_fma, + use_smem_v, + cache_intermediate_states, + opt_level, + fast_math, + ) + if key in _compiled_mtp_ws_kernels: + return _compiled_mtp_ws_kernels[key] + + q = torch.zeros(N, T, H, K, dtype=torch.bfloat16, device="cuda") + k = torch.zeros(N, T, H, K, dtype=torch.bfloat16, device="cuda") + v = torch.zeros(N, T, HV, V, dtype=torch.bfloat16, device="cuda") + a = torch.zeros(N, T, HV, K, dtype=torch.bfloat16, device="cuda") + b = torch.zeros(N, T, HV, dtype=torch.bfloat16, device="cuda") + o = torch.zeros(N, T, HV, V, dtype=torch.bfloat16, device="cuda") + A_log = torch.zeros(HV, dtype=torch.float32, device="cuda") + dt_bias = torch.zeros(HV, K, dtype=torch.float32, device="cuda") + # Warp-spec kernel uses the flat 3D state view [pool*HV, V, K] (VK layout). + h0_source = torch.zeros(pool_size * HV, V, K, dtype=torch.float32, device="cuda") + h0_indices = torch.zeros(N, dtype=torch.int32, device="cuda") + if cache_intermediate_states: + intermediate_states = torch.zeros( + N * T * HV, V, K, dtype=torch.float32, device="cuda" + ) + else: + intermediate_states = torch.zeros(1, 1, 1, dtype=torch.float32, device="cuda") + + q_tensor = from_dlpack(q, assumed_align=16) + k_tensor = from_dlpack(k, assumed_align=16) + v_tensor = from_dlpack(v, assumed_align=16) + a_tensor = from_dlpack(a, assumed_align=16) + b_tensor = from_dlpack(b, assumed_align=16) + A_log_tensor = from_dlpack(A_log, assumed_align=16) + dt_bias_tensor = from_dlpack(dt_bias, assumed_align=16) + h0_source_tensor = from_dlpack(h0_source, assumed_align=16) + h0_indices_tensor = from_dlpack(h0_indices, assumed_align=16) + o_tensor = from_dlpack(o, assumed_align=16) + intermediate_states_tensor = from_dlpack(intermediate_states, assumed_align=16) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compiled_kernel = cute.compile( + run_kda_verify_kernel_mtp_ws, + h0_source_tensor, + intermediate_states_tensor, + A_log_tensor, + a_tensor, + dt_bias_tensor, + q_tensor, + k_tensor, + v_tensor, + b_tensor, + o_tensor, + h0_indices_tensor, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + scale=scale, + HV=HV, + T=T, + H=H, + K=K, + V=V, + tile_v=tile_v, + vec_size=VEC_SIZE_MTP, + use_qk_l2norm=use_qk_l2norm, + disable_state_update=disable_state_update, + ilp_rows=ilp_rows, + use_packed_fma=use_packed_fma, + use_smem_v=use_smem_v, + cache_intermediate_states=cache_intermediate_states, + fast_math=fast_math, + stream=stream, + options=f"--enable-tvm-ffi --opt-level {opt_level}", + ) + + _compiled_mtp_ws_kernels[key] = compiled_kernel + logger.info( + "CuTe DSL KDA MTP warp-spec kernel compiled: " + f"N={N}, T={T}, H={H}, HV={HV}, K={K}, V={V}, pool_size={pool_size}, " + f"tile_v={tile_v}, ilp_rows={ilp_rows}, use_packed_fma={use_packed_fma}, " + f"use_smem_v={use_smem_v}, cache_intermediate_states={cache_intermediate_states}" + ) + return compiled_kernel + + +def kda_decode_mtp_ws( + A_log: torch.Tensor, + dt_bias: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + initial_state_source: torch.Tensor, + initial_state_indices: torch.Tensor, + scale: float | None = None, + use_qk_l2norm_in_kernel: bool = True, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, + out: torch.Tensor | None = None, + state_layout: str = "vk", + tile_v: int | None = None, + ilp_rows: int | None = None, + disable_state_update: bool = False, + use_packed_fma: bool | None = None, + use_smem_v: bool | None = None, + intermediate_states_buffer: torch.Tensor | None = None, +) -> torch.Tensor: + N, T, H, K = q.shape + HV = v.shape[2] + V = v.shape[3] + + if scale is None: + scale = K**-0.5 + else: + assert scale > 0, f"scale must be positive, got {scale}" + + assert K == TILE_K, f"KDA MTP (ws) kernel requires K={TILE_K}, got {K}" + + # Resolve tile_v / ilp_rows / use_smem_v from the work_units=N*HV heuristic + # where not given explicitly. An explicit tile_v can make the heuristic's ilp=4 + # illegal (needs tile_v % 16 == 0); the auto path then falls back to ilp=2. + if tile_v is None or ilp_rows is None or use_smem_v is None: + sel_tile_v, sel_ilp_rows, sel_use_smem_v = _select_mtp_config( + N, HV, V, T, disable_state_update=disable_state_update + ) + if tile_v is None: + if intermediate_states_buffer is not None and N >= 8 and V % 16 == 0: + # write-bound: smaller tile = more CTAs = more in-flight DRAM requests + tile_v = 16 + else: + tile_v = sel_tile_v + if ilp_rows is None: + ilp_rows = sel_ilp_rows + if ilp_rows == 4 and tile_v % 16 != 0: + ilp_rows = 2 + if use_smem_v is None: + use_smem_v = sel_use_smem_v + + if ilp_rows not in (2, 4): + raise NotImplementedError( + f"kda_decode_mtp_ws implements ilp_rows in {{2, 4}}, got {ilp_rows}" + ) + + # packed F32x2 FMA exists only on SM100+ (Blackwell) + if use_packed_fma is None: + major, _ = torch.cuda.get_device_capability(q.device) + use_packed_fma = major >= 10 + # The packed path only exists in the ilp=4 kernel branch; ilp=2 is scalar. + if ilp_rows != 4: + use_packed_fma = False + + state_layout = _canonicalize_state_layout(state_layout) + if state_layout != "vk": + raise NotImplementedError( + "kda_decode_mtp_ws only supports state_layout='vk'; " + f"got {state_layout!r}" + ) + + assert tile_v % 4 == 0, f"KDA MTP (ws) requires tile_v % 4 == 0, got tile_v={tile_v}" + assert V % tile_v == 0, f"KDA MTP (ws) requires V % tile_v == 0, got V={V}, tile_v={tile_v}" + + rows_per_group = tile_v // 4 + assert rows_per_group % ilp_rows == 0, ( + f"ilp_rows={ilp_rows} requires (tile_v//4) divisible by {ilp_rows}, " + f"got tile_v={tile_v} (tile_v//4={rows_per_group})" + ) + + # State is token-independent: reuse the single-token normalizer/validator. + h0_source, pool_size, state_layout_is_kv = _normalize_state_source( + initial_state_source, + N=N, + HV=HV, + K=K, + V=V, + device=q.device, + state_layout=state_layout, + ) + assert not state_layout_is_kv # guaranteed by the vk-only guard above + + a = _normalize_mtp_a(a, N=N, T=T, HV=HV, K=K) + if b.dim() != 3 or tuple(b.shape) != (N, T, HV): + raise ValueError(f"Unexpected b shape for MTP dense: {tuple(b.shape)}; expected {(N, T, HV)}") + + o = _prepare_output_tensor(q, out, (N, T, HV, V)) + + q = q if q.is_contiguous() else q.contiguous() + k = k if k.is_contiguous() else k.contiguous() + v = v if v.is_contiguous() else v.contiguous() + a = a if a.is_contiguous() else a.contiguous() + b = b if b.is_contiguous() else b.contiguous() + + A_log = _normalize_A_log(A_log, HV) + dt_bias = _normalize_dt_bias(dt_bias, HV, K) + initial_state_indices = _normalize_state_indices( + initial_state_indices, N=N, pool_size=pool_size, device=q.device + ) + + # Flatten the VK state pool [pool, HV, V, K] -> [pool*HV, V, K] + h0_source_flat = h0_source.view(pool_size * HV, V, K) + + # Stage D: resolve the snapshot cache. + cache_intermediate_states = intermediate_states_buffer is not None + if cache_intermediate_states: + if intermediate_states_buffer.dtype != torch.float32: + raise ValueError( + "intermediate_states_buffer must be float32, got " + f"{intermediate_states_buffer.dtype}" + ) + expected_buf_shape = (N, T, HV, V, K) + if tuple(intermediate_states_buffer.shape) != expected_buf_shape: + raise ValueError( + f"intermediate_states_buffer shape {tuple(intermediate_states_buffer.shape)} " + f"!= expected {expected_buf_shape} ([N, T, HV, V, K] vk / K-last)" + ) + intermediate_states_flat = intermediate_states_buffer.view(N * T * HV, V, K) + else: + intermediate_states_flat = torch.empty( + 1, 1, 1, dtype=torch.float32, device=q.device + ) + + stream = _get_cached_stream(q.device) + + compiled_kernel = _get_compiled_mtp_ws_kernel( + N, + T, + H, + HV, + K, + V, + pool_size, + scale=scale, + use_qk_l2norm=use_qk_l2norm_in_kernel, + disable_state_update=disable_state_update, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + tile_v=tile_v, + ilp_rows=ilp_rows, + use_packed_fma=use_packed_fma, + use_smem_v=use_smem_v, + cache_intermediate_states=cache_intermediate_states, + ) + + compiled_kernel( + h0_source_flat, + intermediate_states_flat, + A_log, + a, + dt_bias, + q, + k, + v, + b, + o, + initial_state_indices, + stream, + ) + + return o + + +# ============================================================================ +# small_batch kernel (1-warp/program):kv layout(lane=V)+ vk layout(lane=K) +# ============================================================================ + + +WARP_BV = 32 +VEC_SIZE = 4 + +_compiled_mtp_small_batch_kernels: dict[tuple, object] = {} + + +@cute.kernel +def kda_mtp_small_batch_kernel( + h0_source: cute.Tensor, # [pool*HV, K, V] fp32 (kv, V-last) + A_log: cute.Tensor, # [HV] fp32 + a: cute.Tensor, # [N, T, HV, K] + dt_bias: cute.Tensor, # [HV, K] + q: cute.Tensor, # [N, T, H, K] + k: cute.Tensor, # [N, T, H, K] + v: cute.Tensor, # [N, T, HV, V] + b: cute.Tensor, # [N, T, HV] + o: cute.Tensor, # [N, T, HV, V] + h0_indices: cute.Tensor, # [N] int32 + vec_size: cutlass.Constexpr[int], + num_v_tiles: cutlass.Constexpr[int], + BV: cutlass.Constexpr[int], + k_split: cutlass.Constexpr[int], + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + disable_state_update: cutlass.Constexpr[bool], + fast_math: cutlass.Constexpr[bool], +): + tidx, _, _ = cute.arch.thread_idx() + lane = tidx + + bidx, _, _ = cute.arch.block_idx() + i_v = bidx % num_v_tiles # flat CTA -> (i_n, i_hv, i_v V-block) + tmp = bidx // num_v_tiles + i_hv = tmp % HV + i_n = tmp // HV + i_h = i_hv // (HV // H) + + cache_idx = h0_indices[i_n] + r_exp_A = cute.exp(cutlass.Float32(A_log[i_hv]), fastmath=fast_math) # per-head, shared across T + + # SMEM-broadcast q/k/g (shared across V-cols on K dim); XOR swizzle staggers k_split segments across banks. + smem_k = K + smem = cutlass.utils.SmemAllocator() + sQ = smem.allocate_tensor(cutlass.Float32, cute.make_layout((smem_k,), stride=(1,)), 16) + sK = smem.allocate_tensor(cutlass.Float32, cute.make_layout((smem_k,), stride=(1,)), 16) + sG = smem.allocate_tensor(cutlass.Float32, cute.make_layout((smem_k,), stride=(1,)), 16) + + # k_split lanes split one V-col's K (each holds k_per_lane), butterfly-merged after reduce. + k_per_lane = K // k_split + v_local = lane % BV + k_part = lane // BV + k_off = k_part * k_per_lane + + r_h = cute.make_rmem_tensor(cute.make_layout((k_per_lane,), stride=(1,)), cutlass.Float32) + r_q = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_k = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_q_bf16 = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) + r_k_bf16 = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) + + v_global = i_v * BV + v_local # global V-col this lane serves + k_start = lane * vec_size # prep: full warp, 32 lanes x 4 = all 128 K + + # constexpr k_split decisions hoisted to top level so they stay python + # constants inside the cache_idx>=0 block (else reboxed to Int32 -> error). + ks_single = cutlass.const_expr(k_split == 1) + ks_log2 = cutlass.const_expr(k_split.bit_length() - 1) + if cache_idx >= 0: + flat_state_idx = cache_idx * HV + i_hv + for j in cutlass.range_constexpr(k_per_lane): + r_h[j] = cutlass.Float32(h0_source[flat_state_idx, k_off + j, v_global]) + + for i_t in cutlass.range_constexpr(T): + q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane)) + k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane)) + cute.autovec_copy(q_tile, r_q_bf16) + cute.autovec_copy(k_tile, r_k_bf16) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = cutlass.Float32(r_q_bf16[i]) + r_k[i] = cutlass.Float32(r_k_bf16[i]) + + if cutlass.const_expr(use_qk_l2norm): + sum_q = cutlass.Float32(0.0) + sum_k = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(vec_size): + sum_q += r_q[i] * r_q[i] + sum_k += r_k[i] * r_k[i] + for offset in [16, 8, 4, 2, 1]: + sum_q += cute.arch.shuffle_sync_bfly(sum_q, offset=offset, mask=-1, mask_and_clamp=31) + sum_k += cute.arch.shuffle_sync_bfly(sum_k, offset=offset, mask=-1, mask_and_clamp=31) + inv_q = cute.rsqrt(sum_q + 1e-6, fastmath=fast_math) * scale + inv_k = cute.rsqrt(sum_k + 1e-6, fastmath=fast_math) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * inv_q + r_k[i] = r_k[i] * inv_k + else: + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * scale + + for i in cutlass.range_constexpr(vec_size): + kk = k_start + i + sw = kk ^ (kk // k_per_lane) # XOR swizzle SMEM write addr (a/dt_bias read GMEM with raw kk) + x = cutlass.Float32(a[i_n, i_t, i_hv, kk]) + cutlass.Float32(dt_bias[i_hv, kk]) + beta_x = softplus_beta * x + exp_bx = cute.exp(beta_x, fastmath=fast_math) + sp_val = (cutlass.Float32(1.0) / softplus_beta) * cute.log( + cutlass.Float32(1.0) + exp_bx, fastmath=fast_math + ) + use_sp = ( + cutlass.Float32(1.0) + if beta_x <= softplus_threshold + else cutlass.Float32(0.0) + ) + sp_x = use_sp * sp_val + (cutlass.Float32(1.0) - use_sp) * x + sG[sw] = cute.exp(-r_exp_A * sp_x, fastmath=fast_math) + sQ[sw] = r_q[i] + sK[sw] = r_k[i] + + r_beta = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + + cute.exp(-cutlass.Float32(b[i_n, i_t, i_hv]), fastmath=fast_math) + ) + + cute.arch.barrier() # publish prep's SMEM writes before recurrence reads + + r_v = cutlass.Float32(v[i_n, i_t, i_hv, v_global]) + # fused decay + s partial. + s = cutlass.Float32(0.0) + for j in cutlass.range_constexpr(k_per_lane): + sw = j if ks_single else (k_off + j) ^ k_part # XOR swizzle read addr = swz(k_off+j) + r_h[j] = r_h[j] * sG[sw] + s += r_h[j] * sK[sw] + for st in cutlass.range_constexpr(ks_log2): + s += cute.arch.shuffle_sync_bfly(s, offset=BV << st, mask=-1, mask_and_clamp=31) + v_new = (r_v - s) * r_beta + o_val = cutlass.Float32(0.0) + for j in cutlass.range_constexpr(k_per_lane): + sw = j if ks_single else (k_off + j) ^ k_part # XOR swizzle read addr + r_h[j] = r_h[j] + sK[sw] * v_new + o_val += r_h[j] * sQ[sw] + for st in cutlass.range_constexpr(ks_log2): + o_val += cute.arch.shuffle_sync_bfly(o_val, offset=BV << st, mask=-1, mask_and_clamp=31) + o[(i_n, i_t, i_hv, v_global)] = cutlass.BFloat16(o_val) + + cute.arch.barrier() + + if cutlass.const_expr(not disable_state_update): + flat_state_idx = cache_idx * HV + i_hv + for j in cutlass.range_constexpr(k_per_lane): + h0_source[(flat_state_idx, k_off + j, v_global)] = r_h[j] + + +@cute.jit +def run_kda_mtp_small_batch_kernel( + h0_source: cute.Tensor, + A_log: cute.Tensor, + a: cute.Tensor, + dt_bias: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + b: cute.Tensor, + o: cute.Tensor, + h0_indices: cute.Tensor, + vec_size: cutlass.Constexpr[int], + BV: cutlass.Constexpr[int], + k_split: cutlass.Constexpr[int], + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + disable_state_update: cutlass.Constexpr[bool], + fast_math: cutlass.Constexpr[bool], + stream: cuda.CUstream, +): + n_indices = h0_indices.layout.shape[0] + num_v_tiles = cute.ceil_div(V, BV) + grid_size = n_indices * HV * num_v_tiles + + smem_bytes = 3 * K * 4 + 256 # sQ + sK + sG + + kda_mtp_small_batch_kernel( + h0_source, + A_log, + a, + dt_bias, + q, + k, + v, + b, + o, + h0_indices, + vec_size, + num_v_tiles, + BV, + k_split, + softplus_beta, + softplus_threshold, + scale, + HV, + T, + H, + K, + V, + use_qk_l2norm, + disable_state_update, + fast_math, + ).launch( + grid=(grid_size, 1, 1), + block=[32, 1, 1], + smem=smem_bytes, + stream=stream, + ) + + +def _get_compiled_mtp_small_batch_kernel( + N, + T, + H, + HV, + K, + V, + pool_size, + BV, + k_split, + scale, + use_qk_l2norm, + disable_state_update, + softplus_beta, + softplus_threshold, + opt_level=3, + fast_math=True, +): + key = ( + N, + T, + H, + HV, + K, + V, + pool_size, + BV, + k_split, + scale, + use_qk_l2norm, + disable_state_update, + softplus_beta, + softplus_threshold, + opt_level, + fast_math, + ) + if key in _compiled_mtp_small_batch_kernels: + return _compiled_mtp_small_batch_kernels[key] + + q = torch.zeros(N, T, H, K, dtype=torch.bfloat16, device="cuda") + k = torch.zeros(N, T, H, K, dtype=torch.bfloat16, device="cuda") + v = torch.zeros(N, T, HV, V, dtype=torch.bfloat16, device="cuda") + a = torch.zeros(N, T, HV, K, dtype=torch.bfloat16, device="cuda") + b = torch.zeros(N, T, HV, dtype=torch.bfloat16, device="cuda") + o = torch.zeros(N, T, HV, V, dtype=torch.bfloat16, device="cuda") + A_log = torch.zeros(HV, dtype=torch.float32, device="cuda") + dt_bias = torch.zeros(HV, K, dtype=torch.float32, device="cuda") + h0_source = torch.zeros(pool_size * HV, K, V, dtype=torch.float32, device="cuda") # kv + h0_indices = torch.zeros(N, dtype=torch.int32, device="cuda") + + q_t = from_dlpack(q, assumed_align=16) + k_t = from_dlpack(k, assumed_align=16) + v_t = from_dlpack(v, assumed_align=16) + a_t = from_dlpack(a, assumed_align=16) + b_t = from_dlpack(b, assumed_align=16) + o_t = from_dlpack(o, assumed_align=16) + A_log_t = from_dlpack(A_log, assumed_align=16) + dt_bias_t = from_dlpack(dt_bias, assumed_align=16) + h0_source_t = from_dlpack(h0_source, assumed_align=16) + h0_indices_t = from_dlpack(h0_indices, assumed_align=16) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compiled_kernel = cute.compile( + run_kda_mtp_small_batch_kernel, + h0_source_t, + A_log_t, + a_t, + dt_bias_t, + q_t, + k_t, + v_t, + b_t, + o_t, + h0_indices_t, + vec_size=VEC_SIZE, + BV=BV, + k_split=k_split, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + scale=scale, + HV=HV, + T=T, + H=H, + K=K, + V=V, + use_qk_l2norm=use_qk_l2norm, + disable_state_update=disable_state_update, + fast_math=fast_math, + stream=stream, + options=f"--enable-tvm-ffi --opt-level {opt_level}", + ) + + _compiled_mtp_small_batch_kernels[key] = compiled_kernel + logger.info( + "CuTe DSL KDA MTP small-batch kernel compiled: " + f"N={N}, T={T}, H={H}, HV={HV}, K={K}, V={V}, pool_size={pool_size}, BV={BV}, " + f"k_split={k_split}, opt_level={opt_level}, fast_math={fast_math}" + ) + return compiled_kernel + + +_KV_CTAS_PER_SM = {1: 8, 2: 12, 4: 16} + + +def _select_k_split(work_units, V, num_sms): + waves1 = work_units * (V // 32) / (num_sms * _KV_CTAS_PER_SM[1]) + for ks, thresh in ((4, 0.3), (2, 0.6)): + vcols = 32 // ks + if V % vcols == 0 and waves1 < thresh: + return ks + return 1 + + +def kda_decode_mtp_small_batch( + A_log: torch.Tensor, + dt_bias: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + initial_state_source: torch.Tensor, + initial_state_indices: torch.Tensor, + scale: float | None = None, + use_qk_l2norm_in_kernel: bool = True, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, + out: torch.Tensor | None = None, + disable_state_update: bool = False, + variant: str = "kv", + bv: int = WARP_BV, + k_split: int = 1, + opt_level: int = 3, + fast_math: bool = True, + intermediate_states_buffer: torch.Tensor | None = None, +) -> torch.Tensor: + assert variant in ("kv", "vk"), f"variant only supports 'kv'/'vk',got {variant!r}" + N, T, H, K = q.shape + HV = v.shape[2] + V = v.shape[3] + + if scale is None: + scale = K**-0.5 + else: + assert scale > 0, f"scale must be positive, got {scale}" + + assert K == TILE_K, f"KDA MTP (small_batch) requires K={TILE_K}, got {K}" + assert K % VEC_SIZE == 0 and K // VEC_SIZE == 32, ( + f"small_batch assumes K//vec_size==32, got K={K}, vec_size={VEC_SIZE}" + ) + + if variant == "kv": + state_layout = "kv" + assert bv == WARP_BV, f"small_batch(kv) supports 1 warp,bv must be {WARP_BV},got {bv}" + if k_split <= 0: + num_sms = torch.cuda.get_device_properties(q.device).multi_processor_count + k_split = _select_k_split(N * HV, V, num_sms) + assert k_split in (1, 2, 4), f"k_split only supports 1/2/4 or <=0(auto),got {k_split}" + assert bv % k_split == 0 and K % k_split == 0, ( + f"requires bv%k_split==0 and K%k_split==0, got bv={bv}, K={K}, k_split={k_split}" + ) + vcols = bv // k_split + assert V % vcols == 0, f"small_batch(kv) requires V % (bv//k_split) == 0, got V={V}, vcols={vcols}" + else: # vk + state_layout = "vk" + if bv <= 0: + num_sms = torch.cuda.get_device_properties(q.device).multi_processor_count + bv = _select_vk_bv(N * HV, V, num_sms) + assert bv in (8, 16, 32), f"vk bv only supports 8/16/32 or <=0(auto),got {bv}" + assert V % bv == 0, f"vk requires V % bv == 0, got V={V}, bv={bv}" + + h0_source, pool_size, _ = _normalize_state_source( + initial_state_source, N=N, HV=HV, K=K, V=V, device=q.device, state_layout=state_layout, + ) + + a = _normalize_mtp_a(a, N=N, T=T, HV=HV, K=K) + if b.dim() != 3 or tuple(b.shape) != (N, T, HV): + raise ValueError(f"Unexpected b shape for MTP dense: {tuple(b.shape)}; expected {(N, T, HV)}") + + o = _prepare_output_tensor(q, out, (N, T, HV, V)) + + q = q if q.is_contiguous() else q.contiguous() + k = k if k.is_contiguous() else k.contiguous() + v = v if v.is_contiguous() else v.contiguous() + a = a if a.is_contiguous() else a.contiguous() + b = b if b.is_contiguous() else b.contiguous() + + A_log = _normalize_A_log(A_log, HV) + dt_bias = _normalize_dt_bias(dt_bias, HV, K) + initial_state_indices = _normalize_state_indices( + initial_state_indices, N=N, pool_size=pool_size, device=q.device + ) + + stream = _get_cached_stream(q.device) + + cache_intermediate_states = intermediate_states_buffer is not None + if cache_intermediate_states: + if variant != "vk": + raise NotImplementedError("intermediate_states_buffer only supported for variant='vk'") + if intermediate_states_buffer.dtype != torch.float32: + raise ValueError(f"intermediate_states_buffer must be float32, got {intermediate_states_buffer.dtype}") + if tuple(intermediate_states_buffer.shape) != (N, T, HV, V, K): + raise ValueError(f"intermediate_states_buffer shape {tuple(intermediate_states_buffer.shape)} != expected {(N, T, HV, V, K)} ([N,T,HV,V,K] vk)") + intermediate_states_flat = intermediate_states_buffer.view(N * T * HV, V, K) + else: + intermediate_states_flat = torch.empty(1, 1, 1, dtype=torch.float32, device=q.device) + + if variant == "kv": + h0_source_flat = h0_source.view(pool_size * HV, K, V) # kv + compiled_kernel = _get_compiled_mtp_small_batch_kernel( + N, T, H, HV, K, V, pool_size, vcols, k_split, + scale=scale, use_qk_l2norm=use_qk_l2norm_in_kernel, + disable_state_update=disable_state_update, + softplus_beta=softplus_beta, softplus_threshold=softplus_threshold, + opt_level=opt_level, fast_math=fast_math, + ) + else: # vk + h0_source_flat = h0_source.view(pool_size * HV, V, K) # vk + compiled_kernel = _get_compiled_mtp_vk_kernel( + N, T, H, HV, K, V, pool_size, bv, + scale=scale, use_qk_l2norm=use_qk_l2norm_in_kernel, + disable_state_update=disable_state_update, + softplus_beta=softplus_beta, softplus_threshold=softplus_threshold, + opt_level=opt_level, fast_math=fast_math, + cache_intermediate_states=cache_intermediate_states, + ) + + if variant == "vk": + compiled_kernel( + h0_source_flat, A_log, a, dt_bias, q, k, v, b, o, + intermediate_states_flat, initial_state_indices, stream, + ) + else: + compiled_kernel( + h0_source_flat, A_log, a, dt_bias, q, k, v, b, o, + initial_state_indices, stream, + ) + + return o + +@cute.kernel +def kda_mtp_small_batch_vk_kernel( + h0_source: cute.Tensor, # [pool*HV, V, K] fp32 (vk) + A_log: cute.Tensor, + a: cute.Tensor, + dt_bias: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + b: cute.Tensor, + o: cute.Tensor, + intermediate_states: cute.Tensor, + h0_indices: cute.Tensor, + vec_size: cutlass.Constexpr[int], + num_v_tiles: cutlass.Constexpr[int], + BV: cutlass.Constexpr[int], + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + disable_state_update: cutlass.Constexpr[bool], + cache_intermediate_states: cutlass.Constexpr[bool], + fast_math: cutlass.Constexpr[bool], +): + tidx, _, _ = cute.arch.thread_idx() + lane = tidx # 1 warp = 32 lanes + + bidx, _, _ = cute.arch.block_idx() + i_v = bidx % num_v_tiles + tmp = bidx // num_v_tiles + i_hv = tmp % HV + i_n = tmp // HV + i_h = i_hv // (HV // H) + + cache_idx = h0_indices[i_n] + r_exp_A = cute.exp(cutlass.Float32(A_log[i_hv]), fastmath=fast_math) + + # lane t holds vec_size contiguous K (K[4t:4t+4]) x all BV V-cols; r_h[vv*vec_size+c]=state[i_v*BV+vv, vec_size*lane+c]. + r_h = cute.make_rmem_tensor(cute.make_layout((BV * vec_size,), stride=(1,)), cutlass.Float32) + r_q = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_k = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_g = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_vbf = [cute.make_rmem_tensor(cute.make_layout((BV,), stride=(1,)), cutlass.BFloat16) for _ in range(2)] # v: bf16 double-buffer + r_red = cute.make_rmem_tensor(cute.make_layout((BV,), stride=(1,)), cutlass.Float32) # ILP: BV reduce partials, batched butterfly + r_gx = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) # gate: x=a+dtb + r_gexp = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) # gate: exp(beta_x) + r_h4 = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) # float4 temp buffer (state load/store) + # ===== 2-stage software-pipeline double-buffer: prefetch token t+1's q/k/a/b while computing token t ===== + r_qbf = [cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) for _ in range(2)] + r_kbf = [cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) for _ in range(2)] + r_abf = [cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) for _ in range(2)] + r_bbf = [cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), cutlass.Float32) for _ in range(2)] + r_dtb = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) # dt_bias + + # ===== state load (contiguous + float4: lane t takes K[4t:4t+4] ===== + if cache_idx >= 0: + flat_state_idx = cache_idx * HV + i_hv + for vv in cutlass.range_constexpr(BV): + v_global = i_v * BV + vv + # local_tile 3rd coord = lane, tile=vec_size -> contiguous K -> autovec float4 + h_tile = cute.local_tile(h0_source, (1, 1, vec_size), (flat_state_idx, v_global, lane)) + cute.autovec_copy(h_tile, r_h4) + for c in cutlass.range_constexpr(vec_size): + r_h[vv * vec_size + c] = r_h4[c] + + for c in cutlass.range_constexpr(vec_size): # dt_bias loaded once outside loop (contiguous K[4t:4t+4]) + r_dtb[c] = cutlass.Float32(dt_bias[i_hv, vec_size * lane + c]) + + # prefetch token 0's q/k/a/b into stage 0 (pipeline fill). + q_t0 = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, 0, i_h, lane)) + k_t0 = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, 0, i_h, lane)) + cute.autovec_copy(q_t0, r_qbf[0]) + cute.autovec_copy(k_t0, r_kbf[0]) + a_t0 = cute.local_tile(a, (1, 1, 1, vec_size), (i_n, 0, i_hv, lane)) + cute.autovec_copy(a_t0, r_abf[0]) + v_t0 = cute.local_tile(v, (1, 1, 1, BV), (i_n, 0, i_hv, i_v)) + cute.autovec_copy(v_t0, r_vbf[0]) + r_bbf[0][0] = cutlass.Float32(b[i_n, 0, i_hv]) + + for i_t in cutlass.range_constexpr(T): + cur = i_t % 2 + # ===== prefetch t+1's q/k/a/b ===== + if cutlass.const_expr(i_t + 1 < T): + nxt = (i_t + 1) % 2 + q_tn = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, i_t + 1, i_h, lane)) + k_tn = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i_t + 1, i_h, lane)) + cute.autovec_copy(q_tn, r_qbf[nxt]) + cute.autovec_copy(k_tn, r_kbf[nxt]) + a_tn = cute.local_tile(a, (1, 1, 1, vec_size), (i_n, i_t + 1, i_hv, lane)) + cute.autovec_copy(a_tn, r_abf[nxt]) + v_tn = cute.local_tile(v, (1, 1, 1, BV), (i_n, i_t + 1, i_hv, i_v)) + cute.autovec_copy(v_tn, r_vbf[nxt]) + r_bbf[nxt][0] = cutlass.Float32(b[i_n, i_t + 1, i_hv]) + + # ===== prep: read q/k + gate<->l2norm cross-pipe interleave ===== + for c in cutlass.range_constexpr(vec_size): + r_q[c] = cutlass.Float32(r_qbf[cur][c]) + r_k[c] = cutlass.Float32(r_kbf[cur][c]) + + # gate stage 1: x=a+dtb + for c in cutlass.range_constexpr(vec_size): + r_gx[c] = cutlass.Float32(r_abf[cur][c]) + r_dtb[c] # x = a + dt_bias + for c in cutlass.range_constexpr(vec_size): + r_gexp[c] = cute.exp(softplus_beta * r_gx[c], fastmath=fast_math) # exp(beta_x) + + if cutlass.const_expr(use_qk_l2norm): + sum_q = cutlass.Float32(0.0) + sum_k = cutlass.Float32(0.0) + for c in cutlass.range_constexpr(vec_size): + sum_q += r_q[c] * r_q[c] + sum_k += r_k[c] * r_k[c] + for off in [16, 8, 4, 2, 1]: + sum_q += cute.arch.shuffle_sync_bfly(sum_q, offset=off, mask=-1, mask_and_clamp=31) + sum_k += cute.arch.shuffle_sync_bfly(sum_k, offset=off, mask=-1, mask_and_clamp=31) + inv_q = cute.rsqrt(sum_q + 1e-6, fastmath=fast_math) * scale + inv_k = cute.rsqrt(sum_k + 1e-6, fastmath=fast_math) + for c in cutlass.range_constexpr(vec_size): + r_q[c] = r_q[c] * inv_q + r_k[c] = r_k[c] * inv_k + else: + for c in cutlass.range_constexpr(vec_size): + r_q[c] = r_q[c] * scale + + # gate stage 2: log + softplus select -> sp_x stashed in r_g + for c in cutlass.range_constexpr(vec_size): + beta_x = softplus_beta * r_gx[c] + sp_val = (cutlass.Float32(1.0) / softplus_beta) * cute.log( + cutlass.Float32(1.0) + r_gexp[c], fastmath=fast_math + ) + use_sp = ( + cutlass.Float32(1.0) + if beta_x <= softplus_threshold + else cutlass.Float32(0.0) + ) + r_g[c] = use_sp * sp_val + (cutlass.Float32(1.0) - use_sp) * r_gx[c] # stash sp_x + for c in cutlass.range_constexpr(vec_size): + r_g[c] = cute.exp(-r_exp_A * r_g[c], fastmath=fast_math) # final exp (batched) + + r_beta = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + + cute.exp(-r_bbf[cur][0], fastmath=fast_math) + ) + + # ===== recurrence (fused: decay+h@k in one pass / update+h@q in one pass) ===== + for vv in cutlass.range_constexpr(BV): + sv = cutlass.Float32(0.0) + for c in cutlass.range_constexpr(vec_size): + r_h[vv * vec_size + c] = r_h[vv * vec_size + c] * r_g[c] # decay: h *= exp(g) (per K) + sv += r_h[vv * vec_size + c] * r_k[c] # s = sum_k h*k_norm + r_red[vv] = sv + for off in [16, 8, 4, 2, 1]: + for vv in cutlass.range_constexpr(BV): + r_red[vv] = r_red[vv] + cute.arch.shuffle_sync_bfly(r_red[vv], offset=off, mask=-1, mask_and_clamp=31) + for vv in cutlass.range_constexpr(BV): + v_new = (cutlass.Float32(r_vbf[cur][vv]) - r_red[vv]) * r_beta # v_new = beta*(v - s) + ovv = cutlass.Float32(0.0) + for c in cutlass.range_constexpr(vec_size): + r_h[vv * vec_size + c] = r_h[vv * vec_size + c] + r_k[c] * v_new # rank-1 update: h += k*v_new + ovv += r_h[vv * vec_size + c] * r_q[c] # o = sum_k h*q_scaled (partial) + r_red[vv] = ovv + for off in [16, 8, 4, 2, 1]: + for vv in cutlass.range_constexpr(BV): + r_red[vv] = r_red[vv] + cute.arch.shuffle_sync_bfly(r_red[vv], offset=off, mask=-1, mask_and_clamp=31) + for vv in cutlass.range_constexpr(BV): + o[(i_n, i_t, i_hv, i_v * BV + vv)] = cutlass.BFloat16(r_red[vv]) + if cutlass.const_expr(cache_intermediate_states): # Stage-D snapshot: post-token-t state + flat_idx = i_n * T * HV + i_t * HV + i_hv + for vv in cutlass.range_constexpr(BV): + for c in cutlass.range_constexpr(vec_size): + r_h4[c] = r_h[vv * vec_size + c] + inter_tile = cute.local_tile(intermediate_states, (1, 1, vec_size), (flat_idx, i_v * BV + vv, lane)) + cute.autovec_copy(r_h4, inter_tile) + + # ===== epilogue: write state back ===== + if cutlass.const_expr(not disable_state_update): + flat_state_idx = cache_idx * HV + i_hv + for vv in cutlass.range_constexpr(BV): + v_global = i_v * BV + vv + for c in cutlass.range_constexpr(vec_size): + r_h4[c] = r_h[vv * vec_size + c] + h_out = cute.local_tile(h0_source, (1, 1, vec_size), (flat_state_idx, v_global, lane)) + cute.autovec_copy(r_h4, h_out) + + +@cute.jit +def run_kda_mtp_small_batch_vk_kernel( + h0_source: cute.Tensor, + A_log: cute.Tensor, + a: cute.Tensor, + dt_bias: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + b: cute.Tensor, + o: cute.Tensor, + intermediate_states: cute.Tensor, + h0_indices: cute.Tensor, + vec_size: cutlass.Constexpr[int], + BV: cutlass.Constexpr[int], + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + disable_state_update: cutlass.Constexpr[bool], + cache_intermediate_states: cutlass.Constexpr[bool], + fast_math: cutlass.Constexpr[bool], + stream: cuda.CUstream, +): + """lane=K vk launcher:grid = N*HV*(V//BV),block = 32(1 warp)。无 SMEM。""" + n_indices = h0_indices.layout.shape[0] + num_v_tiles = cute.ceil_div(V, BV) + grid_size = n_indices * HV * num_v_tiles + + kda_mtp_small_batch_vk_kernel( + h0_source, + A_log, + a, + dt_bias, + q, + k, + v, + b, + o, + intermediate_states, + h0_indices, + vec_size, + num_v_tiles, + BV, + softplus_beta, + softplus_threshold, + scale, + HV, + T, + H, + K, + V, + use_qk_l2norm, + disable_state_update, + cache_intermediate_states, + fast_math, + ).launch( + grid=(grid_size, 1, 1), + block=[32, 1, 1], + smem=0, + stream=stream, + ) + + +_compiled_mtp_vk_kernels: dict[tuple, object] = {} + + +def _get_compiled_mtp_vk_kernel( + N, + T, + H, + HV, + K, + V, + pool_size, + BV, + scale, + use_qk_l2norm, + disable_state_update, + softplus_beta, + softplus_threshold, + opt_level=3, + fast_math=True, + cache_intermediate_states=False, +): + key = ( + N, + T, + H, + HV, + K, + V, + pool_size, + BV, + scale, + use_qk_l2norm, + disable_state_update, + cache_intermediate_states, + softplus_beta, + softplus_threshold, + opt_level, + fast_math, + ) + if key in _compiled_mtp_vk_kernels: + return _compiled_mtp_vk_kernels[key] + + q = torch.zeros(N, T, H, K, dtype=torch.bfloat16, device="cuda") + k = torch.zeros(N, T, H, K, dtype=torch.bfloat16, device="cuda") + v = torch.zeros(N, T, HV, V, dtype=torch.bfloat16, device="cuda") + a = torch.zeros(N, T, HV, K, dtype=torch.bfloat16, device="cuda") + b = torch.zeros(N, T, HV, dtype=torch.bfloat16, device="cuda") + o = torch.zeros(N, T, HV, V, dtype=torch.bfloat16, device="cuda") + A_log = torch.zeros(HV, dtype=torch.float32, device="cuda") + dt_bias = torch.zeros(HV, K, dtype=torch.float32, device="cuda") + h0_source = torch.zeros(pool_size * HV, V, K, dtype=torch.float32, device="cuda") + h0_indices = torch.zeros(N, dtype=torch.int32, device="cuda") + if cache_intermediate_states: + intermediate_states = torch.zeros(N * T * HV, V, K, dtype=torch.float32, device="cuda") + else: + intermediate_states = torch.empty(1, 1, 1, dtype=torch.float32, device="cuda") + + q_t = from_dlpack(q, assumed_align=16) + k_t = from_dlpack(k, assumed_align=16) + v_t = from_dlpack(v, assumed_align=16) + a_t = from_dlpack(a, assumed_align=16) + b_t = from_dlpack(b, assumed_align=16) + o_t = from_dlpack(o, assumed_align=16) + A_log_t = from_dlpack(A_log, assumed_align=16) + dt_bias_t = from_dlpack(dt_bias, assumed_align=16) + h0_source_t = from_dlpack(h0_source, assumed_align=16) + h0_indices_t = from_dlpack(h0_indices, assumed_align=16) + intermediate_states_t = from_dlpack(intermediate_states, assumed_align=16) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compiled_kernel = cute.compile( + run_kda_mtp_small_batch_vk_kernel, + h0_source_t, + A_log_t, + a_t, + dt_bias_t, + q_t, + k_t, + v_t, + b_t, + o_t, + intermediate_states_t, + h0_indices_t, + vec_size=VEC_SIZE, + BV=BV, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + scale=scale, + HV=HV, + T=T, + H=H, + K=K, + V=V, + use_qk_l2norm=use_qk_l2norm, + disable_state_update=disable_state_update, + cache_intermediate_states=cache_intermediate_states, + fast_math=fast_math, + stream=stream, + options=f"--enable-tvm-ffi --opt-level {opt_level}", + ) + + _compiled_mtp_vk_kernels[key] = compiled_kernel + logger.info( + "CuTe DSL KDA MTP small-batch VK(lane=K) kernel compiled: " + f"N={N}, T={T}, H={H}, HV={HV}, K={K}, V={V}, pool_size={pool_size}, BV={BV}, " + f"opt_level={opt_level}, fast_math={fast_math}" + ) + return compiled_kernel + + +def _select_vk_bv(work_units, V, num_sms): + waves32 = work_units * (V // 32) / (num_sms * 12) + if V % 8 == 0 and waves32 < 3.0: + return 8 + return 32 + +def kda_decode_mtp( + A_log: torch.Tensor, + dt_bias: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + initial_state_source: torch.Tensor, + initial_state_indices: torch.Tensor, + scale: float | None = None, + use_qk_l2norm_in_kernel: bool = True, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, + out: torch.Tensor | None = None, + state_layout: str = "vk", + disable_state_update: bool = False, + intermediate_states_buffer: torch.Tensor | None = None, +) -> torch.Tensor: + common = dict( + A_log=A_log, dt_bias=dt_bias, q=q, k=k, v=v, a=a, b=b, + initial_state_source=initial_state_source, + initial_state_indices=initial_state_indices, + scale=scale, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + softplus_beta=softplus_beta, softplus_threshold=softplus_threshold, + out=out, disable_state_update=disable_state_update, + intermediate_states_buffer=intermediate_states_buffer, + ) + if state_layout == "kv": + return kda_decode_mtp_small_batch(**common, variant="kv", k_split=-1) # k_split auto + # T=1 single-token decode: vk small_batch is fastest at every batch (beats + # ws/packed across all N, see bench_kda_decode_t1_vs_sgl) -> route T=1 to vk. + T = q.shape[1] + work_units = q.shape[0] * v.shape[2] # N * HV + if T == 1 or work_units <= 512: + return kda_decode_mtp_small_batch(**common, variant="vk", bv=-1) # bv auto + return kda_decode_mtp_ws(**common, state_layout="vk") diff --git a/cula/ops/kda_decode_mtp_kvbuffer.py b/cula/ops/kda_decode_mtp_kvbuffer.py new file mode 100644 index 00000000..11f6b5d6 --- /dev/null +++ b/cula/ops/kda_decode_mtp_kvbuffer.py @@ -0,0 +1,1416 @@ +"""CuTe DSL KDA MTP decode — KVBuffer / chunkwise parallel-verification variant. + +KVBuffer paper's chunkwise verify form (https://arxiv.org/abs/2605.19049) as a new +operator vs the recurrent vk/kv ops in ``kda_decode_mtp.py``. The T draft tokens +are treated as ONE chunk: per-token outputs come from the FIXED input state S0 plus a +small T×T intra-chunk correction, and the state is updated once at the end — the +S0-matvecs are independent across tokens (no length-T serial chain), the latency angle +at small batch. Infra (grid N*HV*(V//BV), 1 warp/CTA, lane=K, float4 loads, butterfly +reduce-over-K) mirrors the production vk kernel for apples-to-apples comparison. + +Chunkwise math (state S0[v,k], decay-first; matches the recurrent op): + g_t[k] = exp(-exp(A_log) * softplus(a_t[k] + dt_bias[k])) # per channel + b_t[k] = prod_{i<=t} g_i[k] # cumulative decay + kdec_t = k_norm_t * b_t ; kinv_t = k_norm_t / b_t ; qdec_t = q_scaled_t * b_t + A[t,i] = (i (i<=t) + u_t[v] = beta_t * (v_t[v] - (S0 @ kdec_t)[v] - sum_{i16, <256->32, >=256->64 (H200 sweep; >=256 fixes the HV=64 N=4 case). +def _select_kvb_tile_v(V, N, HV): + """work-unit (N*HV) dependent tile_v. Returns the first candidate that divides V.""" + wu = N * HV + if wu <= 32: + order = (16, 32, 8, 64) + elif wu < 256: + order = (32, 64, 16, 8) + else: + order = (64, 32, 16, 8) + for tv in order: + if V % tv == 0: + return tv + return 8 + + +# flush kernel: read the compact u-buffer from verify, rank-m update over the first m accepted tokens: +# S_m[v,k] = b_m[k] * (S0[v,k] + sum_{i= 0: + flat_state_idx = cache_idx * HV + i_hv + m_n = m_buf[i_n] # this request's accept length (runtime; 1 <= m_n <= T) + + r_h = cute.make_rmem_tensor(cute.make_layout((BV * vec_size,), stride=(1,)), cutlass.Float32) + r_h4 = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_bm = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_kinv = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + + # load S0 (this lane's vec_size K channels x BV v-cols) + for vv in cutlass.range_constexpr(BV): + v_global = i_v * BV + vv + h_tile = cute.local_tile(h0_source, (1, 1, vec_size), (flat_state_idx, v_global, lane)) + cute.autovec_copy(h_tile, r_h4) + for c in cutlass.range_constexpr(vec_size): + r_h[vv * vec_size + c] = r_h4[c] + + # b_m: cumulative decay at token m-1 (this lane's channels) + bm_tile = cute.local_tile(b_buf, (1, 1, 1, vec_size), (i_n, m_n - 1, i_hv, lane)) + cute.autovec_copy(bm_tile, r_bm) + + # accumulate sum_{i torch.Tensor: + N, T, HV, V = u_buffer.shape + K = kinv_buffer.shape[3] + if isinstance(accept_len, torch.Tensor): + assert accept_len.numel() == N, f"per-request accept_len must have N={N} entries, got {accept_len.numel()}" + m_buf = accept_len.to(device=u_buffer.device, dtype=torch.int32).contiguous() + else: + m = int(accept_len) + assert 1 <= m <= T, f"accept_len must be in [1,{T}], got {m}" + m_buf = torch.full((N,), m, dtype=torch.int32, device=u_buffer.device) + + if bv <= 0: + num_sms = torch.cuda.get_device_properties(initial_state_source.device).multi_processor_count + bv = _select_vk_bv(N * HV, V, num_sms) + assert bv in (8, 16, 32) and V % bv == 0, f"flush bv must be 8/16/32 and divide V, got bv={bv}, V={V}" + + h0_source, pool_size, _ = _normalize_state_source( + initial_state_source, N=N, HV=HV, K=K, V=V, device=initial_state_source.device, state_layout="vk", + ) + initial_state_indices = _normalize_state_indices( + initial_state_indices, N=N, pool_size=pool_size, device=initial_state_source.device + ) + stream = _get_cached_stream(initial_state_source.device) + + h0_source_flat = h0_source.view(pool_size * HV, V, K) + compiled = _get_compiled_flush_kvbuffer_kernel(N, T, HV, K, V, pool_size, bv, opt_level=opt_level) + compiled(h0_source_flat, u_buffer, kinv_buffer, b_buffer, initial_state_indices, m_buf, stream) + return initial_state_source + + +# --------------------------------------------------------------------------- +# tp-kvbuffer: token-parallel chunkwise verify (structure B). UT-transform +# W = L^{-1} diag(beta) makes the consumer solve dependence-free: u = W @ (v - S0 kdec). +# --------------------------------------------------------------------------- +@cute.kernel +def kda_mtp_tp_kvbuffer_kernel( + h0_source: cute.Tensor, # [pool*HV, V, K] fp32 (vk) + A_log: cute.Tensor, + a: cute.Tensor, + dt_bias: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + b: cute.Tensor, + o: cute.Tensor, + h0_indices: cute.Tensor, + u_buf: cute.Tensor, # [N, T, HV, V] fp32 + kinv_buf: cute.Tensor, # [N, T, HV, K] fp32 + b_buf: cute.Tensor, # [N, T, HV, K] fp32 + vec_size: cutlass.Constexpr[int], + num_v_tiles: cutlass.Constexpr[int], + tile_v: cutlass.Constexpr[int], + ilp_rows: cutlass.Constexpr[int], + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + disable_state_update: cutlass.Constexpr[bool], + emit_output: cutlass.Constexpr[bool], + write_ubuf: cutlass.Constexpr[bool], + fast_math: cutlass.Constexpr[bool], +): + tidx, _, _ = cute.arch.thread_idx() + lane_id = tidx % 32 + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + num_warps: cutlass.Constexpr[int] = 4 + + bidx, _, _ = cute.arch.block_idx() + i_v = bidx % num_v_tiles + tmp = bidx // num_v_tiles + i_hv = tmp % HV + i_n = tmp // HV + i_h = i_hv // (HV // H) + + cache_idx = h0_indices[i_n] + r_exp_A = cute.exp(cutlass.Float32(A_log[i_hv]), fastmath=fast_math) + + # SMEM. sKdec/sQdec double as staging for k_norm/q_scaled between Stage 1 and 2. + smem = cutlass.utils.SmemAllocator() + sKdec = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T, K), stride=(K + 8, 1)), 16) + sKinv = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T, K), stride=(K + 8, 1)), 16) + sQdec = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T, K), stride=(K + 8, 1)), 16) + sG = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T, K), stride=(K + 8, 1)), 16) + sBeta = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T,)), 16) + sBlast = smem.allocate_tensor(cutlass.Float32, cute.make_layout((K,)), 16) # b_{T-1}[k] + sA = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T, T), stride=(T, 1)), 16) + sP = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T, T), stride=(T, 1)), 16) + sW = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T, T), stride=(T, 1)), 16) + + r_qbf = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) + r_kbf = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) + r_qf = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_kf = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_dtb = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_tmp = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_h = cute.make_rmem_tensor(cute.make_layout((ilp_rows, vec_size), stride=(vec_size, 1)), cutlass.Float32) + # r_part: ilp_rows*T batched partials (Skdec, then reused as x = v - Skdec, then Sqdec). + r_part = cute.make_rmem_tensor(cute.make_layout((ilp_rows, T), stride=(T, 1)), cutlass.Float32) + r_u = cute.make_rmem_tensor(cute.make_layout((ilp_rows, T), stride=(T, 1)), cutlass.Float32) + # Stage-3 pair partials: ceil(T*T/4) per warp. + ppw: cutlass.Constexpr[int] = (T * T + num_warps - 1) // num_warps + r_red = cute.make_rmem_tensor(cute.make_layout((ppw,), stride=(1,)), cutlass.Float32) + + if cache_idx >= 0: + k_start = lane_id * vec_size + rows_per_group: cutlass.Constexpr[int] = tile_v // num_warps + flat_state_idx = cache_idx * HV + i_hv + + # ---- Stage 1: token-parallel gating/l2norm (warp w owns tokens w, w+4, ...) ---- + for c in cutlass.range_constexpr(vec_size): + r_dtb[c] = cutlass.Float32(dt_bias[i_hv, k_start + c]) + tokens_per_warp: cutlass.Constexpr[int] = (T + num_warps - 1) // num_warps + for tt in cutlass.range_constexpr(tokens_per_warp): + t_tok = tt * num_warps + warp_idx + if t_tok < T: + q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, t_tok, i_h, lane_id)) + k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, t_tok, i_h, lane_id)) + cute.autovec_copy(q_tile, r_qbf) + cute.autovec_copy(k_tile, r_kbf) + for c in cutlass.range_constexpr(vec_size): + r_qf[c] = cutlass.Float32(r_qbf[c]) + r_kf[c] = cutlass.Float32(r_kbf[c]) + + if cutlass.const_expr(use_qk_l2norm): + sum_q = cutlass.Float32(0.0) + sum_k = cutlass.Float32(0.0) + for c in cutlass.range_constexpr(vec_size): + sum_q += r_qf[c] * r_qf[c] + sum_k += r_kf[c] * r_kf[c] + for off in [16, 8, 4, 2, 1]: + sum_q += cute.arch.shuffle_sync_bfly(sum_q, offset=off, mask=-1, mask_and_clamp=31) + sum_k += cute.arch.shuffle_sync_bfly(sum_k, offset=off, mask=-1, mask_and_clamp=31) + inv_q = cute.rsqrt(sum_q + 1e-6, fastmath=fast_math) * scale + inv_k = cute.rsqrt(sum_k + 1e-6, fastmath=fast_math) + for c in cutlass.range_constexpr(vec_size): + r_qf[c] = r_qf[c] * inv_q + r_kf[c] = r_kf[c] * inv_k + else: + for c in cutlass.range_constexpr(vec_size): + r_qf[c] = r_qf[c] * scale + + # gate g_t per channel; stage k_norm/q_scaled (decay applied in Stage 2) + for c in cutlass.range_constexpr(vec_size): + x = cutlass.Float32(a[i_n, t_tok, i_hv, k_start + c]) + r_dtb[c] + beta_x = softplus_beta * x + exp_bx = cute.exp(beta_x, fastmath=fast_math) + sp_val = (cutlass.Float32(1.0) / softplus_beta) * cute.log( + cutlass.Float32(1.0) + exp_bx, fastmath=fast_math + ) + use_sp = ( + cutlass.Float32(1.0) + if beta_x <= softplus_threshold + else cutlass.Float32(0.0) + ) + sp_x = use_sp * sp_val + (cutlass.Float32(1.0) - use_sp) * x + sG[t_tok, k_start + c] = cute.exp(-r_exp_A * sp_x, fastmath=fast_math) + sKdec[t_tok, k_start + c] = r_kf[c] + sQdec[t_tok, k_start + c] = r_qf[c] + if lane_id == 0: + sBeta[t_tok] = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + + cute.exp(-cutlass.Float32(b[i_n, t_tok, i_hv]), fastmath=fast_math) + ) + cute.arch.barrier() + + # ---- Stage 2: K-parallel prefix-product scan (thread = one channel) ---- + kc = tidx # requires K == 128 == block size + b_run_s = cutlass.Float32(1.0) + for i_t in cutlass.range_constexpr(T): + kn = sKdec[i_t, kc] + b_run_s = b_run_s * sG[i_t, kc] + kinv_v = kn / b_run_s + sKdec[i_t, kc] = kn * b_run_s + sKinv[i_t, kc] = kinv_v + sQdec[i_t, kc] = sQdec[i_t, kc] * b_run_s + if cutlass.const_expr(write_ubuf): + if i_v == 0: + kinv_buf[i_n, i_t, i_hv, kc] = kinv_v + b_buf[i_n, i_t, i_hv, kc] = b_run_s + sBlast[kc] = b_run_s + cute.arch.barrier() + + # ---- Stage 3: (t,i)-parallel A/P, T^2 pairs round-robined over 4 warps, + # ONE batched butterfly per warp. Pair p: p < T*(T-1)/2 -> A, else P. ---- + for j in cutlass.range_constexpr(ppw): + r_red[j] = cutlass.Float32(0.0) + p_ctr = 0 + for i_t in cutlass.range_constexpr(T): + for i_i in cutlass.range_constexpr(i_t): # A[t,i], i no cross-lane sync needed. ---- + if warp_idx == 0: + if lane_id < T: + for i_t in cutlass.range_constexpr(T): + eq = cutlass.Float32(1.0) if lane_id == i_t else cutlass.Float32(0.0) + acc_w = eq + for i_i in cutlass.range_constexpr(i_t): + acc_w -= sA[i_t, i_i] * sW[i_i, lane_id] + sW[i_t, lane_id] = sBeta[i_t] * acc_w + cute.arch.barrier() + + # ---- Stage 4: consumer (4 warp groups over V rows), zero serial deps. ---- + n_row_groups: cutlass.Constexpr[int] = rows_per_group // ilp_rows + for rg in cutlass.range_constexpr(n_row_groups): + v_base = i_v * tile_v + warp_idx * rows_per_group + rg * ilp_rows + for r in cutlass.range_constexpr(ilp_rows): + h_tile = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v_base + r, lane_id) + ) + cute.autovec_copy(h_tile, cute.slice_(r_h, (r, None))) + # all T Skdec_t for all ilp_rows rows in ONE batched butterfly + for r in cutlass.range_constexpr(ilp_rows): + for i_t in cutlass.range_constexpr(T): + s = cutlass.Float32(0.0) + for c in cutlass.range_constexpr(vec_size): + s += r_h[r, c] * sKdec[i_t, k_start + c] + r_part[r, i_t] = s + for off in [16, 8, 4, 2, 1]: + for r in cutlass.range_constexpr(ilp_rows): + for i_t in cutlass.range_constexpr(T): + r_part[r, i_t] += cute.arch.shuffle_sync_bfly(r_part[r, i_t], offset=off, mask=-1, mask_and_clamp=31) + # x = v - Skdec (r_part reused), then u = W @ x (token-parallel, no dep chain) + for r in cutlass.range_constexpr(ilp_rows): + for i_t in cutlass.range_constexpr(T): + r_part[r, i_t] = cutlass.Float32(v[i_n, i_t, i_hv, v_base + r]) - r_part[r, i_t] + for r in cutlass.range_constexpr(ilp_rows): + for i_t in cutlass.range_constexpr(T): + acc = cutlass.Float32(0.0) + for i_i in cutlass.range_constexpr(i_t + 1): + acc += sW[i_t, i_i] * r_part[r, i_i] + r_u[r, i_t] = acc + if cutlass.const_expr(write_ubuf): + if lane_id == 0: + for r in cutlass.range_constexpr(ilp_rows): + for i_t in cutlass.range_constexpr(T): + u_buf[i_n, i_t, i_hv, v_base + r] = r_u[r, i_t] + # o_t = Sqdec_t + sum_{i<=t} P[t,i] u_i (Sqdec batched butterfly into r_part) + if cutlass.const_expr(emit_output): + for r in cutlass.range_constexpr(ilp_rows): + for i_t in cutlass.range_constexpr(T): + s = cutlass.Float32(0.0) + for c in cutlass.range_constexpr(vec_size): + s += r_h[r, c] * sQdec[i_t, k_start + c] + r_part[r, i_t] = s + for off in [16, 8, 4, 2, 1]: + for r in cutlass.range_constexpr(ilp_rows): + for i_t in cutlass.range_constexpr(T): + r_part[r, i_t] += cute.arch.shuffle_sync_bfly(r_part[r, i_t], offset=off, mask=-1, mask_and_clamp=31) + for r in cutlass.range_constexpr(ilp_rows): + for i_t in cutlass.range_constexpr(T): + ov = r_part[r, i_t] + for i_i in cutlass.range_constexpr(i_t + 1): + ov += sP[i_t, i_i] * r_u[r, i_i] + if lane_id == 0: + o[(i_n, i_t, i_hv, v_base + r)] = cutlass.BFloat16(ov) + # final state S_T[v,k] = b_{T-1}[k]*(S0[v,k] + sum_t u_t kinv_t[k]) + if cutlass.const_expr(not disable_state_update): + for r in cutlass.range_constexpr(ilp_rows): + for c in cutlass.range_constexpr(vec_size): + acc = r_h[r, c] + for i_t in cutlass.range_constexpr(T): + acc += r_u[r, i_t] * sKinv[i_t, k_start + c] + r_tmp[c] = sBlast[k_start + c] * acc + h_out = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v_base + r, lane_id) + ) + cute.autovec_copy(r_tmp, h_out) + + +@cute.jit +def run_kda_mtp_tp_kvbuffer_kernel( + h0_source: cute.Tensor, + A_log: cute.Tensor, + a: cute.Tensor, + dt_bias: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + b: cute.Tensor, + o: cute.Tensor, + h0_indices: cute.Tensor, + u_buf: cute.Tensor, + kinv_buf: cute.Tensor, + b_buf: cute.Tensor, + vec_size: cutlass.Constexpr[int], + tile_v: cutlass.Constexpr[int], + ilp_rows: cutlass.Constexpr[int], + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + disable_state_update: cutlass.Constexpr[bool], + emit_output: cutlass.Constexpr[bool], + write_ubuf: cutlass.Constexpr[bool], + fast_math: cutlass.Constexpr[bool], + stream: cuda.CUstream, +): + """tp-kvbuffer launcher: grid = N*HV*(V//tile_v), block = 128 (4 warps).""" + n_indices = h0_indices.layout.shape[0] + num_v_tiles = cute.ceil_div(V, tile_v) + grid_size = n_indices * HV * num_v_tiles + smem_bytes = ( + 4 * 4 * T * (K + 8) # sKdec/sKinv/sQdec/sG + + 4 * T # sBeta + + 4 * K # sBlast + + 3 * 4 * T * T # sA/sP/sW + + 256 # alignment slack + ) + kda_mtp_tp_kvbuffer_kernel( + h0_source, A_log, a, dt_bias, q, k, v, b, o, h0_indices, + u_buf, kinv_buf, b_buf, + vec_size, num_v_tiles, tile_v, ilp_rows, + softplus_beta, softplus_threshold, scale, + HV, T, H, K, V, + use_qk_l2norm, disable_state_update, emit_output, write_ubuf, fast_math, + ).launch(grid=(grid_size, 1, 1), block=[128, 1, 1], smem=smem_bytes, stream=stream) + + +_compiled_mtp_tp_kvbuffer_kernels: dict[tuple, object] = {} + + +def _get_compiled_mtp_tp_kvbuffer_kernel( + N, T, H, HV, K, V, pool_size, tile_v, ilp_rows, scale, use_qk_l2norm, + disable_state_update, emit_output, write_ubuf, + softplus_beta, softplus_threshold, opt_level=3, fast_math=True, +): + key = ( + N, T, H, HV, K, V, pool_size, tile_v, ilp_rows, scale, use_qk_l2norm, + disable_state_update, emit_output, write_ubuf, + softplus_beta, softplus_threshold, opt_level, fast_math, + ) + if key in _compiled_mtp_tp_kvbuffer_kernels: + return _compiled_mtp_tp_kvbuffer_kernels[key] + + q = torch.zeros(N, T, H, K, dtype=torch.bfloat16, device="cuda") + k = torch.zeros(N, T, H, K, dtype=torch.bfloat16, device="cuda") + v = torch.zeros(N, T, HV, V, dtype=torch.bfloat16, device="cuda") + a = torch.zeros(N, T, HV, K, dtype=torch.bfloat16, device="cuda") + b = torch.zeros(N, T, HV, dtype=torch.bfloat16, device="cuda") + o = torch.zeros(N, T, HV, V, dtype=torch.bfloat16, device="cuda") + A_log = torch.zeros(HV, dtype=torch.float32, device="cuda") + dt_bias = torch.zeros(HV, K, dtype=torch.float32, device="cuda") + h0_source = torch.zeros(pool_size * HV, V, K, dtype=torch.float32, device="cuda") + h0_indices = torch.zeros(N, dtype=torch.int32, device="cuda") + u_buf = torch.zeros(N, T, HV, V, dtype=torch.float32, device="cuda") + kinv_buf = torch.zeros(N, T, HV, K, dtype=torch.float32, device="cuda") + b_buf = torch.zeros(N, T, HV, K, dtype=torch.float32, device="cuda") + + compiled_kernel = cute.compile( + run_kda_mtp_tp_kvbuffer_kernel, + from_dlpack(h0_source, assumed_align=16), + from_dlpack(A_log, assumed_align=16), + from_dlpack(a, assumed_align=16), + from_dlpack(dt_bias, assumed_align=16), + from_dlpack(q, assumed_align=16), + from_dlpack(k, assumed_align=16), + from_dlpack(v, assumed_align=16), + from_dlpack(b, assumed_align=16), + from_dlpack(o, assumed_align=16), + from_dlpack(h0_indices, assumed_align=16), + from_dlpack(u_buf, assumed_align=16), + from_dlpack(kinv_buf, assumed_align=16), + from_dlpack(b_buf, assumed_align=16), + vec_size=VEC_SIZE, + tile_v=tile_v, + ilp_rows=ilp_rows, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + scale=scale, + HV=HV, T=T, H=H, K=K, V=V, + use_qk_l2norm=use_qk_l2norm, + disable_state_update=disable_state_update, + emit_output=emit_output, + write_ubuf=write_ubuf, + fast_math=fast_math, + stream=cuda.CUstream(torch.cuda.current_stream().cuda_stream), + options=f"--enable-tvm-ffi --opt-level {opt_level}", + ) + _compiled_mtp_tp_kvbuffer_kernels[key] = compiled_kernel + logger.info( + "CuTe DSL KDA MTP tp-KVBuffer kernel compiled: " + f"N={N}, T={T}, HV={HV}, K={K}, V={V}, tile_v={tile_v}, ilp_rows={ilp_rows}, " + f"opt_level={opt_level}, fast_math={fast_math}" + ) + return compiled_kernel + + +def _select_tp_kvb_ilp_rows(tile_v, T): + """Largest ilp_rows in {4,2,1} dividing rows_per_group with ilp_rows*T <= 16 — the consumer + holds two (ilp_rows, T) fp32 register arrays (r_part + r_u), so cap their footprint.""" + rows_per_group = tile_v // 4 + for r in (4, 2, 1): + if rows_per_group % r == 0 and r * T <= 16: + return r + return 1 + + +def kda_decode_mtp_tp_kvbuffer( + A_log: torch.Tensor, + dt_bias: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + initial_state_source: torch.Tensor, + initial_state_indices: torch.Tensor, + scale: float | None = None, + use_qk_l2norm_in_kernel: bool = True, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, + out: torch.Tensor | None = None, + disable_state_update: bool = True, + emit_output: bool = True, + u_buffer: torch.Tensor | None = None, + kinv_buffer: torch.Tensor | None = None, + b_buffer: torch.Tensor | None = None, + tile_v: int = -1, + ilp_rows: int = -1, + opt_level: int = 3, + fast_math: bool = True, +) -> torch.Tensor: + """KDA MTP tp-KVBuffer verify (token-parallel chunkwise; flush reuses kda_flush_kvbuffer).""" + N, T, H, K = q.shape + HV = v.shape[2] + V = v.shape[3] + write_ubuf = u_buffer is not None + + if scale is None: + scale = K**-0.5 + else: + assert scale > 0, f"scale must be positive, got {scale}" + + assert K == TILE_K, f"tp-kvbuffer requires K={TILE_K}, got {K}" + assert K == 128, f"tp-kvbuffer Stage-2 scan maps 128 threads to K channels; needs K=128, got {K}" + assert T <= 32, f"tp-kvbuffer W-build uses one lane per token column; needs T<=32, got {T}" + + if tile_v <= 0: + tile_v = _select_kvb_tile_v(V, N, HV) + assert V % tile_v == 0, f"tp-kvbuffer requires V % tile_v == 0, got V={V}, tile_v={tile_v}" + assert tile_v % 4 == 0, f"tp-kvbuffer requires tile_v % 4 == 0 (4 warps), got {tile_v}" + rows_per_group = tile_v // 4 + if ilp_rows <= 0: + ilp_rows = _select_tp_kvb_ilp_rows(tile_v, T) + assert rows_per_group % ilp_rows == 0, ( + f"tp-kvbuffer requires (tile_v/4) % ilp_rows == 0, got tile_v={tile_v}, ilp_rows={ilp_rows}" + ) + + h0_source, pool_size, _ = _normalize_state_source( + initial_state_source, N=N, HV=HV, K=K, V=V, device=q.device, state_layout="vk", + ) + + a = _normalize_mtp_a(a, N=N, T=T, HV=HV, K=K) + if b.dim() != 3 or tuple(b.shape) != (N, T, HV): + raise ValueError(f"Unexpected b shape for MTP dense: {tuple(b.shape)}; expected {(N, T, HV)}") + + o = _prepare_output_tensor(q, out, (N, T, HV, V)) + + q = q if q.is_contiguous() else q.contiguous() + k = k if k.is_contiguous() else k.contiguous() + v = v if v.is_contiguous() else v.contiguous() + a = a if a.is_contiguous() else a.contiguous() + b = b if b.is_contiguous() else b.contiguous() + + A_log = _normalize_A_log(A_log, HV) + dt_bias = _normalize_dt_bias(dt_bias, HV, K) + initial_state_indices = _normalize_state_indices( + initial_state_indices, N=N, pool_size=pool_size, device=q.device + ) + + if write_ubuf: + if tuple(u_buffer.shape) != (N, T, HV, V): + raise ValueError(f"u_buffer shape must be {(N, T, HV, V)}, got {tuple(u_buffer.shape)}") + if tuple(kinv_buffer.shape) != (N, T, HV, K) or tuple(b_buffer.shape) != (N, T, HV, K): + raise ValueError(f"kinv_buffer/b_buffer shape must be {(N, T, HV, K)}") + u_buf, kinv_buf, b_buf = u_buffer, kinv_buffer, b_buffer + else: + u_buf = torch.empty(N, T, HV, V, dtype=torch.float32, device=q.device) + kinv_buf = torch.empty(N, T, HV, K, dtype=torch.float32, device=q.device) + b_buf = torch.empty(N, T, HV, K, dtype=torch.float32, device=q.device) + + stream = _get_cached_stream(q.device) + + h0_source_flat = h0_source.view(pool_size * HV, V, K) + compiled_kernel = _get_compiled_mtp_tp_kvbuffer_kernel( + N, T, H, HV, K, V, pool_size, tile_v, ilp_rows, + scale=scale, use_qk_l2norm=use_qk_l2norm_in_kernel, + disable_state_update=disable_state_update, emit_output=emit_output, + write_ubuf=write_ubuf, + softplus_beta=softplus_beta, softplus_threshold=softplus_threshold, + opt_level=opt_level, fast_math=fast_math, + ) + compiled_kernel( + h0_source_flat, A_log, a, dt_bias, q, k, v, b, o, + initial_state_indices, u_buf, kinv_buf, b_buf, stream, + ) + return o + + +# =========================================================================== +# gemm-kvbuffer (CuTe sm_90 tensor-core, flat-in-T): every reduction on warp-level +# mma.sync.m16n8k8.tf32 (llvm.inline_asm wrapper); verify = the BT=8 stacked kernel below. +# +# mma.sync m16n8k8 fragment mapping (PTX ISA), gid = lane>>2, tig = lane&3: +# A row-major [16,8]: a0=A[gid][tig] a1=A[gid+8][tig] a2=A[gid][tig+4] a3=A[gid+8][tig+4] +# B col-major [8,8]: b0=B[tig][gid] b1=B[tig+4][gid] +# C/D [16,8] f32: c0=C[gid][2tig] c1=C[gid][2tig+1] c2=C[gid+8][2tig] c3=C[gid+8][2tig+1] +# =========================================================================== + +from cutlass._mlir.dialects import arith as _arith +from cutlass._mlir.dialects import llvm as _llvm +from cutlass.cutlass_dsl import T as _T +from cutlass.cutlass_dsl import dsl_user_op + + +@dsl_user_op +def _mma_m16n8k8_tf32(a0, a1, a2, a3, b0, b1, c0, c1, c2, c3, *, loc=None, ip=None): + """One mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32; returns (d0, d1, d2, d3). + + a*/b* are Float32 values reinterpreted as tf32 (raw f32 bits; HW ignores the low + mantissa bits — same truncation semantics as Triton's tf32 dots).""" + f32 = _T.f32() + i32 = _T.i32() + + def _bits(v): + vv = v.ir_value(loc=loc, ip=ip) if hasattr(v, "ir_value") else v + return _arith.bitcast(i32, vv, loc=loc, ip=ip) + + def _f(v): + return v.ir_value(loc=loc, ip=ip) if hasattr(v, "ir_value") else v + + res_ty = _llvm.StructType.get_literal([f32, f32, f32, f32]) + res = _llvm.inline_asm( + res_ty, + [_bits(a0), _bits(a1), _bits(a2), _bits(a3), _bits(b0), _bits(b1), + _f(c0), _f(c1), _f(c2), _f(c3)], + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{$0,$1,$2,$3}, {$4,$5,$6,$7}, {$8,$9}, {$10,$11,$12,$13};", + "=f,=f,=f,=f,r,r,r,r,r,r,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=_llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + d0 = cutlass.Float32(_llvm.extractvalue(f32, res, [0], loc=loc, ip=ip)) + d1 = cutlass.Float32(_llvm.extractvalue(f32, res, [1], loc=loc, ip=ip)) + d2 = cutlass.Float32(_llvm.extractvalue(f32, res, [2], loc=loc, ip=ip)) + d3 = cutlass.Float32(_llvm.extractvalue(f32, res, [3], loc=loc, ip=ip)) + return d0, d1, d2, d3 + + +_compiled_gemm_kvbuffer_cute_kernels: dict[tuple, object] = {} + + +def _get_compiled_gemm_kvbuffer_cute_kernel( + N, T, H, HV, K, V, pool_size, bv, num_v_tiles, scale, use_qk_l2norm, + disable_state_update, emit_output, write_ubuf, + softplus_beta, softplus_threshold, opt_level=3, fast_math=True, +): + key = ( + N, T, H, HV, K, V, pool_size, bv, num_v_tiles, scale, use_qk_l2norm, + disable_state_update, emit_output, write_ubuf, + softplus_beta, softplus_threshold, opt_level, fast_math, + ) + if key in _compiled_gemm_kvbuffer_cute_kernels: + return _compiled_gemm_kvbuffer_cute_kernels[key] + + q = torch.zeros(N, T, H, K, dtype=torch.bfloat16, device="cuda") + k = torch.zeros(N, T, H, K, dtype=torch.bfloat16, device="cuda") + v = torch.zeros(N, T, HV, V, dtype=torch.bfloat16, device="cuda") + a = torch.zeros(N, T, HV, K, dtype=torch.bfloat16, device="cuda") + b = torch.zeros(N, T, HV, dtype=torch.bfloat16, device="cuda") + o = torch.zeros(N, T, HV, V, dtype=torch.bfloat16, device="cuda") + A_log = torch.zeros(HV, dtype=torch.float32, device="cuda") + dt_bias = torch.zeros(HV, K, dtype=torch.float32, device="cuda") + h0_source = torch.zeros(pool_size * HV, V, K, dtype=torch.float32, device="cuda") + h0_indices = torch.zeros(N, dtype=torch.int32, device="cuda") + u_buf = torch.zeros(N, T, HV, V, dtype=torch.float32, device="cuda") + kinv_buf = torch.zeros(N, T, HV, K, dtype=torch.float32, device="cuda") + b_buf = torch.zeros(N, T, HV, K, dtype=torch.float32, device="cuda") + + run_fn = run_kda_mtp_gemm_kvbuffer_cute_kernel + compiled_kernel = cute.compile( + run_fn, + from_dlpack(h0_source, assumed_align=16), + from_dlpack(A_log, assumed_align=16), + from_dlpack(a, assumed_align=16), + from_dlpack(dt_bias, assumed_align=16), + from_dlpack(q, assumed_align=16), + from_dlpack(k, assumed_align=16), + from_dlpack(v, assumed_align=16), + from_dlpack(b, assumed_align=16), + from_dlpack(o, assumed_align=16), + from_dlpack(h0_indices, assumed_align=16), + from_dlpack(u_buf, assumed_align=16), + from_dlpack(kinv_buf, assumed_align=16), + from_dlpack(b_buf, assumed_align=16), + vec_size=VEC_SIZE, + BV=bv, + num_v_tiles=num_v_tiles, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + scale=scale, + HV=HV, T=T, H=H, K=K, V=V, + use_qk_l2norm=use_qk_l2norm, + disable_state_update=disable_state_update, + emit_output=emit_output, + write_ubuf=write_ubuf, + fast_math=fast_math, + stream=cuda.CUstream(torch.cuda.current_stream().cuda_stream), + options=f"--enable-tvm-ffi --opt-level {opt_level}", + ) + _compiled_gemm_kvbuffer_cute_kernels[key] = compiled_kernel + logger.info( + "CuTe DSL KDA MTP gemm-KVBuffer (sm90 mma) kernel compiled: " + f"N={N}, T={T}, HV={HV}, K={K}, V={V}, BV={bv}, num_v_tiles={num_v_tiles}, opt_level={opt_level}" + ) + return compiled_kernel + + +def kda_decode_mtp_gemm_kvbuffer_cute( + A_log: torch.Tensor, + dt_bias: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + initial_state_source: torch.Tensor, + initial_state_indices: torch.Tensor, + scale: float | None = None, + use_qk_l2norm_in_kernel: bool = True, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, + out: torch.Tensor | None = None, + disable_state_update: bool = True, + emit_output: bool = True, + u_buffer: torch.Tensor | None = None, + kinv_buffer: torch.Tensor | None = None, + b_buffer: torch.Tensor | None = None, + bv: int = 32, + num_v_tiles: int = -1, + opt_level: int = 3, + fast_math: bool = True, +) -> torch.Tensor: + """KDA MTP decode — CuTe sm_90 tensor-core kvbuffer VERIFY (port of the Triton gemm op).""" + N, T, H, K = q.shape + HV = v.shape[2] + V = v.shape[3] + write_ubuf = u_buffer is not None + + if scale is None: + scale = K**-0.5 + assert K == TILE_K == 128, f"cute-gemm-kvbuffer requires K=128, got {K}" + assert T <= 8, f"cute-gemm-kvbuffer (BT stacked) needs T<=8, got {T}" + assert bv == 32, f"cute-gemm-kvbuffer (BT) requires bv=32 (one n-tile per warp), got {bv}" + assert V % bv == 0 and bv % 16 == 0, f"bv must divide V and be 16-aligned, got {bv}" + if num_v_tiles <= 0: + # auto: split V across CTAs until the grid reaches ~512 (fills H200's 132 SMs + # at small batch); producer redundancy per extra slice is negligible. + num_v_tiles = 1 + while num_v_tiles < V // bv and N * HV * num_v_tiles < 512: + num_v_tiles *= 2 + assert (V // bv) % num_v_tiles == 0, f"num_v_tiles must divide V//bv, got num_v_tiles={num_v_tiles}" + + h0_source, pool_size, _ = _normalize_state_source( + initial_state_source, N=N, HV=HV, K=K, V=V, device=q.device, state_layout="vk", + ) + a = _normalize_mtp_a(a, N=N, T=T, HV=HV, K=K) + if b.dim() != 3 or tuple(b.shape) != (N, T, HV): + raise ValueError(f"Unexpected b shape for MTP dense: {tuple(b.shape)}; expected {(N, T, HV)}") + o = _prepare_output_tensor(q, out, (N, T, HV, V)) + q = q if q.is_contiguous() else q.contiguous() + k = k if k.is_contiguous() else k.contiguous() + v = v if v.is_contiguous() else v.contiguous() + a = a if a.is_contiguous() else a.contiguous() + b = b if b.is_contiguous() else b.contiguous() + A_log = _normalize_A_log(A_log, HV) + dt_bias = _normalize_dt_bias(dt_bias, HV, K) + initial_state_indices = _normalize_state_indices( + initial_state_indices, N=N, pool_size=pool_size, device=q.device + ) + + if write_ubuf: + if tuple(u_buffer.shape) != (N, T, HV, V): + raise ValueError(f"u_buffer shape must be {(N, T, HV, V)}, got {tuple(u_buffer.shape)}") + if tuple(kinv_buffer.shape) != (N, T, HV, K) or tuple(b_buffer.shape) != (N, T, HV, K): + raise ValueError(f"kinv_buffer/b_buffer shape must be {(N, T, HV, K)}") + u_buf, kinv_buf, b_buf = u_buffer, kinv_buffer, b_buffer + else: + u_buf = torch.empty(N, T, HV, V, dtype=torch.float32, device=q.device) + kinv_buf = torch.empty(N, T, HV, K, dtype=torch.float32, device=q.device) + b_buf = torch.empty(N, T, HV, K, dtype=torch.float32, device=q.device) + + stream = _get_cached_stream(q.device) + h0_source_flat = h0_source.view(pool_size * HV, V, K) + compiled_kernel = _get_compiled_gemm_kvbuffer_cute_kernel( + N, T, H, HV, K, V, pool_size, bv, num_v_tiles, + scale=scale, use_qk_l2norm=use_qk_l2norm_in_kernel, + disable_state_update=disable_state_update, emit_output=emit_output, + write_ubuf=write_ubuf, + softplus_beta=softplus_beta, softplus_threshold=softplus_threshold, + opt_level=opt_level, fast_math=fast_math, + ) + compiled_kernel( + h0_source_flat, A_log, a, dt_bias, q, k, v, b, o, + initial_state_indices, u_buf, kinv_buf, b_buf, stream, + ) + return o + + +# --------------------------------------------------------------------------- +# BT=8 stacked variant of the cute-gemm kernel (T <= 8). mma.sync m16n8k8 has a +# hard M=16, so instead of padding tokens to 16 the spare 8 M-rows carry a +# SECOND matrix — pad waste becomes a ~2x instruction saving: +# P3: [kdec; qdec] @ kinv^T -> A (top) and P (bottom) in one GEMM chain +# P4: Neumann inverse in plain fp32 (precision); L_s is strictly-lower 8x8 so +# L_s^8 = 0 -> inv = (I+L_s)(I+L_s^2)(I+L_s^4), exactly 3 doubling steps +# P5: [kdec; qdec] @ S0^T -> Skdec + Sqdec together; u = inv @ (beta*x) on +# tensor cores; o-combine P@u in exact fp32 from SMEM (16 FMA/lane) +# Requires BV=32 (4 n-tiles = 1 per warp, keeps barriers warp-uniform). +# --------------------------------------------------------------------------- +BT = 8 + + +@cute.kernel +def kda_mtp_gemm_kvbuffer_cute_kernel( + h0_source: cute.Tensor, + A_log: cute.Tensor, + a: cute.Tensor, + dt_bias: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + b: cute.Tensor, + o: cute.Tensor, + h0_indices: cute.Tensor, + u_buf: cute.Tensor, + kinv_buf: cute.Tensor, + b_buf: cute.Tensor, + vec_size: cutlass.Constexpr[int], + BV: cutlass.Constexpr[int], + num_v_tiles: cutlass.Constexpr[int], + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + disable_state_update: cutlass.Constexpr[bool], + emit_output: cutlass.Constexpr[bool], + write_ubuf: cutlass.Constexpr[bool], + fast_math: cutlass.Constexpr[bool], +): + tidx, _, _ = cute.arch.thread_idx() + lane_id = tidx % 32 + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + gid = lane_id // 4 + tig = lane_id % 4 + + num_warps: cutlass.Constexpr[int] = 4 + bidx, _, _ = cute.arch.block_idx() + i_v = bidx % num_v_tiles + tmp = bidx // num_v_tiles + i_hv = tmp % HV + i_n = tmp // HV + i_h = i_hv // (HV // H) + + cache_idx = h0_indices[i_n] + r_exp_A = cute.exp(cutlass.Float32(A_log[i_hv]), fastmath=fast_math) + + smem = cutlass.utils.SmemAllocator() + # stacked feature maps: rows 0..7 = kdec(tokens, pad-zeroed), rows 8..15 = qdec + sKQ = smem.allocate_tensor(cutlass.Float32, cute.make_layout((2 * BT, K), stride=(K + 8, 1)), 16) + sKinv = smem.allocate_tensor(cutlass.Float32, cute.make_layout((BT, K), stride=(K + 8, 1)), 16) + sG = smem.allocate_tensor(cutlass.Float32, cute.make_layout((BT, K), stride=(K + 8, 1)), 16) + sBeta = smem.allocate_tensor(cutlass.Float32, cute.make_layout((BT,)), 16) + sBlast = smem.allocate_tensor(cutlass.Float32, cute.make_layout((K,)), 16) + # P3 cross-warp partial tiles: row = warp*16 + stacked-row + sPart = smem.allocate_tensor(cutlass.Float32, cute.make_layout((4 * 16, 12), stride=(12, 1)), 16) + sL = smem.allocate_tensor(cutlass.Float32, cute.make_layout((BT, BT), stride=(BT + 1, 1)), 16) + sP = smem.allocate_tensor(cutlass.Float32, cute.make_layout((BT, BT), stride=(BT + 1, 1)), 16) + sInv = smem.allocate_tensor(cutlass.Float32, cute.make_layout((BT, BT), stride=(BT + 1, 1)), 16) + sLp = smem.allocate_tensor(cutlass.Float32, cute.make_layout((BT, BT), stride=(BT + 1, 1)), 16) + sX = smem.allocate_tensor(cutlass.Float32, cute.make_layout((BT, BV), stride=(BV + 1, 1)), 16) + sU = smem.allocate_tensor(cutlass.Float32, cute.make_layout((BT, BV), stride=(BV + 1, 1)), 16) + sS0 = smem.allocate_tensor(cutlass.Float32, cute.make_layout((BV, K), stride=(K + 8, 1)), 16) + + r_qbf = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) + r_kbf = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16) + r_qf = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_kf = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + r_s = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) + + if cache_idx >= 0: + k_start = lane_id * vec_size + flat_state_idx = cache_idx * HV + i_hv + + # ---- P1: token-parallel l2norm + staging (k_norm -> sKQ top, q_scaled -> bottom) ---- + tokens_per_warp: cutlass.Constexpr[int] = (T + num_warps - 1) // num_warps + for tt in cutlass.range_constexpr(tokens_per_warp): + t_tok = tt * num_warps + warp_idx + if t_tok < T: + q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, t_tok, i_h, lane_id)) + k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, t_tok, i_h, lane_id)) + cute.autovec_copy(q_tile, r_qbf) + cute.autovec_copy(k_tile, r_kbf) + for c in cutlass.range_constexpr(vec_size): + r_qf[c] = cutlass.Float32(r_qbf[c]) + r_kf[c] = cutlass.Float32(r_kbf[c]) + if cutlass.const_expr(use_qk_l2norm): + sum_q = cutlass.Float32(0.0) + sum_k = cutlass.Float32(0.0) + for c in cutlass.range_constexpr(vec_size): + sum_q += r_qf[c] * r_qf[c] + sum_k += r_kf[c] * r_kf[c] + for off in [16, 8, 4, 2, 1]: + sum_q += cute.arch.shuffle_sync_bfly(sum_q, offset=off, mask=-1, mask_and_clamp=31) + sum_k += cute.arch.shuffle_sync_bfly(sum_k, offset=off, mask=-1, mask_and_clamp=31) + inv_q = cute.rsqrt(sum_q + 1e-6, fastmath=fast_math) * scale + inv_k = cute.rsqrt(sum_k + 1e-6, fastmath=fast_math) + for c in cutlass.range_constexpr(vec_size): + r_qf[c] = r_qf[c] * inv_q + r_kf[c] = r_kf[c] * inv_k + else: + for c in cutlass.range_constexpr(vec_size): + r_qf[c] = r_qf[c] * scale + # gate g_t per channel into sG (decay applied in P2) + for c in cutlass.range_constexpr(vec_size): + x = cutlass.Float32(a[i_n, t_tok, i_hv, k_start + c]) + cutlass.Float32( + dt_bias[i_hv, k_start + c] + ) + beta_x = softplus_beta * x + exp_bx = cute.exp(beta_x, fastmath=fast_math) + sp_val = (cutlass.Float32(1.0) / softplus_beta) * cute.log( + cutlass.Float32(1.0) + exp_bx, fastmath=fast_math + ) + use_sp = ( + cutlass.Float32(1.0) + if beta_x <= softplus_threshold + else cutlass.Float32(0.0) + ) + sp_x = use_sp * sp_val + (cutlass.Float32(1.0) - use_sp) * x + sG[t_tok, k_start + c] = cute.exp(-r_exp_A * sp_x, fastmath=fast_math) # g_t directly (exact prefix product in P2) + sKQ[t_tok, k_start + c] = r_kf[c] + sKQ[BT + t_tok, k_start + c] = r_qf[c] + if lane_id == 0: + sBeta[t_tok] = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + + cute.exp(-cutlass.Float32(b[i_n, t_tok, i_hv]), fastmath=fast_math) + ) + for rp in cutlass.range_constexpr(BT - T): + sKQ[T + rp, tidx] = cutlass.Float32(0.0) + sKQ[BT + T + rp, tidx] = cutlass.Float32(0.0) + sKinv[T + rp, tidx] = cutlass.Float32(0.0) + if tidx >= T: + if tidx < BT: + sBeta[tidx] = cutlass.Float32(0.0) + cute.arch.barrier() + + # ---- P2: K-parallel prefix-product scan (thread = channel kc) ---- + kc = tidx # requires K == 128 == block size + bcum = cutlass.Float32(1.0) + for i_t in cutlass.range_constexpr(T): + bcum = bcum * sG[i_t, kc] + binv = cutlass.Float32(1.0) / bcum + kn = sKQ[i_t, kc] + kinv_v = kn * binv + sKQ[i_t, kc] = kn * bcum + sKQ[BT + i_t, kc] = sKQ[BT + i_t, kc] * bcum + sKinv[i_t, kc] = kinv_v + if cutlass.const_expr(write_ubuf): + if i_v == 0: + kinv_buf[i_n, i_t, i_hv, kc] = kinv_v + b_buf[i_n, i_t, i_hv, kc] = bcum + sBlast[kc] = bcum + cute.arch.barrier() + + # ---- P3: stacked [kdec; qdec] @ kinv^T — 16 k-slabs, 4 per warp, partials in SMEM ---- + c0 = cutlass.Float32(0.0) + c1 = cutlass.Float32(0.0) + c2 = cutlass.Float32(0.0) + c3 = cutlass.Float32(0.0) + for ks in cutlass.range_constexpr(K // 8 // num_warps): + kb = (warp_idx * (K // 8 // num_warps) + ks) * 8 + a0 = sKQ[gid, kb + tig] + a1 = sKQ[gid + 8, kb + tig] + a2 = sKQ[gid, kb + tig + 4] + a3 = sKQ[gid + 8, kb + tig + 4] + b0 = sKinv[gid, kb + tig] + b1 = sKinv[gid, kb + tig + 4] + c0, c1, c2, c3 = _mma_m16n8k8_tf32(a0, a1, a2, a3, b0, b1, c0, c1, c2, c3) + for fi in cutlass.range_constexpr(4): + row = gid + (fi // 2) * 8 + col = 2 * tig + (fi % 2) + cv = c0 + if cutlass.const_expr(fi == 1): + cv = c1 + if cutlass.const_expr(fi == 2): + cv = c2 + if cutlass.const_expr(fi == 3): + cv = c3 + sPart[warp_idx * 16 + row, col] = cv + cute.arch.barrier() + # reduce 4 partials; top half -> L (strict lower, -beta), bottom -> P (lower) + rr = tidx // 8 + cc = tidx % 8 + psum = ( + sPart[rr, cc] + sPart[16 + rr, cc] + sPart[32 + rr, cc] + sPart[48 + rr, cc] + ) + if rr < BT: + keep = cutlass.Float32(1.0) if rr > cc else cutlass.Float32(0.0) + sL[rr, cc] = -sBeta[rr] * psum * keep + else: + tr = rr - BT + keep = cutlass.Float32(1.0) if tr >= cc else cutlass.Float32(0.0) + sP[tr, cc] = psum * keep + cute.arch.barrier() + if tidx < BT * BT: + ri = tidx // BT + ci = tidx % BT + one = cutlass.Float32(1.0) if ri == ci else cutlass.Float32(0.0) + sInv[ri, ci] = one # inv starts at I: each doubling step does inv += inv@Lp_old + # (with Lp_old = Ls^(2^step)), so I+Ls is produced by step 0 + sLp[ri, ci] = sL[ri, ci] + cute.arch.barrier() + + # ---- P4: doubling chain + Pinv on the 8x8 mats in PLAIN fp32 + ri = tidx // BT + ci = tidx % BT + for step in cutlass.range_constexpr(3): # 3 steps: (I+Ls)(I+Ls^2)(I+Ls^4), nilpotency 8 + if tidx < 2 * BT * BT: # rows 0..7 -> Lp@Lp, rows 8..15 -> inv@Lp + rr = ri % BT + acc = cutlass.Float32(0.0) + for l in cutlass.range_constexpr(BT): + if ri < BT: + acc += sLp[rr, l] * sLp[l, ci] + else: + acc += sInv[rr, l] * sLp[l, ci] + sPart[ri, ci] = acc + cute.arch.barrier() + if tidx < BT * BT: + sLp[ri, ci] = sPart[ri, ci] + sInv[ri, ci] = sInv[ri, ci] + sPart[BT + ri, ci] + cute.arch.barrier() + cute.arch.barrier() + + # ---- P5 consumer. V tiled 3 ways (outer->inner): + # num_v_tiles : V split across CTAs (grid=N*HV*num_v_tiles) + # BV=32 : V rows/block = 4 warps x mma-N(8); 1 n-tile/warp, uniform barriers + # num_v_blocks : BV-blocks each CTA walks serially + num_v_blocks: cutlass.Constexpr[int] = V // BV // num_v_tiles + for vb in cutlass.range_constexpr(num_v_blocks): + v_base = (i_v * num_v_blocks + vb) * BV # global V-row start of this block + row_vecs = K // vec_size # float4s per V row + # stage S0[BV,K] -> sS0: 128 threads (blockDim), one float4 each; + # passes = BV*K / (128*vec_size) + for j in cutlass.range_constexpr(BV * K // (128 * vec_size)): + flat = j * 128 + tidx # float4-group id + s_row = flat // row_vecs # V row + s_col = flat % row_vecs # float4 within row + h_tile = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v_base + s_row, s_col) + ) + cute.autovec_copy(h_tile, r_s) + for cc in cutlass.range_constexpr(vec_size): + sS0[s_row, s_col * vec_size + cc] = r_s[cc] + cute.arch.barrier() + + nb = warp_idx * 8 # current warp's n-tile = V rows [nb, nb+8) within the BV block + # the two adjacent V indices this lane owns (mma N-frag: 2*tig, 2*tig+1) + vc0 = nb + 2 * tig + vc1 = nb + 2 * tig + 1 + # GEMM1: [kdec; qdec] @ S0^T -> Skdec (rows 0..7) + Sqdec (rows 8..15) + e0 = cutlass.Float32(0.0) + e1 = cutlass.Float32(0.0) + e2 = cutlass.Float32(0.0) + e3 = cutlass.Float32(0.0) + for ks in cutlass.range_constexpr(K // 8): + kb = ks * 8 + a0 = sKQ[gid, kb + tig] + a1 = sKQ[gid + 8, kb + tig] + a2 = sKQ[gid, kb + tig + 4] + a3 = sKQ[gid + 8, kb + tig + 4] + b0 = sS0[nb + gid, kb + tig] + b1 = sS0[nb + gid, kb + tig + 4] + e0, e1, e2, e3 = _mma_m16n8k8_tf32(a0, a1, a2, a3, b0, b1, e0, e1, e2, e3) + # x = beta * (v - Skdec) from the top half; Sqdec (e2/e3) stays in registers + vmask = cutlass.Float32(1.0) if gid < T else cutlass.Float32(0.0) + vv0 = cutlass.Float32(v[i_n, gid % T, i_hv, v_base + vc0]) * vmask + vv1 = cutlass.Float32(v[i_n, gid % T, i_hv, v_base + vc1]) * vmask + sX[gid, vc0] = sBeta[gid] * (vv0 - e0) + sX[gid, vc1] = sBeta[gid] * (vv1 - e1) + cute.arch.barrier() + + # u = inv @ x in exact fp32 + f0 = cutlass.Float32(0.0) + f1 = cutlass.Float32(0.0) + for l in cutlass.range_constexpr(BT): + f0 += sInv[gid, l] * sX[l, vc0] + f1 += sInv[gid, l] * sX[l, vc1] + sU[gid, vc0] = f0 + sU[gid, vc1] = f1 + if cutlass.const_expr(write_ubuf): + if gid < T: + u_buf[i_n, gid, i_hv, v_base + vc0] = f0 + u_buf[i_n, gid, i_hv, v_base + vc1] = f1 + cute.arch.barrier() + # o = Sqdec + P@u combined in exact fp32 from sU (16 FMA/lane — removes the + # extra tf32 hop that the stacked [inv;Pinv]@x route put on the output path) + if cutlass.const_expr(emit_output): + if gid < T: + ov0 = e2 + ov1 = e3 + for l in cutlass.range_constexpr(BT): + ov0 += sP[gid, l] * sU[l, vc0] + ov1 += sP[gid, l] * sU[l, vc1] + o[(i_n, gid, i_hv, v_base + vc0)] = cutlass.BFloat16(ov0) + o[(i_n, gid, i_hv, v_base + vc1)] = cutlass.BFloat16(ov1) + + # state: S_T = b_last * (S0 + u^T @ kinv), M = v rows, single k-slab + if cutlass.const_expr(not disable_state_update): + m_tiles: cutlass.Constexpr[int] = BV // 16 + pairs: cutlass.Constexpr[int] = m_tiles * (K // 8) + for pp in cutlass.range_constexpr((pairs + num_warps - 1) // num_warps): + pidx = pp * num_warps + warp_idx + if pidx < pairs: + m_t = pidx % m_tiles + n_t = pidx // m_tiles + mb = m_t * 16 + nb = n_t * 8 + g0 = cutlass.Float32(0.0) + g1 = cutlass.Float32(0.0) + g2 = cutlass.Float32(0.0) + g3 = cutlass.Float32(0.0) + a0 = sU[tig, mb + gid] + a1 = sU[tig, mb + gid + 8] + a2 = sU[tig + 4, mb + gid] + a3 = sU[tig + 4, mb + gid + 8] + b0 = sKinv[tig, nb + gid] + b1 = sKinv[tig + 4, nb + gid] + g0, g1, g2, g3 = _mma_m16n8k8_tf32(a0, a1, a2, a3, b0, b1, g0, g1, g2, g3) + for fi in cutlass.range_constexpr(4): + vrow = mb + gid + (fi // 2) * 8 + kcol = nb + 2 * tig + (fi % 2) + gv = g0 + if cutlass.const_expr(fi == 1): + gv = g1 + if cutlass.const_expr(fi == 2): + gv = g2 + if cutlass.const_expr(fi == 3): + gv = g3 + h0_source[(flat_state_idx, v_base + vrow, kcol)] = ( + sBlast[kcol] * (sS0[vrow, kcol] + gv) + ) + cute.arch.barrier() + + +@cute.jit +def run_kda_mtp_gemm_kvbuffer_cute_kernel( + h0_source: cute.Tensor, + A_log: cute.Tensor, + a: cute.Tensor, + dt_bias: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + b: cute.Tensor, + o: cute.Tensor, + h0_indices: cute.Tensor, + u_buf: cute.Tensor, + kinv_buf: cute.Tensor, + b_buf: cute.Tensor, + vec_size: cutlass.Constexpr[int], + BV: cutlass.Constexpr[int], + num_v_tiles: cutlass.Constexpr[int], + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + disable_state_update: cutlass.Constexpr[bool], + emit_output: cutlass.Constexpr[bool], + write_ubuf: cutlass.Constexpr[bool], + fast_math: cutlass.Constexpr[bool], + stream: cuda.CUstream, +): + """BT=8 stacked cute-gemm launcher: grid = N*HV*num_v_tiles, block = 128.""" + n_indices = h0_indices.layout.shape[0] + grid_size = n_indices * HV * num_v_tiles + smem_bytes = ( + 2 * 4 * BT * (K + 8) # sKQ (stacked) + + 2 * 4 * BT * (K + 8) # sKinv + sG + + 4 * BT + 4 * K # sBeta + sBlast + + 4 * 64 * 12 # sPart + + 4 * 4 * BT * (BT + 1) # sL/sP/sInv/sLp + + 2 * 4 * BT * (BV + 1) # sX/sU + + 4 * BV * (K + 8) # sS0 + + 512 + ) + kda_mtp_gemm_kvbuffer_cute_kernel( + h0_source, A_log, a, dt_bias, q, k, v, b, o, h0_indices, + u_buf, kinv_buf, b_buf, + vec_size, BV, num_v_tiles, + softplus_beta, softplus_threshold, scale, + HV, T, H, K, V, + use_qk_l2norm, disable_state_update, emit_output, write_ubuf, fast_math, + ).launch(grid=(grid_size, 1, 1), block=[128, 1, 1], smem=smem_bytes, stream=stream) + + +# --------------------------------------------------------------------------- +# KVBuffer verify dispatch: route between the two kvbuffer verify ops by T. +# --------------------------------------------------------------------------- +def kda_decode_mtp_kvbuffer( + A_log: torch.Tensor, + dt_bias: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + initial_state_source: torch.Tensor, + initial_state_indices: torch.Tensor, + scale: float | None = None, + use_qk_l2norm_in_kernel: bool = True, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, + out: torch.Tensor | None = None, + disable_state_update: bool = True, + emit_output: bool = True, + u_buffer: torch.Tensor | None = None, + kinv_buffer: torch.Tensor | None = None, + b_buffer: torch.Tensor | None = None, + t_crossover: int = 3, + opt_level: int = 3, + fast_math: bool = True, +) -> torch.Tensor: + """KDA MTP KVBuffer verify dispatch by T: < t_crossover (default 3) -> tp-kvbuffer + (token-parallel SIMT), else gemm-kvbuffer (CuTe tensor-core, flat-in-T; crossover T~3 + from H200). Routes only among kvbuffer ops; recurrent fallback is a higher-layer concern. + """ + T = q.shape[1] + common = dict( + A_log=A_log, dt_bias=dt_bias, q=q, k=k, v=v, a=a, b=b, + initial_state_source=initial_state_source, initial_state_indices=initial_state_indices, + scale=scale, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + softplus_beta=softplus_beta, softplus_threshold=softplus_threshold, out=out, + disable_state_update=disable_state_update, emit_output=emit_output, + u_buffer=u_buffer, kinv_buffer=kinv_buffer, b_buffer=b_buffer, + opt_level=opt_level, fast_math=fast_math, + ) + if T >= t_crossover: + return kda_decode_mtp_gemm_kvbuffer_cute(**common) + return kda_decode_mtp_tp_kvbuffer(**common) diff --git a/tests/test_kda_decode_mtp.py b/tests/test_kda_decode_mtp.py new file mode 100644 index 00000000..5a1bb79e --- /dev/null +++ b/tests/test_kda_decode_mtp.py @@ -0,0 +1,631 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pathlib +import sys + +import pytest +import torch +import torch.nn.functional as F + +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent)) +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent)) # for sibling test import + +from cula.kda import ( + kda_decode, + kda_decode_mtp_ws, +) +from cula.ops.kda_decode_mtp import ( + _select_mtp_config, + _select_mtp_tile_v, + kda_decode_mtp_small_batch, +) +from test_kda_decode import torch_kda_decode_ref # trusted single-token reference +from cula.ops.kda_decode_mtp_kvbuffer import ( + _select_kvb_tile_v, + _select_tp_kvb_ilp_rows, + kda_decode_mtp_gemm_kvbuffer_cute, + kda_decode_mtp_kvbuffer, + kda_decode_mtp_tp_kvbuffer, + kda_flush_kvbuffer, +) + + +def torch_kda_mtp_ref(q, k, v, a, b, A_log, dt_bias, state, scale, + use_l2norm=True, softplus_beta=1.0, softplus_threshold=20.0): + """fp32 ground truth: the single-token KDA recurrence threaded over T. Returns (o, final_state).""" + N, T, HV, V = v.shape + K = q.shape[-1] + H = q.shape[2] + heads_per_group = HV // H + A = torch.exp(A_log) + state_cur = state.clone() + o = torch.zeros(N, T, HV, V, dtype=torch.float32, device=q.device) + for t in range(T): + for n in range(N): + for hv in range(HV): + i_h = hv // heads_per_group + x = a[n, t, hv, :] + dt_bias[hv, :] + sp = F.softplus(x, beta=softplus_beta, threshold=softplus_threshold) + gate = torch.exp(-A[hv] * sp) + if use_l2norm: + q_vec = F.normalize(q[n, t, i_h, :], dim=0) * scale + k_vec = F.normalize(k[n, t, i_h, :], dim=0) + else: + q_vec = q[n, t, i_h, :] * scale + k_vec = k[n, t, i_h, :] + Hk = state_cur[n, hv] @ (gate * k_vec) + beta_val = torch.sigmoid(b[n, t, hv]) + v_new = beta_val * (v[n, t, hv, :] - Hk) + state_cur[n, hv] = gate[None, :] * state_cur[n, hv] + v_new[:, None] * k_vec[None, :] + o[n, t, hv, :] = state_cur[n, hv] @ q_vec + return o, state_cur + + +def make_inputs_mtp(N, T, H, HV, K, V, device="cuda", seed=42): + """Random MTP inputs (q/k/v/a/b bf16, A_log/dt_bias/state fp32).""" + torch.manual_seed(seed) + q = torch.randn(N, T, H, K, device=device, dtype=torch.bfloat16) + k = torch.randn(N, T, H, K, device=device, dtype=torch.bfloat16) + v = torch.randn(N, T, HV, V, device=device, dtype=torch.bfloat16) + a = (torch.randn(N, T, HV, K, device=device, dtype=torch.float32) * 0.1).to(torch.bfloat16) + b = torch.randn(N, T, HV, device=device, dtype=torch.bfloat16) + A_log = -torch.rand(HV, device=device, dtype=torch.float32) * 2 # negative -> A < 1 + dt_bias = torch.randn(HV, K, device=device, dtype=torch.float32) * 0.1 + state = torch.randn(N, HV, V, K, device=device, dtype=torch.float32) * 0.01 + return q, k, v, a, b, A_log, dt_bias, state + + +def run_kda_decode_mtp_via_loop_dense(q, k, v, a, b, A_log, dt_bias, state, scale, opt_level=1): + """The "loop" baseline: T sequential single-token kda_decode calls, state carried across tokens.""" + N, T, H, K = q.shape + HV, V = v.shape[2], v.shape[3] + state_source = state.clone().contiguous() + indices = torch.arange(N, device=q.device, dtype=torch.int32) + o_all = torch.empty(N, T, HV, V, device=q.device, dtype=torch.bfloat16) + for t in range(T): + q_t = q[:, t].unsqueeze(1).contiguous() + k_t = k[:, t].unsqueeze(1).contiguous() + v_t = v[:, t].unsqueeze(1).contiguous() + a_t = a[:, t].unsqueeze(1).contiguous() + b_t = b[:, t].unsqueeze(1).contiguous() + o_t = kda_decode( + A_log=A_log, dt_bias=dt_bias, + q=q_t.to(torch.bfloat16), k=k_t.to(torch.bfloat16), v=v_t.to(torch.bfloat16), + a=a_t.to(torch.bfloat16), b=b_t.to(torch.bfloat16), + initial_state_source=state_source, initial_state_indices=indices, + scale=scale, use_qk_l2norm_in_kernel=True, opt_level=opt_level, + ) + o_all[:, t] = o_t.squeeze(1) + return o_all, state_source + + +def _assert_close(name, ref, actual, atol=3e-2, rtol=2e-2): + """allclose, printing the observed max/mean margin (pytest -s).""" + diff = (ref.float() - actual.float()).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + print(f" [{name}] max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f} (atol={atol}, rtol={rtol})") + ok = torch.allclose(ref.float(), actual.float(), atol=atol, rtol=rtol) + assert ok, f"{name}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, atol={atol}, rtol={rtol}" + + +def oracle_intermediate_states(q, k, v, a, b, A_log, dt_bias, state, scale): + """fp32 per-token state snapshots [N,T,HV,V,K] from the trusted single-token reference.""" + N, T = q.shape[0], q.shape[1] + HV, V, K = v.shape[2], v.shape[3], q.shape[3] + state_cur = state.clone() + inter = torch.zeros(N, T, HV, V, K, dtype=torch.float32, device=q.device) + for t in range(T): + _, state_cur = torch_kda_decode_ref( + q[:, t].float(), k[:, t].float(), v[:, t].float(), + a[:, t], b[:, t].float(), A_log, dt_bias, state_cur, scale, + ) + inter[:, t] = state_cur + return inter + + +def run_ws(q, k, v, a, b, A_log, dt_bias, state, scale, *, tile_v=None, + ilp_rows=None, use_packed_fma=None, use_smem_v=None, + disable_state_update=False, intermediate=False): + """Run kda_decode_mtp_ws (vk). Returns (o, state) or (o, state, inter).""" + N, T, _, K = q.shape + HV, V = v.shape[2], v.shape[3] + st = state.clone().contiguous() + indices = torch.arange(N, device=q.device, dtype=torch.int32) + inter = torch.zeros(N, T, HV, V, K, device=q.device, dtype=torch.float32) if intermediate else None + o = kda_decode_mtp_ws( + A_log=A_log, dt_bias=dt_bias, + q=q.to(torch.bfloat16), k=k.to(torch.bfloat16), v=v.to(torch.bfloat16), + a=a.to(torch.bfloat16), b=b.to(torch.bfloat16), + initial_state_source=st, initial_state_indices=indices, + scale=scale, use_qk_l2norm_in_kernel=True, + tile_v=tile_v, ilp_rows=ilp_rows, use_packed_fma=use_packed_fma, + use_smem_v=use_smem_v, disable_state_update=disable_state_update, + intermediate_states_buffer=inter, + ) + return (o, st, inter) if intermediate else (o, st) + + +def run_small_batch(q, k, v, a, b, A_log, dt_bias, state, scale, *, variant, + bv=-1, k_split=-1, disable_state_update=False, intermediate=False): + """Run kda_decode_mtp_small_batch; state fed/returned in vk layout (kv transposed in and back).""" + N = q.shape[0] + indices = torch.arange(N, device=q.device, dtype=torch.int32) + T = q.shape[1]; HV, V, K = v.shape[2], v.shape[3], q.shape[3] + inter = torch.zeros(N, T, HV, V, K, device=q.device, dtype=torch.float32) if intermediate else None + st = state.clone().contiguous() + if variant == "kv": + st = st.transpose(-2, -1).contiguous() # vk -> kv + sb_kwargs = dict( + A_log=A_log, dt_bias=dt_bias, + q=q.to(torch.bfloat16), k=k.to(torch.bfloat16), v=v.to(torch.bfloat16), + a=a.to(torch.bfloat16), b=b.to(torch.bfloat16), + initial_state_source=st, initial_state_indices=indices, + scale=scale, use_qk_l2norm_in_kernel=True, + variant=variant, k_split=k_split, disable_state_update=disable_state_update, + intermediate_states_buffer=inter, + ) + if variant == "vk": + sb_kwargs["bv"] = bv # kv is fixed 1-warp; bv stays at the WARP_BV default + o = kda_decode_mtp_small_batch(**sb_kwargs) + state_vk = st.transpose(-2, -1).contiguous() if variant == "kv" else st + return (o, state_vk, inter) if intermediate else (o, state_vk) + + +@pytest.mark.parametrize("T", [1, 2, 4, 8]) +def test_mtp_ref_is_threaded_single_token(T): + """Pure-torch: the MTP oracle equals the trusted single-token ref threaded over T.""" + N, H, HV, K, V = 4, 8, 16, 128, 128 + scale = K**-0.5 + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K, V) + o_mtp, st_mtp = torch_kda_mtp_ref( + q.float(), k.float(), v.float(), a, b.float(), A_log, dt_bias, state.clone(), scale) + st_cur = state.clone() + o_manual = torch.zeros(N, T, HV, V, dtype=torch.float32, device=q.device) + for t in range(T): + o_t, st_cur = torch_kda_decode_ref( + q[:, t].float(), k[:, t].float(), v[:, t].float(), a[:, t], b[:, t].float(), + A_log, dt_bias, st_cur, scale) + o_manual[:, t] = o_t + torch.testing.assert_close(o_mtp, o_manual, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(st_mtp, st_cur, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("zero_state", [False, True], ids=["randstate", "zerostate"]) +@pytest.mark.parametrize( + "N,T,H,HV", + [ + pytest.param(*c, id="N{}-T{}-H{}-HV{}".format(*c)) + for c in [(1, 1, 8, 16), (4, 4, 8, 16), (16, 8, 8, 16), (64, 2, 16, 32), (4, 4, 16, 32)] + ], +) +def test_oracle_vs_loop(N, T, H, HV, zero_state): + """The looped single-token kernel matches the fp32 oracle (small N).""" + K, V = 128, 128 + scale = K**-0.5 + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K, V) + if zero_state: + state = torch.zeros_like(state) + o_ref, st_ref = torch_kda_mtp_ref( + q.float(), k.float(), v.float(), a, b.float(), A_log, dt_bias, state.clone(), scale) + o_loop, st_loop = run_kda_decode_mtp_via_loop_dense(q, k, v, a, b, A_log, dt_bias, state, scale) + _assert_close("loop output", o_ref, o_loop.float()) + _assert_close("loop final state", st_ref, st_loop) + + +@pytest.mark.parametrize( + "N,T,H,HV,tile_v,ilp_rows,use_smem_v", + [ + pytest.param(*c, id="N{}-T{}-H{}-HV{}-tv{}-ilp{}-smem{}".format(*c)) + for c in [ + # auto (None) across N incl GQA and large batch + (1, 2, 8, 16, None, None, None), + (4, 4, 8, 16, None, None, None), + (16, 4, 16, 32, None, None, None), + (64, 8, 8, 16, None, None, None), + (1024, 2, 8, 16, None, None, None), + (2048, 2, 8, 16, None, None, None), + # explicit tile_v sweep, ilp=2 + (4, 4, 8, 16, 8, 2, False), + (4, 4, 8, 16, 16, 2, False), + (4, 4, 8, 16, 32, 2, False), + (4, 2, 8, 16, 64, 2, False), + # ilp=4 (tile_v % 16 == 0), fused steps + double-accumulator + (4, 4, 8, 16, 16, 4, False), + (4, 4, 8, 16, 32, 4, False), + (4, 2, 8, 16, 64, 4, False), + # use_smem_v on + (4, 4, 8, 16, 32, 4, True), + (16, 2, 16, 32, 64, 4, True), + ] + ], +) +def test_ws_decode(N, T, H, HV, tile_v, ilp_rows, use_smem_v): + """ws warp-spec vs loop: auto / tile_v / ilp 2,4 / use_smem_v / large N in one table.""" + K, V = 128, 128 + scale = K**-0.5 + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K, V) + o_loop, st_loop = run_kda_decode_mtp_via_loop_dense(q, k, v, a, b, A_log, dt_bias, state, scale) + o_ws, st_ws = run_ws(q, k, v, a, b, A_log, dt_bias, state, scale, + tile_v=tile_v, ilp_rows=ilp_rows, use_smem_v=use_smem_v) + tag = f"ws tv={tile_v} ilp={ilp_rows} smem={use_smem_v}" + _assert_close(f"{tag} output", o_loop.float(), o_ws.float()) + _assert_close(f"{tag} final state", st_loop, st_ws) + + +@pytest.mark.parametrize( + "N,T,H,HV,variant,bv,k_split", + [ + pytest.param(*c, id="N{}-T{}-H{}-HV{}-{}-bv{}-ks{}".format(*c)) + for c in [ + # vk: bv sweep + auto, incl T=1 and GQA + (1, 1, 8, 16, "vk", -1, 1), + (4, 4, 8, 16, "vk", -1, 1), + (8, 2, 8, 16, "vk", -1, 1), + (4, 4, 8, 16, "vk", 8, 1), + (4, 4, 8, 16, "vk", 16, 1), + (4, 2, 8, 16, "vk", 32, 1), + (16, 4, 16, 32, "vk", -1, 1), + # kv: k_split sweep + auto, incl T=1 and GQA + (1, 1, 8, 16, "kv", 32, -1), + (4, 4, 8, 16, "kv", 32, -1), + (8, 2, 8, 16, "kv", 32, -1), + (4, 4, 8, 16, "kv", 32, 1), + (4, 4, 8, 16, "kv", 32, 2), + (4, 4, 8, 16, "kv", 32, 4), + (16, 4, 16, 32, "kv", 32, -1), + ] + ], +) +def test_small_batch_decode(N, T, H, HV, variant, bv, k_split): + """small_batch vk + kv vs loop: bv / k_split / auto / GQA in one table.""" + K, V = 128, 128 + scale = K**-0.5 + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K, V) + o_loop, st_loop = run_kda_decode_mtp_via_loop_dense(q, k, v, a, b, A_log, dt_bias, state, scale) + o_sb, st_sb = run_small_batch(q, k, v, a, b, A_log, dt_bias, state, scale, + variant=variant, bv=bv, k_split=k_split) + tag = f"sb {variant} bv={bv} ks={k_split}" + _assert_close(f"{tag} output", o_loop.float(), o_sb.float()) + _assert_close(f"{tag} final state", st_loop, st_sb) + + +@pytest.mark.parametrize("kernel", ["ws", "ws_ilp4", "sb_vk", "sb_kv"]) +def test_disable_state_update(kernel): + """disable_state_update leaves the state pool unchanged while output still matches the loop.""" + N, T, H, HV, K, V = 4, 4, 8, 16, 128, 128 + scale = K**-0.5 + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K, V) + o_loop, _ = run_kda_decode_mtp_via_loop_dense(q, k, v, a, b, A_log, dt_bias, state, scale) + + if kernel == "ws": + o, st = run_ws(q, k, v, a, b, A_log, dt_bias, state, scale, disable_state_update=True) + elif kernel == "ws_ilp4": + o, st = run_ws(q, k, v, a, b, A_log, dt_bias, state, scale, + tile_v=32, ilp_rows=4, disable_state_update=True) + else: + variant = "vk" if kernel == "sb_vk" else "kv" + o, st = run_small_batch(q, k, v, a, b, A_log, dt_bias, state, scale, + variant=variant, disable_state_update=True) + + assert torch.equal(st, state), f"{kernel}: state pool modified despite disable_state_update=True" + _assert_close(f"{kernel} dsu output", o_loop.float(), o.float()) + + +@pytest.mark.parametrize("kernel", ["ws", "ws_smem_v", "sb_vk", "sb_kv"]) +def test_determinism(kernel): + """Bit-exact determinism: repeat the state-writeback launch, assert identical output + state.""" + N, T, H, HV, K, V = 16, 4, 8, 16, 128, 128 + scale = K**-0.5 + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K, V) + + def launch(): + if kernel == "ws": + return run_ws(q, k, v, a, b, A_log, dt_bias, state, scale, + tile_v=64, ilp_rows=4, use_packed_fma=False) + if kernel == "ws_smem_v": + return run_ws(q, k, v, a, b, A_log, dt_bias, state, scale, + tile_v=64, ilp_rows=4, use_packed_fma=False, use_smem_v=True) + variant = "vk" if kernel == "sb_vk" else "kv" + return run_small_batch(q, k, v, a, b, A_log, dt_bias, state, scale, variant=variant) + + o_ref, st_ref = launch() + o_ref = o_ref.clone() + n_iters = int(os.environ.get("KDA_MTP_DET_ITERS", "10000")) + for i in range(n_iters): + o_i, st_i = launch() + assert torch.equal(o_i, o_ref), f"{kernel} output non-deterministic at iter {i}" + assert torch.equal(st_i, st_ref), f"{kernel} state non-deterministic at iter {i}" + + +@pytest.mark.parametrize( + "tile_v,ilp_rows", [(8, 2), (16, 2), (32, 2), (64, 2), (16, 4), (32, 4), (64, 4)] +) +def test_ws_smem_v_bit_identical(tile_v, ilp_rows): + """use_smem_v is pure data movement: byte-for-byte identical to the GMEM path.""" + N, T, H, HV, K, V = 4, 4, 8, 16, 128, 128 + scale = K**-0.5 + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K, V) + o_g, st_g = run_ws(q, k, v, a, b, A_log, dt_bias, state, scale, + tile_v=tile_v, ilp_rows=ilp_rows, use_packed_fma=False, use_smem_v=False) + o_s, st_s = run_ws(q, k, v, a, b, A_log, dt_bias, state, scale, + tile_v=tile_v, ilp_rows=ilp_rows, use_packed_fma=False, use_smem_v=True) + assert torch.equal(o_s, o_g), f"smem_v output != GMEM (tile_v={tile_v}, ilp={ilp_rows})" + assert torch.equal(st_s, st_g), f"smem_v state != GMEM (tile_v={tile_v}, ilp={ilp_rows})" + + +def test_ws_ilp4_rejects_bad_tile_v(): + """ilp=4 requires tile_v % 16 == 0; tile_v=8 must raise.""" + N, T, H, HV, K, V = 4, 2, 8, 16, 128, 128 + scale = K**-0.5 + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K, V) + with pytest.raises(AssertionError): + run_ws(q, k, v, a, b, A_log, dt_bias, state, scale, + tile_v=8, ilp_rows=4, use_packed_fma=False) + + +@pytest.mark.parametrize( + "N,HV,V,T,expected", + [ + (1, 16, 128, 2, (8, 2, False)), + (4, 16, 128, 4, (8, 2, False)), + (1, 65, 128, 2, (16, 4, False)), + (8, 16, 128, 2, (16, 4, False)), + (16, 16, 128, 2, (16, 2, False)), + (16, 16, 128, 4, (32, 4, False)), + (7, 64, 128, 2, (16, 2, False)), + (7, 64, 128, 8, (32, 4, False)), + (16, 64, 128, 2, (32, 4, False)), + (64, 16, 128, 8, (32, 4, False)), + (17, 64, 128, 2, (64, 4, True)), + (256, 64, 128, 8, (64, 4, True)), + (8, 16, 8, 2, (8, 2, False)), + (8, 16, 16, 2, (16, 4, False)), + ], +) +def test_select_mtp_config(N, HV, V, T, expected): + """The joint (tile_v, ilp_rows, use_smem_v) heuristic returns the expected config.""" + assert _select_mtp_config(N, HV, V, T) == expected + assert _select_mtp_tile_v(N, HV, V, T) == expected[0] + + +def test_select_mtp_config_ilp_capped_at_4(): + """ilp is capped at 4 (no ilp=8 path) in every bucket.""" + for N in (1, 8, 16, 64, 256, 4096): + for HV in (16, 64): + for T in (1, 2, 4, 8): + for dsu in (False, True): + _, ilp, _ = _select_mtp_config(N, HV, 128, T, disable_state_update=dsu) + assert ilp in (2, 4), f"N={N},HV={HV},T={T},dsu={dsu} -> ilp={ilp}" + + +@pytest.mark.parametrize("use_smem_v", [False, True]) +@pytest.mark.parametrize("tile_v,ilp_rows", [(16, 2), (32, 4), (64, 4)]) +def test_intermediate_vs_oracle_and_final(use_smem_v, tile_v, ilp_rows): + """Each per-token snapshot == fp32 oracle state; the t=T-1 snapshot == final state pool.""" + N, T, H, HV, K, V = 4, 4, 8, 16, 128, 128 + scale = K**-0.5 + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K, V) + inter_ref = oracle_intermediate_states(q, k, v, a, b, A_log, dt_bias, state.clone(), scale) + _o, st_final, inter = run_ws(q, k, v, a, b, A_log, dt_bias, state, scale, + tile_v=tile_v, ilp_rows=ilp_rows, use_packed_fma=False, + use_smem_v=use_smem_v, intermediate=True) + tag = f"inter smem={use_smem_v} tv={tile_v} ilp={ilp_rows}" + for t in range(T): + _assert_close(f"{tag} snapshot[t={t}]", inter_ref[:, t], inter[:, t]) + assert torch.equal(inter[:, T - 1], st_final), f"{tag}: t=T-1 snapshot != final state pool" + + +def test_intermediate_disable_state_update(): + """disable_state_update leaves the pool untouched; snapshots still fire and match the oracle.""" + N, T, H, HV, K, V = 4, 4, 8, 16, 128, 128 + scale = K**-0.5 + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K, V) + inter_ref = oracle_intermediate_states(q, k, v, a, b, A_log, dt_bias, state.clone(), scale) + + st = state.clone().contiguous() + before = st.clone() + indices = torch.arange(N, device=q.device, dtype=torch.int32) + inter = torch.zeros(N, T, HV, V, K, device=q.device, dtype=torch.float32) + kda_decode_mtp_ws( + A_log=A_log, dt_bias=dt_bias, + q=q.to(torch.bfloat16), k=k.to(torch.bfloat16), v=v.to(torch.bfloat16), + a=a.to(torch.bfloat16), b=b.to(torch.bfloat16), + initial_state_source=st, initial_state_indices=indices, + scale=scale, use_qk_l2norm_in_kernel=True, tile_v=32, ilp_rows=4, + use_packed_fma=False, disable_state_update=True, intermediate_states_buffer=inter, + ) + assert torch.equal(st, before), "pool modified despite disable_state_update=True" + for t in range(T): + _assert_close(f"inter+dsu snapshot[t={t}]", inter_ref[:, t], inter[:, t]) + + +def test_intermediate_buffer_validation(): + """Bad intermediate_states_buffer shape / dtype must raise.""" + N, T, H, HV, K, V = 4, 2, 8, 16, 128, 128 + scale = K**-0.5 + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K, V) + st = state.clone().contiguous() + indices = torch.arange(N, device=q.device, dtype=torch.int32) + + def _call(buf): + return kda_decode_mtp_ws( + A_log=A_log, dt_bias=dt_bias, + q=q.to(torch.bfloat16), k=k.to(torch.bfloat16), v=v.to(torch.bfloat16), + a=a.to(torch.bfloat16), b=b.to(torch.bfloat16), + initial_state_source=st, initial_state_indices=indices, + scale=scale, use_qk_l2norm_in_kernel=True, tile_v=32, ilp_rows=4, + use_packed_fma=False, intermediate_states_buffer=buf, + ) + + with pytest.raises((ValueError, AssertionError)): + _call(torch.zeros(N, T + 1, HV, V, K, device="cuda", dtype=torch.float32)) + with pytest.raises((ValueError, AssertionError)): + _call(torch.zeros(N, T, HV, V, K, device="cuda", dtype=torch.bfloat16)) + + +@pytest.mark.parametrize( + "N,T", [(1, 2), (4, 4), (8, 8), (4, 2), (16, 6)] +) +def test_intermediate_small_batch_vk(N, T): + """vk per-token snapshot == fp32 oracle; t=T-1 snapshot == final state pool.""" + H, HV, K, V = 8, 16, 128, 128 + scale = K**-0.5 + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K, V) + inter_ref = oracle_intermediate_states(q, k, v, a, b, A_log, dt_bias, state.clone(), scale) + o, st_vk, inter = run_small_batch(q, k, v, a, b, A_log, dt_bias, state.clone(), scale, + variant="vk", disable_state_update=False, intermediate=True) + for t in range(T): + _assert_close(f"sbvk inter snapshot[t={t}]", inter_ref[:, t], inter[:, t]) + assert torch.equal(inter[:, T - 1], st_vk), "sbvk: t=T-1 snapshot != final state" + + +K_DIM = 128 # kvbuffer ops hard-require K=128 + + +def _alloc_ubufs(N, T, HV, V, device="cuda"): + """u_buffer [N,T,HV,V], kinv/b_buffer [N,T,HV,K] — fp32, matching the kernel contract.""" + return ( + torch.zeros(N, T, HV, V, dtype=torch.float32, device=device), + torch.zeros(N, T, HV, K_DIM, dtype=torch.float32, device=device), + torch.zeros(N, T, HV, K_DIM, dtype=torch.float32, device=device), + ) + + +def _kvb_verify(which, q, k, v, a, b, A_log, dt_bias, state, scale, *, ubufs=None): + """Run a kvbuffer verify op (disable_state_update=True). Returns output o [N,T,HV,V].""" + N = q.shape[0] + indices = torch.arange(N, device=q.device, dtype=torch.int32) + u_b, kinv_b, b_b = ubufs if ubufs is not None else (None, None, None) + op = kda_decode_mtp_tp_kvbuffer if which == "tp" else kda_decode_mtp_gemm_kvbuffer_cute + return op( + A_log=A_log, dt_bias=dt_bias, + q=q.to(torch.bfloat16), k=k.to(torch.bfloat16), v=v.to(torch.bfloat16), + a=a.to(torch.bfloat16), b=b.to(torch.bfloat16), + initial_state_source=state.clone().contiguous(), initial_state_indices=indices, + scale=scale, use_qk_l2norm_in_kernel=True, + disable_state_update=True, emit_output=True, + u_buffer=u_b, kinv_buffer=kinv_b, b_buffer=b_b, + ) + + +def _kvb_oracle_out(q, k, v, a, b, A_log, dt_bias, state, scale): + o_ref, _ = torch_kda_mtp_ref( + q.float(), k.float(), v.float(), a, b.float(), A_log, dt_bias, state, scale, + ) + return o_ref + + +def _check_kvb_verify_and_flush(which, N, T, H, HV): + """verify output == oracle, u-buffer populated; flush(m) == m-th oracle snapshot (m=full/half/one).""" + V = K_DIM + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K_DIM, V) + scale = K_DIM ** -0.5 + o_ref = _kvb_oracle_out(q, k, v, a, b, A_log, dt_bias, state, scale) + inter_ref = oracle_intermediate_states(q, k, v, a, b, A_log, dt_bias, state.clone(), scale) + + indices = torch.arange(N, device=q.device, dtype=torch.int32) + ubufs = _alloc_ubufs(N, T, HV, V) + o = _kvb_verify(which, q, k, v, a, b, A_log, dt_bias, state, scale, ubufs=ubufs) + _assert_close(f"{which}_verify N{N}T{T}", o_ref, o) + assert ubufs[0].abs().sum() > 0, f"{which}: u_buffer was not written" + + # flush each accept length m -> rebuilt S_m == oracle state after m tokens (snapshot m-1) + for m in sorted({T, max(1, T // 2), 1}): + pool = state.clone().contiguous() + kda_flush_kvbuffer(pool, indices, ubufs[0], ubufs[1], ubufs[2], accept_len=m) + _assert_close(f"{which}_flush N{N}T{T}m{m}", inter_ref[:, m - 1], pool) + + +@pytest.mark.parametrize("N,T,H,HV", [(2, 2, 16, 16), (4, 4, 16, 16), (2, 4, 32, 32)]) +def test_tp_kvbuffer_verify_and_flush(N, T, H, HV): + """tp-kvbuffer (token-parallel SIMT) verify output + rank-m flush match the fp32 oracle.""" + _check_kvb_verify_and_flush("tp", N, T, H, HV) + + +@pytest.mark.parametrize("N,T,H,HV", [(2, 3, 16, 16), (4, 6, 16, 16), (1, 8, 32, 32)]) +def test_cg_kvbuffer_verify_and_flush(N, T, H, HV): + """cg-kvbuffer (CuTe tensor-core gemm) verify output + rank-m flush match the fp32 oracle.""" + _check_kvb_verify_and_flush("cg", N, T, H, HV) + + +@pytest.mark.parametrize("T,routed", [(2, "tp"), (4, "cg")]) +def test_kvbuffer_dispatch_routes_by_T(T, routed): + """kda_decode_mtp_kvbuffer routes T<3 -> tp, T>=3 -> cg (t_crossover=3); output matches oracle either way.""" + N, H, HV, V = 2, 16, 16, K_DIM + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K_DIM, V) + scale = K_DIM ** -0.5 + o_ref = _kvb_oracle_out(q, k, v, a, b, A_log, dt_bias, state, scale) + indices = torch.arange(N, device=q.device, dtype=torch.int32) + o = kda_decode_mtp_kvbuffer( + A_log=A_log, dt_bias=dt_bias, + q=q.to(torch.bfloat16), k=k.to(torch.bfloat16), v=v.to(torch.bfloat16), + a=a.to(torch.bfloat16), b=b.to(torch.bfloat16), + initial_state_source=state.clone().contiguous(), initial_state_indices=indices, + scale=scale, + ) + _assert_close(f"dispatch T{T}->{routed}", o_ref, o) + + +@pytest.mark.parametrize("which,N,T,H,HV", [("tp", 4, 4, 16, 16), ("cg", 4, 6, 16, 16)]) +def test_kvbuffer_verify_determinism(which, N, T, H, HV): + """Repeated kvbuffer verify launches produce a bit-identical output (and u-buffer).""" + V = K_DIM + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K_DIM, V) + scale = K_DIM ** -0.5 + ub_ref = _alloc_ubufs(N, T, HV, V) + o_ref = _kvb_verify(which, q, k, v, a, b, A_log, dt_bias, state, scale, ubufs=ub_ref) + for i in range(3): + ub_i = _alloc_ubufs(N, T, HV, V) + o_i = _kvb_verify(which, q, k, v, a, b, A_log, dt_bias, state, scale, ubufs=ub_i) + assert torch.equal(o_i, o_ref), f"{which} verify output non-deterministic at iter {i}" + assert torch.equal(ub_i[0], ub_ref[0]), f"{which} u-buffer non-deterministic at iter {i}" + + +@pytest.mark.parametrize("which,N,T,H,HV", [("tp", 4, 4, 16, 16), ("cg", 4, 6, 16, 16)]) +def test_kvbuffer_flush_determinism(which, N, T, H, HV): + """Repeated flush launches rebuild a bit-identical state.""" + V = K_DIM + q, k, v, a, b, A_log, dt_bias, state = make_inputs_mtp(N, T, H, HV, K_DIM, V) + scale = K_DIM ** -0.5 + indices = torch.arange(N, device=q.device, dtype=torch.int32) + ubufs = _alloc_ubufs(N, T, HV, V) + _kvb_verify(which, q, k, v, a, b, A_log, dt_bias, state, scale, ubufs=ubufs) + pool_ref = state.clone().contiguous() + kda_flush_kvbuffer(pool_ref, indices, ubufs[0], ubufs[1], ubufs[2], accept_len=T) + for i in range(3): + pool_i = state.clone().contiguous() + kda_flush_kvbuffer(pool_i, indices, ubufs[0], ubufs[1], ubufs[2], accept_len=T) + assert torch.equal(pool_i, pool_ref), f"{which} flush state non-deterministic at iter {i}" + + +@pytest.mark.parametrize("V,N,HV", [(128, 1, 16), (128, 4, 32), (128, 16, 64)]) +def test_select_kvb_tile_v_invariants(V, N, HV): + """The auto tile_v must divide V and be a multiple of 4 (4-warp consumer).""" + tile_v = _select_kvb_tile_v(V, N, HV) + assert V % tile_v == 0 and tile_v % 4 == 0, f"tile_v={tile_v} violates V%tile_v==0 & tile_v%4==0" + + +@pytest.mark.parametrize("tile_v,T", [(64, 2), (32, 4), (64, 8), (16, 6)]) +def test_select_tp_kvb_ilp_rows_invariants(tile_v, T): + """ilp_rows must divide rows_per_group = tile_v/4 (the wrapper asserts this).""" + ilp = _select_tp_kvb_ilp_rows(tile_v, T) + assert ilp >= 1 and (tile_v // 4) % ilp == 0, f"ilp_rows={ilp} must divide tile_v/4={tile_v // 4}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"])