diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 73b291e486af..636f55adb863 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -880,6 +880,28 @@ def _get_abs_layer_name(node): params, num_layers) return sym + +def _parse_import_prerequisites(graph): + """ Calculate the named preconditions from TensorFlow `graph`. + Return prerequisites for parsing: + a. Set of operator names which don't have their mapping in TVM, i.e. + which are not supported + """ + missing_operators = set() + for node in graph.node: + if node.op == "Placeholder": + pass + elif node.op == "Const": + pass + else: + if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]): + pass + else: + missing_operators.add(node.op) + + return missing_operators + + class GraphProto(object): """ A helper class for handling nnvm graph copying from Tensorflow GraphDef. Definition: @@ -902,7 +924,7 @@ def from_tensorflow(self, graph): Follow the tensorflow graph definition to parse and convert it to NNVM. Some of the assumptions listed below. - -> First Const or Placeholder node will be considered as graph input. + -> First Placeholder or Const node will be considered as graph input. -> Rest all Const nodes are params. -> Last node is assumed as graph output. -> _output_shapes : Attribute should present in the tenserflow forzen graph. @@ -911,6 +933,7 @@ def from_tensorflow(self, graph): -> CheckNumerics: No implementation as of now for this. Just copies input to output. + TODO: Change algorithm to stop treating first 'Const' in a special way. Parameters ---------- @@ -924,10 +947,6 @@ def from_tensorflow(self, graph): params : dict A dict of name: tvm.nd.array pairs, used as pretrained weights """ - # Parse throught all nodes and start extracting - # params aka Const nodes - # input nodes : First const node - # normal nodes : other normal nodes try: from tensorflow.python.framework import tensor_util @@ -935,12 +954,18 @@ def from_tensorflow(self, graph): raise ImportError( "Unable to import tensorflow which is required {}".format(e)) + missing_operators = _parse_import_prerequisites(graph) + + if missing_operators: + raise NotImplementedError( \ + "The following operators are not implemented: {}".format(missing_operators)) + + # Parse the nodes to re-create TF graph using Symbol API of NNVM for node in graph.node: # Tensorflow doesn't have seperate list for params extraction. # Operator name 'Const' is treated as a parameter to build NNVM params dict. input_shapes = {} if node.op == "Placeholder": - # Assuming only one input graph with type 'Placeholder' self._input_node = node.name self._num_input += 1 @@ -955,7 +980,6 @@ def from_tensorflow(self, graph): raise NotImplementedError( \ "Please freeze the graph with add_shapes=True") elif node.op == "Const": - # Assuming first Const node as Graph Input node if self._input_node == '': self._input_node = node.name self._num_input += 1 @@ -998,7 +1022,7 @@ def from_tensorflow(self, graph): # Pass the node name too in attr attr["_node_name"] = node.name - #ToDo: Some of the tensorflow operators maintain internaly maintain + #ToDo: Some of the tensorflow operators internaly maintain #execution layers and its output name will the layer number along with #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the #output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,