Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
512eeb6
feat(pt_expt): multi-task training support
Apr 15, 2026
9f4d232
fix(dpmodel): wrap fparam/aparam reshape with descriptive ValueError
Apr 16, 2026
9f1f1d8
fix: address CodeQL findings in PR #5397
Apr 16, 2026
f3f5474
fix(pt_expt): access unwrapped module in _compile_model for DDP compat
Apr 16, 2026
665b85a
test(pt_expt): add DDP + torch.compile training tests
Apr 16, 2026
aabb710
feat(pt_expt): use inductor+dynamic for training compile
Apr 16, 2026
f774cd2
test(pt_expt): port silut activation + repformers accessors from #5393
Apr 16, 2026
0b5468e
test(pt_expt): assert virial in compile correctness tests
Apr 16, 2026
9bf006b
test(pt_expt): port silu compile and varying-natoms tests from #5393
Apr 16, 2026
7722f52
test(pt_expt): compare compiled vs uncompiled with varying natoms
Apr 16, 2026
be14ac2
test(pt_expt): cover DPA2/DPA3 in varying-natoms compile correctness
Apr 16, 2026
4c0b8ec
test(pt_expt): exercise DPA2 three-body branch in compile correctness
Apr 16, 2026
80c714c
fix(dpmodel): restore nf in reshapes to fix zero-atom and add silu_ba…
Apr 17, 2026
6158d9c
fix: address CodeQL findings in PR #5397
Apr 17, 2026
c2efbf1
fix(pt): wrap fparam/aparam reshape with descriptive ValueError
Apr 17, 2026
1e694a3
feat(pt_expt): reject DPA1/se_atten_v2 with attention at compile time
Apr 18, 2026
6d39ddf
fix(pt_expt): remove false DPA1 attention compile guard
Apr 18, 2026
23eb6dd
refactor(dpmodel): remove unused get_numb_attn_layer API
Apr 18, 2026
bacd312
fix(test): use real path for PT water data, remove unused API
Apr 18, 2026
f834202
fix(pt_expt): rebuild FX graph after detach node removal to avoid seg…
Apr 18, 2026
447a572
fix(pt_expt): tune inductor options for compile training
Apr 18, 2026
fb25ccb
fix(pt_expt): disable DDPOptimizer to prevent compiled graph splitting
Apr 18, 2026
479900d
fix(test): add .cpu() before .numpy() for GPU-compatible activation t…
Apr 18, 2026
b67a181
fix(pt_expt): revert inductor options that cause numerical divergence
Apr 18, 2026
7ce7352
fix(test): make DDP tests device-adaptive instead of hardcoding CPU
Apr 18, 2026
975db17
fix(test): correct freeze test docstrings to match dpa3 guard
Apr 18, 2026
64dc703
fix(pt_expt): move optimize_ddp into _compile_model, resolve test sym…
Apr 18, 2026
28fbcac
fix(test): backup/restore fparam.npy in TestFparam instead of deleting
Apr 18, 2026
fbb361a
fix(test): skip DDP tests when NCCL is selected with fewer than 2 GPUs
Apr 18, 2026
17aa5d7
fix(test): force single-threaded reductions in compiled-vs-uncompiled…
Apr 20, 2026
88d3772
feat(pt_expt): warn on compiled DPA1+attention; drop flaky compile test
Apr 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,10 @@ def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()

def get_numb_attn_layer(self) -> int:
"""Returns the number of se_atten attention layers."""
return self.se_atten.attn_layer

def share_params(
self, base_class: "DescrptDPA1", shared_level: int, resume: bool = False
) -> NoReturn:
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,14 @@ def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.rcut

def get_rcut_smth(self) -> float:
"""Returns the radius where the neighbor information starts to smoothly decay to 0."""
return self.rcut_smth

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection

