From 8e60ac322192e138d11a2f5fe8bf311df3ef28f3 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Mon, 11 Nov 2019 08:23:23 -0800 Subject: [PATCH] Fix tf reshape (#4285) * Fix tf reshape * Fix test * Fix pylint * Fix pylint --- python/tvm/relay/frontend/tensorflow.py | 26 ++++++++++--------- .../frontend/tensorflow/test_forward.py | 12 +++++++++ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 0abcb09d6ace..837b8d3782e4 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition +# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except """TF: Tensorflow frontend.""" from __future__ import absolute_import as _abs from __future__ import print_function @@ -613,22 +613,24 @@ def _reshape(): def _impl(inputs, attr, params): pop_node = inputs.pop(1) - # We use reshape_like directly to deal with dynamic shape. - if isinstance(pop_node, tvm.relay.expr.Call): - if "shape_of" not in str(pop_node.op): - raise RuntimeError("If shape operator is used in reshape to " - "express reshape_like, shape_of must be " - "the direct ancestor of reshape when input " - "shape is symbolic.") - return _op.reshape_like(inputs[0], pop_node.args[0]) - try: shape_arg = _get_tuple_param(params, pop_node) except AttributeError: # Shape operator is already pruned, hence # try to infer shape by precompute prune if possible. - params_new = _infer_value(pop_node, params) - shape_arg = tuple(params_new.asnumpy().astype('int64').flatten()) + try: + params_new = _infer_value(pop_node, params) + shape_arg = tuple(params_new.asnumpy().astype('int64').flatten()) + except Exception: + # Deal with symbolic shape case. + # Currently only shape_of can be the direct ancestor. + if not isinstance(pop_node, tvm.relay.expr.Call) or \ + "shape_of" not in str(pop_node.op): + raise RuntimeError("If shape operator is used in reshape to " + "express reshape_like, shape_of must be " + "the direct ancestor of reshape when input " + "shape is symbolic.") + return _op.reshape_like(inputs[0], pop_node.args[0]) return AttrCvt( op_name="reshape", extras={'newshape': shape_arg}, diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index c397d05f62ef..ce1d326f87f4 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -551,6 +551,17 @@ def _test_reshape(data, out_shape): compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0') +def _test_reshape_with_call(): + """ relay.expr.Call as shape """ + data = np.zeros((6, 4, 2)) + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out_shape = tf.constant([1, 2, 3], dtype="int32") + out_shape = tf.multiply(out_shape, 2) + array_ops.reshape(in_data, out_shape) + + compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0') + def _test_reshape_like(data, shape_like): """ A special case for reshape. """ @@ -567,6 +578,7 @@ def test_forward_reshape(): _test_reshape(np.arange(6), [-1, 2]) _test_reshape(np.arange(6), [3, -1]) _test_reshape(np.arange(6), [-1]) + _test_reshape_with_call() _test_reshape_like(np.zeros((3, 6)), np.zeros((9, 2))) #######################################################################