Skip to content

[Kernel][Dialect] Support RDNA 3 / 3.5 WMMA (PoC): fix atom + add baseline GEMM kernel#485

Closed
KerwinTsaiii wants to merge 5 commits intoROCm:mainfrom
KerwinTsaiii:support-rdna3.5-core
Closed

[Kernel][Dialect] Support RDNA 3 / 3.5 WMMA (PoC): fix atom + add baseline GEMM kernel#485
KerwinTsaiii wants to merge 5 commits intoROCm:mainfrom
KerwinTsaiii:support-rdna3.5-core

Conversation

@KerwinTsaiii
Copy link
Copy Markdown

Motivation

This PR adds initial PoC-level support for AMD RDNA 3 / 3.5 wave32 WMMA on FlyDSL's atom path, validated on gfx1151 (Strix Halo iGPU, Radeon 8060S).

The goals are:

  1. Make the MmaOpRDNA3_WMMAType atom functionally and numerically correct on gfx110x / gfx115x — the previous registration had a C/D layout mismatch with the actual hardware ordering and a suspect bf16-acc lowering, both papered over at runtime by ~28 ds_bpermute instructions per WMMA.
  2. Demonstrate the atom path works end-to-end on production-shape GEMM (≤ 4096³) by adding a multi-wave + LDS double-buffered kernel, so other RDNA 3 / 3.5 kernels can be built on the same primitives.
  3. Provide a rocBLAS-comparison benchmark so future perf work has a tracked baseline (currently ≈ 56% of rocBLAS hgemm at 4096³).

This is a PoC baseline, not a competitive flagship kernel. For peak GEMM perf on RDNA 3 / 3.5 today users should fall back to PyTorch /rocBLAS hgemm. Closing the gap to rocBLAS is left for follow-up PRs (see the Out of scope list below).

Technical Details

Atom-level fixes (MmaOpRDNA3_WMMAType)

Three issues found while reviewing the existing atom registration against the gfx11 ISA + empirically validating on gfx1151 hardware:

  • C/D layout mismatch. The f32/i32 accumulator's per-lane natural M ordering is interleaved between lane groups (lane n, val v → M = 2·v + n/16), not contiguous halves as getThrValLayoutCD declared. The mismatch was being papered over by calling reorderAccLaneValues four times per WMMA call (3× on the C input as a hand-rolled inverse permutation, 1× on the D output), costing ~32 ds_bpermute per call. Updating getThrValLayoutCD to match the real layout and dropping the runtime swizzles takes the per-WMMA cross-lane op count from ~28 → 8 (the irreducible A/B expansion for the WMMA256b 16-wide K input).

  • bf16-acc path. The native bf16-acc WMMA variant uses op_sel-packed 16-wide C/D lanes which conflicted with FlyDSL's vec8 fragment representation. Tests passed because verify_output uses cosine-similarity with atol=0.1 and only exercised 16×16×16 — almost certainly silently wrong on larger shapes. Now lowers through the f32 WMMA op + fptrunc, matching the f16-acc fallback. (BF16 has the same exponent range as f32, so the promote/truncate round-trip is overflow-safe.)

  • Dead code. Remove the now-unused duplicateTo16WideSimple and reorderAccLaneValues helpers (~110 LOC of MmaAtom.cpp).

The MLIR FileCheck test (tests/mlir/Conversion/wmma_rdna3.mlir) is updated for the new bf16 → f32 lowering.

New PoC baseline GEMM kernel — kernels/rdna3_gemm.py

A multi-wave atom-based GEMM exercising MmaOpRDNA3_WMMAType on production-shape sizes. Reaches ≈ 56% of rocBLAS hgemm at 4096³ on gfx1151 (≈ 30% of WMMA F16 peak). Explicitly framed as a PoC baseline, not a peak-perf kernel — the docstring lists the optimizations needed to close the gap to rocBLAS (LDS XOR swizzle, CShuffle epilogue, sched intrinsics, occupancy tuning, tail-tile predication) for follow-up PRs.

Atom-API optimizations applied:

  • 2×2 wave layout (4 × wave32 = 128 threads/workgroup), per-warp reg_m × reg_n = 4×4 WMMA tile, BLOCK = 128 × 128 × 32.
  • LDS staging for both A and B with a 2-stage ping-pong buffer; LDS layout is row-major (K-fastest) to match the row-major GMEM source, layout-lowering pass handles the reshape into the column-major WMMA fragment convention.
  • Software pipelining: each loop iteration prefetches the next K-tile from GMEM into LDS[(k+1)&1] while computing from LDS[k&1], overlapping GMEM latency with WMMA compute.
  • Outer loop is a runtime SCF for-op iterating over K-tile pairs (so the IR doesn't blow up at K=4096 from unrolling 128 iterations); the pair body unrolls stage 0 and stage 1 explicitly so the LDS-stage index stays constexpr.
  • Implementation note: make_tiled_mma's third argument is the K-direction permutation, not a per-warp reg-tile knob. reg_m / reg_n / reg_k are implicit from the ratio BLOCK_M / (waves_m * WMMA_M).

Tests + benchmark infrastructure

  • Atom validation (tests/kernels/test_rdna3_wmma_gemm.py): expanded 17 → 25 cases — bf16-acc multi-K-tile (was only 16×16×16) and multi-block sizes; f16-acc multi-K-tile; iu8 → i32 multi-block (was only single-tile contract test).

  • Production-shape correctness (tests/kernels/test_rdna3_gemm.py, new): 19 tests covering f16/bf16 inputs, f32/f16/bf16 accumulators, multi-block sizes (128³ → 1024³), and a wave / reg-config sweep.

  • rocBLAS-comparison benchmark (tests/bench_rdna3_gemm.py, new): reports TFLOPS, %rocBLAS, and %WMMA-peak. Includes a chip table for back-calculating theoretical peak from per-CU per-cycle WMMA throughput (511 ops/cycle/CU on RDNA 3, derived from RX 7900 XTX's published 122.6 TFLOPS / 96 CUs / 2.5 GHz).

Test Plan

Validated on gfx1151 (AMD Ryzen AI Max+ 395, Radeon 8060S iGPU, ROCm 7.12).

# Atom validation (25 tests covering all 6 dtype combos + multi-tile)
python -m pytest tests/kernels/test_rdna3_wmma_gemm.py -v

# Production-shape GEMM correctness (19 tests covering dtypes,
# accumulators, multi-block sizes, and a wave/reg-config sweep)
python -m pytest tests/kernels/test_rdna3_gemm.py -v

# Benchmark 
bash scripts/run_benchmark.sh 
[run_benchmark] GPU arch: gfx1151 (CDNA=false, RDNA4=false)
========================================================================
Benchmarks (logs under /tmp/flydsl_bench)
========================================================================

op                     shape                              dtype            TB/s     TFLOPS
---------------------- ---------------------------------- ---------- ---------- ----------
softmax                32768x8192                         bf16            0.209          -
layernorm              32768x8192                         bf16            0.226          -
rmsnorm                32768x8192                         bf16            0.225          -

========================================================================
Benchmark Summary
========================================================================
Total: 3 tests
Success: 3
Failed: 0
Logs: /tmp/flydsl_bench

All benchmarks passed! 

KerwinTsaiii and others added 5 commits May 8, 2026 12:22
Empirically validated against gfx1151 hardware that the f32/i32
accumulator's per-lane natural M ordering is *interleaved* between lane
groups (lane n, val v -> M = 2*v + n/16), not contiguous as the previous
getThrValLayoutCD declared. The mismatch was being papered over by
calling reorderAccLaneValues 4 times per WMMA, costing ~32 ds_bpermute
per call.

Updating getThrValLayoutCD to match the real hardware layout and
dropping the runtime swizzles takes the per-WMMA cross-lane op count
from ~28 ds_bpermute down to 8 (the irreducible A/B operand expansion
for the WMMA256b 16-wide K input). Empirical speedup on gfx1151
runtime tests: up to 1.99x at 4096^3 GEMM.

Other fixes:

  * The native bf16-acc WMMA variant uses op_sel-packed 16-wide C/D
    lanes which conflicted with FlyDSL's vec8 fragment representation.
    Now lowers through the f32 WMMA op + fptrunc, matching the f16-acc
    fallback. (BF16 has the same exponent range as f32 so the
    promote/truncate round-trip is overflow-safe.)

  * Remove unused duplicateTo16WideSimple and reorderAccLaneValues
    helpers (~110 LOC of dead code).

  * Update the bf16 -> bf16 FileCheck test to match the new lowering.

  * Expand runtime test coverage from 17 -> 25 cases:
    - bf16-acc multi-K-tile (was only 16x16x16) and multi-block sizes
    - f16-acc multi-K-tile
    - iu8 -> i32 multi-block (was only single-tile contract test)

Co-authored-by: Cursor <cursoragent@cursor.com>
Adds kernels/rdna3_gemm.py — a production-shape GEMM kernel exercising
the MmaOpRDNA3_WMMAType atom path, intended as a starting point for
gfx110x / gfx115x / gfx1151 (Strix Halo iGPU). Compared to the existing
single-warp PoC kernels.rdna3_f16_gemm (which is for atom validation),
this kernel applies the standard atom-API optimizations:

  * 2x2 wave layout (4 wave32 = 128 threads per workgroup), per-warp
    reg_m × reg_n = 4x4 WMMA tile.
  * BLOCK = 128 x 128 x 32. reg_m / reg_n / reg_k are implicit from the
    BLOCK / (waves * WMMA_*) ratio, NOT passed as the make_tiled_mma
    permutation argument (which is for K-direction permutation, a
    distinct concept — initial attempt with that kept producing layout
    mismatches).
  * LDS staging for both A and B with a 2-stage ping-pong buffer.
    LDS layout is row-major (K-fastest) to match the row-major GMEM
    source; the layout-lowering pass handles the reshape into the
    column-major WMMA fragment convention.
  * Software pipelining: each loop iteration prefetches the next
    K-tile from GMEM into LDS[(k+1)&1] while computing from
    LDS[k&1], overlapping GMEM latency with WMMA compute.
  * Runtime SCF outer loop over K-tile *pairs* (not constexpr
    range_constexpr), so the IR doesn't blow up at K=4096 (which
    would unroll 128 K-iterations).
  * Tunable workgroup swizzle (group_m). On gfx1151 with 4 MB L2
    no benefit was measured at 4096^3 so the default is 1; left as
    a knob for other RDNA3 SKUs.

Performance on gfx1151 (Strix Halo iGPU, Radeon 8060S; 40 CUs;
theoretical WMMA F16 peak ~59 TFLOPS; rocBLAS hgemm ~30 TFLOPS):

  * 1024^3 f16->f32:  11.8 TFLOPS (57.7% of rocBLAS, 19.9% of peak)
  * 2048^3 f16->f32:  21.9 TFLOPS (62.3% of rocBLAS, 37.0% of peak)
  * 4096^3 f16->f32:  17.5 TFLOPS (56.0% of rocBLAS, 29.5% of peak)

Compared to the single-warp PoC at ~7 TFLOPS this is roughly a
2.5x improvement; remaining gap to rocBLAS comes from missing
LDS bank-conflict swizzle, CShuffle epilogue, and hot-loop
scheduler hints — left for follow-up PRs.

Also adds:

  * tests/kernels/test_rdna3_gemm.py — 19 correctness tests
    covering f16/bf16 inputs, f32/f16/bf16 accumulators, multi-block
    sizes (128 to 1024), and a wave/reg-config sweep.

  * tests/bench_rdna3_gemm.py — rocBLAS-comparison benchmark
    that reports TFLOPS, %rocBLAS, and %WMMA-peak. Includes
    chip-table for back-calculating theoretical peak from per-CU
    per-cycle throughput.

Co-authored-by: Cursor <cursoragent@cursor.com>
Two small follow-ups to the previous commit:

  * Update the kernel docstring to clearly mark this as a PoC baseline
    (~56% of rocBLAS at 4096^3 on gfx1151) rather than "production
    starting point". List the optimizations that *would* close the gap
    to rocBLAS for follow-up work (LDS XOR swizzle, CShuffle epilogue,
    sched intrinsics, occupancy tuning, tail-tile predication) so users
    don't mistake this for a peak-perf kernel.

  * Remove the group_m workgroup-swizzle knob. Empirically every value
    > 1 was slower than the trivial bid_m / bid_n mapping on gfx1151
    (4 MB L2). Keeping a knob that only ever hurts perf is a code
    smell; reverting to the simpler 2D grid drops a handful of div /
    mod instructions per workgroup and makes the kernel match the
    structure of `examples/03-tiledMma.py`.

44/44 tests still pass (25 atom validation + 19 production-shape
GEMM); benchmark unchanged at 17.7 TFLOPS @ 4096^3.

Co-authored-by: Cursor <cursoragent@cursor.com>
@coderfeli coderfeli requested a review from sjfeng1999 May 8, 2026 13:49
Comment on lines +243 to +427
// RDNA 3 / 3.5 uses `WMMA256bInsts`: the intrinsic selector expects
// 16-wide lane operands. FlyDSL's per-lane fragment is 8 elements where
// lane-group 0 carries K=0..7 and lane-group 1 carries K=8..15.
//
// Build a 16-wide lane vector as:
// low half = lane-local 8 values
// high half = values pulled from the paired lane (lane xor 16)
//
// `rocdl.ds_bpermute` is 32-bit granularity, so for f16/bf16 we:
// vector<8x{f16|bf16}> -> bitcast vector<4xi32>
// bpermute each dword from lane^16
// bitcast back to vector<8x{f16|bf16}>
// concatenate [self8, pair8] to vector<16x...>
auto expandTo16WideWmma256 = [&](Value v) -> Value {
auto vt = cast<VectorType>(v.getType());
if (vt.getShape().size() != 1 || vt.getShape()[0] != 8)
return v;
auto wideTy = VectorType::get({16}, vt.getElementType());

Value pairedHalf = v;
if (vt.getElementType().isF16() || vt.getElementType().isBF16()) {
auto i32Ty = IntegerType::get(ctx, 32);
auto i64Ty = IntegerType::get(ctx, 64);
auto packedTy = VectorType::get({4}, i32Ty);

Value lane = ROCDL::ThreadIdXOp::create(builder, loc, i32Ty).getResult();
Value c31 = LLVM::ConstantOp::create(builder, loc, i32Ty,
builder.getI32IntegerAttr(31));
Value cNeg32 = LLVM::ConstantOp::create(builder, loc, i32Ty,
builder.getI32IntegerAttr(-32));
Value c16 = LLVM::ConstantOp::create(builder, loc, i32Ty,
builder.getI32IntegerAttr(16));
Value c2 = LLVM::ConstantOp::create(builder, loc, i32Ty,
builder.getI32IntegerAttr(2));

Value laneInWave = LLVM::AndOp::create(builder, loc, lane, c31);
Value waveBase = LLVM::AndOp::create(builder, loc, lane, cNeg32);
Value pairedLaneInWave = LLVM::XOrOp::create(builder, loc, laneInWave, c16);
Value pairedLane = LLVM::OrOp::create(builder, loc, waveBase, pairedLaneInWave);
Value pairedLaneByteAddr = LLVM::ShlOp::create(builder, loc, pairedLane, c2);

Value packed = LLVM::BitcastOp::create(builder, loc, packedTy, v);
Value pairedPacked = LLVM::UndefOp::create(builder, loc, packedTy);
for (int i = 0; i < 4; ++i) {
Value idx = LLVM::ConstantOp::create(builder, loc, i64Ty,
builder.getI64IntegerAttr(i));
Value laneWord = LLVM::ExtractElementOp::create(builder, loc, packed, idx);
Value pairedWord = ROCDL::DsBpermuteOp::create(builder, loc, i32Ty,
pairedLaneByteAddr, laneWord)
.getResult();
pairedPacked = LLVM::InsertElementOp::create(builder, loc, pairedPacked,
pairedWord, idx);
}
pairedHalf = LLVM::BitcastOp::create(builder, loc, vt, pairedPacked);
}

Value lowHalf = v;
Value highHalf = pairedHalf;

SmallVector<int32_t> concatMask = {0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15};
return LLVM::ShuffleVectorOp::create(builder, loc, wideTy, lowHalf, highHalf,
concatMask);
};

// For IU8 / IU4, each lane carries <2xi32> (64 bits). WMMA256b expects
// a 128-bit source per operand lane, where the upper half comes from lane^16.
auto expandPackedI32To4WideWmma256 = [&](Value v) -> Value {
auto vt = dyn_cast<VectorType>(v.getType());
if (!vt || vt.getShape().size() != 1 || vt.getShape()[0] != 2 ||
!vt.getElementType().isInteger(32))
return v;

auto i32Ty = IntegerType::get(ctx, 32);
auto i64Ty = IntegerType::get(ctx, 64);
auto wideTy = VectorType::get({4}, i32Ty);

Value lane = ROCDL::ThreadIdXOp::create(builder, loc, i32Ty).getResult();
Value c31 = LLVM::ConstantOp::create(builder, loc, i32Ty,
builder.getI32IntegerAttr(31));
Value cNeg32 = LLVM::ConstantOp::create(builder, loc, i32Ty,
builder.getI32IntegerAttr(-32));
Value c16 = LLVM::ConstantOp::create(builder, loc, i32Ty,
builder.getI32IntegerAttr(16));
Value c2 = LLVM::ConstantOp::create(builder, loc, i32Ty,
builder.getI32IntegerAttr(2));

Value laneInWave = LLVM::AndOp::create(builder, loc, lane, c31);
Value waveBase = LLVM::AndOp::create(builder, loc, lane, cNeg32);
Value pairedLaneInWave = LLVM::XOrOp::create(builder, loc, laneInWave, c16);
Value pairedLane = LLVM::OrOp::create(builder, loc, waveBase, pairedLaneInWave);
Value pairedLaneByteAddr = LLVM::ShlOp::create(builder, loc, pairedLane, c2);

Value pairedHalf = LLVM::UndefOp::create(builder, loc, vt);
for (int i = 0; i < 2; ++i) {
Value idx = LLVM::ConstantOp::create(builder, loc, i64Ty,
builder.getI64IntegerAttr(i));
Value laneWord = LLVM::ExtractElementOp::create(builder, loc, v, idx);
Value pairedWord = ROCDL::DsBpermuteOp::create(builder, loc, i32Ty,
pairedLaneByteAddr, laneWord)
.getResult();
pairedHalf = LLVM::InsertElementOp::create(builder, loc, pairedHalf,
pairedWord, idx);
}

SmallVector<int32_t> concatMask = {0, 1, 2, 3};
return LLVM::ShuffleVectorOp::create(builder, loc, wideTy, v, pairedHalf,
concatMask);
};

// F32 <- F16/F16
if (elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF32()) {
Value aDup = expandTo16WideWmma256(a);
Value bDup = expandTo16WideWmma256(b);
Value res = ROCDL::wmma_f32_16x16x16_f16::create(builder, loc, accTy, aDup,
bDup, c)
.getResult();
return res;
}
// F32 <- BF16/BF16 (op A/B input ty is AnyInteger -> bitcast bf16 to i16)
if (elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isF32()) {
Value aI = castToI16Vec(expandTo16WideWmma256(a));
Value bI = castToI16Vec(expandTo16WideWmma256(b));
Value res =
ROCDL::wmma_f32_16x16x16_bf16::create(builder, loc, accTy, aI, bI, c)
.getResult();
return res;
}
// F16 <- F16/F16.
//
// The native f16-acc WMMA variant uses op_sel-packed 16-wide C/D lanes, and
// keeping that packed state across loop-carried accumulation in the current
// vec8 fragment model is error-prone. For correctness, we use the f32-acc
// WMMA path and truncate the final lane-local result to f16.
if (elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF16()) {
Value aDup = expandTo16WideWmma256(a);
Value bDup = expandTo16WideWmma256(b);
auto accF32Ty = VectorType::get({accVecSize}, builder.getF32Type());
Value cF32 = LLVM::FPExtOp::create(builder, loc, accF32Ty, c);
Value resF32 = ROCDL::wmma_f32_16x16x16_f16::create(
builder, loc, accF32Ty, aDup, bDup, cF32)
.getResult();
return LLVM::FPTruncOp::create(builder, loc, accTy, resF32).getResult();
}
// BF16 <- BF16/BF16.
//
// The native bf16-acc WMMA variant uses op_sel-packed 16-wide C/D lanes,
// which conflicts with the FlyDSL vec8 fragment representation. For
// correctness and consistency with the f16-acc path, we promote to f32
// acc, run the F32_BF16 WMMA, and truncate back to bf16. (BF16 has the
// same exponent range as f32, so no value can overflow during the
// promote/truncate round-trip.)
if (elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isBF16()) {
Value aI = castToI16Vec(expandTo16WideWmma256(a));
Value bI = castToI16Vec(expandTo16WideWmma256(b));
auto accF32Ty = VectorType::get({accVecSize}, builder.getF32Type());
Value cF32 = LLVM::FPExtOp::create(builder, loc, accF32Ty, c);
Value resF32 =
ROCDL::wmma_f32_16x16x16_bf16::create(builder, loc, accF32Ty, aI, bI, cF32)
.getResult();
return LLVM::FPTruncOp::create(builder, loc, accTy, resF32).getResult();
}
// I32 <- IU8 (signA = signB = 0 = unsigned, clamp = 0).
// For RDNA 3 / 3.5 WMMA256b, the upper half must come from lane^16.
if (elemTyA.isInteger(8) && elemTyB.isInteger(8) && elemTyAcc.isInteger(32)) {
Value aWide = expandPackedI32To4WideWmma256(a);
Value bWide = expandPackedI32To4WideWmma256(b);
Value res = ROCDL::wmma_i32_16x16x16_iu8::create(
builder, loc, accTy,
/*signA=*/false, aWide, /*signB=*/false, bWide, c,
/*clamp=*/false)
.getResult();
return res;
}
// I32 <- IU4. Same WMMA256b lane^16 upper-half rule as IU8.
if (elemTyA.isInteger(4) && elemTyB.isInteger(4) && elemTyAcc.isInteger(32)) {
Value aWide = expandPackedI32To4WideWmma256(a);
Value bWide = expandPackedI32To4WideWmma256(b);
Value res = ROCDL::wmma_i32_16x16x16_iu4::create(
builder, loc, accTy,
/*signA=*/false, aWide, /*signB=*/false, bWide, c,
/*clamp=*/false)
.getResult();
return res;
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with WMMA in RDNA3, but this part looks so strange. Generating so many ops inside AtomCall is not a good idea, since it may replicate multiple times when gemmOp is unpacked into atom calls.

If this wrapping step is the common way for RDNA3 WMMA, I would suggest giving FlyROCDL_MmaOpRDNA3_WMMA a wrapper-specific suffix/name, and reserving the instruction-matching name for an unwrapped version which AtomCall basically is a rocdl.wmma call. This would leave room for future that users could perform these cross-lane operations explicitly at the Python level if they want.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for point out this

With the lane mapping using matrix_calculator.py, and RDNA3 is the odd one here: in wave32, A/B need replication between lane i and i+16. For example:

./matrix_calculator.py --architecture rdna3 --instruction v_wmma_f16_16x16x16_f16 --A-matrix --get-register --I-coordinate 3 --K-coordinate 2
Architecture: RDNA3
Instruction: V_WMMA_F16_16X16X16_F16
A[3][2] = v1{3}.[15:0]
A[3][2] = v1{19}.[15:0]

Same query on RDNA4 and CDNA3 returns only one lane:

./matrix_calculator.py --architecture rdna4 --instruction v_wmma_f16_16x16x16_f16 --A-matrix --get-register --I-coordinate 3 --K-coordinate 2
Architecture: RDNA4
Instruction: V_WMMA_F16_16X16X16_F16
A[3][2] = v1{3}.[15:0]
./matrix_calculator.py --architecture cdna3 --instruction v_mfma_f32_16x16x16_f16 --A-matrix --get-register --I-coordinate 3 --K-coordinate 2
Architecture: CDNA3
Instruction: V_MFMA_F32_16X16X16_F16
A[3][2] = v1{3}.[15:0]

The reason I implemented it this way is that I wanted to keep the current FlyDSL compact fragment model (vec8 split across lane groups) unchanged and compatible with the existing tiling/copy pipeline. Given that model, reconstructing the operand cross-lane inside emitAtomCallSSA was the most direct way to satisfy RDNA3 WMMA operand ABI requirements.

But I also agree this weakens the design contract of MmaOpRDNA3_WMMA, because it is no longer a clean instruction-matching atom and now embeds cross-lane wrapper behavior in the atom call path.

Given this affects atom semantics (instruction-matching vs wrapper behavior), I don’t think I should make that call solo in this PR. I’m going to close this for now, align on direction first (raw-only / wrapper-only / dual-mode), then continue in a follow-up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants