Skip to content

perf(pt_expt): use inductor+dynamic for torch.compile training#5393

Open
wanghan-iapcm wants to merge 2 commits intodeepmodeling:masterfrom
wanghan-iapcm:feat-pt-expt-compile-dynamic
Open

perf(pt_expt): use inductor+dynamic for torch.compile training#5393
wanghan-iapcm wants to merge 2 commits intodeepmodeling:masterfrom
wanghan-iapcm:feat-pt-expt-compile-dynamic

Conversation

@wanghan-iapcm
Copy link
Copy Markdown
Collaborator

@wanghan-iapcm wanghan-iapcm commented Apr 11, 2026

Summary

  • Replace aot_eager backend + manual nall padding with inductor backend + dynamic=True for training compilation
  • Use make_fx(tracing_mode="symbolic") instead of tracing_mode="real" to capture shape-polymorphic ops
  • Inductor options: shape_padding=True, max_autotune=False, epilogue_fusion=False, triton.cudagraphs=False, max_fusion_size=8
  • Removes ~120 lines of manual padding/recompilation infrastructure (_CompiledModel._recompile, max_nall estimation from 20 sampled batches, etc.)

Speed Benchmark (V100 GPU, se_atten_compressible: rcut=6, sel=120, fitting=[240,240,240], float64)

Mode bs=1 bs=4
Uncompiled 21.8 ms 42.9 ms
Old compiled (aot_eager) 18.1 ms (1.20x) 38.3 ms (1.12x)
New compiled (inductor) 9.8 ms (2.22x) 20.4 ms (2.10x)

Convergence Benchmark (1000 steps, stop_lr=1e-4, se_atten_compressible)

Force validation RMSE (rmse_f_val):

step Uncompiled Old compiled New compiled
1 1.37 1.36 1.36
500 0.412 0.535 0.497
1000 0.291 0.360 0.316

All three converge to similar levels. Spread at step 1000 is within normal run-to-run variation from random batch ordering.

Test plan

  • All 11 training tests pass locally
  • CI passes

# Ghost-atom forces must be scatter-summed back to local atoms
# via ``mapping`` — the same operation ``communicate_extended_output``
# performs in the uncompiled path.
actual_nall = ext_coord.shape[1]
Add max_autotune, epilogue_fusion, triton.cudagraphs, max_fusion_size
options to match the reference implementation.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 11, 2026

📝 Walkthrough

Walkthrough

Training compilation now uses shape‑polymorphic symbolic FX tracing and fully dynamic torch.compile (Inductor, dynamic=True). The prior max_nall padding/recompilation flow and padded tracing were removed; compiled model accepts varying nall across batches without manual re-tracing.

Changes

Cohort / File(s) Summary
Compilation Strategy Overhaul
deepmd/pt_expt/train/training.py
Replaced concrete-shape FX tracing and manual max_nall padding/recompilation with symbolic FX tracing (make_fx(..., tracing_mode="symbolic", _allow_non_fake_inputs=True)) and torch.compile(backend="inductor", dynamic=True). Removed _CompiledModel tracking of max_nall and padding of ext_coord/ext_atype/mapping. Simplified force scatter-add to use mapping directly; tracing now uses a single sampled batch without padded nlist building.
Test Updates
source/tests/pt_expt/test_training.py
Renamed test class and method to reflect dynamic shapes: TestCompiledRecompileTestCompiledDynamicShapes, test_nall_growth_triggers_recompile()test_compiled_handles_varying_nall(). Test adjusted to run multiple training steps and assert finite losses across varying nall, removing explicit recompilation/state assertions.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: switching from aot_eager backend with manual padding to inductor backend with dynamic=True for torch.compile training.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
source/tests/pt_expt/test_training.py (1)

167-177: ⚠️ Potential issue | 🟡 Minor

Add the required 60s timeout to this training test.

This new training-path test does not set the repository’s required timeout, so a compile regression can hang CI instead of failing fast.

As per coding guidelines, **/tests/**/*training*.py: Set training test timeouts to 60 seconds maximum for validation purposes, as real training takes hours or days.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt_expt/test_training.py` around lines 167 - 177, This training
test lacks the repository-required 60s timeout; add a 60-second timeout to
test_compiled_handles_varying_nall (in class TestCompiledDynamicShapes) by
decorating the test method with a timeout decorator (e.g.,
`@pytest.mark.timeout`(60)) and import pytest at the top if not present, so the
test will fail fast instead of hanging CI.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@source/tests/pt_expt/test_training.py`:
- Around line 199-211: The loop needs to record and assert that the
dynamic-shape path ran by capturing distinct nall values returned by the model
and use the trainer's stepping helper instead of calling optimizer.step()
directly: inside the loop keep a set (e.g. observed_nall) and extract nall from
the wrapper's extra output (the _more_loss/aux dict returned by
trainer.wrapper(**inp, cur_lr=lr, label=lab)), add it to the set, call the
trainer's stepping helper (use Trainer._optimizer_step or the public
trainer.step helper instead of trainer.optimizer.step()), and after the loop
assert that len(observed_nall) >= 2 to prove at least two different nall values
were seen; also retain the existing finite-loss assertions.

