diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 9a302da72ae6..095313110c3b 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -3,6 +3,7 @@ from __future__ import absolute_import as _abs from __future__ import print_function +import warnings # Numpy support import numpy as np @@ -303,7 +304,8 @@ def _impl(inputs, attr, params): def _decode_image(): def _impl(inputs, attr, params): # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. - print("DecodeJpeg: It's a pass through, please handle preprocessing before input") + warnings.warn("DecodeJpeg: It's a pass through, " + "please handle preprocessing before input") return inputs[0] return _impl @@ -355,6 +357,11 @@ def _impl(inputs, attr, params): return _impl +def _undef(): + def _impl(inputs, attr, params): + return _sym.__undef__() + return _impl + def _identity(): def _impl(inputs, attr, params): return inputs[0] @@ -1129,6 +1136,7 @@ def __init__(self): self._num_param = 0 self._num_rnn_layer = False self._outputs_are_0d = {} + self._input_shapes = {} def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): """Construct nnvm nodes from tensorflow graph definition - GraphDef. @@ -1177,43 +1185,63 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): raise NotImplementedError( \ "The following operators are not implemented: {}".format(missing_operators)) + for node in graph.node: + if node.op == 'Placeholder': + if shape and node.name in shape: + self._input_shapes[node.name] = list(shape[node.name]) + continue + self._input_shapes[node.name] = \ + tensor_util.TensorShapeProtoToList(node.attr['shape'].shape) + for idx, dim in enumerate(self._input_shapes[node.name]): + if dim < 0: + self._input_shapes[node.name][idx] = 1 + warnings.warn("Use 1 instead of -1 in shape of operator %s." + % node.name) + + # Ignore user's input shape for Non placeholder + elif node.op == 'Const': + tensor_value = node.attr['value'].tensor + self._input_shapes[node.name] = \ + tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape) + if shape and node.name in shape: + warnings.warn("Ignore the passed shape. " + "Shape in graphdef will be used for operator %s." % node.name) + final_op = None # Parse the nodes to re-create TF graph using Symbol API of NNVM for node in graph.node: - # Tensorflow doesn't have seperate list for params extraction. + # Tensorflow doesn't have separate list for params extraction. # Operator name 'Const' is treated as a parameter to build NNVM params dict. input_shapes = {} input_0d_mismatch = set() attr = self._parse_attr(node.attr) - #Variable converted to Const will not have only value attr + # Variable converted to Const will not have only value attr if 'value' in attr and node.op == 'Const': - tensor_value = attr['value'] - self._output_shapes[node.name] = \ - [tensor_util.TensorShapeProtoToList( \ - tensor_value.tensor_shape)] + self._output_shapes[node.name] = [self._input_shapes[node.name]] elif shape and node.name in shape: # Give priority to user argument. self._output_shapes[node.name] = [shape[node.name]] + elif node.op == 'Placeholder': + self._output_shapes[node.name] = [self._input_shapes[node.name]] elif '_output_shapes' in attr: self._output_shapes[node.name] = \ [tensor_util.TensorShapeProtoToList(tshape) \ for tshape in attr['_output_shapes']] - elif shape: + else: # Keep the list indexable to avoid key error. # Actual value will be filled after node creation. + # Will infer shapes if the graph is not frozen with add_shapes=True self._output_shapes[node.name] = [None] - else: - raise NotImplementedError( \ - "Please freeze the graph with add_shapes=True") + self._outputs_are_0d[node.name] = [ \ not tshape if isinstance(tshape, list) else False \ for tshape in self._output_shapes[node.name]] if node.op == "Placeholder": self._nodes[node.name] = _sym.Variable(name=node.name, - shape=self._output_shapes[node.name][0]) + shape=self._input_shapes[node.name]) elif node.op == "Const": # All Const nodes are Param nodes, lets parse @@ -1228,7 +1256,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): else: # Pass the parsed shapes instead - attr["_output_shapes"] = self._output_shapes[node.name] + attr["_output_shapes"] = output_shapes = self._output_shapes[node.name] # Pass the node name too in attr attr["_node_name"] = node.name @@ -1269,7 +1297,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): inputs = self._fix_extranodes(node.op, attr, inputs) op = self._convert_operator(node.op, inputs, attr, graph) - # Check is op is converted to param + # Check if op is converted to param if isinstance(op, np.ndarray): self._params[node.name] = tvm.nd.array(op) op = _sym.Variable(name=node.name, @@ -1279,9 +1307,19 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): self._nodes[node.name] = op final_op = op + # Infer shapes even without specifying "add_shapes=True" + if output_shapes == [None]: + g = _graph.create(final_op) + self._output_shapes[node.name] = \ + list(graph_util.infer_shape(g, **self._input_shapes))[-1] + + if self._output_shapes[node.name] and shape and node.name in shape: + assert self._output_shapes[node.name] == list(shape[node.name]) + # Infer shapes if passed explicitely node_output = self._nodes[node.name] - if shape: + if shape and (not self._output_shapes[node.name][0] + or -1 in self._output_shapes[node.name][0]): g = _graph.create(node_output) shape_dict = {k: v.shape for k, v in self._params.items()} shape_dict.update(shape) diff --git a/nnvm/python/nnvm/frontend/util/__init__.py b/nnvm/python/nnvm/frontend/util/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nnvm/python/nnvm/frontend/util/tensorflow_parser.py b/nnvm/python/nnvm/frontend/util/tensorflow_parser.py new file mode 100644 index 000000000000..9b745c9d02c9 --- /dev/null +++ b/nnvm/python/nnvm/frontend/util/tensorflow_parser.py @@ -0,0 +1,153 @@ +"""TF: Tensorflow parser""" +from __future__ import absolute_import as _abs +from __future__ import print_function +import os +from tensorflow.core.framework import graph_pb2 +from tvm.contrib import util + + +class TFParser(object): + """A Wrapper to handle tensorflow models parsing + TensorFlow is needed + ``` + parser = TfParser(model_dir) + graph = parser.parse() + ``` + Parameters + ---------- + model_dir : tensorflow frozen pb file or a directory that contains saved + model or checkpoints. + """ + + def __init__(self, model_dir): + self._tmp_dir = util.tempdir() + self._model_dir = model_dir + self._graph = graph_pb2.GraphDef() + + def _set_graph(self, graph): + """Set Graph""" + self._graph = graph + + def _get_graph(self): + """Get Graph""" + return self._graph + + def _load_pb_file(self): + """Load single pb file""" + graph = self._get_graph() + with open(self._model_dir, "rb") as f: + graph.ParseFromString(f.read()) + return graph + + def _get_tag_set(self): + """Return the tag set of saved model, multiple metagraphs are not supported""" + try: + from tensorflow.contrib.saved_model.python.saved_model import reader + except ImportError: + raise ImportError( + "InputConfiguration: Unable to import saved_model.reader which is " + "required to get tag set from saved model.") + tag_sets = reader.get_saved_model_tag_sets(self._model_dir) + return tag_sets[0] + + def _get_output_names(self): + """Return the concatenated output names""" + try: + import tensorflow as tf + except ImportError: + raise ImportError( + "InputConfiguration: Unable to import tensorflow which is " + "required to restore from saved model.") + tags = self._get_tag_set() + with tf.Session() as sess: + meta_graph_def = tf.saved_model.loader.load(sess, + tags, + self._model_dir) + output_names = set() + for k in meta_graph_def.signature_def.keys(): + outputs_tensor_info = meta_graph_def.signature_def[k].outputs + for output_tensor in outputs_tensor_info.values(): + output_names.add(output_tensor.name) + output_names = [i.replace(":0", "") for i in output_names] + return ",".join(output_names) + + def _load_saved_model(self): + """Load the tensorflow saved model.""" + try: + from tensorflow.python.tools import freeze_graph + from tensorflow.python.framework import ops + from tensorflow.python.framework import graph_util + except ImportError: + raise ImportError( + "InputConfiguration: Unable to import tensorflow which is " + "required to restore from saved model.") + + saved_model_dir = self._model_dir + output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb") + input_saved_model_dir = saved_model_dir + output_node_names = self._get_output_names() + + input_binary = False + input_saver_def_path = False + restore_op_name = None + filename_tensor_name = None + clear_devices = True + input_meta_graph = False + checkpoint_path = None + input_graph_filename = None + saved_model_tags = ",".join(self._get_tag_set()) + + freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path, + input_binary, checkpoint_path, output_node_names, + restore_op_name, filename_tensor_name, + output_graph_filename, clear_devices, "", "", "", + input_meta_graph, input_saved_model_dir, + saved_model_tags) + + with ops.Graph().as_default(): + output_graph_def = graph_pb2.GraphDef() + with open(output_graph_filename, "rb") as f: + output_graph_def.ParseFromString(f.read()) + output_graph_def = graph_util.remove_training_nodes(output_graph_def) + return output_graph_def + + def _load_ckpt(self): + """TODO: Load checkpoint model.""" + raise RuntimeError("InputConfiguration: Loading tf checkpoint model is " + "not supported yet.") + + def parse(self): + """Parse tensorflow models: checkpoints, saved models, and single pb + file. + """ + graph = None + + if os.path.isdir(self._model_dir): + ckpt = os.path.join(self._model_dir, "checkpoint") + if not os.path.isfile(ckpt): + if not os.path.isdir(os.path.join(self._model_dir, "variables")): + raise RuntimeError("InputConfiguration: Invalid model path.") + graph = self._load_saved_model() + else: + graph = self._load_ckpt() + elif os.path.isfile(self._model_dir): + # Only .pb or .pbtxt is a valid suffix name. + if self._model_dir.endswith(".pb") or \ + self._model_dir.endswith(".pbtxt"): + cur_dir = os.path.dirname(self._model_dir) + else: + raise RuntimeError("InputConfiguration: Invalid model format.") + + # It is a saved model if `variables` directory is present at the + # same directory with the pb or pbtxt file. + if os.path.isdir(os.path.join(cur_dir, "variables")): + self._model_dir = cur_dir + graph = self._load_saved_model() + else: + graph = self._load_pb_file() + else: + raise RuntimeError("InputConfiguration: Unrecognized model " + "file or path.") + + self._set_graph(graph) + return graph