Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
606 changes: 392 additions & 214 deletions deepmd/entrypoints/test.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions deepmd/jax/utils/auto_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ def __init__(
self,
initial_batch_size: int = 1024,
factor: float = 2.0,
*,
silent: bool = False,
) -> None:
super().__init__(
initial_batch_size=initial_batch_size,
factor=factor,
silent=silent,
)

def is_gpu_available(self) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pd/utils/auto_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ def __init__(
self,
initial_batch_size: int = 1024,
factor: float = 2.0,
*,
silent: bool = False,
) -> None:
super().__init__(
initial_batch_size=initial_batch_size,
factor=factor,
silent=silent,
)

def is_gpu_available(self) -> bool:
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def get_trainer(
finetune_links: dict[str, Any] | None = None,
) -> training.Trainer:
multi_task = "model_dict" in config.get("model", {})
config = normalize(config, multi_task=multi_task, check=False)

def prepare_trainer_input_single(
model_params_single: dict[str, Any],
Expand Down
90 changes: 90 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
KFOptimizerWrapper,
LKFOptimizer,
)
from deepmd.pt.train.validation import (
FullValidator,
resolve_full_validation_start_step,
)
from deepmd.pt.train.wrapper import (
ModelWrapper,
)
Expand Down Expand Up @@ -857,11 +861,89 @@ def single_model_finetune(
self.enable_profiler = training_params.get("enable_profiler", False)
self.profiling = training_params.get("profiling", False)
self.profiling_file = training_params.get("profiling_file", "timeline.json")
validating_params = config.get("validating") or {}
self.full_validator = self._create_full_validator(
validating_params=validating_params,
validation_data=validation_data,
)

# Log model parameter count
if self.rank == 0:
self._log_parameter_count()

def _create_full_validator(
self,
*,
validating_params: dict[str, Any],
validation_data: DpLoaderSet | None,
) -> FullValidator | None:
"""Create the runtime full validator when it is active."""
if not self._is_full_validation_requested(validating_params):
return None
self._raise_if_full_validation_unsupported(validation_data)
if validation_data is None:
raise RuntimeError(
"validation_data must be available after full validation checks."
)
return FullValidator(
validating_params=validating_params,
validation_data=validation_data,
model=self.model,
train_infos=self._get_inner_module().train_infos,
num_steps=self.num_steps,
rank=self.rank,
zero_stage=self.zero_stage,
restart_training=self.restart_training,
checkpoint_dir=Path(self.save_ckpt).parent,
)

def _is_full_validation_requested(self, validating_params: dict[str, Any]) -> bool:
"""Check whether full validation can trigger during this training run."""
if not validating_params.get("full_validation", False):
return False
start_step = resolve_full_validation_start_step(
validating_params.get("full_val_start", 0.5),
self.num_steps,
)
return start_step is not None and start_step <= self.num_steps

def _raise_if_full_validation_unsupported(
self,
validation_data: DpLoaderSet | None,
) -> None:
"""Validate runtime full validation constraints."""
if self.multi_task:
raise ValueError(
"validating.full_validation only supports single-task energy "
"training; multi-task training is not supported."
)

has_spin = getattr(self.model, "has_spin", False)
if callable(has_spin):
has_spin = has_spin()
if has_spin or isinstance(self.loss, EnergySpinLoss):
raise ValueError(
"validating.full_validation only supports single-task energy "
"training; spin-energy training is not supported."
)

if not isinstance(self.loss, EnergyStdLoss):
raise ValueError(
"validating.full_validation only supports single-task energy training."
)

if validation_data is None:
raise ValueError(
"validating.full_validation requires `training.validation_data` "
"to be configured."
)

if self.zero_stage >= 2:
raise ValueError(
"validating.full_validation only supports single-task energy "
"training with training.zero_stage < 2."
)

@staticmethod
def _count_parameters(model: torch.nn.Module) -> tuple[int, int]:
"""
Expand Down Expand Up @@ -1363,6 +1445,14 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
fout, display_step_id, cur_lr, train_results, valid_results
)

if self.full_validator is not None:
self.full_validator.run(
step_id=_step_id,
display_step=display_step_id,
lr=cur_lr,
save_checkpoint=self.save_model,
)

if (
(
(display_step_id) % self.save_freq == 0
Expand Down
Loading
Loading