Skip to content

feat: intracard cp for sm90#86

Open
Hyaloid wants to merge 1 commit into
inclusionAI:mainfrom
Hyaloid:intcd-cp
Open

feat: intracard cp for sm90#86
Hyaloid wants to merge 1 commit into
inclusionAI:mainfrom
Hyaloid:intcd-cp

Conversation

@Hyaloid

@Hyaloid Hyaloid commented Jun 4, 2026

Copy link
Copy Markdown

📌 Description

The serial bottleneck

kda_prefill_hopper (cuLA's SM90 KDA prefill) launches one CTA per (seq, head) and runs a strictly
sequential chunk recurrence inside each sequence: h_t = decay(g_t) · h_{t-1} + k_t^T @ (u_t − w_t·h_{t-1}).
Within one sequence, work cannot parallelize across chunks — only across the (raw_batch × H) grid.

This becomes a bottleneck when both:

  1. raw_batch × H is small — the baseline grid under‑utilizes the SMs. A single long sequence at
    H=8 occupies only 8 CTAs on a 132‑SM H100 (~6% occupancy). The per‑SM work is so small that most of the card is idle waiting on 8 serial chains.
  2. The shape has a long‑tail sequence (e.g. 128K+1K packed) — the long seq's serial recurrence
    dominates wall time while short seqs finish in microseconds and leave SMs idle.

Approach

Mirroring FLA's intra‑card CP design (and the SM100 cuLA path in cula/ops/cp/chunk_delta_h.py),
this PR splits long sequences into CP‑chunks on the same card and produces per‑CP‑chunk initial
states so the main C++ kernel can run all CP‑chunks in parallel.

🔍 Related Issues

Similar to this issue #20 , but for SM90.

🚀 Pull Request Checklist

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.
clang-format.............................................................Passed
ruff (legacy alias)......................................................Passed
ruff format..............................................................Passed

🧪 Tests

python -m pytest tests/test_intracard_cp_sm90.py -v

platform linux -- Python 3.12.3, pytest-9.1.0, pluggy-1.6.0 -- /opt/torch/bin/python
cachedir: .pytest_cache
configfile: pyproject.toml
plugins: anyio-4.12.1
collected 63 items                                                                                                                                               

tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens0-4-False] PASSED
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens1-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens2-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens3-8-False] PASSED
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens4-8-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens5-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens6-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens7-4-False] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens0-4-False] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens1-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens2-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens3-8-False] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens4-8-True] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens5-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens6-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens7-4-False] PASSED
tests/test_intracard_cp_sm90.py::test_cp_off_matches_basic_baseline PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens0-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens1-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens2-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens3-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens4-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens5-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens6-8] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens7-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens8-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens9-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens10-8] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens11-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens12-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens13-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens0-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens1-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens2-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens3-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens4-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens5-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens6-8] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens7-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens8-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens9-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens10-8] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens11-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens12-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens13-4] PASSED
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens0-4] PASSED
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens1-4] PASSED
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens2-8] PASSED
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens3-4] PASSED
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens4-4] PASSED
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens5-4] PASSED
tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens0-4-False] PASSED
tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens1-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens2-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens3-4-False] PASSED
tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens4-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_stress_repeat[single-64K-H4-h0] PASSED
tests/test_intracard_cp_sm90.py::test_cp_stress_repeat[multi-64K+4K-H4-h0] PASSED
tests/test_intracard_cp_sm90.py::test_cp_h0_none_equiv_h0_zeros PASSED
tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens0-8] PASSED
tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens1-64] PASSED
tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens2-8] PASSED
tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens3-8] PASSED

============================================================================= PASSES =============================================================================
====================================================================== slowest 15 durations ======================================================================
10.53s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens0-4-False]
9.77s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens3-8-False]
7.66s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens5-4-True]
0.20s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_stress_repeat[multi-64K+4K-H4-h0]
0.17s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_stress_repeat[single-64K-H4-h0]
0.12s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens0-4]
0.06s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens3-8]
0.04s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens7-4-False]
0.03s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens7-4-False]
0.03s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens4-4-True]
0.02s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens4-8-True]
0.02s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens4-8-True]
0.02s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens2-4-True]
0.02s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens2-4-True]
0.02s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens1-64]
==================================================================== short test summary info =====================================================================
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens0-4-False]
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens1-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens2-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens3-8-False]
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens4-8-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens5-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens6-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens7-4-False]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens0-4-False]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens1-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens2-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens3-8-False]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens4-8-True]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens5-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens6-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens7-4-False]
PASSED tests/test_intracard_cp_sm90.py::test_cp_off_matches_basic_baseline
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens0-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens1-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens2-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens3-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens4-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens5-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens6-8]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens7-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens8-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens9-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens10-8]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens11-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens12-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens13-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens0-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens1-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens2-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens3-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens4-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens5-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens6-8]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens7-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens8-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens9-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens10-8]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens11-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens12-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens13-4]
PASSED tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens0-4]
PASSED tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens1-4]
PASSED tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens2-8]
PASSED tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens3-4]
PASSED tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens4-4]
PASSED tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens5-4]
PASSED tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens0-4-False]
PASSED tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens1-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens2-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens3-4-False]
PASSED tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens4-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_stress_repeat[single-64K-H4-h0]
PASSED tests/test_intracard_cp_sm90.py::test_cp_stress_repeat[multi-64K+4K-H4-h0]
PASSED tests/test_intracard_cp_sm90.py::test_cp_h0_none_equiv_h0_zeros
PASSED tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens0-8]
PASSED tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens1-64]
PASSED tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens2-8]
PASSED tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens3-8]
====================================================================== 63 passed in 35.27s =======================================================================


  • Tests have been added or updated as needed.
  • All tests are passing.

⚡ Performance

python benchmarks/bench_intracard_cp_sm90.py

==============================================================================================================
                       BENCHMARK REPORT: Intracard CP (SM90)
                       CP-on (kda_prefill_hopper_auto) vs CP-off (kda_prefill_hopper)
                       D=128  dtype=bf16  safe_gate=True
                       Warmup=10  Iters=10
==============================================================================================================

  [H=4]
  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  config                         T  pred  sub fused_pre  │         o max/mean        ht max/mean  │  CP_off(ms)   CP_on(ms)   Speedup
  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  4x256                       1024     N     0     Y  │  2.4e-04/2.4e-07  2.0e-03/3.5e-07  │      0.3059      0.1411     2.17x
  8x256                       2048     N     0     Y  │  2.4e-04/2.4e-07  2.0e-03/2.7e-07  │      0.2820      0.1376     2.05x
  16x256                      4096     N     0     Y  │  2.4e-04/2.1e-07  9.8e-04/2.5e-07  │      0.2815      0.1478     1.91x
  4x1K                        4096     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.3090      0.2768     1.12x
  8x1K                        8192     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.3021      0.2709     1.11x
  4x2K                        8192     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.4381      0.4377     1.00x
  1K+512+256+128              1920     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.2960      0.2348     1.26x
  2K+1K+512+256               3840     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.4188      0.4186     1.00x
  1K+1+63+65+129              1282     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.3014      0.2289     1.32x
  T=4K                        4096     Y    16     Y  │  2.4e-04/7.1e-07  1.5e-04/9.8e-08  │      0.7917      0.4744     1.67x
  T=8K                        8192     Y    32     Y  │  2.4e-04/2.0e-07  1.8e-07/6.4e-12  │      1.5285      0.5210     2.93x
  T=32K                      32768     Y    32     N  │  2.4e-04/7.7e-08  0.0e+00/0.0e+00  │      6.0181      1.3304     4.52x
  T=64K                      65536     Y    32     N  │  3.1e-04/4.1e-07  5.6e-06/5.0e-10  │     11.9152      2.4489     4.87x
  T=128K                    131072     Y    32     N  │  2.4e-04/7.1e-09  0.0e+00/0.0e+00  │     23.7868      4.7113     5.05x
  8x4K                       32768     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.8597      0.8704     0.99x
  4x8K                       32768     Y    32     N  │  2.4e-04/6.9e-08  0.0e+00/0.0e+00  │      1.6107      1.4507     1.11x
  2x16K                      32768     Y    32     N  │  2.4e-04/7.5e-08  0.0e+00/0.0e+00  │      3.1052      1.4304     2.17x
  16K+16K                    32768     Y    32     N  │  2.4e-04/7.5e-08  0.0e+00/0.0e+00  │      3.1048      1.4500     2.14x
  24K+8K                     32768     Y    32     N  │  2.4e-04/7.4e-08  0.0e+00/0.0e+00  │      4.5269      1.4289     3.17x
  28K+4K                     32768     Y    32     N  │  2.4e-04/7.6e-08  0.0e+00/0.0e+00  │      5.2484      1.4333     3.66x
  32K+256+256                33280     Y    34     N  │  2.4e-04/7.6e-08  0.0e+00/0.0e+00  │      5.9927      1.4946     4.01x
  40K+1K+8K                  50176     Y    25     N  │  3.7e-04/2.4e-07  0.0e+00/0.0e+00  │      7.5075      2.1744     3.45x
  64K+512+256+128            66432     Y    35     N  │  3.1e-04/4.1e-07  5.6e-06/1.2e-10  │     11.9276      2.7482     4.34x
  128K+1K                   132096     Y    33     N  │  2.4e-04/7.1e-09  0.0e+00/0.0e+00  │     23.8150      5.1974     4.58x
  128K+2x1K                 133120     Y    34     N  │  1.2e-04/4.0e-10  0.0e+00/0.0e+00  │     23.8056      5.3161     4.48x
  128K+5x1K                 136192     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     23.8974     23.8564     1.00x
  128K+10x1K                141312     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     23.8520     23.8534     1.00x
  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

  [H=8]
  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  config                         T  pred  sub fused_pre  │         o max/mean        ht max/mean  │  CP_off(ms)   CP_on(ms)   Speedup
  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  4x256                       1024     N     0     Y  │  1.2e-04/1.9e-07  8.5e-04/1.6e-07  │      0.2952      0.1422     2.08x
  8x256                       2048     N     0     Y  │  4.9e-04/2.2e-07  1.2e-03/1.7e-07  │      0.2859      0.1445     1.98x
  16x256                      4096     N     0     Y  │  2.4e-04/1.8e-07  9.5e-04/1.0e-07  │      0.2875      0.1556     1.85x
  4x1K                        4096     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.3081      0.2728     1.13x
  8x1K                        8192     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.2959      0.2733     1.08x
  4x2K                        8192     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.4547      0.4533     1.00x
  1K+512+256+128              1920     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.3020      0.2770     1.09x
  2K+1K+512+256               3840     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.4401      0.4399     1.00x
  1K+1+63+65+129              1282     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.3053      0.2761     1.11x
  T=4K                        4096     Y    16     Y  │  2.4e-04/2.4e-07  1.3e-04/5.5e-08  │      0.8058      0.4939     1.63x
  T=8K                        8192     Y    16     Y  │  2.4e-04/1.9e-07  6.3e-05/1.7e-08  │      1.5636      0.6008     2.60x
  T=32K                      32768     Y    16     N  │  4.9e-04/2.2e-07  3.4e-05/3.4e-10  │      6.1141      1.8119     3.37x
  T=64K                      65536     Y    16     N  │  1.2e-04/4.2e-09  0.0e+00/0.0e+00  │     12.1999      3.4392     3.55x
  T=128K                    131072     Y    16     N  │  2.4e-04/6.0e-08  1.5e-06/1.3e-11  │     24.3839      6.6673     3.66x
  8x4K                       32768     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      0.9406      0.9419     1.00x
  4x8K                       32768     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      1.6618      1.6637     1.00x
  2x16K                      32768     Y    16     N  │  4.9e-04/1.9e-07  3.4e-05/1.7e-10  │      3.1538      1.9019     1.66x
  16K+16K                    32768     Y    16     N  │  4.9e-04/1.9e-07  3.4e-05/1.7e-10  │      3.1532      1.9073     1.65x
  24K+8K                     32768     Y    16     N  │  4.9e-04/2.0e-07  3.4e-05/1.7e-10  │      4.8696      1.9045     2.56x
  28K+4K                     32768     Y    16     N  │  4.9e-04/2.0e-07  3.4e-05/3.8e-10  │      5.5071      1.8876     2.92x
  32K+256+256                33280     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      6.1257      6.1434     1.00x
  40K+1K+8K                  50176     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │      7.7079      7.6964     1.00x
  64K+512+256+128            66432     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     12.2289     12.2281     1.00x
  128K+1K                   132096     Y    17     N  │  2.4e-04/5.2e-09  0.0e+00/0.0e+00  │     24.4154      7.4628     3.27x
  128K+2x1K                 133120     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     24.4315     24.4260     1.00x
  128K+5x1K                 136192     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     24.4402     24.4436     1.00x
  128K+10x1K                141312     N     0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     24.4705     24.4655     1.00x
  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