---

Outside diff comments:
In `@source/tests/pt_expt/test_training.py`:
- Around line 167-177: This training test lacks the repository-required 60s
timeout; add a 60-second timeout to test_compiled_handles_varying_nall (in class
TestCompiledDynamicShapes) by decorating the test method with a timeout
decorator (e.g., `@pytest.mark.timeout`(60)) and import pytest at the top if not
present, so the test will fail fast instead of hanging CI.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 2f7348ce-2f4d-4c7f-b068-a1d65cf52d0b

📥 Commits

Reviewing files that changed from the base of the PR and between baab3e8 and a183f95.

📒 Files selected for processing (2)
  • deepmd/pt_expt/train/training.py
  • source/tests/pt_expt/test_training.py

Comment on lines +199 to +211
# Run several training steps — each may have different nall
trainer.wrapper.train()
trainer.optimizer.zero_grad(set_to_none=True)
inp, lab = trainer.get_data(is_train=True)
lr = trainer.scheduler.get_last_lr()[0]
_, loss, more_loss = trainer.wrapper(**inp, cur_lr=lr, label=lab)
loss.backward()
trainer.optimizer.step()

# max_nall should have grown beyond 1
new_max_nall = compiled_model._max_nall
self.assertGreater(new_max_nall, 1)

# compiled_forward_lower should be a new object
self.assertIsNot(
compiled_model.compiled_forward_lower,
old_compiled_lower,
)

# Loss should be a finite scalar
self.assertFalse(torch.isnan(loss))
self.assertFalse(torch.isinf(loss))
for _ in range(3):
trainer.optimizer.zero_grad(set_to_none=True)
inp, lab = trainer.get_data(is_train=True)
lr = trainer.scheduler.get_last_lr()[0]
_, loss, _more_loss = trainer.wrapper(**inp, cur_lr=lr, label=lab)
loss.backward()
trainer.optimizer.step()

# Loss should be a finite scalar
self.assertFalse(torch.isnan(loss))
self.assertFalse(torch.isinf(loss))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Make this loop prove the dynamic-shape path was exercised.

Right now the test only checks that three compiled steps produce finite loss. It never asserts that any step actually hit a different nall, so it can pass without covering the regression this PR is meant to prevent. It also calls trainer.optimizer.step() directly, which skips Trainer._optimizer_step()’s scheduler update and @torch.compiler.disable guard, so the test no longer matches the real training path. Please record/assert at least two distinct observed nall values and use the trainer helper for stepping.

