-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TF parser: return the consistent error message to error handler
- Loading branch information
Showing
2 changed files
with
165 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,176 @@ | ||
"""TF: Tensorflow parser""" | ||
from __future__ import absolute_import as _abs | ||
from __future__ import print_function | ||
from nnvm.frontend.protobuf import graph_pb2 | ||
import os | ||
|
||
try: | ||
from tensorflow.core.framework import graph_pb2 | ||
except ImportError as e: | ||
from nnvm.frontend.protobuf import graph_pb2 | ||
|
||
|
||
try: | ||
from tempfile import TemporaryDirectory | ||
except ImportError: | ||
import tempfile | ||
import shutil | ||
|
||
class TemporaryDirectory(object): | ||
def __enter__(self): | ||
self.name = tempfile.mkdtemp() | ||
return self.name | ||
|
||
def __exit__(self, exc, value, tb): | ||
shutil.rmtree(self.name) | ||
|
||
|
||
class TFParser(object): | ||
"""A Wrapper to handle tensorflow frozen model parsing | ||
"""A Wrapper to handle tensorflow models parsing | ||
Works w/o installing tensorflow, | ||
Protocol Buffer is needed | ||
``` | ||
parser = TfParser(pb_file) | ||
parser = TfParser(model_dir) | ||
graph = parser.parse() | ||
``` | ||
Parameters | ||
---------- | ||
pb_file : tensorflow frozen pb file | ||
The pb file should include both operations and tensors | ||
model_dir : tensorflow frozen pb file or a directory that contains saved | ||
model or checkpoints. | ||
""" | ||
|
||
def __init__(self, pb_file): | ||
self._pb = pb_file | ||
def __init__(self, model_dir): | ||
self._tmp_dir = TemporaryDirectory() | ||
self._model_dir = model_dir | ||
self._graph = graph_pb2.GraphDef() | ||
|
||
def _load_model(self): | ||
"""load frozen tensorflow model, return GraphDef """ | ||
with open(self._pb, "rb") as f: | ||
self._graph.ParseFromString(f.read()) | ||
def _set_graph(self, graph): | ||
"""Set Graph""" | ||
self._graph = graph | ||
|
||
def _get_graph(self): | ||
"""Get Graph""" | ||
return self._graph | ||
|
||
def _output_graph(self): | ||
import logging | ||
logging.basicConfig(level=logging.DEBUG) | ||
for node in self._get_graph().node: | ||
logging.info("Name: {}".format(node.name)) | ||
logging.info("\top: {}".format(node.op)) | ||
for input in node.input: | ||
logging.info("\t\tinput: {}".format(input)) | ||
logging.info("\t\tdevice: {}".format(node.device)) | ||
logging.info("\t\tAttrValue: ") | ||
for key in node.attr.keys(): | ||
logging.info("\t\t\tkey: {} => value: {}" | ||
.format(key, node.attr[key])) | ||
logging.info(node.attr['shape'].shape) | ||
|
||
def _load_pb_file(self): | ||
"""Load single pb file""" | ||
graph = self._get_graph() | ||
with open(self._model_dir, "rb") as f: | ||
graph.ParseFromString(f.read()) | ||
return graph | ||
|
||
def _get_output_names(self, model_path): | ||
"""Return the concatenated output names""" | ||
try: | ||
import tensorflow as tf | ||
except ImportError as e: | ||
raise ImportError( | ||
"InputConfiguration: Unable to import tensorflow which is " | ||
"required to restore from saved model. {}".format(e)) | ||
|
||
with tf.Session() as sess: | ||
meta_graph_def = tf.saved_model.loader.load(sess, | ||
[tf.saved_model.tag_constants.SERVING], | ||
model_path) | ||
output_names = set() | ||
for k in meta_graph_def.signature_def.keys(): | ||
outputs_tensor_info = meta_graph_def.signature_def[k].outputs | ||
for output_tensor in outputs_tensor_info.values(): | ||
output_names.add(output_tensor.name) | ||
output_names = [i.replace(":0", "") for i in output_names] | ||
return ",".join(output_names) | ||
|
||
def _load_saved_model(self): | ||
"""Load the tensorflow saved model.""" | ||
try: | ||
import tensorflow as tf | ||
from tensorflow.python.tools import freeze_graph | ||
from tensorflow.python.framework import ops | ||
from tensorflow.python.framework import graph_util | ||
except ImportError as e: | ||
raise ImportError( | ||
"InputConfiguration: Unable to import tensorflow which is " | ||
"required to restore from saved model. {}".format(e)) | ||
|
||
saved_model_dir = self._model_dir | ||
output_graph_filename = os.path.join(self._tmp_dir.name, "neo_frozen_model.pb") | ||
input_saved_model_dir = saved_model_dir | ||
output_node_names = self._get_output_names(self._model_dir) | ||
|
||
input_binary = False | ||
input_saver_def_path = False | ||
restore_op_name = None | ||
filename_tensor_name = None | ||
clear_devices = True | ||
input_meta_graph = False | ||
checkpoint_path = None | ||
input_graph_filename = None | ||
saved_model_tags = tf.saved_model.tag_constants.SERVING | ||
|
||
freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path, | ||
input_binary, checkpoint_path, output_node_names, | ||
restore_op_name, filename_tensor_name, | ||
output_graph_filename, clear_devices, "", "", "", | ||
input_meta_graph, input_saved_model_dir, | ||
saved_model_tags) | ||
|
||
with ops.Graph().as_default(): | ||
output_graph_def = graph_pb2.GraphDef() | ||
with open(output_graph_filename, "rb") as f: | ||
output_graph_def.ParseFromString(f.read()) | ||
output_graph_def = graph_util.remove_training_nodes(output_graph_def) | ||
return output_graph_def | ||
|
||
def _load_ckpt(self): | ||
"""TODO: Load checkpoint model.""" | ||
raise RuntimeError("InputConfiguration: Loading tf checkpoint model is " | ||
"not supported yet.") | ||
|
||
def parse(self): | ||
self._load_model() | ||
return self._graph | ||
"""Parse tensorflow models: checkpoints, saved models, and single pb | ||
file. | ||
""" | ||
graph = None | ||
if os.path.isdir(self._model_dir): | ||
ckpt = os.path.join(self._model_dir, "checkpoint") | ||
if not os.path.isfile(ckpt): | ||
if not os.path.isdir(os.path.join(self._model_dir, "variables")): | ||
raise RuntimeError("InputConfiguration: Invalid model path.") | ||
graph = self._load_saved_model() | ||
else: | ||
graph = self._load_ckpt() | ||
elif os.path.isfile(self._model_dir): | ||
# Only .pb or .pbtxt is a valid suffix name. | ||
if self._model_dir.endswith(".pb") or \ | ||
self._model_dir.endswith(".pbtxt"): | ||
cur_dir = os.path.dirname(self._model_dir) | ||
else: | ||
raise RuntimeError("InputConfiguration: Invalid model format.") | ||
|
||
# It is a saved model if `variables` directory is present at the | ||
# same directory with the pb or pbtxt file. | ||
if os.path.isdir(os.path.join(cur_dir, "variables")): | ||
self._model_dir = cur_dir | ||
graph = self._load_saved_model() | ||
else: | ||
graph = self._load_pb_file() | ||
else: | ||
raise RuntimeError("InputConfiguration: Unrecognized model " | ||
"file or path.") | ||
|
||
self._set_graph(graph) | ||
return graph |