Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/api/kernels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ GEMM Kernels
MoE (Mixture-of-Experts) Kernels
----------------------------------

- ``kernels.moe_gemm_2stage`` -- MoE GEMM with 2-stage pipeline (stage1 + stage2)
- ``kernels.moe_gemm_2stage`` -- CDNA / MFMA MoE GEMM with 2-stage pipeline
- ``kernels.rdna_moe_gemm_2stage`` -- RDNA4 (``gfx120x`` / ``gfx1201``) MoE
GEMM 2-stage, fp16/bf16 WMMA
- ``kernels.moe_gemm_2stage_wmma_gfx1250`` -- gfx1250 (MI450) MoE GEMM
2-stage, fp16/bf16 WMMA with TDM
- ``kernels.mixed_moe_gemm_2stage`` -- Mixed-precision MoE GEMM
- ``kernels.moe_blockscale_2stage`` -- MoE with block-scale quantization (MXFP4)
- ``kernels.moe_reduce`` -- MoE reduction kernel: sums over the topk dimension
Expand Down
1 change: 1 addition & 0 deletions docs/architecture_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ FlyDSL/
│ ├── blockscale_preshuffle_gemm.py # Blockscale GEMM
│ ├── hgemm_splitk.py # FP16 GEMM split-K
│ ├── moe_gemm_2stage.py # MoE GEMM (2-stage gate/up + reduce)
│ ├── rdna_moe_gemm_2stage.py # RDNA4 (gfx120x) MoE GEMM (fp16/bf16 WMMA)
│ ├── moe_blockscale_2stage.py # MoE Blockscale GEMM
│ ├── mixed_moe_gemm_2stage.py # Mixed-precision MoE GEMM
│ ├── pa_decode_fp8.py # Paged attention decode (FP8)
Expand Down
30 changes: 27 additions & 3 deletions docs/prebuilt_kernels_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,13 @@ What operation do you need?
├── MoE (Mixture of Experts)
│ ├── Blockscale MoE (gate+up+reduce)
│ │ └── → kernels/moe_blockscale_2stage.py
│ └── Standard MoE (fp8/f16/bf16/int8/int4)
│ └── → kernels/moe_gemm_2stage.py
│ ├── Standard MoE (CDNA / MFMA, fp8/f16/bf16/int8/int4)
│ │ └── → kernels/moe_gemm_2stage.py
│ ├── RDNA4 MoE (gfx120x / gfx1201, fp16/bf16 WMMA)
│ │ └── → kernels/rdna_moe_gemm_2stage.py
│ └── GFX1250 MoE (MI450, WMMA fp16/bf16 + MXScale fp4/fp8/a8w4)
│ ├── → kernels/moe_gemm_2stage_wmma_gfx1250.py
│ └── → kernels/moe_gemm_2stage_mxscale_gfx1250.py
└── Building blocks
├── Warp/block reduction → kernels_common.py
Expand All @@ -301,7 +306,10 @@ What operation do you need?
| `kernels/preshuffle_gemm.py` | GEMM (preshuffle layout) |
| `kernels/blockscale_preshuffle_gemm.py` | Blockscale GEMM |
| `kernels/hgemm_splitk.py` | FP16 GEMM split-K |
| `kernels/moe_gemm_2stage.py` | MoE GEMM 2-stage (gate/up + reduce) |
| `kernels/moe_gemm_2stage.py` | MoE GEMM 2-stage (gate/up + reduce), CDNA / MFMA |
| `kernels/rdna_moe_gemm_2stage.py` | RDNA4 (gfx120x) MoE GEMM 2-stage, fp16/bf16 WMMA |
| `kernels/moe_gemm_2stage_wmma_gfx1250.py` | gfx1250 MoE GEMM 2-stage, fp16/bf16 WMMA |
| `kernels/moe_gemm_2stage_mxscale_gfx1250.py` | gfx1250 MoE GEMM 2-stage, fp4/fp8/a8w4 MXScale |
| `kernels/moe_blockscale_2stage.py` | MoE Blockscale 2-stage |
| `kernels/mixed_moe_gemm_2stage.py` | Mixed-precision MoE GEMM |
| `kernels/pa_decode_fp8.py` | Paged attention decode (FP8) |
Expand Down Expand Up @@ -330,6 +338,7 @@ What operation do you need?
| `tests/kernels/test_blockscale_preshuffle_gemm.py` | Blockscale GEMM |
| `tests/kernels/test_hgemm_splitk.py` | FP16 GEMM split-K |
| `tests/kernels/test_moe_gemm.py` | MoE GEMM |
| `tests/kernels/test_moe_gemm_rdna4.py` | RDNA4 MoE GEMM |
| `tests/kernels/test_moe_blockscale.py` | MoE Blockscale GEMM |
| `tests/kernels/test_moe_reduce.py` | MoE reduce kernel |
| `tests/kernels/test_pa.py` | Paged attention decode |
Expand All @@ -345,3 +354,18 @@ What operation do you need?
| `tests/kernels/test_vec_add.py` | Vector addition |
| `tests/kernels/test_quant.py` | Quantization utilities |
| `tests/kernels/benchmark_common.py` | Shared benchmark infrastructure |

