Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tensorflow Frontend] Add saved model support and infer shape w/o add_shapes=True #2493

Merged
merged 3 commits into from
Feb 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 53 additions & 15 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs
from __future__ import print_function

import warnings
# Numpy support
import numpy as np

Expand Down Expand Up @@ -303,7 +304,8 @@ def _impl(inputs, attr, params):
def _decode_image():
def _impl(inputs, attr, params):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
print("DecodeJpeg: It's a pass through, please handle preprocessing before input")
warnings.warn("DecodeJpeg: It's a pass through, "
"please handle preprocessing before input")
return inputs[0]
return _impl

Expand Down Expand Up @@ -355,6 +357,11 @@ def _impl(inputs, attr, params):

return _impl

def _undef():
def _impl(inputs, attr, params):
return _sym.__undef__()
return _impl

def _identity():
def _impl(inputs, attr, params):
return inputs[0]
Expand Down Expand Up @@ -1129,6 +1136,7 @@ def __init__(self):
self._num_param = 0
self._num_rnn_layer = False
self._outputs_are_0d = {}
self._input_shapes = {}

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
Expand Down Expand Up @@ -1177,43 +1185,63 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))

for node in graph.node:
if node.op == 'Placeholder':
yongwww marked this conversation as resolved.
Show resolved Hide resolved
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
continue
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0:
self._input_shapes[node.name][idx] = 1
warnings.warn("Use 1 instead of -1 in shape of operator %s."
% node.name)

# Ignore user's input shape for Non placeholder
elif node.op == 'Const':
tensor_value = node.attr['value'].tensor
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
if shape and node.name in shape:
warnings.warn("Ignore the passed shape. "
"Shape in graphdef will be used for operator %s." % node.name)

final_op = None
# Parse the nodes to re-create TF graph using Symbol API of NNVM
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
# Tensorflow doesn't have separate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.

input_shapes = {}
input_0d_mismatch = set()
attr = self._parse_attr(node.attr)

#Variable converted to Const will not have only value attr
# Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const':
tensor_value = attr['value']
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList( \
tensor_value.tensor_shape)]
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif shape and node.name in shape:
# Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]]
elif node.op == 'Placeholder':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \
for tshape in attr['_output_shapes']]
elif shape:
else:
yongwww marked this conversation as resolved.
Show resolved Hide resolved
# Keep the list indexable to avoid key error.
# Actual value will be filled after node creation.
# Will infer shapes if the graph is not frozen with add_shapes=True
self._output_shapes[node.name] = [None]
else:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")

self._outputs_are_0d[node.name] = [ \
not tshape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]]

if node.op == "Placeholder":
self._nodes[node.name] = _sym.Variable(name=node.name,
shape=self._output_shapes[node.name][0])
shape=self._input_shapes[node.name])

elif node.op == "Const":
# All Const nodes are Param nodes, lets parse
Expand All @@ -1228,7 +1256,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

else:
# Pass the parsed shapes instead
attr["_output_shapes"] = self._output_shapes[node.name]
attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]

# Pass the node name too in attr
attr["_node_name"] = node.name
Expand Down Expand Up @@ -1269,7 +1297,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
inputs = self._fix_extranodes(node.op, attr, inputs)
op = self._convert_operator(node.op, inputs, attr, graph)

# Check is op is converted to param
# Check if op is converted to param
if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
op = _sym.Variable(name=node.name,
Expand All @@ -1279,9 +1307,19 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
self._nodes[node.name] = op
final_op = op

# Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]:
g = _graph.create(final_op)
self._output_shapes[node.name] = \
list(graph_util.infer_shape(g, **self._input_shapes))[-1]

if self._output_shapes[node.name] and shape and node.name in shape:
srkreddy1238 marked this conversation as resolved.
Show resolved Hide resolved
assert self._output_shapes[node.name] == list(shape[node.name])

# Infer shapes if passed explicitely
node_output = self._nodes[node.name]
if shape:
if shape and (not self._output_shapes[node.name][0]
or -1 in self._output_shapes[node.name][0]):
g = _graph.create(node_output)
shape_dict = {k: v.shape for k, v in self._params.items()}
shape_dict.update(shape)
Expand Down
Empty file.
153 changes: 153 additions & 0 deletions nnvm/python/nnvm/frontend/util/tensorflow_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""TF: Tensorflow parser"""
from __future__ import absolute_import as _abs
from __future__ import print_function
import os
from tensorflow.core.framework import graph_pb2
from tvm.contrib import util


class TFParser(object):
"""A Wrapper to handle tensorflow models parsing
TensorFlow is needed
```
parser = TfParser(model_dir)
graph = parser.parse()
```
Parameters
----------
model_dir : tensorflow frozen pb file or a directory that contains saved
model or checkpoints.
"""

def __init__(self, model_dir):
self._tmp_dir = util.tempdir()
self._model_dir = model_dir
self._graph = graph_pb2.GraphDef()

def _set_graph(self, graph):
"""Set Graph"""
self._graph = graph

def _get_graph(self):
"""Get Graph"""
return self._graph

def _load_pb_file(self):
"""Load single pb file"""
graph = self._get_graph()
with open(self._model_dir, "rb") as f:
graph.ParseFromString(f.read())
return graph

def _get_tag_set(self):
"""Return the tag set of saved model, multiple metagraphs are not supported"""
try:
from tensorflow.contrib.saved_model.python.saved_model import reader
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import saved_model.reader which is "
"required to get tag set from saved model.")
tag_sets = reader.get_saved_model_tag_sets(self._model_dir)
return tag_sets[0]

def _get_output_names(self):
"""Return the concatenated output names"""
try:
import tensorflow as tf
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model.")
tags = self._get_tag_set()
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(sess,
tags,
self._model_dir)
output_names = set()
for k in meta_graph_def.signature_def.keys():
outputs_tensor_info = meta_graph_def.signature_def[k].outputs
for output_tensor in outputs_tensor_info.values():
output_names.add(output_tensor.name)
output_names = [i.replace(":0", "") for i in output_names]
return ",".join(output_names)

def _load_saved_model(self):
"""Load the tensorflow saved model."""
try:
from tensorflow.python.tools import freeze_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import graph_util
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model.")

saved_model_dir = self._model_dir
output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb")
input_saved_model_dir = saved_model_dir
output_node_names = self._get_output_names()

input_binary = False
input_saver_def_path = False
restore_op_name = None
filename_tensor_name = None
clear_devices = True
input_meta_graph = False
checkpoint_path = None
input_graph_filename = None
saved_model_tags = ",".join(self._get_tag_set())

freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_graph_filename, clear_devices, "", "", "",
input_meta_graph, input_saved_model_dir,
saved_model_tags)

with ops.Graph().as_default():
output_graph_def = graph_pb2.GraphDef()
with open(output_graph_filename, "rb") as f:
output_graph_def.ParseFromString(f.read())
output_graph_def = graph_util.remove_training_nodes(output_graph_def)
return output_graph_def

def _load_ckpt(self):
"""TODO: Load checkpoint model."""
raise RuntimeError("InputConfiguration: Loading tf checkpoint model is "
"not supported yet.")

def parse(self):
"""Parse tensorflow models: checkpoints, saved models, and single pb
file.
"""
graph = None

if os.path.isdir(self._model_dir):
ckpt = os.path.join(self._model_dir, "checkpoint")
if not os.path.isfile(ckpt):
if not os.path.isdir(os.path.join(self._model_dir, "variables")):
raise RuntimeError("InputConfiguration: Invalid model path.")
graph = self._load_saved_model()
else:
graph = self._load_ckpt()
elif os.path.isfile(self._model_dir):
# Only .pb or .pbtxt is a valid suffix name.
if self._model_dir.endswith(".pb") or \
self._model_dir.endswith(".pbtxt"):
cur_dir = os.path.dirname(self._model_dir)
else:
raise RuntimeError("InputConfiguration: Invalid model format.")

# It is a saved model if `variables` directory is present at the
# same directory with the pb or pbtxt file.
if os.path.isdir(os.path.join(cur_dir, "variables")):
self._model_dir = cur_dir
graph = self._load_saved_model()
else:
graph = self._load_pb_file()
else:
raise RuntimeError("InputConfiguration: Unrecognized model "
"file or path.")

self._set_graph(graph)
return graph