diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 53b8e3c42ed6..3ca505b3d01d 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -24,7 +24,6 @@ from . import parser from . import analysis from . import transform -from . import vm_compiler # Expr @@ -62,7 +61,6 @@ ExecBuilder = exec_builder.ExecBuilder VirtualMachine = vm.VirtualMachine load_exec_from_file = vm.load_exec_from_file -compile = vm_compiler.compile # Operator from .op.base import call_dps diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index cb3eb79ed883..3f0f654d0018 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -375,7 +375,7 @@ def parse_primexpr(self, expr: ast.Expr, bind_free_vars: bool) -> tir.PrimExpr: return var elif bind_free_vars: # introduce TIR variable to scope, e.g. for func params or rx.call_packed - var = tir.Var(var_name, "int32", self.to_tvm_span(expr.span)) + var = tir.Var(var_name, "int64", self.to_tvm_span(expr.span)) self.scope[var_name] = var return var else: @@ -387,7 +387,7 @@ def parse_primexpr(self, expr: ast.Expr, bind_free_vars: bool) -> tir.PrimExpr: elif isinstance(expr, ast.Constant): if not isinstance(expr.value, int): self.report_error("only integer constants are supported", expr.span) - return tir.const(expr.value, "int32", self.to_tvm_span(expr.span)) + return tir.const(expr.value, "int64", self.to_tvm_span(expr.span)) elif isinstance(expr, ast.Call): if not isinstance(expr.func_name, ast.Op): @@ -823,7 +823,7 @@ def parse_attr(self, expr: ast.Attr) -> rx.Expr: """ if expr.field.name == "shape": obj = self.transform_expr(expr.object) - attrs = tvm.ir.attrs.make_node("relay.attrs.ShapeOfAttrs", dtype="int32") + attrs = tvm.ir.attrs.make_node("relay.attrs.ShapeOfAttrs", dtype="int64") return relay.Call( relay.op.get("shape_of"), [obj], attrs=attrs, span=self.to_tvm_span(expr.span) ) @@ -960,7 +960,7 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: elif isinstance(expr, ast.Constant): # FIXME(@altanh): use internal representation that doesn't have precision limits here if isinstance(expr.value, int): - return tir.IntImm("int32", expr.value, self.to_tvm_span(expr.span)) + return tir.IntImm("int64", expr.value, self.to_tvm_span(expr.span)) elif isinstance(expr.value, float): return tir.FloatImm("float32", expr.value, self.to_tvm_span(expr.span)) elif isinstance(expr.value, str): diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index 08237d51f850..210ea7735c8e 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import List, Optional, Union, Dict +from typing import List, Optional, Union, Dict, Tuple import tvm from tvm.runtime import Object, Device, Module, PackedFunc from tvm._ffi.base import _LIB, check_call @@ -64,7 +64,6 @@ def __init__( memory_cfg: Optional[Union[str, Dict[Device, str]]] = None, mod: Optional[Module] = None, ) -> None: - """ Construct a VirtualMachine wrapper object. @@ -133,3 +132,37 @@ def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) def __getitem__(self, key: str) -> PackedFunc: return self.module[key] + + +def build(mod: tvm.IRModule, + target: tvm.target.Target, + target_host: tvm.target.Target) -> Tuple[Executable, Module]: + """ + Build an IRModule to VM executable. + + Parameters + ---------- + mod: IRModule + The IR module. + + target : tvm.target.Target + A build target. + + target_host : tvm.target.Target + Host compilation target, if target is device. + When TVM compiles device specific program such as CUDA, + we also need host(CPU) side code to interact with the driver + to setup the dimensions and parameters correctly. + target_host is used to specify the host side codegen target. + By default, llvm is used if it is enabled, + otherwise a stackvm intepreter is used. + + Returns + ------- + ex: tvm.relax.vm.Exectuable + An executable that can be loaded by virtual machine. + lib: tvm.runtime.Module + A runtime module that contains generated code. + """ + ex, lib = _ffi_api.VMBuild(mod, target, target_host) + return ex, lib diff --git a/python/tvm/relax/vm_compiler.py b/python/tvm/relax/vm_compiler.py deleted file mode 100644 index 99afa966100f..000000000000 --- a/python/tvm/relax/vm_compiler.py +++ /dev/null @@ -1,70 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin -""" -The Relax Virtual Machine compiler. -""" -from typing import List, Optional, Union, Dict -import tvm -from . import vm, _ffi_api - - -def compile(mod: tvm.IRModule) -> vm.Executable: - """Compile the module to VM executable. A helper function for VMCompiler. - - Parameters - ---------- - mod : tvm.IRModule - The Relay module to build. - - Returns - ------- - exec : tvm.relax.Executable - The VM executable that contains the bytecode. - """ - compiler = VMCompiler() - compiler.compile(mod) - return compiler.get_exec() - - -class VMCompiler(object): - """Compiler that compiles module to VM executable.""" - - def __init__(self): - self.mod = _ffi_api.VMCompiler() - self._compile = self.mod["compile"] - self._get_exec = self.mod["get_executable"] - - def compile(self, mod: tvm.IRModule) -> None: - """Compile the module to VM executable. - - Parameters - ---------- - mod : tvm.IRModule - The IRModule to build. - """ - self._compile(mod) - - def get_exec(self) -> vm.Executable: - """Get the VM executable. - - Returns - ------- - exec : tvm.relax.Executable - The VM executable that contains bytecode. - """ - return self._get_exec() diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index ad1b1238e65c..2b43128feb9a 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -199,6 +199,10 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) { } Expr ExprMutator::VisitExpr_(const VarNode* op) { + auto it = var_remap_.find(GetRef(op)); + if (it != var_remap_.end()) { + return it->second; + } if (op->type_annotation.defined()) { Type type = this->VisitType(op->type_annotation.value()); if (!op->type_annotation.same_as(type)) { diff --git a/src/relax/transform/shape_lower.cc b/src/relax/transform/shape_lower.cc index f955842a1ab7..c251f8ad883b 100644 --- a/src/relax/transform/shape_lower.cc +++ b/src/relax/transform/shape_lower.cc @@ -33,7 +33,9 @@ namespace relax { class ShapeLowerMutator : public ExprMutator { public: - static DataType ShapeDType() { return DataType::Int(32); }; + static DataType ShapeDType() { + return DataType::Int(64); + }; explicit ShapeLowerMutator(IRModule mod) { mod_ = mod; } @@ -58,21 +60,24 @@ class ShapeLowerMutator : public ExprMutator { } void VisitMatchShape(const MatchShape& binding) override { - Expr value = binding->value; + Expr shape = ExprMutator::VisitExpr(binding->value); Array pattern = binding->pattern; Array indices; for (size_t i = 0; i < pattern.size(); ++i) { IntImm idx = expr2slot_.at(pattern[i]); indices.push_back(idx); } - builder_->Emit(Call(ExternFunc("decode_shape"), {value, shape_heap_, ShapeExpr(indices)}), "_"); + builder_->Emit(Call(ExternFunc("vm.builtin.decode_shape"), + {shape, shape_heap_, ShapeExpr(indices)}), "gv"); } Expr VisitExpr_(const ShapeExprNode* node) override { tir::PrimFunc func = CalculateShape(GetRef(node)); - GlobalVar shape_func_var(name_table_->GetUniqueName("shape_func")); + std::string shape_func_name = name_table_->GetUniqueName("shape_func"); + func = WithAttr(std::move(func), "global_symbol", runtime::String(shape_func_name)); + GlobalVar shape_func_var(shape_func_name); // TODO make sure shape_heap doesnt get redefined by local funcs? - builder_->Emit(Call(shape_func_var, {shape_heap_}), "_"); + builder_->Emit(Call(shape_func_var, {shape_heap_}), "gv"); ret_mod_->Add(shape_func_var, func); // construct shape @@ -80,8 +85,8 @@ class ShapeLowerMutator : public ExprMutator { for (PrimExpr e : node->values) { indices.push_back(expr2slot_.at(e)); } - return builder_->Emit(Call(ExternFunc("construct_shape"), {shape_heap_, ShapeExpr(indices)}), - "sh"); + return builder_->Emit(Call(ExternFunc("vm.builtin.make_shape"), + {shape_heap_, ShapeExpr(indices)}), "sh"); } Expr VisitExpr_(const FunctionNode* node) override { @@ -93,7 +98,7 @@ class ShapeLowerMutator : public ExprMutator { builder_->BeginBindingBlock(); builder_->Emit(VarBinding( - shape_heap_, Call(ExternFunc("relax.alloc_shape_heap"), {ShapeExpr({heap_size_})}))); + shape_heap_, Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})}))); Expr new_body = this->Mutate(node->body); @@ -106,7 +111,7 @@ class ShapeLowerMutator : public ExprMutator { new_body = seq->body; } - builder_->Emit(Call(ExternFunc("relax.free_shape_heap"), {shape_heap_}), "_"); + builder_->Emit(Call(ExternFunc("vm.builtin.free_shape_heap"), {shape_heap_}), "gv"); blocks.push_back(builder_->EndBlock()); new_body = SeqExpr(blocks, new_body); @@ -131,6 +136,7 @@ class ShapeLowerMutator : public ExprMutator { tir::Stmt body = tir::SeqStmt(seq); Array params{heap}; Type ret_type = VoidType(); + return tir::PrimFunc(params, body, ret_type, buffer_map); } @@ -176,7 +182,8 @@ class ShapeLowerMutator : public ExprMutator { Map expr2slot_; }; -TVM_REGISTER_GLOBAL("relax.transform.shape_lower").set_body_typed([](IRModule mod) { +TVM_REGISTER_GLOBAL("relax.transform.shape_lower") +.set_body_typed([](IRModule mod) { return ShapeLowerMutator(mod).Lower(); }); diff --git a/src/relax/vm/builtin.cc b/src/relax/vm/builtin.cc index 96392b6912ba..2dcfa05f1916 100644 --- a/src/relax/vm/builtin.cc +++ b/src/relax/vm/builtin.cc @@ -36,35 +36,41 @@ namespace relax_vm { using tvm::runtime::NDArray; -TVM_REGISTER_GLOBAL("vm.builtin.shape_of").set_body_typed([](NDArray arr) { return arr.Shape(); }); +TVM_REGISTER_GLOBAL("vm.builtin.shape_of") +.set_body_typed([](NDArray arr) { + return arr.Shape(); +}); + +TVM_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap") +.set_body_typed([](ShapeTuple size) { + return NDArray::Empty(size, DLDataType{kDLInt, 64, 1}, DLDevice{kDLCPU, 0}); +}); -TVM_REGISTER_GLOBAL("vm.builtin.alloc_heap").set_body_typed([](int64_t size) { - return NDArray::Empty(ShapeTuple({size}), DLDataType{kDLInt, 64, 1}, DLDevice{kDLCPU, 0}); +TVM_REGISTER_GLOBAL("vm.builtin.free_shape_heap") +.set_body_typed([](NDArray arr) { + return static_cast(const_cast(arr.get()))->DecRef(); }); -TVM_REGISTER_GLOBAL("vm.builtin.match_shape") -.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* rv) { - ShapeTuple shape = args[0]; - NDArray heap = args[1]; +TVM_REGISTER_GLOBAL("vm.builtin.decode_shape") +.set_body_typed([](ShapeTuple shape, NDArray heap, ShapeTuple indexes) { int64_t* heap_data = reinterpret_cast(heap.ToDLPack()->dl_tensor.data); - for (int i = 2; i < args.size(); ++i) { - int64_t heap_idx = args[i]; + for (size_t i = 0; i < indexes.size(); ++i) { + int64_t heap_idx = indexes[i]; ICHECK(heap_idx >= 0 && heap_idx < heap.Shape()[0]); - heap_data[heap_idx] = shape[i - 2]; + heap_data[heap_idx] = shape[i]; } }); TVM_REGISTER_GLOBAL("vm.builtin.make_shape") -.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* rv) { - NDArray heap = args[0]; +.set_body_typed([](NDArray heap, ShapeTuple indexes) { int64_t* heap_data = reinterpret_cast(heap.ToDLPack()->dl_tensor.data); std::vector shape; - for (int i = 1; i < args.size(); ++i) { - int64_t heap_idx = args[i]; + for (size_t i = 0; i < indexes.size(); ++i) { + int64_t heap_idx = indexes[i]; ICHECK(heap_idx >= 0 && heap_idx < heap.Shape()[0]); shape.push_back(heap_data[heap_idx]); } - *rv = ShapeTuple(shape); + return ShapeTuple(shape); }); TVM_REGISTER_GLOBAL("vm.builtin.alloc_storage") diff --git a/src/relax/vm/compiler.cc b/src/relax/vm/compiler.cc index 2ebe6a3fd919..6dff8259094d 100644 --- a/src/relax/vm/compiler.cc +++ b/src/relax/vm/compiler.cc @@ -24,8 +24,11 @@ #include "compiler.h" +#include #include #include +#include +#include #include #include @@ -37,9 +40,11 @@ namespace relax_vm { using namespace relax; -class VMFunctionCompiler : public ExprVisitor { +class VMCompilerImpl : public ExprVisitor { public: - explicit VMFunctionCompiler(ExecBuilderNode* builder) { builder_ = GetRef(builder); } + explicit VMCompilerImpl(ExecBuilderNode* builder) { + builder_ = GetRef(builder); + } protected: /*! \brief A counter for naming local functions. */ @@ -85,18 +90,15 @@ class VMFunctionCompiler : public ExprVisitor { EmitAllocTensor(call_node, var); } else { // Normal packed function without attributes - std::vector args; - for (size_t i = 0; i < call_node->args.size(); ++i) { - if (call_node->args[i].as()) { - auto reg = this->var_register_map_.find(Downcast(call_node->args[i])); - ICHECK(reg != this->var_register_map_.end()); - args.push_back(Instruction::Arg(Instruction::kRegister, reg->second)); - } - } + std::vector args = ConvertArgs(call_node); // TODO(@yuchen): what if the packed func has void return (no need to write to the dst // register)? builder_->EmitCall(name, args, NewRegister(var)); } + } else if (auto* gvar = call_node->op.as()) { + String name = gvar->name_hint; + std::vector args = ConvertArgs(call_node); + builder_->EmitCall(name, args, NewRegister(var)); } else { LOG(FATAL) << "TODO: support compiling everything other than extern functions."; } @@ -172,6 +174,31 @@ class VMFunctionCompiler : public ExprVisitor { return reg; } + std::vector ConvertArgs(const Call& call) { + std::vector ret; + const auto& args = call->args; + for (size_t i = 0; i < call->args.size(); ++i) { + if (args[i]->IsInstance()) { + auto reg = this->var_register_map_.find(Downcast(args[i])); + ICHECK(reg != this->var_register_map_.end()); + ret.push_back(Instruction::Arg(Instruction::kRegister, reg->second)); + } else if (args[i]->IsInstance()) { + std::vector shape; + for (PrimExpr e : Downcast(args[i])->values) { + shape.push_back(Downcast(e)->value); + } + auto shape_tuple = ShapeTuple(shape); + TVMRetValue shape_tuple_value; + shape_tuple_value = shape_tuple; + Index index = builder_->EmitConstant(shape_tuple_value); + ret.push_back(Instruction::Arg(Instruction::kConstIdx, index)); + } else { + LOG(FATAL) << "not supported argument type."; + } + } + return ret; + } + /*! \brief Internal ExecBuilder. */ relax::ExecBuilder builder_; /*! \brief Total number of virtual registers allocated. */ @@ -183,9 +210,9 @@ class VMFunctionCompiler : public ExprVisitor { PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "compile") { return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.num_args, 1); + ICHECK_EQ(args.num_args, 3); IRModule mod = args[0]; - this->Compile(mod); + this->Compile(mod, args[1], args[2]); }); } else if (name == "get_executable") { return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetExec(); }); @@ -195,31 +222,55 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtrfunctions) { - auto gvar = func.first; - if (!func.second->IsInstance()) { - continue; - } +void VMCompiler::Compile(IRModule mod, Target target, Target target_host) { + // Reset internal builder + builder_ = relax::ExecBuilderNode::Create(); + + IRModule tir_mod; + IRModule rx_mod; + for (auto& p : mod->functions) { + auto gvar = p.first; - VMFunctionCompiler func_compiler(); - if (auto* n = func.second.as()) { - auto func = GetRef(n); - auto func_compiler = VMFunctionCompiler(builder_.operator->()); - func_compiler.VisitExpr(func); + BaseFunc func = p.second; + if (func.as()) { + tir_mod->Add(gvar, func); + } else if (func.as()) { + rx_mod->Add(gvar, func); + } else { + LOG(FATAL) << "Cannot handle such function node now:\n" << func; } } + lib_ = tvm::build(tir_mod, target, target_host); + + VMCompilerImpl compiler(builder_.operator->()); + for (auto& p : rx_mod->functions) { + compiler.VisitExpr(p.second); + } } -Executable VMCompiler::GetExec() { return builder_->Get(); } +Executable VMCompiler::GetExec() { + return builder_->Get(); +} + +runtime::Module VMCompiler::GetLib() { + return lib_; +} runtime::Module CreateVMCompiler() { auto compiler = make_object(); return runtime::Module(compiler); } -TVM_REGISTER_GLOBAL("relax.VMCompiler").set_body_typed([]() { return CreateVMCompiler(); }); +Array Build(IRModule mod, Target target, Target target_host) { + auto compiler = make_object(); + compiler->Compile(mod, target, target_host); + Executable exec = compiler->GetExec(); + Module lib = compiler->GetLib(); + return Array({exec, lib}); +} + +TVM_REGISTER_GLOBAL("relax.VMBuild") +.set_body_typed(Build); } // namespace relax_vm } // namespace runtime diff --git a/src/relax/vm/compiler.h b/src/relax/vm/compiler.h index c4e10493f730..55036c4e6f37 100644 --- a/src/relax/vm/compiler.h +++ b/src/relax/vm/compiler.h @@ -25,6 +25,7 @@ #ifndef TVM_RELAX_VM_COMPILER_H_ #define TVM_RELAX_VM_COMPILER_H_ +#include #include #include #include @@ -35,18 +36,25 @@ namespace tvm { namespace runtime { namespace relax_vm { +using tvm::Target; + class VMCompiler : public runtime::ModuleNode { public: /*! * \brief Compile the functions in a Module. * \param mod Input IRModule to be compiled. */ - void Compile(IRModule mod); + void Compile(IRModule mod, Target target, Target target_host); /*! * \brief Get the compiled executable. * \return The compiled executable. */ Executable GetExec(); + /*! + * \brief Get the compiled library. + * \return The compiled lirary. + */ + Module GetLib(); virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); @@ -54,7 +62,9 @@ class VMCompiler : public runtime::ModuleNode { protected: /*! \brief Internal executable builder. */ - relax::ExecBuilder builder_ = relax::ExecBuilderNode::Create(); + relax::ExecBuilder builder_; + /*! \brief Built library. */ + runtime::Module lib_; }; } // namespace relax_vm diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index b0a2df064a53..8479467f1aa5 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -96,16 +96,14 @@ def test_explicit_memory_rewrite(): s2 = block.bindings[1].value assert s2.op.global_symbol == "test.op.identity" - # rx.parser.pretty_print(func) - @rx.script class Mod: def foo(x: Tensor[_, "float32"]) -> Shape: - relax.match_shape(x.shape, (n, m)) + sh = relax.call_packed("vm.builtin.shape_of", x) + relax.match_shape(sh, (n, m)) return (n * 2, m * 3) - def test_shape_lowering(): mod = Mod() new_mod = rx.transform.shape_lower(mod) @@ -115,7 +113,7 @@ def test_shape_lowering(): code = rx.parser.astext(new_mod) assert "alloc_shape_heap" in code assert "decode_shape" in code - assert "construct_shape" in code + assert "make_shape" in code if __name__ == "__main__": diff --git a/tests/python/relax/test_vm.py b/tests/python/relax/test_vm.py index 68d317e2ba60..52e5f9f0dc3b 100644 --- a/tests/python/relax/test_vm.py +++ b/tests/python/relax/test_vm.py @@ -167,21 +167,6 @@ def test_vm_shapeof(): for i, s in enumerate(res): assert s == shape[i] -def test_vm_heap(): - ib = rx.ExecBuilder() - shape = (32, 16) - arr = tvm.nd.array(np.random.rand(*shape)) - with ib.function("main", num_inputs=0): - ib.emit_call("vm.builtin.alloc_heap", args=[ib.imm(2)], dst=ib.r(0)) - ib.emit_call("vm.builtin.shape_of", args=[arr], dst=ib.r(1)) - ib.emit_call("vm.builtin.match_shape", args=[ib.r(1), ib.r(0), ib.imm(0), ib.imm(1)]) - ib.emit_call("vm.builtin.make_shape", args=[ib.r(0), ib.imm(0), ib.imm(1)], dst=ib.r(2)) - ib.emit_ret(ib.r(2)) - ex = ib.get() - vm = rx.VirtualMachine(ex, tvm.cpu()) - res = vm["main"]() - for i, s in enumerate(res): - assert s == shape[i] def test_vm_storage(): ib = rx.ExecBuilder() @@ -202,7 +187,7 @@ def test_vm_storage(): assert res.device == tvm.cpu() assert res.shape == shape -def test_vm_compile(): +def test_vm_compile_stage0(): @rx.script class Mod: def foo(x: Tensor[(3, 4), "float32"]): @@ -212,11 +197,72 @@ def foo(x: Tensor[(3, 4), "float32"]): return z mod = Mod() - exec = rx.vm_compiler.compile(mod) - input = tvm.nd.array(np.random.rand(3,4).astype(np.float32)) - vm = rx.VirtualMachine(exec, tvm.cpu()) - res = vm["foo"](input) - np.testing.assert_allclose(input.asnumpy(), res.asnumpy()) + target = tvm.target.Target("llvm") + target_host = tvm.target.Target("llvm") + ex, lib = rx.vm.build(mod, target, target_host) + inp = tvm.nd.array(np.random.rand(3,4).astype(np.float32)) + vm = rx.VirtualMachine(ex, tvm.cpu(), mod=lib) + res = vm["foo"](inp) + np.testing.assert_allclose(inp.asnumpy(), res.asnumpy()) + + +def test_vm_compile_stage1(): + @rx.script + class Mod1: + @tvm.script.tir + def shape_func0(heap: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "shape_func0"}) + H = tir.match_buffer(heap, [tir.int64(4)], dtype="int64", elem_offset=tir.int64(0), align=128, offset_factor=1) + # body + tir.store(H.data, tir.int64(2), (tir.load("int64", H.data, tir.int64(0))*tir.int64(2)), True) + tir.store(H.data, tir.int64(3), (tir.load("int64", H.data, tir.int64(1))*tir.int64(3)), True) + + def foo(x: Tensor[_, "float32"]) -> Shape: + shape_heap: Tensor[(4,), "int64"] = relax.call_packed("vm.builtin.alloc_shape_heap", (4,)) + gv0 = relax.call_packed("vm.builtin.shape_of", x) + gv1 = relax.call_packed("vm.builtin.decode_shape", gv0, shape_heap, (0, 1)) + gv2 = shape_func0(shape_heap) + gv3 = relax.call_packed("vm.builtin.make_shape", shape_heap, (2, 3)) + return gv3 + + mod = Mod1() + code = rx.parser.astext(mod) + target = tvm.target.Target("llvm") + target_host = tvm.target.Target("llvm") + ex, lib = rx.vm.build(mod, target, target_host) + vm = rx.VirtualMachine(ex, tvm.cpu(), mod=lib) + + shape = (32, 16) + arr = tvm.nd.array(np.random.rand(*shape)) + res = vm["foo"](arr) + assert res[0] == shape[0] * 2 + assert res[1] == shape[1] * 3 + + +def test_vm_compile_stage2(): + @rx.script + class Mod2: + def foo(x: Tensor[_, "float32"]) -> Shape: + sh = relax.call_packed("vm.builtin.shape_of", x) + relax.match_shape(sh, (n, m)) + return (n * 2, m * 3) + + mod = Mod2() + code = rx.parser.astext(mod) + new_mod = rx.transform.shape_lower(mod) + code = rx.parser.astext(new_mod) + target = tvm.target.Target("llvm") + target_host = tvm.target.Target("llvm") + ex, lib = rx.vm.build(new_mod, target, target_host) + vm = rx.VirtualMachine(ex, tvm.cpu(), mod=lib) + + shape = (32, 16) + arr = tvm.nd.array(np.random.rand(*shape)) + res = vm["foo"](arr) + assert res[0] == shape[0] * 2 + assert res[1] == shape[1] * 3 + if __name__ == "__main__": test_vm_execute() @@ -227,6 +273,7 @@ def foo(x: Tensor[(3, 4), "float32"]): test_vm_serialize() test_vm_constant_serialize() test_vm_shapeof() - test_vm_heap() test_vm_storage() - test_vm_compile() + test_vm_compile_stage0() + test_vm_compile_stage1() + test_vm_compile_stage2()