diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 70c024cf480f8..d91e3954e0e80 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -7,6 +7,7 @@ import numpy as np import tvm +import warnings from .. import symbol as _sym from .. import graph as _graph from .. compiler import graph_util, build_module @@ -303,7 +304,8 @@ def _impl(inputs, attr, params): def _decode_image(): def _impl(inputs, attr, params): # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. - print("DecodeJpeg: It's a pass through, please handle preprocessing before input") + warnings.warn("DecodeJpeg: It's a pass through, " + "please handle preprocessing before input") return inputs[0] return _impl @@ -938,8 +940,6 @@ def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1): 'Split' : _split(False), 'SplitV' : _split(True), 'Unpack' : _unpack(), - 'QueueDequeueManyV2' : _undef(), - 'FIFOQueueV2' : _undef(), } # _convert_map_rnn defines maps of rnn operator name to @@ -1184,42 +1184,57 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): if missing_operators: raise NotImplementedError( \ "The following operators are not implemented: {}".format(missing_operators)) + for node in graph.node: if node.op == 'Placeholder': - self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(node.attr['shape'].shape) - self._input_shapes[node.name][0] = 1 + if shape and node.name in shape: + self._input_shapes[node.name] = list(shape[node.name]) + continue + self._input_shapes[node.name] = \ + tensor_util.TensorShapeProtoToList(node.attr['shape'].shape) + for idx, dim in enumerate(self._input_shapes[node.name]): + if dim < 0: + self._input_shapes[node.name][idx] = 1 + warnings.warn("Use 1 instead of -1 in shape of operator %s." + % node.name) + + # Ignore user's input shape for Non placeholder elif node.op == 'Const': tensor_value = node.attr['value'].tensor - self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape) + self._input_shapes[node.name] = \ + tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape) + if shape and node.name in shape: + warnings.warn("Ignore the passed shape. " + "Shape in graphdef will be used for operator %s." % node.name) final_op = None # Parse the nodes to re-create TF graph using Symbol API of NNVM for node in graph.node: - # Tensorflow doesn't have seperate list for params extraction. + # Tensorflow doesn't have separate list for params extraction. # Operator name 'Const' is treated as a parameter to build NNVM params dict. input_shapes = {} input_0d_mismatch = set() attr = self._parse_attr(node.attr) - #Variable converted to Const will not have only value attr + # Variable converted to Const will not have only value attr if 'value' in attr and node.op == 'Const': self._output_shapes[node.name] = [self._input_shapes[node.name]] + elif shape and node.name in shape: + # Give priority to user argument. + self._output_shapes[node.name] = [shape[node.name]] elif node.op == 'Placeholder': self._output_shapes[node.name] = [self._input_shapes[node.name]] - elif shape and node.name in shape: - # Give priority to user argument. - self._output_shapes[node.name] = [shape[node.name]] elif '_output_shapes' in attr: self._output_shapes[node.name] = \ [tensor_util.TensorShapeProtoToList(tshape) \ for tshape in attr['_output_shapes']] - elif shape: + else: # Keep the list indexable to avoid key error. # Actual value will be filled after node creation. + # Will infer shapes if the graph is not frozen with add_shapes=True self._output_shapes[node.name] = [None] - else: - self._output_shapes[node.name] = None + self._outputs_are_0d[node.name] = [ \ not tshape if isinstance(tshape, list) else False \ for tshape in self._output_shapes[node.name]] @@ -1241,7 +1256,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): else: # Pass the parsed shapes instead - output_shapes = self._output_shapes[node.name] + attr["_output_shapes"] = output_shapes = self._output_shapes[node.name] # Pass the node name too in attr attr["_node_name"] = node.name @@ -1291,20 +1306,27 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): # Assuming only one output. self._nodes[node.name] = op final_op = op - # Infer shapes if passed explicitely - node_output = self._nodes[node.name] - if shape: - g = _graph.create(node_output) - shape_dict = {k: v.shape for k, v in self._params.items()} - shape_dict.update(shape) - _, out_shapes = graph_util.infer_shape(g, **shape_dict) - self._output_shapes[node.name] = out_shapes - elif output_shapes == None: - g = _graph.create(node_output) - self._output_shapes[node.name] = list(graph_util.infer_shape(g, **self._input_shapes))[-1] + + # Infer shapes even without specifying "add_shapes=True" + if output_shapes == [None]: + g = _graph.create(final_op) + self._output_shapes[node.name] = \ + list(graph_util.infer_shape(g, **self._input_shapes))[-1] else: self._output_shapes[node.name] = output_shapes + if self._output_shapes[node.name] and shape and node.name in shape: + assert self._input_shapes[node.name] == list(shape[node.name]) + + # Infer shapes if passed explicitely + node_output = self._nodes[node.name] + if shape: + g = _graph.create(node_output) + shape_dict = {k: v.shape for k, v in self._params.items()} + shape_dict.update(shape) + _, out_shapes = graph_util.infer_shape(g, **shape_dict) + self._output_shapes[node.name] = out_shapes + out = [] if outputs is None: out.append(final_op) diff --git a/nnvm/python/nnvm/frontend/util/tensorflow_parser.py b/nnvm/python/nnvm/frontend/util/tensorflow_parser.py index ddd50633843f4..9b745c9d02c9a 100644 --- a/nnvm/python/nnvm/frontend/util/tensorflow_parser.py +++ b/nnvm/python/nnvm/frontend/util/tensorflow_parser.py @@ -2,32 +2,13 @@ from __future__ import absolute_import as _abs from __future__ import print_function import os - -try: - from tensorflow.core.framework import graph_pb2 -except ImportError as e: - from nnvm.frontend.protobuf import graph_pb2 - - -try: - from tempfile import TemporaryDirectory -except ImportError: - import tempfile - import shutil - - class TemporaryDirectory(object): - def __enter__(self): - self.name = tempfile.mkdtemp() - return self.name - - def __exit__(self, exc, value, tb): - shutil.rmtree(self.name) +from tensorflow.core.framework import graph_pb2 +from tvm.contrib import util class TFParser(object): """A Wrapper to handle tensorflow models parsing - Works w/o installing tensorflow, - Protocol Buffer is needed + TensorFlow is needed ``` parser = TfParser(model_dir) graph = parser.parse() @@ -39,7 +20,7 @@ class TFParser(object): """ def __init__(self, model_dir): - self._tmp_dir = TemporaryDirectory() + self._tmp_dir = util.tempdir() self._model_dir = model_dir self._graph = graph_pb2.GraphDef() @@ -51,21 +32,6 @@ def _get_graph(self): """Get Graph""" return self._graph - def _output_graph(self): - import logging - logging.basicConfig(level=logging.DEBUG) - for node in self._get_graph().node: - logging.info("Name: {}".format(node.name)) - logging.info("\top: {}".format(node.op)) - for input in node.input: - logging.info("\t\tinput: {}".format(input)) - logging.info("\t\tdevice: {}".format(node.device)) - logging.info("\t\tAttrValue: ") - for key in node.attr.keys(): - logging.info("\t\t\tkey: {} => value: {}" - .format(key, node.attr[key])) - logging.info(node.attr['shape'].shape) - def _load_pb_file(self): """Load single pb file""" graph = self._get_graph() @@ -73,19 +39,30 @@ def _load_pb_file(self): graph.ParseFromString(f.read()) return graph - def _get_output_names(self, model_path): + def _get_tag_set(self): + """Return the tag set of saved model, multiple metagraphs are not supported""" + try: + from tensorflow.contrib.saved_model.python.saved_model import reader + except ImportError: + raise ImportError( + "InputConfiguration: Unable to import saved_model.reader which is " + "required to get tag set from saved model.") + tag_sets = reader.get_saved_model_tag_sets(self._model_dir) + return tag_sets[0] + + def _get_output_names(self): """Return the concatenated output names""" try: import tensorflow as tf - except ImportError as e: + except ImportError: raise ImportError( "InputConfiguration: Unable to import tensorflow which is " - "required to restore from saved model. {}".format(e)) - + "required to restore from saved model.") + tags = self._get_tag_set() with tf.Session() as sess: meta_graph_def = tf.saved_model.loader.load(sess, - [tf.saved_model.tag_constants.SERVING], - model_path) + tags, + self._model_dir) output_names = set() for k in meta_graph_def.signature_def.keys(): outputs_tensor_info = meta_graph_def.signature_def[k].outputs @@ -97,19 +74,18 @@ def _get_output_names(self, model_path): def _load_saved_model(self): """Load the tensorflow saved model.""" try: - import tensorflow as tf from tensorflow.python.tools import freeze_graph from tensorflow.python.framework import ops from tensorflow.python.framework import graph_util - except ImportError as e: + except ImportError: raise ImportError( "InputConfiguration: Unable to import tensorflow which is " - "required to restore from saved model. {}".format(e)) + "required to restore from saved model.") saved_model_dir = self._model_dir - output_graph_filename = os.path.join(self._tmp_dir.name, "neo_frozen_model.pb") + output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb") input_saved_model_dir = saved_model_dir - output_node_names = self._get_output_names(self._model_dir) + output_node_names = self._get_output_names() input_binary = False input_saver_def_path = False @@ -119,7 +95,7 @@ def _load_saved_model(self): input_meta_graph = False checkpoint_path = None input_graph_filename = None - saved_model_tags = tf.saved_model.tag_constants.SERVING + saved_model_tags = ",".join(self._get_tag_set()) freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path, input_binary, checkpoint_path, output_node_names, @@ -145,6 +121,7 @@ def parse(self): file. """ graph = None + if os.path.isdir(self._model_dir): ckpt = os.path.join(self._model_dir, "checkpoint") if not os.path.isfile(ckpt):