Skip to content
315 changes: 315 additions & 0 deletions fastvideo-kernel/benchmarks/bench_fused_compress_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
#!/usr/bin/env python3
"""
Benchmark fused compress (block mean) and topk mask kernels vs. PyTorch baselines.

Compares:
compress:
- Old: .view() -> .float() -> .sum(dim=3) -> / vbs -> .to(bf16)
- New: fused_block_mean (Triton: bf16 read -> fp32 accumulate -> div -> bf16 write)

topk:
- Old: torch.topk() -> zeros() -> scatter_()
- New: fused_topk_mask (Triton: binary-search pivot -> bool mask)

Reports per-kernel latency (ms), speedup, and numerical accuracy (max abs error,
cosine similarity for compress; mask match rate for topk).

Also benchmarks backward pass of compress (fused Triton bwd kernel vs. PyTorch autograd).
"""

from __future__ import annotations

import argparse
import random

import numpy as np
import torch

try:
from triton.testing import do_bench
except ImportError as e:
raise ImportError("This benchmark requires triton (for triton.testing.do_bench).") from e

from fastvideo_kernel.triton_kernels.fused_compress_topk import fused_block_mean, fused_topk_mask


def set_seed(seed: int = 42) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


def parse_arguments() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Benchmark fused compress & topk vs. PyTorch baselines")
p.add_argument("--batch_size", type=int, default=1)
p.add_argument("--num_heads", type=int, default=12)
p.add_argument("--head_dim", type=int, default=128, choices=[64, 128])
p.add_argument("--seq_lens", type=int, nargs="+", default=[49152],
help="Sequence lengths to benchmark (must be divisible by block_elements)")
p.add_argument("--block_elements", type=int, default=64, choices=[64, 256],
help="Tokens per block (64 or 256)")
p.add_argument("--topk", type=int, default=None,
help="KV blocks to select per Q block (default: ~10%% of kv_blocks)")
p.add_argument("--warmup", type=int, default=10)
p.add_argument("--rep", type=int, default=50)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16"])
return p.parse_args()


# ---------------------------------------------------------------------------
# Old PyTorch baselines (extracted from ops.py before commit d9cd0b9)
# ---------------------------------------------------------------------------

def pytorch_block_mean(
x: torch.Tensor,
variable_block_sizes: torch.Tensor,
block_elements: int,
) -> torch.Tensor:
"""Old PyTorch compress: .view() -> .float() -> .sum(dim=3) -> / vbs -> .to(dtype)"""
B, H, seq_len, D = x.shape
num_blocks = seq_len // block_elements
x_c = x.view(B, H, num_blocks, block_elements, D)
x_c = (x_c.float().sum(dim=3) / variable_block_sizes.view(1, 1, -1, 1)).to(x.dtype)
return x_c


def pytorch_topk_mask(
scores: torch.Tensor,
topk: int,
) -> torch.Tensor:
"""Old PyTorch topk: torch.topk() -> zeros() -> scatter_()"""
topk = min(topk, scores.shape[-1])
topk_idx = torch.topk(scores, topk, dim=-1).indices
mask = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, topk_idx, True)
return mask


# ---------------------------------------------------------------------------
# Accuracy helpers
# ---------------------------------------------------------------------------

def accuracy_compress(ref: torch.Tensor, test: torch.Tensor) -> dict:
ref_f = ref.float()
test_f = test.float()
abs_err = (ref_f - test_f).abs()
cos_sim = torch.nn.functional.cosine_similarity(ref_f.flatten(), test_f.flatten(), dim=0)
return {
"max_abs_err": abs_err.max().item(),
"mean_abs_err": abs_err.mean().item(),
"cosine_sim": cos_sim.item(),
}


def accuracy_topk(ref_mask: torch.Tensor, test_mask: torch.Tensor) -> dict:
match = (ref_mask == test_mask).all(dim=-1).float().mean().item()
true_per_row_ref = ref_mask.sum(dim=-1).float()
true_per_row_test = test_mask.sum(dim=-1).float()
count_match = (true_per_row_ref == true_per_row_test).float().mean().item()
return {
"row_exact_match": match,
"count_match": count_match,
}


# ---------------------------------------------------------------------------
# Benchmark runner
# ---------------------------------------------------------------------------

