From 26404c14d74ea0e9b6abb12ee4d942dd019748e6 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 1 May 2020 18:01:50 +0000 Subject: [PATCH] address comments --- src/relay/transforms/partition_graph.cc | 107 +++++++++++------------- 1 file changed, 47 insertions(+), 60 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 89c07febc183..634434d3a3d0 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -67,14 +67,17 @@ struct RegionFuncMetadata { */ std::vector> args; - /*! \brief Map from each region output expr node to its output index and TupleGetItem node. */ - std::unordered_map, ObjectHash, ObjectEqual> out_expr_indices; + /*! \brief Map from each region output expr (compiler end) node to + * the corresponding function output expr. + */ + std::unordered_map region_func_out; - /*! \brief Map from each input expression to the corresponding input variable of this region. - * This cache is used to make sure a region function will not have duplicated inputs even - * it refers the same expr multuple times. + /*! \brief Map from each region input expression (compiler begin) to + * the corresponding function input variable. This cache is used to make sure + * a region function will not have duplicated inputs even if it refers to + * the same expr multiple times. */ - std::unordered_map in_expr_vars; + std::unordered_map region_func_in; }; /*! \brief This class partitions the expr labeled with begin and end annotations @@ -123,11 +126,6 @@ class Partitioner : public MixedModeMutator { auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op, partitioning::compiler_end_op); regions_sets_[region_set] = f_func; - - // Initial region function metadata. - for (auto region : region_set) { - region_func_meta_[region]; - } } } @@ -156,8 +154,8 @@ class Partitioner : public MixedModeMutator { int index = GetArgIdx(sg, GetRef(call)); CHECK_NE(index, -1); - if (region_func_meta_[sg].in_expr_vars.count(parent)) { - return region_func_meta_[sg].in_expr_vars[parent]; + if (region_func_meta_[sg].region_func_in.count(parent)) { + return region_func_meta_[sg].region_func_in[parent]; } else { // The type of the created variable is the same as the compiler_begin // node. @@ -172,7 +170,7 @@ class Partitioner : public MixedModeMutator { region_func_meta_[sg].args.end()) { region_func_meta_[sg].args.push_back(cand); } - region_func_meta_[sg].in_expr_vars[parent] = var; + region_func_meta_[sg].region_func_in[parent] = var; return std::move(var); } } else { @@ -195,12 +193,14 @@ class Partitioner : public MixedModeMutator { // If multiple outputs are there, a tuple node is inserted at the end. if (!region_func_meta_[region].func_call.defined()) { - // First time this region is encountered in the traversal. - // Creating the function. + // First time this region is encountered in the traversal. Creating the function. CreateFunction(region, call); } + // Retrieve this particular output of function. - return GetFunctionOutput(region, GetRef(call)); + Expr region_out_expr = Downcast(GetRef(call))->args[0]; + CHECK(region_func_meta_[region].region_func_out.count(region_out_expr)); + return region_func_meta_[region].region_func_out[region_out_expr]; } } @@ -266,21 +266,22 @@ class Partitioner : public MixedModeMutator { } /*! - * \brief This function is called first time that we encounter a compiler_end - * node to create the function for the subgraph. + * \brief Create a function and its function call for the given region. If the function has + * multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes + * will be created to serve output consumers. */ - void CreateFunction(AnnotatedRegion region, const CallNode* call) { + void CreateFunction(AnnotatedRegion region, const CallNode* end_node) { // Create fields which is a unique list of outputs. Array fields; - int i = 0; - for (auto ret : region->GetOutputs()) { - auto ret_node = Downcast(ret)->args[0]; + std::unordered_map out_expr_to_idx; + int out_idx = 0; + for (auto region_end_node : region->GetOutputs()) { + auto ret_node = Downcast(region_end_node)->args[0]; // Don't duplicate outputs. - if (!region_func_meta_[region].out_expr_indices.count(ret_node)) { + if (!out_expr_to_idx.count(ret_node)) { auto ret_expr = MixedModeMutator::VisitExpr(ret_node); fields.push_back(ret_expr); - region_func_meta_[region].out_expr_indices[ret_node] = {i, TupleGetItem()}; - i++; + out_expr_to_idx[ret_node] = out_idx++; } } @@ -309,13 +310,13 @@ class Partitioner : public MixedModeMutator { if (fields.size() == 1) { // If there are only a single output; no need to add a tuple global_region_func = - Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs()); + Function(params, fields[0], end_node->args[0]->checked_type_, {}, DictAttrs()); } else { auto tuple = Tuple(fields); global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs()); } - std::string target = call->attrs.as()->compiler; + std::string target = end_node->attrs.as()->compiler; std::string name = target + "_" + std::to_string(region->GetID()); global_region_func = @@ -340,38 +341,24 @@ class Partitioner : public MixedModeMutator { GlobalVar glob_func(fname); module_->Add(glob_func, global_region_func); - // The return type of callnode is the same as the type of the compiler_end node. - auto ret = Call(glob_func, param_expr); - region_func_meta_[region].func_call = ret; - } + // Create a call node for the function. + auto call = Call(glob_func, param_expr); + region_func_meta_[region].func_call = call; - /*! - * \brief Get the return(output) of the function for compiler end node "end_arg". - * This will return either a Call (for a function with a single output) or a - * TupleGetItem (for a function with multiple outputs). - */ - Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) { - Expr arg = Downcast(end_arg)->args[0]; - // Function has one output. - if (region_func_meta_[region].out_expr_indices.size() == 1) { - return region_func_meta_[region].func_call; - } - - // Function has multiple outputs. - // Use already made TupleGetItem. - if (region_func_meta_[region].out_expr_indices.count(arg) && - region_func_meta_[region].out_expr_indices[arg].second.defined()) { - return region_func_meta_[region].out_expr_indices[arg].second; + // Create output expr(s) for the function call. + if (out_expr_to_idx.size() == 1) { + // Single output direcly uses the call node as the output expr. + region_func_meta_[region].region_func_out[out_expr_to_idx.begin()->first] = call; + } else { + // Multiple outptus need to create TupleGetItem nodes as output exprs. + for (auto pair : out_expr_to_idx) { + Expr region_out_expr = pair.first; // The arg of a compiler end node of this region. + int idx = pair.second; // Corresponding function output tuple index. + auto tuple_get_item = TupleGetItem(call, idx); + tuple_get_item->checked_type_ = region_out_expr->checked_type_; + region_func_meta_[region].region_func_out[region_out_expr] = tuple_get_item; + } } - // Create new TupleGetItem. - CHECK(region_func_meta_[region].out_expr_indices.count(arg)); - int index = region_func_meta_[region].out_expr_indices[arg].first; - - auto func_call = region_func_meta_[region].func_call; - auto tuple_get_item_ = TupleGetItem(func_call, index); - tuple_get_item_->checked_type_ = arg->checked_type_; - region_func_meta_[region].out_expr_indices[arg].second = tuple_get_item_; - return std::move(tuple_get_item_); } /*! \brief Map from each region to its metadata of the generated function. */ @@ -388,7 +375,7 @@ class Partitioner : public MixedModeMutator { IRModule module_; }; -IRModule Remove(IRModule module) { +IRModule RemoveDefaultAnnotations(IRModule module) { class DefaultRemover : public ExprRewriter { public: DefaultRemover() = default; @@ -427,7 +414,7 @@ Pass PartitionGraph() { // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute // by treating them as un-annotated, but we don't have it yet. This workaround pass removes // all "default" annotations and should be deleted in the future. - auto new_m = partitioning::Remove(m); + auto new_m = partitioning::RemoveDefaultAnnotations(m); return partitioning::Partitioner(new_m).Partition(); }; auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});