Skip to content

Commit

Permalink
inline using passfunc
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Mar 3, 2019
1 parent 0aa749a commit e98f312
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 27 deletions.
19 changes: 6 additions & 13 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,6 @@ class Pass : public NodeRef {
using ContainerType = PassNode;
};

// Define pass function at different granularity. It runs on a certain Relay
// node type and yields a new node with the same type. Each pass function
// sketches an optimization, i.e. how we want to mutate an AST, and it is used
// as a packed function that will be invoked when called by the functor of each
// pass.
using ModulePassFunc = runtime::TypedPackedFunc<Module(Module, PassContext)>;
using FunctionPassFunc =
runtime::TypedPackedFunc<Function(Function, PassContext)>;

/*
* \brief Create a module pass.
*
Expand All @@ -158,8 +149,9 @@ using FunctionPassFunc =
*
* \return The created module pass.
*/
Pass CreateModulePass(const std::string& name, int opt_level,
const ModulePassFunc& pass_func);
Pass CreateModulePass(
const std::string& name, int opt_level,
const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func);

/*
* \brief Create a function pass.
Expand All @@ -170,8 +162,9 @@ Pass CreateModulePass(const std::string& name, int opt_level,
*
* \return The created function pass.
*/
Pass CreateFunctionPass(const std::string& name, int opt_level,
const FunctionPassFunc& pass_func);
Pass CreateFunctionPass(
const std::string& name, int opt_level,
const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func);
/*
* \brief Create a sequential pass.
*
Expand Down
35 changes: 21 additions & 14 deletions src/relay/pass/pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ModulePassNode : public PassNode {
* implement the algorithm in the `pass_func` and let it run on a module. It
* will then remove the dead code including the unused functions in the module.
*/
ModulePassFunc pass_func;
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;

ModulePassNode() = default;

Expand Down Expand Up @@ -55,8 +55,9 @@ class ModulePassNode : public PassNode {
*/
void SetContext(const PassContext& pass_ctx) final;

TVM_DLL static ModulePass make(std::string name, int opt_level,
ModulePassFunc pass_func);
TVM_DLL static ModulePass make(
std::string name, int opt_level,
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func);

static constexpr const char* _type_key = "relay.ModulePass";
TVM_DECLARE_NODE_TYPE_INFO(ModulePassNode, PassNode);
Expand Down Expand Up @@ -88,7 +89,7 @@ class FunctionPassNode : public PassNode {
* `pass_func` and let it run on a given module. The same `pass_func` will
* then be applied on each function in the module.
*/
FunctionPassFunc pass_func;
runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func;

FunctionPassNode() = default;

Expand Down Expand Up @@ -116,8 +117,9 @@ class FunctionPassNode : public PassNode {
*/
void SetContext(const PassContext& pass_ctx) final;

TVM_DLL static FunctionPass make(std::string name, int opt_level,
FunctionPassFunc pass_func);
TVM_DLL static FunctionPass make(
std::string name, int opt_level,
runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func);

static constexpr const char* _type_key = "relay.FunctionPass";
TVM_DECLARE_NODE_TYPE_INFO(FunctionPassNode, PassNode);
Expand Down Expand Up @@ -234,8 +236,9 @@ PassContext PassContextNode::make() {
return PassContext(ctx);
}

ModulePass ModulePassNode::make(std::string name, int opt_level,
ModulePassFunc pass_func) {
ModulePass ModulePassNode::make(
std::string name, int opt_level,
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func) {
auto n = make_node<ModulePassNode>();
n->name = std::move(name);
n->opt_level = std::move(opt_level);
Expand Down Expand Up @@ -266,8 +269,9 @@ void ModulePassNode::SetContext(const PassContext& pass_ctx) {
pass_ctx_ = pass_ctx;
}

FunctionPass FunctionPassNode::make(std::string name, int opt_level,
FunctionPassFunc pass_func) {
FunctionPass FunctionPassNode::make(
std::string name, int opt_level,
runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func) {
auto n = make_node<FunctionPassNode>();
n->name = std::move(name);
n->opt_level = std::move(opt_level);
Expand Down Expand Up @@ -368,13 +372,16 @@ void SequentialPassNode::SetContext(const PassContext& pass_ctx) {
pass_ctx_ = pass_ctx;
}

Pass CreateModulePass(const std::string& name, int opt_level,
const ModulePassFunc& pass_func) {
Pass CreateModulePass(
const std::string& name, int opt_level,
const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func) {
return ModulePassNode::make(name, opt_level, pass_func);
}

Pass CreateFunctionPass(const std::string& name, int opt_level,
const FunctionPassFunc& pass_func) {
Pass CreateFunctionPass(
const std::string& name, int opt_level,
const runtime::TypedPackedFunc<Function(Function, PassContext)>&
pass_func) {
return FunctionPassNode::make(name, opt_level, pass_func);
}

Expand Down

0 comments on commit e98f312

Please sign in to comment.