Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend][TF] fix _parse_param bug #4711

Merged
merged 1 commit into from
Jan 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2391,7 +2391,7 @@ def _parse_param(self, key, value, name, shape):
if np_array.dtype == np.dtype(object):
# Object types are generally tensorflow DT_STRING (DecodeJpeg op).
# Just leave it as placeholder.
if shape:
if shape and name in shape:
var_shape = shape[name]
else:
var_shape = tensor_util.TensorShapeProtoToList(value.tensor.tensor_shape)
Expand Down
18 changes: 11 additions & 7 deletions tests/python/frontend/tensorflow/test_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,22 @@
from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow

def run_relay(graph, *vars):
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
def run_relay(graph, shape_dict=None, *vars):
mod, params = from_tensorflow(
graph.as_graph_def(add_shapes=True),
shape=shape_dict)
ex = relay.create_executor('debug', mod=mod)
return ex.evaluate()(*vars)

def test_assert_true():
g = tf.Graph()
shape = (1, 2)
with g.as_default():
x = tf.placeholder(tf.float32, shape=())
assert_op = tf.Assert(tf.less_equal(x, x), ["it failed"])
x = tf.placeholder(tf.float32, shape=shape, name="input")
assert_op = tf.Assert(tf.reduce_all(tf.less_equal(x, x)), ["it failed"])

with tf.Session() as sess:
x_value = np.random.rand()
x_value = np.random.rand(*shape)
assert sess.run(assert_op, feed_dict={x: x_value}) is None

# In TVM, tf.assert is converted to a no-op which is actually a 0,
Expand All @@ -44,7 +47,7 @@ def test_assert_true():
# do that, it's happening in Relay, and that optimization shouldn't
# affect the arity of the main function. We should have to pass in
# x_value here.
np.testing.assert_allclose(0, run_relay(g).asnumpy())
np.testing.assert_allclose(0, run_relay(g, {'input':shape}).asnumpy())

def test_assert_true_var_capture():
g = tf.Graph()
Expand All @@ -65,7 +68,8 @@ def test_assert_true_var_capture():
# the graph as a boolean, which is not correct - as you can see above,
# TF believes that the value of this graph is None. In addition, the
# arity of the translated function should be 1, not 2.
np.testing.assert_allclose(True, run_relay(g, x_value, x_value).asnumpy())
np.testing.assert_allclose(True,
run_relay(g, None, x_value, x_value).asnumpy())

def test_assert_false():
g = tf.Graph()
Expand Down