Skip to content

Tma load grouped gemm#5937

Open
edwingao28 wants to merge 1 commit into
pytorch:mainfrom
edwingao28:tma-load-grouped-gemm
Open

Tma load grouped gemm#5937
edwingao28 wants to merge 1 commit into
pytorch:mainfrom
edwingao28:tma-load-grouped-gemm

Conversation

@edwingao28

Copy link
Copy Markdown

Summary

On Triton 3.x, FBGEMM's grouped-GEMM kernels silently run with TMA disabled: utils.HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) is False because that attribute was removed from Triton. PR #4866 migrated the store path to the device-side tl.make_tensor_descriptor API; this PR does the same for the load path and fixes the probe, restoring TMA on Hopper.

This is a performance forward-port, not a correctness fix — the kernels already produce correct results via the tl.load fallback; they were just stall-bound with TMA off.

Root cause

  • utils.py probes nv_tma_desc_type (removed in Triton 3.x) → HAS_TMA_DESC is False_grouped_gemm disables USE_TMA_LOAD/USE_TMA_STORE.
  • After Migrate to new device TMA API for grouped_gemm.py #4866 the store path uses tl.make_tensor_descriptor, but the load path still calls the removed tl._experimental_descriptor_load — so even with the probe fixed, the load path wouldn't compile. Both have to move together.

What this PR does

  • utils.py: gate HAS_TMA_DESC on hasattr(tl, "make_tensor_descriptor") (the old nv_tma_desc_type check was likewise a pure API-presence probe).
  • grouped_gemm.py: in the two non-WS kernels (_fbgemm_grouped_gemm, _fbgemm_grouped_gemm_fp8_rowwise), build the A/B load descriptors in-kernel with tl.make_tensor_descriptor, once per group, using relative tile offsets
    — mirroring the c_desc_ptr store path from Migrate to new device TMA API for grouped_gemm.py #4866. The host passes the raw x/w tensors and registers the device allocator for TMA load or store.
  • grouped_gemm_test.py: broaden fp8 K coverage to {64, 128, 256}.

What this PR does NOT do (and why)

  • The warp-specialized kernels (*_ws) are not migrated. They require the facebookexperimental/triton ws-3.2.x fork to exercise, which I can't test against, so they keep their existing host-built-descriptor load path (the host still builds descriptors when use_warp_specialization) — behavior there is unchanged. Migrating them is a follow-up.

Results

H100, Triton 3.6, fp8 rowwise, Llama-4-shaped gate/up (E=16, N=16384, K=5120, M=8192), triton.testing.do_bench:

latency
TMA off (before) 1.587 ms
TMA on (this PR) 1.147 ms (1.38×)

Testing

  • pytest fbgemm_gpu/experimental/gemm/test/grouped_gemm_test.py passes on H100
    (fp8 + bf16).
  • Patched fp8 output matches the TMA-off path to fp8 rounding noise.

Related: #4866 (the store-path migration this mirrors).

@meta-cla meta-cla Bot added the cla signed label Jun 19, 2026
@edwingao28

Copy link
Copy Markdown
Author

Hi @pchen7e2 @htyu @dshi7 could you take a look at this pr when you have a moment?

This is a follow-up to #4866. That PR migrated the grouped-GEMM store path to the device-side tl.make_tensor_descriptor API, but the load path still calls the removed tl._experimental_descriptor_load,
so on Triton 3.x HAS_TMA_DESC is False and TMA stays disabled in grouped_gemm.py. This PR mirrors #4866's approach for the load path (device-side descriptors, built per group) and fixes the probe — restoring TMA on Hopper (1.38× on the fp8 grouped GEMM via do_bench; grouped_gemm_test.py passes for fp8 + bf16).

@meta-codesync

meta-codesync Bot commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

@q10 has imported this pull request. If you are a Meta employee, you can view this in D109358672.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant