Skip to content

Commit

Permalink
Get tags of saved model automatically
Browse files Browse the repository at this point in the history
Remove exception trail in tf parser error message

Fix lint

Fix comments
  • Loading branch information
yongwww committed Feb 8, 2019
1 parent ec1000e commit 412f9fb
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 79 deletions.
79 changes: 50 additions & 29 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs
from __future__ import print_function

import warnings
# Numpy support
import numpy as np

Expand Down Expand Up @@ -303,7 +304,8 @@ def _impl(inputs, attr, params):
def _decode_image():
def _impl(inputs, attr, params):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
print("DecodeJpeg: It's a pass through, please handle preprocessing before input")
warnings.warn("DecodeJpeg: It's a pass through, "
"please handle preprocessing before input")
return inputs[0]
return _impl

Expand Down Expand Up @@ -938,8 +940,6 @@ def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
'Split' : _split(False),
'SplitV' : _split(True),
'Unpack' : _unpack(),
'QueueDequeueManyV2' : _undef(),
'FIFOQueueV2' : _undef(),
}

# _convert_map_rnn defines maps of rnn operator name to
Expand Down Expand Up @@ -1184,42 +1184,57 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
if missing_operators:
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))

for node in graph.node:
if node.op == 'Placeholder':
self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
self._input_shapes[node.name][0] = 1
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
continue
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0:
self._input_shapes[node.name][idx] = 1
warnings.warn("Use 1 instead of -1 in shape of operator %s."
% node.name)

# Ignore user's input shape for Non placeholder
elif node.op == 'Const':
tensor_value = node.attr['value'].tensor
self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
if shape and node.name in shape:
warnings.warn("Ignore the passed shape. "
"Shape in graphdef will be used for operator %s." % node.name)

final_op = None
# Parse the nodes to re-create TF graph using Symbol API of NNVM
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
# Tensorflow doesn't have separate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.

input_shapes = {}
input_0d_mismatch = set()
attr = self._parse_attr(node.attr)

#Variable converted to Const will not have only value attr
# Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif shape and node.name in shape:
# Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]]
elif node.op == 'Placeholder':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif shape and node.name in shape:
# Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \
for tshape in attr['_output_shapes']]
elif shape:
else:
# Keep the list indexable to avoid key error.
# Actual value will be filled after node creation.
# Will infer shapes if the graph is not frozen with add_shapes=True
self._output_shapes[node.name] = [None]
else:
self._output_shapes[node.name] = None

self._outputs_are_0d[node.name] = [ \
not tshape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]]
Expand All @@ -1241,7 +1256,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

else:
# Pass the parsed shapes instead
output_shapes = self._output_shapes[node.name]
attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]

# Pass the node name too in attr
attr["_node_name"] = node.name
Expand Down Expand Up @@ -1282,7 +1297,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
inputs = self._fix_extranodes(node.op, attr, inputs)
op = self._convert_operator(node.op, inputs, attr, graph)

