Skip to content

Commit

Permalink
use TF's built-in method to get numpy dtype (deepmodeling#1035)
Browse files Browse the repository at this point in the history
* use TF's built-in method to get numpy dtype

I got a way to get the numpy type from a int. Take an example
```py
>>> tf.dtypes.as_dtype(19).as_numpy_dtype
<class 'numpy.float16'>
```

`PRECISION_MAPPING` is not used any more, as it's actually not a public API.

By the way, it also supports `str`
```py
>>> tf.dtypes.as_dtype("float16")
tf.float16
```

* sadly only `tf.as_dtype` is supported in TF 1.8
  • Loading branch information
njzjz authored Aug 26, 2021
1 parent 5d028c4 commit 0a6a392
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 12 deletions.
6 changes: 0 additions & 6 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@
"float64": tf.float64,
}

PRECISION_MAPPING: Dict[int, type] = {
1: np.float32,
2: np.float64,
19: np.float16,
}


def gelu(x: tf.Tensor) -> tf.Tensor:
"""Gaussian Error Linear Unit.
Expand Down
5 changes: 2 additions & 3 deletions deepmd/entrypoints/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import Dict, Optional, Sequence, Tuple
from deepmd.env import tf
from deepmd.common import PRECISION_MAPPING
import re
import numpy as np
import logging
Expand Down Expand Up @@ -121,8 +120,8 @@ def transform_graph(raw_graph: tf.Graph, old_graph: tf.Graph) -> tf.Graph:

check_dim(raw_graph_node, old_graph_node, node.name)
tensor_shape = [dim.size for dim in raw_node.tensor_shape.dim]
old_graph_dtype = PRECISION_MAPPING[old_node.dtype]
raw_graph_dtype = PRECISION_MAPPING[raw_node.dtype]
old_graph_dtype = tf.as_dtype(old_node.dtype).as_numpy_dtype
raw_graph_dtype = tf.as_dtype(raw_node.dtype).as_numpy_dtype
log.info(
f"{node.name} is passed from old graph({old_graph_dtype}) "
f"to raw graph({raw_graph_dtype})"
Expand Down
5 changes: 2 additions & 3 deletions deepmd/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import numpy as np
from typing import Tuple, Dict
from deepmd.env import tf
from deepmd.common import PRECISION_MAPPING
from deepmd.utils.sess import run_sess
from deepmd.utils.errors import GraphWithoutTensorError

Expand Down Expand Up @@ -174,7 +173,7 @@ def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict:
embedding_net_nodes = get_embedding_net_nodes_from_graph_def(graph_def)
for item in embedding_net_nodes:
node = embedding_net_nodes[item]
dtype = PRECISION_MAPPING[node.dtype]
dtype = tf.as_dtype(node.dtype).as_numpy_dtype
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
tensor_value = np.frombuffer(node.tensor_content)
Expand Down Expand Up @@ -262,7 +261,7 @@ def get_fitting_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict:
fitting_net_nodes = get_fitting_net_nodes_from_graph_def(graph_def)
for item in fitting_net_nodes:
node = fitting_net_nodes[item]
dtype= PRECISION_MAPPING[node.dtype]
dtype= tf.as_dtype(node.dtype).as_numpy_dtype
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
tensor_value = np.frombuffer(node.tensor_content)
Expand Down

0 comments on commit 0a6a392

Please sign in to comment.