From 526b5a519225ec09868e31211a917f7c51d40e52 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 21 Jul 2020 10:10:16 -0700 Subject: [PATCH] [Relay][VM] Add ReshapeTensor instruction in the VM to replace the reshape op (#6089) * [VM] Add reshape tensor instruction * update * lint * fix * fix --- include/tvm/relay/attrs/vm.h | 11 +++ include/tvm/runtime/vm.h | 14 ++++ python/tvm/relay/backend/compile_engine.py | 4 +- python/tvm/relay/backend/vm.py | 4 +- python/tvm/relay/build_module.py | 2 +- python/tvm/relay/op/vm/vm.py | 17 +++++ python/tvm/relay/transform/memory_alloc.py | 80 +++++++++++++++++----- python/tvm/relay/ty.py | 4 +- src/relay/analysis/util.cc | 2 +- src/relay/backend/vm/compiler.cc | 10 +++ src/relay/op/tensor/transform.cc | 2 + src/relay/op/vm/vm.cc | 37 ++++++++++ src/runtime/vm/executable.cc | 10 +++ src/runtime/vm/vm.cc | 54 +++++++++++++-- tests/python/relay/test_vm.py | 51 ++++++++++++++ 15 files changed, 270 insertions(+), 32 deletions(-) diff --git a/include/tvm/relay/attrs/vm.h b/include/tvm/relay/attrs/vm.h index 9144f4734e12..7eb1008004de 100644 --- a/include/tvm/relay/attrs/vm.h +++ b/include/tvm/relay/attrs/vm.h @@ -42,6 +42,17 @@ struct ShapeFuncAttrs : public tvm::AttrsNode { } }; +/*! + * \brief Attributes for VM reshape_tensor operator. + */ +struct ReshapeTensorAttrs : public tvm::AttrsNode { + Array newshape; + + TVM_DECLARE_ATTRS(ReshapeTensorAttrs, "relay.attrs.ReshapeTensorAttrs") { + TVM_ATTR_FIELD(newshape).describe("The new shape of output tensor"); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_VM_H_ diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 0cce533afe62..cb98715c615f 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -115,6 +115,7 @@ enum class Opcode { Fatal = 15U, AllocStorage = 16U, ShapeOf = 17U, + ReshapeTensor = 18U, }; /*! \brief A single virtual machine instruction. @@ -249,6 +250,10 @@ struct Instruction { struct /* ShapeOf Operands */ { RegName tensor; } shape_of; + struct /* ReshapeTensor Operands */ { + RegName tensor; + RegName newshape; + } reshape_tensor; }; /*! @@ -401,6 +406,15 @@ struct Instruction { */ static Instruction ShapeOf(RegName tensor, RegName dst); + /*! + * \brief Reshape the tensor given the new shape. + * \param tensor The input tensor. + * \param newshape The shape tensor. + * \param dst The destination to store the output tensor with new shape. + * \return The reshape tensor instruction. + */ + static Instruction ReshapeTensor(RegName tensor, RegName newshape, RegName dst); + Instruction(); Instruction(const Instruction& instr); Instruction& operator=(const Instruction& instr); diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 8e6698e4a164..25c75b16c7ef 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -246,9 +246,9 @@ def lower_call(call, inputs, target): new_fields.append(field) ret_type = _ty.TupleType(new_fields) - is_dyn = _ty.type_has_any(call.checked_type) + is_dyn = _ty.is_dynamic(call.checked_type) for arg in call.args: - is_dyn = is_dyn or _ty.type_has_any(arg.checked_type) + is_dyn = is_dyn or _ty.is_dynamic(arg.checked_type) # check if in the AutoTVM tracing mode, and disable if op is not in wanted list env = autotvm.task.TaskExtractEnv.current diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 75a11b31e628..16d472452549 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -27,7 +27,7 @@ import tvm.runtime.vm as vm_rt from tvm import autotvm from tvm.relay import expr as _expr -from tvm.relay.ty import type_has_any +from tvm.relay.ty import is_dynamic from tvm.relay.backend.interpreter import Executor from . import _vm @@ -257,7 +257,7 @@ def _make_executor(self, expr=None): def _vm_wrapper(*args, **kwargs): args = self._convert_args(main, args, kwargs) ret_type = self.mod["main"].checked_type.ret_type - if type_has_any(ret_type) and "llvm" not in str(self.target) and "arm" not in str( + if is_dynamic(ret_type) and "llvm" not in str(self.target) and "arm" not in str( self.target): raise ValueError( "Virtual Machine only supports dynamic graphs on CPU, got output type", diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 4ffabf41c3bf..2f285efc8aa2 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -359,7 +359,7 @@ def _make_executor(self, expr=None): if expr: self.mod["main"] = expr ret_type = self.mod["main"].checked_type.ret_type - if _ty.type_has_any(ret_type): + if _ty.is_dynamic(ret_type): raise ValueError("Graph Runtime only supports static graphs, got output type", ret_type) num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1 diff --git a/python/tvm/relay/op/vm/vm.py b/python/tvm/relay/op/vm/vm.py index 761188ace03a..0fb7acec314e 100644 --- a/python/tvm/relay/op/vm/vm.py +++ b/python/tvm/relay/op/vm/vm.py @@ -81,3 +81,20 @@ def shape_func(func, inputs, outputs, is_inputs): The shape function expression. """ return _ffi_api.shape_func(func, inputs, outputs, is_inputs) + + +def reshape_tensor(data, shape, newshape): + """Invoke the VM ReshapeTensor instruction. + + Parameters + ---------- + data : tvm.relay.Expr + The input data. + + shape : tvm.relay.Expr + The newshape tensor. + + newshape : List[tvm.ir.PrimExpr] + The new shape. + """ + return _ffi_api.reshape_tensor(data, shape, newshape) diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 805905c0c18f..ae7db3384214 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -19,7 +19,7 @@ A pass for manifesting explicit memory allocations. """ import numpy as np -from ..expr_functor import ExprMutator +from ..expr_functor import ExprVisitor, ExprMutator from ..scope_builder import ScopeBuilder from . import transform from .. import op @@ -38,6 +38,31 @@ def is_primitive(call): return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \ hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1 + +class CheckReshapeOnly(ExprVisitor): + """A pass to check if the fused op contains only reshape ops.""" + def __init__(self): + super().__init__() + self._reshape_ops = [op.get("reshape"), op.get("contrib_reverse_reshape"), + op.get("dyn.reshape")] + self.reshape_only = True + + def visit_call(self, call): + if not self.reshape_only: + return + if call.op not in self._reshape_ops: + self.reshape_only = False + for arg in call.args: + self.visit(arg) + + +def is_reshape_only(func): + """Check if the primitive function contains only reshape ops.""" + check = CheckReshapeOnly() + check.visit(func) + return check.reshape_only + + class ManifestAllocPass(ExprMutator): """A pass for explicitly manifesting all memory allocations in Relay.""" @@ -45,6 +70,7 @@ def __init__(self, target_host): self.invoke_tvm = op.vm.invoke_tvm_op self.shape_func = op.vm.shape_func self.shape_of = op.vm.shape_of + self.reshape_tensor = op.vm.reshape_tensor self.scopes = [ScopeBuilder()] self.target_host = target_host self.default_context = cpu(0) @@ -121,8 +147,8 @@ def visit_let(self, let): return scope.get() - def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type): - """Generate the code for invoking a TVM op with a dynamic shape.""" + def emit_shape_func(self, scope, func, new_args): + """Insert the shape function given a primitive function.""" shape_func_ins = [] engine = compile_engine.get() cfunc = engine.lower_shape_func(func, self.target_host) @@ -165,9 +191,14 @@ def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type): expr.Tuple(out_shapes), is_inputs) scope.let("shape_func", shape_call) + return out_shapes + + def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type): + """Generate the code for invoking a TVM op with a dynamic shape.""" + out_shapes = self.emit_shape_func(scope, func, new_args) storages = [] - for out_shape, out_type in zip(out_shapes, out_types): + for i, (out_shape, out_type) in enumerate(zip(out_shapes, out_types)): size = self.compute_storage_in_relay( out_shape, out_type.dtype) alignment = self.compute_alignment(out_type.dtype) @@ -191,8 +222,18 @@ def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type): scope.let("", invoke) return to_tuple_type(ret_type, tuple_outs.fields) + def emit_reshape_tensor(self, scope, func, new_args, ret_type): + if self.is_dynamic(ret_type): + out_shapes = self.emit_shape_func(scope, func, new_args) + shape_expr = out_shapes[0] + else: + # constant output shape + shape = [int(dim) for dim in ret_type.shape] + shape_expr = expr.const(shape, dtype=self.compute_dtype) + return self.reshape_tensor(new_args[0], shape_expr, ret_type.shape) + def is_dynamic(self, ret_type): - is_dynamic = ty.type_has_any(ret_type) + is_dynamic = ty.is_dynamic(ret_type) # TODO(@jroesch): restore this code, more complex then it seems # for arg in call.args: # is_dynamic = is_dynamic or arg.checked_type.is_dynamic() @@ -208,22 +249,25 @@ def visit_call(self, call): ret_type = call.checked_type out_types = flatten_tuple_type(ret_type) + if is_reshape_only(call.op): + # Handle fused op that only contains reshape op + return self.emit_reshape_tensor(scope, call.op, new_args, ret_type) + if self.is_dynamic(ret_type): # Handle dynamic case. return self.dynamic_invoke(scope, call.op, ins, new_args, out_types, ret_type) - else: - # Handle static case. - outs = [] - for i, out_ty in enumerate(out_types): - out = self.make_static_allocation(scope, out_ty, i) - outs.append(out) - - output = expr.Tuple(outs) - invoke = self.invoke_tvm(call.op, ins, output) - scope.let("", invoke) - return to_tuple_type(ret_type, output.fields) - else: - return super().visit_call(call) + + # Handle static case. + outs = [] + for i, out_ty in enumerate(out_types): + out = self.make_static_allocation(scope, out_ty, i) + outs.append(out) + + output = expr.Tuple(outs) + invoke = self.invoke_tvm(call.op, ins, output) + scope.let("", invoke) + return to_tuple_type(ret_type, output.fields) + return super().visit_call(call) @transform.function_pass(opt_level=0) diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 19cc10aba41e..84bd1ee63fe4 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -25,8 +25,8 @@ Any = _ffi_api.Any -def type_has_any(tensor_type): - """Check whether type has any as a shape. +def is_dynamic(tensor_type): + """Check whether type has any or symbolic variables as a shape. tensor_type : Type The type to be inspected diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index c8dbb49e15db..b1c512478072 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -424,7 +424,7 @@ struct IsDynamicVisitor : public TypeVisitor { bool is_dyn{false}; void VisitType_(const TensorTypeNode* tt) { for (auto dim : tt->shape) { - if (dim.as()) { + if (dim.as() == nullptr) { is_dyn = true; break; } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 585b8033be8d..ab11c6c65919 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -284,6 +284,7 @@ class VMFunctionCompiler : ExprFunctor { case Opcode::AllocClosure: case Opcode::AllocStorage: case Opcode::ShapeOf: + case Opcode::ReshapeTensor: case Opcode::Move: case Opcode::InvokeClosure: last_register_ = instr.dst; @@ -601,6 +602,15 @@ class VMFunctionCompiler : ExprFunctor { this->VisitExpr(args[0]); Emit(Instruction::ShapeOf(last_register_, NewRegister())); }) + .Match("vm.reshape_tensor", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 2u); + this->VisitExpr(args[0]); + auto tensor_reg = last_register_; + this->VisitExpr(args[1]); + auto shape_reg = last_register_; + Emit(Instruction::ReshapeTensor(tensor_reg, shape_reg, NewRegister())); + }) .Match("memory.kill", [](const Array& args, const Attrs& attrs, const Array& type_arg) { LOG(FATAL) << "memory.kill is not yet supported"; diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 1b072532d8b7..7ebca6635264 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -576,6 +576,8 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, infer_dim = indexdiv(infer_dim, oshape[i]); } } + arith::Analyzer ana; + infer_dim = ana.Simplify(infer_dim); oshape.Set(infer_idx, infer_dim); } diff --git a/src/relay/op/vm/vm.cc b/src/relay/op/vm/vm.cc index ffe276e4493c..6e611d623d35 100644 --- a/src/relay/op/vm/vm.cc +++ b/src/relay/op/vm/vm.cc @@ -37,6 +37,7 @@ namespace tvm { namespace relay { +// vm.shape_func TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs); RELAY_REGISTER_OP("vm.shape_of") @@ -133,6 +134,7 @@ RELAY_REGISTER_OP("vm.shape_func") return {topi::identity(inputs[0])}; }); +// vm.invoke_tvm_op bool InvokeTVMOpRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4u); @@ -181,5 +183,40 @@ RELAY_REGISTER_OP("vm.invoke_tvm_op") return {topi::identity(inputs[0])}; }); +// vm.reshape +TVM_REGISTER_NODE_TYPE(ReshapeTensorAttrs); + +bool ReshapeTensorRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3u); + auto reshape_attrs = attrs.as(); + CHECK(reshape_attrs); + auto tt = types[0].as(); + CHECK(tt) << "input must be tensor type"; + reporter->Assign(types[2], TensorType(reshape_attrs->newshape, tt->dtype)); + return true; +} + +RELAY_REGISTER_OP("vm.reshape_tensor") + .describe(R"code(Use VM reshape_tensor instruction to reshape the tensor. +)code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor") + .add_argument("shape", "Tensor", "The output shape tensor") + .add_type_rel("ReshapeTensor", ReshapeTensorRel) + .set_support_level(10) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("TNonComputational", true) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); + +TVM_REGISTER_GLOBAL("relay.op.vm.reshape_tensor") + .set_body_typed([](Expr data, Expr shape, Array newshape) { + static const Op& op = Op::Get("vm.reshape_tensor"); + auto attrs = make_object(); + attrs->newshape = std::move(newshape); + return Call(op, {data, shape}, Attrs(attrs), {}); + }); + } // namespace relay } // namespace tvm diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index f5204044ac78..4944778110d5 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -422,6 +422,11 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.assign({instr.shape_of.tensor, instr.dst}); break; } + case Opcode::ReshapeTensor: { + // Number of fields = 3 + fields.assign({instr.reshape_tensor.tensor, instr.reshape_tensor.newshape, instr.dst}); + break; + } default: LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); break; @@ -693,6 +698,11 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { DCHECK_EQ(instr.fields.size(), 2U); return Instruction::ShapeOf(instr.fields[0], instr.fields[1]); } + case Opcode::ReshapeTensor: { + // Number of fields = 3 + DCHECK_EQ(instr.fields.size(), 3U); + return Instruction::ReshapeTensor(instr.fields[0], instr.fields[1], instr.fields[2]); + } default: LOG(FATAL) << "Invalid opcode" << instr.opcode; return Instruction(); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 6b10a89d969a..24fc1107ae86 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -148,6 +148,10 @@ Instruction::Instruction(const Instruction& instr) { case Opcode::ShapeOf: this->shape_of.tensor = instr.shape_of.tensor; return; + case Opcode::ReshapeTensor: + this->reshape_tensor.tensor = instr.reshape_tensor.tensor; + this->reshape_tensor.newshape = instr.reshape_tensor.newshape; + return; default: std::ostringstream out; out << "Invalid instruction " << static_cast(instr.op); @@ -265,6 +269,7 @@ Instruction::~Instruction() { case Opcode::LoadConsti: case Opcode::AllocStorage: case Opcode::ShapeOf: + case Opcode::ReshapeTensor: case Opcode::Fatal: return; case Opcode::AllocTensor: @@ -320,7 +325,7 @@ Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index out Instruction Instruction::AllocTensor(RegName storage, RegName offset, const std::vector& shape, DLDataType dtype, - Index dst) { + RegName dst) { Instruction instr; instr.op = Opcode::AllocTensor; instr.dst = dst; @@ -336,7 +341,7 @@ Instruction Instruction::AllocTensor(RegName storage, RegName offset, } Instruction Instruction::AllocTensorReg(RegName storage, RegName offset, RegName shape_register, - DLDataType dtype, Index dst) { + DLDataType dtype, RegName dst) { Instruction instr; instr.op = Opcode::AllocTensorReg; instr.dst = dst; @@ -348,7 +353,7 @@ Instruction Instruction::AllocTensorReg(RegName storage, RegName offset, RegName } Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType dtype_hint, - Index dst) { + RegName dst) { Instruction instr; instr.op = Opcode::AllocStorage; instr.dst = dst; @@ -358,7 +363,7 @@ Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType return instr; } -Instruction Instruction::ShapeOf(RegName tensor, Index dst) { +Instruction Instruction::ShapeOf(RegName tensor, RegName dst) { Instruction instr; instr.op = Opcode::ShapeOf; instr.dst = dst; @@ -366,8 +371,17 @@ Instruction Instruction::ShapeOf(RegName tensor, Index dst) { return instr; } +Instruction Instruction::ReshapeTensor(RegName tensor, RegName newshape, RegName dst) { + Instruction instr; + instr.op = Opcode::ReshapeTensor; + instr.dst = dst; + instr.reshape_tensor.tensor = tensor; + instr.reshape_tensor.newshape = newshape; + return instr; +} + Instruction Instruction::AllocADT(Index tag, Index num_fields, - const std::vector& datatype_fields, Index dst) { + const std::vector& datatype_fields, RegName dst) { Instruction instr; instr.op = Opcode::AllocADT; instr.dst = dst; @@ -381,7 +395,7 @@ Instruction Instruction::AllocADT(Index tag, Index num_fields, } Instruction Instruction::AllocClosure(Index func_index, Index free_vars, - const std::vector& free_var_register, Index dst) { + const std::vector& free_var_register, RegName dst) { Instruction instr; instr.op = Opcode::AllocClosure; instr.dst = dst; @@ -604,6 +618,11 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { os << "shape_of $" << instr.dst << " $" << instr.shape_of.tensor; break; } + case Opcode::ReshapeTensor: { + os << "reshape_tensor $" << instr.dst << " $" << instr.reshape_tensor.tensor << " $" + << instr.reshape_tensor.newshape; + break; + } default: LOG(FATAL) << "should never hit this case" << static_cast(instr.op); break; @@ -1103,6 +1122,29 @@ void VirtualMachine::RunLoop() { goto main_loop; } } + case Opcode::ReshapeTensor: { + DLContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + auto tensor_obj = ReadRegister(instr.reshape_tensor.tensor); + NDArray tensor_arr = Downcast(tensor_obj); + // Read the shape from shape tensor + auto shape_obj = ReadRegister(instr.reshape_tensor.newshape); + NDArray shape_tensor = Downcast(CopyTo(shape_obj, cpu_ctx)); + const DLTensor* dl_tensor = shape_tensor.operator->(); + CHECK_EQ(dl_tensor->dtype.code, 0u); + CHECK_EQ(dl_tensor->dtype.bits, 64); + int64_t* dims = reinterpret_cast(dl_tensor->data); + int64_t ndim = shape_tensor->shape[0]; + std::vector shape(dims, dims + ndim); + // Reshape the input tensor + auto out_tensor = tensor_arr.CreateView(shape, tensor_arr->dtype); + WriteRegister(instr.dst, out_tensor); + pc_++; + goto main_loop; + } + default: + LOG(FATAL) << "Unknown instruction opcode: " << int(instr.op); } } } diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index f2b15ec26f32..91214cbbea3b 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -39,7 +39,11 @@ def check_result(args, expected_result, mod=None): expected_result: The expected result of running the expression. """ + # TODO(@zhiics, @icemelon9): Disable the gpu test for now until the heterogeneous support + # is ready for target, ctx in ctx_list(): + if "cuda" in target: + continue vm = relay.create_executor('vm', ctx=ctx, target=target, mod=mod) rts_result = vm.evaluate()(*args) @@ -622,5 +626,52 @@ def body_with_free_var(i, acc): mod["main"] = relay.Function(relay.analysis.free_vars(ret), ret) check_result(args, expected, mod=mod) +def test_vm_reshape_tensor(): + x_np = np.random.uniform(size=(8, 16)).astype("float32") + x = relay.var("x", shape=(8, 16), dtype="float32") + y = relay.reshape(x, [-1, 4, 8]) + mod = tvm.IRModule() + mod["main"] = relay.Function([x], y) + with tvm.transform.PassContext(opt_level=3): + exec = relay.vm.compile(mod, "llvm") + assert "reshape_tensor" in exec.bytecode + check_result([x_np], x_np.reshape([4, 4, 8]), mod) + + x = relay.var("x", shape=(8, 16), dtype="float32") + y = relay.reshape(x, [16, -1]) + y = relay.reverse_reshape(y, [-1, 4, 0]) + mod = tvm.IRModule() + mod["main"] = relay.Function([x], y) + with tvm.transform.PassContext(opt_level=3): + exec = relay.vm.compile(mod, "llvm") + assert exec.bytecode.count("reshape_tensor") == 1 + check_result([x_np], x_np.reshape([4, 4, 8]), mod) + + # reshape with symbolic/any shape + for n in [tvm.tir.Any(), tvm.te.size_var('n')]: + x = relay.var("x", shape=(n, 16), dtype="float32") + y = relay.reshape(x, [-1, 4]) + y = relay.reshape(y, [0, 2, -1]) + mod = tvm.IRModule() + mod["main"] = relay.Function([x], y) + with tvm.transform.PassContext(opt_level=3): + exec = relay.vm.compile(mod, "llvm") + assert exec.bytecode.count("reshape_tensor") == 1 + check_result([x_np], x_np.reshape([32, 2, 2]), mod) + + # dyn.reshape + x = relay.var("x", shape=(8, 16), dtype="float32") + y = relay.var("y", shape=(3,), dtype="int32") + z = relay.reshape(x, [-1, 4, 8]) + z = relay.reshape(z, y) + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + with tvm.transform.PassContext(opt_level=3): + exec = relay.vm.compile(mod, "llvm") + assert exec.bytecode.count("reshape_tensor") == 2 + assert "reshape_tensor" in exec.bytecode + y_np = np.array([8, 2, 8]).astype("int32") + check_result([x_np, y_np], x_np.reshape([8, 2, 8]), mod) + if __name__ == "__main__": pytest.main([__file__])