Skip to content

Commit

Permalink
refactor shared output
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Apr 30, 2020
1 parent aaed2b4 commit e8ffa89
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ struct RegionFuncMetadata {

/*! \brief Map from each region output expr node to its output index and TupleGetItem node. */
std::unordered_map<Expr, std::pair<int, TupleGetItem>, ObjectHash, ObjectEqual> out_expr_indices;

/*! \brief Map from each input expression to the corresponding input variable of this region.
* This cache is used to make sure a region function will not have duplicated inputs even
* it refers the same expr multuple times.
*/
std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> in_expr_vars;
};

/*! \brief This class partitions the expr labeled with begin and end annotations
Expand Down Expand Up @@ -120,7 +126,7 @@ class Partitioner : public MixedModeMutator {

// Initial region function metadata.
for (auto region : region_set) {
region_func_meta_[region] = RegionFuncMetadata();
region_func_meta_[region];
}
}
}
Expand Down Expand Up @@ -150,8 +156,8 @@ class Partitioner : public MixedModeMutator {
int index = GetArgIdx(sg, GetRef<Call>(call));
CHECK_NE(index, -1);

if (shared_output_.count(parent) && shared_output_[parent].count(sg)) {
return shared_output_[parent][sg];
if (region_func_meta_[sg].in_expr_vars.count(parent)) {
return region_func_meta_[sg].in_expr_vars[parent];
} else {
// The type of the created variable is the same as the compiler_begin
// node.
Expand All @@ -166,7 +172,7 @@ class Partitioner : public MixedModeMutator {
region_func_meta_[sg].args.end()) {
region_func_meta_[sg].args.push_back(cand);
}
shared_output_[parent][sg] = var;
region_func_meta_[sg].in_expr_vars[parent] = var;
return std::move(var);
}
} else {
Expand Down Expand Up @@ -378,13 +384,6 @@ class Partitioner : public MixedModeMutator {
*/
std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;

/*! \brief Map from region output exprs to descendant region inputs (vars). This cache is used
* to make sure descendant region will not have duplicated inputs even it refers the same expr
* multuple times.
*/
using RegionOutputMap = std::unordered_map<AnnotatedRegion, Var, ObjectHash, ObjectEqual>;
std::unordered_map<Expr, RegionOutputMap, ObjectHash, ObjectEqual> shared_output_;

/*!\brief The IRModule used for partitioning. */
IRModule module_;
};
Expand Down

0 comments on commit e8ffa89

Please sign in to comment.