Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
512eeb6
feat(pt_expt): multi-task training support
Apr 15, 2026
9f4d232
fix(dpmodel): wrap fparam/aparam reshape with descriptive ValueError
Apr 16, 2026
9f1f1d8
fix: address CodeQL findings in PR #5397
Apr 16, 2026
f3f5474
fix(pt_expt): access unwrapped module in _compile_model for DDP compat
Apr 16, 2026
665b85a
test(pt_expt): add DDP + torch.compile training tests
Apr 16, 2026
aabb710
feat(pt_expt): use inductor+dynamic for training compile
Apr 16, 2026
f774cd2
test(pt_expt): port silut activation + repformers accessors from #5393
Apr 16, 2026
0b5468e
test(pt_expt): assert virial in compile correctness tests
Apr 16, 2026
9bf006b
test(pt_expt): port silu compile and varying-natoms tests from #5393
Apr 16, 2026
7722f52
test(pt_expt): compare compiled vs uncompiled with varying natoms
Apr 16, 2026
be14ac2
test(pt_expt): cover DPA2/DPA3 in varying-natoms compile correctness
Apr 16, 2026
4c0b8ec
test(pt_expt): exercise DPA2 three-body branch in compile correctness
Apr 16, 2026
80c714c
fix(dpmodel): restore nf in reshapes to fix zero-atom and add silu_ba…
Apr 17, 2026
6158d9c
fix: address CodeQL findings in PR #5397
Apr 17, 2026
c2efbf1
fix(pt): wrap fparam/aparam reshape with descriptive ValueError
Apr 17, 2026
1e694a3
feat(pt_expt): reject DPA1/se_atten_v2 with attention at compile time
Apr 18, 2026
6d39ddf
fix(pt_expt): remove false DPA1 attention compile guard
Apr 18, 2026
23eb6dd
refactor(dpmodel): remove unused get_numb_attn_layer API
Apr 18, 2026
bacd312
fix(test): use real path for PT water data, remove unused API
Apr 18, 2026
f834202
fix(pt_expt): rebuild FX graph after detach node removal to avoid seg…
Apr 18, 2026
447a572
fix(pt_expt): tune inductor options for compile training
Apr 18, 2026
fb25ccb
fix(pt_expt): disable DDPOptimizer to prevent compiled graph splitting
Apr 18, 2026
479900d
fix(test): add .cpu() before .numpy() for GPU-compatible activation t…
Apr 18, 2026
b67a181
fix(pt_expt): revert inductor options that cause numerical divergence
Apr 18, 2026
7ce7352
fix(test): make DDP tests device-adaptive instead of hardcoding CPU
Apr 18, 2026
975db17
fix(test): correct freeze test docstrings to match dpa3 guard
Apr 18, 2026
64dc703
fix(pt_expt): move optimize_ddp into _compile_model, resolve test sym…
Apr 18, 2026
28fbcac
fix(test): backup/restore fparam.npy in TestFparam instead of deleting
Apr 18, 2026
fbb361a
fix(test): skip DDP tests when NCCL is selected with fewer than 2 GPUs
Apr 18, 2026
7739fad
perf(pt2): optimize .pt2 C++ inference path
Apr 20, 2026
19272c2
Merge upstream/master into perf-pt-expt-pt2-cpp
Apr 20, 2026
b7509db
feat(pt2): make nlist nnei dimension dynamic in .pt2 export
Apr 20, 2026
eec2528
fix(pt2): pad nlist in Python eval path for dynamic nnei
Apr 20, 2026
217a587
fix(pt2): move atomic virial check before run_model and reject unsupp…
Apr 20, 2026
8a9fe63
fix(pt2): move nlist padding inside traced fn and strip shape assertions
Apr 20, 2026
711a1f4
fix(test): export test models with atomic_virial=True for .pte/.pt2
Apr 21, 2026
acba914
perf(pt2): cache firstneigh_tensor across timesteps
Apr 21, 2026
9426572
test(pt2): add regression test for oversized nlist with distance sorting
Apr 21, 2026
97870bf
test(pt2): add regression test for oversized nlist with distance sorting
Apr 21, 2026
755c2d3
refactor(pt2): remove need_sorted_nlist_for_lower monkey-patch
Apr 21, 2026
444b4e5
refactor(pt2): remove redundant C++ nnei+1 nlist padding
Apr 21, 2026
583df2e
fix(test): link torch in C++ test binary, fix spin PtExpt guards
Apr 21, 2026
862f560
refactor(cc): unify createNlistTensor and fix oversized nlist handling
Apr 22, 2026
12ad03b
fix(test): remove unused _DESCRIPTOR_DPA1_WITH_ATTN config
Apr 22, 2026
b0096e0
fix(pt2): relax nnei dynamic shape lower bound from 2 to 1
Apr 22, 2026
afa732f
fix(pt2): remove eliminate_dead_code from _strip_shape_assertions
Apr 23, 2026
59973be
fix(pt2): neutralise shape assertions instead of erasing them
Apr 24, 2026
39f84ac
fix(pt2): prevent CUDA NaN in attention backward by disabling kernel …
Apr 24, 2026
57cfe3e
fix(pt2): only apply realize_opcount_threshold=0 on CUDA
Apr 24, 2026
cdbfead
refactor(pt_expt): use extra_nlist_sort override instead of stripping…
Apr 25, 2026
f480456
chore(pt2): note thread-safety constraint of realize_opcount_threshold
Apr 25, 2026
d41eefd
feat(cli): warn instead of raise for --atomic-virial on non-pt_expt o…
Apr 25, 2026
c80db58
refactor(pt2): move atomic_virial check to compute() entry; clarify e…
Apr 25, 2026
72f95f8
fix(cc): guard empty-vector dereferences in NeighborListData and conv…
Apr 25, 2026
43d42e6
test(cc): tighten oversized-nlist sanity check; verify all predictions
Apr 25, 2026
e30206d
fix(cc): apply empty-vector guard to make_inlist and DeepSpinTF; add …
Apr 25, 2026
9d2f577
test(cc): regression test for atomic_virial fail-fast guard
Apr 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,13 @@ def _format_nlist(
axis=-1,
)

