Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up unused Polynomial Cutoff Class from ZBLBasis, remove r_max argument. #569

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
)
edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e")
if pair_repulsion:
self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff)
self.pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff)
self.pair_repulsion = True

sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
Expand Down
81 changes: 44 additions & 37 deletions mace/modules/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################

import logging

import ase
import numpy as np
import torch
Expand Down Expand Up @@ -110,67 +112,70 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]

@compile_mode("script")
class PolynomialCutoff(torch.nn.Module):
"""
Equation (8)
"""Polynomial cutoff function that goes from 1 to 0 as x goes from 0 to r_max.
Equation (8) -- TODO: from where?
"""

p: torch.Tensor
r_max: torch.Tensor

def __init__(self, r_max: float, p=6):
super().__init__()
self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype()))
self.register_buffer("p", torch.tensor(p, dtype=torch.int))
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# yapf: disable
return self.calculate_envelope(x, self.r_max, self.p.to(torch.int))

@staticmethod
def calculate_envelope(
x: torch.Tensor, r_max: torch.Tensor, p: int
) -> torch.Tensor:
r_over_r_max = x / r_max
envelope = (
1.0
- ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p)
+ self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1)
- (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2)
1.0
- ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p)
+ p * (p + 2.0) * torch.pow(r_over_r_max, p + 1)
- (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2)
)
# yapf: enable

# noinspection PyUnresolvedReferences
return envelope * (x < self.r_max)
return envelope * (x < r_max)

def __repr__(self):
return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})"


@compile_mode("script")
class ZBLBasis(torch.nn.Module):
"""
Implementation of the Ziegler-Biersack-Littmark (ZBL) potential
"""Implementation of the Ziegler-Biersack-Littmark (ZBL) potential
with a polynomial cutoff envelope.
"""

p: torch.Tensor
r_max: torch.Tensor

def __init__(self, r_max: float, p=6, trainable=False):
def __init__(self, p=6, trainable=False, **kwargs):
super().__init__()
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)
if "r_max" in kwargs:
logging.warning(
"r_max is deprecated. r_max is determined from the covalent radii."
)

# Pre-calculate the p coefficients for the ZBL potential
self.register_buffer(
"c",
torch.tensor(
[0.1818, 0.5099, 0.2802, 0.02817], dtype=torch.get_default_dtype()
),
)
self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype()))
self.register_buffer("p", torch.tensor(p, dtype=torch.int))
self.register_buffer(
"covalent_radii",
torch.tensor(
ase.data.covalent_radii,
dtype=torch.get_default_dtype(),
),
)
self.cutoff = PolynomialCutoff(r_max, p)
if trainable:
self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True))
self.a_prefactor = torch.nn.Parameter(
Expand Down Expand Up @@ -208,12 +213,7 @@ def forward(
)
v_edges = (14.3996 * Z_u * Z_v) / x * phi
r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v]
envelope = (
1.0
- ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / r_max, self.p)
+ self.p * (self.p + 2.0) * torch.pow(x / r_max, self.p + 1)
- (self.p * (self.p + 1.0) / 2) * torch.pow(x / r_max, self.p + 2)
) * (x < r_max)
envelope = PolynomialCutoff.calculate_envelope(x, r_max, self.p)
v_edges = 0.5 * v_edges * envelope
V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0))
return V_ZBL.squeeze(-1)
Expand All @@ -224,8 +224,8 @@ def __repr__(self):