def get_nsel(self) -> int:
"""Returns the number of selected atoms in the cut-off radius."""
return sum(self.sel)
Expand Down
31 changes: 21 additions & 10 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def compute_input_stats(
stat_file_path : Optional[DPPath]
The path to the stat file.
"""
self._param_stats: dict[str, list[StatItem]] = {}
if self.numb_fparam == 0 and self.numb_aparam == 0:
# skip data statistics
return
Expand Down Expand Up @@ -296,6 +297,7 @@ def compute_input_stats(
self._save_param_stats_to_file(
stat_file_path, "fparam", fparam_stats
)
self._param_stats["fparam"] = fparam_stats
fparam_avg = np.array(
[s.compute_avg() for s in fparam_stats], dtype=np.float64
)
Expand Down Expand Up @@ -362,6 +364,7 @@ def compute_input_stats(
self._save_param_stats_to_file(
stat_file_path, "aparam", aparam_stats
)
self._param_stats["aparam"] = aparam_stats
aparam_avg = np.array(
[s.compute_avg() for s in aparam_stats], dtype=np.float64
)
Expand Down Expand Up @@ -407,6 +410,10 @@ def _load_param_stats_from_file(
for ii in range(numb)
]

def get_param_stats(self) -> dict[str, list[StatItem]]:
"""Get the stored fparam/aparam statistics (populated by compute_input_stats)."""
return getattr(self, "_param_stats", {})

@abstractmethod
def _net_out_dim(self) -> int:
"""Set the FittingNet output dim."""
Expand Down Expand Up @@ -666,11 +673,13 @@ def _call_common(
# check fparam dim, concate to input descriptor
if self.numb_fparam > 0:
assert fparam is not None, "fparam should not be None"
if fparam.shape[-1] != self.numb_fparam:
try:
fparam = xp.reshape(fparam, (nf, self.numb_fparam))
except (ValueError, RuntimeError) as e:
raise ValueError(
f"get an input fparam of dim {fparam.shape[-1]}, "
f"which is not consistent with {self.numb_fparam}."
)
f"input fparam: cannot reshape {fparam.shape} "
f"into ({nf}, {self.numb_fparam})."
Comment thread
wanghan-iapcm marked this conversation as resolved.
) from e
fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...]
fparam = xp.tile(
xp.reshape(fparam, (nf, 1, self.numb_fparam)), (1, nloc, 1)
Expand All @@ -687,12 +696,13 @@ def _call_common(
# check aparam dim, concate to input descriptor
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
assert aparam is not None, "aparam should not be None"
if aparam.shape[-1] != self.numb_aparam:
try:
aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam))
except (ValueError, RuntimeError) as e:
raise ValueError(
f"get an input aparam of dim {aparam.shape[-1]}, "
f"which is not consistent with {self.numb_aparam}."
)
aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam))
f"input aparam: cannot reshape {aparam.shape} "
f"into ({nf}, {nloc}, {self.numb_aparam})."
) from e
aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...]
xx = xp.concat(
[xx, aparam],
Expand Down Expand Up @@ -735,7 +745,8 @@ def _call_common(
)
for type_i in range(self.ntypes):
mask = xp.tile(
xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, net_dim_out)
xp.reshape((atype == type_i), (nf, nloc, 1)),
(1, 1, net_dim_out),
)
atom_property = self.nets[(type_i,)](xx)
if self.remove_vaccum_contribution is not None and not (
Expand Down
8 changes: 6 additions & 2 deletions deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def _make_env_mat(
xp = array_api_compat.array_namespace(nlist)
nf, nloc, nnei = nlist.shape
# nf x nall x 3
coord = xp.reshape(coord, (nf, -1, 3))
# Callers may pass either (nf, nall*3) or (nf, nall, 3); normalise
# both to (nf, nall, 3) using shape-based inference so the concrete nf
# value is not baked into the reshape.
if coord.ndim == 2:
coord = xp.reshape(coord, (-1, coord.shape[1] // 3, 3))
mask = nlist >= 0
nlist = nlist * xp.astype(mask, nlist.dtype)
# nf x (nloc x nnei) x 3
Expand All @@ -77,7 +81,7 @@ def _make_env_mat(
# nf x nloc x nnei x 3
coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3))
# nf x nloc x 1 x 3
coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (nf, -1, 1, 3))
coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (nf, nloc, 1, 3))
# nf x nloc x nnei x 3
diff = coord_r - coord_l
# nf x nloc x nnei
Expand Down
69 changes: 69 additions & 0 deletions deepmd/dpmodel/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,75 @@
)


def merge_env_stat(
base_obj: Union["Descriptor", "DescriptorBlock"],
link_obj: Union["Descriptor", "DescriptorBlock"],
model_prob: float = 1.0,
) -> None:
"""Merge descriptor env mat stats from link_obj into base_obj.

Uses probability-weighted merging: merged = base_stats + link_stats * model_prob,
where model_prob = link_prob / base_prob.
Mutates base_obj.stats for chaining (3+ models).

Parameters
----------
base_obj : Descriptor or DescriptorBlock
The base descriptor whose stats will be updated.
link_obj : Descriptor or DescriptorBlock
The linked descriptor whose stats will be merged in.
model_prob : float
The probability weight ratio (link_prob / base_prob).
"""
if (
getattr(base_obj, "stats", None) is None
or getattr(link_obj, "stats", None) is None
):
return
if getattr(base_obj, "set_stddev_constant", False) and getattr(
base_obj, "set_davg_zero", False
):
return

# Weighted merge of StatItem objects
base_stats = base_obj.stats
link_stats = link_obj.stats
merged_stats = {}
for kk in base_stats:
merged_stats[kk] = base_stats[kk] + link_stats[kk] * model_prob

# Compute mean/stddev from merged stats
base_env = EnvMatStatSe(base_obj)
base_env.stats = merged_stats
mean, stddev = base_env()

# Update base_obj stats for chaining
base_obj.stats = merged_stats

# Update buffers in-place: davg/dstd (simple) or mean/stddev (blocks)
# mean/stddev are numpy arrays; convert to match the buffer's backend
if hasattr(base_obj, "davg"):
xp = array_api_compat.array_namespace(base_obj.dstd)
device = array_api_compat.device(base_obj.dstd)
if not getattr(base_obj, "set_davg_zero", False):
base_obj.davg[...] = xp.asarray(
mean, dtype=base_obj.davg.dtype, device=device
)
base_obj.dstd[...] = xp.asarray(
stddev, dtype=base_obj.dstd.dtype, device=device
)
elif hasattr(base_obj, "mean"):
xp = array_api_compat.array_namespace(base_obj.stddev)
device = array_api_compat.device(base_obj.stddev)
if not getattr(base_obj, "set_davg_zero", False):
base_obj.mean[...] = xp.asarray(
mean, dtype=base_obj.mean.dtype, device=device
)
base_obj.stddev[...] = xp.asarray(
stddev, dtype=base_obj.stddev.dtype, device=device
)


class EnvMatStat(BaseEnvMatStat):
def compute_stat(self, env_mat: dict[str, Array]) -> dict[str, StatItem]:
"""Compute the statistics of the environment matrix for a single system.
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,10 @@ def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()

def get_numb_attn_layer(self) -> int:
"""Returns the number of se_atten attention layers."""
return self.se_atten.attn_layer

def share_params(
self, base_class: Any, shared_level: int, resume: bool = False
) -> None:
Expand Down
12 changes: 6 additions & 6 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,10 +779,10 @@ def _forward_common(
assert fparam is not None, "fparam should not be None"
assert self.fparam_avg is not None
assert self.fparam_inv_std is not None
if fparam.shape[-1] != self.numb_fparam:
if fparam.numel() != nf * self.numb_fparam:
raise ValueError(
"get an input fparam of dim {fparam.shape[-1]}, ",
"which is not consistent with {self.numb_fparam}.",
f"input fparam: cannot reshape {list(fparam.shape)} "
f"into ({nf}, {self.numb_fparam})."
)
fparam = fparam.view([nf, self.numb_fparam])
nb, _ = fparam.shape
Expand All @@ -804,10 +804,10 @@ def _forward_common(
assert aparam is not None, "aparam should not be None"
assert self.aparam_avg is not None
assert self.aparam_inv_std is not None
if aparam.shape[-1] != self.numb_aparam:
if aparam.numel() % (nf * self.numb_aparam) != 0:
raise ValueError(
f"get an input aparam of dim {aparam.shape[-1]}, ",
f"which is not consistent with {self.numb_aparam}.",
f"input aparam: cannot reshape {list(aparam.shape)} "
f"into ({nf}, nloc, {self.numb_aparam})."
)
aparam = aparam.view([nf, -1, self.numb_aparam])
nb, nloc, _ = aparam.shape
Comment thread
wanghan-iapcm marked this conversation as resolved.
Expand Down
28 changes: 28 additions & 0 deletions deepmd/pt_expt/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
cast_precision,
)
from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP
from deepmd.dpmodel.utils.env_mat_stat import (
merge_env_stat,
)
from deepmd.pt_expt.common import (
torch_module,
)
Expand All @@ -26,6 +29,31 @@
class DescrptDPA1(DescrptDPA1DP):
_update_sel_cls = UpdateSel

def share_params(
self,
base_class: Any,
shared_level: int,
model_prob: float = 1.0,
resume: bool = False,
) -> None:
"""Share parameters with base_class for multi-task training.

Level 0: share type_embedding and se_atten (all modules and buffers).
Level 1: share type_embedding only.
"""
assert self.__class__ == base_class.__class__, (
"Only descriptors of the same type can share params!"
)
if shared_level == 0:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
if not resume:
merge_env_stat(base_class.se_atten, self.se_atten, model_prob)
self._modules["se_atten"] = base_class._modules["se_atten"]
elif shared_level == 1:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
else:
raise NotImplementedError

def enable_compression(
self,
min_nbor_dist: float,
Expand Down
44 changes: 44 additions & 0 deletions deepmd/pt_expt/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
build_multiple_neighbor_list,
get_multiple_nlist_key,
)
from deepmd.dpmodel.utils.env_mat_stat import (
merge_env_stat,
)
from deepmd.pt_expt.common import (
torch_module,
)
Expand All @@ -30,6 +33,47 @@
class DescrptDPA2(DescrptDPA2DP):
_update_sel_cls = UpdateSel

def share_params(
self,
base_class: "DescrptDPA2",
shared_level: int,
model_prob: float = 1.0,
resume: bool = False,
) -> None:
"""Share parameters with base_class for multi-task training.

Level 0: share type_embedding, repinit, repinit_three_body,
g1_shape_tranform, and repformers.
Level 1: share type_embedding only.
"""
assert self.__class__ == base_class.__class__, (
"Only descriptors of the same type can share params!"
)
if shared_level == 0:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
if not resume:
merge_env_stat(base_class.repinit, self.repinit, model_prob)
if self.use_three_body and "repinit_three_body" in base_class._modules:
merge_env_stat(
base_class.repinit_three_body,
self.repinit_three_body,
model_prob,
)
merge_env_stat(base_class.repformers, self.repformers, model_prob)
self._modules["repinit"] = base_class._modules["repinit"]
if self.use_three_body and "repinit_three_body" in base_class._modules:
self._modules["repinit_three_body"] = base_class._modules[
"repinit_three_body"
]
self._modules["g1_shape_tranform"] = base_class._modules[
"g1_shape_tranform"
]
self._modules["repformers"] = base_class._modules["repformers"]
elif shared_level == 1:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
else:
raise NotImplementedError

def enable_compression(
self,
min_nbor_dist: float,
Expand Down
28 changes: 28 additions & 0 deletions deepmd/pt_expt/descriptor/dpa3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

from deepmd.dpmodel.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3DP
from deepmd.dpmodel.utils.env_mat_stat import (
merge_env_stat,
)
from deepmd.pt_expt.common import (
torch_module,
)
Expand All @@ -16,3 +19,28 @@
@torch_module
class DescrptDPA3(DescrptDPA3DP):
_update_sel_cls = UpdateSel

def share_params(
self,
base_class: "DescrptDPA3",
shared_level: int,
model_prob: float = 1.0,
resume: bool = False,
) -> None:
"""Share parameters with base_class for multi-task training.

Level 0: share type_embedding and repflows.
Level 1: share type_embedding only.
"""
assert self.__class__ == base_class.__class__, (
"Only descriptors of the same type can share params!"
)
if shared_level == 0:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
if not resume:
merge_env_stat(base_class.repflows, self.repflows, model_prob)
self._modules["repflows"] = base_class._modules["repflows"]
elif shared_level == 1:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
else:
raise NotImplementedError
Loading
Loading