Skip to content

Commit

Permalink
update consistent check of dpa1
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Nov 30, 2024
1 parent 8b5b4a8 commit d98c644
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
44 changes: 44 additions & 0 deletions source/tests/consistent/descriptor/test_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PD,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -39,6 +40,10 @@
from deepmd.jax.descriptor.dpa1 import DescrptDPA1 as DescriptorDPA1JAX
else:
DescriptorDPA1JAX = None
if INSTALLED_PD:
from deepmd.pd.model.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1PD
else:
DescrptDPA1PD = None
if INSTALLED_ARRAY_API_STRICT:
from ...array_api_strict.descriptor.dpa1 import DescrptDPA1 as DescriptorDPA1Strict
else:
Expand Down Expand Up @@ -187,6 +192,34 @@ def skip_dp(self) -> bool:
temperature,
)

@property
def skip_pd(self) -> bool:
(
tebd_dim,
tebd_input_mode,
resnet_dt,
type_one_side,
attn,
attn_layer,
attn_dotr,
excluded_types,
env_protection,
set_davg_zero,
scaling_factor,
normalize,
temperature,
ln_eps,
smooth_type_embedding,
concat_output_tebd,
precision,
use_econf_tebd,
use_tebd_bias,
) = self.param
return CommonTest.skip_pd or self.is_meaningless_zero_attention_layer_tests(
attn_layer,
temperature,
)

@property
def skip_jax(self) -> bool:
(
Expand Down Expand Up @@ -287,6 +320,7 @@ def skip_tf(self) -> bool:
tf_class = DescrptDPA1TF
dp_class = DescrptDPA1DP
pt_class = DescrptDPA1PT
pd_class = DescrptDPA1PD
jax_class = DescriptorDPA1JAX
array_api_strict_class = DescriptorDPA1Strict

Expand Down Expand Up @@ -387,6 +421,16 @@ def eval_jax(self, jax_obj: Any) -> Any:
mixed_types=True,
)

def eval_pd(self, pd_obj: Any) -> Any:
return self.eval_pd_descriptor(
pd_obj,
self.natoms,
self.coords,
self.atype,
self.box,
mixed_types=True,
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return self.eval_array_api_strict_descriptor(
array_api_strict_obj,
Expand Down
28 changes: 28 additions & 0 deletions source/tests/consistent/model/test_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ..common import (
INSTALLED_JAX,
INSTALLED_PD,
INSTALLED_PT,
INSTALLED_TF,
SKIP_FLAG,
Expand All @@ -37,6 +38,11 @@
model_args,
)

if INSTALLED_PD:
from deepmd.pd.model.model import get_model as get_model_pd
from deepmd.pd.model.model.ener_model import EnergyModel as EnergyModelPD
else:
EnergyModelPD = None
if INSTALLED_JAX:
from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX
from deepmd.jax.model.model import get_model as get_model_jax
Expand Down Expand Up @@ -90,6 +96,7 @@ def data(self) -> dict:
tf_class = EnergyModelTF
dp_class = EnergyModelDP
pt_class = EnergyModelPT
pd_class = EnergyModelPD
jax_class = EnergyModelJAX
args = model_args()

Expand All @@ -102,6 +109,8 @@ def get_reference_backend(self):
return self.RefBackend.PT
if not self.skip_tf:
return self.RefBackend.TF
if not self.skip_pd:
return self.RefBackend.PD
if not self.skip_jax:
return self.RefBackend.JAX
if not self.skip_dp:
Expand All @@ -119,6 +128,8 @@ def pass_data_to_cls(self, cls, data) -> Any:
return get_model_dp(data)
elif cls is EnergyModelPT:
return get_model_pt(data)
elif cls is EnergyModelPD:
return get_model_pd(data)
elif cls is EnergyModelJAX:
return get_model_jax(data)
return cls(**data, **self.additional_data)
Expand Down Expand Up @@ -190,6 +201,15 @@ def eval_pt(self, pt_obj: Any) -> Any:
self.box,
)

def eval_pd(self, pd_obj: Any) -> Any:
return self.eval_pd_model(
pd_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def eval_jax(self, jax_obj: Any) -> Any:
return self.eval_jax_model(
jax_obj,
Expand Down Expand Up @@ -225,6 +245,14 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
ret[3].ravel(),
ret[4].ravel(),
)
elif backend is self.RefBackend.PD:
return (
ret["energy"].flatten(),
ret["atom_energy"].flatten(),
ret["force"].flatten(),
ret["virial"].flatten(),
ret["atom_virial"].flatten(),
)
elif backend is self.RefBackend.JAX:
return (
ret["energy_redu"].ravel(),
Expand Down

0 comments on commit d98c644

Please sign in to comment.