@compile_mode("script")
class AgnesiTransform(torch.nn.Module):
"""
Agnesi transform see ACEpotentials.jl, JCP 2023, p. 160
"""Agnesi transform - see section on Radial transformations in
ACEpotentials.jl, JCP 2023 (https://doi.org/10.1063/5.0158783).
"""

def __init__(
Expand Down Expand Up @@ -265,21 +265,27 @@ def forward(
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0 = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v])
r_0: torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v])
r_over_r_0 = x / r_0
return (
1 + (self.a * ((x / r_0) ** self.q) / (1 + (x / r_0) ** (self.q - self.p)))
) ** (-1)
1
+ (
self.a
* torch.pow(r_over_r_0, self.q)
/ (1 + torch.pow(r_over_r_0, self.q - self.p))
)
).reciprocal_()

def __repr__(self):
return f"{self.__class__.__name__}(a={self.a}, q={self.q}, p={self.p})"
return (
f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})"
)


@simplify_if_compile
@compile_mode("script")
class SoftTransform(torch.nn.Module):
"""
Soft Transform
"""
"""Soft Transform."""

def __init__(self, a: float = 0.2, b: float = 3.0, trainable=False):
super().__init__()
Expand Down Expand Up @@ -312,9 +318,10 @@ def forward(
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0 = (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) / 4
r_over_r_0 = x / r_0
y = (
x
+ (1 / 2) * torch.tanh(-(x / r_0) - self.a * ((x / r_0) ** self.b))
+ (1 / 2) * torch.tanh(-r_over_r_0 - self.a * torch.pow(r_over_r_0, self.b))
+ 1 / 2
)
return y
Expand Down
83 changes: 83 additions & 0 deletions tests/modules/test_radial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest
import torch
from mace.modules.radial import ZBLBasis, AgnesiTransform

@pytest.fixture
def zbl_basis():
return ZBLBasis(p=6, trainable=False)

def test_zbl_basis_initialization(zbl_basis):
assert zbl_basis.p == torch.tensor(6.0)
assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817]))

assert zbl_basis.a_exp == torch.tensor(0.300)
assert zbl_basis.a_prefactor == torch.tensor(0.4543)
assert not zbl_basis.a_exp.requires_grad
assert not zbl_basis.a_prefactor.requires_grad

def test_trainable_zbl_basis_initialization(zbl_basis):
zbl_basis = ZBLBasis(p=6, trainable=True)
assert zbl_basis.p == torch.tensor(6.0)
assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817]))

assert zbl_basis.a_exp == torch.tensor(0.300)
assert zbl_basis.a_prefactor == torch.tensor(0.4543)
assert zbl_basis.a_exp.requires_grad
assert zbl_basis.a_prefactor.requires_grad

def test_forward(zbl_basis):
x = torch.tensor([1.0, 1.0, 2.0]).unsqueeze(-1) # [n_edges]
node_attrs = torch.tensor([[1, 0], [0, 1]]) # [n_nodes, n_node_features] - one_hot encoding of atomic numbers
edge_index = torch.tensor([[0, 1, 1], [1, 0, 1]]) # [2, n_edges]
atomic_numbers = torch.tensor([1, 6]) # [n_nodes]
output = zbl_basis(x, node_attrs, edge_index, atomic_numbers)

assert output.shape == torch.Size([node_attrs.shape[0]])
assert torch.is_tensor(output)
assert torch.allclose(
output,
torch.tensor([0.0031, 0.0031], dtype=torch.get_default_dtype()),
rtol=1e-2
)

@pytest.fixture
def agnesi():
return AgnesiTransform(trainable=False)

def test_agnesi_transform_initialization(agnesi: AgnesiTransform):
assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4)
assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4)
assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4)
assert not agnesi.a.requires_grad
assert not agnesi.q.requires_grad
assert not agnesi.p.requires_grad

def test_trainable_agnesi_transform_initialization():
agnesi = AgnesiTransform(trainable=True)

assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4)
assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4)
assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4)
assert agnesi.a.requires_grad
assert agnesi.q.requires_grad
assert agnesi.p.requires_grad

def test_agnesi_transform_forward():
agnesi = AgnesiTransform()
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.get_default_dtype()).unsqueeze(-1)
node_attrs = torch.tensor([[0, 1], [1, 0], [0, 1]], dtype=torch.get_default_dtype())
edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]])
atomic_numbers = torch.tensor([1, 6, 8])
output = agnesi(x, node_attrs, edge_index, atomic_numbers)
assert output.shape == x.shape
assert torch.is_tensor(output)
assert torch.allclose(
output,
torch.tensor(
[0.3646, 0.2175, 0.2089], dtype=torch.get_default_dtype()
).unsqueeze(-1),
rtol=1e-2
)

if __name__ == "__main__":
pytest.main([__file__])
Loading