Skip to content

Commit

Permalink
move more to transform
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed May 30, 2019
1 parent 031bfe9 commit 83c2527
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 25 deletions.
17 changes: 1 addition & 16 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -480,22 +480,7 @@ TVM_DLL Pass InferType();
*
* \return The pass.
*/
TVM_DLL Pass EliminateCommonSubexpr(
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
if (expr.as<CallNode>()) {
auto call_node = expr.as<CallNode>();
auto op_node = call_node->op.as<OpNode>();
if (op_node->name == "cast") {
auto attrs = call_node->attrs.as<CastAttrs>();
if (attrs->dtype == HalideIR::Int(32)) {
*rv = true;
}
}
}
*rv = false;
})
);
TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr);

/*!
* \brief Combine parallel 2d convolutions into a single convolution if the
Expand Down
35 changes: 26 additions & 9 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,21 @@ class RelayBuildModule : public runtime::ModuleNode {
const std::unordered_map<std::string, runtime::NDArray>& params) {
Array<Pass> pass_seqs;
pass_seqs.push_back(transform::SimplifyInference());
pass_seqs.push_back(transform::EliminateCommonSubexpr());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
if (expr.as<CallNode>()) {
auto call_node = expr.as<CallNode>();
auto op_node = call_node->op.as<OpNode>();
if (op_node->name == "cast") {
auto attrs = call_node->attrs.as<CastAttrs>();
if (attrs->dtype == HalideIR::Int(32)) {
*rv = true;
}
}
}
*rv = false;
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeOps());
Expand Down Expand Up @@ -377,9 +391,11 @@ class RelayBuildModule : public runtime::ModuleNode {
Function RunDeviceAnnotationPass(Function func,
int fallback_device,
TargetsMap* targets_map_ptr) {
auto new_func = relay::InferType(func, Module(nullptr));
new_func = relay::RewriteAnnotatedOps(new_func, fallback_device);
func = Downcast<Function>(new_func);
relay::Module relay_module = relay::ModuleNode::FromExpr(func);
auto rewrite = transform::RewriteAnnotatedOps(fallback_device);
relay_module = rewrite(relay_module);
CHECK(relay_module.defined());
func = relay_module->Lookup(relay_module->entry_func->name_hint);
CHECK(func.defined());
auto device_map = relay::CollectDeviceInfo(func);
if (device_map.size() == 0) {
Expand Down Expand Up @@ -432,11 +448,12 @@ class RelayBuildModule : public runtime::ModuleNode {
&device_target);
}

auto new_func = relay::InferType(func, Module(nullptr));
new_func = relay::FuseOps(new_func, pass_ctx->opt_level, Module(nullptr));
new_func = relay::InferType(new_func, Module(nullptr));
func = Downcast<Function>(new_func);
CHECK(func.defined());
relay::Module relay_module = relay::ModuleNode::FromExpr(func);
relay_module = transform::InferType()(relay_module);
relay_module = transform::FuseOps()(relay_module);
relay_module = transform::InferType()(relay_module);
CHECK(relay_module.defined());
func = relay_module->Lookup(relay_module->entry_func->name_hint);

graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
graph_codegen_->Init(nullptr, device_target);
Expand Down

0 comments on commit 83c2527

Please sign in to comment.