diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 71f133bfc4..a34cb030bf 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -936,7 +936,15 @@ def _parse_operand_output_specs( # Non-contracting dims of RHS always needs to be gathered along the FSDP axis rhs_non_cspecs = tuple( - None if spec is not None and spec == gsr.fsdp_resource else spec + ( + None + if spec is not None + and ( + spec == gsr.fsdp_resource + or (isinstance(spec, tuple) and gsr.fsdp_resource in spec) + ) + else spec + ) for spec in rhs_non_cspecs )