Skip to content

[graph_trainer] Fix DSv3 bucketing order for multinode bitwise numerics with eager#3770

Merged
IvanKobzarev merged 1 commit into
pytorch:mainfrom
IvanKobzarev:ds-bucketorder-fix
Jun 25, 2026
Merged

[graph_trainer] Fix DSv3 bucketing order for multinode bitwise numerics with eager#3770
IvanKobzarev merged 1 commit into
pytorch:mainfrom
IvanKobzarev:ds-bucketorder-fix

Conversation

@IvanKobzarev

Copy link
Copy Markdown
Contributor

DSv3 GraphTrainer numerics sweeps were comparing Eager and GraphTrainer weight hashes after each step. After the bucketing-order investigation, one ordering difference was that GraphTrainer chunked loss can move the lm_head weight use under module_fqn "loss". The non-chunked default transformer buckets ended with only ["norm", "lm_head"], so chunked loss left the final lm_head-related work outside the bucket intended to preserve final-layer ordering.

The default final transformer block bucket remains ["norm", "lm_head"] for non-chunked loss. When GraphTrainer is configured with chunked CE loss, the pass builder calls the bucket helper with chunked_loss_enabled=True, which extends that final bucket to include "loss". That keeps chunked-loss lm_head uses in the same final bucket as the lm_head module and avoids letting GraphTrainer reorder that final dependency relative to the rest of the model, while preserving the original bucket plan for non-chunked loss.

…cs with eager

Bug report:
DSv3 GraphTrainer numerics sweeps were comparing Eager and GraphTrainer weight hashes after each step. After the bucketing-order investigation, one ordering difference was that GraphTrainer chunked loss can move the lm_head weight use under module_fqn "loss". The non-chunked default transformer buckets ended with only ["norm", "lm_head"], so chunked loss left the final lm_head-related work outside the bucket intended to preserve final-layer ordering.

Repro:
Run the 16-GPU DSv3 numerics matrix with weight hashes enabled, for example the FLASH/NOFLASH, CG/NOCG, BS=1/16 sweep with:

```bash
GT_WEIGHT_HASH=1 --debug.deterministic --metrics.perf_metrics_only
```

The BS16 cases exposed weight-hash mismatches and required checking bucket ordering around the final norm/lm_head/loss region.

Explanation and fix:
The default final transformer block bucket remains ["norm", "lm_head"] for non-chunked loss. When GraphTrainer is configured with chunked CE loss, the pass builder calls the bucket helper with `chunked_loss_enabled=True`, which extends that final bucket to include "loss". That keeps chunked-loss lm_head uses in the same final bucket as the lm_head module and avoids letting GraphTrainer reorder that final dependency relative to the rest of the model, while preserving the original bucket plan for non-chunked loss.

The unit test pins the compile-time pass decision so the loss bucket is enabled only when the configured loss is chunked CE, without asserting the helper's literal bucket values.

Test Plan:
```bash
/home/ivankobzarev/local/b/pytorch-env/bin/python -m py_compile \
  torchtitan/experiments/graph_trainer/common_utils.py \
  torchtitan/experiments/graph_trainer/passes.py \
  torchtitan/experiments/graph_trainer/tests/test_passes.py

/home/ivankobzarev/local/b/pytorch-env/bin/python -m unittest \
  torchtitan.experiments.graph_trainer.tests.test_passes.TestDefaultTransformerBlockBuckets
```

Authored with assistance from OpenAI Codex.
@IvanKobzarev IvanKobzarev marked this pull request as ready for review June 24, 2026 16:36
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 24, 2026
@IvanKobzarev IvanKobzarev merged commit 854898a into pytorch:main Jun 25, 2026
20 of 22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants