Tma load grouped gemm#5937
Open
edwingao28 wants to merge 1 commit into
Open
Conversation
Author
|
Hi @pchen7e2 @htyu @dshi7 could you take a look at this pr when you have a moment? This is a follow-up to #4866. That PR migrated the grouped-GEMM store path to the device-side |
Contributor
|
@q10 has imported this pull request. If you are a Meta employee, you can view this in D109358672. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
On Triton 3.x, FBGEMM's grouped-GEMM kernels silently run with TMA disabled:
utils.HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)isFalsebecause that attribute was removed from Triton. PR #4866 migrated the store path to the device-sidetl.make_tensor_descriptorAPI; this PR does the same for the load path and fixes the probe, restoring TMA on Hopper.This is a performance forward-port, not a correctness fix — the kernels already produce correct results via the
tl.loadfallback; they were just stall-bound with TMA off.Root cause
utils.pyprobesnv_tma_desc_type(removed in Triton 3.x) →HAS_TMA_DESCisFalse→_grouped_gemmdisablesUSE_TMA_LOAD/USE_TMA_STORE.tl.make_tensor_descriptor, but the load path still calls the removedtl._experimental_descriptor_load— so even with the probe fixed, the load path wouldn't compile. Both have to move together.What this PR does
utils.py: gateHAS_TMA_DESConhasattr(tl, "make_tensor_descriptor")(the oldnv_tma_desc_typecheck was likewise a pure API-presence probe).grouped_gemm.py: in the two non-WS kernels (_fbgemm_grouped_gemm,_fbgemm_grouped_gemm_fp8_rowwise), build the A/B load descriptors in-kernel withtl.make_tensor_descriptor, once per group, using relative tile offsets— mirroring the
c_desc_ptrstore path from Migrate to new device TMA API for grouped_gemm.py #4866. The host passes the rawx/wtensors and registers the device allocator for TMA load or store.grouped_gemm_test.py: broaden fp8 K coverage to{64, 128, 256}.What this PR does NOT do (and why)
*_ws) are not migrated. They require thefacebookexperimental/triton ws-3.2.xfork to exercise, which I can't test against, so they keep their existing host-built-descriptor load path (the host still builds descriptors whenuse_warp_specialization) — behavior there is unchanged. Migrating them is a follow-up.Results
H100, Triton 3.6, fp8 rowwise, Llama-4-shaped gate/up (E=16, N=16384, K=5120, M=8192),
triton.testing.do_bench:Testing
pytest fbgemm_gpu/experimental/gemm/test/grouped_gemm_test.pypasses on H100(fp8 + bf16).
Related: #4866 (the store-path migration this mirrors).