diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 3c136444229b..334fe169ad41 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -344,6 +344,19 @@ TVM_DLL Array lower(Schedule sch, const std::string& name, const std::unordered_map& binds, const BuildConfig& config); +/*! +* \brief Split host/device function and running necessary pass before build +* \param funcs The functions to be built. +* \param target The target device to build for. +* \param target_host The target for building host code. To use the default, pass Target() +* \param config The build configuration. +* \return The Array> with 2 elements. First is host function Array, + second is device function array +*/ +TVM_DLL Array > split_dev_host_funcs(const Array& funcs, + const Target& target, + const Target& target_host, + const BuildConfig& config); /*! * \brief Build a device and host module for a specific target from an array of lowered functions. diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 92a12a0da1b7..01ebcacf6180 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -423,10 +423,10 @@ Array lower(Schedule sch, return Array({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) }); } -runtime::Module build(const Array& funcs, - const Target& target, - const Target& target_host, - const BuildConfig& config) { +Array > split_dev_host_funcs(const Array& funcs, + const Target& target, + const Target& target_host, + const BuildConfig& config) { std::unordered_set all_names; for (const auto &x : funcs) { CHECK(all_names.count(x->name) == 0) << "Duplicate function name " << x->name; @@ -493,6 +493,17 @@ runtime::Module build(const Array& funcs, func = ir::CombineContextCall(func); fhost.Set(i, func); } + return {fhost, fdevice}; +} + +runtime::Module build(const Array& funcs, + const Target& target, + const Target& target_host, + const BuildConfig& config) { + auto target_host_val = target_host.defined() ? target_host : DefaultTargetHost(target); + auto host_dev_funcs = split_dev_host_funcs(funcs, target, target_host, config); + auto& fhost = host_dev_funcs[0]; + auto& fdevice = host_dev_funcs[1]; auto mhost = codegen::Build(fhost, target_host_val->str()); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc new file mode 100644 index 000000000000..b60a048e638a --- /dev/null +++ b/src/relay/backend/build_module.cc @@ -0,0 +1,713 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file relay/backend/build_module.cc + * \brief Code generation for TVM's graph runtime. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" + +namespace tvm { +namespace relay { +namespace backend { + +/*! + * \brief Context name / index + * See: python/tvm/_ffi/runtime_ctypes.py + */ +struct ContextMap { + static const std::unordered_map mask2str; + static const std::unordered_map str2mask; + static std::string Mask2Str(int mask) { + CHECK_GT(mask2str.count(mask), 0) << "Unknown mask."; + return mask2str.at(mask); + } + static int Str2Mask(const std::string& str) { + CHECK_GT(str2mask.count(str), 0) << "Unknown context."; + return str2mask.at(str); + } +}; + +const std::unordered_map ContextMap::mask2str = { + {1, "cpu"}, + {2, "gpu"}, + {4, "opencl"}, + {5, "aocl"}, + {6, "sdaccel"}, + {7, "vulkan"}, + {8, "metal"}, + {9, "vpi"}, + {10, "rocm"}, + {11, "opengl"}, + {12, "ext_dev"} +}; + +const std::unordered_map ContextMap::str2mask = { + {"llvm", 1}, + {"cpu", 1}, + {"c", 1}, + {"gpu", 2}, + {"cuda", 2}, + {"nvptx", 2}, + {"cl", 4}, + {"opencl", 4}, + {"aocl", 5}, + {"aocl_sw_emu", 5}, + {"vulkan", 7}, + {"metal", 8}, + {"vpi", 9}, + {"rocm", 10}, + {"opengl", 11}, + {"ext_dev", 12} +}; + +/*! + * \brief A data structure to map the names of specific optimizations to + * numeric optimization levels + * + */ +struct OptPassLevel { + static const std::unordered_map _data; + /*! + * \brief Get level for an optimization pass + * + * \param key pass name + * \return int level + */ + int operator[](const std::string& key) const { + auto it = _data.find(key); + if (it == _data.end()) { + return -1; + } + return it->second; + } +}; + +const std::unordered_map OptPassLevel::_data = { + {"SimplifyInference", 0}, + {"OpFusion", 1}, + {"FoldConstant", 2}, + {"CombineParallelConv2D", 3}, + {"FoldScaleAxis", 3}, + {"AlterOpLayout", 3}, + {"CanonicalizeOps", 3}, + {"EliminateCommonSubexpr", 3} +}; + +/*! + * \brief Output of building module + * + */ +struct BuildOutput { + std::string graph_json; + runtime::Module mod; + std::unordered_map params; +}; + +/*! + * \brief Relay building config + * + */ +struct RelayBuildConfig { + int opt_level{2}; + std::string fallback_device{"llvm"}; + std::unordered_set enabled_pass; + std::unordered_set disabled_pass; + OptPassLevel OPT_PASS_LEVEL; + inline bool pass_enabled(const std::string& pass_name) const { + if (disabled_pass.count(pass_name)) { + return false; + } + if (enabled_pass.count(pass_name)) { + return true; + } + return opt_level >= OPT_PASS_LEVEL[pass_name]; + } +}; + +/*! + * \brief GraphCodegen module wrapper + * + */ +struct GraphCodegen { + public: + GraphCodegen() { + auto pf = GetPackedFunc("relay.build_module._GraphRuntimeCodegen"); + mod = (*pf)(); + } + ~GraphCodegen() {} + + void Init(runtime::Module* m, + Map targets) { + Array tgts; + for (auto kv : targets) { + tgts.push_back(kv.first); + tgts.push_back(kv.second); + } + CallFunc("init", m, tgts); + } + + void Codegen(const Function& func) { + CallFunc("codegen", func); + } + + std::string GetJSON() { + return CallFunc("get_graph_json", nullptr); + } + + Map > GetLoweredFunc() { + return CallFunc > >("get_lowered_funcs", nullptr); + } + + std::unordered_map GetParams() { + std::unordered_map ret; + auto names = CallFunc >("list_params_name", nullptr); + for (auto expr : names) { + auto key = expr.as()->value; + ret[key] = CallFunc("get_param_by_name", key); + } + return ret; + } + + protected: + tvm::runtime::Module mod; + template + R CallFunc(const std::string &name, Args... args) { + auto pf = mod.GetFunction(name, false); + return pf(std::forward(args)...); + } + template + void CallFunc(const std::string &name, Args... args) { + auto pf = mod.GetFunction(name, false); + pf(std::forward(args)...); + return; + } +}; + +template +R CallPackedFunc(const std::string &name, Args... args) { + auto pf = GetPackedFunc(name); + return (*pf)(std::forward(args)...); +} + +template +Function CallPackedFunc(const std::string &name, Args... args) { + auto pf = GetPackedFunc(name); + return (*pf)(std::forward(args)...); +} + +/*! + * \brief Relay build module + * + */ +class RelayBuildModule : public runtime::ModuleNode { + public: + /*! + * \brief Get member function to front-end + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding member function. + */ + PackedFunc GetFunction(const std::string& name, + const std::shared_ptr& sptr_to_self) final { + if (name == "get_graph_json") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetGraphJSON(); + }); + } else if (name == "get_module") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetModule(); + }); + } else if (name == "build") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.num_args, 3); + Array tmp = args[1]; + std::unordered_map targets; + for (size_t i = 0; i < tmp.size(); i += 2) { + auto k = tmp[i].as()->value; + auto v = tmp[i + 1].as()->value; + targets[k] = v; + } + this->Build(args[0], targets, args[2]); + }); + } else if (name == "list_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->ListParamNames(); + }); + } else if (name == "get_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetParams(); + }); + } else if (name == "set_opt_level") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.num_args, 1); + int level = args[0]; + this->SetOptLevel(level); + }); + } else if (name == "set_fallback_device") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string dev = args[0]; + this->SetFallBackDev(dev); + }); + } else if (name == "add_pass") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string pass_name = args[0]; + this->AddPass(pass_name); + }); + } else if (name == "disable_pass") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string pass_name = args[0]; + this->DisablePass(pass_name); + }); + } else if (name == "set_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + Map params = args[0]; + for (const auto& kv : params) { + this->SetParam(kv.first, kv.second->data); + } + }); + } else { + LOG(FATAL) << "Unknown packed function: " << name; + return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); + } + } + + /*! + * \brief Get the GraphJSON for runtime + * + * \return const std::string graph_json + */ + const std::string& GetGraphJSON() { + return ret_.graph_json; + } + /*! + * \brief Add extra pass into build cfg + * + * \param pass_name name of pass + */ + void AddPass(const std::string& pass_name) { + cfg_.enabled_pass.insert(pass_name); + } + /*! + * \brief Disable a specific pass in cfg + * + * \param pass_name name of pass + */ + void DisablePass(const std::string& pass_name) { + cfg_.disabled_pass.insert(pass_name); + } + /*! + * \brief Set the Fallback device + * + * \param device name + */ + void SetFallBackDev(const std::string& dev) { + cfg_.fallback_device = dev; + } + /*! + * \brief Get the Module object + * + * \return runtime::Module + */ + runtime::Module GetModule() { + return ret_.mod; + } + + /*! + * \brief List all paramter names + * + * \return Array names of params + */ + Array ListParamNames() { + Array ret; + for (const auto& kv : params_) { + ret.push_back(ir::StringImm::make(kv.first)); + } + return ret; + } + + /*! + * \brief Get params dictionary + * + * \return Map params dictionary + */ + Map GetParams() { + Map ret; + for (const auto& kv : ret_.params) { + ret.Set(kv.first, ConstantNode::make(kv.second)); + } + return ret; + } + + /*! + * \brief Set the parameters + * + * \param name name of parameter + * \param data_in input DLTensor + */ + void SetParam(const std::string& name, runtime::NDArray data_in) { + params_[name] = data_in; + } + + /*! + * \brief Set the optimization level + * + * \param level + */ + void SetOptLevel(char level) { + cfg_.opt_level = level; + } + + /*! + * \brief type key + * + * \return const char* + */ + const char* type_key() const final { + return "RelayBuildModule"; + } + + /*! + * \brief Build relay function for graph runtime + * + * \param func Relay Function + * \param target Target device + * \param target_host Host target device + */ + void Build(Function func, + const std::unordered_map& targets, + const std::string& target_host) { + targets_ = targets; + target_host_ = target_host; + BuildRelay(func, cfg_, params_); + } + + protected: + /*! + * \brief Bind params to function by using name + * \param func Relay function + * \param params params dict + * \return relay::Function + */ + relay::Function BindParamsByName(relay::Function func, + const std::unordered_map& params) { + std::unordered_map name_dict; + std::unordered_set repeat_var; + for (auto arg : func->params) { + const auto &name = arg->name_hint(); + if (name_dict.count(name)) { + repeat_var.insert(arg); + } else { + name_dict[name] = arg; + } + } + + std::unordered_map bind_dict; + for (auto &kv : params) { + if (name_dict.count(kv.first) == 0) { + continue; + } + auto arg = name_dict.at(kv.first); + if (repeat_var.count(arg)) { + LOG(FATAL) << "Multiple args in the function have name " << kv.first; + } + auto e = CallPackedFunc("relay._make.Constant", kv.second); + bind_dict[arg] = e; + } + return CallPackedFunc("relay._expr.Bind", func, tvm::Map(bind_dict)); + } + + /*! + * \brief Optimize Relay function + * + * \param func Input function + * \param target target device + * \param cfg Relay build config + * \param params params dict + * \return relay::Function + */ + relay::Function Optimize(relay::Function func, + const std::unordered_map& targets, + const RelayBuildConfig& cfg, + const std::unordered_map& params) { + if (params.size()) { + func = BindParamsByName(func, params); + } + if (cfg.pass_enabled("SimplifyInference")) { + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.simplify_inference", func); + } + if (cfg.pass_enabled("EliminateCommonSubexpr")) { + auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Expr expr = args[0]; + if (expr.as()) { + auto call_node = expr.as(); + auto op_node = call_node->op.as(); + if (op_node->name == "cast") { + auto attrs = call_node->attrs.as(); + if (attrs->dtype == HalideIR::Int(32)) { + *rv = true; + } + } + } + *rv = false; + }); + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.eliminate_common_subexpr", func, fskip); + } + if (cfg.pass_enabled("CombineParallelConv2D")) { + const int min_num_branches = 3; + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.CombineParallelConv2D", func, min_num_branches); + } + if (cfg.pass_enabled("FoldConstant")) { + func = CallPackedFunc("relay._ir_pass.FoldConstant", func); + } + if (cfg.pass_enabled("FoldScaleAxis")) { + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.backward_fold_scale_axis", func); + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.forward_fold_scale_axis", func); + func = CallPackedFunc("relay._ir_pass.FoldConstant", func); + } + if (cfg.pass_enabled("CanonicalizeOps")) { + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.canonicalize_ops", func); + } + if (cfg.pass_enabled("AlterOpLayout")) { + if (targets.size() == 1) { + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func); + } else { + LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous" + << " execution yet."; + } + } + if (cfg.pass_enabled("FoldConstant")) { + func = CallPackedFunc("relay._ir_pass.FoldConstant", func); + } + return func; + } + /*! + * \brief Update the target and fallback device required for heterogeneous + * compilation. CPU is used as the fallback device if it wasn't provided. + * Meanwhile, a CPU device type and "llvm" pair will be added to the target + * dictionary in this case. + * + * \param targets dictionary + * \param cfg + * \return Map + */ + Map UpdateHeterogeneousInputs( + const std::unordered_map& targets, + const RelayBuildConfig& cfg) { + Map device_target; + std::unordered_map tmp_map; + auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device); + + for (const auto& kv : targets) { + tmp_map[ContextMap::Str2Mask(kv.first)] = kv.second; + } + if (tmp_map.count(fallback_idx) == 0) { + tmp_map[fallback_idx] = cfg.fallback_device; + } + for (const auto& kv : tmp_map) { + device_target.Set( + ir::IntImm::make(HalideIR::Int(64), kv.first), + ir::StringImm::make(kv.second)); + } + return device_target; + } + /*! + * \brief Execute the device annotation passes to update the input program and + * target information. + * + * \param func + * \param cfg + * \param targets_map_ptr + * \return Function + */ + Function RunDeviceAnnotationPass( + Function func, + const RelayBuildConfig& cfg, + Map* targets_map_ptr) { + auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device); + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, fallback_idx); + auto device_map = CallPackedFunc >("relay._ir_pass.CollectDeviceInfo", + func, + nullptr); + if (device_map.size() == 0) { + auto annotation_map = + CallPackedFunc >("relay._ir_pass.CollectDeviceAnnotationOps", + func, + nullptr); + if (annotation_map.size() == 0) { + targets_map_ptr->Set( + ir::IntImm::make(HalideIR::Int(64), 0), + ir::StringImm::make(cfg.fallback_device)); + } else { + int64_t dev_type = -1; + for (auto kv : annotation_map) { + dev_type = kv.second->value; + break; + } + for (auto kv : annotation_map) { + CHECK_EQ(kv.second->value, dev_type) + << "Expressions in the function are " + << "annotated with various device types," + << "but not device copy operators " + << "found. Please check the " + << "RewriteAnnotation pass."; + } + targets_map_ptr->Set( + ir::IntImm::make(HalideIR::Int(64), 0), + ir::StringImm::make(ContextMap::Mask2Str(dev_type))); + } + } + return func; + } + /*! + * \brief Build module given lowered functions for each target + * + * \param lowered_funcs target_str -> Array map + * \param targets Targets map + * \param cfg Building configuration + */ + void BuildModule(const Map >& lowered_funcs, + const Map& targets, + const BuildConfig& cfg) { + auto target_host = Target::create(cfg_.fallback_device); + for (const auto& kv : lowered_funcs) { + std::unordered_set fname_set; + for (auto f : kv.second) { + if (fname_set.count(f->name)) { + LOG(FATAL) << "Duplicate function name " + << f->name; + } + fname_set.insert(f->name); + } + } + std::unordered_map target_map; + for (const auto& kv : lowered_funcs) { + target_map[kv.first] = Target::create(kv.first); + } + Array fhost_all; + std::vector device_module; + for (const auto& kv : lowered_funcs) { + auto target = target_map[kv.first]; + auto host_dev_funcs = split_dev_host_funcs(kv.second, target, target_host, cfg); + for (auto f : host_dev_funcs[0]) { + fhost_all.push_back(f); + } + if (host_dev_funcs[1].size()) { + auto mdev = codegen::Build(host_dev_funcs[1], target->str()); + device_module.push_back(mdev); + } + } + + auto mhost = codegen::Build(fhost_all, target_host->str()); + + for (auto mdev : device_module) { + mhost.Import(mdev); + } + ret_.mod = mhost; + } + + /*! + * \brief Build relay function to runtime module + * + * \param func Relay Function + * \param cfg Relay build config + * \param params parameters + */ + void BuildRelay(Function func, + const RelayBuildConfig& cfg, + const std::unordered_map ¶ms) { + // convert + tvm_cfg_ = build_config(); + Map device_target; + if (targets_.size() > 1) { + device_target = UpdateHeterogeneousInputs(targets_, cfg); + } else { + for (auto &kv : targets_) { + device_target.Set( + ir::IntImm::make(HalideIR::Int(64), ContextMap::Str2Mask(kv.first)), + ir::StringImm::make(kv.second)); + } + } + func = Optimize(func, targets_, cfg, params); + if (device_target.size() > 1) { + func = RunDeviceAnnotationPass(func, cfg, &device_target); + } + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.FuseOps", func, cfg.opt_level); + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + + graph_codegen_ = std::unique_ptr(new GraphCodegen()); + graph_codegen_->Init(nullptr, device_target); + graph_codegen_->Codegen(func); + + ret_.graph_json = graph_codegen_->GetJSON(); + ret_.params = graph_codegen_->GetParams(); + + BuildModule(graph_codegen_->GetLoweredFunc(), + device_target, + tvm_cfg_); + } + + protected: + std::unique_ptr graph_codegen_; + /*! \brief target device */ + std::unordered_map targets_; + /*! \brief target host device */ + std::string target_host_; + /*! \brief frontend optimization configure */ + RelayBuildConfig cfg_; + /*! \brief parameters */ + std::unordered_map params_; + /*! \brief building output */ + BuildOutput ret_; + /*! \brief tvm building cfg */ + BuildConfig tvm_cfg_; +}; + +runtime::Module RelayBuildCreate() { + std::shared_ptr exec = std::make_shared(); + return runtime::Module(exec); +} + +TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = RelayBuildCreate(); +}); + +} // namespace backend +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 4b5842c36020..a824c457107a 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -371,7 +371,9 @@ class CompileEngineImpl : public CompileEngineNode { cache_node->funcs = (*f)( spair.first, all_args, cache_node->func_name, key->source_func); } else { - LOG(FATAL) << "relay.backend.lower is not registred"; + tvm::BuildConfig bcfg = tvm::build_config(); + std::unordered_map binds; + cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg); } value->cached_func = CachedFunc(cache_node); return value; diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 7f16891da8a7..415e0ec9c2a5 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -416,7 +416,12 @@ class GraphRuntimeCodegen } else { // heterogeneous execution. const auto call_dev_key = std::to_string(call_dev_type); - const auto call_dev_name = runtime::DeviceName(call_dev_type); + std::string call_dev_name; + if (call_dev_type == 0) { + call_dev_name = "llvm"; + } else { + call_dev_name = runtime::DeviceName(call_dev_type); + } if (targets_.count(call_dev_name) == 0 && targets_.count(call_dev_key) == 0) { LOG(FATAL) << "No target is provided for device " << call_dev_name; diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc new file mode 100644 index 000000000000..38481bfb8204 --- /dev/null +++ b/tests/cpp/relay_build_module_test.cc @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +TVM_REGISTER_GLOBAL("test.sch") +.set_body([](tvm::TVMArgs args, tvm::TVMRetValue *rv) { + *rv = topi::generic::schedule_injective(args[0], args[1]); + }); + +TEST(Relay, BuildModule) { + using namespace tvm; + auto tensor_type = relay::TensorTypeNode::make({2, 3}, ::tvm::Float(32)); + auto a = relay::VarNode::make("a", tensor_type); + auto b = relay::VarNode::make("b", tensor_type); + auto add_op = relay::Op::Get("add"); + auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {}); + auto c = relay::VarNode::make("c", tensor_type); + auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {}); + auto func = relay::FunctionNode::make(relay::FreeVars(y), y, relay::Type(), {}); + auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + + auto pA = (float*)A.ToDLPack()->dl_tensor.data; + auto pB = (float*)B.ToDLPack()->dl_tensor.data; + auto pC = (float*)C.ToDLPack()->dl_tensor.data; + + for (int i = 0; i < 6; ++i) { + pA[i] = i; + pB[i] = i + 1; + pC[i] = i + 2; + } + // get schedule + auto reg = tvm::runtime::Registry::Get("relay.op._Register"); + auto s_i = tvm::runtime::Registry::Get("test.sch"); + if (!reg) { + LOG(FATAL) << "no _Register"; + } + if (!s_i) { + LOG(FATAL) << "no _Register"; + } + (*reg)("add", "FTVMSchedule", *s_i, 10); + // build + auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule"); + tvm::runtime::Module build_mod = (*pfb)(); + auto build_f = build_mod.GetFunction("build", false); + auto json_f = build_mod.GetFunction("get_graph_json", false); + auto mod_f = build_mod.GetFunction("get_module", false); + Array target_pair; + target_pair.push_back(ir::StringImm::make("cpu")); + target_pair.push_back(ir::StringImm::make("llvm")); + build_f(func, target_pair, "llvm"); + std::string json = json_f(); + tvm::runtime::Module mod = mod_f(); + // run + auto ctx = A->ctx; + auto pfr = tvm::runtime::Registry::Get("tvm.graph_runtime.create"); + tvm::runtime::Module run_mod = (*pfr)(json, mod, (int)ctx.device_type, (int)ctx.device_id); + auto set_input_f = run_mod.GetFunction("set_input", false); + auto run_f = run_mod.GetFunction("run", false); + auto get_output_f = run_mod.GetFunction("get_output", false); + set_input_f("a", A); + set_input_f("b", B); + set_input_f("c", C); + run_f(); + tvm::runtime::NDArray Y = get_output_f(0); + auto pY = (float*)Y.ToDLPack()->dl_tensor.data; + for (int i = 0; i < 6; ++i) { + CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4); + } +} + +int main(int argc, char ** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py new file mode 100644 index 000000000000..c69d877d3b09 --- /dev/null +++ b/tests/python/relay/test_cpp_build_module.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np + +import tvm +from tvm import relay + +from tvm._ffi.function import _init_api +_init_api("tvm.relay.build_module") + +class BuildModule(object): + def __init__(self): + self.mod = relay.build_module._BuildModule() + self._get_graph_json = self.mod["get_graph_json"] + self._get_module = self.mod["get_module"] + self._build = self.mod["build"] + self._set_opt_level = self.mod["set_opt_level"] + self._set_params_func = self.mod["set_params"] + self._get_params_func = self.mod["get_params"] + + + def build(self, func, target, target_host, params): + tgts = [] + for kv in target.items(): + tgts.append(kv[0]) + tgts.append(kv[1]) + self._set_params(params) + self._build(func, tgts, target_host) + + def get_json(self): + return self._get_graph_json() + + def get_module(self): + return self._get_module() + + def set_opt_level(self, level): + self._set_opt_level(level) + + def _set_params(self, params): + inputs = {} + for name, param in params.items(): + inputs[name] = relay.Constant(param) + self._set_params_func(inputs) + + def get_params(self): + params = self._get_params_func() + ret = {} + for key, value in params.items(): + ret[key] = value.data + return ret + + +def test_build(): + m_bld = BuildModule() + tgt_name = "llvm" + tgt = "llvm" + ctx = tvm.cpu() + # func + a = relay.var("a", dtype="float32", shape=(16, 8)) + b = relay.var("b", dtype="float32", shape=(8, 8)) + c = relay.var("c", dtype="float32", shape=(16, 8)) + x = relay.nn.dense(a, b) + y = relay.nn.relu(x) + z = y + c + func = relay.Function([a, b, c], z) + A = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx) + B = tvm.nd.array(np.random.uniform(-1, 1, (8, 8)).astype("float32"), ctx=ctx) + C = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx) + params = { + "b" : B, + "c" : C + } + # build + targets = { + tgt: tgt + } + m_bld.set_opt_level(3) + m_bld.build(func, targets, "llvm -mcpu=sse3", params=params) + g_json = m_bld.get_json() + mmod = m_bld.get_module() + params = m_bld.get_params() + + # test + rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) + rt.set_input("a", A) + rt.load_params(relay.save_param_dict(params)) + rt.run() + out = rt.get_output(0) + + np.testing.assert_allclose(out.asnumpy(), + np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(), atol=1e-5, rtol=1e-5) +