From 48dd312343484d63fa1227a3f9dbf1f7b406bfc0 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Tue, 17 Mar 2020 20:26:26 +0000 Subject: [PATCH] Fix for dilation2d --- python/tvm/relay/frontend/tensorflow.py | 6 +++--- tests/python/frontend/tensorflow/test_control_flow.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 2fd6fccc8de3d..d68f4bd7e731e 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -409,12 +409,12 @@ def _impl(inputs, attr, params, mod): # Dilation2d def _dilation2d(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): if 'data_format' not in attr: attr['data_format'] = 'NHWC' - input_shape = attr['_input_shapes'][inputs[0]] - weights_shape = attr['_input_shapes'][inputs[1]] + input_shape = _infer_shape(inputs[0], mod) + weights_shape = _infer_shape(inputs[1], mod) if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] diff --git a/tests/python/frontend/tensorflow/test_control_flow.py b/tests/python/frontend/tensorflow/test_control_flow.py index bb6fbabfdf4a4..9777a8dc4462c 100644 --- a/tests/python/frontend/tensorflow/test_control_flow.py +++ b/tests/python/frontend/tensorflow/test_control_flow.py @@ -312,7 +312,7 @@ def test_vanilla_loop_bound(): dtype = "float32" dname = "data" np_data = np.random.uniform(size=dshape).astype(dtype) - data = tf.compat.v1.placeholder(shape=dshape, dtype=dtype, name=dname) + data = tf.placeholder(shape=dshape, dtype=dtype, name=dname) x = tf.slice(data, [1, 4], [1, 4]) outer = x + 5.0 def body(x, y): @@ -339,7 +339,7 @@ def test_nested_loop_bound(): dtype = "float32" dname = "data" np_data = np.random.uniform(size=dshape).astype(dtype) - data = tf.compat.v1.placeholder(shape=dshape, dtype=dtype, name=dname) + data = tf.placeholder(shape=dshape, dtype=dtype, name=dname) x = tf.slice(data, [1, 4], [1, 4]) outer = x + 5.0 def body(x, y):