# Check is op is converted to param
# Check if op is converted to param
if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
op = _sym.Variable(name=node.name,
Expand All @@ -1291,19 +1306,25 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
# Assuming only one output.
self._nodes[node.name] = op
final_op = op
# Infer shapes if passed explicitely
node_output = self._nodes[node.name]
if shape:
g = _graph.create(node_output)
shape_dict = {k: v.shape for k, v in self._params.items()}
shape_dict.update(shape)
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
self._output_shapes[node.name] = out_shapes
elif output_shapes == None:
g = _graph.create(node_output)
self._output_shapes[node.name] = list(graph_util.infer_shape(g, **self._input_shapes))[-1]
else:
self._output_shapes[node.name] = output_shapes

# Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]:
g = _graph.create(final_op)
self._output_shapes[node.name] = \
list(graph_util.infer_shape(g, **self._input_shapes))[-1]

if self._output_shapes[node.name] and shape and node.name in shape:
assert self._output_shapes[node.name] == list(shape[node.name])

# Infer shapes if passed explicitely
node_output = self._nodes[node.name]
if shape and (not self._output_shapes[node.name][0]
or -1 in self._output_shapes[node.name][0]):
g = _graph.create(node_output)
shape_dict = {k: v.shape for k, v in self._params.items()}
shape_dict.update(shape)
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
self._output_shapes[node.name] = out_shapes

out = []
if outputs is None:
Expand Down
77 changes: 27 additions & 50 deletions nnvm/python/nnvm/frontend/util/tensorflow_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,13 @@
from __future__ import absolute_import as _abs
from __future__ import print_function
import os

try:
from tensorflow.core.framework import graph_pb2
except ImportError as e:
from nnvm.frontend.protobuf import graph_pb2


try:
from tempfile import TemporaryDirectory
except ImportError:
import tempfile
import shutil

class TemporaryDirectory(object):
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name

def __exit__(self, exc, value, tb):
shutil.rmtree(self.name)
from tensorflow.core.framework import graph_pb2
from tvm.contrib import util


class TFParser(object):
"""A Wrapper to handle tensorflow models parsing
Works w/o installing tensorflow,
Protocol Buffer is needed
TensorFlow is needed
```
parser = TfParser(model_dir)
graph = parser.parse()
Expand All @@ -39,7 +20,7 @@ class TFParser(object):
"""

def __init__(self, model_dir):
self._tmp_dir = TemporaryDirectory()
self._tmp_dir = util.tempdir()
self._model_dir = model_dir
self._graph = graph_pb2.GraphDef()

Expand All @@ -51,41 +32,37 @@ def _get_graph(self):
"""Get Graph"""
return self._graph

def _output_graph(self):
import logging
logging.basicConfig(level=logging.DEBUG)
for node in self._get_graph().node:
logging.info("Name: {}".format(node.name))
logging.info("\top: {}".format(node.op))
for input in node.input:
logging.info("\t\tinput: {}".format(input))
logging.info("\t\tdevice: {}".format(node.device))
logging.info("\t\tAttrValue: ")
for key in node.attr.keys():
logging.info("\t\t\tkey: {} => value: {}"
.format(key, node.attr[key]))
logging.info(node.attr['shape'].shape)

def _load_pb_file(self):
"""Load single pb file"""
graph = self._get_graph()
with open(self._model_dir, "rb") as f:
graph.ParseFromString(f.read())
return graph

def _get_output_names(self, model_path):
def _get_tag_set(self):
"""Return the tag set of saved model, multiple metagraphs are not supported"""
try:
from tensorflow.contrib.saved_model.python.saved_model import reader
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import saved_model.reader which is "
"required to get tag set from saved model.")
tag_sets = reader.get_saved_model_tag_sets(self._model_dir)
return tag_sets[0]

def _get_output_names(self):
"""Return the concatenated output names"""
try:
import tensorflow as tf
except ImportError as e:
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model. {}".format(e))

"required to restore from saved model.")
tags = self._get_tag_set()
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(sess,
[tf.saved_model.tag_constants.SERVING],
model_path)
tags,
self._model_dir)
output_names = set()
for k in meta_graph_def.signature_def.keys():
outputs_tensor_info = meta_graph_def.signature_def[k].outputs
Expand All @@ -97,19 +74,18 @@ def _get_output_names(self, model_path):
def _load_saved_model(self):
"""Load the tensorflow saved model."""
try:
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import graph_util
except ImportError as e:
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model. {}".format(e))
"required to restore from saved model.")

saved_model_dir = self._model_dir
output_graph_filename = os.path.join(self._tmp_dir.name, "neo_frozen_model.pb")
output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb")
input_saved_model_dir = saved_model_dir
output_node_names = self._get_output_names(self._model_dir)
output_node_names = self._get_output_names()

input_binary = False
input_saver_def_path = False
Expand All @@ -119,7 +95,7 @@ def _load_saved_model(self):
input_meta_graph = False
checkpoint_path = None
input_graph_filename = None
saved_model_tags = tf.saved_model.tag_constants.SERVING
saved_model_tags = ",".join(self._get_tag_set())

freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
Expand All @@ -145,6 +121,7 @@ def parse(self):
file.
"""
graph = None

if os.path.isdir(self._model_dir):
ckpt = os.path.join(self._model_dir, "checkpoint")
if not os.path.isfile(ckpt):
Expand Down

0 comments on commit 412f9fb

Please sign in to comment.