Skip to content

Commit

Permalink
[Relay][VM] Support execution on devices (#3678)
Browse files Browse the repository at this point in the history
* [Relay][VM] Support execution on devices

* Reduce Copy calls

* Cleanup

* Lint

* CR comments

* Merge test into test_vm.py
  • Loading branch information
wweic authored and icemelon committed Aug 1, 2019
1 parent a279dd0 commit 5357f49
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 38 deletions.
9 changes: 8 additions & 1 deletion include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,10 @@ class VirtualMachine : public runtime::ModuleNode {
* \param contexts The set of TVM contexts.
*/
void Init(const std::vector<TVMContext>& contexts);
void Run();

/*! \brief Run VM dispatch loop.
*/
void RunLoop();

/*!
* \brief Load parameters from the parameter bytearray.
Expand All @@ -475,6 +478,10 @@ class VirtualMachine : public runtime::ModuleNode {
*/
void InvokeGlobal(const VMFunction& func, const std::vector<Object>& args);

/*! \brief Get device context for params.
*/
TVMContext GetParamsContext() const;

/*! \brief The parameter name to data mapping. */
std::unordered_map<std::string, Object> params_;
};
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def __init__(self, mod, ctx, target):
self.vm.init(ctx)

def _make_executor(self, expr=None):
assert expr is None
main = self.mod["main"]

def _vm_wrapper(*args, **kwargs):
Expand Down
43 changes: 38 additions & 5 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/logging.h>
#include <tvm/runtime/vm.h>

#include <algorithm>
#include <chrono>
#include <iostream>
#include <sstream>
Expand Down Expand Up @@ -569,20 +570,36 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) {
return os;
}

Object CopyTo(Object src, const DLContext& ctx) {
if (src->tag == ObjectTag::kTensor) {
auto tensor = ToNDArray(src);
if (tensor->ctx.device_type != ctx.device_type) {
auto copy = tensor.CopyTo(ctx);
return Object::Tensor(copy);
} else {
return src;
}
} else {
return src;
}
}

PackedFunc VirtualMachine::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
auto ctx = this->GetParamsContext();
std::vector<Object> func_args;
for (int i = 1; i < args.size(); ++i) {
Object obj = args[i];
Object obj = CopyTo(args[i], ctx);
func_args.push_back(obj);
}
auto it = std::find_if(functions.begin(), functions.end(),
[func_name](const VMFunction& func) {
return func.name == func_name;
});

CHECK(it != functions.end()) << "Cannot find function " << func_name << "\n";
CHECK_EQ(func_args.size() + params_.size(), it->params.size())
<< "The number of provided parameters doesn't match the number of arguments"
Expand Down Expand Up @@ -621,6 +638,18 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
}
}

TVMContext VirtualMachine::GetParamsContext() const {
// Use the fallback device if no device index is available.
int fallback_device_type = static_cast<int>(ctxs[0].device_type);
// TODO(wweic): For heterogeneous execution, get device information from byte

const auto& cit =
std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) {
return fallback_device_type == static_cast<int>(c.device_type);
});
return (cit == ctxs.end() ? ctxs[0] : *cit);
}

void VirtualMachine::LoadParams(const std::string& params) {
dmlc::MemoryStringStream mss(const_cast<std::string*>(&params));
dmlc::Stream* strm = &mss;
Expand All @@ -637,11 +666,13 @@ void VirtualMachine::LoadParams(const std::string& params) {
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size()) << "Invalid parameter file";

auto ctx = GetParamsContext();
for (size_t i = 0; i < size; i++) {
NDArray arr;
CHECK(arr.Load(strm)) << "Invalid parameter file";
runtime::Object obj = runtime::Object::Tensor(arr);
params_.emplace(std::make_pair(names[i], obj));
auto copy = CopyTo(obj, ctx);
params_.emplace(std::make_pair(names[i], copy));
}
}

Expand Down Expand Up @@ -678,7 +709,7 @@ Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>&
DLOG(INFO) << "Executing Function: " << std::endl << func;

InvokeGlobal(func, args);
Run();
RunLoop();
auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]);
DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B";
return return_register;
Expand Down Expand Up @@ -762,7 +793,7 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
return result;
}

void VirtualMachine::Run() {
void VirtualMachine::RunLoop() {
CHECK(this->code);
this->pc = 0;
Index frame_start = frames.size();
Expand All @@ -786,7 +817,9 @@ void VirtualMachine::Run() {
throw std::runtime_error("VM encountered fatal error");
}
case Opcode::LoadConst: {
WriteRegister(instr.dst, this->constants[instr.const_index]);
auto constant_obj = this->constants[instr.const_index];
auto device_obj = CopyTo(constant_obj, ctxs[0]);
WriteRegister(instr.dst, device_obj);
pc++;
goto main_loop;
}
Expand Down
Loading

0 comments on commit 5357f49

Please sign in to comment.