From 08fda9c7a6db2bb16d5f946369c7c9ec63b355b2 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 3 Feb 2026 15:54:49 -0800 Subject: [PATCH 1/2] Fix FSDP when FSDP+EP Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 71f133bfc4..37cb4febf8 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -936,7 +936,7 @@ def _parse_operand_output_specs( # Non-contracting dims of RHS always needs to be gathered along the FSDP axis rhs_non_cspecs = tuple( - None if spec is not None and spec == gsr.fsdp_resource else spec + None if spec is not None and (spec == gsr.fsdp_resource or (isinstance(spec, tuple) and gsr.fsdp_resource in spec)) else spec for spec in rhs_non_cspecs ) From fb1a6a9723aa464422728ce43162d4a4a1b64154 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 17:10:57 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 37cb4febf8..a34cb030bf 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -936,7 +936,15 @@ def _parse_operand_output_specs( # Non-contracting dims of RHS always needs to be gathered along the FSDP axis rhs_non_cspecs = tuple( - None if spec is not None and (spec == gsr.fsdp_resource or (isinstance(spec, tuple) and gsr.fsdp_resource in spec)) else spec + ( + None + if spec is not None + and ( + spec == gsr.fsdp_resource + or (isinstance(spec, tuple) and gsr.fsdp_resource in spec) + ) + else spec + ) for spec in rhs_non_cspecs )