Skip to content

[perf] optimize compress & topk kernel#1421

Open
xsank wants to merge 12 commits into
hao-ai-lab:mainfrom
xsank:main
Open

[perf] optimize compress & topk kernel#1421
xsank wants to merge 12 commits into
hao-ai-lab:mainfrom
xsank:main

Conversation

@xsank

@xsank xsank commented Jun 1, 2026

Copy link
Copy Markdown

Purpose

improve compress & topk performance.

Changes

add fused kernels based on triton.

Test

command

python benchmarks/bench_fused_compress_topk.py --seq_lens 10240 49152 115200

result

Fused Compress & TopK Benchmark
device: NVIDIA L20D
batch=1, heads=12, head_dim=128, block_elements=64, dtype=bf16
warmup=10, rep=50

====================================================================================================
seq_len=10240, num_blocks=160, topk=16
----------------------------------------------------------------------------------------------------
  compress fwd | old:    0.232 ms | new:    0.056 ms | speedup:  4.18x | max_abs_err: 0.00e+00 | cos_sim: 1.00000000
  compress bwd | old:    0.249 ms | new:    0.078 ms | speedup:  3.17x | max_abs_err: 0.00e+00 | cos_sim: 1.00000000
  topk      | old:    0.127 ms | new:    0.036 ms | speedup:  3.48x | row_exact_match: 1.0000 | count_match: 1.0000

====================================================================================================
seq_len=40960, num_blocks=640, topk=64
----------------------------------------------------------------------------------------------------
  compress fwd | old:    0.685 ms | new:    0.200 ms | speedup:  3.42x | max_abs_err: 0.00e+00 | cos_sim: 1.00000012
  compress bwd | old:    0.602 ms | new:    0.075 ms | speedup:  8.04x | max_abs_err: 0.00e+00 | cos_sim: 1.00000000
  topk      | old:    0.654 ms | new:    0.238 ms | speedup:  2.75x | row_exact_match: 1.0000 | count_match: 1.0000

====================================================================================================
seq_len=102400, num_blocks=1600, topk=160
----------------------------------------------------------------------------------------------------
  compress fwd | old:    1.557 ms | new:    0.422 ms | speedup:  3.69x | max_abs_err: 0.00e+00 | cos_sim: 1.00000000
  compress bwd | old:    1.190 ms | new:    0.223 ms | speedup:  5.33x | max_abs_err: 0.00e+00 | cos_sim: 1.00000000
  topk      | old:    2.381 ms | new:    1.149 ms | speedup:  2.07x | row_exact_match: 1.0000 | count_match: 1.0000

====================================================================================================
TopK with -inf scores (masked positions)
----------------------------------------------------------------------------------------------------
  topk -inf | 10% masked | PASS | row_exact_match: 1.0000 | count_match: 1.0000
  topk -inf | 30% masked | PASS | row_exact_match: 1.0000 | count_match: 1.0000
  topk -inf | 50% masked | PASS | row_exact_match: 1.0000 | count_match: 1.0000
  topk -inf | 80% masked | PASS | row_exact_match: 1.0000 | count_match: 1.0000

====================================================================================================
TopK with large kv_blocks (Triton block size limit test)
----------------------------------------------------------------------------------------------------
  topk kv=1024  | triton   | old:    0.118 ms | new:    0.058 ms | speedup:  2.03x | row_exact_match: 1.0000
  topk kv=2048  | triton   | old:    0.221 ms | new:    0.028 ms | speedup:  7.92x | row_exact_match: 1.0000
  topk kv=4096  | triton   | old:    0.353 ms | new:    0.108 ms | speedup:  3.25x | row_exact_match: 1.0000
  topk kv=8192  | fallback | old:    0.526 ms | new:    0.558 ms | speedup:  0.94x | row_exact_match: 1.0000

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Welcome to FastVideo! Thanks for your first pull request.

How our CI works:

PRs run a two-tier CI system:

  1. Pre-commit — formatting (yapf), linting (ruff), type checking (mypy). Runs immediately on every PR.
  2. Fastcheck — core GPU tests (encoders, VAEs, transformers, kernels, unit tests). Runs automatically via Buildkite on relevant file changes (~10-15 min).
  3. Full Suite — integration tests, training pipelines, SSIM regression. Runs only when a reviewer adds the ready label.

Before your PR is reviewed:

  • pre-commit run --all-files passes locally
  • You've added or updated tests for your changes
  • The PR description explains what and why

If pre-commit fails, a bot comment will explain how to fix it. Fastcheck and Full Suite results appear in the Checks section below.

Useful links:

@mergify mergify Bot added type: perf Performance improvement scope: kernel CUDA kernels, fastvideo-kernel labels Jun 1, 2026
@mergify

mergify Bot commented Jun 1, 2026

Copy link
Copy Markdown
Contributor

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🟠 PR merge requirements

Waiting for

  • check-success~=pre-commit
Waiting checks: pre-commit.
  • check-success~=pre-commit
  • #approved-reviews-by>=1
  • check-success=fastcheck-passed
  • check-success=full-suite-passed
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model|skill|skills|infra)\]

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces fused Triton kernels (fused_block_mean and fused_topk_mask) to replace the PyTorch-based pipeline for block mean compression and topk mask construction, along with a benchmark script to verify performance and accuracy. The review feedback highlights several critical improvements for the Triton kernels: resolving hardcoded tl.bfloat16 types to support other dtypes like float16, optimizing sequential loops using 2D block loads/stores and parallel reductions, handling -inf values in the topk binary search to prevent convergence failure, and adding a fallback to PyTorch's torch.topk when sequence lengths exceed Triton's maximum block size limit.

Comment thread fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py Outdated
Comment thread fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py Outdated
Comment thread fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py Outdated
@xsank

xsank commented Jun 4, 2026

Copy link
Copy Markdown
Author

@alexzms @Davids048 @SolitaryThinker could you please take a look at this change, thanks

@alexzms alexzms self-requested a review June 4, 2026 20:43
@alexzms

alexzms commented Jun 4, 2026

Copy link
Copy Markdown
Collaborator

@xsank Thanks for contribution! I will take a look

@alexzms

alexzms commented Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Thanks for the speedup! Two quick correctness checks:

1. Have you considered exact ties at the k-th boundary in fused_topk_mask? When several scores tie at the threshold, the bisection >threshold path can grab the whole tied cluster and select more than topk (e.g. two equal top scores with topk=1 → 2 selected). FWIW I couldn't trigger it on realistic bf16 q_c@k_c scores, so impact looks low — but a tie test + a clamp to topk would lock the contract down.

2. The 64-block path now goes through block_sparse_attn(mask) instead of block_sparse_attn_from_indices. Does the mask path assume exactly topk per row anywhere (buffer sizing / index compaction)? If so, the tie case above could feed it an unexpected count. An end-to-end video_sparse_attn equivalence check (fused vs. old index path) would cover both.

@xsank

xsank commented Jun 6, 2026

Copy link
Copy Markdown
Author

@alexzms Thansk for reply.

  1. Even if multiple values are tied at the threshold, after applying at_threshold & (at_thresh_cumsum <= n_needed_at_thresh), only topk - n_above values will be retained, so the final topk is strictly correct.
  2. block_sparse_attn is a thin wrapper of block_sparse_attn_from_indices , I don’t think we need to guarantee that the topk per row is strictly consistent either.

As for the e2e testing, python fastvideo-kernel/benchmarks/bench_vsa.py also works well. Should i add more benchmark cases?

@alexzms

alexzms commented Jun 7, 2026

Copy link
Copy Markdown
Collaborator

Quick follow-up on (1): I ran it on a real GPU and exact boundary ties do over-select. The cumsum clamp only bounds the == threshold set; the overshoot comes from > threshold itself — when the float threshold converges just below a tied cluster, n_above already exceeds topk, n_needed goes negative, and nothing trims it.

scores=[0.5222, 0.5222, 0.8875, 0.5222, 0.5222, 0.5222, 0.8875], topk=1  ->  selects 2

That said — on realistic bf16 q_c @ k_c scores I got 0 failures in 35k rows, so I agree the real-world impact is negligible and I'm fine not gating on it. Up to you whether to add a one-line clamp + a tie unit test, or just leave a comment noting the assumption. No need for more perf benchmark cases — an e2e video_sparse_attn equivalence check (fused vs. old index path) on one real shape is the only thing I'd still suggest. Nice speedup either way. 👍

