Skip to content

Add A16W4 MoE GEMM stage2 kernel (BF16 activations x MXFP4 weights)#431

Open
apicciau wants to merge 7 commits intomainfrom
apicciau/a16w4-moe-gemm2
Open

Add A16W4 MoE GEMM stage2 kernel (BF16 activations x MXFP4 weights)#431
apicciau wants to merge 7 commits intomainfrom
apicciau/a16w4-moe-gemm2

Conversation

@apicciau
Copy link
Copy Markdown

@apicciau apicciau commented Apr 23, 2026

Motivation

Ports the A16W4 MoE GEMM stage2 kernel from aiter into FlyDSL: BF16 activations x MXFP4
(FP4 E2M1, per-1x32 block scale) weights, down-projection in a MoE FFN block.
Stage1 (compile_a16w4_moe_gemm1) is a follow-up.

Technical Details

kernels/mfma_preshuffle_pipeline.py — MXFP4 kpack=16 load/unpack helpers (Zan Zhang):

  • load_b_raw_mxfp4 / load_b_raw_mxfp4_dwordx4: load 4 / 16 bytes of FP4 nibbles via buffer load
  • unpack_b_mxfp4_bf16: unpacks 8 FP4 E2M1 nibbles to bf16 pairs; uses GFX950 hardware path (v_cvt_scalef32_pk_bf16_fp4) or software fallback

kernels/mixed_moe_gemm_2stage.py — new compile_a16w4_moe_gemm2 (Zan Zhang):
ping-pong double-buffered mfma_f32_16x16x32_bf16 pipeline over pre-shuffled MXFP4 weights.
Supports atomic accumulation, per-row routing weights, optional bias, and split-K (k_batch).

Ported from aiter/aiter/ops/flydsl/kernels/a16w4_moe_gemm_2stage.py with two adaptations:

  • idx2crd / layout_getfx.idx2crd / fx.get (FlyDSL returns !fly.int_tuple, not a Python tuple)
  • cache_modifier dropped from _buffer_load_vec (unsupported; call site passed 0, no impact)

tests/kernels/test_moe_gemm.py — new parametrised test test_moe_gemm2_a16w4.

bench_a16w4_moe_gemm2.py — standalone FlyDSL vs CK Tile benchmark, two shapes, tile_n sweep.

Performance vs CK Tile stage-2 baseline

gfx950 (MI355X), 100 iters / 50 warmup. Inputs built via torch_moe_stage1 + moe_sorting.
TFLOPS = 2 x tokens x topk x inter_dim x model_dim / latency. fly/ck > 1.0 = FlyDSL faster.

GPT-OSS — model_dim=3072, inter_dim=3072, E=128, topk=4, tile_m=16

tokens tile_n fly (us) fly (TFLOPS) ck (us) ck (TFLOPS) fly/ck
1 128 30.44 2.48 10.46 7.22 0.34x
1 256 27.16 2.78 15.09 5.00 0.56x
4 128 30.94 9.76 21.81 13.85 0.70x
4 256 26.37 11.45 21.74 13.89 0.82x
8 128 28.42 21.25 35.48 17.02 1.25x
8 256 27.51 21.96 35.66 16.94 1.30x
16 128 49.01 24.65 60.76 19.88 1.24x
16 256 49.15 24.58 61.16 19.75 1.24x
32 128 102.63 23.54 110.15 21.93 1.07x
32 256 106.52 22.68 110.37 21.89 1.04x
64 128 105.31 45.88 109.66 44.06 1.04x
64 256 107.18 45.08 109.59 44.09 1.02x
128 128 105.10 91.94 110.83 87.19 1.05x
128 256 108.13 89.37 110.85 87.17 1.03x
256 128 107.36 180.02 107.39 179.97 1.00x
256 256 110.90 174.28 107.11 180.45 0.97x

DeepSeek-R1 — model_dim=7168, inter_dim=1024, E=384, topk=8, tile_m=16

tokens tile_n fly (us) fly (TFLOPS) ck (us) ck (TFLOPS) fly/ck
1 128 29.30 4.01 11.44 10.26 0.39x
1 256 26.57 4.42 11.51 10.20 0.43x
4 128 30.08 15.62 30.42 15.45 1.01x
4 256 27.40 17.14 30.26 15.52 1.10x
8 128 48.02 19.56 51.62 18.20 1.07x
8 256 42.56 22.08 51.35 18.30 1.21x
16 128 93.05 20.19 94.49 19.89 1.02x
16 256 87.29 21.53 94.30 19.93 1.08x
32 128 167.89 22.39 177.29 21.20 1.06x
32 256 163.51 22.98 177.55 21.17 1.09x
64 128 243.72 30.84 268.87 27.96 1.10x
64 256 241.31 31.15 268.25 28.02 1.11x
128 128 247.91 60.64 266.13 56.49 1.07x
128 256 243.48 61.74 266.38 56.43 1.09x
256 128 251.83 119.39 260.93 115.22 1.04x
256 256 247.42 121.51 261.29 115.07 1.06x

tokens <= 4 (GPT-OSS) / tokens = 1 (DeepSeek): CK Tile faster — FlyDSL launches a full
expert grid even when most experts are idle. Closing this gap is a follow-up.
tokens >= 8: FlyDSL 4–30% faster on GPT-OSS; 1–21% faster on DeepSeek (tile_n=256 preferred there).

Test Plan

test_moe_gemm2_a16w4 in tests/kernels/test_moe_gemm.py: three parametrised shapes
(small, medium, k_batch=2), MXFP4-quantised weights via aiter.get_torch_quant, output
compared to FP32 dequantised reference. Gated on gfx950+.

Test Result

tests/kernels/test_moe_gemm.py::test_moe_gemm2_a16w4[a16w4-s2-small]   PASSED
tests/kernels/test_moe_gemm.py::test_moe_gemm2_a16w4[a16w4-s2-medium]  PASSED
tests/kernels/test_moe_gemm.py::test_moe_gemm2_a16w4[a16w4-s2-kbatch2] PASSED

Follow-ups (out of scope)

  • Low-token gap (tokens <= 4): reduce grid dispatch overhead for sparse expert grids.
  • Stage-1 A16W4 (compile_a16w4_moe_gemm1): not yet in FlyDSL.
  • aiter: get_flydsl_stage2_kernels needs k_batch enumeration and passthrough.

Submission Checklist

@apicciau apicciau self-assigned this Apr 23, 2026
@apicciau apicciau requested a review from coderfeli April 28, 2026 09:03
@apicciau apicciau force-pushed the apicciau/a16w4-moe-gemm2 branch from 6777ab1 to 6f59d49 Compare April 29, 2026 10:00
@apicciau apicciau marked this pull request as ready for review April 29, 2026 10:00
Copilot AI review requested due to automatic review settings April 29, 2026 10:00
@apicciau apicciau requested a review from Zzz9990 April 29, 2026 10:04
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Ports an A16W4 MoE GEMM stage-2 kernel into FlyDSL to support BF16 activations × MXFP4 (FP4 E2M1 + E8M0 scales) weights, along with new load/unpack helpers and a correctness test.

Changes:

  • Add MXFP4 (FP4 E2M1) B-load + unpack helpers (including a GFX950 HW conversion intrinsic path).
  • Add compile_a16w4_moe_gemm2 (stage-2 kernel) using a ping-pong MFMA pipeline over preshuffled MXFP4 weights/scales.
  • Add a new parametrized correctness test for the A16W4 stage-2 kernel on gfx950+.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.

File Description
kernels/mfma_preshuffle_pipeline.py Adds raw MXFP4 load helpers and FP4→BF16 unpack (HW + SW paths).
kernels/mixed_moe_gemm_2stage.py Adds compile_a16w4_moe_gemm2 implementation and integrates MXFP4 helpers/epilogues.
tests/kernels/test_moe_gemm.py Adds test_moe_gemm2_a16w4 parametrized correctness coverage for the new kernel.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread kernels/mixed_moe_gemm_2stage.py Outdated
Comment on lines +5281 to +5286
)
bb0, bb1 = unpack_b_mxfp4_bf16(
b_raw_ku, arith, vector,
scale_f32=b_scales[ku][ni],
)

Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unpack_b_mxfp4_bf16(...) defaults use_hw_cvt=True, and this call site passes a non-None scale_f32, so it will always try to use the GFX950-only llvm.amdgcn.cvt.scalef32.pk.bf16.fp4 intrinsic. If this kernel is ever compiled on a non-gfx95 target, compilation will likely fail instead of cleanly using the software fallback. Plumb an arch check (gpu_arch) to disable the hw path on unsupported targets (or explicitly raise a clearer arch requirement here).

