Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Get tags of saved model automatically
Browse files Browse the repository at this point in the history
Remove exception trail in tf parser error message

Fix lint

Fix comments
yongwww committed Feb 5, 2019
1 parent ec1000e commit af15ec4
Showing 2 changed files with 75 additions and 76 deletions.
74 changes: 48 additions & 26 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
import numpy as np

import tvm
import warnings
from .. import symbol as _sym
from .. import graph as _graph
from .. compiler import graph_util, build_module
@@ -303,7 +304,8 @@ def _impl(inputs, attr, params):
def _decode_image():
def _impl(inputs, attr, params):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
print("DecodeJpeg: It's a pass through, please handle preprocessing before input")
warnings.warn("DecodeJpeg: It's a pass through, "
"please handle preprocessing before input")
return inputs[0]
return _impl

@@ -938,8 +940,6 @@ def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
'Split' : _split(False),
'SplitV' : _split(True),
'Unpack' : _unpack(),
'QueueDequeueManyV2' : _undef(),
'FIFOQueueV2' : _undef(),
}

# _convert_map_rnn defines maps of rnn operator name to
@@ -1184,42 +1184,57 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
if missing_operators:
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))

for node in graph.node:
if node.op == 'Placeholder':
self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
self._input_shapes[node.name][0] = 1
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
continue
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0:
self._input_shapes[node.name][idx] = 1
warnings.warn("Use 1 instead of -1 in shape of operator %s."
% node.name)

# Ignore user's input shape for Non placeholder
elif node.op == 'Const':
tensor_value = node.attr['value'].tensor
self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
if shape and node.name in shape:
warnings.warn("Ignore the passed shape. "
"Shape in graphdef will be used for operator %s." % node.name)

final_op = None
# Parse the nodes to re-create TF graph using Symbol API of NNVM
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
# Tensorflow doesn't have separate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.

input_shapes = {}
input_0d_mismatch = set()
attr = self._parse_attr(node.attr)

#Variable converted to Const will not have only value attr
# Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif shape and node.name in shape:
# Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]]
elif node.op == 'Placeholder':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif shape and node.name in shape:
# Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \
for tshape in attr['_output_shapes']]
elif shape:
else:
# Keep the list indexable to avoid key error.
# Actual value will be filled after node creation.
# Will infer shapes if the graph is not frozen with add_shapes=True
self._output_shapes[node.name] = [None]
else:
self._output_shapes[node.name] = None

self._outputs_are_0d[node.name] = [ \
not tshape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]]
@@ -1241,7 +1256,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

else:
# Pass the parsed shapes instead
output_shapes = self._output_shapes[node.name]
attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]

# Pass the node name too in attr
attr["_node_name"] = node.name
@@ -1291,20 +1306,27 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
# Assuming only one output.
self._nodes[node.name] = op
final_op = op
# Infer shapes if passed explicitely
node_output = self._nodes[node.name]
if shape:
g = _graph.create(node_output)
shape_dict = {k: v.shape for k, v in self._params.items()}
shape_dict.update(shape)
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
self._output_shapes[node.name] = out_shapes
elif output_shapes == None:
g = _graph.create(node_output)
self._output_shapes[node.name] = list(graph_util.infer_shape(g, **self._input_shapes))[-1]

# Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]:
g = _graph.create(final_op)
self._output_shapes[node.name] = \
list(graph_util.infer_shape(g, **self._input_shapes))[-1]
else:
self._output_shapes[node.name] = output_shapes

if self._output_shapes[node.name] and shape and node.name in shape:
assert self._input_shapes[node.name] == list(shape[node.name])

# Infer shapes if passed explicitely
node_output = self._nodes[node.name]
if shape:
g = _graph.create(node_output)
shape_dict = {k: v.shape for k, v in self._params.items()}
shape_dict.update(shape)
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
self._output_shapes[node.name] = out_shapes

out = []
if outputs is None:
out.append(final_op)
77 changes: 27 additions & 50 deletions nnvm/python/nnvm/frontend/util/tensorflow_parser.py
Original file line number Diff line number Diff line change
@@ -2,32 +2,13 @@
from __future__ import absolute_import as _abs
from __future__ import print_function
import os

try:
from tensorflow.core.framework import graph_pb2
except ImportError as e:
from nnvm.frontend.protobuf import graph_pb2


try:
from tempfile import TemporaryDirectory
except ImportError:
import tempfile
import shutil

class TemporaryDirectory(object):
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name

def __exit__(self, exc, value, tb):
shutil.rmtree(self.name)
from tensorflow.core.framework import graph_pb2
from tvm.contrib import util


class TFParser(object):
"""A Wrapper to handle tensorflow models parsing
Works w/o installing tensorflow,
Protocol Buffer is needed
TensorFlow is needed
```
parser = TfParser(model_dir)
graph = parser.parse()
@@ -39,7 +20,7 @@ class TFParser(object):
"""

def __init__(self, model_dir):
self._tmp_dir = TemporaryDirectory()
self._tmp_dir = util.tempdir()
self._model_dir = model_dir
self._graph = graph_pb2.GraphDef()

@@ -51,41 +32,37 @@ def _get_graph(self):
"""Get Graph"""
return self._graph

def _output_graph(self):
import logging
logging.basicConfig(level=logging.DEBUG)
for node in self._get_graph().node:
logging.info("Name: {}".format(node.name))
logging.info("\top: {}".format(node.op))
for input in node.input:
logging.info("\t\tinput: {}".format(input))
logging.info("\t\tdevice: {}".format(node.device))
logging.info("\t\tAttrValue: ")
for key in node.attr.keys():
logging.info("\t\t\tkey: {} => value: {}"
.format(key, node.attr[key]))
logging.info(node.attr['shape'].shape)

def _load_pb_file(self):
"""Load single pb file"""
graph = self._get_graph()
with open(self._model_dir, "rb") as f:
graph.ParseFromString(f.read())
return graph

def _get_output_names(self, model_path):
def _get_tag_set(self):
"""Return the tag set of saved model, multiple metagraphs are not supported"""
try:
from tensorflow.contrib.saved_model.python.saved_model import reader
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import saved_model.reader which is "
"required to get tag set from saved model.")
tag_sets = reader.get_saved_model_tag_sets(self._model_dir)
return tag_sets[0]

def _get_output_names(self):
"""Return the concatenated output names"""
try:
import tensorflow as tf
except ImportError as e:
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model. {}".format(e))

"required to restore from saved model.")
tags = self._get_tag_set()
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(sess,
[tf.saved_model.tag_constants.SERVING],
model_path)
tags,
self._model_dir)
output_names = set()
for k in meta_graph_def.signature_def.keys():
outputs_tensor_info = meta_graph_def.signature_def[k].outputs
@@ -97,19 +74,18 @@ def _get_output_names(self, model_path):
def _load_saved_model(self):
"""Load the tensorflow saved model."""
try:
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import graph_util
except ImportError as e:
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model. {}".format(e))
"required to restore from saved model.")

saved_model_dir = self._model_dir
output_graph_filename = os.path.join(self._tmp_dir.name, "neo_frozen_model.pb")
output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb")
input_saved_model_dir = saved_model_dir
output_node_names = self._get_output_names(self._model_dir)
output_node_names = self._get_output_names()

input_binary = False
input_saver_def_path = False
@@ -119,7 +95,7 @@ def _load_saved_model(self):
input_meta_graph = False
checkpoint_path = None
input_graph_filename = None
saved_model_tags = tf.saved_model.tag_constants.SERVING
saved_model_tags = ",".join(self._get_tag_set())

freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
@@ -145,6 +121,7 @@ def parse(self):
file.
"""
graph = None

if os.path.isdir(self._model_dir):
ckpt = os.path.join(self._model_dir, "checkpoint")
if not os.path.isfile(ckpt):

0 comments on commit af15ec4

Please sign in to comment.