Adds Grouped and Batched GEMM kernels with blockscaling matching DeepGEMM API#433
Adds Grouped and Batched GEMM kernels with blockscaling matching DeepGEMM API#433aryaman-gupta wants to merge 58 commits intomainfrom
Conversation
Replace hardcoded test calls with argparse-based __main__ matching the pattern used by other kernel tests (blockscale, moe, preshuffle). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace inline correctness checks and manual CUDA event benchmarking with shared utilities from tests.test_common, matching the pattern used by blockscale_preshuffle_gemm and other kernel tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
tile_m/n/k, out_dtype, num_iters, and num_warmup were parsed but never passed to the test functions which hardcoded their own values. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Explicitly create the CPU reference output tensor with device='cpu'
to avoid conflict when torch.set_default_device('cuda') is active.
The reference matmul stays on CPU due to hipBLAS issues on this ROCm.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Split load_a_tile into prefetch_a_tile (Global→VGPR) and store_a_tile_to_lds (VGPR→LDS). Moves ds_write after compute_tile to match the MoE blockscale 2-stage pipeline, enabling future instruction scheduling to interleave ds_write with trailing MFMAs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds hot_loop_scheduler() with coarse-grained sched_group_barrier hints matching the moe_blockscale_2stage pattern. Placed after store_a_tile_to_lds and before gpu.barrier(), only emitted when a next tile actually exists (avoids LLVM assertion from mismatched instruction counts on tail iterations). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds optional waves_per_eu parameter to compile functions. When set, applies rocdl.waves_per_eu attribute to gpu.func for occupancy tuning. Matches the pattern from blockscale_preshuffle_gemm and moe_blockscale_2stage kernels. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Contiguous test: - Generate unaligned M group sizes with -1 padding rows (DeepGEMM convention) - Add unaligned M test cases (2g-100m, 4g-200m) - Add DeepSeek-V3 shapes (2112x7168, 7168x2304) - Add out_dtype parametrization (bf16 + f16) - Zero out padding rows before comparison - Add --waves_per_eu CLI arg Masked test: - Fix output buffer dtype bug (was hardcoded bf16, now respects out_dtype) - Add sparse masking test (4g-512max-50m) - Add DeepSeek-V3 shapes - Add out_dtype parametrization (bf16 + f16) - Wire out_dtype through generate_masked_grouped_gemm_inputs - Add --waves_per_eu CLI arg Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The s_a_vecs f32 loads are unused on the gfx950 HW path and would index out-of-bounds against the int8-sized scale buffer if MLIR DCE ever failed to eliminate them. Gating makes the gfx950/942 split explicit. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Add `prefetch_scales(k_tile_idx_py)` helper that loads the E8M0 byte for each (mi, ni) of the next K-tile into VGPRs ahead of `compute_tile`. Issued before `load_b_tile` in the ping-pong loop so scale-VMEM latency overlaps the prior tile's MFMAs and the next B-tile load. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
- Register both tests in tests/arch_compat.py:CDNA_ONLY_TESTS so RDNA CI auto-skips them. - Add `pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower]` so CI buckets them correctly. - Add `torch.cuda.is_available()` module-level skip guard. - Drop dead `--waves_per_eu` argparse arg (was accepted but never forwarded). - Merge per-file `*_correctness` and `*_performance` into a single `test_grouped_fp8_gemm` / `test_masked_grouped_fp8_gemm` matching the test_blockscale_preshuffle_gemm convention. - Move the per-group reference matmul from CPU to GPU (hipBLASLt). Test suite runtime drops ~70s → ~34s. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
- Remove unused `import os` from both kernel files. - Remove orphan `# Helper: compute one K-tile from LDS + B tile` banner from both kernels (the function it labeled was renamed/refactored). - Remove duplicate `c_scale_k = fx.Index(scale_k)` reassignment in the masked kernel (already in scope from the earlier definition). - Drop the drift-prone "Optimizations applied:" lists from kernel module docstrings; correct the now-stale `scale_a` / `scale_b` dtype to reflect uint8 on gfx950 / FP32 on gfx942. - Simplify the "Per-group matmul" comment in both test files; drop the specific backend (hipBLASLt) claim. - Add missing `device`, `scale_block_k`, `scale_block_n` entries to the masked test's `generate_*_inputs` Args docstring. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
- compile_grouped_fp8_gemm -> compile_grouped_gemm_blockscale_contiguous - compile_masked_grouped_fp8_gemm -> compile_grouped_gemm_blockscale_masked The new names mirror the file names exactly (drop "fp8_gemm", incorporate "blockscale" + the contiguous/masked variant), making the call site self-documenting. Internal kernel/launcher symbols and the JIT cache-key strings are renamed in lockstep. Test imports and call sites updated. DeepGEMM op references in the docstrings (`m_grouped_fp8_gemm_nt_contiguous` / `..._masked`) are unchanged — those are DeepGEMM's actual symbol names. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Pulls the byte-identical helpers (ARCH/DTYPE_FP8/USE_UE8M0 constants, ceil_div, align, fp32_to_e8m0, fp32_e8m0_to_byte, quantize_b_to_fp8, _as_i8) out of the two test files into a new shared module. quantize_to_fp8 (2D contig) and quantize_a_masked_to_fp8 (3D masked) stay in their respective test files — parallel implementations for different tensor ranks, not duplicates. Step 1 of the modularity refactor. Test-only change; correctness 30/30. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Adds kernels/grouped_gemm_blockscale_common.py with validate_params, out_mlir_for, and compute_compile_constants (returning a namedtuple of the byte-identical compile-time scalars). Both compile_* functions call these at the top instead of inlining. Step 2 of the modularity refactor. Pure-Python helpers, no MLIR emit; ISA byte-identical for both kernels at the canonical N=2048 K=7168 m=256 shape; correctness 30/30. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Move the ping-pong A + CShuffle epilogue LDS allocation into a shared
setup_lds_allocation helper in grouped_gemm_blockscale_common, called
by both compile_grouped_gemm_blockscale_{contiguous,masked}. ISA
byte-identical for both kernels.
Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Move the prefetch_a_tile + store_a_tile_to_lds closures (and the tile-coordinate precomputation that feeds them) into a new make_a_tile_loaders factory in grouped_gemm_blockscale_common. Optional m_in/group_idx parameters control the 3D group offset that the masked path adds inside prefetch_a_tile. When both are None (contig) the offset addition is skipped entirely so the contig path emits no extra MLIR ops. Computing the offset inside the helper shares the underlying _k_div4_factor with the row->idx mapping, preserving ISA byte-identity for both kernels. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Move the byte-identical inner helpers out of both compile_* into grouped_gemm_blockscale_common: - make_lds_loader → lds_load_packs_k64 closure - make_b_loader → load_b_tile closure (with load_b_pack inlined) - pack_i64x4_to_i32x8 (pure function, no closure capture) - make_hot_loop_scheduler → hot_loop_scheduler closure The kernel-local row_a_lds_base / col_offset_base_bytes / mfma_res_ty / ku_per_sb / sched_barrier(0) stay inline since they are one-liners that the next steps still need to reference. ISA byte-identical for both kernels. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Move the cross-tile E8M0 scale prefetch and the per-K-tile MFMA compute closure into make_prefetch_scales / make_compute_tile in grouped_gemm_blockscale_common — the substantive bulk of the inner kernel logic. Optional sa_group_off parameter controls the 3D scale_a offset that the masked path needs (group_idx * c_scale_k * m_in). When None (contig) the addition is skipped entirely so the contig path emits no extra MLIR ops. ISA byte-identical for both kernels. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Move the prologue + alternating ping/pong K-tile loop into make_pingpong_kloop in grouped_gemm_blockscale_common. The loop body is byte-identical between contig and masked, so this factory has no offset parameters — it just wires up the prefetch/store/compute closures the caller already built. ISA byte-identical for both kernels. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Move write_row_to_lds + store_pair into make_epilogue_writers in grouped_gemm_blockscale_common. Optional d_group_off parameter controls the 3D D offset (group_idx * m_in * n_in) that the masked path adds inside store_pair; None for contig keeps the contig MLIR identical. The mfma_epilog call stays at the call-site since it is just one line of plumbing. ISA byte-identical for both kernels. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
After the helper extractions, math_dialect, ArithValue, crd2idx, lds_store_16b_xor16, load_b_pack_k32, swizzle_xor16, tile_chunk_coord_i32 and pack_i64x4_to_i32x8 are only referenced inside grouped_gemm_blockscale_common; the per-kernel files no longer need them. ISA byte-identical for both kernels. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Move the byte-identical post-prologue setup into two helpers in
grouped_gemm_blockscale_common:
- compute_mfma_tiling: pure-Python derivation of m_repeat /
num_waves / n_per_wave / num_acc_n / num_accs from tile_m, tile_n.
- init_accumulators: emit the FP32 zero accumulator constant and
replicate it for all MFMA result slots.
- make_n_block_coords: per-wave N-tile base, scale_b N-block index
list, preshuffle B layout, and the per-MFMA (n_blk, n_intra)
coordinate lists for the all-groups-concatenated B layout.
c_scale_k is returned from make_n_block_coords so the downstream
prefetch_scales / compute_tile call sites reuse the same constant.
ISA byte-identical for both kernels.
Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
|
@aryaman-gupta flops and Tb/s data compared with aiter ck moe? |
| is_valid = arith.cmpi(arith.CmpIPredicate.sge, group_id_i32, fx.Int32(0)) | ||
|
|
||
| # Early exit for invalid blocks (padding rows) | ||
| _if_valid = scf.IfOp(is_valid) |
There was a problem hiding this comment.
use normal if instead of scf?
There was a problem hiding this comment.
050e26d refactors to use the _if_then helper function pattern used by the moe_blockscale_2stage.py. I kept scf instead of normal if for consistency with existing kernels, several of which use scf if for early-exit, including the moe_blockscale_2stage.py kernel.
| i32_num_groups: fx.Int32, | ||
| ): | ||
| # Convert runtime parameters to index type | ||
| m_in = arith.index_cast(T.index, i32_m) |
…it guard Hide the ir.InsertionPoint(if_op.then_block) boilerplate behind a shared scf_then_region context manager in grouped_gemm_blockscale_common. The helper auto-appends a scf.YieldOp([]) terminator if the body did not add one, so the explicit yield at the end of the if-body in both compile_* functions is removed. ISA byte-identical for both kernels. Addresses PR review feedback asking for the early-exit guard to use a more idiomatic syntax. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Replace arith.index_cast(T.index, i32_*) with fx.Index(i32_*) for the runtime-parameter conversions at the kernel entry and grid-dimension setup in the launcher. fx.Index accepts an i32 ir.Value and routes through the same index_cast under the hood, so MLIR is unchanged (ISA byte-identical for both kernels). Matches the convention in blockscale_preshuffle_gemm.py and moe_blockscale_2stage.py. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…s_bytes create_buffer_resource auto-casts an index-typed num_records_bytes to i64 internally (buffer_ops.py:295-303), so the explicit arith.index_cast(T.i64, *_nbytes) wrapper at every call site is redundant boilerplate. Pass the *_nbytes index value directly, matching the convention in blockscale_preshuffle_gemm.py:218-227. Also use fx.Index(group_id_i32) for the contig kernel's group ID buffer-load conversion, mirroring the kernel-arg fx.Index(i32_*) pattern. ISA byte-identical for both kernels. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
@coderfeli here are the comparisons with TFLOPS (
|
| E | M_per_E | N | K | FlyDSL masked | FlyDSL contig | CK ck_moe_stage1 | masked / CK | contig / CK |
|---|---|---|---|---|---|---|---|---|
| 8 | 64 | 2048 | 7168 | 202.3 | 196.1 | 149.8 | 1.35× | 1.31× |
| 8 | 128 | 2048 | 7168 | 392.5 | 389.9 | 300.2 | 1.31× | 1.30× |
| 8 | 256 | 2048 | 7168 | 776.6 | 766.7 | 586.8 | 1.32× | 1.31× |
| 8 | 512 | 2048 | 7168 | 1359.2 | 1353.2 | 830.3 | 1.64× | 1.63× |
| 8 | 64 | 7168 | 2048 | 434.9 | 410.6 | 394.8 | 1.10× | 1.04× |
| 8 | 128 | 7168 | 2048 | 821.2 | 829.5 | 519.9 | 1.58× | 1.60× |
| 8 | 256 | 7168 | 2048 | 992.6 | 990.2 | 627.3 | 1.58× | 1.58× |
| 8 | 512 | 7168 | 2048 | 1357.5 | 1349.2 | 743.2 | 1.83× | 1.82× |
Bandwidth (TB/s, matmul-input bytes only: (M·K + E·N·K) / time)
| E | M_per_E | N | K | FlyDSL masked | FlyDSL contig | CK ck_moe_stage1 |
|---|---|---|---|---|---|---|
| 8 | 64 | 2048 | 7168 | 1.63 | 1.58 | 1.21 |
| 8 | 128 | 2048 | 7168 | 1.63 | 1.62 | 1.25 |
| 8 | 256 | 2048 | 7168 | 1.71 | 1.68 | 1.29 |
| 8 | 512 | 2048 | 7168 | 1.66 | 1.65 | 1.01 |
| 8 | 64 | 7168 | 2048 | 3.43 | 3.24 | 3.11 |
| 8 | 128 | 7168 | 2048 | 3.27 | 3.30 | 2.07 |
| 8 | 256 | 7168 | 2048 | 2.01 | 2.00 | 1.27 |
| 8 | 512 | 7168 | 2048 | 1.42 | 1.41 | 0.78 |
Configuration: gfx950 (MI350), per-1×128 blockscale FP8 inputs (1×128 for activations, 128×128 for weights), BF16 output, balanced expert assignment with topk=1, FlyDSL tile_m=128, CK block_m=64 (CK's production heuristic for tokens > 32). 100 iters / 10 warmup, max_m_factor=1.0 so all Experts receive the same number of tokens. FlyDSL contiguous numbers are essentially identical to masked at this factor; the contiguous kernel's additional advantage on imbalanced/heavily-padded inputs is not exercised here.
TFLOPS: counts only matmul work (2·M·N·K). ck_moe_stage1 additionally fuses a SiLU+multiply epilogue (~0.01–0.04% of matmul FLOPs but ~1–5% of wall-clock time at these shapes); CK TFLOPS is understated by that fraction.
BW: matmul-input bytes only — output writes excluded
Summary
Adds two new grouped and batched FP8 GEMM kernels with blockscaling that mirror the DeepGEMM API on AMD CDNA GPUs:
kernels/grouped_gemm_blockscale_contiguous.py→compile_grouped_gemm_blockscale_contiguousm_grouped_fp8_gemm_nt_contiguouskernels/grouped_gemm_blockscale_masked.py→compile_masked_grouped_gemm_blockscale_maskedm_grouped_fp8_gemm_nt_maskedThese ops are core to MoE inference workloads where the gate/up/down projections of multiple experts are batched into a single grouped GEMM. DeepSeek-V3 is the most prominent example — its expert MLPs use FP8 with per-token activation scaling and per-block weight scaling, exactly the configuration these kernels accept. Adding FlyDSL implementations unblocks running such models on AMD hardware via the same call sites that already use DeepGEMM on NVIDIA.
The Python signatures (tensor shapes, dtypes,
scale_a/scale_blayouts including the transposed[scale_k, M]activation-scale layout,grouped_layoutsemantics with-1for padding,masked_msemantics for the masked variant,(1, 128)×(128, 128)block-scale granularity) are designed to match the DeepGEMM ops byte-for-byte so call-site code can switch backend by import alone.Test plan
tests/kernels/test_grouped_gemm_blockscale_contiguous.pyand..._masked.py. Coverage spans 1–8 groups,m_per_groupfrom 100 (unaligned) to 1024, plus DeepSeek-V3 shapes (N=2048 K=7168 and N=7168 K=2304) at both bf16 and f16 outputs.logits_diff_threshold=1e-3.tests/arch_compat.py:CDNA_ONLY_TESTSso non-CDNA CI auto-skips.pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower]for correct CI bucketing.Run locally:
DeepGEMM conformance details
A few things were chosen explicitly to match DeepGEMM's behavior so any divergence in numerical results vs the NVIDIA path would surface as a real bug rather than a tolerance / convention mismatch:
calc_diffformula as DeepGEMM (cosine-similarity-style normalized error) atlogits_diff_threshold=1e-3— matching DeepGEMM's own FP8 GEMM tests (tests/test_legacy.py:35,tests/test_fp8_fp4.py:194).ceil_to_ue8m0(matching DeepGEMM'sdeep_gemm/utils/math.py:13). Truncation would shrink the scale and cause FP8 saturation on every block; the ceiling is what keepsx / scale_e8m0 ≤ fp8_max.a_f32 @ b_f32.T) — same as DeepGEMM's tests. The reference contains zero quantization error, so the diff measured against it is the actual end-to-end FP8 → ground-truth error budget.mfma_scale_f32_16x16x128_f8f6f4, with E8M0 bytes pre-extracted on the host and loaded asuint8(analogous to DeepGEMM'spack_ue8m0_to_intfor the SM100 path). On MI300 (gfx942) where the MFMA-scale instruction is unavailable, scaling is applied in software.Notes for reviewers
test_blockscale_preshuffle_gemm.py. Reference uses pre-quantization FP32 instead of dequant-then-matmul, and tolerance is1e-3instead of the repo default2e-3— both deliberately chosen to match DeepGEMM's convention as described above.