From d94010c64099613e330da9e9e0501c96a5b9ffbf Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 5 Jan 2025 18:40:05 +0800 Subject: [PATCH] feat: dpmodel energy loss & consistent tests Fix #4105. Fix #4429. Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/loss/__init__.py | 1 + deepmd/dpmodel/loss/ener.py | 352 ++++++++++++++++++++++ deepmd/dpmodel/loss/loss.py | 100 ++++++ deepmd/pt/loss/ener.py | 51 ++++ deepmd/pt/loss/loss.py | 26 ++ deepmd/tf/loss/ener.py | 58 ++++ deepmd/tf/loss/loss.py | 52 ++++ source/tests/consistent/loss/__init__.py | 2 + source/tests/consistent/loss/common.py | 5 + source/tests/consistent/loss/test_ener.py | 241 +++++++++++++++ 10 files changed, 888 insertions(+) create mode 100644 deepmd/dpmodel/loss/__init__.py create mode 100644 deepmd/dpmodel/loss/ener.py create mode 100644 deepmd/dpmodel/loss/loss.py create mode 100644 source/tests/consistent/loss/__init__.py create mode 100644 source/tests/consistent/loss/common.py create mode 100644 source/tests/consistent/loss/test_ener.py diff --git a/deepmd/dpmodel/loss/__init__.py b/deepmd/dpmodel/loss/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/dpmodel/loss/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/dpmodel/loss/ener.py b/deepmd/dpmodel/loss/ener.py new file mode 100644 index 0000000000..52c447b60b --- /dev/null +++ b/deepmd/dpmodel/loss/ener.py @@ -0,0 +1,352 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + +import array_api_compat +import numpy as np + +from deepmd.dpmodel.loss.loss import ( + Loss, +) +from deepmd.utils.data import ( + DataRequirementItem, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + + +class EnergyLoss(Loss): + def __init__( + self, + starter_learning_rate: float, + start_pref_e: float = 0.02, + limit_pref_e: float = 1.00, + start_pref_f: float = 1000, + limit_pref_f: float = 1.00, + start_pref_v: float = 0.0, + limit_pref_v: float = 0.0, + start_pref_ae: float = 0.0, + limit_pref_ae: float = 0.0, + start_pref_pf: float = 0.0, + limit_pref_pf: float = 0.0, + relative_f: Optional[float] = None, + enable_atom_ener_coeff: bool = False, + start_pref_gf: float = 0.0, + limit_pref_gf: float = 0.0, + numb_generalized_coord: int = 0, + **kwargs, + ) -> None: + self.starter_learning_rate = starter_learning_rate + self.start_pref_e = start_pref_e + self.limit_pref_e = limit_pref_e + self.start_pref_f = start_pref_f + self.limit_pref_f = limit_pref_f + self.start_pref_v = start_pref_v + self.limit_pref_v = limit_pref_v + self.start_pref_ae = start_pref_ae + self.limit_pref_ae = limit_pref_ae + self.start_pref_pf = start_pref_pf + self.limit_pref_pf = limit_pref_pf + self.relative_f = relative_f + self.enable_atom_ener_coeff = enable_atom_ener_coeff + self.start_pref_gf = start_pref_gf + self.limit_pref_gf = limit_pref_gf + self.numb_generalized_coord = numb_generalized_coord + self.has_e = self.start_pref_e != 0.0 or self.limit_pref_e != 0.0 + self.has_f = self.start_pref_f != 0.0 or self.limit_pref_f != 0.0 + self.has_v = self.start_pref_v != 0.0 or self.limit_pref_v != 0.0 + self.has_ae = self.start_pref_ae != 0.0 or self.limit_pref_ae != 0.0 + self.has_pf = self.start_pref_pf != 0.0 or self.limit_pref_pf != 0.0 + self.has_gf = self.start_pref_gf != 0.0 or self.limit_pref_gf != 0.0 + if self.has_gf and self.numb_generalized_coord < 1: + raise RuntimeError( + "When generalized force loss is used, the dimension of generalized coordinates should be larger than 0" + ) + + def call( + self, + learning_rate: float, + natoms: int, + model_dict: dict[str, np.ndarray], + label_dict: dict[str, np.ndarray], + ) -> dict[str, np.ndarray]: + """Calculate loss from model results and labeled results.""" + energy = model_dict["energy"] + force = model_dict["force"] + virial = model_dict["virial"] + atom_ener = model_dict["atom_ener"] + energy_hat = label_dict["energy"] + force_hat = label_dict["force"] + virial_hat = label_dict["virial"] + atom_ener_hat = label_dict["atom_ener"] + atom_pref = label_dict["atom_pref"] + find_energy = label_dict["find_energy"] + find_force = label_dict["find_force"] + find_virial = label_dict["find_virial"] + find_atom_ener = label_dict["find_atom_ener"] + find_atom_pref = label_dict["find_atom_pref"] + xp = array_api_compat.array_namespace( + energy, + force, + virial, + atom_ener, + energy_hat, + force_hat, + virial_hat, + atom_ener_hat, + atom_pref, + ) + if self.has_gf: + drdq = label_dict["drdq"] + find_drdq = label_dict["find_drdq"] + + if self.enable_atom_ener_coeff: + # when ener_coeff (\nu) is defined, the energy is defined as + # E = \sum_i \nu_i E_i + # instead of the sum of atomic energies. + # + # A case is that we want to train reaction energy + # A + B -> C + D + # E = - E(A) - E(B) + E(C) + E(D) + # A, B, C, D could be put far away from each other + atom_ener_coeff = label_dict["atom_ener_coeff"] + atom_ener_coeff = xp.reshape(atom_ener_coeff, xp.shape(atom_ener)) + energy = xp.sum(atom_ener_coeff * atom_ener, 1) + if self.has_e: + l2_ener_loss = xp.mean(xp.square(energy - energy_hat)) + + if self.has_f or self.has_pf or self.relative_f or self.has_gf: + force_reshape = xp.reshape(force, [-1]) + force_hat_reshape = xp.reshape(force_hat, [-1]) + diff_f = force_hat_reshape - force_reshape + + if self.relative_f is not None: + force_hat_3 = xp.reshape(force_hat, [-1, 3]) + norm_f = xp.reshape(xp.norm(force_hat_3, axis=1), [-1, 1]) + self.relative_f + diff_f_3 = xp.reshape(diff_f, [-1, 3]) + diff_f_3 = diff_f_3 / norm_f + diff_f = xp.reshape(diff_f_3, [-1]) + + if self.has_f: + l2_force_loss = xp.mean(xp.square(diff_f)) + + if self.has_pf: + atom_pref_reshape = xp.reshape(atom_pref, [-1]) + l2_pref_force_loss = xp.mean( + xp.multiply(xp.square(diff_f), atom_pref_reshape), + ) + + if self.has_gf: + drdq = label_dict["drdq"] + force_reshape_nframes = xp.reshape(force, [-1, natoms[0] * 3]) + force_hat_reshape_nframes = xp.reshape(force_hat, [-1, natoms[0] * 3]) + drdq_reshape = xp.reshape( + drdq, [-1, natoms[0] * 3, self.numb_generalized_coord] + ) + gen_force_hat = xp.einsum( + "bij,bi->bj", drdq_reshape, force_hat_reshape_nframes + ) + gen_force = xp.einsum("bij,bi->bj", drdq_reshape, force_reshape_nframes) + diff_gen_force = gen_force_hat - gen_force + l2_gen_force_loss = xp.mean(xp.square(diff_gen_force)) + + if self.has_v: + virial_reshape = xp.reshape(virial, [-1]) + virial_hat_reshape = xp.reshape(virial_hat, [-1]) + l2_virial_loss = xp.mean( + xp.square(virial_hat_reshape - virial_reshape), + ) + + if self.has_ae: + atom_ener_reshape = xp.reshape(atom_ener, [-1]) + atom_ener_hat_reshape = xp.reshape(atom_ener_hat, [-1]) + l2_atom_ener_loss = xp.mean( + xp.square(atom_ener_hat_reshape - atom_ener_reshape), + ) + + atom_norm = 1.0 / natoms + atom_norm_ener = 1.0 / natoms + lr_ratio = learning_rate / self.starter_learning_rate + pref_e = find_energy * ( + self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * lr_ratio + ) + pref_f = find_force * ( + self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * lr_ratio + ) + pref_v = find_virial * ( + self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * lr_ratio + ) + pref_ae = find_atom_ener * ( + self.limit_pref_ae + (self.start_pref_ae - self.limit_pref_ae) * lr_ratio + ) + pref_pf = find_atom_pref * ( + self.limit_pref_pf + (self.start_pref_pf - self.limit_pref_pf) * lr_ratio + ) + if self.has_gf: + pref_gf = find_drdq * ( + self.limit_pref_gf + + (self.start_pref_gf - self.limit_pref_gf) * lr_ratio + ) + + l2_loss = 0 + more_loss = {} + if self.has_e: + l2_loss += atom_norm_ener * (pref_e * l2_ener_loss) + more_loss["l2_ener_loss"] = self.display_if_exist(l2_ener_loss, find_energy) + if self.has_f: + l2_loss += pref_f * l2_force_loss + more_loss["l2_force_loss"] = self.display_if_exist( + l2_force_loss, find_force + ) + if self.has_v: + l2_loss += atom_norm * (pref_v * l2_virial_loss) + more_loss["l2_virial_loss"] = self.display_if_exist( + l2_virial_loss, find_virial + ) + if self.has_ae: + l2_loss += pref_ae * l2_atom_ener_loss + more_loss["l2_atom_ener_loss"] = self.display_if_exist( + l2_atom_ener_loss, find_atom_ener + ) + if self.has_pf: + l2_loss += pref_pf * l2_pref_force_loss + more_loss["l2_pref_force_loss"] = self.display_if_exist( + l2_pref_force_loss, find_atom_pref + ) + if self.has_gf: + l2_loss += pref_gf * l2_gen_force_loss + more_loss["l2_gen_force_loss"] = self.display_if_exist( + l2_gen_force_loss, find_drdq + ) + + self.l2_l = l2_loss + self.l2_more = more_loss + return l2_loss, more_loss + + @property + def label_requirement(self) -> list[DataRequirementItem]: + """Return data label requirements needed for this loss calculation.""" + label_requirement = [] + if self.has_e: + label_requirement.append( + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + must=False, + high_prec=True, + ) + ) + if self.has_f: + label_requirement.append( + DataRequirementItem( + "force", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ) + ) + if self.has_v: + label_requirement.append( + DataRequirementItem( + "virial", + ndof=9, + atomic=False, + must=False, + high_prec=False, + ) + ) + if self.has_ae: + label_requirement.append( + DataRequirementItem( + "atom_ener", + ndof=1, + atomic=True, + must=False, + high_prec=False, + ) + ) + if self.has_pf: + label_requirement.append( + DataRequirementItem( + "atom_pref", + ndof=1, + atomic=True, + must=False, + high_prec=False, + repeat=3, + ) + ) + if self.has_gf > 0: + label_requirement.append( + DataRequirementItem( + "drdq", + ndof=self.numb_generalized_coord * 3, + atomic=True, + must=False, + high_prec=False, + ) + ) + if self.enable_atom_ener_coeff: + label_requirement.append( + DataRequirementItem( + "atom_ener_coeff", + ndof=1, + atomic=True, + must=False, + high_prec=False, + default=1.0, + ) + ) + return label_requirement + + def serialize(self) -> dict: + """Serialize the loss module. + + Returns + ------- + dict + The serialized loss module + """ + return { + "@class": "EnergyLoss", + "@version": 1, + "starter_learning_rate": self.starter_learning_rate, + "start_pref_e": self.start_pref_e, + "limit_pref_e": self.limit_pref_e, + "start_pref_f": self.start_pref_f, + "limit_pref_f": self.limit_pref_f, + "start_pref_v": self.start_pref_v, + "limit_pref_v": self.limit_pref_v, + "start_pref_ae": self.start_pref_ae, + "limit_pref_ae": self.limit_pref_ae, + "start_pref_pf": self.start_pref_pf, + "limit_pref_pf": self.limit_pref_pf, + "relative_f": self.relative_f, + "enable_atom_ener_coeff": self.enable_atom_ener_coeff, + "start_pref_gf": self.start_pref_gf, + "limit_pref_gf": self.limit_pref_gf, + "numb_generalized_coord": self.numb_generalized_coord, + } + + @classmethod + def deserialize(cls, data: dict) -> "Loss": + """Deserialize the loss module. + + Parameters + ---------- + data : dict + The serialized loss module + + Returns + ------- + Loss + The deserialized loss module + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + return cls(**data) diff --git a/deepmd/dpmodel/loss/loss.py b/deepmd/dpmodel/loss/loss.py new file mode 100644 index 0000000000..a297380cce --- /dev/null +++ b/deepmd/dpmodel/loss/loss.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + ABC, + abstractmethod, +) + +import array_api_compat +import numpy as np + +from deepmd.dpmodel.common import ( + NativeOP, +) +from deepmd.utils.data import ( + DataRequirementItem, +) +from deepmd.utils.plugin import ( + make_plugin_registry, +) + + +class Loss(NativeOP, ABC, make_plugin_registry("loss")): + @abstractmethod + def call( + self, + learning_rate: float, + natoms: int, + model_dict: dict[str, np.ndarray], + label_dict: dict[str, np.ndarray], + ) -> dict[str, np.ndarray]: + """Calculate loss from model results and labeled results.""" + + @property + @abstractmethod + def label_requirement(self) -> list[DataRequirementItem]: + """Return data label requirements needed for this loss calculation.""" + + @staticmethod + def display_if_exist(loss: np.ndarray, find_property: float) -> np.ndarray: + """Display NaN if labeled property is not found. + + Parameters + ---------- + loss : np.ndarray + the loss scalar + find_property : float + whether the property is found + + Returns + ------- + np.ndarray + the loss scalar or NaN + """ + xp = array_api_compat.array_namespace(loss) + return loss if bool(find_property) else xp.nan + + @classmethod + def get_loss(cls, loss_params: dict) -> "Loss": + """Get the loss module by the parameters. + + By default, all the parameters are directly passed to the constructor. + If not, override this method. + + Parameters + ---------- + loss_params : dict + The loss parameters + + Returns + ------- + Loss + The loss module + """ + loss = cls(**loss_params) + return loss + + @abstractmethod + def serialize(self) -> dict: + """Serialize the loss module. + + Returns + ------- + dict + The serialized loss module + """ + + @classmethod + @abstractmethod + def deserialize(cls, data: dict) -> "Loss": + """Deserialize the loss module. + + Parameters + ---------- + data : dict + The serialized loss module + + Returns + ------- + Loss + The deserialized loss module + """ diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index b564aa57ec..47b1a2ecc2 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -18,6 +18,9 @@ from deepmd.utils.data import ( DataRequirementItem, ) +from deepmd.utils.version import ( + check_version_compatibility, +) class EnergyStdLoss(TaskLoss): @@ -412,6 +415,54 @@ def label_requirement(self) -> list[DataRequirementItem]: ) return label_requirement + def serialize(self) -> dict: + """Serialize the loss module. + + Returns + ------- + dict + The serialized loss module + """ + return { + "@class": "EnergyLoss", + "@version": 1, + "starter_learning_rate": self.starter_learning_rate, + "start_pref_e": self.start_pref_e, + "limit_pref_e": self.limit_pref_e, + "start_pref_f": self.start_pref_f, + "limit_pref_f": self.limit_pref_f, + "start_pref_v": self.start_pref_v, + "limit_pref_v": self.limit_pref_v, + "start_pref_ae": self.start_pref_ae, + "limit_pref_ae": self.limit_pref_ae, + "start_pref_pf": self.start_pref_pf, + "limit_pref_pf": self.limit_pref_pf, + "relative_f": self.relative_f, + "enable_atom_ener_coeff": self.enable_atom_ener_coeff, + "start_pref_gf": self.start_pref_gf, + "limit_pref_gf": self.limit_pref_gf, + "numb_generalized_coord": self.numb_generalized_coord, + } + + @classmethod + def deserialize(cls, data: dict) -> "TaskLoss": + """Deserialize the loss module. + + Parameters + ---------- + data : dict + The serialized loss module + + Returns + ------- + Loss + The deserialized loss module + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + return cls(**data) + class EnergyHessianStdLoss(EnergyStdLoss): def __init__( diff --git a/deepmd/pt/loss/loss.py b/deepmd/pt/loss/loss.py index dfe62f4da0..d1777a29b3 100644 --- a/deepmd/pt/loss/loss.py +++ b/deepmd/pt/loss/loss.py @@ -64,3 +64,29 @@ def get_loss(cls, loss_params: dict) -> "TaskLoss": """ loss = cls(**loss_params) return loss + + def serialize(self) -> dict: + """Serialize the loss module. + + Returns + ------- + dict + The serialized loss module + """ + raise NotImplementedError + + @classmethod + def deserialize(cls, data: dict) -> "TaskLoss": + """Deserialize the loss module. + + Parameters + ---------- + data : dict + The serialized loss module + + Returns + ------- + TaskLoss + The deserialized loss module + """ + raise NotImplementedError diff --git a/deepmd/tf/loss/ener.py b/deepmd/tf/loss/ener.py index 95cc8adafb..2b5eb7f3d5 100644 --- a/deepmd/tf/loss/ener.py +++ b/deepmd/tf/loss/ener.py @@ -16,6 +16,9 @@ from deepmd.utils.data import ( DataRequirementItem, ) +from deepmd.utils.version import ( + check_version_compatibility, +) from .loss import ( Loss, @@ -402,6 +405,61 @@ def label_requirement(self) -> list[DataRequirementItem]: ) return data_requirements + def serialize(self, suffix: str = "") -> dict: + """Serialize the loss module. + + Parameters + ---------- + suffix : str + The suffix of the loss module + + Returns + ------- + dict + The serialized loss module + """ + return { + "@class": "EnergyLoss", + "@version": 1, + "starter_learning_rate": self.starter_learning_rate, + "start_pref_e": self.start_pref_e, + "limit_pref_e": self.limit_pref_e, + "start_pref_f": self.start_pref_f, + "limit_pref_f": self.limit_pref_f, + "start_pref_v": self.start_pref_v, + "limit_pref_v": self.limit_pref_v, + "start_pref_ae": self.start_pref_ae, + "limit_pref_ae": self.limit_pref_ae, + "start_pref_pf": self.start_pref_pf, + "limit_pref_pf": self.limit_pref_pf, + "relative_f": self.relative_f, + "enable_atom_ener_coeff": self.enable_atom_ener_coeff, + "start_pref_gf": self.start_pref_gf, + "limit_pref_gf": self.limit_pref_gf, + "numb_generalized_coord": self.numb_generalized_coord, + } + + @classmethod + def deserialize(cls, data: dict, suffix: str = "") -> "Loss": + """Deserialize the loss module. + + Parameters + ---------- + data : dict + The serialized loss module + suffix : str + The suffix of the loss module + + Returns + ------- + Loss + The deserialized loss module + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + return cls(**data) + class EnerSpinLoss(Loss): def __init__( diff --git a/deepmd/tf/loss/loss.py b/deepmd/tf/loss/loss.py index 351da7b748..0d6d8c9c40 100644 --- a/deepmd/tf/loss/loss.py +++ b/deepmd/tf/loss/loss.py @@ -95,3 +95,55 @@ def display_if_exist(loss: tf.Tensor, find_property: float) -> tf.Tensor: @abstractmethod def label_requirement(self) -> list[DataRequirementItem]: """Return data label requirements needed for this loss calculation.""" + + def serialize(self, suffix: str = "") -> dict: + """Serialize the loss module. + + Parameters + ---------- + suffix : str + The suffix of the loss module + + Returns + ------- + dict + The serialized loss module + """ + raise NotImplementedError + + @classmethod + def deserialize(cls, data: dict, suffix: str = "") -> "Loss": + """Deserialize the loss module. + + Parameters + ---------- + data : dict + The serialized loss module + suffix : str + The suffix of the loss module + + Returns + ------- + Loss + The deserialized loss module + """ + raise NotImplementedError + + def init_variables( + self, + graph: tf.Graph, + graph_def: tf.GraphDef, + suffix: str = "", + ) -> None: + """No actual effect. + + Parameters + ---------- + graph : tf.Graph + The input frozen model graph + graph_def : tf.GraphDef + The input frozen model graph_def + suffix : str, optional + The suffix of the scope + """ + pass diff --git a/source/tests/consistent/loss/__init__.py b/source/tests/consistent/loss/__init__.py new file mode 100644 index 0000000000..3a1ed68529 --- /dev/null +++ b/source/tests/consistent/loss/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Test consistency of loss among backends.""" diff --git a/source/tests/consistent/loss/common.py b/source/tests/consistent/loss/common.py new file mode 100644 index 0000000000..efe4a33968 --- /dev/null +++ b/source/tests/consistent/loss/common.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + + +class LossTest: + """Useful utilities for loss tests.""" diff --git a/source/tests/consistent/loss/test_ener.py b/source/tests/consistent/loss/test_ener.py new file mode 100644 index 0000000000..6354a69d98 --- /dev/null +++ b/source/tests/consistent/loss/test_ener.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np + +from deepmd.dpmodel.common import ( + to_numpy_array, +) +from deepmd.dpmodel.loss.ener import EnergyLoss as EnerLossDP +from deepmd.utils.argcheck import ( + loss_ener, +) + +from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, + INSTALLED_PD, + INSTALLED_PT, + INSTALLED_TF, + CommonTest, +) +from .common import ( + LossTest, +) + +if INSTALLED_TF: + from deepmd.tf.env import ( + GLOBAL_TF_FLOAT_PRECISION, + tf, + ) + from deepmd.tf.loss.ener import EnerStdLoss as EnerLossTF +else: + EnerLossTF = None +if INSTALLED_PT: + import torch + + from deepmd.pt.loss.ener import EnergyStdLoss as EnerLossPT +else: + EnerLossPT = None +if INSTALLED_PD: + import paddle + + from deepmd.pd.loss.ener import EnergyStdLoss as EnerLossPD +else: + EnerLossPD = None +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict + + +class TestEner(CommonTest, LossTest, unittest.TestCase): + @property + def data(self) -> dict: + return { + "start_pref_e": 0.02, + "limit_pref_e": 1.0, + "start_pref_f": 1000.0, + "limit_pref_f": 1.0, + "start_pref_v": 1.0, + "limit_pref_v": 1.0, + "start_pref_ae": 1.0, + "limit_pref_ae": 1.0, + "start_pref_pf": 1.0, + "limit_pref_pf": 1.0, + } + + skip_tf = CommonTest.skip_tf + skip_pt = CommonTest.skip_pt + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + skip_pd = not INSTALLED_PD + + tf_class = EnerLossTF + dp_class = EnerLossDP + pt_class = EnerLossPT + jax_class = EnerLossDP + pd_class = EnerLossPD + array_api_strict_class = EnerLossDP + args = loss_ener() + + def setUp(self) -> None: + CommonTest.setUp(self) + self.learning_rate = 1e-3 + rng = np.random.default_rng(20250105) + self.nframes = 2 + self.natoms = 6 + self.predict = { + "energy": rng.random((self.nframes,)), + "force": rng.random((self.nframes, self.natoms, 3)), + "virial": rng.random((self.nframes, 9)), + "atom_ener": rng.random( + ( + self.nframes, + self.natoms, + ) + ), + } + self.label = { + "energy": rng.random((self.nframes,)), + "force": rng.random((self.nframes, self.natoms, 3)), + "virial": rng.random((self.nframes, 9)), + "atom_ener": rng.random( + ( + self.nframes, + self.natoms, + ) + ), + "atom_pref": np.ones((self.nframes, self.natoms, 3)), + "find_energy": 1.0, + "find_force": 1.0, + "find_virial": 1.0, + "find_atom_ener": 1.0, + "find_atom_pref": 1.0, + } + + @property + def additional_data(self) -> dict: + return { + "starter_learning_rate": 1e-3, + } + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + predict = { + kk: tf.placeholder( + GLOBAL_TF_FLOAT_PRECISION, vv.shape, name="i_predict_" + kk + ) + for kk, vv in self.predict.items() + } + label = { + kk: tf.placeholder( + GLOBAL_TF_FLOAT_PRECISION, vv.shape, name="i_label_" + kk + ) + if isinstance(vv, np.ndarray) + else vv + for kk, vv in self.label.items() + } + + loss, more_loss = obj.build( + self.learning_rate, + [self.natoms], + predict, + label, + suffix=suffix, + ) + return [loss], { + **{ + vv: self.predict[kk] + for kk, vv in predict.items() + if isinstance(vv, tf.Tensor) + }, + **{ + vv: self.label[kk] + for kk, vv in label.items() + if isinstance(vv, tf.Tensor) + }, + } + + def eval_pt(self, pt_obj: Any) -> Any: + predict = {kk: torch.asarray(vv) for kk, vv in self.predict.items()} + label = {kk: torch.asarray(vv) for kk, vv in self.label.items()} + predict["atom_energy"] = predict.pop("atom_ener") + _, loss, more_loss = pt_obj( + {}, + lambda: predict, + label, + self.natoms, + self.learning_rate, + ) + loss = to_numpy_array(loss) + more_loss = {kk: to_numpy_array(vv) for kk, vv in more_loss.items()} + return loss, more_loss + + def eval_dp(self, dp_obj: Any) -> Any: + return dp_obj( + self.learning_rate, + self.natoms, + self.predict, + self.label, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + predict = {kk: jnp.asarray(vv) for kk, vv in self.predict.items()} + label = {kk: jnp.asarray(vv) for kk, vv in self.label.items()} + + loss, more_loss = jax_obj( + self.learning_rate, + self.natoms, + predict, + label, + ) + loss = to_numpy_array(loss) + more_loss = {kk: to_numpy_array(vv) for kk, vv in more_loss.items()} + return loss, more_loss + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + predict = {kk: array_api_strict.asarray(vv) for kk, vv in self.predict.items()} + label = {kk: array_api_strict.asarray(vv) for kk, vv in self.label.items()} + + loss, more_loss = array_api_strict_obj( + self.learning_rate, + self.natoms, + predict, + label, + ) + loss = to_numpy_array(loss) + more_loss = {kk: to_numpy_array(vv) for kk, vv in more_loss.items()} + return loss, more_loss + + def eval_pd(self, pd_obj: Any) -> Any: + predict = {kk: paddle.asarray(vv) for kk, vv in self.predict.items()} + label = {kk: paddle.asarray(vv) for kk, vv in self.label.items()} + predict["atom_energy"] = predict.pop("atom_ener") + _, loss, more_loss = pd_obj( + {}, + lambda: predict, + label, + self.natoms, + self.learning_rate, + ) + loss = to_numpy_array(loss) + more_loss = {kk: to_numpy_array(vv) for kk, vv in more_loss.items()} + return loss, more_loss + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + return (ret[0],) + + @property + def rtol(self) -> float: + """Relative tolerance for comparing the return value.""" + return 1e-10 + + @property + def atol(self) -> float: + """Absolute tolerance for comparing the return value.""" + return 1e-10