def bench_compress_fwd(
B: int, H: int, seq_len: int, D: int, block_elements: int,
dtype: torch.dtype, warmup: int, rep: int,
) -> None:
num_blocks = seq_len // block_elements
x = torch.randn(B, H, seq_len, D, dtype=dtype, device="cuda")
vbs = torch.full((num_blocks,), block_elements, dtype=torch.int32, device="cuda")
if num_blocks > 4:
vbs[1] = block_elements - 2
vbs[-2] = block_elements - 5

# Accuracy
ref = pytorch_block_mean(x, vbs, block_elements)
fused = fused_block_mean(x, vbs, block_elements)
acc = accuracy_compress(ref, fused)

# Latency
old_ms = do_bench(lambda: pytorch_block_mean(x, vbs, block_elements), warmup=warmup, rep=rep)
new_ms = do_bench(lambda: fused_block_mean(x, vbs, block_elements), warmup=warmup, rep=rep)

speedup = old_ms / new_ms if new_ms > 0 else float("inf")
print(f" compress fwd | old: {old_ms:8.3f} ms | new: {new_ms:8.3f} ms | speedup: {speedup:5.2f}x "
f"| max_abs_err: {acc['max_abs_err']:.2e} | cos_sim: {acc['cosine_sim']:.8f}")


def bench_compress_bwd(
B: int, H: int, seq_len: int, D: int, block_elements: int,
dtype: torch.dtype, warmup: int, rep: int,
) -> None:
num_blocks = seq_len // block_elements
vbs = torch.full((num_blocks,), block_elements, dtype=torch.int32, device="cuda")
if num_blocks > 4:
vbs[1] = block_elements - 2
vbs[-2] = block_elements - 5

# --- Accuracy: compare gradients ---
x_old = torch.randn(B, H, seq_len, D, dtype=dtype, device="cuda", requires_grad=True)
grad_out = torch.randn(B, H, num_blocks, D, dtype=dtype, device="cuda")

out_old = pytorch_block_mean(x_old, vbs, block_elements)
out_old.backward(grad_out)
grad_ref = x_old.grad.clone()

x_new = x_old.detach().clone().requires_grad_(True)
out_new = fused_block_mean(x_new, vbs, block_elements)
out_new.backward(grad_out)
grad_fused = x_new.grad.clone()

acc = accuracy_compress(grad_ref, grad_fused)

# --- Latency: isolate backward-only via retain_graph ---
x_o = x_old.detach().clone().requires_grad_(True)
out_o = pytorch_block_mean(x_o, vbs, block_elements)
loss_o = (out_o * grad_out).sum()
for _ in range(warmup):
torch.autograd.grad(loss_o, x_o, retain_graph=True)
old_ms = do_bench(
lambda: torch.autograd.grad(loss_o, x_o, retain_graph=True),
warmup=0, rep=rep,
)

x_n = x_old.detach().clone().requires_grad_(True)
out_n = fused_block_mean(x_n, vbs, block_elements)
loss_n = (out_n * grad_out).sum()
for _ in range(warmup):
torch.autograd.grad(loss_n, x_n, retain_graph=True)
new_ms = do_bench(
lambda: torch.autograd.grad(loss_n, x_n, retain_graph=True),
warmup=0, rep=rep,
)

speedup = old_ms / new_ms if new_ms > 0 else float("inf")
print(f" compress bwd | old: {old_ms:8.3f} ms | new: {new_ms:8.3f} ms | speedup: {speedup:5.2f}x "
f"| max_abs_err: {acc['max_abs_err']:.2e} | cos_sim: {acc['cosine_sim']:.8f}")


def bench_topk(
B: int, H: int, num_blocks: int, topk: int,
dtype: torch.dtype, warmup: int, rep: int,
) -> None:
scores = torch.randn(B, H, num_blocks, num_blocks, dtype=dtype, device="cuda")

# Accuracy
ref = pytorch_topk_mask(scores, topk)
fused = fused_topk_mask(scores, topk)
acc = accuracy_topk(ref, fused)

# Latency
old_ms = do_bench(lambda: pytorch_topk_mask(scores, topk), warmup=warmup, rep=rep)
new_ms = do_bench(lambda: fused_topk_mask(scores, topk), warmup=warmup, rep=rep)

