Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] [Tensorflow] ReadVariableOp operator support #4952

Merged
merged 3 commits into from
Mar 2, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2187,6 +2187,14 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
missing_operators = self._parse_import_prerequisites(graph)

if missing_operators:
# TODO: ReadVariableOp gets removed when graph is frozen.
# Add list of other operators as well which get removed
# and are not needed for inference.
# Other approach is instead of raising error, we can freeze the graph.
if 'ReadVariableOp' in missing_operators:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any other op falling into this case? If so, I suggest we create a list containing all these ops.

Copy link
Contributor Author

@maheshambule maheshambule Feb 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there are.
I will compile the list and add it in.

Copy link
Contributor Author

@maheshambule maheshambule Mar 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added other operators as well

raise Exception("Found ReadVariableOp operator in the graph. "
"Graph is not frozen. Provide a frozen graph.")

raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))

Expand Down
61 changes: 61 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""
from __future__ import print_function
import numpy as np
import pytest
import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import graph_util
Expand Down Expand Up @@ -1061,6 +1062,63 @@ def test_forward_variable():
_test_variable(np.random.uniform(size=(32, 100)).astype('float32'))


def test_read_variable_op():
""" Read Variable op test """

tf.reset_default_graph()
data = np.random.uniform(size=(32, 100)).astype('float32')
input_tensor = array_ops.placeholder(shape=data.shape, dtype=data.dtype)

size = input_tensor.shape.dims[1]
var_data = np.random.uniform(-5, 5, size=[size, size]).astype(np.float32)
input_var = tf.Variable(var_data, name='var1', use_resource=True)
math_ops.matmul(input_tensor, input_var)

out_name = ['MatMul:0']
out_node = ['MatMul']
in_name = ['Placeholder:0']
in_node = ['Placeholder']
in_data = [data]

with tf.Session() as sess:
sess.run(variables.global_variables_initializer())

final_graph_def = sess.graph.as_graph_def(add_shapes=True)
tf_output = run_tf_graph(sess, in_data, in_name, out_name)

shape_dict = {e: i.shape for e, i in zip(in_name, in_data)}
with pytest.raises(Exception) as exexcinfo:
mod, params = relay.frontend.from_tensorflow(final_graph_def,
layout=None,
shape=shape_dict,
outputs=None)

assert exexcinfo.value.args[0] == "Found ReadVariableOp operator in the graph. " \
"Graph is not frozen. Provide a frozen graph."

# Now convert the variables to constant and run inference on the converted graph
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
out_node,
)

for device in ["llvm", "cuda"]:
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue

tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
target=device, out_names=out_name,
num_output=len(out_name))
for i in range(len(tf_output)):
tvm.testing.assert_allclose(
tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)

sess.close()


#######################################################################
# MatMul, BatchMatMul, BatchMatMulV2
# ----------------------------------
Expand Down Expand Up @@ -3038,3 +3096,6 @@ def test_forward_add_n():
test_forward_where()
test_forward_matmul()
test_forward_batch_matmul()

# Internal misc. ops
test_read_variable_op()