From a4688194e9c42ff285df877fe5fcbff384f5a470 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 31 Oct 2024 22:09:42 -0400 Subject: [PATCH] feat(jax/array-api): property fitting (#4287) ## Summary by CodeRabbit - **New Features** - Introduced the `PropertyFittingNet` class for enhanced property-specific fitting operations. - Enhanced testing framework to support additional computational backends (JAX and Array API Strict). - **Bug Fixes** - Improved handling of attribute assignments in property fitting. - **Tests** - Added new methods and properties to the testing suite for evaluating property fitting with JAX and Array API Strict. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/jax/fitting/fitting.py | 11 ++++ .../tests/array_api_strict/fitting/fitting.py | 9 +++ .../tests/consistent/fitting/test_property.py | 62 +++++++++++++++++++ 3 files changed, 82 insertions(+) diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index 2a6186ac46..d62681490c 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -9,6 +9,9 @@ from deepmd.dpmodel.fitting.polarizability_fitting import ( PolarFitting as PolarFittingNetDP, ) +from deepmd.dpmodel.fitting.property_fitting import ( + PropertyFittingNet as PropertyFittingNetDP, +) from deepmd.jax.common import ( ArrayAPIVariable, flax_module, @@ -51,6 +54,14 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseFitting.register("property") +@flax_module +class PropertyFittingNet(PropertyFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + @BaseFitting.register("dos") @flax_module class DOSFittingNet(DOSFittingNetDP): diff --git a/source/tests/array_api_strict/fitting/fitting.py b/source/tests/array_api_strict/fitting/fitting.py index 5a2bd9c58f..323a49cfe8 100644 --- a/source/tests/array_api_strict/fitting/fitting.py +++ b/source/tests/array_api_strict/fitting/fitting.py @@ -9,6 +9,9 @@ from deepmd.dpmodel.fitting.polarizability_fitting import ( PolarFitting as PolarFittingNetDP, ) +from deepmd.dpmodel.fitting.property_fitting import ( + PropertyFittingNet as PropertyFittingNetDP, +) from ..common import ( to_array_api_strict_array, @@ -43,6 +46,12 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +class PropertyFittingNet(PropertyFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + class DOSFittingNet(DOSFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: value = setattr_for_general_fitting(name, value) diff --git a/source/tests/consistent/fitting/test_property.py b/source/tests/consistent/fitting/test_property.py index beb21d9c04..4e0fe04f9f 100644 --- a/source/tests/consistent/fitting/test_property.py +++ b/source/tests/consistent/fitting/test_property.py @@ -17,6 +17,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, CommonTest, parameterized, @@ -32,6 +34,22 @@ from deepmd.pt.utils.env import DEVICE as PT_DEVICE else: PropertyFittingPT = object +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.fitting.fitting import PropertyFittingNet as PropertyFittingJAX +else: + PropertyFittingJAX = object +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict + + from ...array_api_strict.fitting.fitting import ( + PropertyFittingNet as PropertyFittingStrict, + ) +else: + PropertyFittingStrict = object + PropertyFittingTF = object @@ -84,9 +102,14 @@ def skip_pt(self) -> bool: def skip_tf(self) -> bool: return True + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + tf_class = PropertyFittingTF dp_class = PropertyFittingDP pt_class = PropertyFittingPT + jax_class = PropertyFittingJAX + array_api_strict_class = PropertyFittingStrict args = fitting_property() def setUp(self): @@ -183,6 +206,45 @@ def eval_dp(self, dp_obj: Any) -> Any: aparam=self.aparam if numb_aparam else None, )["property"] + def eval_jax(self, jax_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + task_dim, + intensive, + ) = self.param + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + fparam=jnp.asarray(self.fparam) if numb_fparam else None, + aparam=jnp.asarray(self.aparam) if numb_aparam else None, + )["property"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + array_api_strict.set_array_api_strict_flags(api_version="2023.12") + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + task_dim, + intensive, + ) = self.param + return np.asarray( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None, + aparam=array_api_strict.asarray(self.aparam) if numb_aparam else None, + )["property"] + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: if backend == self.RefBackend.TF: # shape is not same