Skip to content

Commit

Permalink
Fix tf reshape (apache#4285)
Browse files Browse the repository at this point in the history
* Fix tf reshape

* Fix test

* Fix pylint

* Fix pylint
  • Loading branch information
kevinthesun authored and Xingyu Zhou committed Nov 13, 2019
1 parent 1688a7b commit 8e60ac3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
26 changes: 14 additions & 12 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
"""TF: Tensorflow frontend."""
from __future__ import absolute_import as _abs
from __future__ import print_function
Expand Down Expand Up @@ -613,22 +613,24 @@ def _reshape():
def _impl(inputs, attr, params):
pop_node = inputs.pop(1)

# We use reshape_like directly to deal with dynamic shape.
if isinstance(pop_node, tvm.relay.expr.Call):
if "shape_of" not in str(pop_node.op):
raise RuntimeError("If shape operator is used in reshape to "
"express reshape_like, shape_of must be "
"the direct ancestor of reshape when input "
"shape is symbolic.")
return _op.reshape_like(inputs[0], pop_node.args[0])

try:
shape_arg = _get_tuple_param(params, pop_node)
except AttributeError:
# Shape operator is already pruned, hence
# try to infer shape by precompute prune if possible.
params_new = _infer_value(pop_node, params)
shape_arg = tuple(params_new.asnumpy().astype('int64').flatten())
try:
params_new = _infer_value(pop_node, params)
shape_arg = tuple(params_new.asnumpy().astype('int64').flatten())
except Exception:
# Deal with symbolic shape case.
# Currently only shape_of can be the direct ancestor.
if not isinstance(pop_node, tvm.relay.expr.Call) or \
"shape_of" not in str(pop_node.op):
raise RuntimeError("If shape operator is used in reshape to "
"express reshape_like, shape_of must be "
"the direct ancestor of reshape when input "
"shape is symbolic.")
return _op.reshape_like(inputs[0], pop_node.args[0])
return AttrCvt(
op_name="reshape",
extras={'newshape': shape_arg},
Expand Down
12 changes: 12 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,17 @@ def _test_reshape(data, out_shape):

compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0')

def _test_reshape_with_call():
""" relay.expr.Call as shape """
data = np.zeros((6, 4, 2))
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out_shape = tf.constant([1, 2, 3], dtype="int32")
out_shape = tf.multiply(out_shape, 2)
array_ops.reshape(in_data, out_shape)

compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0')

def _test_reshape_like(data, shape_like):
""" A special case for reshape. """

Expand All @@ -567,6 +578,7 @@ def test_forward_reshape():
_test_reshape(np.arange(6), [-1, 2])
_test_reshape(np.arange(6), [3, -1])
_test_reshape(np.arange(6), [-1])
_test_reshape_with_call()
_test_reshape_like(np.zeros((3, 6)), np.zeros((9, 2)))

#######################################################################
Expand Down

0 comments on commit 8e60ac3

Please sign in to comment.