From 6e815a25a883adeac56d9c234247657e62728956 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 14 Nov 2024 03:30:05 -0500 Subject: [PATCH 1/5] fix(dpmodel): fix precision (#4343) ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced a new environment variable `DP_DTYPE_PROMOTION_STRICT` to enhance precision handling in TensorFlow tests. - Added a decorator `@cast_precision` to several descriptor classes, improving precision management during computations. - Updated JAX configuration to enable strict dtype promotion based on the new environment variable. - Enhanced serialization and deserialization processes to include precision attributes across multiple classes. - **Bug Fixes** - Enhanced type handling and input processing in the `GeneralFitting` class for better output predictions. - Improved handling of atomic contributions and exclusions in the `BaseAtomicModel` class. - Addressed potential type mismatches during matrix operations in the `NativeLayer` class. - **Chores** - Updated caching mechanisms in the testing workflow to ensure unique keys based on run parameters. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- .github/workflows/test_python.yml | 1 + .../dpmodel/atomic_model/base_atomic_model.py | 15 +-- deepmd/dpmodel/common.py | 104 ++++++++++++++++++ deepmd/dpmodel/descriptor/dpa1.py | 3 + deepmd/dpmodel/descriptor/dpa2.py | 3 + deepmd/dpmodel/descriptor/se_e2_a.py | 14 +-- deepmd/dpmodel/descriptor/se_r.py | 3 +- deepmd/dpmodel/descriptor/se_t.py | 4 +- deepmd/dpmodel/descriptor/se_t_tebd.py | 5 +- deepmd/dpmodel/fitting/dipole_fitting.py | 4 + deepmd/dpmodel/fitting/general_fitting.py | 20 ++-- deepmd/dpmodel/fitting/invar_fitting.py | 4 + .../dpmodel/fitting/polarizability_fitting.py | 11 +- deepmd/dpmodel/utils/network.py | 4 + deepmd/jax/env.py | 3 + source/tests/common/dpmodel/test_network.py | 11 +- 16 files changed, 175 insertions(+), 34 deletions(-) diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 1b1935a2f6..9437c69ae8 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -62,6 +62,7 @@ jobs: env: NUM_WORKERS: 0 DP_TEST_TF2_ONLY: 1 + DP_DTYPE_PROMOTION_STRICT: 1 if: matrix.group == 1 - run: mv .test_durations .test_durations_${{ matrix.group }} - name: Upload partial durations diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index 04cd5b0014..eb95886598 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -201,18 +201,19 @@ def forward_common_atomic( ret_dict = self.apply_out_stat(ret_dict, atype) # nf x nloc - atom_mask = ext_atom_mask[:, :nloc].astype(xp.int32) + atom_mask = ext_atom_mask[:, :nloc] if self.atom_excl is not None: - atom_mask *= self.atom_excl.build_type_exclude_mask(atype) + atom_mask = xp.logical_and( + atom_mask, self.atom_excl.build_type_exclude_mask(atype) + ) for kk in ret_dict.keys(): out_shape = ret_dict[kk].shape out_shape2 = math.prod(out_shape[2:]) - ret_dict[kk] = ( - ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) - * atom_mask[:, :, None] - ).reshape(out_shape) - ret_dict["mask"] = atom_mask + tmp_arr = ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) + tmp_arr = xp.where(atom_mask[:, :, None], tmp_arr, xp.zeros_like(tmp_arr)) + ret_dict[kk] = xp.reshape(tmp_arr, out_shape) + ret_dict["mask"] = xp.astype(atom_mask, xp.int32) return ret_dict diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 6e6113b494..2bef086726 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -3,9 +3,14 @@ ABC, abstractmethod, ) +from functools import ( + wraps, +) from typing import ( Any, + Callable, Optional, + overload, ) import array_api_compat @@ -116,6 +121,105 @@ def to_numpy_array(x: Any) -> Optional[np.ndarray]: return np.from_dlpack(x) +def cast_precision(func: Callable[..., Any]) -> Callable[..., Any]: + """A decorator that casts and casts back the input + and output tensor of a method. + + The decorator should be used on an instance method. + + The decorator will do the following thing: + (1) It casts input arrays from the global precision + to precision defined by property `precision`. + (2) It casts output arrays from `precision` to + the global precision. + (3) It checks inputs and outputs and only casts when + input or output is an array and its dtype matches + the global precision and `precision`, respectively. + If it does not match (e.g. it is an integer), the decorator + will do nothing on it. + + The decorator supports the array API. + + Returns + ------- + Callable + a decorator that casts and casts back the input and + output array of a method + + Examples + -------- + >>> class A: + ... def __init__(self): + ... self.precision = "float32" + ... + ... @cast_precision + ... def f(x: Array, y: Array) -> Array: + ... return x**2 + y + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + # only convert tensors + returned_tensor = func( + self, + *[safe_cast_array(vv, "global", self.precision) for vv in args], + **{ + kk: safe_cast_array(vv, "global", self.precision) + for kk, vv in kwargs.items() + }, + ) + if isinstance(returned_tensor, tuple): + return tuple( + safe_cast_array(vv, self.precision, "global") for vv in returned_tensor + ) + elif isinstance(returned_tensor, dict): + return { + kk: safe_cast_array(vv, self.precision, "global") + for kk, vv in returned_tensor.items() + } + else: + return safe_cast_array(returned_tensor, self.precision, "global") + + return wrapper + + +@overload +def safe_cast_array( + input: np.ndarray, from_precision: str, to_precision: str +) -> np.ndarray: ... +@overload +def safe_cast_array(input: None, from_precision: str, to_precision: str) -> None: ... +def safe_cast_array( + input: Optional[np.ndarray], from_precision: str, to_precision: str +) -> Optional[np.ndarray]: + """Convert an array from a precision to another precision. + + If input is not an array or without the specific precision, the method will not + cast it. + + Array API is supported. + + Parameters + ---------- + input : np.ndarray or None + Input array + from_precision : str + Array data type that is casted from + to_precision : str + Array data type that casts to + + Returns + ------- + np.ndarray or None + casted array + """ + if array_api_compat.is_array_api_obj(input): + xp = array_api_compat.array_namespace(input) + if input.dtype == get_xp_precision(xp, from_precision): + return xp.astype(input, get_xp_precision(xp, to_precision)) + return input + + __all__ = [ "GLOBAL_NP_FLOAT_PRECISION", "GLOBAL_ENER_FLOAT_PRECISION", diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index dd8acba872..62ab2a5a9a 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -20,6 +20,7 @@ xp_take_along_axis, ) from deepmd.dpmodel.common import ( + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -330,6 +331,7 @@ def __init__( self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd self.trainable = trainable + self.precision = precision def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -451,6 +453,7 @@ def change_type_map( obj["davg"] = obj["davg"][remap_index] obj["dstd"] = obj["dstd"][remap_index] + @cast_precision def call( self, coord_ext, diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 6aa06c47c3..eb6bfa4766 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -15,6 +15,7 @@ xp_take_along_axis, ) from deepmd.dpmodel.common import ( + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -595,6 +596,7 @@ def init_subclass_params(sub_data, sub_class): self.rcut = self.repinit.get_rcut() self.ntypes = ntypes self.sel = self.repinit.sel + self.precision = precision def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -760,6 +762,7 @@ def get_stat_mean_and_stddev(self) -> tuple[list[np.ndarray], list[np.ndarray]]: stddev_list.append(self.repinit_three_body.stddev) return mean_list, stddev_list + @cast_precision def call( self, coord_ext: np.ndarray, diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index a43b92082c..598d5c5fcc 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -16,6 +16,7 @@ NativeOP, ) from deepmd.dpmodel.common import ( + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -30,9 +31,6 @@ from deepmd.dpmodel.utils.update_sel import ( UpdateSel, ) -from deepmd.env import ( - GLOBAL_NP_FLOAT_PRECISION, -) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -343,6 +341,7 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + @cast_precision def call( self, coord_ext, @@ -418,9 +417,7 @@ def call( # nf x nloc x ng x ng1 grrg = np.einsum("flid,fljd->flij", gr, gr1) # nf x nloc x (ng x ng1) - grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype( - GLOBAL_NP_FLOAT_PRECISION - ) + grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron) return grrg, gr[..., 1:], None, None, ww def serialize(self) -> dict: @@ -509,6 +506,7 @@ def update_sel( class DescrptSeAArrayAPI(DescrptSeA): + @cast_precision def call( self, coord_ext, @@ -588,7 +586,5 @@ def call( # grrg = xp.einsum("flid,fljd->flij", gr, gr1) grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4) # nf x nloc x (ng x ng1) - grrg = xp.astype( - xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)), input_dtype - ) + grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)) return grrg, gr[..., 1:], None, None, ww diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index f1260bbab6..54e22a909a 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -15,6 +15,7 @@ NativeOP, ) from deepmd.dpmodel.common import ( + cast_precision, get_xp_precision, to_numpy_array, ) @@ -292,6 +293,7 @@ def cal_g( gg = self.embeddings[(ll,)].call(ss) return gg + @cast_precision def call( self, coord_ext, @@ -355,7 +357,6 @@ def call( res_rescale = 1.0 / 5.0 res = xyz_scatter * res_rescale res = xp.reshape(res, (nf, nloc, ng)) - res = xp.astype(res, get_xp_precision(xp, "global")) return res, None, None, None, ww def serialize(self) -> dict: diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index a0f60baebb..366af90fd9 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -15,6 +15,7 @@ NativeOP, ) from deepmd.dpmodel.common import ( + cast_precision, get_xp_precision, to_numpy_array, ) @@ -267,6 +268,7 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + @cast_precision def call( self, coord_ext, @@ -320,7 +322,6 @@ def call( # we don't require atype is the same in all frames exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) rr = xp.reshape(rr, (nf * nloc, nnei, 4)) - rr = xp.astype(rr, get_xp_precision(xp, self.precision)) for embedding_idx in itertools.product( range(self.ntypes), repeat=self.embeddings.ndim @@ -352,7 +353,6 @@ def call( result += res_ij # nf x nloc x ng result = xp.reshape(result, (nf, nloc, ng)) - result = xp.astype(result, get_xp_precision(xp, "global")) return result, None, None, None, ww def serialize(self) -> dict: diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index 0079c2f6aa..1efa991047 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -17,7 +17,7 @@ xp_take_along_axis, ) from deepmd.dpmodel.common import ( - get_xp_precision, + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -169,6 +169,7 @@ def __init__( self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd self.trainable = trainable + self.precision = precision def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -290,6 +291,7 @@ def change_type_map( obj["davg"] = obj["davg"][remap_index] obj["dstd"] = obj["dstd"][remap_index] + @cast_precision def call( self, coord_ext, @@ -744,7 +746,6 @@ def call( res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei)) # nf x nl x ng result = xp.reshape(res_ij, (nf, nloc, self.filter_neuron[-1])) - result = xp.astype(result, get_xp_precision(xp, "global")) return ( result, None, diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index 2a03934f3b..c872ef0555 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -11,6 +11,9 @@ from deepmd.dpmodel import ( DEFAULT_PRECISION, ) +from deepmd.dpmodel.common import ( + cast_precision, +) from deepmd.dpmodel.fitting.base_fitting import ( BaseFitting, ) @@ -174,6 +177,7 @@ def output_def(self): ] ) + @cast_precision def call( self, descriptor: np.ndarray, diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index f71f322dd1..2958a7d18d 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -439,18 +439,22 @@ def _call_common( ): assert xx_zeros is not None atom_property -= self.nets[(type_i,)](xx_zeros) - atom_property = atom_property + self.bias_atom_e[type_i, ...] - atom_property = atom_property * xp.astype(mask, atom_property.dtype) + atom_property = xp.where( + mask, atom_property, xp.zeros_like(atom_property) + ) outs = outs + atom_property # Shape is [nframes, natoms[0], 1] else: - outs = self.nets[()](xx) + xp.reshape( - xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0), - [nf, nloc, net_dim_out], - ) + outs = self.nets[()](xx) if xx_zeros is not None: outs -= self.nets[()](xx_zeros) + outs += xp.reshape( + xp.take( + xp.astype(self.bias_atom_e, outs.dtype), xp.reshape(atype, [-1]), axis=0 + ), + [nf, nloc, net_dim_out], + ) # nf x nloc exclude_mask = self.emask.build_type_exclude_mask(atype) # nf x nloc x nod - outs = outs * xp.astype(exclude_mask[:, :, None], outs.dtype) - return {self.var_name: xp.astype(outs, get_xp_precision(xp, "global"))} + outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs)) + return {self.var_name: outs} diff --git a/deepmd/dpmodel/fitting/invar_fitting.py b/deepmd/dpmodel/fitting/invar_fitting.py index d840d76149..219589d9ee 100644 --- a/deepmd/dpmodel/fitting/invar_fitting.py +++ b/deepmd/dpmodel/fitting/invar_fitting.py @@ -11,6 +11,9 @@ from deepmd.dpmodel import ( DEFAULT_PRECISION, ) +from deepmd.dpmodel.common import ( + cast_precision, +) from deepmd.dpmodel.output_def import ( FittingOutputDef, OutputVariableDef, @@ -204,6 +207,7 @@ def output_def(self): ] ) + @cast_precision def call( self, descriptor: np.ndarray, diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 616ac20437..021359a96e 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -15,6 +15,7 @@ DEFAULT_PRECISION, ) from deepmd.dpmodel.common import ( + cast_precision, to_numpy_array, ) from deepmd.dpmodel.fitting.base_fitting import ( @@ -241,6 +242,7 @@ def change_type_map( self.scale = self.scale[remap_index] self.constant_matrix = self.constant_matrix[remap_index] + @cast_precision def call( self, descriptor: np.ndarray, @@ -285,7 +287,8 @@ def call( ] # out = out * self.scale[atype, ...] scale_atype = xp.reshape( - xp.take(self.scale, xp.reshape(atype, [-1]), axis=0), (*atype.shape, 1) + xp.take(xp.astype(self.scale, out.dtype), xp.reshape(atype, [-1]), axis=0), + (*atype.shape, 1), ) out = out * scale_atype # (nframes * nloc, m1, 3) @@ -308,7 +311,11 @@ def call( if self.shift_diag: # bias = self.constant_matrix[atype] bias = xp.reshape( - xp.take(self.constant_matrix, xp.reshape(atype, [-1]), axis=0), + xp.take( + xp.astype(self.constant_matrix, out.dtype), + xp.reshape(atype, [-1]), + axis=0, + ), (nframes, nloc), ) # (nframes, nloc, 1) diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 971ab00894..18071f8eca 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -248,6 +248,10 @@ def call(self, x: np.ndarray) -> np.ndarray: if self.b is not None else xp.matmul(x, self.w) ) + if y.dtype != x.dtype: + # workaround for bfloat16 + # https://github.com/jax-ml/ml_dtypes/issues/235 + y = xp.astype(y, x.dtype) y = fn(y) if self.idt is not None: y *= self.idt diff --git a/deepmd/jax/env.py b/deepmd/jax/env.py index 1b90433b00..02e31ae66e 100644 --- a/deepmd/jax/env.py +++ b/deepmd/jax/env.py @@ -13,6 +13,9 @@ jax.config.update("jax_enable_x64", True) # jax.config.update("jax_debug_nans", True) +if os.environ.get("DP_DTYPE_PROMOTION_STRICT") == "1": + jax.config.update("jax_numpy_dtype_promotion", "strict") + __all__ = [ "jax", "jnp", diff --git a/source/tests/common/dpmodel/test_network.py b/source/tests/common/dpmodel/test_network.py index dde5992746..3feb64f72f 100644 --- a/source/tests/common/dpmodel/test_network.py +++ b/source/tests/common/dpmodel/test_network.py @@ -8,6 +8,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + get_xp_precision, +) from deepmd.dpmodel.utils import ( EmbeddingNet, FittingNet, @@ -46,7 +49,9 @@ def test_serialize_deserize(self) -> None: inp_shap = [ni] if ashp is not None: inp_shap = ashp + inp_shap - inp = np.arange(np.prod(inp_shap)).reshape(inp_shap) + inp = np.arange( + np.prod(inp_shap), dtype=get_xp_precision(np, prec) + ).reshape(inp_shap) np.testing.assert_allclose(nl0.call(inp), nl1.call(inp)) def test_shape_error(self) -> None: @@ -168,7 +173,7 @@ def test_embedding_net(self) -> None: resnet_dt=idt, ) en1 = EmbeddingNet.deserialize(en0.serialize()) - inp = np.ones([ni]) + inp = np.ones([ni], dtype=get_xp_precision(np, prec)) np.testing.assert_allclose(en0.call(inp), en1.call(inp)) @@ -191,7 +196,7 @@ def test_fitting_net(self) -> None: bias_out=bo, ) en1 = FittingNet.deserialize(en0.serialize()) - inp = np.ones([ni]) + inp = np.ones([ni], dtype=get_xp_precision(np, prec)) en0.call(inp) en1.call(inp) np.testing.assert_allclose(en0.call(inp), en1.call(inp)) From d7cf48c3ba0dd86e14765922152e1146ec142f0b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 14 Nov 2024 04:03:33 -0500 Subject: [PATCH 2/5] fix(tf): fix model out_bias deserialize (#4350) per discussion ## Summary by CodeRabbit - **New Features** - Enhanced handling of model serialization and deserialization, particularly for bias parameters. - Updated output structure for the `PT` backend in the energy model tests. - **Bug Fixes** - Improved logic for managing unsupported model configurations, ensuring clearer error reporting. - **Documentation** - Updated method signatures to reflect changes in functionality for model handling and testing. --------- Signed-off-by: Jinzhe Zeng --- deepmd/tf/model/model.py | 46 ++++++++++++++++++++-- source/tests/consistent/model/test_ener.py | 4 +- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index 2edfe4b651..8991bf1baf 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -17,6 +17,9 @@ from deepmd.common import ( j_get_type, ) +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) from deepmd.tf.descriptor.descriptor import ( Descriptor, ) @@ -803,10 +806,34 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": ------- Descriptor The deserialized descriptor + + Raises + ------ + ValueError + If both fitting/@variables/bias_atom_e and @variables/out_bias are non-zero """ data = data.copy() check_version_compatibility(data.pop("@version", 2), 2, 1) descriptor = Descriptor.deserialize(data.pop("descriptor"), suffix=suffix) + if data["fitting"].get("@variables", {}).get("bias_atom_e") is not None: + # careful: copy each level and don't modify the input array, + # otherwise it will affect the original data + # deepcopy is not used for performance reasons + data["fitting"] = data["fitting"].copy() + data["fitting"]["@variables"] = data["fitting"]["@variables"].copy() + if ( + int(np.any(data["fitting"]["@variables"]["bias_atom_e"])) + + int(np.any(data["@variables"]["out_bias"])) + > 1 + ): + raise ValueError( + "fitting/@variables/bias_atom_e and @variables/out_bias should not be both non-zero" + ) + data["fitting"]["@variables"]["bias_atom_e"] = data["fitting"][ + "@variables" + ]["bias_atom_e"] + data["@variables"]["out_bias"].reshape( + data["fitting"]["@variables"]["bias_atom_e"].shape + ) fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix) # pass descriptor type embedding to model if descriptor.explicit_ntypes: @@ -815,8 +842,10 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": else: type_embedding = None # BEGINE not supported keys - data.pop("atom_exclude_types") - data.pop("pair_exclude_types") + if len(data.pop("atom_exclude_types")) > 0: + raise NotImplementedError("atom_exclude_types is not supported") + if len(data.pop("pair_exclude_types")) > 0: + raise NotImplementedError("pair_exclude_types is not supported") data.pop("rcond", None) data.pop("preset_out_bias", None) data.pop("@variables", None) @@ -853,6 +882,17 @@ def serialize(self, suffix: str = "") -> dict: ntypes = len(self.get_type_map()) dict_fit = self.fitting.serialize(suffix=suffix) + if dict_fit.get("@variables", {}).get("bias_atom_e") is not None: + out_bias = dict_fit["@variables"]["bias_atom_e"].reshape( + [1, ntypes, dict_fit["dim_out"]] + ) + dict_fit["@variables"]["bias_atom_e"] = np.zeros_like( + dict_fit["@variables"]["bias_atom_e"] + ) + else: + out_bias = np.zeros( + [1, ntypes, dict_fit["dim_out"]], dtype=GLOBAL_NP_FLOAT_PRECISION + ) return { "@class": "Model", "type": "standard", @@ -866,7 +906,7 @@ def serialize(self, suffix: str = "") -> dict: "rcond": None, "preset_out_bias": None, "@variables": { - "out_bias": np.zeros([1, ntypes, dict_fit["dim_out"]]), # pylint: disable=no-explicit-dtype + "out_bias": out_bias, "out_std": np.ones([1, ntypes, dict_fit["dim_out"]]), # pylint: disable=no-explicit-dtype }, } diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 45d85861fd..4c50c08bef 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -141,7 +141,9 @@ def pass_data_to_cls(self, cls, data) -> Any: if cls is EnergyModelDP: return get_model_dp(data) elif cls is EnergyModelPT: - return get_model_pt(data) + model = get_model_pt(data) + model.atomic_model.out_bias.uniform_() + return model elif cls is EnergyModelJAX: return get_model_jax(data) return cls(**data, **self.additional_data) From 6d9d8bb64e2c571f3fbe5b0ed9a8369001417869 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 14 Nov 2024 04:04:02 -0500 Subject: [PATCH 3/5] chore(pt): Delete deepmd/pt/utils/cache.py (#4356) It's not used and not covered. Delete per discussion. ## Summary by CodeRabbit - **Chores** - Removed a custom LRU cache decorator to streamline functionality and reduce complexity. Signed-off-by: Jinzhe Zeng --- deepmd/pt/utils/cache.py | 31 ------------------------------- 1 file changed, 31 deletions(-) delete mode 100644 deepmd/pt/utils/cache.py diff --git a/deepmd/pt/utils/cache.py b/deepmd/pt/utils/cache.py deleted file mode 100644 index c40c4050b7..0000000000 --- a/deepmd/pt/utils/cache.py +++ /dev/null @@ -1,31 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import copy as copy_lib -import functools - - -def lru_cache(maxsize=16, typed=False, copy=False, deepcopy=False): - if deepcopy: - - def decorator(f): - cached_func = functools.lru_cache(maxsize, typed)(f) - - @functools.wraps(f) - def wrapper(*args, **kwargs): - return copy_lib.deepcopy(cached_func(*args, **kwargs)) - - return wrapper - - elif copy: - - def decorator(f): - cached_func = functools.lru_cache(maxsize, typed)(f) - - @functools.wraps(f) - def wrapper(*args, **kwargs): - return copy_lib.copy(cached_func(*args, **kwargs)) - - return wrapper - - else: - decorator = functools.lru_cache(maxsize, typed) - return decorator From d3095cf16ace2fd29c3d3ff8773f54def28c8570 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 14 Nov 2024 04:04:55 -0500 Subject: [PATCH 4/5] chore(cc): merge get backend codes (#4355) Fix #4308. ## Summary by CodeRabbit - **New Features** - Introduced a new function to dynamically determine the backend framework based on the model file type. - **Improvements** - Enhanced backend detection logic in multiple classes, allowing for more flexible model initialization. - Simplified control flow in the initialization methods of various components. - **Bug Fixes** - Improved error handling for unsupported backends and model formats during initialization processes. Signed-off-by: Jinzhe Zeng --- source/api_cc/include/common.h | 7 +++++++ source/api_cc/src/DataModifier.cc | 3 +-- source/api_cc/src/DeepPot.cc | 12 +----------- source/api_cc/src/DeepSpin.cc | 9 +-------- source/api_cc/src/DeepTensor.cc | 3 +-- source/api_cc/src/common.cc | 12 ++++++++++++ 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/source/api_cc/include/common.h b/source/api_cc/include/common.h index def3df933b..2bd0cf7135 100644 --- a/source/api_cc/include/common.h +++ b/source/api_cc/include/common.h @@ -15,6 +15,13 @@ namespace deepmd { typedef double ENERGYTYPE; enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown }; +/** + * @brief Get the backend of the model. + * @param[in] model The model name. + * @return The backend of the model. + **/ +DPBackend get_backend(const std::string& model); + struct NeighborListData { /// Array stores the core region atom's index std::vector ilist; diff --git a/source/api_cc/src/DataModifier.cc b/source/api_cc/src/DataModifier.cc index 23e0321410..4f319b4f66 100644 --- a/source/api_cc/src/DataModifier.cc +++ b/source/api_cc/src/DataModifier.cc @@ -28,8 +28,7 @@ void DipoleChargeModifier::init(const std::string& model, << std::endl; return; } - // TODO: To implement detect_backend - DPBackend backend = deepmd::DPBackend::TensorFlow; + const DPBackend backend = get_backend(model); if (deepmd::DPBackend::TensorFlow == backend) { #ifdef BUILD_TENSORFLOW dcm = std::make_shared(model, gpu_rank, diff --git a/source/api_cc/src/DeepPot.cc b/source/api_cc/src/DeepPot.cc index 8769f5b211..29f8b99cfd 100644 --- a/source/api_cc/src/DeepPot.cc +++ b/source/api_cc/src/DeepPot.cc @@ -39,17 +39,7 @@ void DeepPot::init(const std::string& model, << std::endl; return; } - DPBackend backend; - if (model.length() >= 4 && model.substr(model.length() - 4) == ".pth") { - backend = deepmd::DPBackend::PyTorch; - } else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") { - backend = deepmd::DPBackend::TensorFlow; - } else if (model.length() >= 11 && - model.substr(model.length() - 11) == ".savedmodel") { - backend = deepmd::DPBackend::JAX; - } else { - throw deepmd::deepmd_exception("Unsupported model file format"); - } + const DPBackend backend = get_backend(model); if (deepmd::DPBackend::TensorFlow == backend) { #ifdef BUILD_TENSORFLOW dp = std::make_shared(model, gpu_rank, file_content); diff --git a/source/api_cc/src/DeepSpin.cc b/source/api_cc/src/DeepSpin.cc index d761e9d3c2..eb37828410 100644 --- a/source/api_cc/src/DeepSpin.cc +++ b/source/api_cc/src/DeepSpin.cc @@ -36,14 +36,7 @@ void DeepSpin::init(const std::string& model, << std::endl; return; } - DPBackend backend; - if (model.length() >= 4 && model.substr(model.length() - 4) == ".pth") { - backend = deepmd::DPBackend::PyTorch; - } else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") { - backend = deepmd::DPBackend::TensorFlow; - } else { - throw deepmd::deepmd_exception("Unsupported model file format"); - } + const DPBackend backend = get_backend(model); if (deepmd::DPBackend::TensorFlow == backend) { #ifdef BUILD_TENSORFLOW dp = std::make_shared(model, gpu_rank, file_content); diff --git a/source/api_cc/src/DeepTensor.cc b/source/api_cc/src/DeepTensor.cc index a0596e046f..a9031472e6 100644 --- a/source/api_cc/src/DeepTensor.cc +++ b/source/api_cc/src/DeepTensor.cc @@ -30,8 +30,7 @@ void DeepTensor::init(const std::string &model, << std::endl; return; } - // TODO: To implement detect_backend - DPBackend backend = deepmd::DPBackend::TensorFlow; + const DPBackend backend = get_backend(model); if (deepmd::DPBackend::TensorFlow == backend) { #ifdef BUILD_TENSORFLOW dt = std::make_shared(model, gpu_rank, name_scope_); diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index bd3f18c579..5a4f05d75c 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -1399,3 +1399,15 @@ void deepmd::print_summary(const std::string& pre) { << "set tf inter_op_parallelism_threads: " << num_inter_nthreads << std::endl; } + +deepmd::DPBackend deepmd::get_backend(const std::string& model) { + if (model.length() >= 4 && model.substr(model.length() - 4) == ".pth") { + return deepmd::DPBackend::PyTorch; + } else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") { + return deepmd::DPBackend::TensorFlow; + } else if (model.length() >= 11 && + model.substr(model.length() - 11) == ".savedmodel") { + return deepmd::DPBackend::JAX; + } + throw deepmd::deepmd_exception("Unsupported model file format"); +} From 0ad42893bf76829f68127c145984d58e5c6133eb Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 15 Nov 2024 00:10:38 +0800 Subject: [PATCH 5/5] feat(pt): add universal test for loss (#4354) ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced a new `LossTest` class for enhanced testing of loss functions. - Added multiple parameterized test functions for various loss functions in the new `test_loss.py` file. - **Bug Fixes** - Corrected tensor operations in the `DOSLoss` class to ensure accurate cumulative sum calculations. - **Documentation** - Added SPDX license identifiers to multiple files for clarity on licensing terms. - **Chores** - Refactored data conversion methods in the `PTTestCase` class for improved handling of tensors and arrays. --------- Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: Jinzhe Zeng --- deepmd/dpmodel/common.py | 4 +- deepmd/pt/loss/dos.py | 8 +- .../universal/common/cases/loss/__init__.py | 1 + .../tests/universal/common/cases/loss/loss.py | 11 + .../universal/common/cases/loss/utils.py | 79 +++++++ .../tests/universal/dpmodel/loss/__init__.py | 1 + .../tests/universal/dpmodel/loss/test_loss.py | 203 ++++++++++++++++++ source/tests/universal/pt/backend.py | 44 +++- source/tests/universal/pt/loss/__init__.py | 1 + source/tests/universal/pt/loss/test_loss.py | 47 ++++ 10 files changed, 388 insertions(+), 11 deletions(-) create mode 100644 source/tests/universal/common/cases/loss/__init__.py create mode 100644 source/tests/universal/common/cases/loss/loss.py create mode 100644 source/tests/universal/common/cases/loss/utils.py create mode 100644 source/tests/universal/dpmodel/loss/__init__.py create mode 100644 source/tests/universal/dpmodel/loss/test_loss.py create mode 100644 source/tests/universal/pt/loss/__init__.py create mode 100644 source/tests/universal/pt/loss/test_loss.py diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 2bef086726..8353cc28e3 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -34,7 +34,7 @@ "double": np.float64, "int32": np.int32, "int64": np.int64, - "bool": bool, + "bool": np.bool_, "default": GLOBAL_NP_FLOAT_PRECISION, # NumPy doesn't have bfloat16 (and doesn't plan to add) # ml_dtypes is a solution, but it seems not supporting np.save/np.load @@ -50,7 +50,7 @@ np.int32: "int32", np.int64: "int64", ml_dtypes.bfloat16: "bfloat16", - bool: "bool", + np.bool_: "bool", } assert set(RESERVED_PRECISON_DICT.keys()) == set(PRECISION_DICT.values()) DEFAULT_PRECISION = "float64" diff --git a/deepmd/pt/loss/dos.py b/deepmd/pt/loss/dos.py index 7a64c6fbd3..30a3f715ef 100644 --- a/deepmd/pt/loss/dos.py +++ b/deepmd/pt/loss/dos.py @@ -151,10 +151,10 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False if self.has_acdf and "atom_dos" in model_pred and "atom_dos" in label: find_local = label.get("find_atom_dos", 0.0) pref_acdf = pref_acdf * find_local - local_tensor_pred_cdf = torch.cusum( + local_tensor_pred_cdf = torch.cumsum( model_pred["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1 ) - local_tensor_label_cdf = torch.cusum( + local_tensor_label_cdf = torch.cumsum( label["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1 ) diff = (local_tensor_pred_cdf - local_tensor_label_cdf).reshape( @@ -199,10 +199,10 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False if self.has_cdf and "dos" in model_pred and "dos" in label: find_global = label.get("find_dos", 0.0) pref_cdf = pref_cdf * find_global - global_tensor_pred_cdf = torch.cusum( + global_tensor_pred_cdf = torch.cumsum( model_pred["dos"].reshape([-1, self.numb_dos]), dim=-1 ) - global_tensor_label_cdf = torch.cusum( + global_tensor_label_cdf = torch.cumsum( label["dos"].reshape([-1, self.numb_dos]), dim=-1 ) diff = global_tensor_pred_cdf - global_tensor_label_cdf diff --git a/source/tests/universal/common/cases/loss/__init__.py b/source/tests/universal/common/cases/loss/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/common/cases/loss/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/common/cases/loss/loss.py b/source/tests/universal/common/cases/loss/loss.py new file mode 100644 index 0000000000..a3b585114f --- /dev/null +++ b/source/tests/universal/common/cases/loss/loss.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + + +from .utils import ( + LossTestCase, +) + + +class LossTest(LossTestCase): + def setUp(self) -> None: + LossTestCase.setUp(self) diff --git a/source/tests/universal/common/cases/loss/utils.py b/source/tests/universal/common/cases/loss/utils.py new file mode 100644 index 0000000000..63e6e3ed27 --- /dev/null +++ b/source/tests/universal/common/cases/loss/utils.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import numpy as np + +from deepmd.utils.data import ( + DataRequirementItem, +) + +from .....seed import ( + GLOBAL_SEED, +) + + +class LossTestCase: + """Common test case for loss function.""" + + def setUp(self): + pass + + def test_label_keys(self): + module = self.forward_wrapper(self.module) + label_requirement = self.module.label_requirement + label_dict = {item.key: item for item in label_requirement} + label_keys = sorted(label_dict.keys()) + label_keys_expected = sorted( + [key for key in self.key_to_pref_map if self.key_to_pref_map[key] > 0] + ) + np.testing.assert_equal(label_keys_expected, label_keys) + + def test_forward(self): + module = self.forward_wrapper(self.module) + label_requirement = self.module.label_requirement + label_dict = {item.key: item for item in label_requirement} + label_keys = sorted(label_dict.keys()) + natoms = 5 + nframes = 2 + + def fake_model(): + model_predict = { + data_key: fake_input( + label_dict[data_key], natoms=natoms, nframes=nframes + ) + for data_key in label_keys + } + if "atom_ener" in model_predict: + model_predict["atom_energy"] = model_predict.pop("atom_ener") + model_predict.update( + {"mask_mag": np.ones([nframes, natoms, 1], dtype=np.bool_)} + ) + return model_predict + + labels = { + data_key: fake_input(label_dict[data_key], natoms=natoms, nframes=nframes) + for data_key in label_keys + } + labels.update({"find_" + data_key: 1.0 for data_key in label_keys}) + + _, loss, more_loss = module( + {}, + fake_model, + labels, + natoms, + 1.0, + ) + + +def fake_input(data_item: DataRequirementItem, natoms=5, nframes=2) -> np.ndarray: + ndof = data_item.ndof + atomic = data_item.atomic + repeat = data_item.repeat + rng = np.random.default_rng(seed=GLOBAL_SEED) + dtype = data_item.dtype if data_item.dtype is not None else np.float64 + if atomic: + data = rng.random([nframes, natoms, ndof], dtype) + else: + data = rng.random([nframes, ndof], dtype) + if repeat != 1: + data = np.repeat(data, repeat).reshape([nframes, -1]) + return data diff --git a/source/tests/universal/dpmodel/loss/__init__.py b/source/tests/universal/dpmodel/loss/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/dpmodel/loss/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/dpmodel/loss/test_loss.py b/source/tests/universal/dpmodel/loss/test_loss.py new file mode 100644 index 0000000000..6473c159da --- /dev/null +++ b/source/tests/universal/dpmodel/loss/test_loss.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from collections import ( + OrderedDict, +) + +from ....consistent.common import ( + parameterize_func, +) + + +def LossParamEnergy( + starter_learning_rate=1.0, + pref_e=1.0, + pref_f=1.0, + pref_v=1.0, + pref_ae=1.0, +): + key_to_pref_map = { + "energy": pref_e, + "force": pref_f, + "virial": pref_v, + "atom_ener": pref_ae, + } + input_dict = { + "key_to_pref_map": key_to_pref_map, + "starter_learning_rate": starter_learning_rate, + "start_pref_e": pref_e, + "limit_pref_e": pref_e / 2, + "start_pref_f": pref_f, + "limit_pref_f": pref_f / 2, + "start_pref_v": pref_v, + "limit_pref_v": pref_v / 2, + "start_pref_ae": pref_ae, + "limit_pref_ae": pref_ae / 2, + } + return input_dict + + +LossParamEnergyList = parameterize_func( + LossParamEnergy, + OrderedDict( + { + "pref_e": (1.0, 0.0), + "pref_f": (1.0, 0.0), + "pref_v": (1.0, 0.0), + "pref_ae": (1.0, 0.0), + } + ), +) +# to get name for the default function +LossParamEnergy = LossParamEnergyList[0] + + +def LossParamEnergySpin( + starter_learning_rate=1.0, + pref_e=1.0, + pref_fr=1.0, + pref_fm=1.0, + pref_v=1.0, + pref_ae=1.0, +): + key_to_pref_map = { + "energy": pref_e, + "force": pref_fr, + "force_mag": pref_fm, + "virial": pref_v, + "atom_ener": pref_ae, + } + input_dict = { + "key_to_pref_map": key_to_pref_map, + "starter_learning_rate": starter_learning_rate, + "start_pref_e": pref_e, + "limit_pref_e": pref_e / 2, + "start_pref_fr": pref_fr, + "limit_pref_fr": pref_fr / 2, + "start_pref_fm": pref_fm, + "limit_pref_fm": pref_fm / 2, + "start_pref_v": pref_v, + "limit_pref_v": pref_v / 2, + "start_pref_ae": pref_ae, + "limit_pref_ae": pref_ae / 2, + } + return input_dict + + +LossParamEnergySpinList = parameterize_func( + LossParamEnergySpin, + OrderedDict( + { + "pref_e": (1.0, 0.0), + "pref_fr": (1.0, 0.0), + "pref_fm": (1.0, 0.0), + "pref_v": (1.0, 0.0), + "pref_ae": (1.0, 0.0), + } + ), +) +# to get name for the default function +LossParamEnergySpin = LossParamEnergySpinList[0] + + +def LossParamDos( + starter_learning_rate=1.0, + pref_dos=1.0, + pref_ados=1.0, +): + key_to_pref_map = { + "dos": pref_dos, + "atom_dos": pref_ados, + } + input_dict = { + "key_to_pref_map": key_to_pref_map, + "starter_learning_rate": starter_learning_rate, + "numb_dos": 2, + "start_pref_dos": pref_dos, + "limit_pref_dos": pref_dos / 2, + "start_pref_ados": pref_ados, + "limit_pref_ados": pref_ados / 2, + "start_pref_cdf": 0.0, + "limit_pref_cdf": 0.0, + "start_pref_acdf": 0.0, + "limit_pref_acdf": 0.0, + } + return input_dict + + +LossParamDosList = parameterize_func( + LossParamDos, + OrderedDict( + { + "pref_dos": (1.0,), + "pref_ados": (1.0, 0.0), + } + ), +) + parameterize_func( + LossParamDos, + OrderedDict( + { + "pref_dos": (0.0,), + "pref_ados": (1.0,), + } + ), +) + +# to get name for the default function +LossParamDos = LossParamDosList[0] + + +def LossParamTensor( + pref=1.0, + pref_atomic=1.0, +): + tensor_name = "test_tensor" + key_to_pref_map = { + tensor_name: pref, + f"atomic_{tensor_name}": pref_atomic, + } + input_dict = { + "key_to_pref_map": key_to_pref_map, + "tensor_name": tensor_name, + "tensor_size": 2, + "label_name": tensor_name, + "pref": pref, + "pref_atomic": pref_atomic, + } + return input_dict + + +LossParamTensorList = parameterize_func( + LossParamTensor, + OrderedDict( + { + "pref": (1.0,), + "pref_atomic": (1.0, 0.0), + } + ), +) + parameterize_func( + LossParamTensor, + OrderedDict( + { + "pref": (0.0,), + "pref_atomic": (1.0,), + } + ), +) +# to get name for the default function +LossParamTensor = LossParamTensorList[0] + + +def LossParamProperty(): + key_to_pref_map = { + "property": 1.0, + } + input_dict = { + "key_to_pref_map": key_to_pref_map, + "task_dim": 2, + } + return input_dict + + +LossParamPropertyList = [LossParamProperty] +# to get name for the default function +LossParamProperty = LossParamPropertyList[0] diff --git a/source/tests/universal/pt/backend.py b/source/tests/universal/pt/backend.py index ae857d6105..082787780e 100644 --- a/source/tests/universal/pt/backend.py +++ b/source/tests/universal/pt/backend.py @@ -83,8 +83,8 @@ def forward_wrapper(self, module, on_cpu=False): def create_wrapper_method(method): def wrapper_method(self, *args, **kwargs): # convert to torch tensor - args = [to_torch_tensor(arg) for arg in args] - kwargs = {k: to_torch_tensor(v) for k, v in kwargs.items()} + args = [_to_torch_tensor(arg) for arg in args] + kwargs = {k: _to_torch_tensor(v) for k, v in kwargs.items()} if on_cpu: args = [ arg.detach().cpu() if arg is not None else None for arg in args @@ -97,11 +97,11 @@ def wrapper_method(self, *args, **kwargs): output = method(*args, **kwargs) # convert to numpy array if isinstance(output, tuple): - output = tuple(to_numpy_array(o) for o in output) + output = tuple(_to_numpy_array(o) for o in output) elif isinstance(output, dict): - output = {k: to_numpy_array(v) for k, v in output.items()} + output = {k: _to_numpy_array(v) for k, v in output.items()} else: - output = to_numpy_array(output) + output = _to_numpy_array(output) return output return wrapper_method @@ -112,3 +112,37 @@ class wrapper_module: forward_lower = create_wrapper_method(module.forward_lower) return wrapper_module() + + +def _to_torch_tensor(xx): + if isinstance(xx, dict): + return {kk: to_torch_tensor(xx[kk]) for kk in xx} + elif callable(xx): + return convert_to_torch_callable(xx) + else: + return to_torch_tensor(xx) + + +def convert_to_torch_callable(func): + def wrapper(*args, **kwargs): + output = _to_torch_tensor(func(*args, **kwargs)) + return output + + return wrapper + + +def _to_numpy_array(xx): + if isinstance(xx, dict): + return {kk: to_numpy_array(xx[kk]) for kk in xx} + elif callable(xx): + return convert_to_numpy_callable(xx) + else: + return to_numpy_array(xx) + + +def convert_to_numpy_callable(func): + def wrapper(*args, **kwargs): + output = _to_numpy_array(func(*args, **kwargs)) + return output + + return wrapper diff --git a/source/tests/universal/pt/loss/__init__.py b/source/tests/universal/pt/loss/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/pt/loss/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/pt/loss/test_loss.py b/source/tests/universal/pt/loss/test_loss.py new file mode 100644 index 0000000000..47c2d06fbc --- /dev/null +++ b/source/tests/universal/pt/loss/test_loss.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +from deepmd.pt.loss import ( + DOSLoss, + EnergySpinLoss, + EnergyStdLoss, + PropertyLoss, + TensorLoss, +) + +from ....consistent.common import ( + parameterized, +) +from ...common.cases.loss.loss import ( + LossTest, +) +from ...dpmodel.loss.test_loss import ( + LossParamDosList, + LossParamEnergyList, + LossParamEnergySpinList, + LossParamPropertyList, + LossParamTensorList, +) +from ..backend import ( + PTTestCase, +) + + +@parameterized( + ( + *[(param_func, EnergyStdLoss) for param_func in LossParamEnergyList], + *[(param_func, EnergySpinLoss) for param_func in LossParamEnergySpinList], + *[(param_func, DOSLoss) for param_func in LossParamDosList], + *[(param_func, TensorLoss) for param_func in LossParamTensorList], + *[(param_func, PropertyLoss) for param_func in LossParamPropertyList], + ) # class_param & class +) +class TestLossPT(unittest.TestCase, LossTest, PTTestCase): + def setUp(self): + (LossParam, Loss) = self.param[0] + LossTest.setUp(self) + self.module_class = Loss + self.input_dict = LossParam() + self.key_to_pref_map = self.input_dict.pop("key_to_pref_map") + self.module = Loss(**self.input_dict) + self.skip_test_jit = True