Skip to content

Commit

Permalink
Fix vm build. (apache#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuchenJin authored and junrushao committed Oct 14, 2022
1 parent ab35a8e commit 99401ac
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 52 deletions.
22 changes: 21 additions & 1 deletion python/tvm/relax/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

from typing import List, Optional, Union, Dict, Tuple
import tvm
from tvm import relax
from tvm.ir.module import IRModule
from tvm.runtime import Object, Device, Module, PackedFunc
from tvm._ffi.base import _LIB, check_call
from tvm.tir.function import PrimFunc
from . import _ffi_api
from . import transform
from ..rpc.base import RPC_SESS_MASK
Expand Down Expand Up @@ -169,5 +172,22 @@ def build(mod: tvm.IRModule,
new_mod = transform.call_dps_rewrite(new_mod)
new_mod = transform.vm_memory_lower(new_mod)
new_mod = transform.vm_shape_lower(new_mod)
ex, lib = _ffi_api.VMBuild(new_mod, target, target_host)

# split primfunc and relax function
rx_mod, tir_mod = _split_tir_relax(new_mod)

lib = tvm.build(tir_mod, target, target_host)
ex = _ffi_api.VMCodeGen(rx_mod)
return ex, lib

def _split_tir_relax(mod: tvm.IRModule) -> Tuple[tvm.IRModule, tvm.IRModule]:
rx_mod = IRModule({})
tir_mod = IRModule({})
for gv in mod.get_global_vars():
if isinstance(mod[gv], PrimFunc):
tir_mod[gv] = mod[gv]
elif isinstance(mod[gv], relax.Function):
rx_mod[gv] = mod[gv]
else:
raise ValueError("An IRModule should contain contain relax function and TIR primfunc.")
return rx_mod, tir_mod
48 changes: 13 additions & 35 deletions src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

/*!
* \file src/relax/backend/vm/codegen_vm.cc
* \brief A compiler to compile an IRModule to VM executable.
* \brief A codegen to generate VM executable from an IRModule with relax functions.
*/

#include "codegen_vm.h"
Expand Down Expand Up @@ -64,7 +64,7 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
// TODO(@yuchen): handle local functions that capture local vars outside the func
// TODO(@yuchen): a renaming pass to resolve name conflicts, e.g. the input module has a
// function named "local_funcN"
// lift the local func to a global func and compile it normally
// lift the local func to a global func and process it normally
builder_->EmitFunction("local_func" + std::to_string(local_func_counter_++),
func_node->params.size());
}
Expand Down Expand Up @@ -287,49 +287,27 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
const Op& load_shape_op_ = Op::Get("relax.vm.builtin.load_shape");
};

void VMCompiler::Compile(IRModule mod, Target target, Target target_host) {
void VMCodeGen::CodeGen(IRModule rx_mod) {
builder_ = relax::ExecBuilderNode::Create();

IRModule tir_mod;
IRModule rx_mod;
for (auto& p : mod->functions) {
auto gvar = p.first;

BaseFunc func = p.second;
if (func.as<tir::PrimFuncNode>()) {
tir_mod->Add(gvar, func);
} else if (func.as<FunctionNode>()) {
rx_mod->Add(gvar, func);
} else {
LOG(FATAL) << "Cannot handle such function node now:\n" << func;
}
}
lib_ = tvm::build(tir_mod, target, target_host);

CodeGenVM compiler(builder_.operator->());
CodeGenVM codegen(builder_.operator->());
for (auto& p : rx_mod->functions) {
compiler.VisitExpr(p.second);
codegen.VisitExpr(p.second);
}
}

Executable VMCompiler::GetExec() {
Executable VMCodeGen::GetExec() {
return builder_->Get();
}

runtime::Module VMCompiler::GetLib() {
return lib_;
}

Array<ObjectRef> Build(IRModule mod, Target target, Target target_host) {
auto compiler = make_object<VMCompiler>();
compiler->Compile(mod, target, target_host);
Executable exec = compiler->GetExec();
Module lib = compiler->GetLib();
return Array<ObjectRef>({exec, lib});
Executable CodeGen(IRModule mod) {
auto codegen = make_object<VMCodeGen>();
codegen->CodeGen(mod);
Executable exec = codegen->GetExec();
return exec;
}

TVM_REGISTER_GLOBAL("relax.VMBuild")
.set_body_typed(Build);
TVM_REGISTER_GLOBAL("relax.VMCodeGen")
.set_body_typed(CodeGen);

} // namespace relax_vm
} // namespace relax
Expand Down
24 changes: 8 additions & 16 deletions src/relax/backend/vm/codegen_vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

/*!
* \file src/relax/backend/vm/codegen_vm.h
* \brief A compiler to compile an IRModule to VM executable.
* \brief A codegen to generate VM executable from an IRModule with relax functions.
*/

#ifndef TVM_RELAX_BACKEND_VM_COMPILER_H_
#define TVM_RELAX_BACKEND_VM_COMPILER_H_
#ifndef TVM_RELAX_BACKEND_CODEGEN_VM_H_
#define TVM_RELAX_BACKEND_CODEGEN_VM_H_

#include <tvm/ir/module.h>
#include <tvm/relax/vm/exec_builder.h>
Expand All @@ -40,37 +40,29 @@ using tvm::Target;
using namespace tvm::runtime::relax_vm;
using namespace tvm::runtime;

class VMCompiler : public Object {
class VMCodeGen : public Object {
public:
/*!
* \brief Compile the functions in a Module.
* \param mod Input IRModule to be compiled.
* \param rx_mod Input IRModule that constains relax functions.
*/
void Compile(IRModule mod, Target target, Target target_host);
void CodeGen(IRModule rx_mod);
/*!
* \brief Get the compiled executable.
* \return The compiled executable.
*/
Executable GetExec();
/*!
* \brief Get the compiled library.
* \return The compiled lirary.
*/
Module GetLib();

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.VMCompiler";
TVM_DECLARE_FINAL_OBJECT_INFO(ExecutableNode, Object);
static constexpr const char* _type_key = "relax.VMCodeGen";

protected:
/*! \brief Internal executable builder. */
relax::ExecBuilder builder_;
/*! \brief Built library. */
runtime::Module lib_;
};

} // namespace relax_vm
} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_BACKEND_VM_COMPILER_H_
#endif // TVM_RELAX_BACKEND_CODEGEN_VM_H_

0 comments on commit 99401ac

Please sign in to comment.