diff --git a/deepmd/common.py b/deepmd/common.py index 94f2e43cd6..03d7d8caf3 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -42,12 +42,6 @@ "float64": tf.float64, } -PRECISION_MAPPING: Dict[int, type] = { - 1: np.float32, - 2: np.float64, - 19: np.float16, -} - def gelu(x: tf.Tensor) -> tf.Tensor: """Gaussian Error Linear Unit. diff --git a/deepmd/entrypoints/transfer.py b/deepmd/entrypoints/transfer.py index 9efc07c668..576df74c80 100644 --- a/deepmd/entrypoints/transfer.py +++ b/deepmd/entrypoints/transfer.py @@ -2,7 +2,6 @@ from typing import Dict, Optional, Sequence, Tuple from deepmd.env import tf -from deepmd.common import PRECISION_MAPPING import re import numpy as np import logging @@ -121,8 +120,8 @@ def transform_graph(raw_graph: tf.Graph, old_graph: tf.Graph) -> tf.Graph: check_dim(raw_graph_node, old_graph_node, node.name) tensor_shape = [dim.size for dim in raw_node.tensor_shape.dim] - old_graph_dtype = PRECISION_MAPPING[old_node.dtype] - raw_graph_dtype = PRECISION_MAPPING[raw_node.dtype] + old_graph_dtype = tf.as_dtype(old_node.dtype).as_numpy_dtype + raw_graph_dtype = tf.as_dtype(raw_node.dtype).as_numpy_dtype log.info( f"{node.name} is passed from old graph({old_graph_dtype}) " f"to raw graph({raw_graph_dtype})" diff --git a/deepmd/utils/graph.py b/deepmd/utils/graph.py index ace87f5438..31452be7f5 100644 --- a/deepmd/utils/graph.py +++ b/deepmd/utils/graph.py @@ -2,7 +2,6 @@ import numpy as np from typing import Tuple, Dict from deepmd.env import tf -from deepmd.common import PRECISION_MAPPING from deepmd.utils.sess import run_sess from deepmd.utils.errors import GraphWithoutTensorError @@ -174,7 +173,7 @@ def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict: embedding_net_nodes = get_embedding_net_nodes_from_graph_def(graph_def) for item in embedding_net_nodes: node = embedding_net_nodes[item] - dtype = PRECISION_MAPPING[node.dtype] + dtype = tf.as_dtype(node.dtype).as_numpy_dtype tensor_shape = tf.TensorShape(node.tensor_shape).as_list() if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): tensor_value = np.frombuffer(node.tensor_content) @@ -262,7 +261,7 @@ def get_fitting_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict: fitting_net_nodes = get_fitting_net_nodes_from_graph_def(graph_def) for item in fitting_net_nodes: node = fitting_net_nodes[item] - dtype= PRECISION_MAPPING[node.dtype] + dtype= tf.as_dtype(node.dtype).as_numpy_dtype tensor_shape = tf.TensorShape(node.tensor_shape).as_list() if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): tensor_value = np.frombuffer(node.tensor_content)