diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 292b114b..3a8b9e1a 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -135,7 +135,7 @@ def __init__( r_maxs = np.array(r_maxs) assert np.all( r_maxs == r_maxs[0] - ), "committee r_max are not all the same {' '.join(r_maxs)}" + ), f"committee r_max are not all the same {' '.join(r_maxs)}" self.r_max = float(r_maxs[0]) self.device = torch_tools.init_device(device) diff --git a/mace/modules/models.py b/mace/modules/models.py index c0d8ab43..05ce3864 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -93,7 +93,7 @@ def __init__( ) edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") if pair_repulsion: - self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff) + self.pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff) self.pair_repulsion = True sh_irreps = o3.Irreps.spherical_harmonics(max_ell) diff --git a/mace/modules/radial.py b/mace/modules/radial.py index 8dba50cb..72f234fa 100644 --- a/mace/modules/radial.py +++ b/mace/modules/radial.py @@ -125,16 +125,16 @@ def __init__(self, r_max: float, p=6): ) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self._calculate_envelope(x, self.r_max, self.p) + return self.calculate_envelope(x, self.r_max, self.p) @staticmethod - def _calculate_envelope(x: torch.Tensor, r_max: float, p: int) -> torch.Tensor: + def calculate_envelope(x: torch.Tensor, r_max: float, p: int) -> torch.Tensor: r_over_r_max = x / r_max envelope = ( - 1.0 - - ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p) - + p * (p + 2.0) * torch.pow(r_over_r_max, p + 1) - - (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2) + 1.0 + - ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p) + + p * (p + 2.0) * torch.pow(r_over_r_max, p + 1) + - (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2) ) return envelope * (x < r_max) @@ -147,6 +147,7 @@ class ZBLBasis(torch.nn.Module): """Implementation of the Ziegler-Biersack-Littmark (ZBL) potential with a polynomial cutoff envelope. """ + p: torch.Tensor def __init__(self, p=6, trainable=False): @@ -203,7 +204,7 @@ def forward( ) v_edges = (14.3996 * Z_u * Z_v) / x * phi r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v] - envelope = PolynomialCutoff._calculate_envelope(x, r_max, self.p) + envelope = PolynomialCutoff.calculate_envelope(x, r_max, self.p) v_edges = 0.5 * v_edges * envelope V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0)) return V_ZBL.squeeze(-1) @@ -255,14 +256,21 @@ def forward( ) Z_u = node_atomic_numbers[sender] Z_v = node_atomic_numbers[receiver] - r_0 : torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) + r_0: torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) r_over_r_0 = x / r_0 return ( - 1 + (self.a * torch.pow(r_over_r_0, self.q) / (1 + torch.pow(r_over_r_0, self.q - self.p))) + 1 + + ( + self.a + * torch.pow(r_over_r_0, self.q) + / (1 + torch.pow(r_over_r_0, self.q - self.p)) + ) ).reciprocal_() def __repr__(self): - return f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})" + return ( + f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})" + ) @simplify_if_compile