if n_nnei > nnei or extra_nlist_sort:
# Order matters for torch.export: Python evaluates `or` left-to-right
# with short-circuit. When `extra_nlist_sort=True` (Python bool) is
# on the left, the right-hand `n_nnei > nnei` is not evaluated, so no
# symbolic guard is registered on the dynamic `n_nnei` dimension.
# Swapping the operands would force the SymInt comparison to run and
# emit an `_assert_scalar` node in the exported graph.
if extra_nlist_sort or n_nnei > nnei:
n_nf, n_nloc, n_nnei = nlist.shape
# make a copy before revise
m_real_nei = nlist >= 0
Expand Down
27 changes: 25 additions & 2 deletions deepmd/entrypoints/convert_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
Any,
)
Expand All @@ -7,11 +8,14 @@
Backend,
)

log = logging.getLogger(__name__)


def convert_backend(
*, # Enforce keyword-only arguments
INPUT: str,
OUTPUT: str,
atomic_virial: bool = False,
**kwargs: Any,
) -> None:
"""Convert a model file from one backend to another.
Expand All @@ -20,12 +24,31 @@ def convert_backend(
----------
INPUT : str
The input model file.
INPUT : str
OUTPUT : str
The output model file.
atomic_virial : bool
If True, export .pt2/.pte models with per-atom virial correction.
This adds ~2.5x inference cost. Default False. Silently ignored
(with a warning) for backends that don't support the flag.
"""
inp_backend: Backend = Backend.detect_backend_by_model(INPUT)()
out_backend: Backend = Backend.detect_backend_by_model(OUTPUT)()
inp_hook = inp_backend.serialize_hook
out_hook = out_backend.deserialize_hook
data = inp_hook(INPUT)
out_hook(OUTPUT, data)
# Forward atomic_virial to pt_expt deserialize_to_file if applicable;
# warn and skip the flag for backends that don't accept it so that
# scripts passing --atomic-virial indiscriminately don't break.
import inspect

