Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add consistency test for ZBL between dp and pt #4292

Merged
merged 23 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions deepmd/dpmodel/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
)

from deepmd.dpmodel.atomic_model.linear_atomic_model import (
DPZBLLinearEnergyAtomicModel,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.model.dp_model import (
DPModelCommon,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)

from .make_model import (
make_model,
)

DPEnergyModel_ = make_model(DPZBLLinearEnergyAtomicModel)


@BaseModel.register("zbl")
class DPZBLModel(DPEnergyModel_):
def __init__(
self,
*args,
**kwargs,
):
DPEnergyModel_.__init__(self, *args, **kwargs)


@classmethod
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[list[str]],
local_jdata: dict,
) -> tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.

Parameters
----------
train_data : DeepmdDataSystem
data used to do neighbor statistics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class

Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["dpmodel"], min_nbor_dist = DPModelCommon.update_sel(
train_data, type_map, local_jdata["dpmodel"]
)
return local_jdata_cpy, min_nbor_dist
52 changes: 52 additions & 0 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.atomic_model.pairtab_atomic_model import (
PairTabAtomicModel,
)
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.descriptor.se_e2_a import (
DescrptSeA,
)
Expand All @@ -8,6 +17,9 @@
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.model.dp_zbl_model import (
DPZBLModel,
)
from deepmd.dpmodel.model.ener_model import (
EnergyModel,
)
Expand Down Expand Up @@ -55,6 +67,44 @@ def get_standard_model(data: dict) -> EnergyModel:
)


def get_zbl_model(data: dict):
descriptor = BaseDescriptor(**data["descriptor"])
fitting_type = data["fitting_net"].pop("type")
if fitting_type == "ener":
fitting = EnergyFittingNet(
ntypes=descriptor.get_ntypes(),
dim_descrpt=descriptor.get_dim_out(),
mixed_types=descriptor.mixed_types(),
**data["fitting_net"],
)
else:
raise ValueError(f"Unknown fitting type {fitting_type}")

dp_model = DPAtomicModel(descriptor, fitting, type_map=data["type_map"])
# pairtab
filepath = data["use_srtab"]
pt_model = PairTabAtomicModel(
filepath,
data["descriptor"]["rcut"],
data["descriptor"]["sel"],
type_map=data["type_map"],
)
anyangml marked this conversation as resolved.
Show resolved Hide resolved
anyangml marked this conversation as resolved.
Show resolved Hide resolved

rmin = data["sw_rmin"]
rmax = data["sw_rmax"]
anyangml marked this conversation as resolved.
Show resolved Hide resolved
atom_exclude_types = data.get("atom_exclude_types", [])
pair_exclude_types = data.get("pair_exclude_types", [])
return DPZBLModel(
dp_model,
pt_model,
rmin,
rmax,
type_map=data["type_map"],
atom_exclude_types=atom_exclude_types,
pair_exclude_types=pair_exclude_types,
)


def get_spin_model(data: dict) -> SpinModel:
"""Get a spin model from a dictionary.

Expand Down Expand Up @@ -100,6 +150,8 @@ def get_model(data: dict):
if model_type == "standard":
if "spin" in data:
return get_spin_model(data)
elif "use_srtab" in data:
return get_zbl_model(data)
else:
return get_standard_model(data)
else:
Expand Down
4 changes: 2 additions & 2 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
class CommonTest(ABC):
data: ClassVar[dict]
"""Arguments data."""
addtional_data: ClassVar[dict] = {}
additional_data: ClassVar[dict] = {}
"""Additional data that will not be checked."""
tf_class: ClassVar[Optional[type]]
"""TensorFlow model class."""
Expand Down Expand Up @@ -128,7 +128,7 @@ def init_backend_cls(self, cls) -> Any:

def pass_data_to_cls(self, cls, data) -> Any:
"""Pass data to the class."""
return cls(**data, **self.addtional_data)
return cls(**data, **self.additional_data)

@abstractmethod
def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def setUp(self):
self.atype.sort()

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def setUp(self):
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def setUp(self):
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def setUp(self):
self.atype.sort()

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def setUp(self):
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/model/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def pass_data_to_cls(self, cls, data) -> Any:
return get_model_pt(data)
elif cls is EnergyModelJAX:
return get_model_jax(data)
return cls(**data, **self.addtional_data)
return cls(**data, **self.additional_data)

def setUp(self):
CommonTest.setUp(self)
Expand Down
Loading
Loading