Skip to content

Commit

Permalink
merge with make_model
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 26, 2024
1 parent 40ad218 commit d35198b
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 77 deletions.
130 changes: 99 additions & 31 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Optional,
)

Expand Down Expand Up @@ -39,6 +40,95 @@
)


def model_call_from_call_lower(
*, # enforce keyword-only arguments
call_lower: Callable[
[
np.ndarray,
np.ndarray,
np.ndarray,
Optional[np.ndarray],
Optional[np.ndarray],
bool,
],
dict[str, np.ndarray],
],
rcut: float,
sel: list[int],
mixed_types: bool,
model_output_def: ModelOutputDef,
coord: np.ndarray,
atype: np.ndarray,
box: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
do_atomic_virial: bool = False,
):
"""Return model prediction from lower interface.
Parameters
----------
coord
The coordinates of the atoms.
shape: nf x (nloc x 3)
atype
The type of atoms. shape: nf x nloc
box
The simulation box. shape: nf x 9
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
do_atomic_virial
If calculate the atomic virial.
Returns
-------
ret_dict
The result dict of type dict[str,np.ndarray].
The keys are defined by the `ModelOutputDef`.
"""
nframes, nloc = atype.shape[:2]
cc, bb, fp, ap = coord, box, fparam, aparam
del coord, box, fparam, aparam
if bb is not None:
coord_normalized = normalize_coord(
cc.reshape(nframes, nloc, 3),
bb.reshape(nframes, 3, 3),
)
else:
coord_normalized = cc.copy()
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype, bb, rcut
)
nlist = build_neighbor_list(
extended_coord,
extended_atype,
nloc,
rcut,
sel,
distinguish_types=not mixed_types,
)
extended_coord = extended_coord.reshape(nframes, -1, 3)
model_predict_lower = call_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
model_predict = communicate_extended_output(
model_predict_lower,
model_output_def,
mapping,
do_atomic_virial=do_atomic_virial,
)
return model_predict


def make_model(T_AtomicModel: type[BaseAtomicModel]):
"""Make a model as a derived class of an atomic model.
Expand Down Expand Up @@ -130,45 +220,23 @@ def call(
The keys are defined by the `ModelOutputDef`.
"""
nframes, nloc = atype.shape[:2]
cc, bb, fp, ap, input_prec = self.input_type_cast(
coord, box=box, fparam=fparam, aparam=aparam
)
del coord, box, fparam, aparam
if bb is not None:
coord_normalized = normalize_coord(
cc.reshape(nframes, nloc, 3),
bb.reshape(nframes, 3, 3),
)
else:
coord_normalized = cc.copy()
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype, bb, self.get_rcut()
)
nlist = build_neighbor_list(
extended_coord,
extended_atype,
nloc,
self.get_rcut(),
self.get_sel(),
distinguish_types=not self.mixed_types(),
)
extended_coord = extended_coord.reshape(nframes, -1, 3)
model_predict_lower = self.call_lower(
extended_coord,
extended_atype,
nlist,
mapping,
model_predict = model_call_from_call_lower(
call_lower=self.call_lower,
rcut=self.get_rcut(),
sel=self.get_sel(),
mixed_types=self.mixed_types(),
model_output_def=self.model_output_def(),
coord=cc,
atype=atype,
box=bb,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
model_predict = communicate_extended_output(
model_predict_lower,
self.model_output_def(),
mapping,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

Expand Down
59 changes: 13 additions & 46 deletions deepmd/jax/model/hlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,14 @@
Optional,
)

from deepmd.dpmodel.model.transform_output import (
communicate_extended_output,
from deepmd.dpmodel.model.make_model import (
model_call_from_call_lower,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
)
from deepmd.dpmodel.utils.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
)
from deepmd.dpmodel.utils.region import (
normalize_coord,
)
from deepmd.jax.env import (
jax_export,
jnp,
Expand Down Expand Up @@ -148,44 +141,19 @@ def call(
The keys are defined by the `ModelOutputDef`.
"""
nframes, nloc = atype.shape[:2]
cc, bb, fp, ap = coord, box, fparam, aparam
del coord, box, fparam, aparam
if bb is not None:
coord_normalized = normalize_coord(
cc.reshape(nframes, nloc, 3),
bb.reshape(nframes, 3, 3),
)
else:
coord_normalized = cc.copy()
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype, bb, self.get_rcut()
)
nlist = build_neighbor_list(
extended_coord,
extended_atype,
nloc,
self.get_rcut(),
self.get_sel(),
distinguish_types=not self.mixed_types(),
)
extended_coord = extended_coord.reshape(nframes, -1, 3)
model_predict_lower = self.call_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
model_predict = communicate_extended_output(
model_predict_lower,
self.model_output_def(),
mapping,
return model_call_from_call_lower(
call_lower=self.call_lower,
rcut=self.get_rcut(),
sel=self.get_sel(),
mixed_types=self.mixed_types(),
model_output_def=self.model_output_def(),
coord=coord,
atype=atype,
box=box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
return model_predict

def model_output_def(self):
return ModelOutputDef(
Expand Down Expand Up @@ -284,7 +252,6 @@ def get_min_nbor_dist(self) -> Optional[float]:

def get_nnei(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
# for C++ interface
return self.nsel

def get_sel(self) -> list[int]:
Expand Down

0 comments on commit d35198b

Please sign in to comment.