[Kernel][Dialect] Support RDNA 3 / 3.5 WMMA (PoC): fix atom + add baseline GEMM kernel#485
[Kernel][Dialect] Support RDNA 3 / 3.5 WMMA (PoC): fix atom + add baseline GEMM kernel#485KerwinTsaiii wants to merge 5 commits intoROCm:mainfrom
Conversation
Empirically validated against gfx1151 hardware that the f32/i32
accumulator's per-lane natural M ordering is *interleaved* between lane
groups (lane n, val v -> M = 2*v + n/16), not contiguous as the previous
getThrValLayoutCD declared. The mismatch was being papered over by
calling reorderAccLaneValues 4 times per WMMA, costing ~32 ds_bpermute
per call.
Updating getThrValLayoutCD to match the real hardware layout and
dropping the runtime swizzles takes the per-WMMA cross-lane op count
from ~28 ds_bpermute down to 8 (the irreducible A/B operand expansion
for the WMMA256b 16-wide K input). Empirical speedup on gfx1151
runtime tests: up to 1.99x at 4096^3 GEMM.
Other fixes:
* The native bf16-acc WMMA variant uses op_sel-packed 16-wide C/D
lanes which conflicted with FlyDSL's vec8 fragment representation.
Now lowers through the f32 WMMA op + fptrunc, matching the f16-acc
fallback. (BF16 has the same exponent range as f32 so the
promote/truncate round-trip is overflow-safe.)
* Remove unused duplicateTo16WideSimple and reorderAccLaneValues
helpers (~110 LOC of dead code).
* Update the bf16 -> bf16 FileCheck test to match the new lowering.
* Expand runtime test coverage from 17 -> 25 cases:
- bf16-acc multi-K-tile (was only 16x16x16) and multi-block sizes
- f16-acc multi-K-tile
- iu8 -> i32 multi-block (was only single-tile contract test)
Co-authored-by: Cursor <cursoragent@cursor.com>
Adds kernels/rdna3_gemm.py — a production-shape GEMM kernel exercising
the MmaOpRDNA3_WMMAType atom path, intended as a starting point for
gfx110x / gfx115x / gfx1151 (Strix Halo iGPU). Compared to the existing
single-warp PoC kernels.rdna3_f16_gemm (which is for atom validation),
this kernel applies the standard atom-API optimizations:
* 2x2 wave layout (4 wave32 = 128 threads per workgroup), per-warp
reg_m × reg_n = 4x4 WMMA tile.
* BLOCK = 128 x 128 x 32. reg_m / reg_n / reg_k are implicit from the
BLOCK / (waves * WMMA_*) ratio, NOT passed as the make_tiled_mma
permutation argument (which is for K-direction permutation, a
distinct concept — initial attempt with that kept producing layout
mismatches).
* LDS staging for both A and B with a 2-stage ping-pong buffer.
LDS layout is row-major (K-fastest) to match the row-major GMEM
source; the layout-lowering pass handles the reshape into the
column-major WMMA fragment convention.
* Software pipelining: each loop iteration prefetches the next
K-tile from GMEM into LDS[(k+1)&1] while computing from
LDS[k&1], overlapping GMEM latency with WMMA compute.
* Runtime SCF outer loop over K-tile *pairs* (not constexpr
range_constexpr), so the IR doesn't blow up at K=4096 (which
would unroll 128 K-iterations).
* Tunable workgroup swizzle (group_m). On gfx1151 with 4 MB L2
no benefit was measured at 4096^3 so the default is 1; left as
a knob for other RDNA3 SKUs.
Performance on gfx1151 (Strix Halo iGPU, Radeon 8060S; 40 CUs;
theoretical WMMA F16 peak ~59 TFLOPS; rocBLAS hgemm ~30 TFLOPS):
* 1024^3 f16->f32: 11.8 TFLOPS (57.7% of rocBLAS, 19.9% of peak)
* 2048^3 f16->f32: 21.9 TFLOPS (62.3% of rocBLAS, 37.0% of peak)
* 4096^3 f16->f32: 17.5 TFLOPS (56.0% of rocBLAS, 29.5% of peak)
Compared to the single-warp PoC at ~7 TFLOPS this is roughly a
2.5x improvement; remaining gap to rocBLAS comes from missing
LDS bank-conflict swizzle, CShuffle epilogue, and hot-loop
scheduler hints — left for follow-up PRs.
Also adds:
* tests/kernels/test_rdna3_gemm.py — 19 correctness tests
covering f16/bf16 inputs, f32/f16/bf16 accumulators, multi-block
sizes (128 to 1024), and a wave/reg-config sweep.
* tests/bench_rdna3_gemm.py — rocBLAS-comparison benchmark
that reports TFLOPS, %rocBLAS, and %WMMA-peak. Includes
chip-table for back-calculating theoretical peak from per-CU
per-cycle throughput.
Co-authored-by: Cursor <cursoragent@cursor.com>
Two small follow-ups to the previous commit:
* Update the kernel docstring to clearly mark this as a PoC baseline
(~56% of rocBLAS at 4096^3 on gfx1151) rather than "production
starting point". List the optimizations that *would* close the gap
to rocBLAS for follow-up work (LDS XOR swizzle, CShuffle epilogue,
sched intrinsics, occupancy tuning, tail-tile predication) so users
don't mistake this for a peak-perf kernel.
* Remove the group_m workgroup-swizzle knob. Empirically every value
> 1 was slower than the trivial bid_m / bid_n mapping on gfx1151
(4 MB L2). Keeping a knob that only ever hurts perf is a code
smell; reverting to the simpler 2D grid drops a handful of div /
mod instructions per workgroup and makes the kernel match the
structure of `examples/03-tiledMma.py`.
44/44 tests still pass (25 atom validation + 19 production-shape
GEMM); benchmark unchanged at 17.7 TFLOPS @ 4096^3.
Co-authored-by: Cursor <cursoragent@cursor.com>
| // RDNA 3 / 3.5 uses `WMMA256bInsts`: the intrinsic selector expects | ||
| // 16-wide lane operands. FlyDSL's per-lane fragment is 8 elements where | ||
| // lane-group 0 carries K=0..7 and lane-group 1 carries K=8..15. | ||
| // | ||
| // Build a 16-wide lane vector as: | ||
| // low half = lane-local 8 values | ||
| // high half = values pulled from the paired lane (lane xor 16) | ||
| // | ||
| // `rocdl.ds_bpermute` is 32-bit granularity, so for f16/bf16 we: | ||
| // vector<8x{f16|bf16}> -> bitcast vector<4xi32> | ||
| // bpermute each dword from lane^16 | ||
| // bitcast back to vector<8x{f16|bf16}> | ||
| // concatenate [self8, pair8] to vector<16x...> | ||
| auto expandTo16WideWmma256 = [&](Value v) -> Value { | ||
| auto vt = cast<VectorType>(v.getType()); | ||
| if (vt.getShape().size() != 1 || vt.getShape()[0] != 8) | ||
| return v; | ||
| auto wideTy = VectorType::get({16}, vt.getElementType()); | ||
|
|
||
| Value pairedHalf = v; | ||
| if (vt.getElementType().isF16() || vt.getElementType().isBF16()) { | ||
| auto i32Ty = IntegerType::get(ctx, 32); | ||
| auto i64Ty = IntegerType::get(ctx, 64); | ||
| auto packedTy = VectorType::get({4}, i32Ty); | ||
|
|
||
| Value lane = ROCDL::ThreadIdXOp::create(builder, loc, i32Ty).getResult(); | ||
| Value c31 = LLVM::ConstantOp::create(builder, loc, i32Ty, | ||
| builder.getI32IntegerAttr(31)); | ||
| Value cNeg32 = LLVM::ConstantOp::create(builder, loc, i32Ty, | ||
| builder.getI32IntegerAttr(-32)); | ||
| Value c16 = LLVM::ConstantOp::create(builder, loc, i32Ty, | ||
| builder.getI32IntegerAttr(16)); | ||
| Value c2 = LLVM::ConstantOp::create(builder, loc, i32Ty, | ||
| builder.getI32IntegerAttr(2)); | ||
|
|
||
| Value laneInWave = LLVM::AndOp::create(builder, loc, lane, c31); | ||
| Value waveBase = LLVM::AndOp::create(builder, loc, lane, cNeg32); | ||
| Value pairedLaneInWave = LLVM::XOrOp::create(builder, loc, laneInWave, c16); | ||
| Value pairedLane = LLVM::OrOp::create(builder, loc, waveBase, pairedLaneInWave); | ||
| Value pairedLaneByteAddr = LLVM::ShlOp::create(builder, loc, pairedLane, c2); | ||
|
|
||
| Value packed = LLVM::BitcastOp::create(builder, loc, packedTy, v); | ||
| Value pairedPacked = LLVM::UndefOp::create(builder, loc, packedTy); | ||
| for (int i = 0; i < 4; ++i) { | ||
| Value idx = LLVM::ConstantOp::create(builder, loc, i64Ty, | ||
| builder.getI64IntegerAttr(i)); | ||
| Value laneWord = LLVM::ExtractElementOp::create(builder, loc, packed, idx); | ||
| Value pairedWord = ROCDL::DsBpermuteOp::create(builder, loc, i32Ty, | ||
| pairedLaneByteAddr, laneWord) | ||
| .getResult(); | ||
| pairedPacked = LLVM::InsertElementOp::create(builder, loc, pairedPacked, | ||
| pairedWord, idx); | ||
| } | ||
| pairedHalf = LLVM::BitcastOp::create(builder, loc, vt, pairedPacked); | ||
| } | ||
|
|
||
| Value lowHalf = v; | ||
| Value highHalf = pairedHalf; | ||
|
|
||
| SmallVector<int32_t> concatMask = {0, 1, 2, 3, 4, 5, 6, 7, | ||
| 8, 9, 10, 11, 12, 13, 14, 15}; | ||
| return LLVM::ShuffleVectorOp::create(builder, loc, wideTy, lowHalf, highHalf, | ||
| concatMask); | ||
| }; | ||
|
|
||
| // For IU8 / IU4, each lane carries <2xi32> (64 bits). WMMA256b expects | ||
| // a 128-bit source per operand lane, where the upper half comes from lane^16. | ||
| auto expandPackedI32To4WideWmma256 = [&](Value v) -> Value { | ||
| auto vt = dyn_cast<VectorType>(v.getType()); | ||
| if (!vt || vt.getShape().size() != 1 || vt.getShape()[0] != 2 || | ||
| !vt.getElementType().isInteger(32)) | ||
| return v; | ||
|
|
||
| auto i32Ty = IntegerType::get(ctx, 32); | ||
| auto i64Ty = IntegerType::get(ctx, 64); | ||
| auto wideTy = VectorType::get({4}, i32Ty); | ||
|
|
||
| Value lane = ROCDL::ThreadIdXOp::create(builder, loc, i32Ty).getResult(); | ||
| Value c31 = LLVM::ConstantOp::create(builder, loc, i32Ty, | ||
| builder.getI32IntegerAttr(31)); | ||
| Value cNeg32 = LLVM::ConstantOp::create(builder, loc, i32Ty, | ||
| builder.getI32IntegerAttr(-32)); | ||
| Value c16 = LLVM::ConstantOp::create(builder, loc, i32Ty, | ||
| builder.getI32IntegerAttr(16)); | ||
| Value c2 = LLVM::ConstantOp::create(builder, loc, i32Ty, | ||
| builder.getI32IntegerAttr(2)); | ||
|
|
||
| Value laneInWave = LLVM::AndOp::create(builder, loc, lane, c31); | ||
| Value waveBase = LLVM::AndOp::create(builder, loc, lane, cNeg32); | ||
| Value pairedLaneInWave = LLVM::XOrOp::create(builder, loc, laneInWave, c16); | ||
| Value pairedLane = LLVM::OrOp::create(builder, loc, waveBase, pairedLaneInWave); | ||
| Value pairedLaneByteAddr = LLVM::ShlOp::create(builder, loc, pairedLane, c2); | ||
|
|
||
| Value pairedHalf = LLVM::UndefOp::create(builder, loc, vt); | ||
| for (int i = 0; i < 2; ++i) { | ||
| Value idx = LLVM::ConstantOp::create(builder, loc, i64Ty, | ||
| builder.getI64IntegerAttr(i)); | ||
| Value laneWord = LLVM::ExtractElementOp::create(builder, loc, v, idx); | ||
| Value pairedWord = ROCDL::DsBpermuteOp::create(builder, loc, i32Ty, | ||
| pairedLaneByteAddr, laneWord) | ||
| .getResult(); | ||
| pairedHalf = LLVM::InsertElementOp::create(builder, loc, pairedHalf, | ||
| pairedWord, idx); | ||
| } | ||
|
|
||
| SmallVector<int32_t> concatMask = {0, 1, 2, 3}; | ||
| return LLVM::ShuffleVectorOp::create(builder, loc, wideTy, v, pairedHalf, | ||
| concatMask); | ||
| }; | ||
|
|
||
| // F32 <- F16/F16 | ||
| if (elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF32()) { | ||
| Value aDup = expandTo16WideWmma256(a); | ||
| Value bDup = expandTo16WideWmma256(b); | ||
| Value res = ROCDL::wmma_f32_16x16x16_f16::create(builder, loc, accTy, aDup, | ||
| bDup, c) | ||
| .getResult(); | ||
| return res; | ||
| } | ||
| // F32 <- BF16/BF16 (op A/B input ty is AnyInteger -> bitcast bf16 to i16) | ||
| if (elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isF32()) { | ||
| Value aI = castToI16Vec(expandTo16WideWmma256(a)); | ||
| Value bI = castToI16Vec(expandTo16WideWmma256(b)); | ||
| Value res = | ||
| ROCDL::wmma_f32_16x16x16_bf16::create(builder, loc, accTy, aI, bI, c) | ||
| .getResult(); | ||
| return res; | ||
| } | ||
| // F16 <- F16/F16. | ||
| // | ||
| // The native f16-acc WMMA variant uses op_sel-packed 16-wide C/D lanes, and | ||
| // keeping that packed state across loop-carried accumulation in the current | ||
| // vec8 fragment model is error-prone. For correctness, we use the f32-acc | ||
| // WMMA path and truncate the final lane-local result to f16. | ||
| if (elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF16()) { | ||
| Value aDup = expandTo16WideWmma256(a); | ||
| Value bDup = expandTo16WideWmma256(b); | ||
| auto accF32Ty = VectorType::get({accVecSize}, builder.getF32Type()); | ||
| Value cF32 = LLVM::FPExtOp::create(builder, loc, accF32Ty, c); | ||
| Value resF32 = ROCDL::wmma_f32_16x16x16_f16::create( | ||
| builder, loc, accF32Ty, aDup, bDup, cF32) | ||
| .getResult(); | ||
| return LLVM::FPTruncOp::create(builder, loc, accTy, resF32).getResult(); | ||
| } | ||
| // BF16 <- BF16/BF16. | ||
| // | ||
| // The native bf16-acc WMMA variant uses op_sel-packed 16-wide C/D lanes, | ||
| // which conflicts with the FlyDSL vec8 fragment representation. For | ||
| // correctness and consistency with the f16-acc path, we promote to f32 | ||
| // acc, run the F32_BF16 WMMA, and truncate back to bf16. (BF16 has the | ||
| // same exponent range as f32, so no value can overflow during the | ||
| // promote/truncate round-trip.) | ||
| if (elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isBF16()) { | ||
| Value aI = castToI16Vec(expandTo16WideWmma256(a)); | ||
| Value bI = castToI16Vec(expandTo16WideWmma256(b)); | ||
| auto accF32Ty = VectorType::get({accVecSize}, builder.getF32Type()); | ||
| Value cF32 = LLVM::FPExtOp::create(builder, loc, accF32Ty, c); | ||
| Value resF32 = | ||
| ROCDL::wmma_f32_16x16x16_bf16::create(builder, loc, accF32Ty, aI, bI, cF32) | ||
| .getResult(); | ||
| return LLVM::FPTruncOp::create(builder, loc, accTy, resF32).getResult(); | ||
| } | ||
| // I32 <- IU8 (signA = signB = 0 = unsigned, clamp = 0). | ||
| // For RDNA 3 / 3.5 WMMA256b, the upper half must come from lane^16. | ||
| if (elemTyA.isInteger(8) && elemTyB.isInteger(8) && elemTyAcc.isInteger(32)) { | ||
| Value aWide = expandPackedI32To4WideWmma256(a); | ||
| Value bWide = expandPackedI32To4WideWmma256(b); | ||
| Value res = ROCDL::wmma_i32_16x16x16_iu8::create( | ||
| builder, loc, accTy, | ||
| /*signA=*/false, aWide, /*signB=*/false, bWide, c, | ||
| /*clamp=*/false) | ||
| .getResult(); | ||
| return res; | ||
| } | ||
| // I32 <- IU4. Same WMMA256b lane^16 upper-half rule as IU8. | ||
| if (elemTyA.isInteger(4) && elemTyB.isInteger(4) && elemTyAcc.isInteger(32)) { | ||
| Value aWide = expandPackedI32To4WideWmma256(a); | ||
| Value bWide = expandPackedI32To4WideWmma256(b); | ||
| Value res = ROCDL::wmma_i32_16x16x16_iu4::create( | ||
| builder, loc, accTy, | ||
| /*signA=*/false, aWide, /*signB=*/false, bWide, c, | ||
| /*clamp=*/false) | ||
| .getResult(); | ||
| return res; | ||
| } |
There was a problem hiding this comment.
I am not familiar with WMMA in RDNA3, but this part looks so strange. Generating so many ops inside AtomCall is not a good idea, since it may replicate multiple times when gemmOp is unpacked into atom calls.
If this wrapping step is the common way for RDNA3 WMMA, I would suggest giving FlyROCDL_MmaOpRDNA3_WMMA a wrapper-specific suffix/name, and reserving the instruction-matching name for an unwrapped version which AtomCall basically is a rocdl.wmma call. This would leave room for future that users could perform these cross-lane operations explicitly at the Python level if they want.
There was a problem hiding this comment.
Thanks for point out this
With the lane mapping using matrix_calculator.py, and RDNA3 is the odd one here: in wave32, A/B need replication between lane i and i+16. For example:
./matrix_calculator.py --architecture rdna3 --instruction v_wmma_f16_16x16x16_f16 --A-matrix --get-register --I-coordinate 3 --K-coordinate 2
Architecture: RDNA3
Instruction: V_WMMA_F16_16X16X16_F16
A[3][2] = v1{3}.[15:0]
A[3][2] = v1{19}.[15:0]
Same query on RDNA4 and CDNA3 returns only one lane:
./matrix_calculator.py --architecture rdna4 --instruction v_wmma_f16_16x16x16_f16 --A-matrix --get-register --I-coordinate 3 --K-coordinate 2
Architecture: RDNA4
Instruction: V_WMMA_F16_16X16X16_F16
A[3][2] = v1{3}.[15:0]
./matrix_calculator.py --architecture cdna3 --instruction v_mfma_f32_16x16x16_f16 --A-matrix --get-register --I-coordinate 3 --K-coordinate 2
Architecture: CDNA3
Instruction: V_MFMA_F32_16X16X16_F16
A[3][2] = v1{3}.[15:0]
The reason I implemented it this way is that I wanted to keep the current FlyDSL compact fragment model (vec8 split across lane groups) unchanged and compatible with the existing tiling/copy pipeline. Given that model, reconstructing the operand cross-lane inside emitAtomCallSSA was the most direct way to satisfy RDNA3 WMMA operand ABI requirements.
But I also agree this weakens the design contract of MmaOpRDNA3_WMMA, because it is no longer a clean instruction-matching atom and now embeds cross-lane wrapper behavior in the atom call path.
Given this affects atom semantics (instruction-matching vs wrapper behavior), I don’t think I should make that call solo in this PR. I’m going to close this for now, align on direction first (raw-only / wrapper-only / dual-mode), then continue in a follow-up.
Motivation
This PR adds initial PoC-level support for AMD RDNA 3 / 3.5 wave32 WMMA on FlyDSL's atom path, validated on gfx1151 (Strix Halo iGPU, Radeon 8060S).
The goals are:
MmaOpRDNA3_WMMATypeatom functionally and numerically correct on gfx110x / gfx115x — the previous registration had a C/D layout mismatch with the actual hardware ordering and a suspect bf16-acc lowering, both papered over at runtime by ~28ds_bpermuteinstructions per WMMA.This is a PoC baseline, not a competitive flagship kernel. For peak GEMM perf on RDNA 3 / 3.5 today users should fall back to PyTorch /rocBLAS hgemm. Closing the gap to rocBLAS is left for follow-up PRs (see the Out of scope list below).
Technical Details
Atom-level fixes (
MmaOpRDNA3_WMMAType)Three issues found while reviewing the existing atom registration against the gfx11 ISA + empirically validating on gfx1151 hardware:
C/D layout mismatch. The f32/i32 accumulator's per-lane natural M ordering is interleaved between lane groups (lane n, val v → M = 2·v + n/16), not contiguous halves as
getThrValLayoutCDdeclared. The mismatch was being papered over by callingreorderAccLaneValuesfour times per WMMA call (3× on the C input as a hand-rolled inverse permutation, 1× on the D output), costing ~32ds_bpermuteper call. UpdatinggetThrValLayoutCDto match the real layout and dropping the runtime swizzles takes the per-WMMA cross-lane op count from ~28 → 8 (the irreducible A/B expansion for the WMMA256b 16-wide K input).bf16-acc path. The native bf16-acc WMMA variant uses op_sel-packed 16-wide C/D lanes which conflicted with FlyDSL's vec8 fragment representation. Tests passed because
verify_outputuses cosine-similarity with atol=0.1 and only exercised 16×16×16 — almost certainly silently wrong on larger shapes. Now lowers through the f32 WMMA op +fptrunc, matching the f16-acc fallback. (BF16 has the same exponent range as f32, so the promote/truncate round-trip is overflow-safe.)Dead code. Remove the now-unused
duplicateTo16WideSimpleandreorderAccLaneValueshelpers (~110 LOC of MmaAtom.cpp).The MLIR FileCheck test (
tests/mlir/Conversion/wmma_rdna3.mlir) is updated for the new bf16 → f32 lowering.New PoC baseline GEMM kernel —
kernels/rdna3_gemm.pyA multi-wave atom-based GEMM exercising
MmaOpRDNA3_WMMATypeon production-shape sizes. Reaches ≈ 56% of rocBLAS hgemm at 4096³ on gfx1151 (≈ 30% of WMMA F16 peak). Explicitly framed as a PoC baseline, not a peak-perf kernel — the docstring lists the optimizations needed to close the gap to rocBLAS (LDS XOR swizzle, CShuffle epilogue, sched intrinsics, occupancy tuning, tail-tile predication) for follow-up PRs.Atom-API optimizations applied:
make_tiled_mma's third argument is the K-direction permutation, not a per-warp reg-tile knob. reg_m / reg_n / reg_k are implicit from the ratioBLOCK_M / (waves_m * WMMA_M).Tests + benchmark infrastructure
Atom validation (
tests/kernels/test_rdna3_wmma_gemm.py): expanded 17 → 25 cases — bf16-acc multi-K-tile (was only 16×16×16) and multi-block sizes; f16-acc multi-K-tile; iu8 → i32 multi-block (was only single-tile contract test).Production-shape correctness (
tests/kernels/test_rdna3_gemm.py, new): 19 tests covering f16/bf16 inputs, f32/f16/bf16 accumulators, multi-block sizes (128³ → 1024³), and a wave / reg-config sweep.rocBLAS-comparison benchmark (
tests/bench_rdna3_gemm.py, new): reports TFLOPS, %rocBLAS, and %WMMA-peak. Includes a chip table for back-calculating theoretical peak from per-CU per-cycle WMMA throughput (511 ops/cycle/CU on RDNA 3, derived from RX 7900 XTX's published 122.6 TFLOPS / 96 CUs / 2.5 GHz).Test Plan
Validated on gfx1151 (AMD Ryzen AI Max+ 395, Radeon 8060S iGPU, ROCm 7.12).