Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cluster force #675

Open
wants to merge 7 commits into
base: refactor_data
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions mace/data/atomic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ class AtomicData(torch_geometric.data.Data):
energy: torch.Tensor
stress: torch.Tensor
virials: torch.Tensor
cluster: torch.Tensor
dipole: torch.Tensor
charges: torch.Tensor
weight: torch.Tensor
energy_weight: torch.Tensor
forces_weight: torch.Tensor
stress_weight: torch.Tensor
virials_weight: torch.Tensor
cluster_weight: torch.Tensor
dipole_weight: torch.Tensor
charges_weight: torch.Tensor

Expand All @@ -59,12 +61,14 @@ def __init__(
energy_weight: Optional[torch.Tensor], # [,]
forces_weight: Optional[torch.Tensor], # [,]
stress_weight: Optional[torch.Tensor], # [,]
cluster_weight: Optional[torch.Tensor], # [,]
virials_weight: Optional[torch.Tensor], # [,]
dipole_weight: Optional[torch.Tensor], # [,]
charges_weight: Optional[torch.Tensor], # [,]
forces: Optional[torch.Tensor], # [n_nodes, 3]
energy: Optional[torch.Tensor], # [, ]
stress: Optional[torch.Tensor], # [1,3,3]
cluster: Optional[torch.Tensor], # [n_nodes, ]
virials: Optional[torch.Tensor], # [1,3,3]
dipole: Optional[torch.Tensor], # [, 3]
charges: Optional[torch.Tensor], # [n_nodes, ]
Expand All @@ -83,12 +87,14 @@ def __init__(
assert forces_weight is None or len(forces_weight.shape) == 0
assert stress_weight is None or len(stress_weight.shape) == 0
assert virials_weight is None or len(virials_weight.shape) == 0
assert cluster_weight is None or len(cluster_weight.shape) == 0
assert dipole_weight is None or dipole_weight.shape == (1, 3), dipole_weight
assert charges_weight is None or len(charges_weight.shape) == 0
assert cell is None or cell.shape == (3, 3)
assert forces is None or forces.shape == (num_nodes, 3)
assert energy is None or len(energy.shape) == 0
assert stress is None or stress.shape == (1, 3, 3)
assert cluster is None or cluster.shape == (num_nodes,)
assert virials is None or virials.shape == (1, 3, 3)
assert dipole is None or dipole.shape[-1] == 3
assert charges is None or charges.shape == (num_nodes,)
Expand All @@ -106,12 +112,14 @@ def __init__(
"energy_weight": energy_weight,
"forces_weight": forces_weight,
"stress_weight": stress_weight,
"cluster_weight": cluster_weight,
"virials_weight": virials_weight,
"dipole_weight": dipole_weight,
"charges_weight": charges_weight,
"forces": forces,
"energy": energy,
"stress": stress,
"cluster": cluster,
"virials": virials,
"dipole": dipole,
"charges": charges,
Expand Down Expand Up @@ -184,6 +192,13 @@ def from_config(
if config.property_weights.get("stress") is not None
else torch.tensor(1.0, dtype=torch.get_default_dtype())
)
cluster_weight = (
torch.tensor(
config.property_weights.get("cluster"), dtype=torch.get_default_dtype()
)
if config.property_weights.get("cluster") is not None
else torch.tensor(1.0, dtype=torch.get_default_dtype())
)

virials_weight = (
torch.tensor(
Expand Down Expand Up @@ -238,6 +253,13 @@ def from_config(
if config.properties.get("stress") is not None
else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype())
)
cluster = (
torch.tensor(
config.properties.get("cluster"), dtype=torch.get_default_dtype()
)
if config.properties.get("cluster") is not None
else torch.zeros(num_atoms, dtype=torch.get_default_dtype())
)
virials = (
voigt_to_matrix(
torch.tensor(
Expand Down Expand Up @@ -274,12 +296,14 @@ def from_config(
energy_weight=energy_weight,
forces_weight=forces_weight,
stress_weight=stress_weight,
cluster_weight=cluster_weight,
virials_weight=virials_weight,
dipole_weight=dipole_weight,
charges_weight=charges_weight,
forces=forces,
energy=energy,
stress=stress,
cluster=cluster,
virials=virials,
dipole=dipole,
charges=charges,
Expand Down
4 changes: 4 additions & 0 deletions mace/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
UniversalLoss,
WeightedEnergyForcesDipoleLoss,
WeightedEnergyForcesLoss,
WeightedEnergyForcesLossForceCluster,
WeightedEnergyForcesStressLoss,
WeightedEnergyForcesStressLossForceCluster,
WeightedEnergyForcesVirialsLoss,
WeightedForcesLoss,
WeightedHuberEnergyForcesStressLoss,
Expand Down Expand Up @@ -97,6 +99,8 @@
"AtomicDipolesMACE",
"EnergyDipolesMACE",
"WeightedEnergyForcesLoss",
"WeightedEnergyForcesLossForceCluster",
"WeightedEnergyForcesStressLossForceCluster",
"WeightedForcesLoss",
"WeightedEnergyForcesVirialsLoss",
"WeightedEnergyForcesStressLoss",
Expand Down
101 changes: 101 additions & 0 deletions mace/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from mace.tools import TensorDict
from mace.tools.scatter import compute_effective_index, scatter_sum
from mace.tools.torch_geometric import Batch


Expand All @@ -27,6 +28,24 @@ def weighted_mean_squared_error_energy(ref: Batch, pred: TensorDict) -> torch.Te
) # []


def weighted_mean_square_error_force_cluster(
ref: Batch, pred: TensorDict
) -> torch.Tensor:
effective_inicies, _ = compute_effective_index([ref.batch, ref.cluster])
cluster_forces_ref = scatter_sum(
ref["forces"],
effective_inicies,
dim=0,
)
cluster_forces_pred = scatter_sum(
pred["forces"],
effective_inicies,
dim=0,
)

return torch.mean(torch.square(cluster_forces_ref - cluster_forces_pred))


def weighted_mean_squared_stress(ref: Batch, pred: TensorDict) -> torch.Tensor:
# energy: [n_graphs, ]
configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ]
Expand Down Expand Up @@ -171,6 +190,88 @@ def __repr__(self):
)


class WeightedEnergyForcesLossForceCluster(torch.nn.Module):
def __init__(
self, energy_weight=1.0, forces_weight=1.0, cluster_weight=1.0
) -> None:
super().__init__()
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"cluster_weight",
torch.tensor(cluster_weight, dtype=torch.get_default_dtype()),
)

def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor:
tloss = (
self.energy_weight * weighted_mean_squared_error_energy(ref, pred)
+ self.forces_weight * mean_squared_error_forces(ref, pred)
+ self.cluster_weight * weighted_mean_square_error_force_cluster(ref, pred)
)
# logging.info(
# f"Cluster weight: {self.cluster_weight}, "
# f"Cluster error: {weighted_mean_square_error_force_cluster(ref, pred)},"
# f" Total loss: {tloss}"
# )
return tloss

def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f})"
f"cluster_weight={self.cluster_weight:.3f})"
)


class WeightedEnergyForcesStressLossForceCluster(torch.nn.Module):
def __init__(
self,
energy_weight=1.0,
forces_weight=1.0,
stress_weight=1.0,
cluster_weight=1.0,
) -> None:
super().__init__()
self.register_buffer(
"energy_weight",
torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"forces_weight",
torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"stress_weight",
torch.tensor(stress_weight, dtype=torch.get_default_dtype()),
)
self.register_buffer(
"cluster_weight",
torch.tensor(cluster_weight, dtype=torch.get_default_dtype()),
)

def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor:
return (
self.energy_weight * weighted_mean_squared_error_energy(ref, pred)
+ self.forces_weight * mean_squared_error_forces(ref, pred)
+ self.stress_weight * weighted_mean_squared_stress(ref, pred)
+ self.cluster_weight * weighted_mean_square_error_force_cluster(ref, pred)
)

def __repr__(self):
return (
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f}, "
f"stress_weight={self.stress_weight:.3f}, "
f"cluster_weight={self.cluster_weight:.3f})"
)


class WeightedForcesLoss(torch.nn.Module):
def __init__(self, forces_weight=1.0) -> None:
super().__init__()
Expand Down
13 changes: 13 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"DipoleRMSE",
"DipoleMAE",
"EnergyDipoleRMSE",
"PerAtomRMSECluster",
"PerAtomRMSEstressCluster",
],
default="PerAtomRMSE",
)
Expand Down Expand Up @@ -449,6 +451,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
choices=[
"ef",
"weighted",
"weighted_cluster",
"weighted_cluster_stress",
"forces_only",
"virials",
"stress",
Expand Down Expand Up @@ -491,6 +495,15 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
default=10.0,
dest="swa_virials_weight",
)
parser.add_argument(
"--cluster_weight", help="weight of cluster loss", type=float, default=0.0
)
parser.add_argument(
"--swa_cluster_weight",
help="weight of cluster loss after starting swa",
type=float,
default=0.0,
)
parser.add_argument(
"--stress_weight", help="weight of virials loss", type=float, default=1.0
)
Expand Down
31 changes: 30 additions & 1 deletion mace/tools/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
See https://github.com/pytorch/pytorch/issues/63780.
"""

from typing import Optional
from typing import List, Optional, Tuple

import torch

Expand Down Expand Up @@ -110,3 +110,32 @@ def scatter_mean(
else:
out.div_(count, rounding_mode="floor")
return out


def compute_effective_index(
indices: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes an effective index from multiple index tensors. Useful for multi index scatter operations.

Args:
indices (List[torch.Tensor]): List of index tensors, each of shape (N,).

Returns:
effective_index (torch.Tensor): Tensor of shape (N,), where each element
is a unique integer representing the combination of indices.
unique_combinations (torch.Tensor): Tensor containing unique combinations
of indices, shape (num_unique_combinations, num_indices).
"""
# Stack indices to shape (num_indices, N)
indices_stack = torch.stack(indices, dim=0) # Shape: (num_indices, N)

# Transpose to get combinations per element
index_combinations = indices_stack.t() # Shape: (N, num_indices)

# Find unique combinations and get inverse indices
unique_combinations, inverse_indices = torch.unique(
index_combinations, dim=0, return_inverse=True
)

return inverse_indices, unique_combinations
Loading
Loading