From e8ffa8985687908039fa89d825c2b67891cdabad Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 30 Apr 2020 21:32:28 +0000 Subject: [PATCH] refactor shared output --- src/relay/transforms/partition_graph.cc | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 6d52af0a8204..89c07febc183 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -69,6 +69,12 @@ struct RegionFuncMetadata { /*! \brief Map from each region output expr node to its output index and TupleGetItem node. */ std::unordered_map, ObjectHash, ObjectEqual> out_expr_indices; + + /*! \brief Map from each input expression to the corresponding input variable of this region. + * This cache is used to make sure a region function will not have duplicated inputs even + * it refers the same expr multuple times. + */ + std::unordered_map in_expr_vars; }; /*! \brief This class partitions the expr labeled with begin and end annotations @@ -120,7 +126,7 @@ class Partitioner : public MixedModeMutator { // Initial region function metadata. for (auto region : region_set) { - region_func_meta_[region] = RegionFuncMetadata(); + region_func_meta_[region]; } } } @@ -150,8 +156,8 @@ class Partitioner : public MixedModeMutator { int index = GetArgIdx(sg, GetRef(call)); CHECK_NE(index, -1); - if (shared_output_.count(parent) && shared_output_[parent].count(sg)) { - return shared_output_[parent][sg]; + if (region_func_meta_[sg].in_expr_vars.count(parent)) { + return region_func_meta_[sg].in_expr_vars[parent]; } else { // The type of the created variable is the same as the compiler_begin // node. @@ -166,7 +172,7 @@ class Partitioner : public MixedModeMutator { region_func_meta_[sg].args.end()) { region_func_meta_[sg].args.push_back(cand); } - shared_output_[parent][sg] = var; + region_func_meta_[sg].in_expr_vars[parent] = var; return std::move(var); } } else { @@ -378,13 +384,6 @@ class Partitioner : public MixedModeMutator { */ std::unordered_map regions_sets_; - /*! \brief Map from region output exprs to descendant region inputs (vars). This cache is used - * to make sure descendant region will not have duplicated inputs even it refers the same expr - * multuple times. - */ - using RegionOutputMap = std::unordered_map; - std::unordered_map shared_output_; - /*!\brief The IRModule used for partitioning. */ IRModule module_; };