Skip to content

Commit

Permalink
[NNVM] TF, Report missing operators in advance
Browse files Browse the repository at this point in the history
  • Loading branch information
sergei-mironov authored and Sergey Mironov committed Aug 2, 2018
1 parent 1a3264e commit fa69b93
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,28 @@ def _get_abs_layer_name(node):
params, num_layers)
return sym


def _parse_import_prerequisites(graph):
""" Calculate the named preconditions from TensorFlow `graph`.
Return prerequisites for parsing:
a. Set of operator names which don't have their mapping in TVM, i.e.
which are not supported
"""
missing_operators = set()
for node in graph.node:
if node.op == "Placeholder":
pass
elif node.op == "Const":
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
pass
else:
missing_operators.add(node.op)

return missing_operators


class GraphProto(object):
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef.
Definition:
Expand All @@ -902,7 +924,7 @@ def from_tensorflow(self, graph):
Follow the tensorflow graph definition to parse and convert it to NNVM.
Some of the assumptions listed below.
-> First Const or Placeholder node will be considered as graph input.
-> First Placeholder or Const node will be considered as graph input.
-> Rest all Const nodes are params.
-> Last node is assumed as graph output.
-> _output_shapes : Attribute should present in the tenserflow forzen graph.
Expand All @@ -911,6 +933,7 @@ def from_tensorflow(self, graph):
-> CheckNumerics: No implementation as of now for this.
Just copies input to output.
TODO: Change algorithm to stop treating first 'Const' in a special way.
Parameters
----------
Expand All @@ -924,23 +947,25 @@ def from_tensorflow(self, graph):
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
# Parse throught all nodes and start extracting
# params aka Const nodes
# input nodes : First const node
# normal nodes : other normal nodes

try:
from tensorflow.python.framework import tensor_util
except ImportError as e:
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))

missing_operators = _parse_import_prerequisites(graph)

if missing_operators:
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))

# Parse the nodes to re-create TF graph using Symbol API of NNVM
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes = {}
if node.op == "Placeholder":
# Assuming only one input graph with type 'Placeholder'
self._input_node = node.name
self._num_input += 1

Expand All @@ -955,7 +980,6 @@ def from_tensorflow(self, graph):
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
elif node.op == "Const":
# Assuming first Const node as Graph Input node
if self._input_node == '':
self._input_node = node.name
self._num_input += 1
Expand Down Expand Up @@ -998,7 +1022,7 @@ def from_tensorflow(self, graph):
# Pass the node name too in attr
attr["_node_name"] = node.name

#ToDo: Some of the tensorflow operators maintain internaly maintain
#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
Expand Down

0 comments on commit fa69b93

Please sign in to comment.