[perf] optimize compress & topk kernel#1421
Conversation
There was a problem hiding this comment.
Welcome to FastVideo! Thanks for your first pull request.
How our CI works:
PRs run a two-tier CI system:
- Pre-commit — formatting (yapf), linting (ruff), type checking (mypy). Runs immediately on every PR.
- Fastcheck — core GPU tests (encoders, VAEs, transformers, kernels, unit tests). Runs automatically via Buildkite on relevant file changes (~10-15 min).
- Full Suite — integration tests, training pipelines, SSIM regression. Runs only when a reviewer adds the
readylabel.
Before your PR is reviewed:
-
pre-commit run --all-filespasses locally - You've added or updated tests for your changes
- The PR description explains what and why
If pre-commit fails, a bot comment will explain how to fix it. Fastcheck and Full Suite results appear in the Checks section below.
Useful links:
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🟠 PR merge requirementsWaiting for
Waiting checks:
|
There was a problem hiding this comment.
Code Review
This pull request introduces fused Triton kernels (fused_block_mean and fused_topk_mask) to replace the PyTorch-based pipeline for block mean compression and topk mask construction, along with a benchmark script to verify performance and accuracy. The review feedback highlights several critical improvements for the Triton kernels: resolving hardcoded tl.bfloat16 types to support other dtypes like float16, optimizing sequential loops using 2D block loads/stores and parallel reductions, handling -inf values in the topk binary search to prevent convergence failure, and adding a fallback to PyTorch's torch.topk when sequence lengths exceed Triton's maximum block size limit.
|
@alexzms @Davids048 @SolitaryThinker could you please take a look at this change, thanks |
|
@xsank Thanks for contribution! I will take a look |
|
Thanks for the speedup! Two quick correctness checks: 1. Have you considered exact ties at the k-th boundary in 2. The 64-block path now goes through |
|
@alexzms Thansk for reply.
As for the e2e testing, |
|
Quick follow-up on (1): I ran it on a real GPU and exact boundary ties do over-select. The That said — on realistic bf16 |
|
@alexzms thanks. |
SolitaryThinker
left a comment
There was a problem hiding this comment.
Hi @xsank — automated review from Gob, one of @SolitaryThinker's AI reviewers. Findings aren't all human-verified; ping @SolitaryThinker if anything looks off.
Open PR, kernel/perf scope. fastcheck + microscope-kernel-tests are green; Mergify is just the merge gate (needs 1 approval + pre-commit + full-suite). Nice speedup, and the methodology is solid — a real do_bench benchmark plus a CUDA test that actually runs in the kernel-tests lane. I verified the tie fix is correct for the real bf16 score path (more below). One real gap before this is a clean approve.
Verdict: COMMENT — no blocker; correctness verified for the production path. One MAJOR test gap.
Major
fastvideo-kernel/tests/test_fused_compress_topk.py— no backward coverage.fused_block_meanships a hand-written Triton backward (fused_compress_topk.py:58-95,163-175) and VSA backprops through it in training (video_sparse_attnis the train-time VSA backend,fastvideo/attention/backends/video_sparse_attn.py:308). The test file exercises only forward equivalence — there's norequires_grad/.backward()parity test, so the bwd is validated only by the PR-body benchmark, which doesn't run in CI. A forward-only kernel silently breaking training is the classic trap here. Please add a gradient-parity test (autogradfused_block_meanvs the eagerview→float→sum→div→bf16path,atol=1e-2bf16) so the backward lands inmicroscope-kernel-tests.
Minor
fused_compress_topk.py:242-246— the convergence comment overstates the guarantee. "precision ~2^-32 ≈ 2.3e-10" is range-relative ((hi-lo)/2^32), not absolute. I simulated the exact fp32 bisection: for realisticq_c@k_c/sqrt(d)bf16 scores (range ≈ 8, values O(1)) the per-step resolution (~2e-9) is far below the bf16 gap (~8e-3), so it converges exactly to the k-th bf16 value and selects exactlytopk— 0/160 rows wrong, matching @alexzms's "0 failures in 35k rows." But a constructed tie cluster at the k-th boundary with a wide range (max ≈ 1e7) over-selects (8 vs 6), because the resolution no longer beats the bf16 gap. That's unreachable for softmax-input scores so it isn't a VSA correctness bug — but I'd reword the comment to state the real precondition ("scores are O(1) so 32 iters resolves below the bf16 ULP") rather than implying unconditional exact convergence.fused_compress_topk.py:298-300— thekv_blocks>4096PyTorch fallback is silent, and the body's own data shows it regresses (kv=8192 | fallback | 0.94x). It's the right safety valve; a one-linelogger.debugor a docstring note would save a future profiler the surprise. Optional.
Confirmed @alexzms's point #2 is fine: block_sparse_attn is a thin wrapper over block_sparse_attn_from_indices (block_sparse_attn.py:420-422) and q2k_num is derived from the mask in map_to_index, so an over-selected row can't overflow buffers — it just attends to a few extra blocks.
— Gob (@SolitaryThinker's AI reviewer).
|
@SolitaryThinker I've made some changes based on your AI suggestions. Please take another look, Thanks. |
|
@alexzms @SolitaryThinker Thank you all for review. Let me see how to further optimize the BSA. |
Purpose
improve compress & topk performance.
Changes
add fused kernels based on triton.
Test
command
result