Skip to content

Commit

Permalink
feat(jax/array-api): property fitting (#4287)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced the `PropertyFittingNet` class for enhanced
property-specific fitting operations.
- Enhanced testing framework to support additional computational
backends (JAX and Array API Strict).

- **Bug Fixes**
	- Improved handling of attribute assignments in property fitting.

- **Tests**
- Added new methods and properties to the testing suite for evaluating
property fitting with JAX and Array API Strict.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Nov 1, 2024
1 parent 704db2f commit a468819
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 0 deletions.
11 changes: 11 additions & 0 deletions deepmd/jax/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from deepmd.dpmodel.fitting.polarizability_fitting import (
PolarFitting as PolarFittingNetDP,
)
from deepmd.dpmodel.fitting.property_fitting import (
PropertyFittingNet as PropertyFittingNetDP,
)
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
Expand Down Expand Up @@ -51,6 +54,14 @@ def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name, value)


@BaseFitting.register("property")
@flax_module
class PropertyFittingNet(PropertyFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)


@BaseFitting.register("dos")
@flax_module
class DOSFittingNet(DOSFittingNetDP):
Expand Down
9 changes: 9 additions & 0 deletions source/tests/array_api_strict/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from deepmd.dpmodel.fitting.polarizability_fitting import (
PolarFitting as PolarFittingNetDP,
)
from deepmd.dpmodel.fitting.property_fitting import (
PropertyFittingNet as PropertyFittingNetDP,
)

from ..common import (
to_array_api_strict_array,
Expand Down Expand Up @@ -43,6 +46,12 @@ def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name, value)


class PropertyFittingNet(PropertyFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)


class DOSFittingNet(DOSFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
Expand Down
62 changes: 62 additions & 0 deletions source/tests/consistent/fitting/test_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
CommonTest,
parameterized,
Expand All @@ -32,6 +34,22 @@
from deepmd.pt.utils.env import DEVICE as PT_DEVICE
else:
PropertyFittingPT = object
if INSTALLED_JAX:
from deepmd.jax.env import (
jnp,
)
from deepmd.jax.fitting.fitting import PropertyFittingNet as PropertyFittingJAX
else:
PropertyFittingJAX = object
if INSTALLED_ARRAY_API_STRICT:
import array_api_strict

from ...array_api_strict.fitting.fitting import (
PropertyFittingNet as PropertyFittingStrict,
)
else:
PropertyFittingStrict = object

PropertyFittingTF = object


Expand Down Expand Up @@ -84,9 +102,14 @@ def skip_pt(self) -> bool:
def skip_tf(self) -> bool:
return True

skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

tf_class = PropertyFittingTF
dp_class = PropertyFittingDP
pt_class = PropertyFittingPT
jax_class = PropertyFittingJAX
array_api_strict_class = PropertyFittingStrict
args = fitting_property()

def setUp(self):
Expand Down Expand Up @@ -183,6 +206,45 @@ def eval_dp(self, dp_obj: Any) -> Any:
aparam=self.aparam if numb_aparam else None,
)["property"]

def eval_jax(self, jax_obj: Any) -> Any:
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
numb_aparam,
task_dim,
intensive,
) = self.param
return np.asarray(
jax_obj(
jnp.asarray(self.inputs),
jnp.asarray(self.atype.reshape(1, -1)),
fparam=jnp.asarray(self.fparam) if numb_fparam else None,
aparam=jnp.asarray(self.aparam) if numb_aparam else None,
)["property"]
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
array_api_strict.set_array_api_strict_flags(api_version="2023.12")
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
numb_aparam,
task_dim,
intensive,
) = self.param
return np.asarray(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None,
aparam=array_api_strict.asarray(self.aparam) if numb_aparam else None,
)["property"]
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
if backend == self.RefBackend.TF:
# shape is not same
Expand Down

0 comments on commit a468819

Please sign in to comment.