Skip to content

Commit

Permalink
fix: make the lint checks pass again
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Sep 12, 2024
1 parent 8a966a5 commit 6a8babe
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
2 changes: 1 addition & 1 deletion mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(
r_maxs = np.array(r_maxs)
assert np.all(
r_maxs == r_maxs[0]
), "committee r_max are not all the same {' '.join(r_maxs)}"
), f"committee r_max are not all the same {' '.join(r_maxs)}"
self.r_max = float(r_maxs[0])

self.device = torch_tools.init_device(device)
Expand Down
2 changes: 1 addition & 1 deletion mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
)
edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e")
if pair_repulsion:
self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff)
self.pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff)
self.pair_repulsion = True

sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
Expand Down
28 changes: 18 additions & 10 deletions mace/modules/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,16 @@ def __init__(self, r_max: float, p=6):
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self._calculate_envelope(x, self.r_max, self.p)
return self.calculate_envelope(x, self.r_max, self.p)

@staticmethod
def _calculate_envelope(x: torch.Tensor, r_max: float, p: int) -> torch.Tensor:
def calculate_envelope(x: torch.Tensor, r_max: float, p: int) -> torch.Tensor:
r_over_r_max = x / r_max
envelope = (
1.0
- ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p)
+ p * (p + 2.0) * torch.pow(r_over_r_max, p + 1)
- (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2)
1.0
- ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p)
+ p * (p + 2.0) * torch.pow(r_over_r_max, p + 1)
- (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2)
)
return envelope * (x < r_max)

Expand All @@ -147,6 +147,7 @@ class ZBLBasis(torch.nn.Module):
"""Implementation of the Ziegler-Biersack-Littmark (ZBL) potential
with a polynomial cutoff envelope.
"""

p: torch.Tensor

def __init__(self, p=6, trainable=False):
Expand Down Expand Up @@ -203,7 +204,7 @@ def forward(
)
v_edges = (14.3996 * Z_u * Z_v) / x * phi
r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v]
envelope = PolynomialCutoff._calculate_envelope(x, r_max, self.p)
envelope = PolynomialCutoff.calculate_envelope(x, r_max, self.p)
v_edges = 0.5 * v_edges * envelope
V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0))
return V_ZBL.squeeze(-1)
Expand Down Expand Up @@ -255,14 +256,21 @@ def forward(
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0 : torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v])
r_0: torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v])
r_over_r_0 = x / r_0
return (
1 + (self.a * torch.pow(r_over_r_0, self.q) / (1 + torch.pow(r_over_r_0, self.q - self.p)))
1
+ (
self.a
* torch.pow(r_over_r_0, self.q)
/ (1 + torch.pow(r_over_r_0, self.q - self.p))
)
).reciprocal_()

def __repr__(self):
return f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})"
return (
f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})"
)


@simplify_if_compile
Expand Down

0 comments on commit 6a8babe

Please sign in to comment.