@xsank

xsank commented Jun 8, 2026

Copy link
Copy Markdown
Author

@alexzms thanks.
I added e2e tests for video_sparse_attn , specifically the test file fastvideo-kernel/tests/test_fused_compress_topk.py, and also added comments in the topk implementation.

@SolitaryThinker SolitaryThinker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @xsank — automated review from Gob, one of @SolitaryThinker's AI reviewers. Findings aren't all human-verified; ping @SolitaryThinker if anything looks off.

Open PR, kernel/perf scope. fastcheck + microscope-kernel-tests are green; Mergify is just the merge gate (needs 1 approval + pre-commit + full-suite). Nice speedup, and the methodology is solid — a real do_bench benchmark plus a CUDA test that actually runs in the kernel-tests lane. I verified the tie fix is correct for the real bf16 score path (more below). One real gap before this is a clean approve.

Verdict: COMMENT — no blocker; correctness verified for the production path. One MAJOR test gap.

Major

  • fastvideo-kernel/tests/test_fused_compress_topk.py — no backward coverage. fused_block_mean ships a hand-written Triton backward (fused_compress_topk.py:58-95,163-175) and VSA backprops through it in training (video_sparse_attn is the train-time VSA backend, fastvideo/attention/backends/video_sparse_attn.py:308). The test file exercises only forward equivalence — there's no requires_grad/.backward() parity test, so the bwd is validated only by the PR-body benchmark, which doesn't run in CI. A forward-only kernel silently breaking training is the classic trap here. Please add a gradient-parity test (autograd fused_block_mean vs the eager view→float→sum→div→bf16 path, atol=1e-2 bf16) so the backward lands in microscope-kernel-tests.

Minor

  • fused_compress_topk.py:242-246 — the convergence comment overstates the guarantee. "precision ~2^-32 ≈ 2.3e-10" is range-relative ((hi-lo)/2^32), not absolute. I simulated the exact fp32 bisection: for realistic q_c@k_c/sqrt(d) bf16 scores (range ≈ 8, values O(1)) the per-step resolution (~2e-9) is far below the bf16 gap (~8e-3), so it converges exactly to the k-th bf16 value and selects exactly topk — 0/160 rows wrong, matching @alexzms's "0 failures in 35k rows." But a constructed tie cluster at the k-th boundary with a wide range (max ≈ 1e7) over-selects (8 vs 6), because the resolution no longer beats the bf16 gap. That's unreachable for softmax-input scores so it isn't a VSA correctness bug — but I'd reword the comment to state the real precondition ("scores are O(1) so 32 iters resolves below the bf16 ULP") rather than implying unconditional exact convergence.
  • fused_compress_topk.py:298-300 — the kv_blocks>4096 PyTorch fallback is silent, and the body's own data shows it regresses (kv=8192 | fallback | 0.94x). It's the right safety valve; a one-line logger.debug or a docstring note would save a future profiler the surprise. Optional.

Confirmed @alexzms's point #2 is fine: block_sparse_attn is a thin wrapper over block_sparse_attn_from_indices (block_sparse_attn.py:420-422) and q2k_num is derived from the mask in map_to_index, so an over-selected row can't overflow buffers — it just attends to a few extra blocks.

— Gob (@SolitaryThinker's AI reviewer).

@xsank

xsank commented Jun 8, 2026

Copy link
Copy Markdown
Author

@SolitaryThinker I've made some changes based on your AI suggestions. Please take another look, Thanks.

@xsank xsank requested a review from SolitaryThinker June 8, 2026 11:58

@alexzms alexzms left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM — thanks for the fast iterations. Both of my earlier points are addressed, approving. 🚀

@xsank xsank requested a review from alexzms June 9, 2026 00:11
@xsank

xsank commented Jun 9, 2026

Copy link
Copy Markdown
Author

@alexzms @SolitaryThinker Thank you all for review. Let me see how to further optimize the BSA.

@alexzms alexzms added the ready PR is ready to merge label Jun 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready PR is ready to merge scope: kernel CUDA kernels, fastvideo-kernel type: perf Performance improvement

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants