diff --git a/dptb/data/AtomicData.py b/dptb/data/AtomicData.py index bc05bfe70..4d0a390f3 100644 --- a/dptb/data/AtomicData.py +++ b/dptb/data/AtomicData.py @@ -15,6 +15,8 @@ from ase.calculators.singlepoint import SinglePointCalculator, SinglePointDFTCalculator from ase.calculators.calculator import all_properties as ase_all_properties from ase.stress import voigt_6_to_full_3x3_stress, full_3x3_to_voigt_6_stress +from ase.data import chemical_symbols +import itertools import torch import e3nn.o3 @@ -882,14 +884,15 @@ def without_nodes(self, which_nodes): assert _ERROR_ON_NO_EDGES in ("true", "false"), "NEQUIP_ERROR_ON_NO_EDGES must be 'true' or 'false'" _ERROR_ON_NO_EDGES = _ERROR_ON_NO_EDGES == "true" + def neighbor_list_and_relative_vec( - pos, - r_max, - self_interaction=False, - reduce=True, - atomic_numbers=None, - cell=None, - pbc=False, + pos, + r_max, + self_interaction=False, + reduce=True, + atomic_numbers=None, + cell=None, + pbc=False, ): """Create neighbor list and neighbor vectors based on radial cutoff. @@ -903,45 +906,69 @@ def neighbor_list_and_relative_vec( Thus, ``edge_index`` has the same convention as the relative vectors: :math:`\\vec{r}_{source, target}` - If the input positions are a tensor with ``requires_grad == True``, - the output displacement vectors will be correctly attached to the inputs - for autograd. - - All outputs are Tensors on the same device as ``pos``; this allows future - optimization of the neighbor list on the GPU. - Args: pos (shape [N, 3]): Positional coordinate; Tensor or numpy array. If Tensor, must be on CPU. - r_max (float): Radial cutoff distance for neighbor finding. + r_max (float, dict): Radial cutoff distance. Can be a global float, or a dictionary. cell (numpy shape [3, 3]): Cell for periodic boundary conditions. Ignored if ``pbc == False``. - pbc (bool or 3-tuple of bool): Whether the system is periodic in each of the three cell dimensions. - self_interaction (bool): Whether or not to include same periodic image self-edges in the neighbor list. - strict_self_interaction (bool): Whether to include *any* self interaction edges in the graph, even if the two - instances of the atom are in different periodic images. Defaults to True, should be True for most applications. + pbc (bool or 3-tuple of bool): Periodic boundary conditions. + self_interaction (bool): Whether to include same periodic image self-edges. + reduce (bool): If True, returns an undirected graph (half edges). If False, returns a full directed graph. + atomic_numbers (array-like): Atomic numbers of the atoms, required if r_max is a dict. Returns: - edge_index (torch.tensor shape [2, num_edges]): List of edges. - edge_cell_shift (torch.tensor shape [num_edges, 3]): Relative cell shift - vectors. Returned only if cell is not None. - cell (torch.Tensor [3, 3]): the cell as a tensor on the correct device. - Returned only if cell is not None. + edge_index (torch.tensor shape [2, num_edges]) + shifts (torch.tensor shape [num_edges, 3]) + cell (torch.Tensor [3, 3]) """ if isinstance(pbc, bool): pbc = (pbc,) * 3 - mask_r = False + # 【优化】:确保如果输入的是 GPU Tensor,能够安全转换 + if atomic_numbers is not None: + if isinstance(atomic_numbers, torch.Tensor): + atomic_numbers_np = atomic_numbers.detach().cpu().numpy() + else: + atomic_numbers_np = np.asarray(atomic_numbers) + else: + atomic_numbers_np = None + + # ------------------------------------------------------------------------- + # 1. Parse r_max to ASE-compatible format for native pair-cutoff filtering + # ------------------------------------------------------------------------- if isinstance(r_max, dict): - _r_max = max(r_max.values()) - if _r_max - min(r_max.values()) > 1e-5: - mask_r = True - - if len(r_max) < len(set(atomic_numbers)): + if atomic_numbers_np is None: + raise ValueError("atomic_numbers must be provided when r_max is a dict.") + + if len(r_max) < len(set(atomic_numbers_np)): raise ValueError("The number of r_max is less than the number of required atom species.") + + first_key = next(iter(r_max.keys())) + key_parts = str(first_key).split("-") + + if len(key_parts) == 1: + # Atom-wise cutoffs: ASE naturally handles array input as R[i] + R[j] < cutoff + r_map = get_r_map(r_max, atomic_numbers_np) + r_map_np = r_map.detach().cpu().numpy() if isinstance(r_map, torch.Tensor) else np.asarray(r_map) + user_cutoff = 0.5 * r_map_np[atomic_numbers_np - 1] + + elif len(key_parts) == 2: + # Pair-wise cutoffs: Convert user string keys to ASE tuple keys + r_map = get_r_map_bondwise(r_max, atomic_numbers_np) + r_map_np = r_map.detach().cpu().numpy() if isinstance(r_map, torch.Tensor) else np.asarray(r_map) + user_cutoff = {} + unique_nums = np.unique(atomic_numbers_np) + for z1 in unique_nums: + for z2 in unique_nums: + user_cutoff[(int(z1), int(z2))] = float(r_map_np[int(z1) - 1, int(z2) - 1]) + else: + raise ValueError("The r_max keys should be either atomic number or atomic number pair.") else: - _r_max = r_max assert isinstance(r_max, (float, int)) + user_cutoff = float(r_max) - # Either the position or the cell may be on the GPU as tensors + # ------------------------------------------------------------------------- + # 2. Setup Device, Tensors, and Geometry + # ------------------------------------------------------------------------- if isinstance(pos, torch.Tensor): temp_pos = pos.detach().cpu().numpy() out_device = pos.device @@ -951,170 +978,88 @@ def neighbor_list_and_relative_vec( out_device = torch.device("cpu") out_dtype = torch.get_default_dtype() - # Right now, GPU tensors require a round trip if out_device.type != "cpu": warnings.warn( "Currently, neighborlists require a round trip to the CPU. Please pass CPU tensors if possible." ) - # Get a cell on the CPU no matter what + # 获取初始 cell 数据 if isinstance(cell, torch.Tensor): temp_cell = cell.detach().cpu().numpy() - cell_tensor = cell.to(device=out_device, dtype=out_dtype) elif cell is not None: temp_cell = np.asarray(cell) - cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype) else: - # ASE will "complete" this correctly. temp_cell = np.zeros((3, 3), dtype=temp_pos.dtype) - cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype) - # ASE dependent part + # ASE 补全缺失的晶格向量 temp_cell = ase.geometry.complete_cell(temp_cell) - first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list( + # 【修复2】:在此处(补全后)生成 cell_tensor,保证输出与 shifts 处于同一坐标系参考标准下 + cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype) + + # ------------------------------------------------------------------------- + # 3. Call core O(N) neighbor search algorithm + # ------------------------------------------------------------------------- + # By default, primitive_neighbor_list returns a fully directed graph representing + # both (i, j, S) and (j, i, -S). It also automatically removes self-edges (i=i, S=0) + # if self_interaction=False. + first_idx, second_idx, shifts = ase.neighborlist.primitive_neighbor_list( "ijS", pbc, temp_cell, temp_pos, - cutoff=float(_r_max), - self_interaction=self_interaction, # we want edges from atom to itself in different periodic images! + cutoff=user_cutoff, + numbers=atomic_numbers_np, + self_interaction=self_interaction, use_scaled_positions=False, ) + # ------------------------------------------------------------------------- + # 4. Handle graph reduction state + # ------------------------------------------------------------------------- + if reduce: + # Convert full directed graph to undirected half-graph + mask_lt = first_idx < second_idx + mask_eq = first_idx == second_idx - # Eliminate true self-edges that don't cross periodic boundaries - # if not self_interaction: - # bad_edge = first_idex == second_idex - # bad_edge &= np.all(shifts == 0, axis=1) - # keep_edge = ~bad_edge - # if _ERROR_ON_NO_EDGES and (not np.any(keep_edge)): - # raise ValueError( - # f"Every single atom has no neighbors within the cutoff r_max={r_max} (after eliminating self edges, no edges remain in this system)" - # ) - # first_idex = first_idex[keep_edge] - # second_idex = second_idex[keep_edge] - # shifts = shifts[keep_edge] + # Deduplicate mirrored periodic boundaries for i == j + eq_first = first_idx[mask_eq] + eq_shifts = shifts[mask_eq] + eq_keep = np.zeros(len(eq_first), dtype=bool) - """ - bond list is: i, j, shift; but i j shift and j i -shift are the same bond. so we need to remove the duplicate bonds.s - first for i != j; we only keep i < j; then the j i -shift will be removed. - then, for i == j; we only keep i i shift and remove i i -shift. - """ - # 1. for i != j, keep i < j - assert atomic_numbers is not None - atomic_numbers = torch.as_tensor(atomic_numbers, dtype=torch.long) - mask = first_idex <= second_idex - first_idex = first_idex[mask] - second_idex = second_idex[mask] - shifts = shifts[mask] - - # 2. for i == j - - mask = torch.ones(len(first_idex), dtype=torch.bool) - mask[first_idex == second_idex] = False - # get index bool type ~mask for i == j. - # Convert mask to numpy for consistent indexing behavior - mask_np = mask.cpu().numpy() - o_first_idex = first_idex[~mask_np] - o_second_idex = second_idex[~mask_np] - o_shift = shifts[~mask_np] - o_mask = mask[~mask] # this is all False, with length being the number all the bonds with i == j. - - # Ensure arrays are proper numpy arrays (not scalars) for isolated systems - o_first_idex = np.atleast_1d(o_first_idex) - o_second_idex = np.atleast_1d(o_second_idex) - o_shift = np.atleast_2d(o_shift) - - # using the dict key to remove the duplicate bonds, because it is O(1) to check if a key is in the dict. - rev_dict = {} - for i in range(len(o_first_idex)): - key = str(o_first_idex[i])+str(o_shift[i]) - key_rev = str(o_first_idex[i])+str(-o_shift[i]) - rev_dict[key] = True - # key_rev is the reverse key of key, if key_rev is in the dict, then the bond is duplicate. - # so, only when key_rev is not in the dict, we keep the bond. that is when rev_dict.get(key_rev, False) is False, we set o_mast = True. - if not (rev_dict.get(key_rev, False) and rev_dict.get(key, False)): - o_mask[i] = True - - if self_interaction: - log.warning("self_interaction is True, but usually we do not want the self-interaction, please check if it is correct.") - # for self-interaction, the above will remove the self-interaction, i.e. i == j, shift == [0, 0, 0]. since -0 = 0. - if (o_shift[i] == np.array([0, 0, 0])).all(): - o_mask[i] = True - - del rev_dict - del o_first_idex - del o_second_idex - del o_shift - mask[~mask] = o_mask - del o_mask - - # Convert mask to numpy for indexing numpy arrays (avoids torch/numpy compatibility issues) - mask_np = mask.cpu().numpy() - first_idex = torch.as_tensor(first_idex[mask_np], dtype=torch.long, device=out_device) - second_idex = torch.as_tensor(second_idex[mask_np], dtype=torch.long, device=out_device) - shifts = torch.as_tensor(shifts[mask_np], dtype=out_dtype, device=out_device) - - if not reduce: - assert self_interaction == False, "for self_interaction = True, i i 0 0 0 will be duplicated." - first_idex, second_idex = torch.cat((first_idex, second_idex), dim=0), torch.cat((second_idex, first_idex), dim=0) - shifts = torch.cat((shifts, -shifts), dim=0) - - # Build output: - edge_index = torch.vstack( - (torch.LongTensor(first_idex), torch.LongTensor(second_idex)) - ) + rev_dict = {} + for i in range(len(eq_first)): + key = f"{eq_first[i]}_{eq_shifts[i]}" + key_rev = f"{eq_first[i]}_{-eq_shifts[i]}" + rev_dict[key] = True - # TODO: mask the edges that is larger than r_max - if mask_r: - edge_vec = pos[edge_index[1]] - pos[edge_index[0]] - if cell is not None : - edge_vec = edge_vec + torch.einsum( - "ni,ij->nj", - shifts, - cell_tensor.reshape(3,3), # remove batch dimension - ) + if not (rev_dict.get(key_rev, False) and rev_dict.get(key, False)): + eq_keep[i] = True - edge_length = torch.linalg.norm(edge_vec, dim=-1) + if self_interaction and (eq_shifts[i] == 0).all(): + eq_keep[i] = True - # atom_species_num = [atomic_num_dict[k] for k in r_max.keys()] - # for i in set(atomic_numbers): - # assert i in atom_species_num - # r_map = torch.zeros(max(atom_species_num)) - # for k, v in r_max.items(): - # r_map[atomic_num_dict[k]-1] = v + # Combine reduction masks + final_mask = mask_lt.copy() + final_mask[mask_eq] = eq_keep - first_key = next(iter(r_max.keys())) - key_parts = first_key.split("-") - - if len(key_parts)==1: - r_map = get_r_map(r_max, atomic_numbers) - edge_length_max = 0.5 * (r_map[atomic_numbers[edge_index[0]]-1] + r_map[atomic_numbers[edge_index[1]]-1]) - - elif len(key_parts)==2: - r_map = get_r_map_bondwise(r_max, atomic_numbers) - edge_length_max = r_map[atomic_numbers[edge_index[0]]-1,atomic_numbers[edge_index[1]]-1] - else: - raise ValueError("The r_max keys should be either atomic number or atomic number pair.") - - r_mask = edge_length <= edge_length_max - if any(~r_mask): - edge_index = edge_index[:, r_mask] - shifts = shifts[r_mask] - # 收集不同类型的边及其对应的最大截断半径 - #edge_types = {} - #for i in range(edge_index.shape[1]): - # atom_type_pair = (atomic_numbers[edge_index[0, i]], atomic_numbers[edge_index[1, i]]) - # if atom_type_pair not in edge_types: - # edge_types[atom_type_pair] = edge_length_max[i].item() - - del edge_length - del edge_vec - del r_map - del edge_length_max - del r_mask + first_idx = first_idx[final_mask] + second_idx = second_idx[final_mask] + shifts = shifts[final_mask] + + # Note: If `reduce=False`, the output of primitive_neighbor_list is exactly the + # full bidirectional graph structure required, so no post-processing is needed. + + # ------------------------------------------------------------------------- + # 5. Build output tensors + # ------------------------------------------------------------------------- + first_idx_t = torch.as_tensor(first_idx, dtype=torch.long, device=out_device) + second_idx_t = torch.as_tensor(second_idx, dtype=torch.long, device=out_device) + shifts_t = torch.as_tensor(shifts, dtype=out_dtype, device=out_device) + + edge_index = torch.vstack((first_idx_t, second_idx_t)) + + return edge_index, shifts_t, cell_tensor - return edge_index, shifts, cell_tensor def get_r_map(r_max: dict, atomic_numbers=None): """ diff --git a/dptb/nn/deeptb.py b/dptb/nn/deeptb.py index 13fbf9846..a17278459 100644 --- a/dptb/nn/deeptb.py +++ b/dptb/nn/deeptb.py @@ -68,6 +68,7 @@ def __init__( dtype: Union[str, torch.dtype] = torch.float32, device: Union[str, torch.device] = torch.device("cpu"), transform: bool = True, + scale_type: str = 'scale_w_back_grad', **kwargs, ): @@ -103,7 +104,7 @@ def __init__( self.device = device self.model_options = {"embedding": embedding.copy(), "prediction": prediction.copy()} self.transform = transform - + self.scale_type = scale_type self.method = prediction.get("method", "e3tb") # self.soc = prediction.get("soc", False) @@ -298,9 +299,11 @@ def forward(self, data: AtomicDataDict.Type): data = self.embedding(data) if hasattr(self, "overlap") and self.method == "sktb": data[AtomicDataDict.EDGE_OVERLAP_KEY] = data[AtomicDataDict.EDGE_FEATURES_KEY] - - data = self.node_prediction_h(data) - data = self.edge_prediction_h(data) + + if self.method != "e3tb" or self.scale_type != "no_scale": + data = self.node_prediction_h(data) + data = self.edge_prediction_h(data) + if hasattr(self, "overlap"): data = self.edge_prediction_s(data) data[AtomicDataDict.NODE_OVERLAP_KEY] = self.overlaponsite_param[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()] diff --git a/dptb/nn/embedding/lem.py b/dptb/nn/embedding/lem.py index 2987e71cd..abdc1b067 100644 --- a/dptb/nn/embedding/lem.py +++ b/dptb/nn/embedding/lem.py @@ -39,6 +39,7 @@ def __init__( avg_num_neighbors: Optional[float] = None, # cutoffs r_start_cos_ratio: float = 0.8, + norm_eps: float = 1e-8, PolynomialCutoff_p: float = 6, cutoff_type: str = "polynomial", # general hyperparameters: @@ -131,6 +132,7 @@ def __init__( cutoff_type=cutoff_type, device=device, dtype=dtype, + norm_eps=norm_eps ) self.layers = torch.nn.ModuleList() @@ -235,6 +237,7 @@ def __init__( latent_dim: int=128, # cutoffs r_start_cos_ratio: float = 0.8, + norm_eps: float = 1e-8, PolynomialCutoff_p: float = 6, cutoff_type: str = "polynomial", device: Union[str, torch.device] = torch.device("cpu"), @@ -290,7 +293,7 @@ def __init__( self.sln_n = SeperableLayerNorm( irreps=self.irreps_out, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -300,7 +303,7 @@ def __init__( self.sln_e = SeperableLayerNorm( irreps=self.irreps_out, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -438,6 +441,7 @@ def __init__( irreps_in: o3.Irreps, irreps_out: o3.Irreps, latent_dim: int, + norm_eps: float = 1e-8, radial_emb: bool=False, radial_channels: list=[128, 128], res_update: bool = True, @@ -470,7 +474,7 @@ def __init__( self.sln = SeperableLayerNorm( irreps=self.irreps_in, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -480,7 +484,7 @@ def __init__( self.sln_e = SeperableLayerNorm( irreps=self.edge_irreps_in, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -614,6 +618,7 @@ def __init__( irreps_in: o3.Irreps, irreps_out: o3.Irreps, latent_dim: int, + norm_eps: float = 1e-8, latent_channels: list=[128, 128], radial_emb: bool=False, radial_channels: list=[128, 128], @@ -675,7 +680,7 @@ def __init__( self.sln_e = SeperableLayerNorm( irreps=self.irreps_in, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -685,7 +690,7 @@ def __init__( self.sln_n = SeperableLayerNorm( irreps=self.irreps_in, - eps=1e-5, + eps=norm_eps, affine=True, normalization='component', std_balance_degrees=True, @@ -806,6 +811,7 @@ def __init__( tp_radial_emb: bool=False, tp_radial_channels: list=[128, 128], # MLP parameters: + norm_eps: float = 1e-8, latent_channels: list=[128, 128], latent_dim: int=128, res_update: bool = True, @@ -842,6 +848,7 @@ def __init__( res_update_ratios_learnable=res_update_ratios_learnable, dtype=dtype, device=device, + norm_eps=norm_eps ) self.node_update = UpdateNode( @@ -857,6 +864,7 @@ def __init__( avg_num_neighbors=avg_num_neighbors, dtype=dtype, device=device, + norm_eps=norm_eps ) def forward(self, latents, node_features, edge_features, node_onehot, edge_index, edge_vector, atom_type, cutoff_coeffs, active_edges): diff --git a/dptb/nn/rescale.py b/dptb/nn/rescale.py index 8aca26aa9..8d9351404 100644 --- a/dptb/nn/rescale.py +++ b/dptb/nn/rescale.py @@ -220,6 +220,7 @@ def __init__( shifts_trainable: bool = False, dtype: Union[str, torch.dtype] = torch.float32, device: Union[str, torch.device] = torch.device("cpu"), + scale_type: str = 'scale_w_back_grad', **kwargs, ): """Sum edges into nodes.""" @@ -233,6 +234,8 @@ def __init__( self.dtype = dtype self.shift_index = [] self.scale_index = [] + self.scale_type = scale_type + self.scales_trainable = scales_trainable start = 0 start_scalar = 0 @@ -293,7 +296,6 @@ def set_scale_shift(self, scales: torch.Tensor=None, shifts: torch.Tensor=None): self.register_buffer("shifts", shifts) - def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: if not (self.has_scales or self.has_shifts): @@ -305,22 +307,31 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: in_field = data[self.field][mask] species_idx = data[AtomicDataDict.EDGE_TYPE_KEY].flatten()[mask] - - assert len(in_field) == len( edge_center[mask] ), "in_field doesnt seem to have correct per-edge shape" if self.has_scales: - in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field + scales = self.scales[species_idx][:, self.scale_index].view(-1, self.irreps_in.dim) + if self.scale_type == 'scale_w_back_grad': + in_field = scales * in_field + elif self.scale_type == 'scale_wo_back_grad': + if self.scales_trainable: + in_field = in_field + in_field.detach() * (scales - 1.0) + else: + in_field = in_field + (in_field * (scales - 1.0)).detach() + else: + raise NotImplementedError + if self.has_shifts: shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar) in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0] - + data[self.out_field][mask] = in_field return data + class E3PerSpeciesScaleShift(torch.nn.Module): """Scale and/or shift a predicted per-atom property based on (learnable) per-species/type parameters. @@ -358,6 +369,7 @@ def __init__( shifts_trainable: bool = False, dtype: Union[str, torch.dtype] = torch.float32, device: Union[str, torch.device] = torch.device("cpu"), + scale_type: str = 'scale_w_back_grad', **kwargs, ): super().__init__() @@ -370,6 +382,8 @@ def __init__( self.scale_index = [] self.dtype = dtype self.device = device + self.scale_type = scale_type + self.scales_trainable = scales_trainable start = 0 start_scalar = 0 @@ -442,7 +456,16 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: species_idx ), "in_field doesnt seem to have correct per-atom shape" if self.has_scales: - in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field + scales = self.scales[species_idx][:, self.scale_index].view(-1, self.irreps_in.dim) + if self.scale_type == 'scale_w_back_grad': + in_field = scales * in_field + elif self.scale_type == 'scale_wo_back_grad': + if self.scales_trainable: + in_field = in_field + in_field.detach() * (scales - 1.0) + else: + in_field = in_field + (in_field * (scales - 1.0)).detach() + else: + raise NotImplementedError if self.has_shifts: shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar) in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0] diff --git a/dptb/nnops/loss.py b/dptb/nnops/loss.py index 4e8d6fcd6..6b3daec77 100644 --- a/dptb/nnops/loss.py +++ b/dptb/nnops/loss.py @@ -643,6 +643,92 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict): return 0.5 * (onsite_loss + hopping_loss) +@Loss.register("hamil_abs_mae") +class HamilLossAbsMAE(nn.Module): + def __init__( + self, + basis: Dict[str, Union[str, list]] = None, + idp: Union[OrbitalMapper, None] = None, + overlap: bool = False, + onsite_shift: bool = False, + dtype: Union[str, torch.dtype] = torch.float32, + device: Union[str, torch.device] = torch.device("cpu"), + **kwargs, + ): + + super(HamilLossAbsMAE, self).__init__() + self.loss1 = nn.L1Loss() + self.loss2 = nn.MSELoss() + self.overlap = overlap + self.device = device + self.onsite_shift = onsite_shift + + if basis is not None: + self.idp = OrbitalMapper(basis, method="e3tb", device=self.device) + if idp is not None: + assert idp == self.idp, "The basis of idp and basis should be the same." + else: + assert idp is not None, "Either basis or idp should be provided." + self.idp = idp + + def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict): + + # ================= 修复 CodeRabbit 审查意见 ================= + if self.onsite_shift: + # 直接复用统一的 shift_mu 函数,保持与 HamilLossAbs/EigHamLoss 语义对齐 + # 内部会同时利用 node 和 edge 的 overlap 贡献综合推导并应用 mu + shift_mu(data, ref_data, self.idp) + # ============================================================ + + # onsite loss + pre_onsite = data[AtomicDataDict.NODE_FEATURES_KEY][ + self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()] + ] + tgt_onsite = ref_data[AtomicDataDict.NODE_FEATURES_KEY][ + self.idp.mask_to_nrme[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()] + ] + + # hopping loss + pre_hopping = data[AtomicDataDict.EDGE_FEATURES_KEY][ + self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()] + ] + tgt_hopping = ref_data[AtomicDataDict.EDGE_FEATURES_KEY][ + self.idp.mask_to_erme[ref_data[AtomicDataDict.EDGE_TYPE_KEY].flatten()] + ] + + pre = torch.cat([pre_onsite, pre_hopping], dim=0) + tgt = torch.cat([tgt_onsite, tgt_hopping], dim=0) + + # ================= 保留 overlap loss 逻辑 ================= + if self.overlap: + # onsite overlap + pre_onsite_ovlp = data[AtomicDataDict.NODE_OVERLAP_KEY][ + self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()] + ] + tgt_onsite_ovlp = ref_data[AtomicDataDict.NODE_OVERLAP_KEY][ + self.idp.mask_to_nrme[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()] + ] + + # hopping overlap + pre_hopping_ovlp = data[AtomicDataDict.EDGE_OVERLAP_KEY][ + self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()] + ] + tgt_hopping_ovlp = ref_data[AtomicDataDict.EDGE_OVERLAP_KEY][ + self.idp.mask_to_erme[ref_data[AtomicDataDict.EDGE_TYPE_KEY].flatten()] + ] + + pre_ovlp = torch.cat([pre_onsite_ovlp, pre_hopping_ovlp], dim=0) + tgt_ovlp = torch.cat([tgt_onsite_ovlp, tgt_hopping_ovlp], dim=0) + + # 将 overlap 特征拼接到 pre/tgt 中一同计算 MAE + pre = torch.cat([pre, pre_ovlp], dim=0) + tgt = torch.cat([tgt, tgt_ovlp], dim=0) + # ========================================================== + + total_loss = self.loss1(pre, tgt) + return total_loss + + @Loss.register("hamil_wt") class HamilLossWT(nn.Module): def __init__( diff --git a/dptb/postprocess/__init__.py b/dptb/postprocess/__init__.py index be849ddfc..6f51a6525 100644 --- a/dptb/postprocess/__init__.py +++ b/dptb/postprocess/__init__.py @@ -1,11 +1,11 @@ from .bandstructure import Band from .totbplas import TBPLaS from .write_block import write_block - +from .write_abacus_csr_file import write_blocks_to_abacus_csr __all__ = [ Band, TBPLaS, write_block, - + write_blocks_to_abacus_csr ] \ No newline at end of file diff --git a/dptb/postprocess/write_abacus_csr_file.py b/dptb/postprocess/write_abacus_csr_file.py new file mode 100644 index 000000000..ff83d3521 --- /dev/null +++ b/dptb/postprocess/write_abacus_csr_file.py @@ -0,0 +1,224 @@ +import os +import lmdb +import pickle +import re +import numpy as np +from scipy.sparse import csr_matrix, coo_matrix +from collections import defaultdict +import ase.data +from scipy.linalg import block_diag +from dftio.constants import ABACUS2DFTIO + +# DFTIO -> ABACUS +DFTIO2ABACUS = {l: M.T.astype(np.float32) for l, M in ABACUS2DFTIO.items()} + +ORBITAL_MAP = {'s': 0, 'p': 1, 'd': 2, 'f': 3, 'g': 4, 'h': 5} +KEY_RE = re.compile(r'^\s*(-?\d+)[ _](-?\d+)[ _](-?\d+)[ _](-?\d+)[ _](-?\d+)\s*$') +H_FACTOR = 13.605698 # Ryd -> eV factor for Hamiltonian + + +def parse_basis_to_l_list(basis_str): + """'2s2p1d' or 'spd' -> [0,0,1,1,2].""" + if basis_str is None: + return [] + s = re.sub(r"\s+", "", str(basis_str).lower()) + if s == "": + return [] + if not re.fullmatch(r"(?:\d*[spdfgh])+", s): + raise ValueError(f"Invalid basis string '{basis_str}'") + tokens = re.findall(r'(\d*)([spdfgh])', s) + lst = [] + for num, ch in tokens: + cnt = int(num) if num else 1 + if ch not in ORBITAL_MAP: + raise ValueError(f"Unsupported orbital '{ch}' in '{basis_str}'") + lst.extend([ORBITAL_MAP[ch]] * cnt) + return lst + + +def find_basis_for_Z_or_symbol(basis_dict, Z): + """Find basis string for atomic number Z (multiple key forms).""" + if Z in basis_dict: + return basis_dict[Z] + sym = ase.data.chemical_symbols[Z] + for key_try in (sym, sym.capitalize(), sym.upper(), str(Z)): + if key_try in basis_dict: + return basis_dict[key_try] + for k, v in basis_dict.items(): + if isinstance(k, str) and k.lower() == sym.lower(): + return v + return None + + +def transform_2_ABACUS(mat, l_lefts, l_rights): + """Transform block from DFTIO ordering to ABACUS ordering.""" + if max(*(list(l_lefts) + list(l_rights))) > 5: + raise NotImplementedError("Only support l = s..h.") + left_mats = [DFTIO2ABACUS[l] for l in l_lefts] + right_mats = [DFTIO2ABACUS[l] for l in l_rights] + left = block_diag(*left_mats) if left_mats else np.eye(0, dtype=np.float32) + right = block_diag(*right_mats) if right_mats else np.eye(0, dtype=np.float32) + return left @ mat @ right.T + + +def write_abacus_csr_format(matrix_dict, matrix_symbol, output_path, step=0): + """Write mapping 'Rx_Ry_Rz' -> csr_matrix into ABACUS text CSR.""" + if not matrix_dict: + print(f"Warning: empty matrix_dict for {matrix_symbol}") + return + first = next(iter(matrix_dict)) + norbits = matrix_dict[first].shape[0] + num_blocks = len(matrix_dict) + with open(output_path, 'w') as f: + f.write(f"STEP: {step}\n") + f.write(f"Matrix Dimension of {matrix_symbol}(R): {norbits}\n") + f.write(f"Matrix number of {matrix_symbol}(R): {num_blocks}\n") + for r_key, sparse_mat in matrix_dict.items(): + r_vector_str = r_key.replace('_', ' ') + nnz = int(sparse_mat.nnz) + f.write(f"{r_vector_str} {nnz}\n") + if nnz > 0: + np.savetxt(f, sparse_mat.data.reshape(1, -1), fmt='%.8e') + np.savetxt(f, sparse_mat.indices.reshape(1, -1), fmt='%d') + np.savetxt(f, sparse_mat.indptr.reshape(1, -1), fmt='%d') + else: + f.write("\n\n\n") + # print(f"Wrote {num_blocks} blocks to {output_path}") + + +def write_blocks_to_abacus_csr(atomic_numbers, basis_dict, blocks_dict, matrix_symbol, output_path, step=0): + """ + Entry function: + atomic_numbers: per-site Z array-like + basis_dict: parse_orbital_files result + blocks_dict: mapping 'i_j_Rx_Ry_Rz' -> small block (DFTIO ordering) + matrix_symbol: 'H'/'S'/'D' + """ + atomic_numbers = np.asarray(atomic_numbers, dtype=int) + if atomic_numbers.size == 0: + raise ValueError("empty atomic_numbers") + + # choose factor + factor = H_FACTOR if str(matrix_symbol).upper() == 'H' else 1.0 + + # element -> l-list + element_l_lists = {} + for Z in np.unique(atomic_numbers): + basis_str = find_basis_for_Z_or_symbol(basis_dict, int(Z)) + + # 1. 检测基组是否缺失 + if basis_str is None: + raise ValueError( + f"Matrix '{matrix_symbol}': find_basis_for_Z_or_symbol() could not find a basis for Z={Z}. " + f"Available keys in basis_dict: {list(basis_dict.keys())}. " + f"Aborting to prevent silent downstream dimension errors in element_l_lists." + ) + else: + ll = parse_basis_to_l_list(basis_str) + # 2. 检测基组字符串是否解析为空 + if not ll: + raise ValueError( + f"Matrix '{matrix_symbol}': parse_basis_to_l_list() returned an empty list " + f"for basis string '{basis_str}' (Z={Z}). " + f"Aborting to prevent silent downstream dimension errors in element_l_lists." + ) + element_l_lists[int(Z)] = ll + + # site norbits + site_norbits = np.array([sum(2 * l + 1 for l in element_l_lists[int(Z)]) for Z in atomic_numbers], dtype=int) + site_norbits_cumsum = np.cumsum(site_norbits) + norbits = int(site_norbits_cumsum[-1]) + + # aggregate COO data per R + r_vector_coo = defaultdict(lambda: {'data': [], 'rows': [], 'cols': []}) + + for raw_key, small_block in blocks_dict.items(): + key = raw_key.decode() if isinstance(raw_key, (bytes, bytearray)) else str(raw_key) + m = KEY_RE.match(key) + if not m: + # skip unparseable keys + continue + i_site = int(m.group(1)) + j_site = int(m.group(2)) + Rx = int(m.group(3)) + Ry = int(m.group(4)) + Rz = int(m.group(5)) + r_str = f"{Rx}_{Ry}_{Rz}" + + # l-lists + l_lefts = element_l_lists[int(atomic_numbers[i_site])] + l_rights = element_l_lists[int(atomic_numbers[j_site])] + + # get ndarray (support sparse objects) + if hasattr(small_block, "toarray"): + block_arr = small_block.toarray() + elif "torch" in str(type(small_block)): + if small_block.is_cuda: + block_arr = small_block.detach().cpu().numpy() + else: + block_arr = small_block.detach().numpy() + else: + block_arr = np.asarray(small_block) + if block_arr.size == 0: + continue + + # transform DFTIO -> ABACUS + transformed = transform_2_ABACUS(block_arr.astype(np.float32), l_lefts, l_rights) + + # offsets + row_offset = int(site_norbits_cumsum[i_site] - site_norbits[i_site]) + col_offset = int(site_norbits_cumsum[j_site] - site_norbits[j_site]) + + coo = coo_matrix(transformed) + if coo.nnz == 0: + continue + + # apply factor (H vs others) + r_vector_coo[r_str]['data'].append((coo.data.astype(np.float32) / factor)) + r_vector_coo[r_str]['rows'].append((coo.row + row_offset).astype(int)) + r_vector_coo[r_str]['cols'].append((coo.col + col_offset).astype(int)) + + # build final CSR dict + reassembled = {} + for r_str, parts in r_vector_coo.items(): + if not parts['data']: + full = csr_matrix((norbits, norbits), dtype=np.float32) + else: + data = np.concatenate(parts['data']).astype(np.float32) + rows = np.concatenate(parts['rows']).astype(int) + cols = np.concatenate(parts['cols']).astype(int) + full = csr_matrix((data, (rows, cols)), shape=(norbits, norbits)) + reassembled[r_str] = full + + write_abacus_csr_format(reassembled, matrix_symbol, output_path, step=step) + return reassembled, norbits + +# demo main +if __name__ == "__main__": + LMDB_PATH = r'E:\deeptb\large_DeepTB\0909\0910_lmdb\train\data.28400.lmdb' + ORBITAL_PATH = r'E:\deeptb\basis_set_test\production_use_dzp\orb_upf\public' + + from dprep.dptb_dpdispatcher import parse_orbital_files + _, basis_dict = parse_orbital_files(ORBITAL_PATH) + + env = lmdb.open(LMDB_PATH, readonly=True, lock=False) + with env.begin() as txn: + rec = txn.get((0).to_bytes(length=4, byteorder='big')) + if rec is None: + raise RuntimeError("No record at index 0") + data = pickle.loads(rec) + env.close() + + atomic_numbers = np.array(data['atomic_numbers'], dtype=int) + + if 'hamiltonian' in data and data['hamiltonian']: + write_blocks_to_abacus_csr( + atomic_numbers=atomic_numbers, + basis_dict=basis_dict, + blocks_dict=data['hamiltonian'], + matrix_symbol='H', + output_path='data-HR-sparse_SPIN0.csr', + step=0 + ) + else: + print("No hamiltonian in record 0.") diff --git a/dptb/utils/argcheck.py b/dptb/utils/argcheck.py index 3ff888fc3..53280b74c 100644 --- a/dptb/utils/argcheck.py +++ b/dptb/utils/argcheck.py @@ -115,9 +115,10 @@ def train_options(): doc_sliding_win_size = "Sliding window size for the average of the latest iterations' loss. Used for the reduce on plateau learning rate scheduler in case of the pairing of large dataset and small batch size. Default: `50`" doc_optimizer = "\ - The optimizer setting for selecting the gradient optimizer of model training. Optimizer supported includes `Adam`, `SGD` and `LBFGS` \n\n\ + The optimizer setting for selecting the gradient optimizer of model training. Optimizer supported includes `Adam`, `AdamW`, `SGD` and `LBFGS` \n\n\ For more information about these optmization algorithm, we refer to:\n\n\ - `Adam`: [Adam: A Method for Stochastic Optimization.](https://arxiv.org/abs/1412.6980)\n\n\ + - `AdamW`: [AdamW: Decoupled Weight Decay Regularization.](https://arxiv.org/abs/1711.05101)\n\n\ - `SGD`: [Stochastic Gradient Descent.](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html)\n\n\ - `LBFGS`: [On the limited memory BFGS method for large scale optimization.](http://users.iems.northwestern.edu/~nocedal/PDFfiles/limited-memory.pdf) \n\n\ " @@ -231,10 +232,11 @@ def LBFGS(): ] def optimizer(): - doc_type = "select type of optimizer, support type includes: `Adam`, `SGD` and `LBFGS`. Default: `Adam`" + doc_type = "select type of optimizer, support type includes: `Adam`, `AdamW`, `SGD` and `LBFGS`. Default: `Adam`" return Variant("type", [ Argument("Adam", dict, Adam()), + Argument("AdamW", dict, Adam()), Argument("SGD", dict, SGD()), Argument("RMSprop", dict, RMSprop()), Argument("LBFGS", dict, LBFGS()), @@ -635,6 +637,7 @@ def slem(): Argument("res_update_ratios", float, optional=True, default=0.5, doc="The ratios of residual update, should in (0,1)."), Argument("res_update_ratios_learnable", bool, optional=True, default=False, doc="Whether to make the ratios of residual update learnable."), Argument("universal", bool, optional=True, default=False, doc=doc_universal), + Argument("norm_eps", float, optional=True, default=1e-8, doc="eps in SeperableLayerNorm."), ] @@ -662,17 +665,21 @@ def sktb_prediction(): def e3tb_prediction(): - doc_scales_trainable = "whether to scale the trianing target." - doc_shifts_trainable = "whether to shift the training target." + doc_scales_trainable = "The scale parameter is from the statistics. Whether to train this parameter." + doc_shifts_trainable = "The scale parameter is from the statistics. Whether to train this parameter." doc_neurons = "neurons in the neural network." doc_activation = "activation function." doc_if_batch_normalized = "if to turn on batch normalization" + doc_scale_type = ("Which scale method to use. Can be no_scale, " + "scale_wo_back_grad (the scale parameter will not engage the back grad computation graph), " + "scale_w_back_grad (the scale parameter will engage the back grad computation graph)") nn = [ Argument("scales_trainable", bool, optional=True, default=False, doc=doc_scales_trainable), Argument("shifts_trainable", bool, optional=True, default=False, doc=doc_shifts_trainable), Argument("neurons", list, optional=True, default=None, doc=doc_neurons), Argument("activation", str, optional=True, default="tanh", doc=doc_activation), + Argument("scale_type", str, optional=True, default="scale_w_back_grad", doc=doc_scale_type), Argument("if_batch_normalized", bool, optional=True, default=False, doc=doc_if_batch_normalized), ] @@ -830,6 +837,7 @@ def loss_options(): - `eigvals`: The mse loss predicted and labeled eigenvalues and Delta eigenvalues between different k. - `hamil`: - `hamil_abs`: + - `hamil_abs_mae`: - `hamil_blas`: """ doc_train = "Loss options for training." @@ -867,6 +875,7 @@ def loss_options(): Argument("eigvals", dict, sub_fields=eigvals), Argument("skints", dict, sub_fields=skints), Argument("hamil_abs", dict, sub_fields=hamil), + Argument("hamil_abs_mae", dict, sub_fields=hamil), Argument("hamil_blas", dict, sub_fields=hamil), Argument("hamil_wt", dict, sub_fields=hamil+wt), Argument("eig_ham", dict, sub_fields=hamil+eigvals+eig_ham), @@ -1756,9 +1765,10 @@ def normalize_skf2nnsk(data): doc_lr_scheduler = "The learning rate scheduler tools settings, the lr scheduler is used to scales down the learning rate during the training process. Proper setting can make the training more stable and efficient. The supported lr schedular includes: `Exponential Decaying (exp)`, `Linear multiplication (linear)`" doc_optimizer = "\ - The optimizer setting for selecting the gradient optimizer of model training. Optimizer supported includes `Adam`, `SGD` and `LBFGS` \n\n\ + The optimizer setting for selecting the gradient optimizer of model training. Optimizer supported includes `Adam`, `AdamW`, `SGD` and `LBFGS` \n\n\ For more information about these optmization algorithm, we refer to:\n\n\ - `Adam`: [Adam: A Method for Stochastic Optimization.](https://arxiv.org/abs/1412.6980)\n\n\ + - `AdamW`: [AdamW: Decoupled Weight Decay Regularization.](https://arxiv.org/abs/1711.05101)\n\n\ - `SGD`: [Stochastic Gradient Descent.](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html)\n\n\ - `LBFGS`: [On the limited memory BFGS method for large scale optimization.](http://users.iems.northwestern.edu/~nocedal/PDFfiles/limited-memory.pdf) \n\n\ " diff --git a/dptb/utils/tools.py b/dptb/utils/tools.py index 8c8e5f1b5..2d6caeadd 100644 --- a/dptb/utils/tools.py +++ b/dptb/utils/tools.py @@ -125,8 +125,6 @@ def update_dict_with_warning(dict_input, update_list, update_value): return reconstruct_dict(flatten_input_dict) - - def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -138,6 +136,8 @@ def setup_seed(seed): def get_optimizer(type: str, model_param, lr: float, **options: dict): if type == 'Adam': optimizer = optim.Adam(params=model_param, lr=lr, **options) + elif type == 'AdamW': + optimizer = optim.AdamW(params=model_param, lr=lr, **options) elif type == 'SGD': optimizer = optim.SGD(params=model_param, lr=lr, **options) elif type == 'RMSprop': @@ -145,7 +145,7 @@ def get_optimizer(type: str, model_param, lr: float, **options: dict): elif type == 'LBFGS': optimizer = optim.LBFGS(params=model_param, lr=lr, **options) else: - raise RuntimeError("Optimizer should be Adam/SGD/RMSprop, not {}".format(type)) + raise RuntimeError("Optimizer should be Adam/AdamW/SGD/RMSprop/LBFGS, not {}".format(type)) return optimizer def get_lr_scheduler(type: str, optimizer: optim.Optimizer, **sch_options):