speedup = old_ms / new_ms if new_ms > 0 else float("inf")
print(f" topk | old: {old_ms:8.3f} ms | new: {new_ms:8.3f} ms | speedup: {speedup:5.2f}x "
f"| row_exact_match: {acc['row_exact_match']:.4f} | count_match: {acc['count_match']:.4f}")


def bench_topk_neginf(
B: int, H: int, num_blocks: int, topk: int,
dtype: torch.dtype, inf_ratio: float = 0.3,
) -> None:
"""Test topk correctness when a fraction of scores are -inf (masked positions)."""
scores = torch.randn(B, H, num_blocks, num_blocks, dtype=dtype, device="cuda")
inf_mask = torch.rand(B, H, num_blocks, num_blocks, device="cuda") < inf_ratio
scores[inf_mask] = float("-inf")

ref = pytorch_topk_mask(scores, topk)
fused = fused_topk_mask(scores, topk)
acc = accuracy_topk(ref, fused)

status = "PASS" if acc["row_exact_match"] == 1.0 else "FAIL"
print(f" topk -inf | {inf_ratio*100:.0f}% masked | {status} "
f"| row_exact_match: {acc['row_exact_match']:.4f} | count_match: {acc['count_match']:.4f}")


_MAX_KV_BLOCK_SIZE = 4096


def bench_topk_large_kv(
B: int, H: int, kv_blocks: int, topk: int,
dtype: torch.dtype, warmup: int, rep: int,
) -> None:
"""Test topk with large kv_blocks that may exceed Triton block size limit."""
import triton
kv_block_size = triton.next_power_of_2(kv_blocks)
fallback = kv_block_size > _MAX_KV_BLOCK_SIZE

