Skip to content

Commit

Permalink
make constant binding in PartitionGraph optional
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 13, 2021
1 parent 2b35cfd commit ab01b3a
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ struct RegionFuncMetadata {

class Partitioner : public MixedModeMutator {
public:
explicit Partitioner(const IRModule& module) : module_(module) {
Partitioner(const IRModule& module, bool bind_constants)
: module_(module), bind_constants_(bind_constants) {
std::set<std::string> func_names;
for (auto f : module->functions) {
GlobalVar f_var = f.first;
Expand Down Expand Up @@ -293,7 +294,7 @@ class Partitioner : public MixedModeMutator {
Map<Var, Expr> params_bind;
for (auto pair : region_func_meta_[region].args) {
params.push_back(pair.first);
if (IsConstant(pair.second)) {
if (bind_constants_ && IsConstant(pair.second)) {
params_bind.Set(pair.first, pair.second);
} else {
param_expr.push_back(pair.second);
Expand Down Expand Up @@ -401,6 +402,9 @@ class Partitioner : public MixedModeMutator {

/*!\brief The IRModule used for partitioning. */
IRModule module_;

/*!\brief Whether or not to bind constants in partitioned subgraphs. */
bool bind_constants_{false};
};

IRModule RemoveDefaultAnnotations(IRModule module) {
Expand Down Expand Up @@ -562,7 +566,7 @@ class NameMangleExtFuncs : public MixedModeMutator {

namespace transform {

Pass PartitionGraph(String mod_name) {
Pass PartitionGraph(String mod_name, bool bind_constants) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> flatten_tuples = [=](IRModule m,
PassContext pc) {
// There could be compiler_end annotations on tuples
Expand All @@ -581,8 +585,10 @@ Pass PartitionGraph(String mod_name) {
return partitioning::RemoveDefaultAnnotations(m);
};

runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); };
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = [=](IRModule m,
PassContext pc) {
return partitioning::Partitioner(m, bind_constants).Partition();
};

auto name_mangling_fn = [mod_name](String name) {
return runtime::get_name_mangled(mod_name, name);
Expand All @@ -601,9 +607,10 @@ Pass PartitionGraph(String mod_name) {
{flatten_tuples_pass, remove_default_pass, partition_pass, name_mangling_pass, InferType()});
}

TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph").set_body_typed([](String mod_name) {
return transform::PartitionGraph(mod_name);
});
TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph")
.set_body_typed([](String mod_name, bool bind_constants = false) {
return transform::PartitionGraph(mod_name, bind_constants);
});

} // namespace transform

Expand Down

0 comments on commit ab01b3a

Please sign in to comment.