From 94d38b7739507646a502289e6c7aaa83dccb31d4 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 15 Mar 2021 20:19:22 +0800 Subject: [PATCH 1/6] add model version support. implemented python interface. --- deepmd/entrypoints/freeze.py | 2 + deepmd/infer/deep_dipole.py | 6 +- deepmd/infer/deep_eval.py | 211 ++++++---------------------- deepmd/infer/deep_polar.py | 8 +- deepmd/infer/deep_pot.py | 55 ++++++-- deepmd/infer/deep_tensor.py | 172 +++++++++++++++++++++++ deepmd/infer/deep_wfc.py | 5 +- source/api_cc/include/version.h.in | 1 + source/tests/infer/deepdipole.pbtxt | 21 +++ source/tests/infer/deeppolar.pbtxt | 21 +++ source/tests/infer/deeppot-1.pbtxt | 21 +++ source/tests/infer/deeppot-r.pbtxt | 21 +++ source/tests/infer/deeppot.pbtxt | 21 +++ source/tests/test_deeppot_a.py | 59 ++++++++ source/train/MODEL_VER | 1 + source/train/model.py | 9 +- source/train/run_config.ini | 3 +- source/train/run_options.py | 2 + 18 files changed, 445 insertions(+), 194 deletions(-) create mode 100644 deepmd/infer/deep_tensor.py create mode 100644 source/train/MODEL_VER diff --git a/deepmd/entrypoints/freeze.py b/deepmd/entrypoints/freeze.py index d47dffaf87..e11ac0c906 100755 --- a/deepmd/entrypoints/freeze.py +++ b/deepmd/entrypoints/freeze.py @@ -43,6 +43,7 @@ def _make_node_names(model_type: str, modifier_type: Optional[str] = None) -> Li "descrpt_attr/ntypes", "model_attr/tmap", "model_attr/model_type", + "model_attr/model_version", ] if model_type == "ener": @@ -59,6 +60,7 @@ def _make_node_names(model_type: str, modifier_type: Optional[str] = None) -> Li nodes += [ "o_wfc", "model_attr/sel_type", + "model_attr/output_dim", ] elif model_type == "dipole": nodes += [ diff --git a/deepmd/infer/deep_dipole.py b/deepmd/infer/deep_dipole.py index 7ad3ee49a5..d158cbafac 100644 --- a/deepmd/infer/deep_dipole.py +++ b/deepmd/infer/deep_dipole.py @@ -1,6 +1,6 @@ -from typing import TYPE_CHECKING +from deepmd.infer.deep_tensor import DeepTensor -from deepmd.infer.deep_eval import DeepTensor +from typing import TYPE_CHECKING if TYPE_CHECKING: from pathlib import Path @@ -33,7 +33,6 @@ def __init__( # instance namespace self.tensors = dict( { - "t_sel_type": "model_attr/sel_type:0", # output tensor "t_tensor": "o_dipole:0", }, @@ -43,7 +42,6 @@ def __init__( DeepTensor.__init__( self, model_file, - 3, load_prefix=load_prefix, default_tf_graph=default_tf_graph, ) diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index 2c52644eb1..dc31995e58 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -4,6 +4,7 @@ import numpy as np from deepmd.common import make_default_mesh from deepmd.env import default_tf_session_config, tf +from deepmd.run_options import MODEL_VERSION if TYPE_CHECKING: from pathlib import Path @@ -13,6 +14,7 @@ class DeepEval: """Common methods for DeepPot, DeepWFC, DeepPolar, ...""" _model_type: Optional[str] = None + _model_version: Optional[str] = None load_prefix: str # set by subclass def __init__( @@ -26,6 +28,12 @@ def __init__( ) self.load_prefix = load_prefix + if not self._graph_compatable(): + raise RuntimeError( + f"model in graph (version {self.model_version}) is incompatible" + f"with the model (version {MODEL_VERSION}) supported by the current code." + ) + @property def model_type(self) -> str: """Get type of model. @@ -37,9 +45,44 @@ def model_type(self) -> str: sess = tf.Session(graph=self.graph, config=default_tf_session_config) [mt] = sess.run([t_mt], feed_dict={}) self._model_type = mt.decode("utf-8") - return self._model_type + @property + def model_version(self) -> str: + """Get type of model. + + :type:str + """ + if not self._model_version: + try: + t_mt = self._get_tensor("model_attr/model_version:0") + sess = tf.Session(graph=self.graph, config=default_tf_session_config) + [mt] = sess.run([t_mt], feed_dict={}) + self._model_version = mt.decode("utf-8") + except KeyError: + # For deepmd-kit version 0.x - 1.x, set model version to 0.0 + self._model_version = "0.0" + return self._model_version + + def _graph_compatable( + self + ) -> bool : + """ Check the model compatability + + Return + bool + If the model stored in the graph file is compatable with the current code + """ + model_version_major = int(self.model_version.split('.')[0]) + model_version_minor = int(self.model_version.split('.')[1]) + MODEL_VERSION_MAJOR = int(MODEL_VERSION.split('.')[0]) + MODEL_VERSION_MINOR = int(MODEL_VERSION.split('.')[1]) + if (model_version_major != MODEL_VERSION_MAJOR) or \ + (model_version_minor > MODEL_VERSION_MINOR) : + return False + else: + return True + def _get_tensor( self, tensor_name: str, attr_name: Optional[str] = None ) -> tf.Tensor: @@ -70,8 +113,6 @@ def _get_tensor( def _load_graph( frozen_graph_filename: "Path", prefix: str = "load", default_tf_graph: bool = False ): - - # We load the protobuf file from the disk and parse it to retrieve the # unserialized graph_def with tf.gfile.GFile(str(frozen_graph_filename), "rb") as f: @@ -101,168 +142,6 @@ def _load_graph( return graph -class DeepTensor(DeepEval): - """Evaluates a tensor model. - - Constructor - - Parameters - ---------- - model_file: str - The name of the frozen model file. - variable_dof: int - The DOF of the variable to evaluate. - load_prefix: str - The prefix in the load computational graph - default_tf_graph : bool - If uses the default tf graph, otherwise build a new tf graph for evaluation - """ - - tensors = { - "t_ntypes": "descrpt_attr/ntypes:0", - "t_rcut": "descrpt_attr/rcut:0", - "t_tmap": "model_attr/tmap:0", - # inputs - "t_coord": "t_coord:0", - "t_type": "t_type:0", - "t_natoms": "t_natoms:0", - "t_box": "t_box:0", - "t_mesh": "t_mesh:0", - } - - def __init__( - self, - model_file: "Path", - variable_dof: Optional[int], - load_prefix: str = 'load', - default_tf_graph: bool = False - ) -> None: - DeepEval.__init__( - self, - model_file, - load_prefix=load_prefix, - default_tf_graph=default_tf_graph - ) - self.variable_dof = variable_dof - - # now load tensors to object attributes - for attr_name, tensor_name in self.tensors.items(): - self._get_tensor(tensor_name, attr_name) - - # start a tf session associated to the graph - self.sess = tf.Session(graph=self.graph, config=default_tf_session_config) - self._run_default_sess() - self.tmap = self.tmap.decode('UTF-8').split() - - def _run_default_sess(self): - [self.ntypes, self.rcut, self.tmap, self.tselt] = self.sess.run( - [self.t_ntypes, self.t_rcut, self.t_tmap, self.t_sel_type] - ) - - def get_ntypes(self) -> int: - """Get the number of atom types of this model.""" - return self.ntypes - - def get_rcut(self) -> float: - """Get the cut-off radius of this model.""" - return self.rcut - - def get_type_map(self) -> List[int]: - """Get the type map (element name of the atom types) of this model.""" - return self.tmap - - def get_sel_type(self) -> List[int]: - """Get the selected atom types of this model.""" - return self.tselt - - def get_dim_fparam(self) -> int: - """Get the number (dimension) of frame parameters of this DP.""" - return self.dfparam - - def get_dim_aparam(self) -> int: - """Get the number (dimension) of atomic parameters of this DP.""" - return self.daparam - - def eval( - self, - coords: np.array, - cells: np.array, - atom_types: List[int], - atomic: bool = True, - fparam: Optional[np.array] = None, - aparam: Optional[np.array] = None, - efield: Optional[np.array] = None - ) -> np.array: - """Evaluate the model. - - Parameters - ---------- - coords - The coordinates of atoms. - The array should be of size nframes x natoms x 3 - cells - The cell of the region. - If None then non-PBC is assumed, otherwise using PBC. - The array should be of size nframes x 9 - atom_types - The atom types - The list should contain natoms ints - atomic - Calculate the atomic energy and virial - fparam - Not used in this model - aparam - Not used in this model - efield - Not used in this model - - Returns - ------- - tensor - The returned tensor - If atomic == False then of size nframes x variable_dof - else of size nframes x natoms x variable_dof - """ - # standarize the shape of inputs - coords = np.array(coords) - cells = np.array(cells) - atom_types = np.array(atom_types, dtype = int) - - # reshape the inputs - cells = np.reshape(cells, [-1, 9]) - nframes = cells.shape[0] - coords = np.reshape(coords, [nframes, -1]) - natoms = coords.shape[1] // 3 - - # sort inputs - coords, atom_types, imap, sel_at, sel_imap = self.sort_input(coords, atom_types, sel_atoms = self.get_sel_type()) - - # make natoms_vec and default_mesh - natoms_vec = self.make_natoms_vec(atom_types) - assert(natoms_vec[0] == natoms) - - # evaluate - tensor = [] - feed_dict_test = {} - feed_dict_test[self.t_natoms] = natoms_vec - feed_dict_test[self.t_type ] = np.tile(atom_types, [nframes,1]).reshape([-1]) - t_out = [self.t_tensor] - feed_dict_test[self.t_coord] = np.reshape(coords, [-1]) - feed_dict_test[self.t_box ] = np.reshape(cells , [-1]) - feed_dict_test[self.t_mesh ] = make_default_mesh(cells) - v_out = self.sess.run (t_out, feed_dict = feed_dict_test) - tensor = v_out[0] - - # reverse map of the outputs - if atomic: - tensor = np.array(tensor) - tensor = self.reverse_map(np.reshape(tensor, [nframes,-1,self.variable_dof]), sel_imap) - tensor = np.reshape(tensor, [nframes, len(sel_at), self.variable_dof]) - else: - tensor = np.reshape(tensor, [nframes, self.variable_dof]) - - return tensor - @staticmethod def sort_input( coord : np.array, atom_type : np.array, sel_atoms : List[int] = None @@ -339,6 +218,7 @@ def reverse_map(vec : np.ndarray, imap : List[int]) -> np.ndarray: ret[:,ii,:] = vec[:,idx,:] return ret + def make_natoms_vec(self, atom_types : np.ndarray) -> np.ndarray : """Make the natom vector used by deepmd-kit. @@ -363,4 +243,3 @@ def make_natoms_vec(self, atom_types : np.ndarray) -> np.ndarray : for ii in range (self.ntypes) : natoms_vec[ii+2] = np.count_nonzero(atom_types == ii) return natoms_vec - diff --git a/deepmd/infer/deep_polar.py b/deepmd/infer/deep_polar.py index 48f565faa2..7ee02cf1c8 100644 --- a/deepmd/infer/deep_polar.py +++ b/deepmd/infer/deep_polar.py @@ -1,7 +1,7 @@ -from typing import TYPE_CHECKING, List, Optional - +from deepmd.infer.deep_tensor import DeepTensor import numpy as np -from deepmd.infer.deep_eval import DeepTensor + +from typing import TYPE_CHECKING, List, Optional if TYPE_CHECKING: from pathlib import Path @@ -34,7 +34,6 @@ def __init__( # instance namespace self.tensors = dict( { - "t_sel_type": "model_attr/sel_type:0", # output tensor "t_tensor": "o_polar:0", }, @@ -44,7 +43,6 @@ def __init__( DeepTensor.__init__( self, model_file, - 9, load_prefix=load_prefix, default_tf_graph=default_tf_graph, ) diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index 8fefe843d5..a8e70d5a72 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -3,8 +3,9 @@ import numpy as np from deepmd.common import make_default_mesh +from deepmd.env import default_tf_session_config, tf from deepmd.infer.data_modifier import DipoleChargeModifier -from deepmd.infer.deep_eval import DeepEval, DeepTensor +from deepmd.infer.deep_eval import DeepEval if TYPE_CHECKING: from pathlib import Path @@ -12,7 +13,7 @@ log = logging.getLogger(__name__) -class DeepPot(DeepTensor): +class DeepPot(DeepEval): """Constructor. Parameters @@ -43,9 +44,20 @@ def __init__( # instance namespace self.tensors = dict( { - # general + # descrpt attrs + "t_ntypes": "descrpt_attr/ntypes:0", + "t_rcut": "descrpt_attr/rcut:0", + # fitting attrs "t_dfparam": "fitting_attr/dfparam:0", "t_daparam": "fitting_attr/daparam:0", + # model attrs + "t_tmap": "model_attr/tmap:0", + # inputs + "t_coord": "t_coord:0", + "t_type": "t_type:0", + "t_natoms": "t_natoms:0", + "t_box": "t_box:0", + "t_mesh": "t_mesh:0", # add output tensors "t_energy": "o_energy:0", "t_force": "o_force:0", @@ -53,7 +65,6 @@ def __init__( "t_ae": "o_atom_energy:0", "t_av": "o_atom_virial:0" }, - **self.tensors ) DeepEval.__init__( self, @@ -90,15 +101,14 @@ def __init__( self.t_aparam = None self.has_aparam = False - # now when tensors are set initialize DeepTensor which will load them all - # to class attributes, the run session assciated with the graph - DeepTensor.__init__( - self, - model_file, - None, - load_prefix=load_prefix, - default_tf_graph=default_tf_graph - ) + # now load tensors to object attributes + for attr_name, tensor_name in self.tensors.items(): + self._get_tensor(tensor_name, attr_name) + + # start a tf session associated to the graph + self.sess = tf.Session(graph=self.graph, config=default_tf_session_config) + self._run_default_sess() + self.tmap = self.tmap.decode('UTF-8').split() # setup modifier try: @@ -123,9 +133,28 @@ def _run_default_sess(self): [self.t_ntypes, self.t_rcut, self.t_dfparam, self.t_daparam, self.t_tmap] ) + def get_ntypes(self) -> int: + """Get the number of atom types of this model.""" + return self.ntypes + + def get_rcut(self) -> float: + """Get the cut-off radius of this model.""" + return self.rcut + + def get_type_map(self) -> List[int]: + """Get the type map (element name of the atom types) of this model.""" + return self.tmap + def get_sel_type(self) -> List[int]: """Unsupported in this model.""" raise NotImplementedError("This model type does not support this attribute") + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this DP.""" + return self.dfparam + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this DP.""" + return self.daparam def eval( self, diff --git a/deepmd/infer/deep_tensor.py b/deepmd/infer/deep_tensor.py new file mode 100644 index 0000000000..24a7832a32 --- /dev/null +++ b/deepmd/infer/deep_tensor.py @@ -0,0 +1,172 @@ +import os +from typing import List, Optional, TYPE_CHECKING + +import numpy as np +from deepmd.common import make_default_mesh +from deepmd.env import default_tf_session_config, tf +from deepmd.infer.deep_eval import DeepEval + +if TYPE_CHECKING: + from pathlib import Path + +class DeepTensor(DeepEval): + """Evaluates a tensor model. + + Constructor + + Parameters + ---------- + model_file: str + The name of the frozen model file. + load_prefix: str + The prefix in the load computational graph + default_tf_graph : bool + If uses the default tf graph, otherwise build a new tf graph for evaluation + """ + + tensors = { + # descriptor attrs + "t_ntypes": "descrpt_attr/ntypes:0", + "t_rcut": "descrpt_attr/rcut:0", + # model attrs + "t_tmap": "model_attr/tmap:0", + "t_sel_type": "model_attr/sel_type:0", + "t_ouput_dim": "model_attr/output_dim:0", + # inputs + "t_coord": "t_coord:0", + "t_type": "t_type:0", + "t_natoms": "t_natoms:0", + "t_box": "t_box:0", + "t_mesh": "t_mesh:0", + } + + def __init__( + self, + model_file: "Path", + load_prefix: str = 'load', + default_tf_graph: bool = False + ) -> None: + DeepEval.__init__( + self, + model_file, + load_prefix=load_prefix, + default_tf_graph=default_tf_graph + ) + # now load tensors to object attributes + for attr_name, tensor_name in self.tensors.items(): + self._get_tensor(tensor_name, attr_name) + + # start a tf session associated to the graph + self.sess = tf.Session(graph=self.graph, config=default_tf_session_config) + self._run_default_sess() + self.tmap = self.tmap.decode('UTF-8').split() + + def _run_default_sess(self): + [self.ntypes, self.rcut, self.tmap, self.tselt, self.output_dim] \ + = self.sess.run( + [self.t_ntypes, self.t_rcut, self.t_tmap, self.t_sel_type, self.t_ouput_dim] + ) + + def get_ntypes(self) -> int: + """Get the number of atom types of this model.""" + return self.ntypes + + def get_rcut(self) -> float: + """Get the cut-off radius of this model.""" + return self.rcut + + def get_type_map(self) -> List[int]: + """Get the type map (element name of the atom types) of this model.""" + return self.tmap + + def get_sel_type(self) -> List[int]: + """Get the selected atom types of this model.""" + return self.tselt + + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this DP.""" + return self.dfparam + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this DP.""" + return self.daparam + + def eval( + self, + coords: np.array, + cells: np.array, + atom_types: List[int], + atomic: bool = True, + fparam: Optional[np.array] = None, + aparam: Optional[np.array] = None, + efield: Optional[np.array] = None + ) -> np.array: + """Evaluate the model. + + Parameters + ---------- + coords + The coordinates of atoms. + The array should be of size nframes x natoms x 3 + cells + The cell of the region. + If None then non-PBC is assumed, otherwise using PBC. + The array should be of size nframes x 9 + atom_types + The atom types + The list should contain natoms ints + atomic + Calculate the atomic energy and virial + fparam + Not used in this model + aparam + Not used in this model + efield + Not used in this model + + Returns + ------- + tensor + The returned tensor + If atomic == False then of size nframes x output_dim + else of size nframes x natoms x output_dim + """ + # standarize the shape of inputs + coords = np.array(coords) + cells = np.array(cells) + atom_types = np.array(atom_types, dtype = int) + + # reshape the inputs + cells = np.reshape(cells, [-1, 9]) + nframes = cells.shape[0] + coords = np.reshape(coords, [nframes, -1]) + natoms = coords.shape[1] // 3 + + # sort inputs + coords, atom_types, imap, sel_at, sel_imap = self.sort_input(coords, atom_types, sel_atoms = self.get_sel_type()) + + # make natoms_vec and default_mesh + natoms_vec = self.make_natoms_vec(atom_types) + assert(natoms_vec[0] == natoms) + + # evaluate + tensor = [] + feed_dict_test = {} + feed_dict_test[self.t_natoms] = natoms_vec + feed_dict_test[self.t_type ] = np.tile(atom_types, [nframes,1]).reshape([-1]) + t_out = [self.t_tensor] + feed_dict_test[self.t_coord] = np.reshape(coords, [-1]) + feed_dict_test[self.t_box ] = np.reshape(cells , [-1]) + feed_dict_test[self.t_mesh ] = make_default_mesh(cells) + v_out = self.sess.run (t_out, feed_dict = feed_dict_test) + tensor = v_out[0] + + # reverse map of the outputs + if atomic: + tensor = np.array(tensor) + tensor = self.reverse_map(np.reshape(tensor, [nframes,-1,self.output_dim]), sel_imap) + tensor = np.reshape(tensor, [nframes, len(sel_at), self.output_dim]) + else: + tensor = np.reshape(tensor, [nframes, self.output_dim]) + + return tensor diff --git a/deepmd/infer/deep_wfc.py b/deepmd/infer/deep_wfc.py index 1f976af04a..40d3cd6a5c 100644 --- a/deepmd/infer/deep_wfc.py +++ b/deepmd/infer/deep_wfc.py @@ -1,4 +1,5 @@ -from deepmd.infer.deep_eval import DeepTensor +from deepmd.infer.deep_tensor import DeepTensor + from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -32,7 +33,6 @@ def __init__( # instance namespace self.tensors = dict( { - "t_sel_type": "model_attr/sel_type:0", # output tensor "t_tensor": "o_wfc:0", }, @@ -41,7 +41,6 @@ def __init__( DeepTensor.__init__( self, model_file, - 12, load_prefix=load_prefix, default_tf_graph=default_tf_graph, ) diff --git a/source/api_cc/include/version.h.in b/source/api_cc/include/version.h.in index d3acbf5aff..23c22ae588 100644 --- a/source/api_cc/include/version.h.in +++ b/source/api_cc/include/version.h.in @@ -16,3 +16,4 @@ const std::string global_git_date="@GIT_DATE@"; const std::string global_git_branch="@GIT_BRANCH@"; const std::string global_tf_include_dir="@TensorFlow_INCLUDE_DIRS@"; const std::string global_tf_lib="@TensorFlow_LIBRARY@"; +const std::string global_model_version="@MODEL_VERSION@"; diff --git a/source/tests/infer/deepdipole.pbtxt b/source/tests/infer/deepdipole.pbtxt index a8f09a3331..b503c29336 100644 --- a/source/tests/infer/deepdipole.pbtxt +++ b/source/tests/infer/deepdipole.pbtxt @@ -5950,6 +5950,27 @@ node { } } } +node { + name: "model_attr/model_version" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "1.0" + } + } + } +} node { name: "model_attr/sel_type" op: "Const" diff --git a/source/tests/infer/deeppolar.pbtxt b/source/tests/infer/deeppolar.pbtxt index d7f29bdd87..49b9645b68 100644 --- a/source/tests/infer/deeppolar.pbtxt +++ b/source/tests/infer/deeppolar.pbtxt @@ -6178,6 +6178,27 @@ node { } } } +node { + name: "model_attr/model_version" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "1.0" + } + } + } +} node { name: "model_attr/sel_type" op: "Const" diff --git a/source/tests/infer/deeppot-1.pbtxt b/source/tests/infer/deeppot-1.pbtxt index 3f8034de55..0819df4b9e 100644 --- a/source/tests/infer/deeppot-1.pbtxt +++ b/source/tests/infer/deeppot-1.pbtxt @@ -8875,6 +8875,27 @@ node { } } } +node { + name: "model_attr/model_version" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "1.0" + } + } + } +} node { name: "model_attr/tmap" op: "Const" diff --git a/source/tests/infer/deeppot-r.pbtxt b/source/tests/infer/deeppot-r.pbtxt index 3ddd2effe4..c307be00f0 100644 --- a/source/tests/infer/deeppot-r.pbtxt +++ b/source/tests/infer/deeppot-r.pbtxt @@ -8529,6 +8529,27 @@ node { } } } +node { + name: "model_attr/model_version" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "1.0" + } + } + } +} node { name: "model_attr/tmap" op: "Const" diff --git a/source/tests/infer/deeppot.pbtxt b/source/tests/infer/deeppot.pbtxt index 7a5ba00f78..c7c49e2483 100644 --- a/source/tests/infer/deeppot.pbtxt +++ b/source/tests/infer/deeppot.pbtxt @@ -8875,6 +8875,27 @@ node { } } } +node { + name: "model_attr/model_version" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "1.0" + } + } + } +} node { name: "model_attr/tmap" op: "Const" diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index 8dee81df75..b9a519672d 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -4,6 +4,7 @@ from infer.convert2pb import convert_pbtxt_to_pb from deepmd.infer import DeepPot +from deepmd.run_options import MODEL_VERSION from common import tests_path from deepmd.run_options import GLOBAL_NP_FLOAT_PRECISION @@ -12,6 +13,64 @@ else : default_places = 10 +class TestModelMajorCompatability(unittest.TestCase) : + def setUp(self): + model_file = str(tests_path / os.path.join("infer","deeppot.pbtxt")) + with open(model_file, 'r') as fp: + # data = fp.read().replace('\n', '') + data = fp.read().split("\n") + for ii in range(len(data)): + if "model_attr/model_version" in data[ii]: + for jj in range(ii, len(data)): + if "string_val:" in data[jj]: + data[jj] = data[jj].replace(MODEL_VERSION, "0.0") + break + with open("deeppot-ver.pbtxt", "w") as fp: + fp.write("\n".join(data)) + + convert_pbtxt_to_pb(str(tests_path / os.path.join("deeppot-ver.pbtxt")), "deeppot.pb") + + def tearDown(self): + os.remove("deeppot-ver.pbtxt") + os.remove("deeppot.pb") + + def test(self): + with self.assertRaises(RuntimeError) as context: + DeepPot("deeppot.pb") + self.assertTrue('incompatible' in str(context.exception)) + self.assertTrue(MODEL_VERSION in str(context.exception)) + self.assertTrue('0.0' in str(context.exception)) + + +class TestModelMinorCompatability(unittest.TestCase) : + def setUp(self): + model_file = str(tests_path / os.path.join("infer","deeppot.pbtxt")) + with open(model_file, 'r') as fp: + # data = fp.read().replace('\n', '') + data = fp.read().split("\n") + for ii in range(len(data)): + if "model_attr/model_version" in data[ii]: + for jj in range(ii, len(data)): + if "string_val:" in data[jj]: + data[jj] = data[jj].replace(MODEL_VERSION, "0.1000000") + break + with open("deeppot-ver.pbtxt", "w") as fp: + fp.write("\n".join(data)) + + convert_pbtxt_to_pb(str(tests_path / os.path.join("deeppot-ver.pbtxt")), "deeppot.pb") + + def tearDown(self): + os.remove("deeppot-ver.pbtxt") + os.remove("deeppot.pb") + + def test(self): + with self.assertRaises(RuntimeError) as context: + DeepPot("deeppot.pb") + self.assertTrue('incompatible' in str(context.exception)) + self.assertTrue(MODEL_VERSION in str(context.exception)) + self.assertTrue('0.1000000' in str(context.exception)) + + class TestDeepPotAPBC(unittest.TestCase) : def setUp(self): convert_pbtxt_to_pb(str(tests_path / os.path.join("infer","deeppot.pbtxt")), "deeppot.pb") diff --git a/source/train/MODEL_VER b/source/train/MODEL_VER new file mode 100644 index 0000000000..d3827e75a5 --- /dev/null +++ b/source/train/MODEL_VER @@ -0,0 +1 @@ +1.0 diff --git a/source/train/model.py b/source/train/model.py index d56d390204..8ccee419bf 100644 --- a/source/train/model.py +++ b/source/train/model.py @@ -4,7 +4,7 @@ from deepmd.utils.pair_tab import PairTab from deepmd.common import ClassArg -from deepmd.run_options import global_cvt_2_ener_float +from deepmd.run_options import global_cvt_2_ener_float, MODEL_VERSION from deepmd.env import op_module @@ -146,6 +146,9 @@ def build (self, t_mt = tf.constant(self.model_type, name = 'model_type', dtype = tf.string) + t_ver = tf.constant(MODEL_VERSION, + name = 'model_version', + dtype = tf.string) if self.srtab is not None : tab_info, tab_data = self.srtab.get() @@ -335,11 +338,13 @@ def build (self, t_mt = tf.constant(self.model_type, name = 'model_type', dtype = tf.string) + t_ver = tf.constant(MODEL_VERSION, + name = 'model_version', + dtype = tf.string) t_od = tf.constant(self.get_out_size(), name = 'output_dim', dtype = tf.int32) - dout \ = self.descrpt.build(coord_, atype_, diff --git a/source/train/run_config.ini b/source/train/run_config.ini index 5579f13134..3f2e8cc86a 100644 --- a/source/train/run_config.ini +++ b/source/train/run_config.ini @@ -6,4 +6,5 @@ GIT_DATE = @GIT_DATE@ GIT_BRANCH = @GIT_BRANCH@ TF_INCLUDE_DIR = @TensorFlow_INCLUDE_DIRS@ TF_LIBS = @TensorFlow_LIBRARY@ -PRECISION = @PREC_DEF@ \ No newline at end of file +PRECISION = @PREC_DEF@ +MODEL_VERSION=@MODEL_VERSION@ diff --git a/source/train/run_options.py b/source/train/run_options.py index 632c18af63..0ef39a7e95 100644 --- a/source/train/run_options.py +++ b/source/train/run_options.py @@ -36,6 +36,7 @@ class TFServerV1(Protocol): "global_cvt_2_tf_float", "global_cvt_2_ener_float", "RunOptions", + "MODEL_VERSION", ] log = logging.getLogger(__name__) @@ -62,6 +63,7 @@ def _get_package_constants( GLOBAL_CONFIG = _get_package_constants() +MODEL_VERSION = GLOBAL_CONFIG["model_version"] if GLOBAL_CONFIG["precision"] == "-DHIGH_PREC": GLOBAL_TF_FLOAT_PRECISION = tf.float64 From 3c307f73014b41106f704d3e94ddab051345f1c7 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 16 Mar 2021 10:36:01 +0800 Subject: [PATCH 2/6] support model version check in c++ interface --- source/CMakeLists.txt | 5 +++ source/api_cc/include/DeepPot.h | 4 +++ source/api_cc/include/DeepTensor.h | 1 + source/api_cc/include/common.h | 5 +++ source/api_cc/src/DeepPot.cc | 18 ++++++++++- source/api_cc/src/DeepTensor.cc | 9 +++++- source/api_cc/src/common.cc | 40 +++++++++++++++++++++++ source/api_cc/tests/CMakeLists.txt | 6 ++-- source/tests/infer/dipolecharge_d.pbtxt | 21 +++++++++++++ source/tests/infer/dipolecharge_e.pbtxt | 42 +++++++++++++++++++++++++ 10 files changed, 147 insertions(+), 4 deletions(-) diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index 1803a10f5d..3aa8373aba 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -46,6 +46,11 @@ endif(GIT_FOUND) list (APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/) set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wno-ignored-attributes") +# model version +file(READ ${PROJECT_SOURCE_DIR}/train/MODEL_VER MODEL_VERSION) +string(REPLACE "\n" " " MODEL_VERSION ${MODEL_VERSION}) +message(STATUS "Current model version: ${MODEL_VERSION}") + # define USE_CUDA_TOOLKIT if (DEFINED USE_CUDA_TOOLKIT) if (USE_CUDA_TOOLKIT) diff --git a/source/api_cc/include/DeepPot.h b/source/api_cc/include/DeepPot.h index ba6162a62b..53996903c1 100644 --- a/source/api_cc/include/DeepPot.h +++ b/source/api_cc/include/DeepPot.h @@ -70,6 +70,8 @@ class DeepPot // int get_ntypes () const; VALUETYPE rcut; VALUETYPE cell_size; + std::string model_type; + std::string model_version; int ntypes; int dfparam; int daparam; @@ -163,6 +165,8 @@ class DeepPotModelDevi // int get_ntypes () const; VALUETYPE rcut; VALUETYPE cell_size; + std::string model_type; + std::string model_version; int ntypes; int dfparam; int daparam; diff --git a/source/api_cc/include/DeepTensor.h b/source/api_cc/include/DeepTensor.h index 786908f295..a60c6e1ed9 100644 --- a/source/api_cc/include/DeepTensor.h +++ b/source/api_cc/include/DeepTensor.h @@ -38,6 +38,7 @@ class DeepTensor VALUETYPE cell_size; int ntypes; std::string model_type; + std::string model_version; int odim; std::vector sel_type; template VT get_scalar(const std::string & name) const; diff --git a/source/api_cc/include/common.h b/source/api_cc/include/common.h index 698ebe04ed..46eb076d4f 100644 --- a/source/api_cc/include/common.h +++ b/source/api_cc/include/common.h @@ -42,6 +42,10 @@ struct NeighborListData void make_inlist(InputNlist & inlist); }; +bool +model_compatable( + std::string & model_version); + void select_by_type(std::vector & fwd_map, std::vector & bkw_map, @@ -204,3 +208,4 @@ select_map(std::vector & out, } } + diff --git a/source/api_cc/src/DeepPot.cc b/source/api_cc/src/DeepPot.cc index 71ca3226e9..f122f1173d 100644 --- a/source/api_cc/src/DeepPot.cc +++ b/source/api_cc/src/DeepPot.cc @@ -224,6 +224,14 @@ init (const std::string & model, const int & gpu_rank, const std::string & file_ daparam = get_scalar("fitting_attr/daparam"); if (dfparam < 0) dfparam = 0; if (daparam < 0) daparam = 0; + model_type = get_scalar("model_attr/model_type"); + model_version = get_scalar("model_attr/model_version"); + if(! model_compatable(model_version)){ + throw std::runtime_error( + "incompatable model: version " + model_version + + " in graph, but version " + global_model_version + + " supported "); + } inited = true; init_nbor = false; @@ -547,6 +555,14 @@ init (const std::vector & models, const int & gpu_rank, const std:: daparam = get_scalar("fitting_attr/daparam"); if (dfparam < 0) dfparam = 0; if (daparam < 0) daparam = 0; + model_type = get_scalar("model_attr/model_type"); + model_version = get_scalar("model_attr/model_version"); + if(! model_compatable(model_version)){ + throw std::runtime_error( + "incompatable model: version " + model_version + + " in graph, but version " + global_model_version + + " supported "); + } // rcut = get_rcut(); // cell_size = rcut; // ntypes = get_ntypes(); @@ -560,7 +576,7 @@ VT DeepPotModelDevi:: get_scalar(const std::string name) const { - VT myrcut = 0; + VT myrcut; for (unsigned ii = 0; ii < numb_models; ++ii){ VT ret = session_get_scalar(sessions[ii], name); if (ii == 0){ diff --git a/source/api_cc/src/DeepTensor.cc b/source/api_cc/src/DeepTensor.cc index 270d307013..e701c044f9 100644 --- a/source/api_cc/src/DeepTensor.cc +++ b/source/api_cc/src/DeepTensor.cc @@ -36,9 +36,16 @@ init (const std::string & model, rcut = get_scalar("descrpt_attr/rcut"); cell_size = rcut; ntypes = get_scalar("descrpt_attr/ntypes"); - model_type = get_scalar("model_attr/model_type"); odim = get_scalar("model_attr/output_dim"); get_vector(sel_type, "model_attr/sel_type"); + model_type = get_scalar("model_attr/model_type"); + model_version = get_scalar("model_attr/model_version"); + if(! model_compatable(model_version)){ + throw std::runtime_error( + "incompatable model: version " + model_version + + " in graph, but version " + global_model_version + + " supported "); + } inited = true; } diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index 4a431602d3..560790c1a1 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -2,6 +2,46 @@ #include "AtomMap.h" #include "device.h" +static std::vector +split(const std::string &input_, + const std::string &delimiter) +{ + std::string input = input_; + size_t pos = 0; + std::vector res; + while ((pos = input.find(delimiter)) != std::string::npos) { + res.push_back(input.substr(0, pos)); + input.erase(0, pos + delimiter.length()); + } + res.push_back(input); + return res; +} + +bool +model_compatable( + std::string & model_version) +{ + std::vector words_mv = split(model_version, "."); + std::vector words_gmv = split(global_model_version, "."); + if(words_mv.size() != 2){ + throw std::runtime_error("invalid graph model version string " + model_version); + } + if(words_gmv.size() != 2){ + throw std::runtime_error("invalid supported model version string " + global_model_version); + } + int model_version_major = atoi(words_mv[0].c_str()); + int model_version_minor = atoi(words_mv[1].c_str()); + int MODEL_VERSION_MAJOR = atoi(words_gmv[0].c_str()); + int MODEL_VERSION_MINOR = atoi(words_gmv[1].c_str()); + if(model_version_major != MODEL_VERSION_MAJOR || + model_version_minor > MODEL_VERSION_MINOR){ + return false; + } + else{ + return true; + } +} + void select_by_type(std::vector & fwd_map, std::vector & bkw_map, diff --git a/source/api_cc/tests/CMakeLists.txt b/source/api_cc/tests/CMakeLists.txt index aedb5b0f20..8f4bef0492 100644 --- a/source/api_cc/tests/CMakeLists.txt +++ b/source/api_cc/tests/CMakeLists.txt @@ -11,8 +11,10 @@ enable_testing() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") -message(${PROJECT_SOURCE_DIR}) -message(${CMAKE_SOURCE_DIR}) +# model version +file(READ ${PROJECT_SOURCE_DIR}/../../train/MODEL_VER MODEL_VERSION) +string(REPLACE "\n" " " MODEL_VERSION ${MODEL_VERSION}) +message(STATUS "Current model version: ${MODEL_VERSION}") set(libname "deepmd") set(LIB_BASE_DIR ${CMAKE_SOURCE_DIR}/../../lib) diff --git a/source/tests/infer/dipolecharge_d.pbtxt b/source/tests/infer/dipolecharge_d.pbtxt index 6bd2b430e3..6be963119f 100644 --- a/source/tests/infer/dipolecharge_d.pbtxt +++ b/source/tests/infer/dipolecharge_d.pbtxt @@ -7103,6 +7103,27 @@ node { } } } +node { + name: "model_attr/model_version" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "1.0" + } + } + } +} node { name: "model_attr/sel_type" op: "Const" diff --git a/source/tests/infer/dipolecharge_e.pbtxt b/source/tests/infer/dipolecharge_e.pbtxt index 61326137aa..ec9412a111 100644 --- a/source/tests/infer/dipolecharge_e.pbtxt +++ b/source/tests/infer/dipolecharge_e.pbtxt @@ -12402,6 +12402,27 @@ node { } } } +node { + name: "model_attr/model_version" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "1.0" + } + } + } +} node { name: "model_attr/tmap" op: "Const" @@ -56014,6 +56035,27 @@ node { } } } +node { + name: "dipole_charge/model_attr/model_version" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "1.0" + } + } + } +} node { name: "dipole_charge/model_attr/sel_type" op: "Const" From 8d57dc3446e877e9047d18abf90e26c48029d815 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 16 Mar 2021 11:37:26 +0800 Subject: [PATCH 3/6] improved cmake message for model version --- source/CMakeLists.txt | 2 +- source/api_cc/tests/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index 3aa8373aba..07f5bf7760 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -49,7 +49,7 @@ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wno-ignored-attributes") # model version file(READ ${PROJECT_SOURCE_DIR}/train/MODEL_VER MODEL_VERSION) string(REPLACE "\n" " " MODEL_VERSION ${MODEL_VERSION}) -message(STATUS "Current model version: ${MODEL_VERSION}") +message(STATUS "Supported model version: ${MODEL_VERSION}") # define USE_CUDA_TOOLKIT if (DEFINED USE_CUDA_TOOLKIT) diff --git a/source/api_cc/tests/CMakeLists.txt b/source/api_cc/tests/CMakeLists.txt index 8f4bef0492..aabf452038 100644 --- a/source/api_cc/tests/CMakeLists.txt +++ b/source/api_cc/tests/CMakeLists.txt @@ -14,7 +14,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") # model version file(READ ${PROJECT_SOURCE_DIR}/../../train/MODEL_VER MODEL_VERSION) string(REPLACE "\n" " " MODEL_VERSION ${MODEL_VERSION}) -message(STATUS "Current model version: ${MODEL_VERSION}") +message(STATUS "Supported model version: ${MODEL_VERSION}") set(libname "deepmd") set(LIB_BASE_DIR ${CMAKE_SOURCE_DIR}/../../lib) From bf775f7051dbcde8a205be40ac023a0769a23ee9 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 16 Mar 2021 11:41:17 +0800 Subject: [PATCH 4/6] print model version on summary string --- source/api_cc/src/DeepPot.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/source/api_cc/src/DeepPot.cc b/source/api_cc/src/DeepPot.cc index f122f1173d..a5b9e6db8e 100644 --- a/source/api_cc/src/DeepPot.cc +++ b/source/api_cc/src/DeepPot.cc @@ -246,6 +246,7 @@ print_summary(const std::string &pre) const std::cout << pre << "source brach: " + global_git_branch << std::endl; std::cout << pre << "source commit: " + global_git_hash << std::endl; std::cout << pre << "source commit at: " + global_git_date << std::endl; + std::cout << pre << "surpport model ver.:" + global_model_version << std::endl; std::cout << pre << "build float prec: " + global_float_prec << std::endl; std::cout << pre << "build with tf inc: " + global_tf_include_dir << std::endl; std::cout << pre << "build with tf lib: " + global_tf_lib << std::endl; From 26b24cdf8ab3601f998d04d270adaf2459647dfd Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 16 Mar 2021 14:17:28 +0800 Subject: [PATCH 5/6] fix bugs in UT --- deepmd/infer/deep_eval.py | 1 + source/tests/test_data_modifier.py | 4 ++-- source/tests/test_data_modifier_shuffle.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index dc31995e58..83f9a7aa6d 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -28,6 +28,7 @@ def __init__( ) self.load_prefix = load_prefix + # graph_compatable should be called after graph and prefix are set if not self._graph_compatable(): raise RuntimeError( f"model in graph (version {self.model_version}) is incompatible" diff --git a/source/tests/test_data_modifier.py b/source/tests/test_data_modifier.py index 70999f5734..9c47fa15dd 100644 --- a/source/tests/test_data_modifier.py +++ b/source/tests/test_data_modifier.py @@ -79,7 +79,7 @@ def _setUp(self): sess.run(init_op) graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() - nodes = "o_dipole,o_rmat,o_rmat_deriv,o_nlist,o_rij,descrpt_attr/rcut,descrpt_attr/ntypes,descrpt_attr/sel,descrpt_attr/ndescrpt,model_attr/tmap,model_attr/sel_type,model_attr/model_type" + nodes = "o_dipole,o_rmat,o_rmat_deriv,o_nlist,o_rij,descrpt_attr/rcut,descrpt_attr/ntypes,descrpt_attr/sel,descrpt_attr/ndescrpt,model_attr/tmap,model_attr/sel_type,model_attr/model_type,model_attr/output_dim,model_attr/model_version" output_graph_def = tf.graph_util.convert_variables_to_constants( sess, input_graph_def, @@ -150,7 +150,7 @@ def _test_fv (self): num_deriv = np.transpose(num_deriv, [0,2,1]) t_esti = np.matmul(num_deriv, box3) - print(t_esti, '\n', vv.reshape([-1, 3, 3])) + # print(t_esti, '\n', vv.reshape([-1, 3, 3])) for ff in range(nframes): for ii in range(3): for jj in range(3): diff --git a/source/tests/test_data_modifier_shuffle.py b/source/tests/test_data_modifier_shuffle.py index 39269c2ff5..36ebfef35a 100644 --- a/source/tests/test_data_modifier_shuffle.py +++ b/source/tests/test_data_modifier_shuffle.py @@ -83,7 +83,7 @@ def _setUp(self): sess.run(init_op) graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() - nodes = "o_dipole,o_rmat,o_rmat_deriv,o_nlist,o_rij,descrpt_attr/rcut,descrpt_attr/ntypes,descrpt_attr/sel,descrpt_attr/ndescrpt,model_attr/tmap,model_attr/sel_type,model_attr/model_type" + nodes = "o_dipole,o_rmat,o_rmat_deriv,o_nlist,o_rij,descrpt_attr/rcut,descrpt_attr/ntypes,descrpt_attr/sel,descrpt_attr/ndescrpt,model_attr/tmap,model_attr/sel_type,model_attr/model_type,model_attr/output_dim,model_attr/model_version" output_graph_def = tf.graph_util.convert_variables_to_constants( sess, input_graph_def, From 91971a2585ea2df5a686d7741bd1ac5dae34efa6 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 16 Mar 2021 15:02:20 +0800 Subject: [PATCH 6/6] fix bug of model path in UT --- source/tests/test_deeppot_a.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index b9a519672d..8541cb8fa3 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -25,18 +25,19 @@ def setUp(self): if "string_val:" in data[jj]: data[jj] = data[jj].replace(MODEL_VERSION, "0.0") break - with open("deeppot-ver.pbtxt", "w") as fp: + self.version_pbtxt = str(tests_path / "deeppot-ver.pbtxt") + self.version_pb = str(tests_path / "deeppot.pb") + with open(self.version_pbtxt, "w") as fp: fp.write("\n".join(data)) - - convert_pbtxt_to_pb(str(tests_path / os.path.join("deeppot-ver.pbtxt")), "deeppot.pb") + convert_pbtxt_to_pb(self.version_pbtxt, self.version_pb) def tearDown(self): - os.remove("deeppot-ver.pbtxt") - os.remove("deeppot.pb") + os.remove(self.version_pbtxt) + os.remove(self.version_pb) def test(self): with self.assertRaises(RuntimeError) as context: - DeepPot("deeppot.pb") + DeepPot(str(self.version_pb)) self.assertTrue('incompatible' in str(context.exception)) self.assertTrue(MODEL_VERSION in str(context.exception)) self.assertTrue('0.0' in str(context.exception)) @@ -54,18 +55,19 @@ def setUp(self): if "string_val:" in data[jj]: data[jj] = data[jj].replace(MODEL_VERSION, "0.1000000") break - with open("deeppot-ver.pbtxt", "w") as fp: + self.version_pbtxt = str(tests_path / "deeppot-ver.pbtxt") + self.version_pb = str(tests_path / "deeppot.pb") + with open(self.version_pbtxt, "w") as fp: fp.write("\n".join(data)) - - convert_pbtxt_to_pb(str(tests_path / os.path.join("deeppot-ver.pbtxt")), "deeppot.pb") + convert_pbtxt_to_pb(self.version_pbtxt, self.version_pb) def tearDown(self): - os.remove("deeppot-ver.pbtxt") - os.remove("deeppot.pb") + os.remove(self.version_pbtxt) + os.remove(self.version_pb) def test(self): with self.assertRaises(RuntimeError) as context: - DeepPot("deeppot.pb") + DeepPot(self.version_pb) self.assertTrue('incompatible' in str(context.exception)) self.assertTrue(MODEL_VERSION in str(context.exception)) self.assertTrue('0.1000000' in str(context.exception))