[moe] Fix DTensor scatter under TP+SP full DTensor ( due to PR #3515)#3753
[moe] Fix DTensor scatter under TP+SP full DTensor ( due to PR #3515)#3753githubsgi wants to merge 1 commit into
Conversation
|
The following ciflow label(s) have been added but CI has not been triggered yet because the workflows are awaiting approval:
Once a maintainer approves the workflows (scroll to the bottom of the PR page), the corresponding CI jobs will be triggered automatically. Please ping one of the reviewers if you do not have access to approve and run workflows. |
Both scatter sites in the MoE router operate on router outputs that are
DTensors sharded on the token dim (Shard(1)) when sequence parallel is
enabled with the full-DTensor / SPMD path:
- TokenChoiceTopKRouter._get_node_limited_routing_scores: group_mask.scatter_
- MoE.forward: routing_map_BLE = zeros_like(scores_BLE).scatter_
In both cases an in-place scatter_ on a Shard(1) DTensor raises:
RuntimeError: aten.scatter_.value: in-place operations that require
placement changes are not supported. The input has placement
(Shard(dim=1),), but no valid strategy preserves this placement.
DTensor's sharding-propagation table for aten.scatter_.value has no strategy
that preserves a Shard(1) input for an in-place write, so it refuses the op
even though the write is along the (replicated) expert dim and needs no
redistribution.
Switching to an out-of-place scatter avoids the crash but is not
placement-stable: DTensor may redistribute the result to Replicate(), which
then breaks the downstream routing_map_BLE.sum(dim=(0, 1)) -> Partial(sum)
contract that GroupedExperts expects, producing:
ValueError: GroupedExperts.num_local_tokens_per_expert_E: input DTensor
has placements (Replicate(),), but in_src_shardings expects (Partial(sum),).
Run both scatters on local tensors under local_map when the inputs are
DTensors, re-wrapping the output with the original Shard(1) placement. This
keeps the scatter local per token shard, preserves placements, and is
numerically identical.
This restores the local_map workaround for the routing map (removed in pytorch#3515,
which assumed scatter_ would run as a no-redistribution local op) and adds the
same handling to the node-limited routing path, which never had it.
19dd0d8 to
85b0a74
Compare
|
Hi @githubsgi, thanks for the report. i checked both cases you mentioned above and it worked well locally for me, can you check whether you have installed torch nightly after this pr pytorch/pytorch#186149 where i added dtensor strategy support. Besides, there is indeed a bug here when dp_shard = 1 and ep is off under full_dtensor mode, but it is not related to dtensor |
Both scatter sites in the MoE router operate on router outputs that are DTensors sharded on the token dim (Shard(1)) when sequence parallel is enabled with the full-DTensor / SPMD path:
In both cases an in-place scatter_ on a Shard(1) DTensor raises:
DTensor's sharding-propagation table for aten.scatter_.value has no strategy that preserves a Shard(1) input for an in-place write, so it refuses the op even though the write is along the (replicated) expert dim and needs no redistribution.
Switching to an out-of-place scatter avoids the crash but is not placement-stable: DTensor may redistribute the result to Replicate(), which then breaks the downstream routing_map_BLE.sum(dim=(0, 1)) -> Partial(sum) contract that GroupedExperts expects, producing:
Run both scatters on local tensors under local_map when the inputs are DTensors, re-wrapping the output with the original Shard(1) placement. This keeps the scatter local per token shard, preserves placements, and is numerically identical.
This restores the local_map workaround for the routing map (removed in #3515, which assumed scatter_ would run as a no-redistribution local op) and adds the same handling to the node-limited routing path, which never had it.