diff --git a/mace/modules/models.py b/mace/modules/models.py index c0d8ab43..05ce3864 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -93,7 +93,7 @@ def __init__( ) edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") if pair_repulsion: - self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff) + self.pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff) self.pair_repulsion = True sh_irreps = o3.Irreps.spherical_harmonics(max_ell) diff --git a/mace/modules/radial.py b/mace/modules/radial.py index a928c184..cae2aa71 100644 --- a/mace/modules/radial.py +++ b/mace/modules/radial.py @@ -4,6 +4,8 @@ # This program is distributed under the MIT License (see MIT.md) ########################################################################################### +import logging + import ase import numpy as np import torch @@ -110,8 +112,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] @compile_mode("script") class PolynomialCutoff(torch.nn.Module): - """ - Equation (8) + """Polynomial cutoff function that goes from 1 to 0 as x goes from 0 to r_max. + Equation (8) -- TODO: from where? """ p: torch.Tensor @@ -119,23 +121,26 @@ class PolynomialCutoff(torch.nn.Module): def __init__(self, r_max: float, p=6): super().__init__() - self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) + self.register_buffer("p", torch.tensor(p, dtype=torch.int)) self.register_buffer( "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) ) def forward(self, x: torch.Tensor) -> torch.Tensor: - # yapf: disable + return self.calculate_envelope(x, self.r_max, self.p.to(torch.int)) + + @staticmethod + def calculate_envelope( + x: torch.Tensor, r_max: torch.Tensor, p: int + ) -> torch.Tensor: + r_over_r_max = x / r_max envelope = ( - 1.0 - - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p) - + self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1) - - (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2) + 1.0 + - ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p) + + p * (p + 2.0) * torch.pow(r_over_r_max, p + 1) + - (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2) ) - # yapf: enable - - # noinspection PyUnresolvedReferences - return envelope * (x < self.r_max) + return envelope * (x < r_max) def __repr__(self): return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})" @@ -143,18 +148,19 @@ def __repr__(self): @compile_mode("script") class ZBLBasis(torch.nn.Module): - """ - Implementation of the Ziegler-Biersack-Littmark (ZBL) potential + """Implementation of the Ziegler-Biersack-Littmark (ZBL) potential + with a polynomial cutoff envelope. """ p: torch.Tensor - r_max: torch.Tensor - def __init__(self, r_max: float, p=6, trainable=False): + def __init__(self, p=6, trainable=False, **kwargs): super().__init__() - self.register_buffer( - "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) - ) + if "r_max" in kwargs: + logging.warning( + "r_max is deprecated. r_max is determined from the covalent radii." + ) + # Pre-calculate the p coefficients for the ZBL potential self.register_buffer( "c", @@ -162,7 +168,7 @@ def __init__(self, r_max: float, p=6, trainable=False): [0.1818, 0.5099, 0.2802, 0.02817], dtype=torch.get_default_dtype() ), ) - self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) + self.register_buffer("p", torch.tensor(p, dtype=torch.int)) self.register_buffer( "covalent_radii", torch.tensor( @@ -170,7 +176,6 @@ def __init__(self, r_max: float, p=6, trainable=False): dtype=torch.get_default_dtype(), ), ) - self.cutoff = PolynomialCutoff(r_max, p) if trainable: self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True)) self.a_prefactor = torch.nn.Parameter( @@ -208,12 +213,7 @@ def forward( ) v_edges = (14.3996 * Z_u * Z_v) / x * phi r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v] - envelope = ( - 1.0 - - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / r_max, self.p) - + self.p * (self.p + 2.0) * torch.pow(x / r_max, self.p + 1) - - (self.p * (self.p + 1.0) / 2) * torch.pow(x / r_max, self.p + 2) - ) * (x < r_max) + envelope = PolynomialCutoff.calculate_envelope(x, r_max, self.p) v_edges = 0.5 * v_edges * envelope V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0)) return V_ZBL.squeeze(-1) @@ -224,8 +224,8 @@ def __repr__(self): @compile_mode("script") class AgnesiTransform(torch.nn.Module): - """ - Agnesi transform see ACEpotentials.jl, JCP 2023, p. 160 + """Agnesi transform - see section on Radial transformations in + ACEpotentials.jl, JCP 2023 (https://doi.org/10.1063/5.0158783). """ def __init__( @@ -265,21 +265,27 @@ def forward( ) Z_u = node_atomic_numbers[sender] Z_v = node_atomic_numbers[receiver] - r_0 = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) + r_0: torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) + r_over_r_0 = x / r_0 return ( - 1 + (self.a * ((x / r_0) ** self.q) / (1 + (x / r_0) ** (self.q - self.p))) - ) ** (-1) + 1 + + ( + self.a + * torch.pow(r_over_r_0, self.q) + / (1 + torch.pow(r_over_r_0, self.q - self.p)) + ) + ).reciprocal_() def __repr__(self): - return f"{self.__class__.__name__}(a={self.a}, q={self.q}, p={self.p})" + return ( + f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})" + ) @simplify_if_compile @compile_mode("script") class SoftTransform(torch.nn.Module): - """ - Soft Transform - """ + """Soft Transform.""" def __init__(self, a: float = 0.2, b: float = 3.0, trainable=False): super().__init__() @@ -312,9 +318,10 @@ def forward( Z_u = node_atomic_numbers[sender] Z_v = node_atomic_numbers[receiver] r_0 = (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) / 4 + r_over_r_0 = x / r_0 y = ( x - + (1 / 2) * torch.tanh(-(x / r_0) - self.a * ((x / r_0) ** self.b)) + + (1 / 2) * torch.tanh(-r_over_r_0 - self.a * torch.pow(r_over_r_0, self.b)) + 1 / 2 ) return y diff --git a/tests/modules/test_radial.py b/tests/modules/test_radial.py new file mode 100644 index 00000000..1d8a0c6d --- /dev/null +++ b/tests/modules/test_radial.py @@ -0,0 +1,83 @@ +import pytest +import torch +from mace.modules.radial import ZBLBasis, AgnesiTransform + +@pytest.fixture +def zbl_basis(): + return ZBLBasis(p=6, trainable=False) + +def test_zbl_basis_initialization(zbl_basis): + assert zbl_basis.p == torch.tensor(6.0) + assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817])) + + assert zbl_basis.a_exp == torch.tensor(0.300) + assert zbl_basis.a_prefactor == torch.tensor(0.4543) + assert not zbl_basis.a_exp.requires_grad + assert not zbl_basis.a_prefactor.requires_grad + +def test_trainable_zbl_basis_initialization(zbl_basis): + zbl_basis = ZBLBasis(p=6, trainable=True) + assert zbl_basis.p == torch.tensor(6.0) + assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817])) + + assert zbl_basis.a_exp == torch.tensor(0.300) + assert zbl_basis.a_prefactor == torch.tensor(0.4543) + assert zbl_basis.a_exp.requires_grad + assert zbl_basis.a_prefactor.requires_grad + +def test_forward(zbl_basis): + x = torch.tensor([1.0, 1.0, 2.0]).unsqueeze(-1) # [n_edges] + node_attrs = torch.tensor([[1, 0], [0, 1]]) # [n_nodes, n_node_features] - one_hot encoding of atomic numbers + edge_index = torch.tensor([[0, 1, 1], [1, 0, 1]]) # [2, n_edges] + atomic_numbers = torch.tensor([1, 6]) # [n_nodes] + output = zbl_basis(x, node_attrs, edge_index, atomic_numbers) + + assert output.shape == torch.Size([node_attrs.shape[0]]) + assert torch.is_tensor(output) + assert torch.allclose( + output, + torch.tensor([0.0031, 0.0031], dtype=torch.get_default_dtype()), + rtol=1e-2 + ) + +@pytest.fixture +def agnesi(): + return AgnesiTransform(trainable=False) + +def test_agnesi_transform_initialization(agnesi: AgnesiTransform): + assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4) + assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4) + assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4) + assert not agnesi.a.requires_grad + assert not agnesi.q.requires_grad + assert not agnesi.p.requires_grad + +def test_trainable_agnesi_transform_initialization(): + agnesi = AgnesiTransform(trainable=True) + + assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4) + assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4) + assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4) + assert agnesi.a.requires_grad + assert agnesi.q.requires_grad + assert agnesi.p.requires_grad + +def test_agnesi_transform_forward(): + agnesi = AgnesiTransform() + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.get_default_dtype()).unsqueeze(-1) + node_attrs = torch.tensor([[0, 1], [1, 0], [0, 1]], dtype=torch.get_default_dtype()) + edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) + atomic_numbers = torch.tensor([1, 6, 8]) + output = agnesi(x, node_attrs, edge_index, atomic_numbers) + assert output.shape == x.shape + assert torch.is_tensor(output) + assert torch.allclose( + output, + torch.tensor( + [0.3646, 0.2175, 0.2089], dtype=torch.get_default_dtype() + ).unsqueeze(-1), + rtol=1e-2 + ) + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file