From d35198bcec67f8c6f8f0439402222b83ef1cad4f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 26 Oct 2024 02:47:01 -0400 Subject: [PATCH] merge with make_model Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/model/make_model.py | 130 ++++++++++++++++++++++------- deepmd/jax/model/hlo.py | 59 +++---------- 2 files changed, 112 insertions(+), 77 deletions(-) diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index dc90f10da7..66719c1bb9 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Callable, Optional, ) @@ -39,6 +40,95 @@ ) +def model_call_from_call_lower( + *, # enforce keyword-only arguments + call_lower: Callable[ + [ + np.ndarray, + np.ndarray, + np.ndarray, + Optional[np.ndarray], + Optional[np.ndarray], + bool, + ], + dict[str, np.ndarray], + ], + rcut: float, + sel: list[int], + mixed_types: bool, + model_output_def: ModelOutputDef, + coord: np.ndarray, + atype: np.ndarray, + box: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + do_atomic_virial: bool = False, +): + """Return model prediction from lower interface. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type dict[str,np.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + nframes, nloc = atype.shape[:2] + cc, bb, fp, ap = coord, box, fparam, aparam + del coord, box, fparam, aparam + if bb is not None: + coord_normalized = normalize_coord( + cc.reshape(nframes, nloc, 3), + bb.reshape(nframes, 3, 3), + ) + else: + coord_normalized = cc.copy() + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, bb, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + distinguish_types=not mixed_types, + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + model_predict_lower = call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fp, + aparam=ap, + do_atomic_virial=do_atomic_virial, + ) + model_predict = communicate_extended_output( + model_predict_lower, + model_output_def, + mapping, + do_atomic_virial=do_atomic_virial, + ) + return model_predict + + def make_model(T_AtomicModel: type[BaseAtomicModel]): """Make a model as a derived class of an atomic model. @@ -130,45 +220,23 @@ def call( The keys are defined by the `ModelOutputDef`. """ - nframes, nloc = atype.shape[:2] cc, bb, fp, ap, input_prec = self.input_type_cast( coord, box=box, fparam=fparam, aparam=aparam ) del coord, box, fparam, aparam - if bb is not None: - coord_normalized = normalize_coord( - cc.reshape(nframes, nloc, 3), - bb.reshape(nframes, 3, 3), - ) - else: - coord_normalized = cc.copy() - extended_coord, extended_atype, mapping = extend_coord_with_ghosts( - coord_normalized, atype, bb, self.get_rcut() - ) - nlist = build_neighbor_list( - extended_coord, - extended_atype, - nloc, - self.get_rcut(), - self.get_sel(), - distinguish_types=not self.mixed_types(), - ) - extended_coord = extended_coord.reshape(nframes, -1, 3) - model_predict_lower = self.call_lower( - extended_coord, - extended_atype, - nlist, - mapping, + model_predict = model_call_from_call_lower( + call_lower=self.call_lower, + rcut=self.get_rcut(), + sel=self.get_sel(), + mixed_types=self.mixed_types(), + model_output_def=self.model_output_def(), + coord=cc, + atype=atype, + box=bb, fparam=fp, aparam=ap, do_atomic_virial=do_atomic_virial, ) - model_predict = communicate_extended_output( - model_predict_lower, - self.model_output_def(), - mapping, - do_atomic_virial=do_atomic_virial, - ) model_predict = self.output_type_cast(model_predict, input_prec) return model_predict diff --git a/deepmd/jax/model/hlo.py b/deepmd/jax/model/hlo.py index 8c974e6502..010e3d7a5e 100644 --- a/deepmd/jax/model/hlo.py +++ b/deepmd/jax/model/hlo.py @@ -4,21 +4,14 @@ Optional, ) -from deepmd.dpmodel.model.transform_output import ( - communicate_extended_output, +from deepmd.dpmodel.model.make_model import ( + model_call_from_call_lower, ) from deepmd.dpmodel.output_def import ( FittingOutputDef, ModelOutputDef, OutputVariableDef, ) -from deepmd.dpmodel.utils.nlist import ( - build_neighbor_list, - extend_coord_with_ghosts, -) -from deepmd.dpmodel.utils.region import ( - normalize_coord, -) from deepmd.jax.env import ( jax_export, jnp, @@ -148,44 +141,19 @@ def call( The keys are defined by the `ModelOutputDef`. """ - nframes, nloc = atype.shape[:2] - cc, bb, fp, ap = coord, box, fparam, aparam - del coord, box, fparam, aparam - if bb is not None: - coord_normalized = normalize_coord( - cc.reshape(nframes, nloc, 3), - bb.reshape(nframes, 3, 3), - ) - else: - coord_normalized = cc.copy() - extended_coord, extended_atype, mapping = extend_coord_with_ghosts( - coord_normalized, atype, bb, self.get_rcut() - ) - nlist = build_neighbor_list( - extended_coord, - extended_atype, - nloc, - self.get_rcut(), - self.get_sel(), - distinguish_types=not self.mixed_types(), - ) - extended_coord = extended_coord.reshape(nframes, -1, 3) - model_predict_lower = self.call_lower( - extended_coord, - extended_atype, - nlist, - mapping, - fparam=fp, - aparam=ap, - do_atomic_virial=do_atomic_virial, - ) - model_predict = communicate_extended_output( - model_predict_lower, - self.model_output_def(), - mapping, + return model_call_from_call_lower( + call_lower=self.call_lower, + rcut=self.get_rcut(), + sel=self.get_sel(), + mixed_types=self.mixed_types(), + model_output_def=self.model_output_def(), + coord=coord, + atype=atype, + box=box, + fparam=fparam, + aparam=aparam, do_atomic_virial=do_atomic_virial, ) - return model_predict def model_output_def(self): return ModelOutputDef( @@ -284,7 +252,6 @@ def get_min_nbor_dist(self) -> Optional[float]: def get_nnei(self) -> int: """Returns the total number of selected neighboring atoms in the cut-off radius.""" - # for C++ interface return self.nsel def get_sel(self) -> list[int]: