From 5e85c8fc3d28a4d8ca897cfcc3a5a22e6832b494 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 27 Feb 2020 20:42:32 +0530 Subject: [PATCH 1/3] tf frontend read variable op --- python/tvm/relay/frontend/tensorflow.py | 12 +++- .../frontend/tensorflow/test_forward.py | 61 +++++++++++++++++++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 6f27d73315a1..40099626fb39 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2187,8 +2187,16 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): missing_operators = self._parse_import_prerequisites(graph) if missing_operators: - raise NotImplementedError( \ - "The following operators are not implemented: {}".format(missing_operators)) + # TODO: ReadVariableOp gets removed when graph is frozen. + # Add list of other operators as well which get removed + # and are not needed for inference. + # Other approach is instead of raising error, we can freeze the graph. + if 'ReadVariableOp' in missing_operators: + raise Exception("Found ReadVariableOp operator in the graph. " + "Graph is not frozen. Provide a frozen graph.") + else: + raise NotImplementedError( \ + "The following operators are not implemented: {}".format(missing_operators)) control_flow_node_map = defaultdict(set) for node in graph.node: diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 9cd978e2e147..62dd1da64c09 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -22,6 +22,7 @@ """ from __future__ import print_function import numpy as np +import pytest import tensorflow as tf from tensorflow.python.framework import constant_op from tensorflow.python.framework import graph_util @@ -1061,6 +1062,63 @@ def test_forward_variable(): _test_variable(np.random.uniform(size=(32, 100)).astype('float32')) +def test_read_variable_op(): + """ Read Variable op test """ + + tf.reset_default_graph() + data = np.random.uniform(size=(32, 100)).astype('float32') + input_tensor = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + + size = input_tensor.shape.dims[1] + var_data = np.random.uniform(-5, 5, size=[size, size]).astype(np.float32) + input_var = tf.Variable(var_data, name='var1', use_resource=True) + math_ops.matmul(input_tensor, input_var) + + out_name = ['MatMul:0'] + out_node = ['MatMul'] + in_name = ['Placeholder:0'] + in_node = ['Placeholder'] + in_data = [data] + + with tf.Session() as sess: + sess.run(variables.global_variables_initializer()) + + final_graph_def = sess.graph.as_graph_def(add_shapes=True) + tf_output = run_tf_graph(sess, in_data, in_name, out_name) + + shape_dict = {e: i.shape for e, i in zip(in_name, in_data)} + with pytest.raises(Exception) as exexcinfo: + mod, params = relay.frontend.from_tensorflow(final_graph_def, + layout=None, + shape=shape_dict, + outputs=None) + + assert exexcinfo.value.args[0] == "Found ReadVariableOp operator in the graph. " \ + "Graph is not frozen. Provide a frozen graph." + + # Now convert the variables to constant and run inference on the converted graph + final_graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + out_node, + ) + + for device in ["llvm", "cuda"]: + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + continue + + tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, + target=device, out_names=out_name, + num_output=len(out_name)) + for i in range(len(tf_output)): + tvm.testing.assert_allclose( + tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) + + sess.close() + + ####################################################################### # MatMul, BatchMatMul, BatchMatMulV2 # ---------------------------------- @@ -3038,3 +3096,6 @@ def test_forward_add_n(): test_forward_where() test_forward_matmul() test_forward_batch_matmul() + + # Internal misc. ops + test_read_variable_op() From 96f1e61fa18fc414b3c102386ea8d1d75c26d299 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 27 Feb 2020 20:52:03 +0530 Subject: [PATCH 2/3] pylint fix --- python/tvm/relay/frontend/tensorflow.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 40099626fb39..cc9b45d7674e 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2194,9 +2194,9 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): if 'ReadVariableOp' in missing_operators: raise Exception("Found ReadVariableOp operator in the graph. " "Graph is not frozen. Provide a frozen graph.") - else: - raise NotImplementedError( \ - "The following operators are not implemented: {}".format(missing_operators)) + + raise NotImplementedError( \ + "The following operators are not implemented: {}".format(missing_operators)) control_flow_node_map = defaultdict(set) for node in graph.node: From 73bf5427ced698fb42050c2f753dd823e4bb8308 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Mon, 2 Mar 2020 18:19:51 +0530 Subject: [PATCH 3/3] tf frontend freezed graph pruned ops --- python/tvm/relay/frontend/tensorflow.py | 17 ++++++++++------- .../python/frontend/tensorflow/test_forward.py | 3 +-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index cc9b45d7674e..14d2418da710 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1500,6 +1500,12 @@ def _impl(inputs, attr, params): # compatible operators that do NOT require any conversion. _identity_list = [] +# Operators that get pruned away when the complete graph is frozen. +# These operators are not needed for inference. +_freezed_graph_pruned_op_list = ['ReadVariableOp', 'ResourceGather', 'Variable', + 'VariableV2', 'VarHandleOp', 'Assign', 'AssignVariableOp'] + + # _convert_map defines maps of name to converter functor(callable) # for 1 to 1 mapping, use Renamer if nothing but name is different # use AttrCvt if attributes need to be converted @@ -2187,13 +2193,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): missing_operators = self._parse_import_prerequisites(graph) if missing_operators: - # TODO: ReadVariableOp gets removed when graph is frozen. - # Add list of other operators as well which get removed - # and are not needed for inference. - # Other approach is instead of raising error, we can freeze the graph. - if 'ReadVariableOp' in missing_operators: - raise Exception("Found ReadVariableOp operator in the graph. " - "Graph is not frozen. Provide a frozen graph.") + freezed_ops = [op for op in missing_operators if op in _freezed_graph_pruned_op_list] + if freezed_ops: + raise Exception("Graph is not frozen. Provide a frozen graph. " + "Found operators {}".format(freezed_ops)) raise NotImplementedError( \ "The following operators are not implemented: {}".format(missing_operators)) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 62dd1da64c09..42408b706111 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1093,8 +1093,7 @@ def test_read_variable_op(): shape=shape_dict, outputs=None) - assert exexcinfo.value.args[0] == "Found ReadVariableOp operator in the graph. " \ - "Graph is not frozen. Provide a frozen graph." + assert exexcinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph.") # Now convert the variables to constant and run inference on the converted graph final_graph_def = tf.graph_util.convert_variables_to_constants(