Skip to content

Commit

Permalink
End2End Lowering Stage2: Enable Lowering from ShapeExpr to VM Executa…
Browse files Browse the repository at this point in the history
…ble (apache#21)

* rebase.

* Update.

* Update shape lowering, make sure the lowering pipeline works.

* Address comment.
  • Loading branch information
ZihengJiang authored and yongwww committed Aug 14, 2022
1 parent 3a9672e commit bb26b5f
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 159 deletions.
2 changes: 0 additions & 2 deletions python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from . import parser
from . import analysis
from . import transform
from . import vm_compiler


# Expr
Expand Down Expand Up @@ -62,7 +61,6 @@
ExecBuilder = exec_builder.ExecBuilder
VirtualMachine = vm.VirtualMachine
load_exec_from_file = vm.load_exec_from_file
compile = vm_compiler.compile

# Operator
from .op.base import call_dps
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def parse_primexpr(self, expr: ast.Expr, bind_free_vars: bool) -> tir.PrimExpr:
return var
elif bind_free_vars:
# introduce TIR variable to scope, e.g. for func params or rx.call_packed
var = tir.Var(var_name, "int32", self.to_tvm_span(expr.span))
var = tir.Var(var_name, "int64", self.to_tvm_span(expr.span))
self.scope[var_name] = var
return var
else:
Expand All @@ -387,7 +387,7 @@ def parse_primexpr(self, expr: ast.Expr, bind_free_vars: bool) -> tir.PrimExpr:
elif isinstance(expr, ast.Constant):
if not isinstance(expr.value, int):
self.report_error("only integer constants are supported", expr.span)
return tir.const(expr.value, "int32", self.to_tvm_span(expr.span))
return tir.const(expr.value, "int64", self.to_tvm_span(expr.span))

elif isinstance(expr, ast.Call):
if not isinstance(expr.func_name, ast.Op):
Expand Down Expand Up @@ -823,7 +823,7 @@ def parse_attr(self, expr: ast.Attr) -> rx.Expr:
"""
if expr.field.name == "shape":
obj = self.transform_expr(expr.object)
attrs = tvm.ir.attrs.make_node("relay.attrs.ShapeOfAttrs", dtype="int32")
attrs = tvm.ir.attrs.make_node("relay.attrs.ShapeOfAttrs", dtype="int64")
return relay.Call(
relay.op.get("shape_of"), [obj], attrs=attrs, span=self.to_tvm_span(expr.span)
)
Expand Down Expand Up @@ -960,7 +960,7 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr:
elif isinstance(expr, ast.Constant):
# FIXME(@altanh): use internal representation that doesn't have precision limits here
if isinstance(expr.value, int):
return tir.IntImm("int32", expr.value, self.to_tvm_span(expr.span))
return tir.IntImm("int64", expr.value, self.to_tvm_span(expr.span))
elif isinstance(expr.value, float):
return tir.FloatImm("float32", expr.value, self.to_tvm_span(expr.span))
elif isinstance(expr.value, str):
Expand Down
37 changes: 35 additions & 2 deletions python/tvm/relax/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from typing import List, Optional, Union, Dict
from typing import List, Optional, Union, Dict, Tuple
import tvm
from tvm.runtime import Object, Device, Module, PackedFunc
from tvm._ffi.base import _LIB, check_call
Expand Down Expand Up @@ -64,7 +64,6 @@ def __init__(
memory_cfg: Optional[Union[str, Dict[Device, str]]] = None,
mod: Optional[Module] = None,
) -> None:

"""
Construct a VirtualMachine wrapper object.
Expand Down Expand Up @@ -133,3 +132,37 @@ def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]])

def __getitem__(self, key: str) -> PackedFunc:
return self.module[key]


