From d9cd0b925e9b222b6fbaf1d2105b4a990014f46f Mon Sep 17 00:00:00 2001 From: "xsank.mz" Date: Mon, 1 Jun 2026 14:59:56 +0800 Subject: [PATCH 1/7] opt compress&topk kernel --- .../python/fastvideo_kernel/ops.py | 37 +--- .../triton_kernels/fused_compress_topk.py | 194 ++++++++++++++++++ 2 files changed, 205 insertions(+), 26 deletions(-) create mode 100644 fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py diff --git a/fastvideo-kernel/python/fastvideo_kernel/ops.py b/fastvideo-kernel/python/fastvideo_kernel/ops.py index 0f5458856..28bc87d71 100644 --- a/fastvideo-kernel/python/fastvideo_kernel/ops.py +++ b/fastvideo-kernel/python/fastvideo_kernel/ops.py @@ -1,11 +1,12 @@ import math import torch -from .block_sparse_attn import block_sparse_attn, block_sparse_attn_from_indices +from .block_sparse_attn import block_sparse_attn from .block_sparse_attn_256 import ( block_sparse_attn_256, block_sparse_attn_256_bshd, ) from .triton_kernels.st_attn_triton import sliding_tile_attention_triton +from .triton_kernels.fused_compress_topk import fused_block_mean, fused_topk_mask # Try to load the C++ extension try: @@ -118,13 +119,10 @@ def video_sparse_attn( f"got {q_variable_block_sizes.numel()}" ) - # Compression branch (token-level average per block + dense block-level attn). - q_c = q.view(batch, heads, q_num_blocks, block_elements, dim) - k_c = k.view(batch, heads, kv_num_blocks, block_elements, dim) - v_c = v.view(batch, heads, kv_num_blocks, block_elements, dim) - q_c = (q_c.float().sum(dim=3) / q_variable_block_sizes.view(1, 1, -1, 1)).to(q.dtype) - k_c = (k_c.float().sum(dim=3) / variable_block_sizes.view(1, 1, -1, 1)).to(k.dtype) - v_c = (v_c.float().sum(dim=3) / variable_block_sizes.view(1, 1, -1, 1)).to(v.dtype) + # Compression branch (fused Triton: bf16 read → fp32 accumulate → div → bf16 write) + q_c = fused_block_mean(q, q_variable_block_sizes, block_elements) + k_c = fused_block_mean(k, variable_block_sizes, block_elements) + v_c = fused_block_mean(v, variable_block_sizes, block_elements) scores = torch.matmul(q_c, k_c.transpose(-2, -1)) / (dim ** 0.5) attn = torch.softmax(scores, dim=-1) @@ -132,25 +130,13 @@ def video_sparse_attn( out_c = out_c.view(batch, heads, q_num_blocks, 1, dim) out_c = out_c.repeat(1, 1, 1, block_elements, 1).view(batch, heads, q_seq_len, dim) - # Sparse branch. - topk_idx = torch.topk(scores, topk, dim=-1).indices + # Sparse branch (fused Triton topk mask) + mask = fused_topk_mask(scores, topk) if block_elements == 256: - # CuTe path consumes a bool mask (full/partial split inside the wrapper). - mask = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, topk_idx, True) out_s = block_sparse_attn_256(q, k, v, mask, variable_block_sizes)[0] else: - # Index-native path for 64-block (TK/Triton). - q2k_idx = topk_idx.to(torch.int32).contiguous() - q2k_num = torch.full( - (batch, heads, q_num_blocks), - topk, - dtype=torch.int32, - device=q.device, - ) - out_s = block_sparse_attn_from_indices( - q, k, v, q2k_idx, q2k_num, variable_block_sizes - )[0] + out_s = block_sparse_attn(q, k, v, mask, variable_block_sizes)[0] if compress_attn_weight is not None: return out_c * compress_attn_weight + out_s @@ -236,9 +222,8 @@ def video_sparse_attn_bshd( out_c_ch = torch.matmul(attn, v_ch) out_c_blk = out_c_ch.permute(0, 2, 1, 3).contiguous() - # Sparse branch (CuTe BSHD). - topk_idx = torch.topk(scores, topk, dim=-1).indices - mask = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, topk_idx, True) + # Sparse branch (fused Triton topk mask + CuTe BSHD). + mask = fused_topk_mask(scores, topk) out_s, _ = block_sparse_attn_256_bshd(q, k, v, mask, variable_block_sizes) out = out_s diff --git a/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py new file mode 100644 index 000000000..50e156b60 --- /dev/null +++ b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py @@ -0,0 +1,194 @@ +""" +Fused Triton kernels for VSA compress (block mean) and topk mask construction. + +Replaces the multi-kernel PyTorch pipeline: + Original compress: .view() -> .float() -> .sum(dim=3) -> / vbs -> .to(bf16) + Original topk: torch.topk() -> zeros() -> scatter_() + +With single-pass fused kernels: + fused_block_mean: read bf16, accumulate fp32, div by vbs, write bf16 + fused_topk_mask: read scores, find k-th value, write bool mask +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fused_block_mean_kernel( + X_ptr, + Out_ptr, + VBS_ptr, + stride_x_bh, + stride_x_seq, + stride_o_bh, + stride_o_blk, + num_blocks, + BLOCK_ELEMENTS: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + """Fused block mean: one program computes mean of one block for one (b,h). + + X is viewed as [B*H, num_blocks*BLOCK_ELEMENTS, HEAD_DIM] contiguous. + Out is [B*H, num_blocks, HEAD_DIM] contiguous. + Accumulates in fp32, outputs in original dtype. + """ + block_idx = tl.program_id(0) + bh_idx = tl.program_id(1) + + if block_idx >= num_blocks: + return + + vbs = tl.load(VBS_ptr + block_idx).to(tl.float32) + + x_base = X_ptr + bh_idx * stride_x_bh + block_idx * BLOCK_ELEMENTS * stride_x_seq + + dim_offsets = tl.arange(0, HEAD_DIM) + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + + for i in range(BLOCK_ELEMENTS): + row_ptr = x_base + i * stride_x_seq + dim_offsets + x_val = tl.load(row_ptr).to(tl.float32) + acc += x_val + + acc = acc / vbs + + out_base = Out_ptr + bh_idx * stride_o_bh + block_idx * stride_o_blk + dim_offsets + tl.store(out_base, acc.to(tl.bfloat16)) + + +def fused_block_mean( + x: torch.Tensor, + variable_block_sizes: torch.Tensor, + block_elements: int, +) -> torch.Tensor: + """Compute block-wise mean with fp32 accumulation, fused in one kernel. + + Args: + x: [B, H, seq_len, D] in bf16 + variable_block_sizes: [num_blocks] number of valid tokens per block + block_elements: tokens per block (e.g. 64) + + Returns: + [B, H, num_blocks, D] in bf16 + """ + B, H, seq_len, D = x.shape + num_blocks = seq_len // block_elements + assert seq_len % block_elements == 0 + + x = x.contiguous() + out = torch.empty(B, H, num_blocks, D, dtype=x.dtype, device=x.device) + + x_flat = x.view(B * H, seq_len, D) + out_flat = out.view(B * H, num_blocks, D) + + grid = (num_blocks, B * H) + + _fused_block_mean_kernel[grid]( + x_flat, out_flat, variable_block_sizes, + x_flat.stride(0), x_flat.stride(1), + out_flat.stride(0), out_flat.stride(1), + num_blocks, + BLOCK_ELEMENTS=block_elements, + HEAD_DIM=D, + ) + + return out + + +@triton.jit +def _fused_topk_mask_kernel( + Scores_ptr, + Mask_ptr, + stride_s_bh, + stride_s_q, + stride_s_kv, + stride_m_bh, + stride_m_q, + stride_m_kv, + kv_blocks: tl.constexpr, + topk: tl.constexpr, + KV_BLOCK_SIZE: tl.constexpr, +): + """Build topk boolean mask via randomized pivot selection (quickselect-style). + + For each (b,h,q_block) row: find the k-th largest score using iterative + pivot-based partitioning, then build mask by comparing against threshold. + + Grid: (num_q_blocks, B * H) + """ + q_idx = tl.program_id(0) + bh_idx = tl.program_id(1) + + kv_offsets = tl.arange(0, KV_BLOCK_SIZE) + score_base = Scores_ptr + bh_idx * stride_s_bh + q_idx * stride_s_q + mask_base = Mask_ptr + bh_idx * stride_m_bh + q_idx * stride_m_q + + valid_mask = kv_offsets < kv_blocks + scores = tl.load(score_base + kv_offsets * stride_s_kv, mask=valid_mask, other=-float("inf")) + scores_f32 = scores.to(tl.float32) + + # Binary search for threshold: find value T such that count(scores > T) <= topk + # and count(scores >= T) >= topk + # Use +inf/-inf sentinels so min/max ignore padding positions + lo = tl.min(tl.where(valid_mask, scores_f32, float("inf")), axis=0) + hi = tl.max(tl.where(valid_mask, scores_f32, float("-inf")), axis=0) + + for _i in range(32): + mid = (lo + hi) * 0.5 + count_ge = tl.sum(((scores_f32 >= mid) & valid_mask).to(tl.int32), axis=0) + # If count >= topk, threshold is at or above mid + lo = tl.where(count_ge >= topk, mid, lo) + hi = tl.where(count_ge >= topk, hi, mid) + + # lo is our threshold: count(scores >= lo) >= topk + threshold = lo + above_threshold = scores_f32 > threshold + at_threshold = scores_f32 == threshold + n_above = tl.sum(above_threshold.to(tl.int32), axis=0) + n_needed_at_thresh = topk - n_above + + at_thresh_cumsum = tl.cumsum(at_threshold.to(tl.int32), axis=0) + at_thresh_selected = at_threshold & (at_thresh_cumsum <= n_needed_at_thresh) + + final_mask = above_threshold | at_thresh_selected + + tl.store(mask_base + kv_offsets * stride_m_kv, final_mask, mask=valid_mask) + + +def fused_topk_mask( + scores: torch.Tensor, + topk: int, +) -> torch.Tensor: + """Build topk boolean mask from scores using fused Triton kernel. + + Args: + scores: [B, H, q_blocks, kv_blocks] block-level attention scores + topk: number of top blocks to select per q-block + + Returns: + mask: [B, H, q_blocks, kv_blocks] bool tensor with exactly topk True per row + """ + B, H, q_blocks, kv_blocks = scores.shape + topk = min(topk, kv_blocks) + + mask = torch.zeros(B, H, q_blocks, kv_blocks, dtype=torch.bool, device=scores.device) + + KV_BLOCK_SIZE = triton.next_power_of_2(kv_blocks) + + scores_flat = scores.contiguous().view(B * H, q_blocks, kv_blocks) + mask_flat = mask.view(B * H, q_blocks, kv_blocks) + + grid = (q_blocks, B * H) + + _fused_topk_mask_kernel[grid]( + scores_flat, mask_flat, + scores_flat.stride(0), scores_flat.stride(1), scores_flat.stride(2), + mask_flat.stride(0), mask_flat.stride(1), mask_flat.stride(2), + kv_blocks=kv_blocks, + topk=topk, + KV_BLOCK_SIZE=KV_BLOCK_SIZE, + ) + + return mask From c95ac1a90513e8ef1764197fbd6f08071a361927 Mon Sep 17 00:00:00 2001 From: "xsank.mz" Date: Mon, 1 Jun 2026 15:24:18 +0800 Subject: [PATCH 2/7] add bench scirpt --- .../benchmarks/bench_fused_compress_topk.py | 198 ++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 fastvideo-kernel/benchmarks/bench_fused_compress_topk.py diff --git a/fastvideo-kernel/benchmarks/bench_fused_compress_topk.py b/fastvideo-kernel/benchmarks/bench_fused_compress_topk.py new file mode 100644 index 000000000..b76abebd7 --- /dev/null +++ b/fastvideo-kernel/benchmarks/bench_fused_compress_topk.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +""" +Benchmark fused compress (block mean) and topk mask kernels vs. PyTorch baselines. + +Compares: + compress: + - Old: .view() -> .float() -> .sum(dim=3) -> / vbs -> .to(bf16) + - New: fused_block_mean (Triton: bf16 read -> fp32 accumulate -> div -> bf16 write) + + topk: + - Old: torch.topk() -> zeros() -> scatter_() + - New: fused_topk_mask (Triton: binary-search pivot -> bool mask) + +Reports per-kernel latency (ms), speedup, and numerical accuracy (max abs error, +cosine similarity for compress; mask match rate for topk). +""" + +from __future__ import annotations + +import argparse +import random + +import numpy as np +import torch + +try: + from triton.testing import do_bench +except ImportError as e: + raise ImportError("This benchmark requires triton (for triton.testing.do_bench).") from e + +from fastvideo_kernel.triton_kernels.fused_compress_topk import fused_block_mean, fused_topk_mask + + +def set_seed(seed: int = 42) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def parse_arguments() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Benchmark fused compress & topk vs. PyTorch baselines") + p.add_argument("--batch_size", type=int, default=1) + p.add_argument("--num_heads", type=int, default=12) + p.add_argument("--head_dim", type=int, default=128, choices=[64, 128]) + p.add_argument("--seq_lens", type=int, nargs="+", default=[49152], + help="Sequence lengths to benchmark (must be divisible by block_elements)") + p.add_argument("--block_elements", type=int, default=64, choices=[64, 256], + help="Tokens per block (64 or 256)") + p.add_argument("--topk", type=int, default=None, + help="KV blocks to select per Q block (default: ~10%% of kv_blocks)") + p.add_argument("--warmup", type=int, default=10) + p.add_argument("--rep", type=int, default=50) + p.add_argument("--seed", type=int, default=42) + p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16"]) + return p.parse_args() + + +# --------------------------------------------------------------------------- +# Old PyTorch baselines (extracted from ops.py before commit d9cd0b9) +# --------------------------------------------------------------------------- + +def pytorch_block_mean( + x: torch.Tensor, + variable_block_sizes: torch.Tensor, + block_elements: int, +) -> torch.Tensor: + """Old PyTorch compress: .view() -> .float() -> .sum(dim=3) -> / vbs -> .to(dtype)""" + B, H, seq_len, D = x.shape + num_blocks = seq_len // block_elements + x_c = x.view(B, H, num_blocks, block_elements, D) + x_c = (x_c.float().sum(dim=3) / variable_block_sizes.view(1, 1, -1, 1)).to(x.dtype) + return x_c + + +def pytorch_topk_mask( + scores: torch.Tensor, + topk: int, +) -> torch.Tensor: + """Old PyTorch topk: torch.topk() -> zeros() -> scatter_()""" + topk = min(topk, scores.shape[-1]) + topk_idx = torch.topk(scores, topk, dim=-1).indices + mask = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, topk_idx, True) + return mask + + +# --------------------------------------------------------------------------- +# Accuracy helpers +# --------------------------------------------------------------------------- + +def accuracy_compress(ref: torch.Tensor, test: torch.Tensor) -> dict: + ref_f = ref.float() + test_f = test.float() + abs_err = (ref_f - test_f).abs() + cos_sim = torch.nn.functional.cosine_similarity(ref_f.flatten(), test_f.flatten(), dim=0) + return { + "max_abs_err": abs_err.max().item(), + "mean_abs_err": abs_err.mean().item(), + "cosine_sim": cos_sim.item(), + } + + +def accuracy_topk(ref_mask: torch.Tensor, test_mask: torch.Tensor) -> dict: + match = (ref_mask == test_mask).all(dim=-1).float().mean().item() + true_per_row_ref = ref_mask.sum(dim=-1).float() + true_per_row_test = test_mask.sum(dim=-1).float() + count_match = (true_per_row_ref == true_per_row_test).float().mean().item() + return { + "row_exact_match": match, + "count_match": count_match, + } + + +# --------------------------------------------------------------------------- +# Benchmark runner +# --------------------------------------------------------------------------- + +def bench_compress( + B: int, H: int, seq_len: int, D: int, block_elements: int, + dtype: torch.dtype, warmup: int, rep: int, +) -> None: + num_blocks = seq_len // block_elements + x = torch.randn(B, H, seq_len, D, dtype=dtype, device="cuda") + vbs = torch.full((num_blocks,), block_elements, dtype=torch.int32, device="cuda") + # Make a few blocks partially filled to exercise variable block sizes + if num_blocks > 4: + vbs[1] = block_elements - 2 + vbs[-2] = block_elements - 5 + + # Accuracy + ref = pytorch_block_mean(x, vbs, block_elements) + fused = fused_block_mean(x, vbs, block_elements) + acc = accuracy_compress(ref, fused) + + # Latency + old_ms = do_bench(lambda: pytorch_block_mean(x, vbs, block_elements), warmup=warmup, rep=rep) + new_ms = do_bench(lambda: fused_block_mean(x, vbs, block_elements), warmup=warmup, rep=rep) + + speedup = old_ms / new_ms if new_ms > 0 else float("inf") + print(f" compress | old: {old_ms:8.3f} ms | new: {new_ms:8.3f} ms | speedup: {speedup:5.2f}x " + f"| max_abs_err: {acc['max_abs_err']:.2e} | cos_sim: {acc['cosine_sim']:.8f}") + + +def bench_topk( + B: int, H: int, num_blocks: int, topk: int, + dtype: torch.dtype, warmup: int, rep: int, +) -> None: + scores = torch.randn(B, H, num_blocks, num_blocks, dtype=dtype, device="cuda") + + # Accuracy + ref = pytorch_topk_mask(scores, topk) + fused = fused_topk_mask(scores, topk) + acc = accuracy_topk(ref, fused) + + # Latency + old_ms = do_bench(lambda: pytorch_topk_mask(scores, topk), warmup=warmup, rep=rep) + new_ms = do_bench(lambda: fused_topk_mask(scores, topk), warmup=warmup, rep=rep) + + speedup = old_ms / new_ms if new_ms > 0 else float("inf") + print(f" topk | old: {old_ms:8.3f} ms | new: {new_ms:8.3f} ms | speedup: {speedup:5.2f}x " + f"| row_exact_match: {acc['row_exact_match']:.4f} | count_match: {acc['count_match']:.4f}") + + +def main() -> None: + args = parse_arguments() + set_seed(args.seed) + + dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 + B, H, D = args.batch_size, args.num_heads, args.head_dim + block_elements = args.block_elements + + print("Fused Compress & TopK Benchmark") + print(f"device: {torch.cuda.get_device_name(0)}") + print(f"batch={B}, heads={H}, head_dim={D}, block_elements={block_elements}, dtype={args.dtype}") + print(f"warmup={args.warmup}, rep={args.rep}") + + for seq_len in args.seq_lens: + if seq_len % block_elements != 0: + print(f"\n[skip] seq_len={seq_len} not divisible by block_elements={block_elements}") + continue + + num_blocks = seq_len // block_elements + topk = args.topk if args.topk is not None else max(1, num_blocks // 10) + topk = min(topk, num_blocks) + + print(f"\n{'=' * 100}") + print(f"seq_len={seq_len}, num_blocks={num_blocks}, topk={topk}") + print("-" * 100) + + bench_compress(B, H, seq_len, D, block_elements, dtype, args.warmup, args.rep) + bench_topk(B, H, num_blocks, topk, dtype, args.warmup, args.rep) + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this benchmark.") + main() From 8271e9d81f66974cdd8f6918990125df9b6a2db1 Mon Sep 17 00:00:00 2001 From: "xsank.mz" Date: Mon, 1 Jun 2026 16:08:03 +0800 Subject: [PATCH 3/7] add backward support --- .../benchmarks/bench_fused_compress_topk.py | 61 +++++++++- .../triton_kernels/fused_compress_topk.py | 110 ++++++++++++++++-- 2 files changed, 157 insertions(+), 14 deletions(-) diff --git a/fastvideo-kernel/benchmarks/bench_fused_compress_topk.py b/fastvideo-kernel/benchmarks/bench_fused_compress_topk.py index b76abebd7..cf59baf18 100644 --- a/fastvideo-kernel/benchmarks/bench_fused_compress_topk.py +++ b/fastvideo-kernel/benchmarks/bench_fused_compress_topk.py @@ -13,6 +13,8 @@ Reports per-kernel latency (ms), speedup, and numerical accuracy (max abs error, cosine similarity for compress; mask match rate for topk). + +Also benchmarks backward pass of compress (fused Triton bwd kernel vs. PyTorch autograd). """ from __future__ import annotations @@ -116,14 +118,13 @@ def accuracy_topk(ref_mask: torch.Tensor, test_mask: torch.Tensor) -> dict: # Benchmark runner # --------------------------------------------------------------------------- -def bench_compress( +def bench_compress_fwd( B: int, H: int, seq_len: int, D: int, block_elements: int, dtype: torch.dtype, warmup: int, rep: int, ) -> None: num_blocks = seq_len // block_elements x = torch.randn(B, H, seq_len, D, dtype=dtype, device="cuda") vbs = torch.full((num_blocks,), block_elements, dtype=torch.int32, device="cuda") - # Make a few blocks partially filled to exercise variable block sizes if num_blocks > 4: vbs[1] = block_elements - 2 vbs[-2] = block_elements - 5 @@ -138,7 +139,58 @@ def bench_compress( new_ms = do_bench(lambda: fused_block_mean(x, vbs, block_elements), warmup=warmup, rep=rep) speedup = old_ms / new_ms if new_ms > 0 else float("inf") - print(f" compress | old: {old_ms:8.3f} ms | new: {new_ms:8.3f} ms | speedup: {speedup:5.2f}x " + print(f" compress fwd | old: {old_ms:8.3f} ms | new: {new_ms:8.3f} ms | speedup: {speedup:5.2f}x " + f"| max_abs_err: {acc['max_abs_err']:.2e} | cos_sim: {acc['cosine_sim']:.8f}") + + +def bench_compress_bwd( + B: int, H: int, seq_len: int, D: int, block_elements: int, + dtype: torch.dtype, warmup: int, rep: int, +) -> None: + num_blocks = seq_len // block_elements + vbs = torch.full((num_blocks,), block_elements, dtype=torch.int32, device="cuda") + if num_blocks > 4: + vbs[1] = block_elements - 2 + vbs[-2] = block_elements - 5 + + # --- Accuracy: compare gradients --- + x_old = torch.randn(B, H, seq_len, D, dtype=dtype, device="cuda", requires_grad=True) + grad_out = torch.randn(B, H, num_blocks, D, dtype=dtype, device="cuda") + + out_old = pytorch_block_mean(x_old, vbs, block_elements) + out_old.backward(grad_out) + grad_ref = x_old.grad.clone() + + x_new = x_old.detach().clone().requires_grad_(True) + out_new = fused_block_mean(x_new, vbs, block_elements) + out_new.backward(grad_out) + grad_fused = x_new.grad.clone() + + acc = accuracy_compress(grad_ref, grad_fused) + + # --- Latency: isolate backward-only via retain_graph --- + x_o = x_old.detach().clone().requires_grad_(True) + out_o = pytorch_block_mean(x_o, vbs, block_elements) + loss_o = (out_o * grad_out).sum() + for _ in range(warmup): + torch.autograd.grad(loss_o, x_o, retain_graph=True) + old_ms = do_bench( + lambda: torch.autograd.grad(loss_o, x_o, retain_graph=True), + warmup=0, rep=rep, + ) + + x_n = x_old.detach().clone().requires_grad_(True) + out_n = fused_block_mean(x_n, vbs, block_elements) + loss_n = (out_n * grad_out).sum() + for _ in range(warmup): + torch.autograd.grad(loss_n, x_n, retain_graph=True) + new_ms = do_bench( + lambda: torch.autograd.grad(loss_n, x_n, retain_graph=True), + warmup=0, rep=rep, + ) + + speedup = old_ms / new_ms if new_ms > 0 else float("inf") + print(f" compress bwd | old: {old_ms:8.3f} ms | new: {new_ms:8.3f} ms | speedup: {speedup:5.2f}x " f"| max_abs_err: {acc['max_abs_err']:.2e} | cos_sim: {acc['cosine_sim']:.8f}") @@ -188,7 +240,8 @@ def main() -> None: print(f"seq_len={seq_len}, num_blocks={num_blocks}, topk={topk}") print("-" * 100) - bench_compress(B, H, seq_len, D, block_elements, dtype, args.warmup, args.rep) + bench_compress_fwd(B, H, seq_len, D, block_elements, dtype, args.warmup, args.rep) + bench_compress_bwd(B, H, seq_len, D, block_elements, dtype, args.warmup, args.rep) bench_topk(B, H, num_blocks, topk, dtype, args.warmup, args.rep) diff --git a/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py index 50e156b60..131040433 100644 --- a/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py +++ b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py @@ -58,21 +58,75 @@ def _fused_block_mean_kernel( tl.store(out_base, acc.to(tl.bfloat16)) -def fused_block_mean( - x: torch.Tensor, +@triton.jit +def _fused_block_mean_bwd_kernel( + GradOut_ptr, + GradX_ptr, + VBS_ptr, + stride_go_bh, + stride_go_blk, + stride_gx_bh, + stride_gx_seq, + num_blocks, + BLOCK_ELEMENTS: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + """Backward of block mean: broadcast grad_out / vbs to each token in the block. + + Mirrors the forward kernel: one program per (block, bh). + GradOut is [B*H, num_blocks, HEAD_DIM]. + GradX is [B*H, num_blocks*BLOCK_ELEMENTS, HEAD_DIM]. + """ + block_idx = tl.program_id(0) + bh_idx = tl.program_id(1) + + if block_idx >= num_blocks: + return + + vbs = tl.load(VBS_ptr + block_idx).to(tl.float32) + + go_base = GradOut_ptr + bh_idx * stride_go_bh + block_idx * stride_go_blk + dim_offsets = tl.arange(0, HEAD_DIM) + grad_val = tl.load(go_base + dim_offsets).to(tl.float32) / vbs + + gx_base = GradX_ptr + bh_idx * stride_gx_bh + block_idx * BLOCK_ELEMENTS * stride_gx_seq + grad_out_cast = grad_val.to(tl.bfloat16) + for i in range(BLOCK_ELEMENTS): + tl.store(gx_base + i * stride_gx_seq + dim_offsets, grad_out_cast) + + +def _fused_block_mean_bwd( + grad_output: torch.Tensor, variable_block_sizes: torch.Tensor, block_elements: int, ) -> torch.Tensor: - """Compute block-wise mean with fp32 accumulation, fused in one kernel. + B, H, num_blocks, D = grad_output.shape + seq_len = num_blocks * block_elements - Args: - x: [B, H, seq_len, D] in bf16 - variable_block_sizes: [num_blocks] number of valid tokens per block - block_elements: tokens per block (e.g. 64) + grad_x = torch.empty(B, H, seq_len, D, dtype=grad_output.dtype, device=grad_output.device) - Returns: - [B, H, num_blocks, D] in bf16 - """ + go_flat = grad_output.contiguous().view(B * H, num_blocks, D) + gx_flat = grad_x.view(B * H, seq_len, D) + + grid = (num_blocks, B * H) + + _fused_block_mean_bwd_kernel[grid]( + go_flat, gx_flat, variable_block_sizes, + go_flat.stride(0), go_flat.stride(1), + gx_flat.stride(0), gx_flat.stride(1), + num_blocks, + BLOCK_ELEMENTS=block_elements, + HEAD_DIM=D, + ) + + return grad_x + + +def _fused_block_mean_fwd( + x: torch.Tensor, + variable_block_sizes: torch.Tensor, + block_elements: int, +) -> torch.Tensor: B, H, seq_len, D = x.shape num_blocks = seq_len // block_elements assert seq_len % block_elements == 0 @@ -97,6 +151,42 @@ def fused_block_mean( return out +class _FusedBlockMeanAutograd(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, variable_block_sizes, block_elements): + ctx.save_for_backward(variable_block_sizes) + ctx.block_elements = block_elements + return _fused_block_mean_fwd(x, variable_block_sizes, block_elements) + + @staticmethod + def backward(ctx, grad_output): + variable_block_sizes, = ctx.saved_tensors + block_elements = ctx.block_elements + return _fused_block_mean_bwd(grad_output, variable_block_sizes, block_elements), None, None + + +def fused_block_mean( + x: torch.Tensor, + variable_block_sizes: torch.Tensor, + block_elements: int, +) -> torch.Tensor: + """Compute block-wise mean with fp32 accumulation, fused in one kernel. + + Forward: fused Triton kernel (bf16 read → fp32 accumulate → div → bf16 write). + Backward: broadcasts grad_output / vbs back to each token position. + + Args: + x: [B, H, seq_len, D] in bf16 + variable_block_sizes: [num_blocks] number of valid tokens per block + block_elements: tokens per block (e.g. 64) + + Returns: + [B, H, num_blocks, D] in bf16 + """ + return _FusedBlockMeanAutograd.apply(x, variable_block_sizes, block_elements) + + @triton.jit def _fused_topk_mask_kernel( Scores_ptr, From 78e6b85ec59da9322395df45ba60a7c7175fc3a4 Mon Sep 17 00:00:00 2001 From: "xsank.mz" Date: Mon, 1 Jun 2026 19:02:51 +0800 Subject: [PATCH 4/7] add more benchmark case --- .../benchmarks/bench_fused_compress_topk.py | 64 ++++++++++++++++++ .../triton_kernels/fused_compress_topk.py | 66 ++++++++++++++----- 2 files changed, 112 insertions(+), 18 deletions(-) diff --git a/fastvideo-kernel/benchmarks/bench_fused_compress_topk.py b/fastvideo-kernel/benchmarks/bench_fused_compress_topk.py index cf59baf18..79d598a38 100644 --- a/fastvideo-kernel/benchmarks/bench_fused_compress_topk.py +++ b/fastvideo-kernel/benchmarks/bench_fused_compress_topk.py @@ -214,6 +214,52 @@ def bench_topk( f"| row_exact_match: {acc['row_exact_match']:.4f} | count_match: {acc['count_match']:.4f}") +def bench_topk_neginf( + B: int, H: int, num_blocks: int, topk: int, + dtype: torch.dtype, inf_ratio: float = 0.3, +) -> None: + """Test topk correctness when a fraction of scores are -inf (masked positions).""" + scores = torch.randn(B, H, num_blocks, num_blocks, dtype=dtype, device="cuda") + inf_mask = torch.rand(B, H, num_blocks, num_blocks, device="cuda") < inf_ratio + scores[inf_mask] = float("-inf") + + ref = pytorch_topk_mask(scores, topk) + fused = fused_topk_mask(scores, topk) + acc = accuracy_topk(ref, fused) + + status = "PASS" if acc["row_exact_match"] == 1.0 else "FAIL" + print(f" topk -inf | {inf_ratio*100:.0f}% masked | {status} " + f"| row_exact_match: {acc['row_exact_match']:.4f} | count_match: {acc['count_match']:.4f}") + + +_MAX_KV_BLOCK_SIZE = 4096 + + +def bench_topk_large_kv( + B: int, H: int, kv_blocks: int, topk: int, + dtype: torch.dtype, warmup: int, rep: int, +) -> None: + """Test topk with large kv_blocks that may exceed Triton block size limit.""" + import triton + kv_block_size = triton.next_power_of_2(kv_blocks) + fallback = kv_block_size > _MAX_KV_BLOCK_SIZE + + q_blocks = max(1, kv_blocks // 8) + scores = torch.randn(B, H, q_blocks, kv_blocks, dtype=dtype, device="cuda") + + ref = pytorch_topk_mask(scores, topk) + fused = fused_topk_mask(scores, topk) + acc = accuracy_topk(ref, fused) + + old_ms = do_bench(lambda: pytorch_topk_mask(scores, topk), warmup=warmup, rep=rep) + new_ms = do_bench(lambda: fused_topk_mask(scores, topk), warmup=warmup, rep=rep) + + speedup = old_ms / new_ms if new_ms > 0 else float("inf") + path = "fallback" if fallback else "triton" + print(f" topk kv={kv_blocks:<5d} | {path:8s} | old: {old_ms:8.3f} ms | new: {new_ms:8.3f} ms " + f"| speedup: {speedup:5.2f}x | row_exact_match: {acc['row_exact_match']:.4f}") + + def main() -> None: args = parse_arguments() set_seed(args.seed) @@ -244,6 +290,24 @@ def main() -> None: bench_compress_bwd(B, H, seq_len, D, block_elements, dtype, args.warmup, args.rep) bench_topk(B, H, num_blocks, topk, dtype, args.warmup, args.rep) + # --- Edge case: topk with -inf scores --- + print(f"\n{'=' * 100}") + print("TopK with -inf scores (masked positions)") + print("-" * 100) + kv = args.seq_lens[0] // block_elements + tk = args.topk if args.topk is not None else max(1, kv // 10) + tk = min(tk, kv) + for inf_ratio in [0.1, 0.3, 0.5, 0.8]: + bench_topk_neginf(B, H, kv, tk, dtype, inf_ratio) + + # --- Edge case: topk with large kv_blocks (Triton block size limit) --- + print(f"\n{'=' * 100}") + print("TopK with large kv_blocks (Triton block size limit test)") + print("-" * 100) + for kv_blocks in [1024, 2048, 4096, 8192]: + tk = max(1, kv_blocks // 10) + bench_topk_large_kv(1, 1, kv_blocks, tk, dtype, args.warmup, args.rep) + if __name__ == "__main__": if not torch.cuda.is_available(): diff --git a/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py index 131040433..cdcdbbb0d 100644 --- a/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py +++ b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py @@ -27,12 +27,13 @@ def _fused_block_mean_kernel( num_blocks, BLOCK_ELEMENTS: tl.constexpr, HEAD_DIM: tl.constexpr, + OUTPUT_DTYPE: tl.constexpr, ): """Fused block mean: one program computes mean of one block for one (b,h). X is viewed as [B*H, num_blocks*BLOCK_ELEMENTS, HEAD_DIM] contiguous. Out is [B*H, num_blocks, HEAD_DIM] contiguous. - Accumulates in fp32, outputs in original dtype. + 2D load + parallel tl.sum reduction, accumulates in fp32. """ block_idx = tl.program_id(0) bh_idx = tl.program_id(1) @@ -44,18 +45,14 @@ def _fused_block_mean_kernel( x_base = X_ptr + bh_idx * stride_x_bh + block_idx * BLOCK_ELEMENTS * stride_x_seq + row_offsets = tl.arange(0, BLOCK_ELEMENTS) dim_offsets = tl.arange(0, HEAD_DIM) - acc = tl.zeros([HEAD_DIM], dtype=tl.float32) - - for i in range(BLOCK_ELEMENTS): - row_ptr = x_base + i * stride_x_seq + dim_offsets - x_val = tl.load(row_ptr).to(tl.float32) - acc += x_val - - acc = acc / vbs + offsets = row_offsets[:, None] * stride_x_seq + dim_offsets[None, :] + block_data = tl.load(x_base + offsets).to(tl.float32) + acc = tl.sum(block_data, axis=0) / vbs out_base = Out_ptr + bh_idx * stride_o_bh + block_idx * stride_o_blk + dim_offsets - tl.store(out_base, acc.to(tl.bfloat16)) + tl.store(out_base, acc.to(OUTPUT_DTYPE)) @triton.jit @@ -70,12 +67,14 @@ def _fused_block_mean_bwd_kernel( num_blocks, BLOCK_ELEMENTS: tl.constexpr, HEAD_DIM: tl.constexpr, + OUTPUT_DTYPE: tl.constexpr, ): """Backward of block mean: broadcast grad_out / vbs to each token in the block. Mirrors the forward kernel: one program per (block, bh). GradOut is [B*H, num_blocks, HEAD_DIM]. GradX is [B*H, num_blocks*BLOCK_ELEMENTS, HEAD_DIM]. + 2D store writes all BLOCK_ELEMENTS rows in parallel. """ block_idx = tl.program_id(0) bh_idx = tl.program_id(1) @@ -85,14 +84,22 @@ def _fused_block_mean_bwd_kernel( vbs = tl.load(VBS_ptr + block_idx).to(tl.float32) - go_base = GradOut_ptr + bh_idx * stride_go_bh + block_idx * stride_go_blk dim_offsets = tl.arange(0, HEAD_DIM) + go_base = GradOut_ptr + bh_idx * stride_go_bh + block_idx * stride_go_blk grad_val = tl.load(go_base + dim_offsets).to(tl.float32) / vbs + row_offsets = tl.arange(0, BLOCK_ELEMENTS) gx_base = GradX_ptr + bh_idx * stride_gx_bh + block_idx * BLOCK_ELEMENTS * stride_gx_seq - grad_out_cast = grad_val.to(tl.bfloat16) - for i in range(BLOCK_ELEMENTS): - tl.store(gx_base + i * stride_gx_seq + dim_offsets, grad_out_cast) + offsets = row_offsets[:, None] * stride_gx_seq + dim_offsets[None, :] + grad_2d = tl.broadcast_to(grad_val[None, :], [BLOCK_ELEMENTS, HEAD_DIM]) + tl.store(gx_base + offsets, grad_2d.to(OUTPUT_DTYPE)) + + +_TORCH_TO_TRITON_DTYPE = { + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, + torch.float32: tl.float32, +} def _fused_block_mean_bwd( @@ -117,6 +124,7 @@ def _fused_block_mean_bwd( num_blocks, BLOCK_ELEMENTS=block_elements, HEAD_DIM=D, + OUTPUT_DTYPE=_TORCH_TO_TRITON_DTYPE[grad_output.dtype], ) return grad_x @@ -146,6 +154,7 @@ def _fused_block_mean_fwd( num_blocks, BLOCK_ELEMENTS=block_elements, HEAD_DIM=D, + OUTPUT_DTYPE=_TORCH_TO_TRITON_DTYPE[x.dtype], ) return out @@ -221,9 +230,14 @@ def _fused_topk_mask_kernel( # Binary search for threshold: find value T such that count(scores > T) <= topk # and count(scores >= T) >= topk - # Use +inf/-inf sentinels so min/max ignore padding positions - lo = tl.min(tl.where(valid_mask, scores_f32, float("inf")), axis=0) + # Exclude -inf from lo so the binary search can converge when masked + # scores are present (mid = (-inf + hi) * 0.5 = -inf would stall). + finite_mask = valid_mask & (scores_f32 > float("-inf")) + lo = tl.min(tl.where(finite_mask, scores_f32, float("inf")), axis=0) hi = tl.max(tl.where(valid_mask, scores_f32, float("-inf")), axis=0) + # If all valid scores are -inf, lo > hi; threshold stays at -inf and + # the tie-breaking logic below selects the first topk positions. + lo = tl.minimum(lo, hi) for _i in range(32): mid = (lo + hi) * 0.5 @@ -247,12 +261,26 @@ def _fused_topk_mask_kernel( tl.store(mask_base + kv_offsets * stride_m_kv, final_mask, mask=valid_mask) +MAX_KV_BLOCK_SIZE = 4096 + + +def _pytorch_topk_mask_fallback( + scores: torch.Tensor, + topk: int, +) -> torch.Tensor: + topk_idx = torch.topk(scores, topk, dim=-1).indices + return torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, topk_idx, True) + + def fused_topk_mask( scores: torch.Tensor, topk: int, ) -> torch.Tensor: """Build topk boolean mask from scores using fused Triton kernel. + Falls back to PyTorch when kv_blocks exceeds the Triton block size + limit (MAX_KV_BLOCK_SIZE) to avoid compilation failures. + Args: scores: [B, H, q_blocks, kv_blocks] block-level attention scores topk: number of top blocks to select per q-block @@ -263,9 +291,11 @@ def fused_topk_mask( B, H, q_blocks, kv_blocks = scores.shape topk = min(topk, kv_blocks) - mask = torch.zeros(B, H, q_blocks, kv_blocks, dtype=torch.bool, device=scores.device) - KV_BLOCK_SIZE = triton.next_power_of_2(kv_blocks) + if KV_BLOCK_SIZE > MAX_KV_BLOCK_SIZE: + return _pytorch_topk_mask_fallback(scores, topk) + + mask = torch.zeros(B, H, q_blocks, kv_blocks, dtype=torch.bool, device=scores.device) scores_flat = scores.contiguous().view(B * H, q_blocks, kv_blocks) mask_flat = mask.view(B * H, q_blocks, kv_blocks) From 82f6c86e554565b12890052f752d2eaf386985a3 Mon Sep 17 00:00:00 2001 From: "xsank.mz" Date: Mon, 8 Jun 2026 09:55:48 +0800 Subject: [PATCH 5/7] add compress&topk test case --- CLAUDE.md | 1 - .../triton_kernels/fused_compress_topk.py | 6 +- .../tests/test_fused_compress_topk.py | 277 ++++++++++++++++++ 3 files changed, 282 insertions(+), 2 deletions(-) delete mode 100644 CLAUDE.md create mode 100644 fastvideo-kernel/tests/test_fused_compress_topk.py diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 43c994c2d..000000000 --- a/CLAUDE.md +++ /dev/null @@ -1 +0,0 @@ -@AGENTS.md diff --git a/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py index cdcdbbb0d..ea1b68d7b 100644 --- a/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py +++ b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py @@ -239,10 +239,14 @@ def _fused_topk_mask_kernel( # the tie-breaking logic below selects the first topk positions. lo = tl.minimum(lo, hi) + # 32 iterations of fp32 bisection gives precision ~2^-32 ≈ 2.3e-10, which is + # far below the minimum gap between distinct bf16 values (~2^-8 ≈ 3.9e-3). + # This guarantees threshold converges exactly to a bf16 score value in fp32 + # representation, so the > / == comparisons below are exact with no risk of + # the threshold landing between tied bf16 values. for _i in range(32): mid = (lo + hi) * 0.5 count_ge = tl.sum(((scores_f32 >= mid) & valid_mask).to(tl.int32), axis=0) - # If count >= topk, threshold is at or above mid lo = tl.where(count_ge >= topk, mid, lo) hi = tl.where(count_ge >= topk, hi, mid) diff --git a/fastvideo-kernel/tests/test_fused_compress_topk.py b/fastvideo-kernel/tests/test_fused_compress_topk.py new file mode 100644 index 000000000..28f183b14 --- /dev/null +++ b/fastvideo-kernel/tests/test_fused_compress_topk.py @@ -0,0 +1,277 @@ +"""Tests for fused_block_mean and fused_topk_mask equivalence against PyTorch reference.""" + +import torch +import pytest + + +def _require_cuda(): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + +# --------------------------------------------------------------------------- +# fused_topk_mask: tie-boundary correctness +# --------------------------------------------------------------------------- + + +class TestFusedTopkMaskTies: + """Verify that fused_topk_mask selects exactly topk per row, even with ties.""" + + def _run_topk_mask(self, scores, topk): + from fastvideo_kernel.triton_kernels.fused_compress_topk import fused_topk_mask + return fused_topk_mask(scores.cuda(), topk) + + def _ref_topk_mask(self, scores, topk): + """PyTorch reference: topk + scatter.""" + idx = torch.topk(scores, topk, dim=-1).indices + mask = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, idx, True) + return mask + + def test_all_equal_scores(self): + """All scores identical — must still select exactly topk.""" + _require_cuda() + scores = torch.full((1, 1, 4, 8), 0.5, dtype=torch.bfloat16) + for topk in [1, 3, 5, 8]: + mask = self._run_topk_mask(scores, topk) + row_counts = mask.sum(dim=-1) + assert (row_counts == topk).all(), ( + f"topk={topk}, row_counts={row_counts}" + ) + + def test_reviewer_example(self): + """Exact case from the reviewer: two 0.8875 with topk=1 → must select 1.""" + _require_cuda() + vals = torch.tensor( + [0.5222, 0.5222, 0.8875, 0.5222, 0.5222, 0.5222, 0.8875], + dtype=torch.bfloat16, + ) + scores = vals.view(1, 1, 1, -1) + mask = self._run_topk_mask(scores, topk=1) + assert mask.sum().item() == 1, f"Expected 1 selected, got {mask.sum().item()}" + + def test_tie_at_boundary_various_topk(self): + """Scores with a large tied cluster at the k-th boundary.""" + _require_cuda() + scores = torch.tensor( + [1.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.2], + dtype=torch.bfloat16, + ).view(1, 1, 1, -1) + for topk in [1, 2, 3, 4, 7, 8]: + mask = self._run_topk_mask(scores, topk) + assert mask.sum().item() == topk, ( + f"topk={topk}, selected={mask.sum().item()}" + ) + + def test_batch_heads_ties(self): + """Multiple batch/heads with tie-heavy scores.""" + _require_cuda() + B, H, Q, KV = 2, 4, 8, 16 + scores = torch.zeros(B, H, Q, KV, dtype=torch.bfloat16) + scores[:, :, :, :4] = 1.0 + scores[:, :, :, 4:] = 0.5 + for topk in [1, 4, 8]: + mask = self._run_topk_mask(scores, topk) + row_counts = mask.sum(dim=-1) + assert (row_counts == topk).all(), ( + f"topk={topk}, row_counts min={row_counts.min()}, max={row_counts.max()}" + ) + + +# --------------------------------------------------------------------------- +# fused_topk_mask: equivalence with PyTorch topk+scatter +# --------------------------------------------------------------------------- + + +class TestFusedTopkMaskEquivalence: + """fused_topk_mask must select the same set as torch.topk (modulo tie order).""" + + def test_random_scores_row_count(self): + """Random scores (unlikely ties) — row count must match exactly.""" + _require_cuda() + from fastvideo_kernel.triton_kernels.fused_compress_topk import fused_topk_mask + + B, H, Q, KV = 2, 8, 32, 64 + topk = 8 + scores = torch.randn(B, H, Q, KV, dtype=torch.bfloat16, device="cuda") + mask = fused_topk_mask(scores, topk) + assert (mask.sum(dim=-1) == topk).all() + + def test_random_scores_value_match(self): + """With distinct scores, fused and reference must select identical indices.""" + _require_cuda() + from fastvideo_kernel.triton_kernels.fused_compress_topk import fused_topk_mask + + B, H, Q, KV = 1, 2, 16, 32 + topk = 4 + scores = torch.arange(KV, dtype=torch.float32).view(1, 1, 1, KV) + scores = scores.expand(B, H, Q, KV).contiguous().to(torch.bfloat16).cuda() + + fused_mask = fused_topk_mask(scores, topk) + + ref_idx = torch.topk(scores.cuda(), topk, dim=-1).indices + ref_mask = torch.zeros_like(scores, dtype=torch.bool, device="cuda").scatter_(-1, ref_idx, True) + + assert (fused_mask == ref_mask).all() + + +# --------------------------------------------------------------------------- +# fused_block_mean: equivalence with PyTorch reference +# --------------------------------------------------------------------------- + + +class TestFusedBlockMeanEquivalence: + + def test_against_pytorch_reference(self): + """fused_block_mean must match the original view→float→sum→div→bf16 path.""" + _require_cuda() + from fastvideo_kernel.triton_kernels.fused_compress_topk import fused_block_mean + + B, H, D = 2, 4, 128 + num_blocks = 16 + block_elements = 64 + seq_len = num_blocks * block_elements + + x = torch.randn(B, H, seq_len, D, dtype=torch.bfloat16, device="cuda") + vbs = torch.randint(16, 65, (num_blocks,), dtype=torch.int32, device="cuda") + + fused_out = fused_block_mean(x, vbs, block_elements) + + x_blocks = x.view(B, H, num_blocks, block_elements, D) + ref_out = (x_blocks.float().sum(dim=3) / vbs.view(1, 1, -1, 1).float()).to(torch.bfloat16) + + assert fused_out.shape == ref_out.shape + assert torch.allclose(fused_out, ref_out, atol=1e-2, rtol=1e-2), ( + f"max diff={( fused_out - ref_out).abs().max().item()}" + ) + + +# --------------------------------------------------------------------------- +# End-to-end: fused video_sparse_attn vs old PyTorch pipeline +# --------------------------------------------------------------------------- + + +def _old_pytorch_video_sparse_attn( + q, k, v, variable_block_sizes, q_variable_block_sizes, + topk, block_elements, +): + """Reproduce the pre-optimization PyTorch pipeline from ops.py.""" + batch, heads, q_seq_len, dim = q.shape + kv_seq_len = k.shape[2] + q_num_blocks = q_seq_len // block_elements + kv_num_blocks = kv_seq_len // block_elements + + q_c = q.view(batch, heads, q_num_blocks, block_elements, dim) + k_c = k.view(batch, heads, kv_num_blocks, block_elements, dim) + v_c = v.view(batch, heads, kv_num_blocks, block_elements, dim) + q_c = (q_c.float().sum(dim=3) / q_variable_block_sizes.view(1, 1, -1, 1)).to(q.dtype) + k_c = (k_c.float().sum(dim=3) / variable_block_sizes.view(1, 1, -1, 1)).to(k.dtype) + v_c = (v_c.float().sum(dim=3) / variable_block_sizes.view(1, 1, -1, 1)).to(v.dtype) + + scores = torch.matmul(q_c, k_c.transpose(-2, -1)) / (dim ** 0.5) + attn = torch.softmax(scores, dim=-1) + out_c = torch.matmul(attn, v_c) + out_c = out_c.view(batch, heads, q_num_blocks, 1, dim) + out_c = out_c.repeat(1, 1, 1, block_elements, 1).view(batch, heads, q_seq_len, dim) + + topk_idx = torch.topk(scores, topk, dim=-1).indices + mask = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, topk_idx, True) + + return out_c, scores, mask + + +class TestVideoSparseAttnEquivalence: + """End-to-end: fused compress+topk path vs old PyTorch pipeline.""" + + def _make_inputs(self, B, H, num_blocks, D, block_elements, device="cuda"): + seq_len = num_blocks * block_elements + q = torch.randn(B, H, seq_len, D, dtype=torch.bfloat16, device=device) + k = torch.randn(B, H, seq_len, D, dtype=torch.bfloat16, device=device) + v = torch.randn(B, H, seq_len, D, dtype=torch.bfloat16, device=device) + vbs = torch.full((num_blocks,), block_elements, dtype=torch.int32, device=device) + return q, k, v, vbs + + def test_compress_branch_equivalence(self): + """Compression branch (block_mean → matmul → softmax → matmul) must match.""" + _require_cuda() + from fastvideo_kernel.triton_kernels.fused_compress_topk import fused_block_mean + + B, H, num_blocks, D = 1, 4, 32, 128 + block_elements = 64 + topk = 4 + q, k, v, vbs = self._make_inputs(B, H, num_blocks, D, block_elements) + + q_c_fused = fused_block_mean(q, vbs, block_elements) + k_c_fused = fused_block_mean(k, vbs, block_elements) + v_c_fused = fused_block_mean(v, vbs, block_elements) + scores_fused = torch.matmul(q_c_fused, k_c_fused.transpose(-2, -1)) / (D ** 0.5) + attn_fused = torch.softmax(scores_fused, dim=-1) + out_c_fused = torch.matmul(attn_fused, v_c_fused) + + old_out_c, old_scores, _ = _old_pytorch_video_sparse_attn( + q, k, v, vbs, vbs, topk, block_elements, + ) + + assert torch.allclose(scores_fused, old_scores, atol=1e-2, rtol=1e-2), ( + f"scores max diff={( scores_fused - old_scores).abs().max().item()}" + ) + + def test_topk_mask_row_count_matches(self): + """Fused topk mask must select exactly topk per row, same as torch.topk.""" + _require_cuda() + from fastvideo_kernel.triton_kernels.fused_compress_topk import ( + fused_block_mean, + fused_topk_mask, + ) + + B, H, num_blocks, D = 1, 8, 48, 128 + block_elements = 64 + topk = 6 + q, k, v, vbs = self._make_inputs(B, H, num_blocks, D, block_elements) + + q_c = fused_block_mean(q, vbs, block_elements) + k_c = fused_block_mean(k, vbs, block_elements) + scores = torch.matmul(q_c, k_c.transpose(-2, -1)) / (D ** 0.5) + + fused_mask = fused_topk_mask(scores, topk) + row_counts = fused_mask.sum(dim=-1) + assert (row_counts == topk).all(), ( + f"row_counts min={row_counts.min()}, max={row_counts.max()}" + ) + + _, _, ref_mask = _old_pytorch_video_sparse_attn( + q, k, v, vbs, vbs, topk, block_elements, + ) + ref_counts = ref_mask.sum(dim=-1) + assert (ref_counts == topk).all() + + def test_e2e_mask_equivalence_realistic_shape(self): + """On realistic bf16 q_c@k_c scores, fused and ref mask should agree.""" + _require_cuda() + from fastvideo_kernel.triton_kernels.fused_compress_topk import ( + fused_block_mean, + fused_topk_mask, + ) + + B, H, num_blocks, D = 1, 16, 32, 128 + block_elements = 64 + topk = 4 + q, k, v, vbs = self._make_inputs(B, H, num_blocks, D, block_elements) + + q_c = fused_block_mean(q, vbs, block_elements) + k_c = fused_block_mean(k, vbs, block_elements) + scores = torch.matmul(q_c, k_c.transpose(-2, -1)) / (D ** 0.5) + + fused_mask = fused_topk_mask(scores, topk) + + ref_idx = torch.topk(scores, topk, dim=-1).indices + ref_mask = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, ref_idx, True) + + total_rows = B * H * num_blocks + matching_rows = (fused_mask == ref_mask).all(dim=-1).sum().item() + match_rate = matching_rows / total_rows + + print(f"\n[e2e mask equivalence] {matching_rows}/{total_rows} rows match " + f"({match_rate:.2%})") + + assert (fused_mask.sum(dim=-1) == topk).all(), "row count mismatch" + assert match_rate > 0.99, f"match rate too low: {match_rate:.2%}" From 90bb9a1e6671d78635e9fd418baf975fc7d22c33 Mon Sep 17 00:00:00 2001 From: "xsank.mz" Date: Mon, 8 Jun 2026 10:08:34 +0800 Subject: [PATCH 6/7] restore CLAUDE file --- CLAUDE.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..43c994c2d --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +@AGENTS.md From 62ce3aa3a1d30d19616e4486d2e597640c0c05d9 Mon Sep 17 00:00:00 2001 From: "xsank.mz" Date: Mon, 8 Jun 2026 19:36:52 +0800 Subject: [PATCH 7/7] add backward test case & improve the code comment and log --- .../triton_kernels/fused_compress_topk.py | 26 +++- .../tests/test_fused_compress_topk.py | 124 ++++++++++++++++++ 2 files changed, 145 insertions(+), 5 deletions(-) diff --git a/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py index ea1b68d7b..68a4acd04 100644 --- a/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py +++ b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py @@ -10,10 +10,14 @@ fused_topk_mask: read scores, find k-th value, write bool mask """ +import logging + import torch import triton import triton.language as tl +logger = logging.getLogger(__name__) + @triton.jit def _fused_block_mean_kernel( @@ -239,11 +243,11 @@ def _fused_topk_mask_kernel( # the tie-breaking logic below selects the first topk positions. lo = tl.minimum(lo, hi) - # 32 iterations of fp32 bisection gives precision ~2^-32 ≈ 2.3e-10, which is - # far below the minimum gap between distinct bf16 values (~2^-8 ≈ 3.9e-3). - # This guarantees threshold converges exactly to a bf16 score value in fp32 - # representation, so the > / == comparisons below are exact with no risk of - # the threshold landing between tied bf16 values. + # 32 iterations of fp32 bisection give range-relative resolution (hi-lo)/2^32. + # For VSA's softmax-input scores (q_c@k_c/sqrt(d), O(1) magnitude, range < ~10), + # this resolves to ~2e-9, well below the bf16 ULP at that magnitude (~8e-3), + # so the threshold converges exactly to the k-th bf16 score value and the + # > / == comparisons below are exact. for _i in range(32): mid = (lo + hi) * 0.5 count_ge = tl.sum(((scores_f32 >= mid) & valid_mask).to(tl.int32), axis=0) @@ -265,6 +269,13 @@ def _fused_topk_mask_kernel( tl.store(mask_base + kv_offsets * stride_m_kv, final_mask, mask=valid_mask) +# Triton kernel loads the entire kv row into registers via tl.arange(0, KV_BLOCK_SIZE). +# Each row spawns multiple same-sized register arrays (scores_f32, valid_mask, +# above_threshold, at_threshold, cumsum, etc.). GPU SMs have a fixed register file +# (e.g. 65536 × 32-bit), so the maximum array length per program is bounded. +# 4096 (2^12) is an empirical power-of-2 cap that avoids register spilling on +# mainstream GPUs. Beyond this the kernel either fails to compile or spills to +# local memory with severe perf regression, so we fall back to torch.topk. MAX_KV_BLOCK_SIZE = 4096 @@ -297,6 +308,11 @@ def fused_topk_mask( KV_BLOCK_SIZE = triton.next_power_of_2(kv_blocks) if KV_BLOCK_SIZE > MAX_KV_BLOCK_SIZE: + logger.debug( + "fused_topk_mask: kv_blocks=%d exceeds Triton limit %d, " + "falling back to PyTorch topk (slower)", + kv_blocks, MAX_KV_BLOCK_SIZE, + ) return _pytorch_topk_mask_fallback(scores, topk) mask = torch.zeros(B, H, q_blocks, kv_blocks, dtype=torch.bool, device=scores.device) diff --git a/fastvideo-kernel/tests/test_fused_compress_topk.py b/fastvideo-kernel/tests/test_fused_compress_topk.py index 28f183b14..db0afd40a 100644 --- a/fastvideo-kernel/tests/test_fused_compress_topk.py +++ b/fastvideo-kernel/tests/test_fused_compress_topk.py @@ -145,6 +145,130 @@ def test_against_pytorch_reference(self): ) +# --------------------------------------------------------------------------- +# fused_block_mean: backward (gradient) parity +# --------------------------------------------------------------------------- + + +def _ref_block_mean_with_grad(x, vbs, block_elements): + """Eager PyTorch block-mean that supports autograd for gradient reference.""" + B, H, seq_len, D = x.shape + num_blocks = seq_len // block_elements + x_blocks = x.view(B, H, num_blocks, block_elements, D) + return (x_blocks.float().sum(dim=3) / vbs.view(1, 1, -1, 1).float()).to(x.dtype) + + +class TestFusedBlockMeanBackward: + + def test_gradient_parity(self): + """Gradient through fused_block_mean must match the eager view→sum→div path.""" + _require_cuda() + from fastvideo_kernel.triton_kernels.fused_compress_topk import fused_block_mean + + B, H, D = 2, 4, 128 + num_blocks = 16 + block_elements = 64 + seq_len = num_blocks * block_elements + + vbs = torch.randint(16, 65, (num_blocks,), dtype=torch.int32, device="cuda") + grad_out = torch.randn(B, H, num_blocks, D, dtype=torch.bfloat16, device="cuda") + + x_fused = torch.randn(B, H, seq_len, D, dtype=torch.bfloat16, device="cuda") + x_ref = x_fused.clone() + x_fused.requires_grad_(True) + x_ref.requires_grad_(True) + + out_fused = fused_block_mean(x_fused, vbs, block_elements) + out_fused.backward(grad_out) + + out_ref = _ref_block_mean_with_grad(x_ref, vbs, block_elements) + out_ref.backward(grad_out) + + assert x_fused.grad is not None, "fused path did not produce gradient" + assert x_ref.grad is not None, "ref path did not produce gradient" + assert x_fused.grad.shape == x_ref.grad.shape + + max_diff = (x_fused.grad - x_ref.grad).abs().max().item() + assert torch.allclose(x_fused.grad, x_ref.grad, atol=1e-2, rtol=1e-2), ( + f"gradient max diff={max_diff}" + ) + + def test_gradient_parity_variable_block_sizes(self): + """Backward parity with non-uniform variable_block_sizes.""" + _require_cuda() + from fastvideo_kernel.triton_kernels.fused_compress_topk import fused_block_mean + + B, H, D = 1, 8, 64 + num_blocks = 32 + block_elements = 64 + seq_len = num_blocks * block_elements + + vbs = torch.randint(1, 65, (num_blocks,), dtype=torch.int32, device="cuda") + grad_out = torch.randn(B, H, num_blocks, D, dtype=torch.bfloat16, device="cuda") + + x_fused = torch.randn(B, H, seq_len, D, dtype=torch.bfloat16, device="cuda") + x_ref = x_fused.clone() + x_fused.requires_grad_(True) + x_ref.requires_grad_(True) + + out_fused = fused_block_mean(x_fused, vbs, block_elements) + out_fused.backward(grad_out) + + out_ref = _ref_block_mean_with_grad(x_ref, vbs, block_elements) + out_ref.backward(grad_out) + + max_diff = (x_fused.grad - x_ref.grad).abs().max().item() + assert torch.allclose(x_fused.grad, x_ref.grad, atol=1e-2, rtol=1e-2), ( + f"gradient max diff={max_diff}" + ) + + def test_gradient_through_matmul_chain(self): + """End-to-end backward through the compress branch: block_mean → matmul → softmax → matmul.""" + _require_cuda() + from fastvideo_kernel.triton_kernels.fused_compress_topk import fused_block_mean + + B, H, D = 1, 4, 128 + num_blocks = 16 + block_elements = 64 + seq_len = num_blocks * block_elements + + vbs = torch.full((num_blocks,), block_elements, dtype=torch.int32, device="cuda") + + def _forward_compress(x_q, x_k, x_v, use_fused): + mean_fn = fused_block_mean if use_fused else ( + lambda x, v, be: _ref_block_mean_with_grad(x, v, be) + ) + q_c = mean_fn(x_q, vbs, block_elements) + k_c = mean_fn(x_k, vbs, block_elements) + v_c = mean_fn(x_v, vbs, block_elements) + scores = torch.matmul(q_c, k_c.transpose(-2, -1)) / (D ** 0.5) + attn = torch.softmax(scores, dim=-1) + return torch.matmul(attn, v_c) + + base_q = torch.randn(B, H, seq_len, D, dtype=torch.bfloat16, device="cuda") + base_k = torch.randn(B, H, seq_len, D, dtype=torch.bfloat16, device="cuda") + base_v = torch.randn(B, H, seq_len, D, dtype=torch.bfloat16, device="cuda") + grad_out = torch.randn(B, H, num_blocks, D, dtype=torch.bfloat16, device="cuda") + + grads = {} + for label, use_fused in [("fused", True), ("ref", False)]: + q = base_q.clone().requires_grad_(True) + k = base_k.clone().requires_grad_(True) + v = base_v.clone().requires_grad_(True) + out = _forward_compress(q, k, v, use_fused) + out.backward(grad_out) + grads[label] = (q.grad, k.grad, v.grad) + + for name, (g_fused, g_ref) in zip( + ["dQ", "dK", "dV"], + zip(grads["fused"], grads["ref"]), + ): + max_diff = (g_fused - g_ref).abs().max().item() + assert torch.allclose(g_fused, g_ref, atol=1e-2, rtol=1e-2), ( + f"{name} gradient max diff={max_diff}" + ) + + # --------------------------------------------------------------------------- # End-to-end: fused video_sparse_attn vs old PyTorch pipeline # ---------------------------------------------------------------------------