Skip to content

Commit

Permalink
[NNVM][TENSORFLOW] Cleanup Const, Placeholder, _input_shapes.
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Aug 4, 2018
1 parent 6d68a18 commit 94f1eeb
Showing 1 changed file with 61 additions and 74 deletions.
135 changes: 61 additions & 74 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,28 +833,6 @@ def _get_abs_layer_name(node):
params, num_layers)
return sym


def _parse_import_prerequisites(graph):
""" Calculate the named preconditions from TensorFlow `graph`.
Return prerequisites for parsing:
a. Set of operator names which don't have their mapping in TVM, i.e.
which are not supported
"""
missing_operators = set()
for node in graph.node:
if node.op == "Placeholder":
pass
elif node.op == "Const":
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
pass
else:
missing_operators.add(node.op)

return missing_operators


class GraphProto(object):
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef.
Definition:
Expand All @@ -863,12 +841,8 @@ class GraphProto(object):
def __init__(self):
self._nodes = {}
self._params = {}
self._renames = {}
self._replacements = {}
self._output_shapes = {}
self._num_input = 0
self._num_param = 0
self._input_node = ''
self._num_rnn_layer = False

def from_tensorflow(self, graph):
Expand Down Expand Up @@ -907,7 +881,7 @@ def from_tensorflow(self, graph):
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))

missing_operators = _parse_import_prerequisites(graph)
missing_operators = self._parse_import_prerequisites(graph)

if missing_operators:
raise NotImplementedError( \
Expand All @@ -917,58 +891,42 @@ def from_tensorflow(self, graph):
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.

input_shapes = {}

attr = self._parse_attr(node.attr)

#Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const':
tensor_value = attr['value']
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList( \
tensor_value.tensor_shape)]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in attr['_output_shapes']]
else:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")

if node.op == "Placeholder":
self._input_node = node.name
self._num_input += 1
self._nodes[node.name] = _sym.Variable(name=node.name,
shape=self._output_shapes[node.name][0])

try:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in self._parse_attr(node.attr)['_output_shapes']]
self._nodes[node.name] = _sym.Variable(name=node.name,
shape=self._output_shapes[node.name][0])
input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
except KeyError:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
#input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
elif node.op == "Const":
if self._input_node == '':
self._input_node = node.name
self._num_input += 1
self._nodes[node.name] = _sym.Variable(name=node.name)
else:
# Rest all nodes are Param nodes, lets parse
self._num_param += 1
for key, value in node.attr.items():
self._parse_param(key, value, node.name)
if node.name not in self._nodes:
raise NotImplementedError( \
"Const {} couldn't be converted to Param.".format(node.name))
attr = self._parse_attr(node.attr)
#Variable converted to Const will not have only value attr
if 'value' in attr:
tensor_value = attr['value']
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList( \
tensor_value.tensor_shape)]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in self._parse_attr(node.attr)['_output_shapes']]
else:
# All Const nodes are Param nodes, lets parse
self._num_param += 1
for key, value in node.attr.items():
self._parse_param(key, value, node.name)
if node.name not in self._nodes:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
else:
"Const {} couldn't be converted to Param.".format(node.name))

attr = self._parse_attr(node.attr)
try:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in attr['_output_shapes']]
except KeyError:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")

else:
# Pass the parsed shapes instead
attr["_output_shapes"] = self._output_shapes[node.name]

Expand All @@ -983,11 +941,12 @@ def from_tensorflow(self, graph):
if ":" in node.input[0]:
in_name, _ = node.input[0].split(':')
node.input[0] = in_name

# Fill shapes for all inputs in a list
try:
inputs = [self._nodes[i] for i in node.input]
for i in node.input:
if i not in self._params:
input_shapes[self._nodes[i]] = self._output_shapes[i]
input_shapes[self._nodes[i]] = self._output_shapes[i]
attr['_input_shapes'] = input_shapes
except KeyError:
# TODO: Need to find clean way to handle '^CheckNumerics'
Expand All @@ -999,18 +958,40 @@ def from_tensorflow(self, graph):
# Assuming only one output.
self._nodes[node.name] = op
node_output = op

# Assume the final node is the output node
out = node_output

#Add the RNN outputs also with 'head' nodes of the nnvm graph
if self._num_rnn_layer:
out_rnn = _sym.concatenate(*self._out_rnn, axis=0)
out = [out, out_rnn]

if isinstance(out, list):
out = _sym.Group(out)

return out, self._params

def _parse_import_prerequisites(self, graph):
""" Calculate the named preconditions from TensorFlow `graph`.
Return prerequisites for parsing:
a. Set of operator names which don't have their mapping in TVM, i.e.
which are not supported
"""
missing_operators = set()
for node in graph.node:
if node.op == "Placeholder":
pass
elif node.op == "Const":
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
pass
else:
missing_operators.add(node.op)

return missing_operators

def _parse_param(self, key, value, name):
try:
from tensorflow.python.framework import tensor_util
Expand All @@ -1020,6 +1001,12 @@ def _parse_param(self, key, value, name):

if key == 'value':
np_array = tensor_util.MakeNdarray(value.tensor)

if np_array.dtype == np.dtype(object):
# Object types are generally tensorflow DT_STRING (DecodeJpeg op). Just leave it as placeholder.
self._nodes[name] = _sym.Variable(name=name)
return

array_ndim = len(np_array.shape)
if array_ndim == 0:
new_array = np.empty([1], dtype=np_array.dtype)
Expand Down

0 comments on commit 94f1eeb

Please sign in to comment.