==============================================================================================================
  Summary
==============================================================================================================
  CP triggered (25 configs): geo-mean=2.94x  best=5.05x  worst=1.11x
  CP bypassed  (29 configs): mean overhead=0.862x  max=1.012x  (1.00 = no regression)
  Accuracy (CP-on vs CP-off): o  max=4.88e-04 avg=1.65e-04   ht max=2.01e-03 avg=1.57e-04
==============================================================================================================

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimized Hopper (SM90) KDA prefill path featuring fused gate and L2-norm preprocessing, along with intra-card CP (chunk-parallel) scheduling. Key feedback includes optimizing cp_context.py to avoid a synchronous D2H copy by computing sequence mappings on the CPU, adding device validation checks in the C++ API to prevent illegal memory accesses, and using an if/else block in the fused L2-norm Triton kernel to eliminate redundant load instructions.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread cula/kda/cp_context.py Outdated
Comment thread cula/kda/cp_context.py Outdated
Comment thread cula/kda/cp_context.py Outdated
Comment thread csrc/api/kda_sm90.cu
Comment thread cula/kda/l2norm_qk_fused.py Outdated
@Hyaloid Hyaloid mentioned this pull request Jun 8, 2026
5 tasks
@icavan icavan requested review from cherhh and icavan June 12, 2026 16:17

@icavan icavan left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Hyaloid Thanks for your contribution. The main idea LGTM, could you add more test cases for varlen settings?

Comment thread benchmarks/bench_intracard_cp_sm90.py Outdated
Comment thread cula/kda/hopper_fused_fwd_opt.py
Comment thread tests/test_intracard_cp_sm90.py
pre-commit

adopt cr suggestions

support varlen fuse l2norm+gate cumsum & fix irregular input
@cherhh

cherhh commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator

@Hyaloid Thanks a lot for this contribution! Could you also add some performance numbers comparing this SM90 intra-card CP path against FLA with intra-card CP enabled?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants