Skip to content

Commit

Permalink
End2End Lowering Stage2: Enable Lowering from ShapeExpr to VM Executa…
Browse files Browse the repository at this point in the history
…ble (#21)

* rebase.

* Update.

* Update shape lowering, make sure the lowering pipeline works.

* Address comment.
ZihengJiang authored and junrushao committed Oct 14, 2022
1 parent 9efbd96 commit f5ff3bb
Showing 11 changed files with 243 additions and 159 deletions.
2 changes: 0 additions & 2 deletions python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,6 @@
from . import parser
from . import analysis
from . import transform
from . import vm_compiler


# Expr
@@ -62,7 +61,6 @@
ExecBuilder = exec_builder.ExecBuilder
VirtualMachine = vm.VirtualMachine
load_exec_from_file = vm.load_exec_from_file
compile = vm_compiler.compile

# Operator
from .op.base import call_dps
8 changes: 4 additions & 4 deletions python/tvm/relax/parser.py
Original file line number Diff line number Diff line change
@@ -375,7 +375,7 @@ def parse_primexpr(self, expr: ast.Expr, bind_free_vars: bool) -> tir.PrimExpr:
return var
elif bind_free_vars:
# introduce TIR variable to scope, e.g. for func params or rx.call_packed
var = tir.Var(var_name, "int32", self.to_tvm_span(expr.span))
var = tir.Var(var_name, "int64", self.to_tvm_span(expr.span))
self.scope[var_name] = var
return var
else:
@@ -387,7 +387,7 @@ def parse_primexpr(self, expr: ast.Expr, bind_free_vars: bool) -> tir.PrimExpr:
elif isinstance(expr, ast.Constant):
if not isinstance(expr.value, int):
self.report_error("only integer constants are supported", expr.span)
return tir.const(expr.value, "int32", self.to_tvm_span(expr.span))
return tir.const(expr.value, "int64", self.to_tvm_span(expr.span))

elif isinstance(expr, ast.Call):
if not isinstance(expr.func_name, ast.Op):
@@ -823,7 +823,7 @@ def parse_attr(self, expr: ast.Attr) -> rx.Expr:
"""
if expr.field.name == "shape":
obj = self.transform_expr(expr.object)
attrs = tvm.ir.attrs.make_node("relay.attrs.ShapeOfAttrs", dtype="int32")
attrs = tvm.ir.attrs.make_node("relay.attrs.ShapeOfAttrs", dtype="int64")
return relay.Call(
relay.op.get("shape_of"), [obj], attrs=attrs, span=self.to_tvm_span(expr.span)
)
@@ -960,7 +960,7 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr:
elif isinstance(expr, ast.Constant):
# FIXME(@altanh): use internal representation that doesn't have precision limits here
if isinstance(expr.value, int):
return tir.IntImm("int32", expr.value, self.to_tvm_span(expr.span))
return tir.IntImm("int64", expr.value, self.to_tvm_span(expr.span))
elif isinstance(expr.value, float):
return tir.FloatImm("float32", expr.value, self.to_tvm_span(expr.span))
elif isinstance(expr.value, str):
37 changes: 35 additions & 2 deletions python/tvm/relax/vm.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from typing import List, Optional, Union, Dict
from typing import List, Optional, Union, Dict, Tuple
import tvm
from tvm.runtime import Object, Device, Module, PackedFunc
from tvm._ffi.base import _LIB, check_call
@@ -64,7 +64,6 @@ def __init__(
memory_cfg: Optional[Union[str, Dict[Device, str]]] = None,
mod: Optional[Module] = None,
) -> None:

"""
Construct a VirtualMachine wrapper object.
@@ -133,3 +132,37 @@ def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]])

def __getitem__(self, key: str) -> PackedFunc:
return self.module[key]


def build(mod: tvm.IRModule,
target: tvm.target.Target,
target_host: tvm.target.Target) -> Tuple[Executable, Module]:
"""
Build an IRModule to VM executable.
Parameters
----------
mod: IRModule
The IR module.
target : tvm.target.Target
A build target.
target_host : tvm.target.Target
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
Returns
-------
ex: tvm.relax.vm.Exectuable
An executable that can be loaded by virtual machine.
lib: tvm.runtime.Module
A runtime module that contains generated code.
"""
ex, lib = _ffi_api.VMBuild(mod, target, target_host)
return ex, lib
70 changes: 0 additions & 70 deletions python/tvm/relax/vm_compiler.py

This file was deleted.

4 changes: 4 additions & 0 deletions src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
@@ -199,6 +199,10 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) {
}

Expr ExprMutator::VisitExpr_(const VarNode* op) {
auto it = var_remap_.find(GetRef<Var>(op));
if (it != var_remap_.end()) {
return it->second;
}
if (op->type_annotation.defined()) {
Type type = this->VisitType(op->type_annotation.value());
if (!op->type_annotation.same_as(type)) {
27 changes: 17 additions & 10 deletions src/relax/transform/shape_lower.cc
Original file line number Diff line number Diff line change
@@ -33,7 +33,9 @@ namespace relax {

class ShapeLowerMutator : public ExprMutator {
public:
static DataType ShapeDType() { return DataType::Int(32); };
static DataType ShapeDType() {
return DataType::Int(64);
};

explicit ShapeLowerMutator(IRModule mod) { mod_ = mod; }

@@ -58,30 +60,33 @@ class ShapeLowerMutator : public ExprMutator {
}

void VisitMatchShape(const MatchShape& binding) override {
Expr value = binding->value;
Expr shape = ExprMutator::VisitExpr(binding->value);
Array<PrimExpr> pattern = binding->pattern;
Array<PrimExpr> indices;
for (size_t i = 0; i < pattern.size(); ++i) {
IntImm idx = expr2slot_.at(pattern[i]);
indices.push_back(idx);
}
builder_->Emit(Call(ExternFunc("decode_shape"), {value, shape_heap_, ShapeExpr(indices)}), "_");
builder_->Emit(Call(ExternFunc("vm.builtin.decode_shape"),
{shape, shape_heap_, ShapeExpr(indices)}), "gv");
}

Expr VisitExpr_(const ShapeExprNode* node) override {
tir::PrimFunc func = CalculateShape(GetRef<ShapeExpr>(node));
GlobalVar shape_func_var(name_table_->GetUniqueName("shape_func"));
std::string shape_func_name = name_table_->GetUniqueName("shape_func");
func = WithAttr(std::move(func), "global_symbol", runtime::String(shape_func_name));
GlobalVar shape_func_var(shape_func_name);
// TODO make sure shape_heap doesnt get redefined by local funcs?
builder_->Emit(Call(shape_func_var, {shape_heap_}), "_");
builder_->Emit(Call(shape_func_var, {shape_heap_}), "gv");
ret_mod_->Add(shape_func_var, func);

// construct shape
Array<PrimExpr> indices;
for (PrimExpr e : node->values) {
indices.push_back(expr2slot_.at(e));
}
return builder_->Emit(Call(ExternFunc("construct_shape"), {shape_heap_, ShapeExpr(indices)}),
"sh");
return builder_->Emit(Call(ExternFunc("vm.builtin.make_shape"),
{shape_heap_, ShapeExpr(indices)}), "sh");
}

Expr VisitExpr_(const FunctionNode* node) override {
@@ -93,7 +98,7 @@ class ShapeLowerMutator : public ExprMutator {

builder_->BeginBindingBlock();
builder_->Emit(VarBinding(
shape_heap_, Call(ExternFunc("relax.alloc_shape_heap"), {ShapeExpr({heap_size_})})));
shape_heap_, Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})})));

Expr new_body = this->Mutate(node->body);

@@ -106,7 +111,7 @@ class ShapeLowerMutator : public ExprMutator {
new_body = seq->body;
}

builder_->Emit(Call(ExternFunc("relax.free_shape_heap"), {shape_heap_}), "_");
builder_->Emit(Call(ExternFunc("vm.builtin.free_shape_heap"), {shape_heap_}), "gv");
blocks.push_back(builder_->EndBlock());
new_body = SeqExpr(blocks, new_body);

@@ -131,6 +136,7 @@ class ShapeLowerMutator : public ExprMutator {
tir::Stmt body = tir::SeqStmt(seq);
Array<tir::Var> params{heap};
Type ret_type = VoidType();

return tir::PrimFunc(params, body, ret_type, buffer_map);
}

@@ -176,7 +182,8 @@ class ShapeLowerMutator : public ExprMutator {
Map<PrimExpr, IntImm> expr2slot_;
};

TVM_REGISTER_GLOBAL("relax.transform.shape_lower").set_body_typed([](IRModule mod) {
TVM_REGISTER_GLOBAL("relax.transform.shape_lower")
.set_body_typed([](IRModule mod) {
return ShapeLowerMutator(mod).Lower();
});

36 changes: 21 additions & 15 deletions src/relax/vm/builtin.cc
Original file line number Diff line number Diff line change
@@ -36,35 +36,41 @@ namespace relax_vm {

using tvm::runtime::NDArray;

TVM_REGISTER_GLOBAL("vm.builtin.shape_of").set_body_typed([](NDArray arr) { return arr.Shape(); });
TVM_REGISTER_GLOBAL("vm.builtin.shape_of")
.set_body_typed([](NDArray arr) {
return arr.Shape();
});

TVM_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap")
.set_body_typed([](ShapeTuple size) {
return NDArray::Empty(size, DLDataType{kDLInt, 64, 1}, DLDevice{kDLCPU, 0});
});

TVM_REGISTER_GLOBAL("vm.builtin.alloc_heap").set_body_typed([](int64_t size) {
return NDArray::Empty(ShapeTuple({size}), DLDataType{kDLInt, 64, 1}, DLDevice{kDLCPU, 0});
TVM_REGISTER_GLOBAL("vm.builtin.free_shape_heap")
.set_body_typed([](NDArray arr) {
return static_cast<NDArray::Container*>(const_cast<Object*>(arr.get()))->DecRef();
});

TVM_REGISTER_GLOBAL("vm.builtin.match_shape")
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* rv) {
ShapeTuple shape = args[0];
NDArray heap = args[1];
TVM_REGISTER_GLOBAL("vm.builtin.decode_shape")
.set_body_typed([](ShapeTuple shape, NDArray heap, ShapeTuple indexes) {
int64_t* heap_data = reinterpret_cast<int64_t*>(heap.ToDLPack()->dl_tensor.data);
for (int i = 2; i < args.size(); ++i) {
int64_t heap_idx = args[i];
for (size_t i = 0; i < indexes.size(); ++i) {
int64_t heap_idx = indexes[i];
ICHECK(heap_idx >= 0 && heap_idx < heap.Shape()[0]);
heap_data[heap_idx] = shape[i - 2];
heap_data[heap_idx] = shape[i];
}
});

TVM_REGISTER_GLOBAL("vm.builtin.make_shape")
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* rv) {
NDArray heap = args[0];
.set_body_typed([](NDArray heap, ShapeTuple indexes) {
int64_t* heap_data = reinterpret_cast<int64_t*>(heap.ToDLPack()->dl_tensor.data);
std::vector<int64_t> shape;
for (int i = 1; i < args.size(); ++i) {
int64_t heap_idx = args[i];
for (size_t i = 0; i < indexes.size(); ++i) {
int64_t heap_idx = indexes[i];
ICHECK(heap_idx >= 0 && heap_idx < heap.Shape()[0]);
shape.push_back(heap_data[heap_idx]);
}
*rv = ShapeTuple(shape);
return ShapeTuple(shape);
});

TVM_REGISTER_GLOBAL("vm.builtin.alloc_storage")
103 changes: 77 additions & 26 deletions src/relax/vm/compiler.cc
Original file line number Diff line number Diff line change
@@ -24,8 +24,11 @@

#include "compiler.h"

#include <tvm/target/target.h>
#include <tvm/relax/attrs/memory.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/tir/function.h>
#include <tvm/driver/driver_api.h>

#include <string>
#include <unordered_map>
@@ -37,9 +40,11 @@ namespace relax_vm {

using namespace relax;

class VMFunctionCompiler : public ExprVisitor {
class VMCompilerImpl : public ExprVisitor {
public:
explicit VMFunctionCompiler(ExecBuilderNode* builder) { builder_ = GetRef<ExecBuilder>(builder); }
explicit VMCompilerImpl(ExecBuilderNode* builder) {
builder_ = GetRef<ExecBuilder>(builder);
}

protected:
/*! \brief A counter for naming local functions. */
@@ -85,18 +90,15 @@ class VMFunctionCompiler : public ExprVisitor {
EmitAllocTensor(call_node, var);
} else {
// Normal packed function without attributes
std::vector<Instruction::Arg> args;
for (size_t i = 0; i < call_node->args.size(); ++i) {
if (call_node->args[i].as<VarNode>()) {
auto reg = this->var_register_map_.find(Downcast<Var>(call_node->args[i]));
ICHECK(reg != this->var_register_map_.end());
args.push_back(Instruction::Arg(Instruction::kRegister, reg->second));
}
}
std::vector<Instruction::Arg> args = ConvertArgs(call_node);
// TODO(@yuchen): what if the packed func has void return (no need to write to the dst
// register)?
builder_->EmitCall(name, args, NewRegister(var));
}
} else if (auto* gvar = call_node->op.as<GlobalVarNode>()) {
String name = gvar->name_hint;
std::vector<Instruction::Arg> args = ConvertArgs(call_node);
builder_->EmitCall(name, args, NewRegister(var));
} else {
LOG(FATAL) << "TODO: support compiling everything other than extern functions.";
}
@@ -172,6 +174,31 @@ class VMFunctionCompiler : public ExprVisitor {
return reg;
}

std::vector<Instruction::Arg> ConvertArgs(const Call& call) {
std::vector<Instruction::Arg> ret;
const auto& args = call->args;
for (size_t i = 0; i < call->args.size(); ++i) {
if (args[i]->IsInstance<VarNode>()) {
auto reg = this->var_register_map_.find(Downcast<Var>(args[i]));
ICHECK(reg != this->var_register_map_.end());
ret.push_back(Instruction::Arg(Instruction::kRegister, reg->second));
} else if (args[i]->IsInstance<ShapeExprNode>()) {
std::vector<int64_t> shape;
for (PrimExpr e : Downcast<ShapeExpr>(args[i])->values) {
shape.push_back(Downcast<IntImm>(e)->value);
}
auto shape_tuple = ShapeTuple(shape);
TVMRetValue shape_tuple_value;
shape_tuple_value = shape_tuple;
Index index = builder_->EmitConstant(shape_tuple_value);
ret.push_back(Instruction::Arg(Instruction::kConstIdx, index));
} else {
LOG(FATAL) << "not supported argument type.";
}
}
return ret;
}

/*! \brief Internal ExecBuilder. */
relax::ExecBuilder builder_;
/*! \brief Total number of virtual registers allocated. */
@@ -183,9 +210,9 @@ class VMFunctionCompiler : public ExprVisitor {
PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
if (name == "compile") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.num_args, 1);
ICHECK_EQ(args.num_args, 3);
IRModule mod = args[0];
this->Compile(mod);
this->Compile(mod, args[1], args[2]);
});
} else if (name == "get_executable") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetExec(); });
@@ -195,31 +222,55 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr<Obje
}
}

void VMCompiler::Compile(IRModule mod) {
// TODO(@yuchen, @ziheng): support lowering PrimFuncs
for (auto& func : mod->functions) {
auto gvar = func.first;
if (!func.second->IsInstance<FunctionNode>()) {
continue;
}
void VMCompiler::Compile(IRModule mod, Target target, Target target_host) {
// Reset internal builder
builder_ = relax::ExecBuilderNode::Create();

IRModule tir_mod;
IRModule rx_mod;
for (auto& p : mod->functions) {
auto gvar = p.first;

VMFunctionCompiler func_compiler();
if (auto* n = func.second.as<FunctionNode>()) {
auto func = GetRef<Function>(n);
auto func_compiler = VMFunctionCompiler(builder_.operator->());
func_compiler.VisitExpr(func);
BaseFunc func = p.second;
if (func.as<tir::PrimFuncNode>()) {
tir_mod->Add(gvar, func);
} else if (func.as<FunctionNode>()) {
rx_mod->Add(gvar, func);
} else {
LOG(FATAL) << "Cannot handle such function node now:\n" << func;
}
}
lib_ = tvm::build(tir_mod, target, target_host);

VMCompilerImpl compiler(builder_.operator->());
for (auto& p : rx_mod->functions) {
compiler.VisitExpr(p.second);
}
}

Executable VMCompiler::GetExec() { return builder_->Get(); }
Executable VMCompiler::GetExec() {
return builder_->Get();
}

runtime::Module VMCompiler::GetLib() {
return lib_;
}

runtime::Module CreateVMCompiler() {
auto compiler = make_object<VMCompiler>();
return runtime::Module(compiler);
}

TVM_REGISTER_GLOBAL("relax.VMCompiler").set_body_typed([]() { return CreateVMCompiler(); });
Array<ObjectRef> Build(IRModule mod, Target target, Target target_host) {
auto compiler = make_object<VMCompiler>();
compiler->Compile(mod, target, target_host);
Executable exec = compiler->GetExec();
Module lib = compiler->GetLib();
return Array<ObjectRef>({exec, lib});
}

TVM_REGISTER_GLOBAL("relax.VMBuild")
.set_body_typed(Build);

} // namespace relax_vm
} // namespace runtime
14 changes: 12 additions & 2 deletions src/relax/vm/compiler.h
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@
#ifndef TVM_RELAX_VM_COMPILER_H_
#define TVM_RELAX_VM_COMPILER_H_

#include <tvm/target/target.h>
#include <tvm/ir/module.h>
#include <tvm/relax/vm/exec_builder.h>
#include <tvm/relax/vm/executable.h>
@@ -35,26 +36,35 @@ namespace tvm {
namespace runtime {
namespace relax_vm {

using tvm::Target;

class VMCompiler : public runtime::ModuleNode {
public:
/*!
* \brief Compile the functions in a Module.
* \param mod Input IRModule to be compiled.
*/
void Compile(IRModule mod);
void Compile(IRModule mod, Target target, Target target_host);
/*!
* \brief Get the compiled executable.
* \return The compiled executable.
*/
Executable GetExec();
/*!
* \brief Get the compiled library.
* \return The compiled lirary.
*/
Module GetLib();

virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);

const char* type_key() const { return "relax.VMCompiler"; }

protected:
/*! \brief Internal executable builder. */
relax::ExecBuilder builder_ = relax::ExecBuilderNode::Create();
relax::ExecBuilder builder_;
/*! \brief Built library. */
runtime::Module lib_;
};

} // namespace relax_vm
8 changes: 3 additions & 5 deletions tests/python/relax/test_transform.py
Original file line number Diff line number Diff line change
@@ -96,16 +96,14 @@ def test_explicit_memory_rewrite():
s2 = block.bindings[1].value
assert s2.op.global_symbol == "test.op.identity"

# rx.parser.pretty_print(func)


@rx.script
class Mod:
def foo(x: Tensor[_, "float32"]) -> Shape:
relax.match_shape(x.shape, (n, m))
sh = relax.call_packed("vm.builtin.shape_of", x)
relax.match_shape(sh, (n, m))
return (n * 2, m * 3)


def test_shape_lowering():
mod = Mod()
new_mod = rx.transform.shape_lower(mod)
@@ -115,7 +113,7 @@ def test_shape_lowering():
code = rx.parser.astext(new_mod)
assert "alloc_shape_heap" in code
assert "decode_shape" in code
assert "construct_shape" in code
assert "make_shape" in code


if __name__ == "__main__":
93 changes: 70 additions & 23 deletions tests/python/relax/test_vm.py
Original file line number Diff line number Diff line change
@@ -167,21 +167,6 @@ def test_vm_shapeof():
for i, s in enumerate(res):
assert s == shape[i]

def test_vm_heap():
ib = rx.ExecBuilder()
shape = (32, 16)
arr = tvm.nd.array(np.random.rand(*shape))
with ib.function("main", num_inputs=0):
ib.emit_call("vm.builtin.alloc_heap", args=[ib.imm(2)], dst=ib.r(0))
ib.emit_call("vm.builtin.shape_of", args=[arr], dst=ib.r(1))
ib.emit_call("vm.builtin.match_shape", args=[ib.r(1), ib.r(0), ib.imm(0), ib.imm(1)])
ib.emit_call("vm.builtin.make_shape", args=[ib.r(0), ib.imm(0), ib.imm(1)], dst=ib.r(2))
ib.emit_ret(ib.r(2))
ex = ib.get()
vm = rx.VirtualMachine(ex, tvm.cpu())
res = vm["main"]()
for i, s in enumerate(res):
assert s == shape[i]

def test_vm_storage():
ib = rx.ExecBuilder()
@@ -202,7 +187,7 @@ def test_vm_storage():
assert res.device == tvm.cpu()
assert res.shape == shape

def test_vm_compile():
def test_vm_compile_stage0():
@rx.script
class Mod:
def foo(x: Tensor[(3, 4), "float32"]):
@@ -212,11 +197,72 @@ def foo(x: Tensor[(3, 4), "float32"]):
return z

mod = Mod()
exec = rx.vm_compiler.compile(mod)
input = tvm.nd.array(np.random.rand(3,4).astype(np.float32))
vm = rx.VirtualMachine(exec, tvm.cpu())
res = vm["foo"](input)
np.testing.assert_allclose(input.asnumpy(), res.asnumpy())
target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = rx.vm.build(mod, target, target_host)
inp = tvm.nd.array(np.random.rand(3,4).astype(np.float32))
vm = rx.VirtualMachine(ex, tvm.cpu(), mod=lib)
res = vm["foo"](inp)
np.testing.assert_allclose(inp.asnumpy(), res.asnumpy())


def test_vm_compile_stage1():
@rx.script
class Mod1:
@tvm.script.tir
def shape_func0(heap: ty.handle) -> None:
# function attr dict
tir.func_attr({"global_symbol": "shape_func0"})
H = tir.match_buffer(heap, [tir.int64(4)], dtype="int64", elem_offset=tir.int64(0), align=128, offset_factor=1)
# body
tir.store(H.data, tir.int64(2), (tir.load("int64", H.data, tir.int64(0))*tir.int64(2)), True)
tir.store(H.data, tir.int64(3), (tir.load("int64", H.data, tir.int64(1))*tir.int64(3)), True)

def foo(x: Tensor[_, "float32"]) -> Shape:
shape_heap: Tensor[(4,), "int64"] = relax.call_packed("vm.builtin.alloc_shape_heap", (4,))
gv0 = relax.call_packed("vm.builtin.shape_of", x)
gv1 = relax.call_packed("vm.builtin.decode_shape", gv0, shape_heap, (0, 1))
gv2 = shape_func0(shape_heap)
gv3 = relax.call_packed("vm.builtin.make_shape", shape_heap, (2, 3))
return gv3

mod = Mod1()
code = rx.parser.astext(mod)
target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = rx.vm.build(mod, target, target_host)
vm = rx.VirtualMachine(ex, tvm.cpu(), mod=lib)

shape = (32, 16)
arr = tvm.nd.array(np.random.rand(*shape))
res = vm["foo"](arr)
assert res[0] == shape[0] * 2
assert res[1] == shape[1] * 3


def test_vm_compile_stage2():
@rx.script
class Mod2:
def foo(x: Tensor[_, "float32"]) -> Shape:
sh = relax.call_packed("vm.builtin.shape_of", x)
relax.match_shape(sh, (n, m))
return (n * 2, m * 3)

mod = Mod2()
code = rx.parser.astext(mod)
new_mod = rx.transform.shape_lower(mod)
code = rx.parser.astext(new_mod)
target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = rx.vm.build(new_mod, target, target_host)
vm = rx.VirtualMachine(ex, tvm.cpu(), mod=lib)

shape = (32, 16)
arr = tvm.nd.array(np.random.rand(*shape))
res = vm["foo"](arr)
assert res[0] == shape[0] * 2
assert res[1] == shape[1] * 3


if __name__ == "__main__":
test_vm_execute()
@@ -227,6 +273,7 @@ def foo(x: Tensor[(3, 4), "float32"]):
test_vm_serialize()
test_vm_constant_serialize()
test_vm_shapeof()
test_vm_heap()
test_vm_storage()
test_vm_compile()
test_vm_compile_stage0()
test_vm_compile_stage1()
test_vm_compile_stage2()

0 comments on commit f5ff3bb

Please sign in to comment.