Minimal fix for the stepping path
-                    trainer.optimizer.step()
+                    trainer._optimizer_step()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Run several training steps — each may have different nall
trainer.wrapper.train()
trainer.optimizer.zero_grad(set_to_none=True)
inp, lab = trainer.get_data(is_train=True)
lr = trainer.scheduler.get_last_lr()[0]
_, loss, more_loss = trainer.wrapper(**inp, cur_lr=lr, label=lab)
loss.backward()
trainer.optimizer.step()
# max_nall should have grown beyond 1
new_max_nall = compiled_model._max_nall
self.assertGreater(new_max_nall, 1)
# compiled_forward_lower should be a new object
self.assertIsNot(
compiled_model.compiled_forward_lower,
old_compiled_lower,
)
# Loss should be a finite scalar
self.assertFalse(torch.isnan(loss))
self.assertFalse(torch.isinf(loss))
for _ in range(3):
trainer.optimizer.zero_grad(set_to_none=True)
inp, lab = trainer.get_data(is_train=True)
lr = trainer.scheduler.get_last_lr()[0]
_, loss, _more_loss = trainer.wrapper(**inp, cur_lr=lr, label=lab)
loss.backward()
trainer.optimizer.step()
# Loss should be a finite scalar
self.assertFalse(torch.isnan(loss))
self.assertFalse(torch.isinf(loss))
# Run several training steps — each may have different nall
trainer.wrapper.train()
for _ in range(3):
trainer.optimizer.zero_grad(set_to_none=True)
inp, lab = trainer.get_data(is_train=True)
lr = trainer.scheduler.get_last_lr()[0]
_, loss, _more_loss = trainer.wrapper(**inp, cur_lr=lr, label=lab)
loss.backward()
trainer._optimizer_step()
# Loss should be a finite scalar
self.assertFalse(torch.isnan(loss))
self.assertFalse(torch.isinf(loss))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt_expt/test_training.py` around lines 199 - 211, The loop needs
to record and assert that the dynamic-shape path ran by capturing distinct nall
values returned by the model and use the trainer's stepping helper instead of
calling optimizer.step() directly: inside the loop keep a set (e.g.
observed_nall) and extract nall from the wrapper's extra output (the
_more_loss/aux dict returned by trainer.wrapper(**inp, cur_lr=lr, label=lab)),
add it to the set, call the trainer's stepping helper (use
Trainer._optimizer_step or the public trainer.step helper instead of
trainer.optimizer.step()), and after the loop assert that len(observed_nall) >=
2 to prove at least two different nall values were seen; also retain the
existing finite-loss assertions.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
deepmd/pt_expt/train/training.py (1)

312-312: ⚠️ Potential issue | 🟠 Major

Remove or use the unused local at Line 312

actual_nall is assigned but never read. This has already been flagged by prior scanning and may fail lint/static-analysis gates.

Suggested change
-        actual_nall = ext_coord.shape[1]
         out: dict[str, torch.Tensor] = {}

As per coding guidelines: **/*.py: Install linter and run ruff check . before committing changes or the CI will fail.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/train/training.py` at line 312, Remove the unused local
variable assignment actual_nall = ext_coord.shape[1] (it is assigned but never
read); either delete this line or replace its usage where intended—search for
actual_nall or ext_coord in the surrounding function in training.py and, if a
column count is needed, use ext_coord.shape[1] directly or assign to a used
variable name so the value is consumed.
🧹 Nitpick comments (1)
deepmd/pt_expt/train/training.py (1)

221-225: Avoid mutating caller-owned compile_opts in place

pop()/setdefault() currently mutate the dict from training config. A local copy is safer and avoids side effects if the same config is reused.

Suggested change
 def _trace_and_compile(
@@
-    # Override backend and dynamic — the inductor backend with
+    # Work on a local copy to avoid mutating caller-owned config.
+    compile_opts = deepcopy(compile_opts)
+
+    # Override backend and dynamic — the inductor backend with
     # dynamic=True handles varying shapes automatically.
     compile_opts.pop("dynamic", None)
     compile_opts.pop("backend", None)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/train/training.py` around lines 221 - 225, The code mutates
the caller-owned dict compile_opts using pop() and by adding "options"; instead
create a shallow copy (e.g., local_compile_opts = compile_opts.copy()) and
operate on that copy, then use local_compile_opts.pop("dynamic", None),
local_compile_opts.pop("backend", None) and ensure "options" exists on
local_compile_opts before setting opts = local_compile_opts["options"]; leave
the original compile_opts untouched so callers can reuse it.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 312: Remove the unused local variable assignment actual_nall =
ext_coord.shape[1] (it is assigned but never read); either delete this line or
replace its usage where intended—search for actual_nall or ext_coord in the
surrounding function in training.py and, if a column count is needed, use
ext_coord.shape[1] directly or assign to a used variable name so the value is
consumed.

---

Nitpick comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 221-225: The code mutates the caller-owned dict compile_opts using
pop() and by adding "options"; instead create a shallow copy (e.g.,
local_compile_opts = compile_opts.copy()) and operate on that copy, then use
local_compile_opts.pop("dynamic", None), local_compile_opts.pop("backend", None)
and ensure "options" exists on local_compile_opts before setting opts =
local_compile_opts["options"]; leave the original compile_opts untouched so
callers can reuse it.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 547d2b4a-1e06-47cc-9b70-259db8391305

📥 Commits

Reviewing files that changed from the base of the PR and between a183f95 and 4ebce58.

📒 Files selected for processing (1)
  • deepmd/pt_expt/train/training.py

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 11, 2026

Codecov Report

❌ Patch coverage is 97.05882% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 80.32%. Comparing base (baab3e8) to head (4ebce58).

Files with missing lines Patch % Lines
deepmd/pt_expt/train/training.py 97.05% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5393      +/-   ##
==========================================
- Coverage   80.33%   80.32%   -0.02%     
==========================================
  Files         819      819              
  Lines       85356    85316      -40     
  Branches     4139     4139              
==========================================
- Hits        68571    68529      -42     
- Misses      15509    15510       +1     
- Partials     1276     1277       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants