diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index c38ca1ae0469..137d60b2a202 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -476,30 +476,39 @@ class VMFunctionCompiler : ExprFunctor { argument_registers.push_back(reg->second); } - // Next generate the invoke instruction. Target target; - if (targets_.size() == 1) { - // homogeneous execution. - for (auto kv : targets_) { - target = kv.second; - } + + if (!func->UseDefaultCompiler()) { + target = tvm::target::ext_dev(); } else { - // heterogeneous execution. - LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation"; + // Next generate the invoke instruction. + if (targets_.size() == 1) { + // homogeneous execution. + const auto& it = targets_.begin(); + target = (*it).second; + } else { + // heterogeneous execution. + LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation"; + } } auto key = CCacheKeyNode::make(func, target); auto cfunc = engine_->Lower(key); - // TODO(jroesch): support lowered funcs for multiple targets - CHECK_EQ(cfunc->funcs.size(), 1); auto op_index = -1; - if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) { + if (!func->UseDefaultCompiler()) { op_index = context_->cached_funcs.size(); context_->cached_funcs.push_back(cfunc); - context_->seen_funcs[cfunc->funcs[0]] = op_index; } else { - op_index = context_->seen_funcs[cfunc->funcs[0]]; + // TODO(jroesch): support lowered funcs for multiple targets + CHECK_EQ(cfunc->funcs.size(), 1); + if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) { + op_index = context_->cached_funcs.size(); + context_->cached_funcs.push_back(cfunc); + context_->seen_funcs[cfunc->funcs[0]] = op_index; + } else { + op_index = context_->seen_funcs[cfunc->funcs[0]]; + } } Emit(Instruction::InvokePacked(op_index, @@ -950,32 +959,46 @@ void VMCompiler::LibraryCodegen() { if (cached_funcs.size() == 0) { return; } - std::unordered_map> tgt_funcs; - for (auto &cfunc : cached_funcs) { + std::unordered_map> funcs; + for (auto& cfunc : cached_funcs) { std::string target_str = cfunc->target->str(); - if (tgt_funcs.count(target_str) == 0) { - tgt_funcs.emplace(target_str, Array{cfunc->funcs[0]}); + if (target_str == "ext_dev") { + continue; + } else if (funcs.count(target_str) == 0) { + funcs.emplace(target_str, Array{cfunc->funcs[0]}); } else { - tgt_funcs[target_str].push_back(cfunc->funcs[0]); + funcs[target_str].push_back(cfunc->funcs[0]); } } - Map> funcs; - for (auto &it : tgt_funcs) { - funcs.Set(Target::Create(it.first), it.second); - } - if (const auto *f = runtime::Registry::Get("relay.backend.build")) { - // The target is just a dummy arg because funcs already contains corresponding target - // therefore target won't be used in the build function - runtime::Module mod = (*f)(funcs, Target(), target_host_); + auto compile_engine = CompileEngine::Global(); + auto ext_mods = compile_engine->LowerExternalFunctions(); + runtime::Module mod; + if (funcs.size() > 0) { + mod = tvm::build(funcs, target_host_, tvm::BuildConfig::Current()); CHECK(mod.operator->()); - exec_->lib = mod; } else { - LOG(FATAL) << "relay.backend.build is not registered"; + CHECK_EQ(ext_mods.size(), 1U) + << "Expect to have a TVM DSOModule when multiple runtime modules exist"; + } + if (!ext_mods.empty()) { + if (funcs.size() == 0) { + mod = ext_mods[0]; + } else { + // Import all external runtime modules. + for (auto it : ext_mods) { + mod.Import(it); + } + } } + exec_->lib = mod; size_t primitive_index = 0; for (auto cfunc : cached_funcs) { - exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); + if (cfunc->target->str() == "ext_dev") { + exec_->primitive_map.insert({cfunc->func_name, primitive_index++}); + } else { + exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); + } } } diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 41fe71a9f9ed..a3b11d46a4fb 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -800,7 +800,9 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { if (packed_funcs_.size() <= packed_index) { packed_funcs_.resize(packed_index + 1); } - packed_funcs_[packed_index] = lib.GetFunction(packed_name); + tvm::runtime::PackedFunc pf = lib.GetFunction(packed_name, true); + CHECK(pf != nullptr) << "Cannot find function in module: " << packed_name; + packed_funcs_[packed_index] = pf; } } diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index fb0a8a2494e9..2cf32e7786ee 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -26,36 +26,54 @@ from tvm import relay from tvm.contrib import util -def check_result(mod, map_inputs, out_shape, result, tol=1e-5): +def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", + ctx=tvm.cpu()): if sys.platform == "win32": print("Skip test on Windows for now") return - with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): - json, lib, _ = relay.build(mod, "llvm") - test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - source_dir = os.path.join(test_dir, "..", "..", "..") - contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") - - kwargs = {} - kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] - tmp_path = util.tempdir() - lib_name = 'lib.so' - lib_path = tmp_path.relpath(lib_name) - lib.export_library(lib_path, fcompile=False, **kwargs) - lib = tvm.module.load(lib_path) - - ctx = tvm.cpu() - rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx) - - for name, data in map_inputs.items(): - rt_mod.set_input(name, data) - - rt_mod.run() - out = tvm.nd.empty(out_shape, ctx=ctx) - out = rt_mod.get_output(0, out) - - tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) + def update_lib(lib): + test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) + source_dir = os.path.join(test_dir, "..", "..", "..") + contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") + + kwargs = {} + kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] + tmp_path = util.tempdir() + lib_name = 'lib.so' + lib_path = tmp_path.relpath(lib_name) + lib.export_library(lib_path, fcompile=False, **kwargs) + lib = tvm.module.load(lib_path) + + return lib + + def check_vm_result(): + with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + exe = relay.vm.compile(mod, target=target) + code, lib = exe.save() + lib = update_lib(lib) + exe = relay.vm.Executable.load_exec(code, lib) + vm = relay.vm.VirtualMachine(exe) + vm.init(ctx) + out = vm.run(**map_inputs) + tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) + + def check_graph_runtime_result(): + with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + json, lib, _ = relay.build(mod, target=target) + lib = update_lib(lib) + rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx) + + for name, data in map_inputs.items(): + rt_mod.set_input(name, data) + rt_mod.run() + out = tvm.nd.empty(out_shape, ctx=ctx) + out = rt_mod.get_output(0, out) + + tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) + + check_vm_result() + check_graph_runtime_result() def set_external_func_attr(func, compiler, ext_symbol):