q_blocks = max(1, kv_blocks // 8)
scores = torch.randn(B, H, q_blocks, kv_blocks, dtype=dtype, device="cuda")

ref = pytorch_topk_mask(scores, topk)
fused = fused_topk_mask(scores, topk)
acc = accuracy_topk(ref, fused)

old_ms = do_bench(lambda: pytorch_topk_mask(scores, topk), warmup=warmup, rep=rep)
new_ms = do_bench(lambda: fused_topk_mask(scores, topk), warmup=warmup, rep=rep)

speedup = old_ms / new_ms if new_ms > 0 else float("inf")
path = "fallback" if fallback else "triton"
print(f" topk kv={kv_blocks:<5d} | {path:8s} | old: {old_ms:8.3f} ms | new: {new_ms:8.3f} ms "
f"| speedup: {speedup:5.2f}x | row_exact_match: {acc['row_exact_match']:.4f}")


def main() -> None:
args = parse_arguments()
set_seed(args.seed)

dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
B, H, D = args.batch_size, args.num_heads, args.head_dim
block_elements = args.block_elements

print("Fused Compress & TopK Benchmark")
print(f"device: {torch.cuda.get_device_name(0)}")
print(f"batch={B}, heads={H}, head_dim={D}, block_elements={block_elements}, dtype={args.dtype}")
print(f"warmup={args.warmup}, rep={args.rep}")

for seq_len in args.seq_lens:
if seq_len % block_elements != 0:
print(f"\n[skip] seq_len={seq_len} not divisible by block_elements={block_elements}")
continue

num_blocks = seq_len // block_elements
topk = args.topk if args.topk is not None else max(1, num_blocks // 10)
topk = min(topk, num_blocks)

print(f"\n{'=' * 100}")
print(f"seq_len={seq_len}, num_blocks={num_blocks}, topk={topk}")
print("-" * 100)

bench_compress_fwd(B, H, seq_len, D, block_elements, dtype, args.warmup, args.rep)
bench_compress_bwd(B, H, seq_len, D, block_elements, dtype, args.warmup, args.rep)
bench_topk(B, H, num_blocks, topk, dtype, args.warmup, args.rep)

# --- Edge case: topk with -inf scores ---
print(f"\n{'=' * 100}")
print("TopK with -inf scores (masked positions)")
print("-" * 100)
kv = args.seq_lens[0] // block_elements
tk = args.topk if args.topk is not None else max(1, kv // 10)
tk = min(tk, kv)
for inf_ratio in [0.1, 0.3, 0.5, 0.8]:
bench_topk_neginf(B, H, kv, tk, dtype, inf_ratio)

# --- Edge case: topk with large kv_blocks (Triton block size limit) ---
print(f"\n{'=' * 100}")
print("TopK with large kv_blocks (Triton block size limit test)")
print("-" * 100)
for kv_blocks in [1024, 2048, 4096, 8192]:
tk = max(1, kv_blocks // 10)
bench_topk_large_kv(1, 1, kv_blocks, tk, dtype, args.warmup, args.rep)


if __name__ == "__main__":
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark.")
main()
37 changes: 11 additions & 26 deletions fastvideo-kernel/python/fastvideo_kernel/ops.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import math
import torch
from .block_sparse_attn import block_sparse_attn, block_sparse_attn_from_indices
from .block_sparse_attn import block_sparse_attn
from .block_sparse_attn_256 import (
block_sparse_attn_256,
block_sparse_attn_256_bshd,
)
from .triton_kernels.st_attn_triton import sliding_tile_attention_triton
from .triton_kernels.fused_compress_topk import fused_block_mean, fused_topk_mask

# Try to load the C++ extension
try:
Expand Down Expand Up @@ -118,39 +119,24 @@ def video_sparse_attn(
f"got {q_variable_block_sizes.numel()}"
)

# Compression branch (token-level average per block + dense block-level attn).
q_c = q.view(batch, heads, q_num_blocks, block_elements, dim)
k_c = k.view(batch, heads, kv_num_blocks, block_elements, dim)
v_c = v.view(batch, heads, kv_num_blocks, block_elements, dim)
q_c = (q_c.float().sum(dim=3) / q_variable_block_sizes.view(1, 1, -1, 1)).to(q.dtype)
k_c = (k_c.float().sum(dim=3) / variable_block_sizes.view(1, 1, -1, 1)).to(k.dtype)
v_c = (v_c.float().sum(dim=3) / variable_block_sizes.view(1, 1, -1, 1)).to(v.dtype)
# Compression branch (fused Triton: bf16 read → fp32 accumulate → div → bf16 write)
q_c = fused_block_mean(q, q_variable_block_sizes, block_elements)
k_c = fused_block_mean(k, variable_block_sizes, block_elements)
v_c = fused_block_mean(v, variable_block_sizes, block_elements)

scores = torch.matmul(q_c, k_c.transpose(-2, -1)) / (dim ** 0.5)
attn = torch.softmax(scores, dim=-1)
out_c = torch.matmul(attn, v_c)
out_c = out_c.view(batch, heads, q_num_blocks, 1, dim)
out_c = out_c.repeat(1, 1, 1, block_elements, 1).view(batch, heads, q_seq_len, dim)

# Sparse branch.
topk_idx = torch.topk(scores, topk, dim=-1).indices
# Sparse branch (fused Triton topk mask)
mask = fused_topk_mask(scores, topk)

if block_elements == 256:
# CuTe path consumes a bool mask (full/partial split inside the wrapper).
mask = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, topk_idx, True)
out_s = block_sparse_attn_256(q, k, v, mask, variable_block_sizes)[0]
else:
# Index-native path for 64-block (TK/Triton).
q2k_idx = topk_idx.to(torch.int32).contiguous()
q2k_num = torch.full(
(batch, heads, q_num_blocks),
topk,
dtype=torch.int32,
device=q.device,
)
out_s = block_sparse_attn_from_indices(
q, k, v, q2k_idx, q2k_num, variable_block_sizes
)[0]
out_s = block_sparse_attn(q, k, v, mask, variable_block_sizes)[0]

if compress_attn_weight is not None:
return out_c * compress_attn_weight + out_s
Expand Down Expand Up @@ -236,9 +222,8 @@ def video_sparse_attn_bshd(
out_c_ch = torch.matmul(attn, v_ch)
out_c_blk = out_c_ch.permute(0, 2, 1, 3).contiguous()

# Sparse branch (CuTe BSHD).
topk_idx = torch.topk(scores, topk, dim=-1).indices
mask = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, topk_idx, True)
# Sparse branch (fused Triton topk mask + CuTe BSHD).
mask = fused_topk_mask(scores, topk)
out_s, _ = block_sparse_attn_256_bshd(q, k, v, mask, variable_block_sizes)

out = out_s
Expand Down
Loading
Loading