Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] [Tensorflow] ReadVariableOp operator support #4952

Merged
merged 3 commits into from
Mar 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,12 @@ def _impl(inputs, attr, params):
# compatible operators that do NOT require any conversion.
_identity_list = []

# Operators that get pruned away when the complete graph is frozen.
# These operators are not needed for inference.
_freezed_graph_pruned_op_list = ['ReadVariableOp', 'ResourceGather', 'Variable',
'VariableV2', 'VarHandleOp', 'Assign', 'AssignVariableOp']


# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
Expand Down Expand Up @@ -2187,6 +2193,11 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
missing_operators = self._parse_import_prerequisites(graph)

if missing_operators:
freezed_ops = [op for op in missing_operators if op in _freezed_graph_pruned_op_list]
if freezed_ops:
raise Exception("Graph is not frozen. Provide a frozen graph. "
"Found operators {}".format(freezed_ops))

raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))

Expand Down
60 changes: 60 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""
from __future__ import print_function
import numpy as np
import pytest
import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import graph_util
Expand Down Expand Up @@ -1061,6 +1062,62 @@ def test_forward_variable():
_test_variable(np.random.uniform(size=(32, 100)).astype('float32'))


def test_read_variable_op():
""" Read Variable op test """

tf.reset_default_graph()
data = np.random.uniform(size=(32, 100)).astype('float32')
input_tensor = array_ops.placeholder(shape=data.shape, dtype=data.dtype)

size = input_tensor.shape.dims[1]
var_data = np.random.uniform(-5, 5, size=[size, size]).astype(np.float32)
input_var = tf.Variable(var_data, name='var1', use_resource=True)
math_ops.matmul(input_tensor, input_var)

out_name = ['MatMul:0']
out_node = ['MatMul']
in_name = ['Placeholder:0']
in_node = ['Placeholder']
in_data = [data]

with tf.Session() as sess:
sess.run(variables.global_variables_initializer())

final_graph_def = sess.graph.as_graph_def(add_shapes=True)
tf_output = run_tf_graph(sess, in_data, in_name, out_name)

shape_dict = {e: i.shape for e, i in zip(in_name, in_data)}
with pytest.raises(Exception) as exexcinfo:
mod, params = relay.frontend.from_tensorflow(final_graph_def,
layout=None,
shape=shape_dict,
outputs=None)

assert exexcinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph.")

# Now convert the variables to constant and run inference on the converted graph
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
out_node,
)

for device in ["llvm", "cuda"]:
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue

tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
target=device, out_names=out_name,
num_output=len(out_name))
for i in range(len(tf_output)):
tvm.testing.assert_allclose(
tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)

sess.close()


#######################################################################
# MatMul, BatchMatMul, BatchMatMulV2
# ----------------------------------
Expand Down Expand Up @@ -3038,3 +3095,6 @@ def test_forward_add_n():
test_forward_where()
test_forward_matmul()
test_forward_batch_matmul()

# Internal misc. ops
test_read_variable_op()