def build(mod: tvm.IRModule,
target: tvm.target.Target,
target_host: tvm.target.Target) -> Tuple[Executable, Module]:
"""
Build an IRModule to VM executable.
Parameters
----------
mod: IRModule
The IR module.
target : tvm.target.Target
A build target.
target_host : tvm.target.Target
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
Returns
-------
ex: tvm.relax.vm.Exectuable
An executable that can be loaded by virtual machine.
lib: tvm.runtime.Module
A runtime module that contains generated code.
"""
ex, lib = _ffi_api.VMBuild(mod, target, target_host)
return ex, lib
70 changes: 0 additions & 70 deletions python/tvm/relax/vm_compiler.py

This file was deleted.

4 changes: 4 additions & 0 deletions src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) {
}

Expr ExprMutator::VisitExpr_(const VarNode* op) {
auto it = var_remap_.find(GetRef<Var>(op));
if (it != var_remap_.end()) {
return it->second;
}
if (op->type_annotation.defined()) {
Type type = this->VisitType(op->type_annotation.value());
if (!op->type_annotation.same_as(type)) {
Expand Down
27 changes: 17 additions & 10 deletions src/relax/transform/shape_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ namespace relax {

class ShapeLowerMutator : public ExprMutator {
public:
static DataType ShapeDType() { return DataType::Int(32); };
static DataType ShapeDType() {
return DataType::Int(64);
};

explicit ShapeLowerMutator(IRModule mod) { mod_ = mod; }

Expand All @@ -58,30 +60,33 @@ class ShapeLowerMutator : public ExprMutator {
}

void VisitMatchShape(const MatchShape& binding) override {
Expr value = binding->value;
Expr shape = ExprMutator::VisitExpr(binding->value);
Array<PrimExpr> pattern = binding->pattern;
Array<PrimExpr> indices;
for (size_t i = 0; i < pattern.size(); ++i) {
IntImm idx = expr2slot_.at(pattern[i]);
indices.push_back(idx);
}
builder_->Emit(Call(ExternFunc("decode_shape"), {value, shape_heap_, ShapeExpr(indices)}), "_");
builder_->Emit(Call(ExternFunc("vm.builtin.decode_shape"),
{shape, shape_heap_, ShapeExpr(indices)}), "gv");
}

Expr VisitExpr_(const ShapeExprNode* node) override {
tir::PrimFunc func = CalculateShape(GetRef<ShapeExpr>(node));
GlobalVar shape_func_var(name_table_->GetUniqueName("shape_func"));
std::string shape_func_name = name_table_->GetUniqueName("shape_func");
func = WithAttr(std::move(func), "global_symbol", runtime::String(shape_func_name));
GlobalVar shape_func_var(shape_func_name);
// TODO make sure shape_heap doesnt get redefined by local funcs?
builder_->Emit(Call(shape_func_var, {shape_heap_}), "_");
builder_->Emit(Call(shape_func_var, {shape_heap_}), "gv");
ret_mod_->Add(shape_func_var, func);

// construct shape
Array<PrimExpr> indices;
for (PrimExpr e : node->values) {
indices.push_back(expr2slot_.at(e));
}
return builder_->Emit(Call(ExternFunc("construct_shape"), {shape_heap_, ShapeExpr(indices)}),
"sh");
return builder_->Emit(Call(ExternFunc("vm.builtin.make_shape"),
{shape_heap_, ShapeExpr(indices)}), "sh");
}

Expr VisitExpr_(const FunctionNode* node) override {
Expand All @@ -93,7 +98,7 @@ class ShapeLowerMutator : public ExprMutator {

builder_->BeginBindingBlock();
builder_->Emit(VarBinding(
shape_heap_, Call(ExternFunc("relax.alloc_shape_heap"), {ShapeExpr({heap_size_})})));
shape_heap_, Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})})));

Expr new_body = this->Mutate(node->body);

Expand All @@ -106,7 +111,7 @@ class ShapeLowerMutator : public ExprMutator {
new_body = seq->body;
}

builder_->Emit(Call(ExternFunc("relax.free_shape_heap"), {shape_heap_}), "_");
builder_->Emit(Call(ExternFunc("vm.builtin.free_shape_heap"), {shape_heap_}), "gv");
blocks.push_back(builder_->EndBlock());
new_body = SeqExpr(blocks, new_body);

Expand All @@ -131,6 +136,7 @@ class ShapeLowerMutator : public ExprMutator {
tir::Stmt body = tir::SeqStmt(seq);
Array<tir::Var> params{heap};
Type ret_type = VoidType();

return tir::PrimFunc(params, body, ret_type, buffer_map);
}

Expand Down Expand Up @@ -176,7 +182,8 @@ class ShapeLowerMutator : public ExprMutator {
Map<PrimExpr, IntImm> expr2slot_;
};

TVM_REGISTER_GLOBAL("relax.transform.shape_lower").set_body_typed([](IRModule mod) {
TVM_REGISTER_GLOBAL("relax.transform.shape_lower")
.set_body_typed([](IRModule mod) {
return ShapeLowerMutator(mod).Lower();
});

Expand Down
36 changes: 21 additions & 15 deletions src/relax/vm/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,35 +36,41 @@ namespace relax_vm {

using tvm::runtime::NDArray;

TVM_REGISTER_GLOBAL("vm.builtin.shape_of").set_body_typed([](NDArray arr) { return arr.Shape(); });
TVM_REGISTER_GLOBAL("vm.builtin.shape_of")
.set_body_typed([](NDArray arr) {
return arr.Shape();
});

TVM_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap")
.set_body_typed([](ShapeTuple size) {
return NDArray::Empty(size, DLDataType{kDLInt, 64, 1}, DLDevice{kDLCPU, 0});
});

TVM_REGISTER_GLOBAL("vm.builtin.alloc_heap").set_body_typed([](int64_t size) {
return NDArray::Empty(ShapeTuple({size}), DLDataType{kDLInt, 64, 1}, DLDevice{kDLCPU, 0});
TVM_REGISTER_GLOBAL("vm.builtin.free_shape_heap")
.set_body_typed([](NDArray arr) {
return static_cast<NDArray::Container*>(const_cast<Object*>(arr.get()))->DecRef();
});

TVM_REGISTER_GLOBAL("vm.builtin.match_shape")
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* rv) {
ShapeTuple shape = args[0];
NDArray heap = args[1];
TVM_REGISTER_GLOBAL("vm.builtin.decode_shape")
.set_body_typed([](ShapeTuple shape, NDArray heap, ShapeTuple indexes) {
int64_t* heap_data = reinterpret_cast<int64_t*>(heap.ToDLPack()->dl_tensor.data);
for (int i = 2; i < args.size(); ++i) {
int64_t heap_idx = args[i];
for (size_t i = 0; i < indexes.size(); ++i) {
int64_t heap_idx = indexes[i];
ICHECK(heap_idx >= 0 && heap_idx < heap.Shape()[0]);
heap_data[heap_idx] = shape[i - 2];
heap_data[heap_idx] = shape[i];
}
});

TVM_REGISTER_GLOBAL("vm.builtin.make_shape")
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* rv) {
NDArray heap = args[0];
.set_body_typed([](NDArray heap, ShapeTuple indexes) {
int64_t* heap_data = reinterpret_cast<int64_t*>(heap.ToDLPack()->dl_tensor.data);
std::vector<int64_t> shape;
for (int i = 1; i < args.size(); ++i) {
int64_t heap_idx = args[i];
for (size_t i = 0; i < indexes.size(); ++i) {
int64_t heap_idx = indexes[i];
ICHECK(heap_idx >= 0 && heap_idx < heap.Shape()[0]);
shape.push_back(heap_data[heap_idx]);
}
*rv = ShapeTuple(shape);
return ShapeTuple(shape);
});

TVM_REGISTER_GLOBAL("vm.builtin.alloc_storage")
Expand Down
Loading

0 comments on commit bb26b5f

Please sign in to comment.