Skip to content

Commit

Permalink
fix single precision error
Browse files Browse the repository at this point in the history
  • Loading branch information
denghuilu committed Oct 14, 2021
1 parent 35d333e commit bc24ad4
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Tuple, List, Dict, Any

from deepmd.env import tf
from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, get_np_precision
from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter
from deepmd.utils.argcheck import list_to_doc
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION
from deepmd.env import GLOBAL_NP_FLOAT_PRECISION
Expand All @@ -13,7 +13,7 @@
from deepmd.utils.tabulate import DPTabulate
from deepmd.utils.type_embed import embed_atom_type
from deepmd.utils.sess import run_sess
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph, get_embedding_net_variables
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph
from .descriptor import Descriptor
from .se import DescrptSe

Expand Down Expand Up @@ -133,7 +133,6 @@ def __init__ (self,
self.compress_activation_fn = get_activation_func(activation_function)
self.filter_activation_fn = get_activation_func(activation_function)
self.filter_precision = get_precision(precision)
self.filter_np_precision = get_np_precision(precision)
self.exclude_types = set()
for tt in exclude_types:
assert(len(tt) == 2)
Expand Down Expand Up @@ -687,7 +686,7 @@ def _filter_lower(
net = 'filter_-1_net_' + str(type_i)
else:
net = 'filter_' + str(type_input) + '_net_' + str(type_i)
return op_module.tabulate_fusion(self.table.data[net].astype(self.filter_np_precision), info, xyz_scatter, tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1])
return op_module.tabulate_fusion(tf.cast(self.table.data[net], self.filter_precision), info, xyz_scatter, tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1])
else:
if (not is_exclude):
xyz_scatter = embedding_net(
Expand Down

0 comments on commit bc24ad4

Please sign in to comment.