Skip to content

Commit

Permalink
return module
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Feb 15, 2019
1 parent 18dcc41 commit f1d2030
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 50 deletions.
57 changes: 30 additions & 27 deletions include/tvm/relay/optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,29 +109,32 @@ class PassNode : public RelayNode {

/*!
* \brief Execute the optimization pass using a functor. This functor invokes
* the `run` method to perform a real optimization on a certain type of node.
* the `run` method to perform a real optimization on a certain type
* of node.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The context information that is used to help perform
* a given pass.
*
* \return The updated module.
*/
void operator()(Module* mod, const PassContext& pass_ctx) {
Run(mod, pass_ctx);
Module operator()(const Module& mod, const PassContext& pass_ctx) const {
return Run(mod, pass_ctx);
}

/*!
* \brief Execute the optimization pass. This is function should be specilized
* for different types of Relay nodes. For example, we mainly allow
* transformation of from Module/Function to Module/Function. Note that the
* module will be updated.
* for different types of Relay nodes. For example, we mainly allow
* transformation of from Module/Function to Module/Function. Note that
* the module will be updated.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The context information that is used to help perform
* a given pass.
*
* \return Return the updated module through mod.
* \return Return the updated module.
*/
virtual void Run(Module* mod, const PassContext& pass_ctx) const = 0;
virtual Module Run(const Module& mod, const PassContext& pass_ctx) const = 0;

void VisitAttrs(tvm::AttrVisitor* v) override {
v->Visit("name", &name);
Expand Down Expand Up @@ -186,9 +189,9 @@ class ModulePassNode : public PassNode {
* \param pass_ctx The context information that is used to help perform
* a given module pass.
*
* \return Return the updated module through mod.
* \return Return the updated module.
*/
void Run(Module* mod, const PassContext& pass_ctx) const override;
Module Run(const Module& mod, const PassContext& pass_ctx) const override;

TVM_DLL static ModulePass make(std::string name, int opt_level,
PassFunc<Module> pass_func,
Expand Down Expand Up @@ -233,9 +236,9 @@ class FunctionPassNode : public PassNode {
* \param pass_ctx The context information that is used to help perform
* a given pass.
*
* \return Return the updated module through mod.
* \return Return the updated module.
*/
void Run(Module* mod, const PassContext& pass_ctx) const override;
Module Run(const Module& mod, const PassContext& pass_ctx) const override;

TVM_DLL static FunctionPass make(std::string name, int opt_level,
PassFunc<Function> pass_func,
Expand Down Expand Up @@ -296,40 +299,40 @@ class Optimizer {
* overloaded to focus on different metrics, i.e. performance, memory
* footprint, etc.
*/
void Optimize() const;
Module Optimize();

private:
/* \brief The module where a host of passes are executed on. It is designed
* to be mutable because each optimization is likely to update the module
* on its completion.
/* \brief The module where a host of passes are executed on. It will be
* updated by each pass on its completion as they might need to update the
* module.
*/
mutable Module module_;
Module module_;
/* \brief The pass candidates for optimizations. */
tvm::Array<Pass> passes_;
/* \brief The auxiliary pass context/information that is used to help perform
* the given list of passes.*/
PassContext pass_ctx_;
friend void Optimize(const tvm::Array<Pass>& passes,
Module* mod,
const PassContext& pass_ctx);
friend Module Optimize(const tvm::Array<Pass>& passes,
const Module& mod,
const PassContext& pass_ctx);
};

/*!
* \brief Optimizes the functions and/or expressions in the module. This free
* function is designed as a template function that could take different types
* of Relay nodes.
* function is designed as a template function that could take different
* types of Relay nodes.
*
* \param passes The optimization passes.
* \param mod The module where optimizations are performed on.
* Note that the updated module will be stored and returned.
* \param pass_ctx The auxiliary pass context/information that is used to help
* perform the provided passes.
*
* \return Return the updated Module through mod.
* \return Return the updated module.
*/
void Optimize(const tvm::Array<Pass>& passes,
Module* mod,
const PassContext& pass_ctx);
Module Optimize(const tvm::Array<Pass>& passes,
const Module& mod,
const PassContext& pass_ctx);

} // namespace optimize
} // namespace relay
Expand Down
51 changes: 28 additions & 23 deletions src/relay/pass/optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,18 @@ ModulePass ModulePassNode::make(std::string name, int opt_level,
}

// Module -> Module optimizations.
// TODO(zhiics) Check and handle the required passes.
void ModulePassNode::Run(Module* mod, const PassContext& pass_ctx) const {
// TODO(zhiics) 1. Check and handle the required passes.
// 2. Probably use CoW for all places that use module instead of
// returning the updated one.
Module ModulePassNode::Run(const Module& mod,
const PassContext& pass_ctx) const {
LOG(INFO) << "Executing module pass : " << this->name
<< " with opt level: " << opt_level << "\n";
CHECK(mod->defined());
auto foreach = pass_func(*mod);
*mod = foreach(*mod);
CHECK(mod->defined());
CHECK(mod.defined());
auto foreach = pass_func(mod);
auto updated_mod = foreach(mod);
CHECK(updated_mod.defined());
return std::move(updated_mod);
}

FunctionPass FunctionPassNode::make(std::string name, int opt_level,
Expand All @@ -64,13 +68,14 @@ FunctionPass FunctionPassNode::make(std::string name, int opt_level,

// Perform Module -> Module optimizations at the Function level.
// TODO(zhiics) Check and handle the required passes.
void FunctionPassNode::Run(Module* mod, const PassContext& pass_ctx) const {
Module FunctionPassNode::Run(const Module& mod,
const PassContext& pass_ctx) const {
LOG(INFO) << "Executing function pass : " << this->name
<< " with opt level: " << this->opt_level << "\n";
CHECK(mod->defined());
auto foreach = pass_func(*mod);
CHECK(mod.defined());
auto foreach = pass_func(mod);
std::vector<std::pair<GlobalVar, Function>> updated_funcs;
ModuleNode* mod_node = (*mod).operator->();
ModuleNode* mod_node = mod.operator->();
for (const auto& it : mod_node->functions) {
if (!SkipFunction(it.second)) {
auto updated_func = foreach(it.second);
Expand All @@ -83,6 +88,8 @@ void FunctionPassNode::Run(Module* mod, const PassContext& pass_ctx) const {
for (const auto& it : updated_funcs) {
mod_node->Update(it.first, it.second);
}

return GetRef<Module>(mod_node);
}

// TODO(zhiics) Create an enum attribute for FunctionNode
Expand All @@ -93,20 +100,21 @@ bool FunctionPassNode::SkipFunction(const Function& func) const {
return pval && pval->value != 0;
}

void Optimizer::Optimize() const {
Module Optimizer::Optimize() {
for (const Pass& pass : passes_) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
pass->Run(&module_, pass_ctx_);
module_ = pass->Run(module_, pass_ctx_);
}
return module_;
}

void Optimize(const tvm::Array<Pass>& passes,
Module* mod,
const PassContext& pass_ctx) {
Module Optimize(const tvm::Array<Pass>& passes,
const Module& mod,
const PassContext& pass_ctx) {
LOG(INFO) << "Start executing optimization passes." << "\n";
Optimizer pm(*mod, passes, pass_ctx);
Optimizer pm(mod, passes, pass_ctx);
pm.Optimize();
*mod = pm.module_;
return pm.module_;
}

TVM_REGISTER_NODE_TYPE(ModulePassNode);
Expand Down Expand Up @@ -138,8 +146,7 @@ TVM_REGISTER_API("relay._optimize.RunModulePass")
CHECK(pass.defined())
<< "Running a pass on undefined ModulePass is not allowed."
<< "\n";
pass->Run(&mod, pass_ctx);
*ret = mod;
*ret = pass->Run(mod, pass_ctx);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
Expand Down Expand Up @@ -178,8 +185,7 @@ TVM_REGISTER_API("relay._optimize.RunFunctionPass")
CHECK(pass.defined())
<< "Running a pass on undefined ModulePass is not allowed."
<< "\n";
pass->Run(&mod, pass_ctx);
*ret = mod;
*ret = pass->Run(mod, pass_ctx);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
Expand Down Expand Up @@ -207,8 +213,7 @@ TVM_REGISTER_API("relay._optimize.Optimize")
tvm::Array<Pass> passes = args[0];
Module mod = args[1];
PassContext pass_ctx = args[2];
Optimize(passes, &mod, pass_ctx);
*ret = mod;
*ret = Optimize(passes, mod, pass_ctx);
});

} // namespace optimize
Expand Down

0 comments on commit f1d2030

Please sign in to comment.