Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][VM] Fix constant folding issue in VM compiler #4077

Merged
merged 12 commits into from
Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 17 additions & 23 deletions python/tvm/relay/backend/profiler_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,18 @@

Provides extra APIs for profiling vm execution.
"""
import tvm
from . import vm, _vm

def _update_target(target):
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")

tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts

class VMCompilerProfiler(vm.VMCompiler):
"""Build Relay module to run on VM runtime."""
def __init__(self):
super().__init__()
self.mod = _vm._VMCompilerProfiler()
self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"]
self._set_params_func = self.mod["set_params"]

def compile(self, mod, target=None, target_host=None):
def compile(self, mod, target=None, target_host=None, params=None):
wweic marked this conversation as resolved.
Show resolved Hide resolved
"""
Parameters
----------
Expand All @@ -71,13 +52,26 @@ def compile(self, mod, target=None, target_host=None):
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.

params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.

Returns
-------
vm : VirtualMachineProfiler
The profile VM runtime.

"""
target = _update_target(target)
self._compile(mod, target, target_host)
target = self.update_target(target)
target_host = self.update_target_host(target, target_host)

if params:
self.set_params(params)

tophub_context = self.tophub_context(target)

with tophub_context:
self._compile(mod, target, target_host)
return VirtualMachineProfiler(self._get_vm())

class VirtualMachineProfiler(vm.VirtualMachine):
Expand Down
101 changes: 63 additions & 38 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,11 @@
import tvm
from tvm import autotvm
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm.relay import expr as _expr
from . import _vm
from . import vmobj as _obj
from .interpreter import Executor


def _update_target(target):
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")

tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts

def _convert(arg, cargs):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.tensor_object(arg))
Expand Down Expand Up @@ -150,8 +131,58 @@ def __init__(self):
self.mod = _vm._VMCompiler()
self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"]
self._set_params_func = self.mod["set_params"]

def set_params(self, params):
"""Set constant parameters for the model"""
inputs = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
param = _nd.array(param)
inputs[name] = _expr.const(param)
self._set_params_func(inputs)

def update_target(self, target):
"""Update target"""
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts

def update_target_host(self, target, target_host):
"""Update target host"""
target_host = None if target_host == "" else target_host
if not target_host:
for device_type, tgt in target.items():
if device_type.value == tvm.nd.cpu(0).device_type:
target_host = tgt
break
if not target_host:
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
return tvm.target.create(target_host)

def compile(self, mod, target=None, target_host=None):
def tophub_context(self, target):
# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.util.EmptyContext()
return tophub_context

def compile(self, mod, target=None, target_host=None, params=None):
wweic marked this conversation as resolved.
Show resolved Hide resolved
"""
Parameters
----------
Expand All @@ -172,34 +203,28 @@ def compile(self, mod, target=None, target_host=None):
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.

params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.

Returns
-------
vm : VirtualMachine
The VM runtime.

"""
target = _update_target(target)
target_host = None if target_host == "" else target_host
if not target_host:
for device_type, tgt in target.items():
if device_type.value == tvm.nd.cpu(0).device_type:
target_host = tgt
break
if not target_host:
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
target_host = tvm.target.create(target_host)
target = self.update_target(target)
target_host = self.update_target_host(target, target_host)

# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.util.EmptyContext()
if params:
self.set_params(params)

tophub_context = self.tophub_context(target)

with tophub_context:
self._compile(mod, target, target_host)
return VirtualMachine(self._get_vm())


class VMExecutor(Executor):
"""
An implementation of the executor interface for
Expand Down
60 changes: 57 additions & 3 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {

void VisitExpr_(const ConstantNode* const_node) {
size_t konst_idx = context_->constants.size();
std::string name = "p" + std::to_string(konst_idx);
wweic marked this conversation as resolved.
Show resolved Hide resolved
context_->constant_indices[name] = konst_idx;
context_->constants.push_back(const_node->data);
wweic marked this conversation as resolved.
Show resolved Hide resolved
Emit(Instruction::LoadConst(konst_idx, NewRegister()));
}
Expand Down Expand Up @@ -395,6 +397,8 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
}
}
size_t konst_idx = context_->constants.size();
std::string name = "p" + std::to_string(konst_idx);
wweic marked this conversation as resolved.
Show resolved Hide resolved
context_->constant_indices[name] = konst_idx;
context_->constants.push_back(shape_tensor);
Emit(Instruction::LoadConst(konst_idx, NewRegister()));
return last_register_;
Expand Down Expand Up @@ -780,31 +784,81 @@ PackedFunc VMCompiler::GetFunction(const std::string& name,
if (name == "compile") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3);
this->Compile(args[0], args[1], args[2]);
Module mod = args[0];
this->Compile(mod, args[1], args[2]);
});
} else if (name == "get_vm") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = runtime::Module(vm_);
});
} else if (name == "set_params") {
wweic marked this conversation as resolved.
Show resolved Hide resolved
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Map<std::string, Constant> params = args[0];
for (const auto& kv : params) {
this->SetParam(kv.first, kv.second->data);
}
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
}
}

void VMCompiler::Compile(const Module& mod_ref,
void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
params_[name] = data_in;
}

relay::Function VMCompiler::BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, NodeHash, NodeEqual> repeat_var;
for (auto arg : func->params) {
const auto &name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}
std::unordered_map<relay::Var, Expr, NodeHash, NodeEqual> bind_dict;
for (auto &kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined())
<< "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}


void VMCompiler::Compile(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) {
auto f = BindParamsByName(mod->Lookup("main"), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}

InitVM();
targets_ = targets;
target_host_ = target_host;

// Run some optimizations first, this code should
// be moved to pass manager.
context_.module = OptimizeModule(mod_ref, targets_);
context_.module = OptimizeModule(mod, targets_);

// Populate the global map.
//
Expand Down
24 changes: 23 additions & 1 deletion src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ struct VMCompilerContext {
TagMap tag_map;
// Map from global var to a unique integer
GlobalMap global_map;
// Map from name to constant index
std::unordered_map<std::string, size_t> constant_indices;
wweic marked this conversation as resolved.
Show resolved Hide resolved
// List of constants
std::vector<NDArray> constants;
// List of cached functions
Expand Down Expand Up @@ -100,11 +102,29 @@ class VMCompiler : public runtime::ModuleNode {
vm_ = std::make_shared<VirtualMachine>();
}

void Compile(const Module& mod_ref,
/*!
* \brief Set the parameters
*
* \param name name of parameter
* \param data_in input DLTensor
*/
void SetParam(const std::string& name, runtime::NDArray data_in);

void Compile(Module mod,
wweic marked this conversation as resolved.
Show resolved Hide resolved
const TargetsMap& targets,
const tvm::Target& target_host);

protected:
/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay::Function BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params);

Module OptimizeModule(const Module& mod, const TargetsMap& targets);

void PopulateGlobalMap();
Expand All @@ -120,6 +140,8 @@ class VMCompiler : public runtime::ModuleNode {
VMCompilerContext context_;
/*! \brief Compiled virtual machine. */
std::shared_ptr<VirtualMachine> vm_;
/*! \brief parameters */
std::unordered_map<std::string, runtime::NDArray> params_;
};

} // namespace vm
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/vm/profiler/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,
Index output_size,
const std::vector<Object>& args) {
auto ctx = VirtualMachine::GetParamsContext();
// warmup
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
args);
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);

auto op_begin = std::chrono::high_resolution_clock::now();
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
args);
Expand Down
Loading