diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index bd9e01dc99c4..143d995803d8 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -833,28 +833,6 @@ def _get_abs_layer_name(node): params, num_layers) return sym - -def _parse_import_prerequisites(graph): - """ Calculate the named preconditions from TensorFlow `graph`. - Return prerequisites for parsing: - a. Set of operator names which don't have their mapping in TVM, i.e. - which are not supported - """ - missing_operators = set() - for node in graph.node: - if node.op == "Placeholder": - pass - elif node.op == "Const": - pass - else: - if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]): - pass - else: - missing_operators.add(node.op) - - return missing_operators - - class GraphProto(object): """ A helper class for handling nnvm graph copying from Tensorflow GraphDef. Definition: @@ -863,12 +841,8 @@ class GraphProto(object): def __init__(self): self._nodes = {} self._params = {} - self._renames = {} - self._replacements = {} self._output_shapes = {} - self._num_input = 0 self._num_param = 0 - self._input_node = '' self._num_rnn_layer = False def from_tensorflow(self, graph): @@ -907,7 +881,7 @@ def from_tensorflow(self, graph): raise ImportError( "Unable to import tensorflow which is required {}".format(e)) - missing_operators = _parse_import_prerequisites(graph) + missing_operators = self._parse_import_prerequisites(graph) if missing_operators: raise NotImplementedError( \ @@ -917,58 +891,42 @@ def from_tensorflow(self, graph): for node in graph.node: # Tensorflow doesn't have seperate list for params extraction. # Operator name 'Const' is treated as a parameter to build NNVM params dict. + input_shapes = {} + + attr = self._parse_attr(node.attr) + + #Variable converted to Const will not have only value attr + if 'value' in attr and node.op == 'Const': + tensor_value = attr['value'] + self._output_shapes[node.name] = \ + [tensor_util.TensorShapeProtoToList( \ + tensor_value.tensor_shape)] + elif '_output_shapes' in attr: + self._output_shapes[node.name] = \ + [tensor_util.TensorShapeProtoToList(shape) \ + for shape in attr['_output_shapes']] + else: + raise NotImplementedError( \ + "Please freeze the graph with add_shapes=True") + if node.op == "Placeholder": - self._input_node = node.name - self._num_input += 1 + self._nodes[node.name] = _sym.Variable(name=node.name, + shape=self._output_shapes[node.name][0]) - try: - self._output_shapes[node.name] = \ - [tensor_util.TensorShapeProtoToList(shape) \ - for shape in self._parse_attr(node.attr)['_output_shapes']] - self._nodes[node.name] = _sym.Variable(name=node.name, - shape=self._output_shapes[node.name][0]) - input_shapes[self._nodes[node.name]] = self._output_shapes[node.name] - except KeyError: - raise NotImplementedError( \ - "Please freeze the graph with add_shapes=True") + #input_shapes[self._nodes[node.name]] = self._output_shapes[node.name] elif node.op == "Const": - if self._input_node == '': - self._input_node = node.name - self._num_input += 1 - self._nodes[node.name] = _sym.Variable(name=node.name) - else: - # Rest all nodes are Param nodes, lets parse - self._num_param += 1 - for key, value in node.attr.items(): - self._parse_param(key, value, node.name) - if node.name not in self._nodes: - raise NotImplementedError( \ - "Const {} couldn't be converted to Param.".format(node.name)) - attr = self._parse_attr(node.attr) - #Variable converted to Const will not have only value attr - if 'value' in attr: - tensor_value = attr['value'] - self._output_shapes[node.name] = \ - [tensor_util.TensorShapeProtoToList( \ - tensor_value.tensor_shape)] - elif '_output_shapes' in attr: - self._output_shapes[node.name] = \ - [tensor_util.TensorShapeProtoToList(shape) \ - for shape in self._parse_attr(node.attr)['_output_shapes']] - else: + # All Const nodes are Param nodes, lets parse + self._num_param += 1 + for key, value in node.attr.items(): + self._parse_param(key, value, node.name) + if node.name not in self._nodes: raise NotImplementedError( \ - "Please freeze the graph with add_shapes=True") - else: + "Const {} couldn't be converted to Param.".format(node.name)) + attr = self._parse_attr(node.attr) - try: - self._output_shapes[node.name] = \ - [tensor_util.TensorShapeProtoToList(shape) \ - for shape in attr['_output_shapes']] - except KeyError: - raise NotImplementedError( \ - "Please freeze the graph with add_shapes=True") + else: # Pass the parsed shapes instead attr["_output_shapes"] = self._output_shapes[node.name] @@ -983,11 +941,12 @@ def from_tensorflow(self, graph): if ":" in node.input[0]: in_name, _ = node.input[0].split(':') node.input[0] = in_name + + # Fill shapes for all inputs in a list try: inputs = [self._nodes[i] for i in node.input] for i in node.input: - if i not in self._params: - input_shapes[self._nodes[i]] = self._output_shapes[i] + input_shapes[self._nodes[i]] = self._output_shapes[i] attr['_input_shapes'] = input_shapes except KeyError: # TODO: Need to find clean way to handle '^CheckNumerics' @@ -999,6 +958,7 @@ def from_tensorflow(self, graph): # Assuming only one output. self._nodes[node.name] = op node_output = op + # Assume the final node is the output node out = node_output @@ -1006,11 +966,32 @@ def from_tensorflow(self, graph): if self._num_rnn_layer: out_rnn = _sym.concatenate(*self._out_rnn, axis=0) out = [out, out_rnn] + if isinstance(out, list): out = _sym.Group(out) return out, self._params + def _parse_import_prerequisites(self, graph): + """ Calculate the named preconditions from TensorFlow `graph`. + Return prerequisites for parsing: + a. Set of operator names which don't have their mapping in TVM, i.e. + which are not supported + """ + missing_operators = set() + for node in graph.node: + if node.op == "Placeholder": + pass + elif node.op == "Const": + pass + else: + if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]): + pass + else: + missing_operators.add(node.op) + + return missing_operators + def _parse_param(self, key, value, name): try: from tensorflow.python.framework import tensor_util @@ -1020,6 +1001,13 @@ def _parse_param(self, key, value, name): if key == 'value': np_array = tensor_util.MakeNdarray(value.tensor) + + if np_array.dtype == np.dtype(object): + # Object types are generally tensorflow DT_STRING (DecodeJpeg op). + # Just leave it as placeholder. + self._nodes[name] = _sym.Variable(name=name) + return + array_ndim = len(np_array.shape) if array_ndim == 0: new_array = np.empty([1], dtype=np_array.dtype)