Skip to content

[moe] Fix DTensor scatter under TP+SP full DTensor ( due to PR #3515)#3753

Open
githubsgi wants to merge 1 commit into
pytorch:mainfrom
githubsgi:fix/moe-dtensor-scatter-shard
Open

[moe] Fix DTensor scatter under TP+SP full DTensor ( due to PR #3515)#3753
githubsgi wants to merge 1 commit into
pytorch:mainfrom
githubsgi:fix/moe-dtensor-scatter-shard

Conversation

@githubsgi

Copy link
Copy Markdown
Contributor

Both scatter sites in the MoE router operate on router outputs that are DTensors sharded on the token dim (Shard(1)) when sequence parallel is enabled with the full-DTensor / SPMD path:

  • TokenChoiceTopKRouter.get_node_limited_routing_scores: group_mask.scatter
  • MoE.forward: routing_map_BLE = zeros_like(scores_BLE).scatter_

In both cases an in-place scatter_ on a Shard(1) DTensor raises:

RuntimeError: aten.scatter_.value: in-place operations that require
placement changes are not supported. The input has placement
(Shard(dim=1),), but no valid strategy preserves this placement.

DTensor's sharding-propagation table for aten.scatter_.value has no strategy that preserves a Shard(1) input for an in-place write, so it refuses the op even though the write is along the (replicated) expert dim and needs no redistribution.

Switching to an out-of-place scatter avoids the crash but is not placement-stable: DTensor may redistribute the result to Replicate(), which then breaks the downstream routing_map_BLE.sum(dim=(0, 1)) -> Partial(sum) contract that GroupedExperts expects, producing:

ValueError: GroupedExperts.num_local_tokens_per_expert_E: input DTensor
has placements (Replicate(),), but in_src_shardings expects (Partial(sum),).

Run both scatters on local tensors under local_map when the inputs are DTensors, re-wrapping the output with the original Shard(1) placement. This keeps the scatter local per token shard, preserves placements, and is numerically identical.

This restores the local_map workaround for the routing map (removed in #3515, which assumed scatter_ would run as a no-redistribution local op) and adds the same handling to the node-limited routing path, which never had it.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 23, 2026
@pytorch-bot

pytorch-bot Bot commented Jun 23, 2026

Copy link
Copy Markdown

The following ciflow label(s) have been added but CI has not been triggered yet because the workflows are awaiting approval:

  • ciflow/8gpu

Once a maintainer approves the workflows (scroll to the bottom of the PR page), the corresponding CI jobs will be triggered automatically. Please ping one of the reviewers if you do not have access to approve and run workflows.

@githubsgi githubsgi changed the title [moe] Fix DTensor scatter under TP+SP full DTensor [moe] Fix DTensor scatter under TP+SP full DTensor ( due to PR #3515) Jun 23, 2026
Both scatter sites in the MoE router operate on router outputs that are
DTensors sharded on the token dim (Shard(1)) when sequence parallel is
enabled with the full-DTensor / SPMD path:

- TokenChoiceTopKRouter._get_node_limited_routing_scores: group_mask.scatter_
- MoE.forward: routing_map_BLE = zeros_like(scores_BLE).scatter_

In both cases an in-place scatter_ on a Shard(1) DTensor raises:

    RuntimeError: aten.scatter_.value: in-place operations that require
    placement changes are not supported. The input has placement
    (Shard(dim=1),), but no valid strategy preserves this placement.

DTensor's sharding-propagation table for aten.scatter_.value has no strategy
that preserves a Shard(1) input for an in-place write, so it refuses the op
even though the write is along the (replicated) expert dim and needs no
redistribution.

Switching to an out-of-place scatter avoids the crash but is not
placement-stable: DTensor may redistribute the result to Replicate(), which
then breaks the downstream routing_map_BLE.sum(dim=(0, 1)) -> Partial(sum)
contract that GroupedExperts expects, producing:

    ValueError: GroupedExperts.num_local_tokens_per_expert_E: input DTensor
    has placements (Replicate(),), but in_src_shardings expects (Partial(sum),).

Run both scatters on local tensors under local_map when the inputs are
DTensors, re-wrapping the output with the original Shard(1) placement. This
keeps the scatter local per token shard, preserves placements, and is
numerically identical.

This restores the local_map workaround for the routing map (removed in pytorch#3515,
which assumed scatter_ would run as a no-redistribution local op) and adds the
same handling to the node-limited routing path, which never had it.
@githubsgi githubsgi force-pushed the fix/moe-dtensor-scatter-shard branch from 19dd0d8 to 85b0a74 Compare June 23, 2026 23:25
@shuhuayu

Copy link
Copy Markdown
Contributor

Hi @githubsgi, thanks for the report. i checked both cases you mentioned above and it worked well locally for me, can you check whether you have installed torch nightly after this pr pytorch/pytorch#186149 where i added dtensor strategy support.

Besides, there is indeed a bug here when dp_shard = 1 and ep is off under full_dtensor mode, but it is not related to dtensor scatter, which i made pr to fix: #3762

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants