diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 09217870b07ad..ac60a1f7bb514 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -114,7 +114,7 @@ def convert(self, v): def __call__(self, args, attrs, type_args): if attrs is None: attrs = {} - if self.operator in (op.strided_slice): + if self.operator in (op.strided_slice,): x = self.operator(*args) elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to): x = self.operator(*args, dtype=attrs["dtype"]) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index d104c1b1c2f8b..f1cecda992c6d 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -272,82 +272,11 @@ def _reshape_shape_func_input_shape(data_shape, newshape, ndim): out[infer_idx] = old_size // new_size return out -@script -def _reshape_shape_func_input_data(data, newshape, ndim): - out = output_tensor((ndim,), "int64") - data_shape = allocate((len(data.shape),), "int64") - for x in const_range(len(data.shape)): - data_shape[x] = int64(data.shape[x]) - src_idx = 0 - dst_idx = 0 - infer_idx = -1 - copy = False - skip = 0 - for i in const_range(len(newshape)): - if skip > 0: - skip -= 1 - elif newshape[i] > 0: - out[dst_idx] = int64(newshape[i]) - src_idx += 1 - dst_idx += 1 - elif newshape[i] == 0: - out[dst_idx] = data_shape[src_idx] - src_idx += 1 - dst_idx += 1 - elif newshape[i] == -1: - assert infer_idx < 0, "One and only one dim can be inferred" - out[dst_idx] = int64(1) - infer_idx = i - dst_idx += 1 - elif newshape[i] == -2: - copy = True - elif newshape[i] == -3: - assert data_shape.shape[0] - src_idx > 1, \ - "Not enough dims in input shape for -3" - out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1] - src_idx += 2 - dst_idx += 1 - elif newshape[i] == -4: - assert len(newshape) - i > 2, "Not enough dims in new shape for -4" - if newshape[i+1] == -1: - assert newshape[i+2] != -1, "Split dims cannot both be -1." - out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2]) - out[dst_idx+1] = int64(newshape[i+2]) - else: - out[dst_idx] = int64(newshape[i+1]) - if newshape[i+2] == -1: - out[dst_idx+1] = data_shape[src_idx] // int64(newshape[i+1]) - else: - out[dst_idx+1] = int64(newshape[i+2]) - assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\ - "Product of split dims doesn't match to input dim" - src_idx += 1 - dst_idx += 2 - skip = 2 - else: - assert False, "Invalid special values in new shape" - if len(data_shape.shape) > 0: - # if data is not constant, we can then handle -1 and -2 - if copy: - for i in range(src_idx, data_shape.shape[0]): - out[dst_idx] = data_shape[i] - dst_idx += 1 - if infer_idx >= 0: - old_size = int64(1) - for i in const_range(data_shape.shape[0]): - old_size *= data_shape[i] - new_size = int64(1) - for i in const_range(out.shape[0]): - new_size *= out[i] - out[infer_idx] = old_size // new_size - return out - -@_reg.register_shape_func("reshape", True) +@_reg.register_shape_func("reshape", False) def reshape_shape_func(attrs, inputs, out_ndims): - if attrs.newshape is None: - return [_reshape_shape_func_input_data(*inputs, out_ndims[0])] + newshape = get_const_tuple(attrs.newshape) return [_reshape_shape_func_input_shape(inputs[0], - convert(attrs.newshape), + convert(newshape), out_ndims[0])] @script diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 864917cefc95b..4d746c755734b 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -20,7 +20,6 @@ from . import _make from ..expr import TupleWrapper, const -from ...tir import expr as _expr def cast(data, dtype): diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index f034833b09055..7b3f1957811b1 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -44,7 +44,7 @@ class DynamicToStaticMutator : public MixedModeMutator { attrs->newshape = ToVector(shape->data); attrs->reverse = false; static const Op& reshape = Op::Get("reshape"); - return Call(reshape, call_node->args, Attrs(attrs), {}); + return Call(reshape, {call_node->args[0]}, Attrs(attrs), {}); } } return post; diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 8e535a692b882..7f83cefbfb418 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -167,12 +167,11 @@ def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newsha newshape_var = relay.var('newshape', shape=(len(newshape),), dtype='int64') params.append(newshape_var) args.append(np.array(newshape, dtype='int64')) - newshape = newshape_var - - y = relay.reshape(relu_x, newshape=newshape) + y = relay.dyn.reshape(relu_x, newshape_var) + else: + y = relay.reshape(relu_x, newshape=newshape) mod = tvm.IRModule() mod["main"] = relay.Function(params, y) - for kind in ["debug", "vm"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") result = ex.evaluate()(*args).asnumpy() @@ -184,9 +183,9 @@ def test_any_reshape(): # Variable newshape only supports that output rank is the same as newshape verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24), variable_newshape) verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12), variable_newshape) - verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4), variable_newshape) verify_any_reshape(any_dims(3), (0, -2), (2, 3, 4), (2, 3, 4)) verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12)) + verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4)) def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): x = relay.var('x', shape=x_shape, dtype=dtype)