Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 114 additions & 169 deletions dptb/data/AtomicData.py

Large diffs are not rendered by default.

11 changes: 7 additions & 4 deletions dptb/nn/deeptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu"),
transform: bool = True,
scale_type: str = 'scale_w_back_grad',
**kwargs,
):

Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(
self.device = device
self.model_options = {"embedding": embedding.copy(), "prediction": prediction.copy()}
self.transform = transform

self.scale_type = scale_type

self.method = prediction.get("method", "e3tb")
# self.soc = prediction.get("soc", False)
Expand Down Expand Up @@ -298,9 +299,11 @@ def forward(self, data: AtomicDataDict.Type):
data = self.embedding(data)
if hasattr(self, "overlap") and self.method == "sktb":
data[AtomicDataDict.EDGE_OVERLAP_KEY] = data[AtomicDataDict.EDGE_FEATURES_KEY]

data = self.node_prediction_h(data)
data = self.edge_prediction_h(data)

if self.method != "e3tb" or self.scale_type != "no_scale":
data = self.node_prediction_h(data)
data = self.edge_prediction_h(data)

if hasattr(self, "overlap"):
data = self.edge_prediction_s(data)
data[AtomicDataDict.NODE_OVERLAP_KEY] = self.overlaponsite_param[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]
Expand Down
20 changes: 14 additions & 6 deletions dptb/nn/embedding/lem.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
avg_num_neighbors: Optional[float] = None,
# cutoffs
r_start_cos_ratio: float = 0.8,
norm_eps: float = 1e-8,
PolynomialCutoff_p: float = 6,
cutoff_type: str = "polynomial",
# general hyperparameters:
Expand Down Expand Up @@ -131,6 +132,7 @@ def __init__(
cutoff_type=cutoff_type,
device=device,
dtype=dtype,
norm_eps=norm_eps
)

self.layers = torch.nn.ModuleList()
Expand Down Expand Up @@ -235,6 +237,7 @@ def __init__(
latent_dim: int=128,
# cutoffs
r_start_cos_ratio: float = 0.8,
norm_eps: float = 1e-8,
PolynomialCutoff_p: float = 6,
cutoff_type: str = "polynomial",
device: Union[str, torch.device] = torch.device("cpu"),
Expand Down Expand Up @@ -290,7 +293,7 @@ def __init__(

self.sln_n = SeperableLayerNorm(
irreps=self.irreps_out,
eps=1e-5,
eps=norm_eps,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand All @@ -300,7 +303,7 @@ def __init__(

self.sln_e = SeperableLayerNorm(
irreps=self.irreps_out,
eps=1e-5,
eps=norm_eps,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand Down Expand Up @@ -438,6 +441,7 @@ def __init__(
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
latent_dim: int,
norm_eps: float = 1e-8,
radial_emb: bool=False,
radial_channels: list=[128, 128],
res_update: bool = True,
Expand Down Expand Up @@ -470,7 +474,7 @@ def __init__(

self.sln = SeperableLayerNorm(
irreps=self.irreps_in,
eps=1e-5,
eps=norm_eps,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand All @@ -480,7 +484,7 @@ def __init__(

self.sln_e = SeperableLayerNorm(
irreps=self.edge_irreps_in,
eps=1e-5,
eps=norm_eps,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand Down Expand Up @@ -614,6 +618,7 @@ def __init__(
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
latent_dim: int,
norm_eps: float = 1e-8,
latent_channels: list=[128, 128],
radial_emb: bool=False,
radial_channels: list=[128, 128],
Expand Down Expand Up @@ -675,7 +680,7 @@ def __init__(

self.sln_e = SeperableLayerNorm(
irreps=self.irreps_in,
eps=1e-5,
eps=norm_eps,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand All @@ -685,7 +690,7 @@ def __init__(

self.sln_n = SeperableLayerNorm(
irreps=self.irreps_in,
eps=1e-5,
eps=norm_eps,
affine=True,
normalization='component',
std_balance_degrees=True,
Expand Down Expand Up @@ -806,6 +811,7 @@ def __init__(
tp_radial_emb: bool=False,
tp_radial_channels: list=[128, 128],
# MLP parameters:
norm_eps: float = 1e-8,
latent_channels: list=[128, 128],
latent_dim: int=128,
res_update: bool = True,
Expand Down Expand Up @@ -842,6 +848,7 @@ def __init__(
res_update_ratios_learnable=res_update_ratios_learnable,
dtype=dtype,
device=device,
norm_eps=norm_eps
)

self.node_update = UpdateNode(
Expand All @@ -857,6 +864,7 @@ def __init__(
avg_num_neighbors=avg_num_neighbors,
dtype=dtype,
device=device,
norm_eps=norm_eps
)

def forward(self, latents, node_features, edge_features, node_onehot, edge_index, edge_vector, atom_type, cutoff_coeffs, active_edges):
Expand Down
35 changes: 29 additions & 6 deletions dptb/nn/rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def __init__(
shifts_trainable: bool = False,
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu"),
scale_type: str = 'scale_w_back_grad',
**kwargs,
):
"""Sum edges into nodes."""
Expand All @@ -233,6 +234,8 @@ def __init__(
self.dtype = dtype
self.shift_index = []
self.scale_index = []
self.scale_type = scale_type
self.scales_trainable = scales_trainable

start = 0
start_scalar = 0
Expand Down Expand Up @@ -293,7 +296,6 @@ def set_scale_shift(self, scales: torch.Tensor=None, shifts: torch.Tensor=None):
self.register_buffer("shifts", shifts)



def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:

if not (self.has_scales or self.has_shifts):
Expand All @@ -305,22 +307,31 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
in_field = data[self.field][mask]
species_idx = data[AtomicDataDict.EDGE_TYPE_KEY].flatten()[mask]



assert len(in_field) == len(
edge_center[mask]
), "in_field doesnt seem to have correct per-edge shape"

if self.has_scales:
in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field
scales = self.scales[species_idx][:, self.scale_index].view(-1, self.irreps_in.dim)
if self.scale_type == 'scale_w_back_grad':
in_field = scales * in_field
elif self.scale_type == 'scale_wo_back_grad':
if self.scales_trainable:
in_field = in_field + in_field.detach() * (scales - 1.0)
else:
in_field = in_field + (in_field * (scales - 1.0)).detach()
else:
raise NotImplementedError

if self.has_shifts:
shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar)
in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0]

data[self.out_field][mask] = in_field

Comment thread
coderabbitai[bot] marked this conversation as resolved.
return data


class E3PerSpeciesScaleShift(torch.nn.Module):
"""Scale and/or shift a predicted per-atom property based on (learnable) per-species/type parameters.

Expand Down Expand Up @@ -358,6 +369,7 @@ def __init__(
shifts_trainable: bool = False,
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu"),
scale_type: str = 'scale_w_back_grad',
**kwargs,
):
super().__init__()
Expand All @@ -370,6 +382,8 @@ def __init__(
self.scale_index = []
self.dtype = dtype
self.device = device
self.scale_type = scale_type
self.scales_trainable = scales_trainable

start = 0
start_scalar = 0
Expand Down Expand Up @@ -442,7 +456,16 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
species_idx
), "in_field doesnt seem to have correct per-atom shape"
if self.has_scales:
in_field = self.scales[species_idx][:,self.scale_index].view(-1, self.irreps_in.dim) * in_field
scales = self.scales[species_idx][:, self.scale_index].view(-1, self.irreps_in.dim)
if self.scale_type == 'scale_w_back_grad':
in_field = scales * in_field
elif self.scale_type == 'scale_wo_back_grad':
if self.scales_trainable:
in_field = in_field + in_field.detach() * (scales - 1.0)
else:
in_field = in_field + (in_field * (scales - 1.0)).detach()
else:
raise NotImplementedError
if self.has_shifts:
shifts = self.shifts[species_idx][:,self.shift_index[self.shift_index>=0]].view(-1, self.num_scalar)
in_field[:, self.shift_index>=0] = shifts + in_field[:, self.shift_index>=0]
Expand Down
86 changes: 86 additions & 0 deletions dptb/nnops/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,92 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
return 0.5 * (onsite_loss + hopping_loss)


@Loss.register("hamil_abs_mae")
class HamilLossAbsMAE(nn.Module):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
def __init__(
self,
basis: Dict[str, Union[str, list]] = None,
idp: Union[OrbitalMapper, None] = None,
overlap: bool = False,
onsite_shift: bool = False,
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu"),
**kwargs,
):

super(HamilLossAbsMAE, self).__init__()
self.loss1 = nn.L1Loss()
self.loss2 = nn.MSELoss()
self.overlap = overlap
self.device = device
self.onsite_shift = onsite_shift

if basis is not None:
self.idp = OrbitalMapper(basis, method="e3tb", device=self.device)
if idp is not None:
assert idp == self.idp, "The basis of idp and basis should be the same."
else:
assert idp is not None, "Either basis or idp should be provided."
self.idp = idp

def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):

# ================= 修复 CodeRabbit 审查意见 =================
if self.onsite_shift:
# 直接复用统一的 shift_mu 函数,保持与 HamilLossAbs/EigHamLoss 语义对齐
# 内部会同时利用 node 和 edge 的 overlap 贡献综合推导并应用 mu
shift_mu(data, ref_data, self.idp)
# ============================================================

# onsite loss
pre_onsite = data[AtomicDataDict.NODE_FEATURES_KEY][
self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]
]
tgt_onsite = ref_data[AtomicDataDict.NODE_FEATURES_KEY][
self.idp.mask_to_nrme[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]
]

# hopping loss
pre_hopping = data[AtomicDataDict.EDGE_FEATURES_KEY][
self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]
]
tgt_hopping = ref_data[AtomicDataDict.EDGE_FEATURES_KEY][
self.idp.mask_to_erme[ref_data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]
]

pre = torch.cat([pre_onsite, pre_hopping], dim=0)
tgt = torch.cat([tgt_onsite, tgt_hopping], dim=0)

# ================= 保留 overlap loss 逻辑 =================
if self.overlap:
# onsite overlap
pre_onsite_ovlp = data[AtomicDataDict.NODE_OVERLAP_KEY][
self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]
]
tgt_onsite_ovlp = ref_data[AtomicDataDict.NODE_OVERLAP_KEY][
self.idp.mask_to_nrme[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]
]

# hopping overlap
pre_hopping_ovlp = data[AtomicDataDict.EDGE_OVERLAP_KEY][
self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]
]
tgt_hopping_ovlp = ref_data[AtomicDataDict.EDGE_OVERLAP_KEY][
self.idp.mask_to_erme[ref_data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]
]

pre_ovlp = torch.cat([pre_onsite_ovlp, pre_hopping_ovlp], dim=0)
tgt_ovlp = torch.cat([tgt_onsite_ovlp, tgt_hopping_ovlp], dim=0)

# 将 overlap 特征拼接到 pre/tgt 中一同计算 MAE
pre = torch.cat([pre, pre_ovlp], dim=0)
tgt = torch.cat([tgt, tgt_ovlp], dim=0)
# ==========================================================

total_loss = self.loss1(pre, tgt)
return total_loss
Comment thread
coderabbitai[bot] marked this conversation as resolved.


@Loss.register("hamil_wt")
class HamilLossWT(nn.Module):
def __init__(
Expand Down
4 changes: 2 additions & 2 deletions dptb/postprocess/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .bandstructure import Band
from .totbplas import TBPLaS
from .write_block import write_block

from .write_abacus_csr_file import write_blocks_to_abacus_csr

__all__ = [
Band,
TBPLaS,
write_block,

write_blocks_to_abacus_csr
]
Loading
Loading