Skip to content

Commit

Permalink
[Frontend] [Tensorflow] ReadVariableOp operator support (#4952)
Browse files Browse the repository at this point in the history
* tf frontend read variable op

* pylint fix

* tf frontend freezed graph pruned ops
  • Loading branch information
maheshambule authored Mar 2, 2020
1 parent 0fb4836 commit 8502691
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,12 @@ def _impl(inputs, attr, params):
# compatible operators that do NOT require any conversion.
_identity_list = []

# Operators that get pruned away when the complete graph is frozen.
# These operators are not needed for inference.
_freezed_graph_pruned_op_list = ['ReadVariableOp', 'ResourceGather', 'Variable',
'VariableV2', 'VarHandleOp', 'Assign', 'AssignVariableOp']


# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
Expand Down Expand Up @@ -2187,6 +2193,11 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
missing_operators = self._parse_import_prerequisites(graph)

if missing_operators:
freezed_ops = [op for op in missing_operators if op in _freezed_graph_pruned_op_list]
if freezed_ops:
raise Exception("Graph is not frozen. Provide a frozen graph. "
"Found operators {}".format(freezed_ops))

raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))

Expand Down
60 changes: 60 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""
from __future__ import print_function
import numpy as np
import pytest
import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import graph_util
Expand Down Expand Up @@ -1061,6 +1062,62 @@ def test_forward_variable():
_test_variable(np.random.uniform(size=(32, 100)).astype('float32'))


def test_read_variable_op():
""" Read Variable op test """

tf.reset_default_graph()
data = np.random.uniform(size=(32, 100)).astype('float32')
input_tensor = array_ops.placeholder(shape=data.shape, dtype=data.dtype)

size = input_tensor.shape.dims[1]
var_data = np.random.uniform(-5, 5, size=[size, size]).astype(np.float32)
input_var = tf.Variable(var_data, name='var1', use_resource=True)
math_ops.matmul(input_tensor, input_var)

out_name = ['MatMul:0']
out_node = ['MatMul']
in_name = ['Placeholder:0']
in_node = ['Placeholder']
in_data = [data]

with tf.Session() as sess:
sess.run(variables.global_variables_initializer())

final_graph_def = sess.graph.as_graph_def(add_shapes=True)
tf_output = run_tf_graph(sess, in_data, in_name, out_name)

shape_dict = {e: i.shape for e, i in zip(in_name, in_data)}
with pytest.raises(Exception) as exexcinfo:
mod, params = relay.frontend.from_tensorflow(final_graph_def,
layout=None,
shape=shape_dict,
outputs=None)

assert exexcinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph.")

# Now convert the variables to constant and run inference on the converted graph
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
out_node,
)

for device in ["llvm", "cuda"]:
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue

tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
target=device, out_names=out_name,
num_output=len(out_name))
for i in range(len(tf_output)):
tvm.testing.assert_allclose(
tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)

sess.close()


#######################################################################
# MatMul, BatchMatMul, BatchMatMulV2
# ----------------------------------
Expand Down Expand Up @@ -3038,3 +3095,6 @@ def test_forward_add_n():
test_forward_where()
test_forward_matmul()
test_forward_batch_matmul()

# Internal misc. ops
test_read_variable_op()

0 comments on commit 8502691

Please sign in to comment.