Skip to content

Commit

Permalink
default fskip for cse
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed May 28, 2019
1 parent 1d37f52 commit 25fcc77
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 27 deletions.
21 changes: 19 additions & 2 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@

#include <tvm/base.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -301,7 +303,7 @@ Pass CreateModulePass(
* \return The created function pass.
*/
TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
Function(Function, Module, PassContext)>& pass_func,
Function(Function, Module, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);
Expand Down Expand Up @@ -441,7 +443,22 @@ TVM_DLL Pass InferType();
*
* \return The pass.
*/
TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip);
TVM_DLL Pass EliminateCommonSubexpr(
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
if (expr.as<CallNode>()) {
auto call_node = expr.as<CallNode>();
auto op_node = call_node->op.as<OpNode>();
if (op_node->name == "cast") {
auto attrs = call_node->attrs.as<CastAttrs>();
if (attrs->dtype == HalideIR::Int(32)) {
*rv = true;
}
}
}
*rv = false;
})
);

/*!
* \brief Combine parallel 2d convolutions into a single convolution if the
Expand Down
21 changes: 1 addition & 20 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
*/
#include <tvm/build_module.h>
#include <tvm/runtime/device_api.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/transform.h>
#include <memory>

Expand Down Expand Up @@ -307,24 +305,7 @@ class RelayBuildModule : public runtime::ModuleNode {
const std::unordered_map<std::string, runtime::NDArray>& params) {
Array<Pass> pass_seqs;
pass_seqs.push_back(transform::SimplifyInference());

// Can we move to the pass implementation file and make it as default?
auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
if (expr.as<CallNode>()) {
auto call_node = expr.as<CallNode>();
auto op_node = call_node->op.as<OpNode>();
if (op_node->name == "cast") {
auto attrs = call_node->attrs.as<CastAttrs>();
if (attrs->dtype == HalideIR::Int(32)) {
*rv = true;
}
}
}
*rv = false;
});

pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::EliminateCommonSubexpr());
pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeOps());
Expand Down
6 changes: 3 additions & 3 deletions src/relay/pass/pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ class SequentialNode : public PassNode {
*
* \return true if the pass is enabled. Otherwise, false.
*/
bool pass_enabled(const std::string& pass_name) const;
bool PassEnabled(const std::string& pass_name) const;

/*!
* \brief Resolve the pass dependency. It globs all required passes by
Expand Down Expand Up @@ -473,7 +473,7 @@ std::unordered_set<std::string> SequentialNode::RequiredPasses(
return ret;
}

bool SequentialNode::pass_enabled(const std::string& pass_name) const {
bool SequentialNode::PassEnabled(const std::string& pass_name) const {
PassContext ctx = PassContext::Current();

auto required = RequiredPasses(ctx->required_pass);
Expand Down Expand Up @@ -512,7 +512,7 @@ Module SequentialNode::operator()(const Module& module,
PassInfo info = pass->Info();
const auto& pass_name = info->name;
// Execute the pass if it is enabled.
if (pass_enabled(pass_name)) {
if (PassEnabled(pass_name)) {
const auto* pn = pass.operator->();
mod = (*pn)(mod, pass_ctx);
}
Expand Down
3 changes: 1 addition & 2 deletions tests/cpp/relay_transform_sequential.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ TEST(Relay, Sequential) {
tvm::Array<relay::transform::Pass> pass_seqs;
pass_seqs.push_back(relay::transform::InferType());
pass_seqs.push_back(relay::transform::DeadCodeElimination());
pass_seqs.push_back(
relay::transform::EliminateCommonSubexpr(tvm::PackedFunc(nullptr)));
pass_seqs.push_back(relay::transform::EliminateCommonSubexpr());
pass_seqs.push_back(relay::transform::AlterOpLayout());

relay::GlobalVar var = relay::GlobalVarNode::make("main");
Expand Down

0 comments on commit 25fcc77

Please sign in to comment.