sig = inspect.signature(out_hook)
if "do_atomic_virial" in sig.parameters:
out_hook(OUTPUT, data, do_atomic_virial=atomic_virial)
else:
if atomic_virial:
log.warning(
"--atomic-virial is only meaningful for pt_expt .pt2/.pte "
"outputs; ignoring it for output backend %s",
out_backend.name,
)
out_hook(OUTPUT, data)
9 changes: 9 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,15 @@ def main_parser() -> argparse.ArgumentParser:
)
parser_convert_backend.add_argument("INPUT", help="The input model file.")
parser_convert_backend.add_argument("OUTPUT", help="The output model file.")
parser_convert_backend.add_argument(
"--atomic-virial",
action="store_true",
default=False,
help="Export .pt2/.pte models with per-atom virial correction. "
"This adds ~2.5x inference cost but is required for "
"LAMMPS compute/atom virial output. "
"Ignored (with a warning) for other output backends.",
)

# * show model ******************************************************************
parser_show = subparsers.add_parser(
Expand Down
16 changes: 13 additions & 3 deletions deepmd/pt_expt/model/dipole_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import types
from typing import (
Any,
)
Expand All @@ -16,6 +17,7 @@
)

from .make_model import (
_pad_nlist_for_export,
make_model,
)
from .model import (
Expand Down Expand Up @@ -137,6 +139,7 @@ def fn(
aparam: torch.Tensor | None,
) -> dict[str, torch.Tensor]:
extended_coord = extended_coord.detach().requires_grad_(True)
nlist = _pad_nlist_for_export(nlist)
return model.forward_lower(
extended_coord,
extended_atype,
Expand All @@ -147,6 +150,13 @@ def fn(
do_atomic_virial=do_atomic_virial,
)

return make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)
# See make_model.py for the rationale of the pad + monkeypatch.
_orig_need_sort = model.need_sorted_nlist_for_lower
model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model)
try:
traced = make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)
finally:
model.need_sorted_nlist_for_lower = _orig_need_sort
return traced
16 changes: 13 additions & 3 deletions deepmd/pt_expt/model/dos_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import types
from typing import (
Any,
)
Expand All @@ -16,6 +17,7 @@
)

from .make_model import (
_pad_nlist_for_export,
make_model,
)
from .model import (
Expand Down Expand Up @@ -117,6 +119,7 @@ def fn(
aparam: torch.Tensor | None,
) -> dict[str, torch.Tensor]:
extended_coord = extended_coord.detach().requires_grad_(True)
nlist = _pad_nlist_for_export(nlist)
return model.forward_lower(
extended_coord,
extended_atype,
Expand All @@ -127,6 +130,13 @@ def fn(
do_atomic_virial=do_atomic_virial,
)

return make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)
# See make_model.py for the rationale of the pad + monkeypatch.
_orig_need_sort = model.need_sorted_nlist_for_lower
model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model)
try:
traced = make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)
finally:
model.need_sorted_nlist_for_lower = _orig_need_sort
return traced
16 changes: 13 additions & 3 deletions deepmd/pt_expt/model/dp_linear_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import types
from typing import (
Any,
)
Expand All @@ -19,6 +20,7 @@
)

from .make_model import (
_pad_nlist_for_export,
make_model,
)
from .model import (
Expand Down Expand Up @@ -142,6 +144,7 @@ def fn(
aparam: torch.Tensor | None,
) -> dict[str, torch.Tensor]:
extended_coord = extended_coord.detach().requires_grad_(True)
nlist = _pad_nlist_for_export(nlist)
return model.forward_lower(
extended_coord,
extended_atype,
Expand All @@ -152,9 +155,16 @@ def fn(
do_atomic_virial=do_atomic_virial,
)

return make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)
# See make_model.py for the rationale of the pad + monkeypatch.
_orig_need_sort = model.need_sorted_nlist_for_lower
model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model)
try:
traced = make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)
finally:
model.need_sorted_nlist_for_lower = _orig_need_sort
return traced

@classmethod
def update_sel(
Expand Down
18 changes: 15 additions & 3 deletions deepmd/pt_expt/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import types
from typing import (
Any,
)
Expand All @@ -16,6 +17,7 @@
)

from .make_model import (
_pad_nlist_for_export,
make_model,
)
from .model import (
Expand Down Expand Up @@ -139,6 +141,7 @@ def fn(
aparam: torch.Tensor | None,
) -> dict[str, torch.Tensor]:
extended_coord = extended_coord.detach().requires_grad_(True)
nlist = _pad_nlist_for_export(nlist)
return model.forward_lower(
extended_coord,
extended_atype,
Expand All @@ -149,6 +152,15 @@ def fn(
do_atomic_virial=do_atomic_virial,
)

return make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)
# Force `_format_nlist`'s sort branch into the compiled graph so the
# exported model tolerates oversized nlists at runtime — see
# make_model.py for the full rationale.
_orig_need_sort = model.need_sorted_nlist_for_lower
model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model)
try:
traced = make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)
finally:
model.need_sorted_nlist_for_lower = _orig_need_sort
return traced
51 changes: 44 additions & 7 deletions deepmd/pt_expt/model/make_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import math
import types
from typing import (
Any,
)
Expand Down Expand Up @@ -28,6 +29,28 @@
)


def _pad_nlist_for_export(nlist: torch.Tensor) -> torch.Tensor:
"""Append a single ``-1`` column to ``nlist`` for export-time tracing.

Used inside ``forward_common_lower_exportable`` (and its spin counterpart)
so that ``_format_nlist``'s terminal slice ``ret[..., :nnei]`` truncates
to a statically sized output. Without the extra column, torch.export
cannot prove the ``ret.shape[-1] == nnei`` assertion at trace time and
would specialise the dynamic ``nnei`` dim to the sample value.

Combined with the short-circuit order in ``_format_nlist``
(``extra_nlist_sort`` on the left) and the ``need_sorted_nlist_for_lower``
override during tracing, this keeps the compiled graph's ``nnei`` axis
fully dynamic and free of symbolic shape guards.
"""
pad = -torch.ones(
(*nlist.shape[:2], 1),
dtype=nlist.dtype,
device=nlist.device,
)
return torch.cat([nlist, pad], dim=-1)


def _cal_hessian_ext(
model: Any,
kk: str,
Expand Down Expand Up @@ -346,6 +369,7 @@ def fn(
aparam: torch.Tensor | None,
) -> dict[str, torch.Tensor]:
extended_coord = extended_coord.detach().requires_grad_(True)
nlist = _pad_nlist_for_export(nlist)
return model.forward_common_lower(
extended_coord,
extended_atype,
Expand All @@ -356,13 +380,26 @@ def fn(
do_atomic_virial=do_atomic_virial,
)

return make_fx(fn, **make_fx_kwargs)(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
# Force `_format_nlist`'s sort branch into the compiled graph so the
# exported model tolerates oversized nlists at runtime (LAMMPS builds
# nlists with rcut+skin). Combined with the short-circuit order in
# `_format_nlist`, no symbolic guard on the dynamic `nnei` axis is
# emitted.
_orig_need_sort = model.need_sorted_nlist_for_lower
model.need_sorted_nlist_for_lower = types.MethodType(
lambda self: True, model
)
try:
traced = make_fx(fn, **make_fx_kwargs)(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
)
finally:
model.need_sorted_nlist_for_lower = _orig_need_sort
return traced

return CM
16 changes: 13 additions & 3 deletions deepmd/pt_expt/model/polar_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import types
from typing import (
Any,
)
Expand All @@ -16,6 +17,7 @@
)

from .make_model import (
_pad_nlist_for_export,
make_model,
)
from .model import (
Expand Down Expand Up @@ -117,6 +119,7 @@ def fn(
aparam: torch.Tensor | None,
) -> dict[str, torch.Tensor]:
extended_coord = extended_coord.detach().requires_grad_(True)
nlist = _pad_nlist_for_export(nlist)
return model.forward_lower(
extended_coord,
extended_atype,
Expand All @@ -127,6 +130,13 @@ def fn(
do_atomic_virial=do_atomic_virial,
)

return make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)
# See make_model.py for the rationale of the pad + monkeypatch.
_orig_need_sort = model.need_sorted_nlist_for_lower
model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model)
try:
traced = make_fx(fn, **make_fx_kwargs)(
extended_coord, extended_atype, nlist, mapping, fparam, aparam
)
finally:
model.need_sorted_nlist_for_lower = _orig_need_sort
return traced
Loading
Loading