Add A16W4 MoE GEMM stage2 kernel (BF16 activations x MXFP4 weights)#431
Add A16W4 MoE GEMM stage2 kernel (BF16 activations x MXFP4 weights)#431
Conversation
6777ab1 to
6f59d49
Compare
There was a problem hiding this comment.
Pull request overview
Ports an A16W4 MoE GEMM stage-2 kernel into FlyDSL to support BF16 activations × MXFP4 (FP4 E2M1 + E8M0 scales) weights, along with new load/unpack helpers and a correctness test.
Changes:
- Add MXFP4 (FP4 E2M1) B-load + unpack helpers (including a GFX950 HW conversion intrinsic path).
- Add
compile_a16w4_moe_gemm2(stage-2 kernel) using a ping-pong MFMA pipeline over preshuffled MXFP4 weights/scales. - Add a new parametrized correctness test for the A16W4 stage-2 kernel on gfx950+.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
kernels/mfma_preshuffle_pipeline.py |
Adds raw MXFP4 load helpers and FP4→BF16 unpack (HW + SW paths). |
kernels/mixed_moe_gemm_2stage.py |
Adds compile_a16w4_moe_gemm2 implementation and integrates MXFP4 helpers/epilogues. |
tests/kernels/test_moe_gemm.py |
Adds test_moe_gemm2_a16w4 parametrized correctness coverage for the new kernel. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ) | ||
| bb0, bb1 = unpack_b_mxfp4_bf16( | ||
| b_raw_ku, arith, vector, | ||
| scale_f32=b_scales[ku][ni], | ||
| ) | ||
|
|
There was a problem hiding this comment.
unpack_b_mxfp4_bf16(...) defaults use_hw_cvt=True, and this call site passes a non-None scale_f32, so it will always try to use the GFX950-only llvm.amdgcn.cvt.scalef32.pk.bf16.fp4 intrinsic. If this kernel is ever compiled on a non-gfx95 target, compilation will likely fail instead of cleanly using the software fallback. Plumb an arch check (gpu_arch) to disable the hw path on unsupported targets (or explicitly raise a clearer arch requirement here).
| elem_type=elem_type, | ||
| vec_elems=16, | ||
| elem_bytes=1, | ||
| offset_in_bytes=True, |
There was a problem hiding this comment.
load_b_raw_mxfp4_dwordx4 accepts a cache_modifier argument but never forwards it into _buffer_load_vec/buffer_load, so callers cannot actually control the cache policy (even though compile_a16w4_moe_gemm2 passes cache_modifier=2). Either plumb cache_modifier through to _buffer_load_vec(..., cache_modifier=cache_modifier) or remove the parameter (and the non-default call sites) to avoid misleading behavior.
| offset_in_bytes=True, | |
| offset_in_bytes=True, | |
| cache_modifier=cache_modifier, |
| Weights are MXFP4 E2M1 with E8M0 per-1x32 block scales, | ||
| pre-shuffled via shuffle_weight_a16w4 / shuffle_scale_a16w4 from aiter. | ||
| Result is compared against torch_moe_gemm2 in FP32. | ||
| """ |
There was a problem hiding this comment.
This test unconditionally imports and uses aiter (aiter.ops.shuffle, aiter.fused_moe.moe_sorting, aiter.get_torch_quant). In environments where aiter isn't installed (or failed to import and HAS_AITER is False earlier in this file), this will error instead of skipping. Add a guard like if not HAS_AITER: pytest.skip(...) before the aiter imports/usage so the suite degrades gracefully like the other MoE tests in this file.
| """ | |
| """ | |
| if not HAS_AITER: | |
| pytest.skip("aiter is not available") |
| tokens, | ||
| model_dim, | ||
| inter_dim, | ||
| sorted_ids.shape[0], |
There was a problem hiding this comment.
The i32_size_expert_ids_in argument passed to the kernel should be the length of sorted_expert_ids (number of M-blocks / expert-id entries), but the test currently passes sorted_ids.shape[0] (number of padded token slots). This inflates the grid Y dimension and, more importantly, programs an oversized buffer descriptor for arg_expert_ids/arg_sorted_token_ids, which can mask OOB issues. Pass sorted_expert_ids.shape[0] (and keep sorted_ids.shape[0] only where the kernel actually expects the sorted-token buffer size).
| sorted_ids.shape[0], | |
| sorted_expert_ids.shape[0], |
| if out_is_bf16: | ||
| if not supports_bf16_global_atomics(gpu_arch): | ||
| raise ValueError( | ||
| f"out_dtype='bf16' requires bf16 global atomics, got arch={gpu_arch!r}" | ||
| ) |
There was a problem hiding this comment.
compile_a16w4_moe_gemm2 calls supports_bf16_global_atomics(gpu_arch), but supports_bf16_global_atomics is not imported or defined anywhere in this module, so this will raise a NameError the first time this path runs (e.g., when out_dtype='bf16'). Import it from flydsl.runtime.device (with the same try/except fallback pattern used in kernels/moe_gemm_2stage.py), or otherwise define it in this file.
| if out_is_bf16: | ||
| if not supports_bf16_global_atomics(gpu_arch): | ||
| raise ValueError( | ||
| f"out_dtype='bf16' requires bf16 global atomics, got arch={gpu_arch!r}" |
There was a problem hiding this comment.
The bf16-atomics capability check is currently unconditional for out_dtype='bf16', but the kernel only requires bf16 atomics when accumulate=True. When accumulate=False, the epilogue does plain stores, so this check can incorrectly reject valid configurations on arches without bf16 atomics. Gate the check on accumulate (or on the actual atomic path taken).
| if out_is_bf16: | |
| if not supports_bf16_global_atomics(gpu_arch): | |
| raise ValueError( | |
| f"out_dtype='bf16' requires bf16 global atomics, got arch={gpu_arch!r}" | |
| if out_is_bf16 and bool(accumulate): | |
| if not supports_bf16_global_atomics(gpu_arch): | |
| raise ValueError( | |
| f"out_dtype='bf16' with accumulate=True requires bf16 global atomics, got arch={gpu_arch!r}" |
|
Please first use split_k_intra for testing in the cases of small tokens. |
* Add _cvt_scalef32_pk_bf16_fp4: GFX950 hardware path for converting 2 FP4 E2M1 nibbles to 2 bf16 via v_cvt_scalef32_pk_bf16_fp4 (1 VALU vs ~36 software). * Add _fp4x4_in_i32_to_bf16x4_i64: software fallback converting 4 FP4 nibbles (packed in 4 bytes of i32) to 4 bf16 packed as i64. * Add load_b_raw_mxfp4: loads 4 bytes (8 FP4 nibbles) from a kpack=16 preshuffle layout (shuffle_weight_a16w4 format) using ku-based k0/klane addressing. * Add load_b_raw_mxfp4_dwordx4: dwordx4 variant loading the full 16-byte kpack for one sub-lane in a single buffer_load. * Add unpack_b_mxfp4_bf16: dispatches to hw or sw path, returning (b0, b1) i64 pair for mfma_f32_16x16x32_bf16.
* Add _decode_e8m0_byte_to_f32: converts an E8M0 byte (i8) to f32 = 2^(e-127) via bit shift into position 23. * Add _barrier: emits s_waitcnt + s_barrier as inline asm, bypassing LLVM SIInsertWaitcnts conservative insertion. * Add compile_a16w4_moe_gemm2: stage2 down-projection GEMM for BF16 activations x MXFP4 (FP4 E2M1) weights with E8M0 block scales, using mfma_f32_16x16x32_bf16. Ported from aiter a16w4_moe_gemm_2stage.py with mechanical adaptations (fx.idx2crd, tuple indexing in place of layout_get).
* Add test_moe_gemm2_a16w4: exercises compile_a16w4_moe_gemm2 with BF16 activations and MXFP4 E2M1 weights pre-shuffled via shuffle_weight_a16w4/shuffle_scale_a16w4. Compares kernel output against torch_moe_gemm2 reference. Gated on gfx950+.
…ordx4 * Fix IntTuple component extraction: layout_get(coord, i) must map to fx.get(coord, i), not coord[i]. FlyDSL's fx.idx2crd returns a !fly.int_tuple MLIR value; Python indexing returns another IntTuple, not a scalar index. * Drop cache_modifier from load_b_raw_mxfp4_dwordx4: FlyDSL's _buffer_load_vec does not support this parameter. The argument is an optional cache-policy hint with no correctness impact.
* Use inter_dim=3072 shapes so inter_dim/tile_k >= 2 (ping-pong pipeline requires at least 2 K-tile iterations). * View float4 weight and scale tensors as uint8 before passing to kernel (DLPack does not support float4_e2m1fn_x2 or float8_e8m0fnu). * Remove test_name kwarg from verify_output call (not part of its signature).
* Add k_batch parameter to _A16W4_SHAPES and test_moe_gemm2_a16w4 * New a16w4-s2-kbatch2 shape: tokens=4, tile_m=16, tile_n=128, k_batch=2 * Pass k_batch through to compile_a16w4_moe_gemm2 in test body * Add bench_a16w4_moe_gemm2.py: two-phase standalone benchmark harness
e38ea4c to
c5e44fc
Compare
Motivation
Ports the A16W4 MoE GEMM stage2 kernel from aiter into FlyDSL: BF16 activations x MXFP4
(FP4 E2M1, per-1x32 block scale) weights, down-projection in a MoE FFN block.
Stage1 (
compile_a16w4_moe_gemm1) is a follow-up.Technical Details
kernels/mfma_preshuffle_pipeline.py— MXFP4 kpack=16 load/unpack helpers (Zan Zhang):load_b_raw_mxfp4/load_b_raw_mxfp4_dwordx4: load 4 / 16 bytes of FP4 nibbles via buffer loadunpack_b_mxfp4_bf16: unpacks 8 FP4 E2M1 nibbles to bf16 pairs; uses GFX950 hardware path (v_cvt_scalef32_pk_bf16_fp4) or software fallbackkernels/mixed_moe_gemm_2stage.py— newcompile_a16w4_moe_gemm2(Zan Zhang):ping-pong double-buffered
mfma_f32_16x16x32_bf16pipeline over pre-shuffled MXFP4 weights.Supports atomic accumulation, per-row routing weights, optional bias, and split-K (
k_batch).Ported from
aiter/aiter/ops/flydsl/kernels/a16w4_moe_gemm_2stage.pywith two adaptations:idx2crd/layout_get→fx.idx2crd/fx.get(FlyDSL returns!fly.int_tuple, not a Python tuple)cache_modifierdropped from_buffer_load_vec(unsupported; call site passed 0, no impact)tests/kernels/test_moe_gemm.py— new parametrised testtest_moe_gemm2_a16w4.bench_a16w4_moe_gemm2.py— standalone FlyDSL vs CK Tile benchmark, two shapes, tile_n sweep.Performance vs CK Tile stage-2 baseline
gfx950 (MI355X), 100 iters / 50 warmup. Inputs built via
torch_moe_stage1+moe_sorting.TFLOPS =
2 x tokens x topk x inter_dim x model_dim / latency. fly/ck > 1.0 = FlyDSL faster.GPT-OSS — model_dim=3072, inter_dim=3072, E=128, topk=4, tile_m=16
DeepSeek-R1 — model_dim=7168, inter_dim=1024, E=384, topk=8, tile_m=16
tokens <= 4 (GPT-OSS) / tokens = 1 (DeepSeek): CK Tile faster — FlyDSL launches a full
expert grid even when most experts are idle. Closing this gap is a follow-up.
tokens >= 8: FlyDSL 4–30% faster on GPT-OSS; 1–21% faster on DeepSeek (tile_n=256 preferred there).
Test Plan
test_moe_gemm2_a16w4intests/kernels/test_moe_gemm.py: three parametrised shapes(small, medium, k_batch=2), MXFP4-quantised weights via
aiter.get_torch_quant, outputcompared to FP32 dequantised reference. Gated on gfx950+.
Test Result
Follow-ups (out of scope)
compile_a16w4_moe_gemm1): not yet in FlyDSL.get_flydsl_stage2_kernelsneeds k_batch enumeration and passthrough.Submission Checklist