## 9. RDNA4 MoE Notes

`kernels/rdna_moe_gemm_2stage.py` targets `gfx120x` only (Radeon RDNA4,
including `gfx1201`). It uses ``wmma_f32_16x16x16_{f16,bf16}`` with a simple
LDS pipeline and reuses the public `compile_moe_gemm1` / `compile_moe_gemm2`
/ `compile_moe_gemm2_ex` contract via the `make_moe_public_api` factory in
`kernels/moe_gemm_2stage.py`.

Measured starting points on `gfx1201`:

- Stage1: `tile_k=128`, `tile_n=64` for `tile_m` 16/32, and `tile_n=128` for `tile_m=64`
- Stage2: `tile_k=128`, `tile_n=64`
- `waves_per_eu=2` often helps stage1, while stage2 remains workload-dependent
- Reduce mode can outperform atomic mode for medium and large routed workloads, so both modes should be benchmarked on target shapes
97 changes: 97 additions & 0 deletions kernels/moe_gemm_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3392,6 +3392,103 @@ def mode(self) -> str:
return MoeGemm2Mode.REDUCE


# ---------------------------------------------------------------------------
# Arch-agnostic MoE public API factory
# ---------------------------------------------------------------------------
#
# Arch-specific MoE kernel modules (CDNA MFMA here, RDNA4 in
# ``rdna_moe_gemm_2stage.py``, gfx1250 in ``moe_gemm_2stage_wmma_gfx1250.py``)
# share the same public builder shape. ``make_moe_public_api`` generates
# ``compile_moe_gemm1`` / ``compile_moe_gemm2`` / ``compile_moe_gemm2_ex``
# bound to a given arch-specific ``compile_impl`` so each arch file does not
# have to hand-roll the same wrappers.


# Extra kwargs accepted at the public API layer so callers can stay uniform
# across CDNA / gfx1250 / RDNA4 even if some options are arch-specific; we
# strip the ones the target ``compile_impl`` does not actually use.
_MOE_PUBLIC_EXTRA_KWARGS = (
"group_size",
"use_cshuffle_epilog",
"num_buffers",
"use_tdm_gather",
"use_tdm_store",
"inst_prefetch",
"wave_specialized_tdm",
"cluster_m",
"cluster_n",
)


def _moe_strip_extras(kw: dict, allowed_extras: tuple = ()) -> dict:
result = dict(kw)
for key in _MOE_PUBLIC_EXTRA_KWARGS:
if key in allowed_extras:
continue
result.pop(key, None)
return result


def make_moe_public_api(compile_impl, *, pass_through_kwargs: tuple = ()):
"""Create ``compile_moe_gemm1`` / ``compile_moe_gemm2`` / ``compile_moe_gemm2_ex``.

``compile_impl`` must accept ``stage``, ``doweight``, ``accumulate`` and
the usual MoE kwargs (``model_dim``, ``inter_dim``, ``experts``, ``topk``,
``tile_m``, ``tile_n``, ``tile_k``, ``in_dtype``, ``out_dtype``,
``waves_per_eu``, ``expert_sched_mode``). ``pass_through_kwargs`` lets
arch-specific builders opt in to receiving extra public kwargs (e.g.
gfx1250 TDM / cluster knobs) that would otherwise be stripped.
"""

def compile_moe_gemm1(*, doweight_stage1, **kw):
kw = _moe_strip_extras(kw, pass_through_kwargs)
return compile_impl(stage=1, doweight=doweight_stage1, **kw)

def compile_moe_gemm2(*, doweight_stage2, accumulate=True, **kw):
kw = _moe_strip_extras(kw, pass_through_kwargs)
return compile_impl(
stage=2,
doweight=doweight_stage2,
accumulate=accumulate,
**kw,
)

def compile_moe_gemm2_ex(
*,
mode=MoeGemm2Mode.ATOMIC,
valid_mask=None,
zero_intermediate=True,
**kw,
):
if mode == MoeGemm2Mode.REDUCE:
gemm2_exe = compile_moe_gemm2(accumulate=False, **kw)
out_s = str(kw.get("out_dtype", "f16")).strip().lower()
if out_s in ("f16", "fp16", "half"):
dtype_str = "f16"
elif out_s in ("bf16", "bfloat16"):
dtype_str = "bf16"
else:
dtype_str = "f32"
reduce_exe = compile_moe_reduction(
topk=kw["topk"],
model_dim=kw["model_dim"],
dtype_str=dtype_str,
use_mask=(valid_mask is not None),
)
return _MoeGemm2ReduceWrapper(
gemm2_exe=gemm2_exe,
reduce_exe=reduce_exe,
topk=kw["topk"],
model_dim=kw["model_dim"],
out_dtype_str=dtype_str,
use_mask=(valid_mask is not None),
zero_intermediate=zero_intermediate,
)
return compile_moe_gemm2(accumulate=True, **kw)

return compile_moe_gemm1, compile_moe_gemm2, compile_moe_gemm2_ex


def compile_moe_gemm2_ex(
*,
model_dim: int,
Expand Down
122 changes: 71 additions & 51 deletions kernels/moe_gemm_2stage_common_gfx1250.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
# Copyright (c) 2025 FlyDSL Project Contributors


"""Shared utilities for gfx1250 MoE 2-stage kernels.
"""Shared utilities for gfx1250 (MI450 / GFX12) MoE 2-stage kernels.

Common helpers used by both the fp16 WMMA kernels and the mxscale
(fp4/fp8/a8w4) kernels.
Common helpers used by the gfx1250 fp16 WMMA kernels
(``moe_gemm_2stage_wmma_gfx1250.py``) and the gfx1250 mxscale fp4/fp8/a8w4
kernels (``moe_gemm_2stage_mxscale_gfx1250.py``).

RDNA4 (gfx120x) MoE helpers live in ``rdna_moe_gemm_2stage_common.py``.
The arch-agnostic MoE public API factory (``make_moe_public_api``) lives in
``kernels/moe_gemm_2stage.py``.
"""

from __future__ import annotations

import inspect
from typing import Any
from typing import Any, Optional, Tuple

from flydsl.runtime.device import get_rocm_arch as get_hip_arch

Expand All @@ -26,6 +31,68 @@ def _align_up(v: int, a: int) -> int:
return ((int(v) + int(a) - 1) // int(a)) * int(a)


def _moe_out_elem_ty(out_dtype: str, T):
"""gfx1250 MoE output element type mapping (f16 or bf16)."""
return T.f16 if out_dtype == "f16" else T.bf16


def _make_moe_wave_layout(*, m_warp: int, n_warp: int, WAVE_SIZE: int, fx):
return fx.make_layout(
(int(m_warp), int(n_warp), 2, 16),
(int(n_warp) * WAVE_SIZE, WAVE_SIZE, 16, 1),
)


def _make_wmma_sub_tiles(
*, wmma_m_rep: int, wmma_n_rep: int, WMMA_M: int, is_fp4: bool
) -> list:
sub_tiles = []
for wm in range(wmma_m_rep):
for wn in range(wmma_n_rep):
if is_fp4:
for half in range(2):
sub_tiles.append(
(wm * wmma_n_rep + wn, half * 8, wm * WMMA_M, wn * 2 + half)
)
else:
sub_tiles.append((wm * wmma_n_rep + wn, 0, wm * WMMA_M, wn))
return sub_tiles


def _finalize_alloc_and_launch_2d(
*,
ctx,
alloc,
launcher,
gx,
gy,
block_threads: int,
stream,
waves_per_eu,
ir,
cluster: Optional[Tuple[int, int, int]] = None,
):
with ir.InsertionPoint(ctx.gpu_module_body):
alloc.finalized = False
alloc.finalize()
for op in ctx.gpu_module_body.operations:
if hasattr(op, "attributes") and op.OPERATION_NAME == "gpu.func":
if waves_per_eu is not None and int(waves_per_eu) >= 1:
op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), int(waves_per_eu)
)
if cluster is not None:
op.attributes["rocdl.cluster_dims"] = ir.StringAttr.get(
f"{cluster[0]},{cluster[1]},{cluster[2]}"
)
launcher.launch(
grid=(gx, gy, 1),
block=(block_threads, 1, 1),
stream=stream,
cluster=cluster,
)


def _pick_fp4_warp_shape(tile_m: int, tile_n: int) -> tuple[int, int]:
"""Pick a legal (m_warp, n_warp) for compile_mxfp4_gemm constraints."""
for m_warp in (4, 2, 1):
Expand Down Expand Up @@ -108,59 +175,12 @@ def _pick_mxscale_launch_shape(data_format: str, route_tile_m: int, tile_n: int)
return _pick_fp16_single_launch_shape(int(route_tile_m), int(tile_n), max_total_warps=8)


def _make_moe_wave_layout(*, m_warp: int, n_warp: int, WAVE_SIZE: int, fx):
return fx.make_layout(
(int(m_warp), int(n_warp), 2, 16),
(int(n_warp) * WAVE_SIZE, WAVE_SIZE, 16, 1),
)


def _make_wmma_sub_tiles(
*, wmma_m_rep: int, wmma_n_rep: int, WMMA_M: int, is_fp4: bool
) -> list[tuple[int, int, int, int]]:
sub_tiles = []
for wm in range(wmma_m_rep):
for wn in range(wmma_n_rep):
if is_fp4:
for half in range(2):
sub_tiles.append((wm * wmma_n_rep + wn, half * 8, wm * WMMA_M, wn * 2 + half))
else:
sub_tiles.append((wm * wmma_n_rep + wn, 0, wm * WMMA_M, wn))
return sub_tiles


def _moe_out_elem_ty(out_dtype: str, T):
return T.f16 if out_dtype == "f16" else T.bf16


def _extract_sub8(acc, vec_base: int, *, vector, range_constexpr, ACC_VEC_SIZE: int):
if ACC_VEC_SIZE == 8:
return acc
return vector.shuffle(acc, acc, [vec_base + i for i in range_constexpr(8)])


def _finalize_alloc_and_launch_2d(*, ctx, alloc, launcher, gx, gy, block_threads: int, stream, waves_per_eu, ir,
cluster=None):
with ir.InsertionPoint(ctx.gpu_module_body):
alloc.finalized = False
alloc.finalize()
for op in ctx.gpu_module_body.operations:
if hasattr(op, "attributes") and op.OPERATION_NAME == "gpu.func":
if waves_per_eu is not None and int(waves_per_eu) >= 1:
op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), int(waves_per_eu)
)
if cluster is not None:
op.attributes["rocdl.cluster_dims"] = ir.StringAttr.get(
f"{cluster[0]},{cluster[1]},{cluster[2]}")
launcher.launch(
grid=(gx, gy, 1),
block=(block_threads, 1, 1),
stream=stream,
cluster=cluster,
)


def _emit_stage1_gate_up_epilogue(
*,
sub_tiles,
Expand Down
Loading
Loading