From ab01b3aae36ef88ec299ff5e9d1bbdf5fd7268f6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 13 Dec 2021 11:55:06 +0900 Subject: [PATCH] make constant binding in PartitionGraph optional --- src/relay/transforms/partition_graph.cc | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 4a21bc87411b..151fcb630700 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -115,7 +115,8 @@ struct RegionFuncMetadata { class Partitioner : public MixedModeMutator { public: - explicit Partitioner(const IRModule& module) : module_(module) { + Partitioner(const IRModule& module, bool bind_constants) + : module_(module), bind_constants_(bind_constants) { std::set func_names; for (auto f : module->functions) { GlobalVar f_var = f.first; @@ -293,7 +294,7 @@ class Partitioner : public MixedModeMutator { Map params_bind; for (auto pair : region_func_meta_[region].args) { params.push_back(pair.first); - if (IsConstant(pair.second)) { + if (bind_constants_ && IsConstant(pair.second)) { params_bind.Set(pair.first, pair.second); } else { param_expr.push_back(pair.second); @@ -401,6 +402,9 @@ class Partitioner : public MixedModeMutator { /*!\brief The IRModule used for partitioning. */ IRModule module_; + + /*!\brief Whether or not to bind constants in partitioned subgraphs. */ + bool bind_constants_{false}; }; IRModule RemoveDefaultAnnotations(IRModule module) { @@ -562,7 +566,7 @@ class NameMangleExtFuncs : public MixedModeMutator { namespace transform { -Pass PartitionGraph(String mod_name) { +Pass PartitionGraph(String mod_name, bool bind_constants) { runtime::TypedPackedFunc flatten_tuples = [=](IRModule m, PassContext pc) { // There could be compiler_end annotations on tuples @@ -581,8 +585,10 @@ Pass PartitionGraph(String mod_name) { return partitioning::RemoveDefaultAnnotations(m); }; - runtime::TypedPackedFunc part_func = - [=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); }; + runtime::TypedPackedFunc part_func = [=](IRModule m, + PassContext pc) { + return partitioning::Partitioner(m, bind_constants).Partition(); + }; auto name_mangling_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); @@ -601,9 +607,10 @@ Pass PartitionGraph(String mod_name) { {flatten_tuples_pass, remove_default_pass, partition_pass, name_mangling_pass, InferType()}); } -TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph").set_body_typed([](String mod_name) { - return transform::PartitionGraph(mod_name); -}); +TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph") + .set_body_typed([](String mod_name, bool bind_constants = false) { + return transform::PartitionGraph(mod_name, bind_constants); + }); } // namespace transform