Skip to content

Commit

Permalink
move all passes to optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed May 31, 2019
1 parent 83c2527 commit a845b0f
Showing 1 changed file with 52 additions and 74 deletions.
126 changes: 52 additions & 74 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,30 +265,6 @@ class RelayBuildModule : public runtime::ModuleNode {
return ret;
}

/*!
* \brief Optimize a Relay function.
*
* \param func The input Relay function for optimization.
* \param targets The device type to `Target` mapping.
* \param params The param name to value mapping.
*
* \return func The updated Relay function after optimization.
*/
relay::Function Optimize(
relay::Function func,
const TargetsMap& targets,
const std::unordered_map<std::string, runtime::NDArray>& params) {
if (params.size()) {
func = BindParamsByName(func, params);
}

relay::Module relay_module = relay::ModuleNode::FromExpr(func);
relay_module = Optimize(relay_module, targets_, params);
CHECK(relay_module.defined());
GlobalVar var = relay_module->entry_func;
return relay_module->Lookup(var->name_hint);
}

/*!
* \brief Optimize a Relay module.
*
Expand Down Expand Up @@ -340,6 +316,18 @@ class RelayBuildModule : public runtime::ModuleNode {
} else {
relay_module = seq(relay_module);
}

// Handle heterogeneous compilation.
transform::PassContext pass_ctx = PassContext::Current();
if (targets_.size() > 1) {
relay_module =
RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device);
}

// Fuse the operations if it is needed.
relay_module = transform::FuseOps()(relay_module);
relay_module = transform::InferType()(relay_module);

return relay_module;
}

Expand All @@ -354,55 +342,58 @@ class RelayBuildModule : public runtime::ModuleNode {
if (name == "gpu") return Target::Create("cuda");
return Target::Create(name);
}

/*!
* \brief Update the target and fallback device required for heterogeneous
* compilation. CPU is used as the fallback device if it wasn't provided.
* Meanwhile, a CPU device type and "llvm" pair will be added to the target
* dictionary in this case.
*
* \param targets dictionary
* \param fallback_device The fallback device for heterogeneous execution.
* \return Map<tvm::Integer, tvm::Target>
*/
TargetsMap UpdateHeterogeneousInputs(const TargetsMap& targets,
int fallback_device) {
TargetsMap device_target = targets;
void UpdateHeterogeneousInputs(int fallback_device) {
std::unordered_map<int64_t, tvm::Target> tmp_map;
for (const auto& kv : targets) {
for (const auto& kv : targets_) {
tmp_map[kv.first->value] = kv.second;
}
if (tmp_map.count(fallback_device) == 0) {
device_target.Set(
fallback_device,
CreateDefaultTarget(fallback_device));
targets_.Set(fallback_device, CreateDefaultTarget(fallback_device));
}
return device_target;
}

/*!
* \brief Execute the device annotation passes to update the input program and
* target information.
*
* \param func The input Relay function.
* \param relay_module The input Relay module.
* \param fallback_device The fallback device for heterogeneous execution.
* \param targets_map_ptr The device type to `Target` map pointer.
*
* \return func The updated function after device annotation.
* \return updated_module The updated module after device annotation.
*/
Function RunDeviceAnnotationPass(Function func,
int fallback_device,
TargetsMap* targets_map_ptr) {
relay::Module relay_module = relay::ModuleNode::FromExpr(func);
relay::Module RunDeviceAnnotationPass(const relay::Module& relay_module,
int fallback_device) {
UpdateHeterogeneousInputs(fallback_device);
auto rewrite = transform::RewriteAnnotatedOps(fallback_device);
relay_module = rewrite(relay_module);
CHECK(relay_module.defined());
func = relay_module->Lookup(relay_module->entry_func->name_hint);
CHECK(func.defined());
auto device_map = relay::CollectDeviceInfo(func);
if (device_map.size() == 0) {
auto annotation_map = relay::CollectDeviceAnnotationOps(func);
if (annotation_map.size() == 0) {
targets_map_ptr->Set(0, CreateDefaultTarget(fallback_device));
auto updated_module = rewrite(relay_module);
CHECK(updated_module.defined());

tvm::Map<Expr, Integer> device_map;
for (const auto& it : updated_module->functions) {
device_map = relay::CollectDeviceInfo(it.second);
if (!device_map.empty()) break;
}

if (device_map.empty()) {
tvm::Map<Expr, Integer> annotation_map;
for (const auto& it : relay_module->functions) {
annotation_map = relay::CollectDeviceAnnotationOps(it.second);
if (!annotation_map.empty()) break;
}
// None op is annotated but they are fallen back to the default device.
if (annotation_map.empty()) {
targets_.Set(0, CreateDefaultTarget(fallback_device));
} else {
// All ops are annotated to the same device type.
int64_t dev_type = -1;
for (auto kv : annotation_map) {
dev_type = kv.second->value;
Expand All @@ -416,10 +407,10 @@ class RelayBuildModule : public runtime::ModuleNode {
<< "found. Please check the "
<< "RewriteAnnotation pass.";
}
targets_map_ptr->Set(0, CreateDefaultTarget(dev_type));
targets_.Set(0, CreateDefaultTarget(dev_type));
}
}
return func;
return updated_module;
}

/*!
Expand All @@ -431,38 +422,27 @@ class RelayBuildModule : public runtime::ModuleNode {
void BuildRelay(
Function func,
const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
transform::PassContext pass_ctx = PassContext::Current();

// convert
tvm_cfg_ = BuildConfig::Create();
TargetsMap device_target;
if (targets_.size() > 1) {
device_target =
UpdateHeterogeneousInputs(targets_, pass_ctx->fallback_device);
} else {
device_target = targets_;
}
func = Optimize(func, targets_, params);
if (device_target.size() > 1) {
func = RunDeviceAnnotationPass(func, pass_ctx->fallback_device,
&device_target);
if (params.size()) {
func = BindParamsByName(func, params);
}

// Perform Module->Module optimizations.
relay::Module relay_module = relay::ModuleNode::FromExpr(func);
relay_module = transform::InferType()(relay_module);
relay_module = transform::FuseOps()(relay_module);
relay_module = transform::InferType()(relay_module);
relay_module = Optimize(relay_module, targets_, params);
CHECK(relay_module.defined());
// Get the updated function.
func = relay_module->Lookup(relay_module->entry_func->name_hint);

// Generate code for the updated function.
graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
graph_codegen_->Init(nullptr, device_target);
graph_codegen_->Init(nullptr, targets_);
graph_codegen_->Codegen(func);

ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams();

ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, tvm_cfg_);
ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_,
BuildConfig::Current());
}

protected:
Expand All @@ -475,8 +455,6 @@ class RelayBuildModule : public runtime::ModuleNode {
std::unordered_map<std::string, runtime::NDArray> params_;
/*! \brief building output */
BuildOutput ret_;
/*! \brief tvm building cfg */
BuildConfig tvm_cfg_;
};

runtime::Module RelayBuildCreate() {
Expand Down

0 comments on commit a845b0f

Please sign in to comment.