Copilot uses AI. Check for mistakes.
elem_type=elem_type,
vec_elems=16,
elem_bytes=1,
offset_in_bytes=True,
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_b_raw_mxfp4_dwordx4 accepts a cache_modifier argument but never forwards it into _buffer_load_vec/buffer_load, so callers cannot actually control the cache policy (even though compile_a16w4_moe_gemm2 passes cache_modifier=2). Either plumb cache_modifier through to _buffer_load_vec(..., cache_modifier=cache_modifier) or remove the parameter (and the non-default call sites) to avoid misleading behavior.

Suggested change
offset_in_bytes=True,
offset_in_bytes=True,
cache_modifier=cache_modifier,

Copilot uses AI. Check for mistakes.
Weights are MXFP4 E2M1 with E8M0 per-1x32 block scales,
pre-shuffled via shuffle_weight_a16w4 / shuffle_scale_a16w4 from aiter.
Result is compared against torch_moe_gemm2 in FP32.
"""
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test unconditionally imports and uses aiter (aiter.ops.shuffle, aiter.fused_moe.moe_sorting, aiter.get_torch_quant). In environments where aiter isn't installed (or failed to import and HAS_AITER is False earlier in this file), this will error instead of skipping. Add a guard like if not HAS_AITER: pytest.skip(...) before the aiter imports/usage so the suite degrades gracefully like the other MoE tests in this file.

Suggested change
"""
"""
if not HAS_AITER:
pytest.skip("aiter is not available")

Copilot uses AI. Check for mistakes.
tokens,
model_dim,
inter_dim,
sorted_ids.shape[0],
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The i32_size_expert_ids_in argument passed to the kernel should be the length of sorted_expert_ids (number of M-blocks / expert-id entries), but the test currently passes sorted_ids.shape[0] (number of padded token slots). This inflates the grid Y dimension and, more importantly, programs an oversized buffer descriptor for arg_expert_ids/arg_sorted_token_ids, which can mask OOB issues. Pass sorted_expert_ids.shape[0] (and keep sorted_ids.shape[0] only where the kernel actually expects the sorted-token buffer size).

Suggested change
sorted_ids.shape[0],
sorted_expert_ids.shape[0],

Copilot uses AI. Check for mistakes.
Comment thread kernels/mixed_moe_gemm_2stage.py Outdated
Comment on lines +4625 to +4629
if out_is_bf16:
if not supports_bf16_global_atomics(gpu_arch):
raise ValueError(
f"out_dtype='bf16' requires bf16 global atomics, got arch={gpu_arch!r}"
)
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile_a16w4_moe_gemm2 calls supports_bf16_global_atomics(gpu_arch), but supports_bf16_global_atomics is not imported or defined anywhere in this module, so this will raise a NameError the first time this path runs (e.g., when out_dtype='bf16'). Import it from flydsl.runtime.device (with the same try/except fallback pattern used in kernels/moe_gemm_2stage.py), or otherwise define it in this file.

Copilot uses AI. Check for mistakes.
Comment thread kernels/mixed_moe_gemm_2stage.py Outdated
Comment on lines +4625 to +4628
if out_is_bf16:
if not supports_bf16_global_atomics(gpu_arch):
raise ValueError(
f"out_dtype='bf16' requires bf16 global atomics, got arch={gpu_arch!r}"
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bf16-atomics capability check is currently unconditional for out_dtype='bf16', but the kernel only requires bf16 atomics when accumulate=True. When accumulate=False, the epilogue does plain stores, so this check can incorrectly reject valid configurations on arches without bf16 atomics. Gate the check on accumulate (or on the actual atomic path taken).

Suggested change
if out_is_bf16:
if not supports_bf16_global_atomics(gpu_arch):
raise ValueError(
f"out_dtype='bf16' requires bf16 global atomics, got arch={gpu_arch!r}"
if out_is_bf16 and bool(accumulate):
if not supports_bf16_global_atomics(gpu_arch):
raise ValueError(
f"out_dtype='bf16' with accumulate=True requires bf16 global atomics, got arch={gpu_arch!r}"

Copilot uses AI. Check for mistakes.
@coderfeli
Copy link
Copy Markdown
Collaborator

@Zzz9990

@Zzz9990
Copy link
Copy Markdown
Contributor

Zzz9990 commented Apr 30, 2026

Please first use split_k_intra for testing in the cases of small tokens.

Zzz9990 and others added 6 commits May 7, 2026 10:41
* Add _cvt_scalef32_pk_bf16_fp4: GFX950 hardware path for converting 2 FP4 E2M1 nibbles to 2 bf16 via v_cvt_scalef32_pk_bf16_fp4 (1 VALU vs ~36 software).
* Add _fp4x4_in_i32_to_bf16x4_i64: software fallback converting 4 FP4 nibbles (packed in 4 bytes of i32) to 4 bf16 packed as i64.
* Add load_b_raw_mxfp4: loads 4 bytes (8 FP4 nibbles) from a kpack=16 preshuffle layout (shuffle_weight_a16w4 format) using ku-based k0/klane addressing.
* Add load_b_raw_mxfp4_dwordx4: dwordx4 variant loading the full 16-byte kpack for one sub-lane in a single buffer_load.
* Add unpack_b_mxfp4_bf16: dispatches to hw or sw path, returning (b0, b1) i64 pair for mfma_f32_16x16x32_bf16.
* Add _decode_e8m0_byte_to_f32: converts an E8M0 byte (i8) to f32 = 2^(e-127) via bit shift into position 23.
* Add _barrier: emits s_waitcnt + s_barrier as inline asm, bypassing LLVM SIInsertWaitcnts conservative insertion.
* Add compile_a16w4_moe_gemm2: stage2 down-projection GEMM for BF16 activations x MXFP4 (FP4 E2M1) weights with E8M0 block scales, using mfma_f32_16x16x32_bf16. Ported from aiter a16w4_moe_gemm_2stage.py with mechanical adaptations (fx.idx2crd, tuple indexing in place of layout_get).
* Add test_moe_gemm2_a16w4: exercises compile_a16w4_moe_gemm2 with BF16 activations and MXFP4 E2M1 weights pre-shuffled via shuffle_weight_a16w4/shuffle_scale_a16w4. Compares kernel output against torch_moe_gemm2 reference. Gated on gfx950+.
…ordx4

* Fix IntTuple component extraction: layout_get(coord, i) must map to fx.get(coord, i), not coord[i]. FlyDSL's fx.idx2crd returns a !fly.int_tuple MLIR value; Python indexing returns another IntTuple, not a scalar index.
* Drop cache_modifier from load_b_raw_mxfp4_dwordx4: FlyDSL's _buffer_load_vec does not support this parameter. The argument is an optional cache-policy hint with no correctness impact.
* Use inter_dim=3072 shapes so inter_dim/tile_k >= 2 (ping-pong pipeline requires at least 2 K-tile iterations).
* View float4 weight and scale tensors as uint8 before passing to kernel (DLPack does not support float4_e2m1fn_x2 or float8_e8m0fnu).
* Remove test_name kwarg from verify_output call (not part of its signature).
* Add k_batch parameter to _A16W4_SHAPES and test_moe_gemm2_a16w4
* New a16w4-s2-kbatch2 shape: tokens=4, tile_m=16, tile_n=128, k_batch=2
* Pass k_batch through to compile_a16w4_moe_gemm2 in test body
* Add bench_a16w4_moe_gemm2.py: two-phase standalone benchmark harness
@apicciau apicciau force-pushed the apicciau/a16w4-moe-gemm2 branch from e38ea4c to c5e44fc Compare May 7, 2026 10:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants