Skip to content

[JAX] Fix FSDP when FSDP+EP is active#2649

Open
jberchtold-nvidia wants to merge 3 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/fix-fsdp-when-ep-is-active
Open

[JAX] Fix FSDP when FSDP+EP is active#2649
jberchtold-nvidia wants to merge 3 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/fix-fsdp-when-ep-is-active

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

In some models when FSDP and EP are both active, the non-MoE blocks use (('fsdp', 'expert'), None, None, None) with the EP GPU domain acting as FSDP. Our TE/JAX GEMM did not handle this correctly as it assumed fsdp would be present in an axis alone, not as part of a tuple. This resulting in unnecessary AllGather's of the inputs blocking the critical path.

This PR fixes the check and if the TE GEMM sees an inspect sharding with fsdp as part a tuple, it performs FSDP on the GPU domain of all axes specified in the tuple.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Update the check in TE/JAX GEMM to handle ('fsdp', 'expert') along the same axis

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 4, 2026

Greptile Overview

Greptile Summary

Fixed FSDP sharding check in JAX GEMM to handle tuple specs like ('fsdp', 'expert') that occur when FSDP and Expert Parallelism (EP) are both active in non-MoE blocks.

  • Extended the FSDP resource check to use isinstance(spec, tuple) and gsr.fsdp_resource in spec in addition to the existing spec == gsr.fsdp_resource check
  • This prevents unnecessary AllGather operations on the input that were blocking the critical path
  • The fix is localized to the non-contracting dimension handling for RHS in infer_sharding_from_operands

Confidence Score: 4/5

  • Safe to merge with high confidence - targeted bug fix with clear logic
  • The change is minimal, well-scoped, and addresses a specific issue where FSDP+EP tuple specs were not handled. The logic extension is sound and backward compatible.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Extended FSDP check to handle tuple specs like ('fsdp', 'expert') by adding isinstance(spec, tuple) check

Sequence Diagram

sequenceDiagram
    participant Caller
    participant GemmPrimitive
    participant infer_sharding_from_operands
    participant MeshResource as gsr (MeshResource)
    
    Caller->>GemmPrimitive: Execute GEMM with FSDP+EP
    GemmPrimitive->>infer_sharding_from_operands: Get sharding specs
    infer_sharding_from_operands->>MeshResource: Get fsdp_resource
    MeshResource-->>infer_sharding_from_operands: Returns 'fsdp'
    
    Note over infer_sharding_from_operands: Extract non-contracting specs<br/>from RHS (rhs_non_cspecs)
    
    alt Spec is single string 'fsdp'
        infer_sharding_from_operands->>infer_sharding_from_operands: spec == gsr.fsdp_resource → True
        Note over infer_sharding_from_operands: Replace with None (gather)
    else Spec is tuple ('fsdp', 'expert')
        infer_sharding_from_operands->>infer_sharding_from_operands: isinstance(spec, tuple) → True
        infer_sharding_from_operands->>infer_sharding_from_operands: gsr.fsdp_resource in spec → True
        Note over infer_sharding_from_operands: Replace with None (gather)<br/>NEW: Fixes FSDP+EP case
    else Spec is other
        Note over infer_sharding_from_operands: Keep original spec
    end
    
    infer_sharding_from_operands-->>GemmPrimitive: Return modified specs
    GemmPrimitive-->>Caller: Execute with correct sharding
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/fix-fsdp-when-ep-is-active branch from 4ee4d73 to 08fda9c Compare February 4, 2026 17:10
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L1 jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant