Skip to content

Adds Grouped and Batched GEMM kernels with blockscaling matching DeepGEMM API#433

Open
aryaman-gupta wants to merge 58 commits intomainfrom
aryaman/group-gemm
Open

Adds Grouped and Batched GEMM kernels with blockscaling matching DeepGEMM API#433
aryaman-gupta wants to merge 58 commits intomainfrom
aryaman/group-gemm

Conversation

@aryaman-gupta
Copy link
Copy Markdown

@aryaman-gupta aryaman-gupta commented Apr 23, 2026

Summary

Adds two new grouped and batched FP8 GEMM kernels with blockscaling that mirror the DeepGEMM API on AMD CDNA GPUs:

FlyDSL kernel DeepGEMM op
kernels/grouped_gemm_blockscale_contiguous.pycompile_grouped_gemm_blockscale_contiguous m_grouped_fp8_gemm_nt_contiguous
kernels/grouped_gemm_blockscale_masked.pycompile_masked_grouped_gemm_blockscale_masked m_grouped_fp8_gemm_nt_masked

These ops are core to MoE inference workloads where the gate/up/down projections of multiple experts are batched into a single grouped GEMM. DeepSeek-V3 is the most prominent example — its expert MLPs use FP8 with per-token activation scaling and per-block weight scaling, exactly the configuration these kernels accept. Adding FlyDSL implementations unblocks running such models on AMD hardware via the same call sites that already use DeepGEMM on NVIDIA.

The Python signatures (tensor shapes, dtypes, scale_a / scale_b layouts including the transposed [scale_k, M] activation-scale layout, grouped_layout semantics with -1 for padding, masked_m semantics for the masked variant, (1, 128) × (128, 128) block-scale granularity) are designed to match the DeepGEMM ops byte-for-byte so call-site code can switch backend by import alone.

Test plan

  • Unit / correctness: tests/kernels/test_grouped_gemm_blockscale_contiguous.py and ..._masked.py. Coverage spans 1–8 groups, m_per_group from 100 (unaligned) to 1024, plus DeepSeek-V3 shapes (N=2048 K=7168 and N=7168 K=2304) at both bf16 and f16 outputs.
  • 30 / 30 tests pass on MI350 (gfx950) at logits_diff_threshold=1e-3.
  • Tests registered in tests/arch_compat.py:CDNA_ONLY_TESTS so non-CDNA CI auto-skips.
  • pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] for correct CI bucketing.

Run locally:

PYTHONPATH=./ FLYDSL_RUNTIME_ENABLE_CACHE=0 pytest \
  tests/kernels/test_grouped_gemm_blockscale_contiguous.py \
  tests/kernels/test_grouped_gemm_blockscale_masked.py

DeepGEMM conformance details

A few things were chosen explicitly to match DeepGEMM's behavior so any divergence in numerical results vs the NVIDIA path would surface as a real bug rather than a tolerance / convention mismatch:

  • Tolerance. The correctness threshold uses the same calc_diff formula as DeepGEMM (cosine-similarity-style normalized error) at logits_diff_threshold=1e-3 — matching DeepGEMM's own FP8 GEMM tests (tests/test_legacy.py:35, tests/test_fp8_fp4.py:194).
  • E8M0 quantization on the host. Scale tensors are rounded with ceil_to_ue8m0 (matching DeepGEMM's deep_gemm/utils/math.py:13). Truncation would shrink the scale and cause FP8 saturation on every block; the ceiling is what keeps x / scale_e8m0 ≤ fp8_max.
  • Reference convention. Tests compute the reference from pre-quantization FP32 inputs (a_f32 @ b_f32.T) — same as DeepGEMM's tests. The reference contains zero quantization error, so the diff measured against it is the actual end-to-end FP8 → ground-truth error budget.
  • Hardware-vs-software scaling, matching DeepGEMM's CUDA / SM strategy. On MI350 (gfx950) the kernel uses the hardware E8M0 path via mfma_scale_f32_16x16x128_f8f6f4, with E8M0 bytes pre-extracted on the host and loaded as uint8 (analogous to DeepGEMM's pack_ue8m0_to_int for the SM100 path). On MI300 (gfx942) where the MFMA-scale instruction is unavailable, scaling is applied in software.

Notes for reviewers

  • Deliberate test-style divergence from test_blockscale_preshuffle_gemm.py. Reference uses pre-quantization FP32 instead of dequant-then-matmul, and tolerance is 1e-3 instead of the repo default 2e-3 — both deliberately chosen to match DeepGEMM's convention as described above.

aryaman-gupta and others added 30 commits March 27, 2026 18:58
Replace hardcoded test calls with argparse-based __main__ matching
the pattern used by other kernel tests (blockscale, moe, preshuffle).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace inline correctness checks and manual CUDA event benchmarking
with shared utilities from tests.test_common, matching the pattern
used by blockscale_preshuffle_gemm and other kernel tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
tile_m/n/k, out_dtype, num_iters, and num_warmup were parsed but
never passed to the test functions which hardcoded their own values.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Explicitly create the CPU reference output tensor with device='cpu'
to avoid conflict when torch.set_default_device('cuda') is active.
The reference matmul stays on CPU due to hipBLAS issues on this ROCm.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Split load_a_tile into prefetch_a_tile (Global→VGPR) and
store_a_tile_to_lds (VGPR→LDS). Moves ds_write after compute_tile
to match the MoE blockscale 2-stage pipeline, enabling future
instruction scheduling to interleave ds_write with trailing MFMAs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds hot_loop_scheduler() with coarse-grained sched_group_barrier
hints matching the moe_blockscale_2stage pattern. Placed after
store_a_tile_to_lds and before gpu.barrier(), only emitted when
a next tile actually exists (avoids LLVM assertion from mismatched
instruction counts on tail iterations).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds optional waves_per_eu parameter to compile functions. When set,
applies rocdl.waves_per_eu attribute to gpu.func for occupancy
tuning. Matches the pattern from blockscale_preshuffle_gemm and
moe_blockscale_2stage kernels.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Contiguous test:
- Generate unaligned M group sizes with -1 padding rows (DeepGEMM convention)
- Add unaligned M test cases (2g-100m, 4g-200m)
- Add DeepSeek-V3 shapes (2112x7168, 7168x2304)
- Add out_dtype parametrization (bf16 + f16)
- Zero out padding rows before comparison
- Add --waves_per_eu CLI arg

Masked test:
- Fix output buffer dtype bug (was hardcoded bf16, now respects out_dtype)
- Add sparse masking test (4g-512max-50m)
- Add DeepSeek-V3 shapes
- Add out_dtype parametrization (bf16 + f16)
- Wire out_dtype through generate_masked_grouped_gemm_inputs
- Add --waves_per_eu CLI arg

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
aryaman-gupta and others added 6 commits April 23, 2026 15:46
The s_a_vecs f32 loads are unused on the gfx950 HW path and would index
out-of-bounds against the int8-sized scale buffer if MLIR DCE ever
failed to eliminate them. Gating makes the gfx950/942 split explicit.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Add `prefetch_scales(k_tile_idx_py)` helper that loads the E8M0 byte for
each (mi, ni) of the next K-tile into VGPRs ahead of `compute_tile`.
Issued before `load_b_tile` in the ping-pong loop so scale-VMEM latency
overlaps the prior tile's MFMAs and the next B-tile load.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
- Register both tests in tests/arch_compat.py:CDNA_ONLY_TESTS so RDNA CI
  auto-skips them.
- Add `pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower]`
  so CI buckets them correctly.
- Add `torch.cuda.is_available()` module-level skip guard.
- Drop dead `--waves_per_eu` argparse arg (was accepted but never
  forwarded).
- Merge per-file `*_correctness` and `*_performance` into a single
  `test_grouped_fp8_gemm` / `test_masked_grouped_fp8_gemm` matching the
  test_blockscale_preshuffle_gemm convention.
- Move the per-group reference matmul from CPU to GPU (hipBLASLt). Test
  suite runtime drops ~70s → ~34s.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
- Remove unused `import os` from both kernel files.
- Remove orphan `# Helper: compute one K-tile from LDS + B tile` banner
  from both kernels (the function it labeled was renamed/refactored).
- Remove duplicate `c_scale_k = fx.Index(scale_k)` reassignment in the
  masked kernel (already in scope from the earlier definition).
- Drop the drift-prone "Optimizations applied:" lists from kernel
  module docstrings; correct the now-stale `scale_a` / `scale_b` dtype
  to reflect uint8 on gfx950 / FP32 on gfx942.
- Simplify the "Per-group matmul" comment in both test files; drop the
  specific backend (hipBLASLt) claim.
- Add missing `device`, `scale_block_k`, `scale_block_n` entries to the
  masked test's `generate_*_inputs` Args docstring.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
- compile_grouped_fp8_gemm           -> compile_grouped_gemm_blockscale_contiguous
- compile_masked_grouped_fp8_gemm    -> compile_grouped_gemm_blockscale_masked

The new names mirror the file names exactly (drop "fp8_gemm",
incorporate "blockscale" + the contiguous/masked variant), making
the call site self-documenting. Internal kernel/launcher symbols and
the JIT cache-key strings are renamed in lockstep. Test imports and
call sites updated. DeepGEMM op references in the docstrings
(`m_grouped_fp8_gemm_nt_contiguous` / `..._masked`) are unchanged —
those are DeepGEMM's actual symbol names.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
@aryaman-gupta aryaman-gupta changed the title Adds Grouped GEMM kernels matching DeepGEMM API Adds Grouped and Batched GEMM kernels matching DeepGEMM API Apr 24, 2026
@aryaman-gupta aryaman-gupta changed the title Adds Grouped and Batched GEMM kernels matching DeepGEMM API Adds Grouped and Batched GEMM kernels with blockscaling matching DeepGEMM API Apr 24, 2026
aryaman-gupta and others added 11 commits April 30, 2026 11:39
Pulls the byte-identical helpers (ARCH/DTYPE_FP8/USE_UE8M0 constants,
ceil_div, align, fp32_to_e8m0, fp32_e8m0_to_byte, quantize_b_to_fp8,
_as_i8) out of the two test files into a new shared module.

quantize_to_fp8 (2D contig) and quantize_a_masked_to_fp8 (3D masked)
stay in their respective test files — parallel implementations for
different tensor ranks, not duplicates.

Step 1 of the modularity refactor. Test-only change; correctness 30/30.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Adds kernels/grouped_gemm_blockscale_common.py with validate_params,
out_mlir_for, and compute_compile_constants (returning a namedtuple of
the byte-identical compile-time scalars). Both compile_* functions call
these at the top instead of inlining.

Step 2 of the modularity refactor. Pure-Python helpers, no MLIR emit;
ISA byte-identical for both kernels at the canonical N=2048 K=7168 m=256
shape; correctness 30/30.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Move the ping-pong A + CShuffle epilogue LDS allocation into a shared
setup_lds_allocation helper in grouped_gemm_blockscale_common, called
by both compile_grouped_gemm_blockscale_{contiguous,masked}. ISA
byte-identical for both kernels.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Move the prefetch_a_tile + store_a_tile_to_lds closures (and the
tile-coordinate precomputation that feeds them) into a new
make_a_tile_loaders factory in grouped_gemm_blockscale_common.

Optional m_in/group_idx parameters control the 3D group offset that
the masked path adds inside prefetch_a_tile. When both are None
(contig) the offset addition is skipped entirely so the contig path
emits no extra MLIR ops. Computing the offset inside the helper
shares the underlying _k_div4_factor with the row->idx mapping,
preserving ISA byte-identity for both kernels.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Move the byte-identical inner helpers out of both compile_* into
grouped_gemm_blockscale_common:
  - make_lds_loader → lds_load_packs_k64 closure
  - make_b_loader → load_b_tile closure (with load_b_pack inlined)
  - pack_i64x4_to_i32x8 (pure function, no closure capture)
  - make_hot_loop_scheduler → hot_loop_scheduler closure

The kernel-local row_a_lds_base / col_offset_base_bytes / mfma_res_ty /
ku_per_sb / sched_barrier(0) stay inline since they are one-liners
that the next steps still need to reference. ISA byte-identical for
both kernels.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Move the cross-tile E8M0 scale prefetch and the per-K-tile MFMA
compute closure into make_prefetch_scales / make_compute_tile in
grouped_gemm_blockscale_common — the substantive bulk of the inner
kernel logic.

Optional sa_group_off parameter controls the 3D scale_a offset that
the masked path needs (group_idx * c_scale_k * m_in). When None
(contig) the addition is skipped entirely so the contig path emits
no extra MLIR ops. ISA byte-identical for both kernels.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Move the prologue + alternating ping/pong K-tile loop into
make_pingpong_kloop in grouped_gemm_blockscale_common. The loop body
is byte-identical between contig and masked, so this factory has no
offset parameters — it just wires up the prefetch/store/compute
closures the caller already built. ISA byte-identical for both
kernels.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Move write_row_to_lds + store_pair into make_epilogue_writers in
grouped_gemm_blockscale_common. Optional d_group_off parameter
controls the 3D D offset (group_idx * m_in * n_in) that the masked
path adds inside store_pair; None for contig keeps the contig MLIR
identical. The mfma_epilog call stays at the call-site since it is
just one line of plumbing. ISA byte-identical for both kernels.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
After the helper extractions, math_dialect, ArithValue, crd2idx,
lds_store_16b_xor16, load_b_pack_k32, swizzle_xor16,
tile_chunk_coord_i32 and pack_i64x4_to_i32x8 are only referenced
inside grouped_gemm_blockscale_common; the per-kernel files no
longer need them. ISA byte-identical for both kernels.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Move the byte-identical post-prologue setup into two helpers in
grouped_gemm_blockscale_common:
  - compute_mfma_tiling: pure-Python derivation of m_repeat /
    num_waves / n_per_wave / num_acc_n / num_accs from tile_m, tile_n.
  - init_accumulators: emit the FP32 zero accumulator constant and
    replicate it for all MFMA result slots.
  - make_n_block_coords: per-wave N-tile base, scale_b N-block index
    list, preshuffle B layout, and the per-MFMA (n_blk, n_intra)
    coordinate lists for the all-groups-concatenated B layout.

c_scale_k is returned from make_n_block_coords so the downstream
prefetch_scales / compute_tile call sites reuse the same constant.
ISA byte-identical for both kernels.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
@coderfeli
Copy link
Copy Markdown
Collaborator

@aryaman-gupta flops and Tb/s data compared with aiter ck moe?

is_valid = arith.cmpi(arith.CmpIPredicate.sge, group_id_i32, fx.Int32(0))

# Early exit for invalid blocks (padding rows)
_if_valid = scf.IfOp(is_valid)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use normal if instead of scf?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

050e26d refactors to use the _if_then helper function pattern used by the moe_blockscale_2stage.py. I kept scf instead of normal if for consistency with existing kernels, several of which use scf if for early-exit, including the moe_blockscale_2stage.py kernel.

i32_num_groups: fx.Int32,
):
# Convert runtime parameters to index type
m_in = arith.index_cast(T.index, i32_m)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jse fx.Int32 directly?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 40f3cfd and 96946d4

aryaman-gupta and others added 3 commits May 4, 2026 15:51
…it guard

Hide the ir.InsertionPoint(if_op.then_block) boilerplate behind a
shared scf_then_region context manager in grouped_gemm_blockscale_common.
The helper auto-appends a scf.YieldOp([]) terminator if the body did not
add one, so the explicit yield at the end of the if-body in both
compile_* functions is removed.

ISA byte-identical for both kernels. Addresses PR review feedback
asking for the early-exit guard to use a more idiomatic syntax.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Replace arith.index_cast(T.index, i32_*) with fx.Index(i32_*) for the
runtime-parameter conversions at the kernel entry and grid-dimension
setup in the launcher. fx.Index accepts an i32 ir.Value and routes
through the same index_cast under the hood, so MLIR is unchanged
(ISA byte-identical for both kernels). Matches the convention in
blockscale_preshuffle_gemm.py and moe_blockscale_2stage.py.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…s_bytes

create_buffer_resource auto-casts an index-typed num_records_bytes to
i64 internally (buffer_ops.py:295-303), so the explicit
arith.index_cast(T.i64, *_nbytes) wrapper at every call site is
redundant boilerplate. Pass the *_nbytes index value directly,
matching the convention in blockscale_preshuffle_gemm.py:218-227.

Also use fx.Index(group_id_i32) for the contig kernel's group ID
buffer-load conversion, mirroring the kernel-arg fx.Index(i32_*)
pattern. ISA byte-identical for both kernels.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
@aryaman-gupta
Copy link
Copy Markdown
Author

@aryaman-gupta flops and Tb/s data compared with aiter ck moe?

@coderfeli here are the comparisons with aiter.ck_moe_stage1 from aiter/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh.

TFLOPS (2·M·N·K / time)

E M_per_E N K FlyDSL masked FlyDSL contig CK ck_moe_stage1 masked / CK contig / CK
8 64 2048 7168 202.3 196.1 149.8 1.35× 1.31×
8 128 2048 7168 392.5 389.9 300.2 1.31× 1.30×
8 256 2048 7168 776.6 766.7 586.8 1.32× 1.31×
8 512 2048 7168 1359.2 1353.2 830.3 1.64× 1.63×
8 64 7168 2048 434.9 410.6 394.8 1.10× 1.04×
8 128 7168 2048 821.2 829.5 519.9 1.58× 1.60×
8 256 7168 2048 992.6 990.2 627.3 1.58× 1.58×
8 512 7168 2048 1357.5 1349.2 743.2 1.83× 1.82×

Bandwidth (TB/s, matmul-input bytes only: (M·K + E·N·K) / time)

E M_per_E N K FlyDSL masked FlyDSL contig CK ck_moe_stage1
8 64 2048 7168 1.63 1.58 1.21
8 128 2048 7168 1.63 1.62 1.25
8 256 2048 7168 1.71 1.68 1.29
8 512 2048 7168 1.66 1.65 1.01
8 64 7168 2048 3.43 3.24 3.11
8 128 7168 2048 3.27 3.30 2.07
8 256 7168 2048 2.01 2.00 1.27
8 512 7168 2048 1.42 1.41 0.78

Configuration: gfx950 (MI350), per-1×128 blockscale FP8 inputs (1×128 for activations, 128×128 for weights), BF16 output, balanced expert assignment with topk=1, FlyDSL tile_m=128, CK block_m=64 (CK's production heuristic for tokens > 32). 100 iters / 10 warmup, max_m_factor=1.0 so all Experts receive the same number of tokens. FlyDSL contiguous numbers are essentially identical to masked at this factor; the contiguous kernel's additional advantage on imbalanced/heavily-padded inputs is not exercised here.

TFLOPS: counts only matmul work (2·M·N·K). ck_moe_stage1 additionally fuses a SiLU+multiply epilogue (~0.01–0.04% of matmul FLOPs but ~1–5% of wall-clock time at these shapes); CK TFLOPS is understated by that fraction.

BW: matmul-input bytes only — output writes excluded

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants