diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 3e4a1820b7315..756284a1301f4 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -25,7 +25,7 @@ * These nodes are used as boundaries to partition the Relay function into * multiple regions that can be offloaded to different accelerators/backends. * - * Each of these paritioned functions, a.k.a subgraphs, will be viewed as + * Each of these paritioned functions, a.k.a regions, will be viewed as * external functions, and they will use the provided compiler for codegen. */ @@ -36,13 +36,14 @@ #include #include -#include +#include #include #include -#include #include #include "../backend/utils.h" +#include "../analysis/annotated_region_set.h" + namespace tvm { namespace relay { @@ -50,22 +51,8 @@ namespace partitioning { // Cache compiler_begin and compiler_end annotation ops for equivalence check to // reduce registry lookup overhead. -static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); -static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); - -/*! - * \brief The subgraph properties for partitioning. - */ -struct Subgraph { - /*! \brief The subgraph ID. */ - int id; - - /*! \brief The input arguments of this subgraph. */ - std::vector> args; - - /*! \brief Nodes in this subgraph. */ - std::unordered_set nodes; -}; +static const Op &compiler_begin_op = Op::Get("annotation.compiler_begin"); +static const Op &compiler_end_op = Op::Get("annotation.compiler_end"); /*! * \brief The checker that verifies if a Relay program is annotated correctly @@ -86,7 +73,7 @@ class AnnotationChecker : public ExprVisitor { return true; } - void VisitExpr_(const CallNode* call) final { + void VisitExpr_(const CallNode *call) final { auto op_node = call->op.as(); if (op_node == nullptr || call->attrs.as() == nullptr) { return; @@ -102,61 +89,71 @@ class AnnotationChecker : public ExprVisitor { bool found_end_{false}; }; -/*! \brief This class partitions the expr labeled with begin and end annoations +/*! \brief This class partitions the expr labeled with begin and end annotations * into function containing multiple regions. Each region is labeled with * a compiler attribute so that it will be handled by any compilers that are not * in the TVM stack. * - * TODO(@zhiics) This following algorithm is not adequate to handle all cases, - * i.e. multiple `compiler_end` nodes. + * Input : A Relay module that have functions with disjoint annotated regions + * using compiler_begin and compiler_end. There could be multiple outputs. + * + * Output : A Relay module with global functions for such disjoint annotated regions + * with calls inserted at the respective location + * + * Dependencies : RegionSet Utility class. + * + * Methodology : + * 1) The RegionSet utility class is able to construct a collection of + * nodes that are bound by a given annotation -- here we use compiler_begin + * and compiler_end + * 2) Initially, for each function in the module RegionSets are populated. + * 3) Then, Vistor pass is traversed until a compiler_end node is encountered + * that belongs to a "region". + * 4) When the first compiler_end of a given annotated region is found, a function is + * formed and inserted. + * a) if the region has multiple outputs, a Tuple node (capturing all outputs) + * is returned. + * 5) Thereafter, if we encounter an another output of the same annotated region, + * it is important to note that the function is already formed. Therefore, it will + * lookup the function and add a TupleGetItemNode. + * a) We will use the location index of "rets" of each "Region" of RegionSet + * as TupleGetItemNode index. + * 6) Therefore, functions will be created for all annotated regions. The name for each + * global function is created using "Region" id and the compiler name. + * + * Expected Usecase : + * This pass is intended to run as the last pass in a series of passes as follows : + * 1) Annotate Supported Single Ops - annotated each single op with supported backends. + * We use supported_begin and supported_end annotations. + * 2) Annotate Supported Composite Ops - annotate each composite op (that consist of + * multiple single ops). + * We use supported_begin and supported_end + * annotations. + * 3) Deconflict Pass - Make sure each op is annotated by only a single backend. + * In other words, each Annotated Region will be disjoint. + * We promote supported_* annotations to compiler_* annotations. + * 4) Merge Supported Pass - Merge the disjoint compiler_* Annotated regions belonging + * to same backend. + * 5) *Partition Graph* - Convert Disjoint Annotated Regions into Functions. */ + class Partitioner : public ExprMutator { public: - explicit Partitioner(const IRModule& module) : module_(module) {} - - std::shared_ptr GetSubgraph(const Expr node) { - for (auto candidate : this->subgraphs_) { - if (candidate->nodes.find(node) != candidate->nodes.end()) { - return candidate; - } + explicit Partitioner(const IRModule &module) : module_(module) { + for (auto f : module->functions) { + GlobalVar f_var = f.first; + BaseFunc f_func = f.second; + + // Creating regionset per function in the module + auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op, + partitioning::compiler_end_op); + regions_sets_[region_set] = f_func; } - return nullptr; } - void MergeSubgraph(std::shared_ptr subgraph1, - std::shared_ptr subgraph2) { - if (subgraph1 == subgraph2) { - return; - } - - // Merge subgraph 2 to subgraph 1 and erase subgraph 2. - subgraph1->nodes.insert(subgraph2->nodes.begin(), subgraph2->nodes.end()); - for (auto arg : subgraph2->args) { - subgraph1->args.push_back(arg); - } - this->subgraphs_.erase(subgraph2); - } - - void AddToSubgraph(std::shared_ptr subgraph, const Expr expr) { - auto subgraph2 = GetSubgraph(expr); - if (subgraph2) { - MergeSubgraph(subgraph, subgraph2); - } else { - subgraph->nodes.insert(expr); - } - } - - Expr VisitExpr_(const CallNode* call) final { + Expr VisitExpr_(const CallNode *call) final { auto op_node = call->op.as(); - if (op_node == nullptr || call->attrs.as() == nullptr) { - // Propogate subgraph to arguments - auto subgraph = GetSubgraph(GetRef(call)); - if (subgraph) { - for (auto arg : call->args) { - AddToSubgraph(subgraph, arg); - } - } return ExprMutator::VisitExpr_(call); } else if (call->op == compiler_begin_op) { // The annotation node is inserted on edge so it must have only one argument. @@ -165,101 +162,142 @@ class Partitioner : public ExprMutator { // Traverse the rest graph. auto input_expr = VisitExpr(call->args[0]); - // Replace the begin annotation with an external call input variable. - auto compiler_attrs = call->attrs.as(); + AnnotatedRegion sg = GetRegion(GetRef(call)); + int index = GetArgIdx(sg, GetRef(call)); + CHECK_NE(index, -1); // The type of the created variable is the same as the compiler_begin // node. - auto var = Var(compiler_attrs->compiler + "_input" + std::to_string(var_id_++), - call->checked_type_); - - // Find the corresponding subgraph and add the argument. - auto subgraph = GetSubgraph(GetRef(call)); - if (!subgraph) { - throw Error(ErrorBuilder() - << "Cannot find the corresponding subgraph for start annotation:\n" - << AsText(GetRef(call), false)); - } - subgraph->args.push_back({var, input_expr}); + std::string target = call->attrs.as()->compiler; + std::string varname = target + "_" + std::to_string(sg->GetID()) + + "_i" + std::to_string(index); + auto var = Var(varname, GetRef(call)->checked_type_); + + auto cand = std::make_pair(var, input_expr); + if (std::find(region_args[sg].begin(), + region_args[sg].end(), cand) == region_args[sg].end()) { + region_args[sg].push_back(cand); + } + return std::move(var); } else { CHECK_EQ(call->op, compiler_end_op); // The annotation node is inserted on edge so it must have only one argument. CHECK_EQ(call->args.size(), 1U); - auto compiler_attrs = call->attrs.as(); + AnnotatedRegion region = GetRegion(GetRef(call)); - // Check if the argument already belongs to an existing subgraph - auto subgraph = GetSubgraph(call->args[0]); - if (!subgraph) { - auto ret = this->subgraphs_.emplace(std::make_shared()); - subgraph = *ret.first; - subgraph->nodes.insert(call->args[0]); - subgraph->id = this->subgraph_id_++; - } - subgraph->nodes.insert(GetRef(call)); + // TODO(@manupa-arm) : need to use the parent function (to which region + // belongs to) name/key for the funtions that are created + BaseFunc f = GetFunc(GetRef(call)); // Traverse subgraph inputs. auto input = VisitExpr(call->args[0]); - Array params; - Array args; - std::unordered_map params_bind; - - // The subgraph may be merged so we need to update it again. - subgraph = GetSubgraph(GetRef(call)); - CHECK(subgraph); - - // Record the constants for propagation. - for (auto pair : subgraph->args) { - params.push_back(pair.first); - if (const auto* cn = pair.second.as()) { - params_bind[pair.first->name_hint()] = cn->data; + CHECK(region.defined()) << "Region not defined for " << GetRef(call); + // functions are created for each annotated regions, + // when their first output is encountered. + // If multiple outputs are there, a tuple node is inserted at the end. + // region_function_calls is map that maintains + // (each annotated regions) --> created function + + if (region_function_calls.find(region) != region_function_calls.end()) { + // This section is executed only if there are multiple outputs in the region + // Thus, the function is always created and at the end there would be a tuple node + // Therefore, we insert a tuple get item node. + + // Use the already created tuple node + auto sg_call = region_function_calls[region]; + int index = GetRetIdx(region, GetRef(call)); + CHECK_NE(index, -1); + + auto tuple_get_item_ = TupleGetItem(sg_call, index); + tuple_get_item_->checked_type_ = GetRef(call)->args[0]->checked_type_; + return std::move(tuple_get_item_); + } else { + // First time this region is encountered in the traversal + // Creating the function + + Array fields; + + for (auto ret : region->GetOutputs()) { + auto ret_expr = VisitExpr(Downcast(ret)->args[0]); + fields.push_back(ret_expr); + } + int index = GetRetIdx(region, GetRef(call)); + CHECK_NE(index, -1); + + Array params; + Array param_expr; + std::unordered_map params_bind; + + for (auto pair : region_args[region]) { + params.push_back(pair.first); + if (const auto* cn = pair.second.as()) { + params_bind[pair.first->name_hint()] = cn->data; + } else { + param_expr.push_back(pair.second); + } + } + + Function global_region_func; + if (region->GetOutputs().size() == 1) { + // If there are only a single output; no need to add a tuple + global_region_func = Function(params, fields[0], + call->args[0]->checked_type_, {}, DictAttrs()); } else { - args.push_back(pair.second); + auto tuple = Tuple(fields); + global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs()); + } + + std::string target = call->attrs.as()->compiler; + std::string name = target + "_" + std::to_string(region->GetID()); + + global_region_func = WithAttr(std::move(global_region_func), attr::kExternalSymbol, + tir::StringImmNode::make(name)); + global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, + tvm::Integer(1)); + global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, + tvm::tir::StringImmNode::make(target)); + global_region_func = WithAttr(std::move(global_region_func), attr::kInline, + tvm::Integer(1)); + + // Constant propagation + if (!params_bind.empty()) { + global_region_func = backend::BindParamsByName(global_region_func, params_bind); } - } - auto subgraph_func = - Function(params, input, call->checked_type_, {}); - - std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id); - subgraph_func = - WithAttr(std::move(subgraph_func), attr::kExternalSymbol, tir::StringImmNode::make(name)); - subgraph_func = - WithAttr(std::move(subgraph_func), attr::kPrimitive, tvm::Integer(1)); - subgraph_func = - WithAttr(std::move(subgraph_func), attr::kCompiler, - tvm::tir::StringImmNode::make(compiler_attrs->compiler)); - subgraph_func = - WithAttr(std::move(subgraph_func), attr::kInline, tvm::Integer(1)); - - // Constant propagation - if (!params_bind.empty()) { - subgraph_func = backend::BindParamsByName(subgraph_func, params_bind); + std::string fname = name; + CHECK(!module_->ContainGlobalVar(fname)) + << "Global function " << fname << " already exists"; + // Create a global function and add it to the IRModule for the region. + // This way we lift the functions that should be handled by external + // codegen to the module scope and rely on the pass manager to prevent relay + // function level passes (i.e. simplify inference and fusion) optimizing it. + GlobalVar glob_func(fname); + module_->Add(glob_func, global_region_func); + + // The return type of callnode is the same as the type of the + // compiler_end node. + auto ret = Call(glob_func, param_expr); + region_function_calls[region] = ret; + + if (region->GetOutputs().size() == 1) { + // If there is only a single output; no need to add a tuplegetitem node + return Call(glob_func, param_expr); + } else { + // Add a tuplegetitem node to select this output out of many + auto tuple_get_item_ = TupleGetItem(ret, index); + tuple_get_item_->checked_type_ = GetRef(call)->args[0]->checked_type_; + return std::move(tuple_get_item_); + } } - CHECK(!module_->ContainGlobalVar(name)) - << "Global function " << name << " already exists"; - // Create a global function and add it to the IRModule for the subgraph. - // This way we lift the functions that should be handled by external - // codegen to the module scope and rely on the pass manager to prevent relay - // function level passes (i.e. simplify inference and fusion) optimizing it. - GlobalVar glob_func(name); - module_->Add(glob_func, subgraph_func); - // The return type of callnode is the same as the type of the - // compiler_end node. - auto ret = Call(glob_func, args); - ret->checked_type_ = call->checked_type_; - return std::move(ret); } } - Expr VisitExpr_(const TupleNode* op) final { - auto subgraph = GetSubgraph(GetRef(op)); - if (!subgraph) { + Expr VisitExpr_(const TupleNode *op) final { + auto region = GetRegion(GetRef(op)); + if (!region.defined()) { return ExprMutator::VisitExpr_(op); } else { - for (auto field : op->fields) { - AddToSubgraph(subgraph, field); - } Array fields; for (auto field : op->fields) { fields.push_back(VisitExpr(field)); @@ -268,26 +306,22 @@ class Partitioner : public ExprMutator { } } - Expr VisitExpr_(const TupleGetItemNode* g) final { - auto subgraph = GetSubgraph(GetRef(g)); - if (!subgraph) { + Expr VisitExpr_(const TupleGetItemNode *g) final { + auto region = GetRegion(GetRef(g)); + if (!region.defined()) { return ExprMutator::VisitExpr_(g); } else { - AddToSubgraph(subgraph, g->tuple); auto t = VisitExpr(g->tuple); return TupleGetItem(t, g->index); } } - Expr VisitExpr_(const FunctionNode* op) final { - auto subgraph = GetSubgraph(GetRef(op)); - if (!subgraph) { + Expr VisitExpr_(const FunctionNode *op) final { + auto region = GetRegion(GetRef(op)); + if (!region.defined()) { return ExprMutator::VisitExpr_(op); } else { Array params; - for (auto param : op->params) { - AddToSubgraph(subgraph, param); - } for (auto param : op->params) { Var new_param = Downcast(VisitExpr(param)); params.push_back(new_param); @@ -297,30 +331,23 @@ class Partitioner : public ExprMutator { } } - Expr VisitExpr_(const LetNode* op) final { - auto subgraph = GetSubgraph(GetRef(op)); - if (!subgraph) { + Expr VisitExpr_(const LetNode *op) final { + auto region = GetRegion(GetRef(op)); + if (!region.defined()) { return ExprMutator::VisitExpr_(op); } else { - AddToSubgraph(subgraph, op->var); - AddToSubgraph(subgraph, op->value); - AddToSubgraph(subgraph, op->body); Var var = Downcast(VisitExpr(op->var)); auto value = VisitExpr(op->value); auto body = VisitExpr(op->body); - return Let(var, value, body); } } - Expr VisitExpr_(const IfNode* op) final { - auto subgraph = GetSubgraph(GetRef(op)); - if (!subgraph) { + Expr VisitExpr_(const IfNode *op) final { + auto region = GetRegion(GetRef(op)); + if (!region.defined()) { return ExprMutator::VisitExpr_(op); } else { - AddToSubgraph(subgraph, op->cond); - AddToSubgraph(subgraph, op->true_branch); - AddToSubgraph(subgraph, op->false_branch); auto guard = VisitExpr(op->cond); auto true_b = VisitExpr(op->true_branch); auto false_b = VisitExpr(op->false_branch); @@ -328,34 +355,31 @@ class Partitioner : public ExprMutator { } } - Expr VisitExpr_(const RefCreateNode* op) final { - auto subgraph = GetSubgraph(GetRef(op)); - if (!subgraph) { + Expr VisitExpr_(const RefCreateNode *op) final { + auto region = GetRegion(GetRef(op)); + if (!region.defined()) { return ExprMutator::VisitExpr_(op); } else { - AddToSubgraph(subgraph, op->value); Expr value = VisitExpr(op->value); return RefCreate(value); } } - Expr VisitExpr_(const RefReadNode* op) final { - auto subgraph = GetSubgraph(GetRef(op)); - if (!subgraph) { + Expr VisitExpr_(const RefReadNode *op) final { + auto region = GetRegion(GetRef(op)); + if (!region.defined()) { return ExprMutator::VisitExpr_(op); } else { - AddToSubgraph(subgraph, op->ref); Expr ref = VisitExpr(op->ref); return RefRead(ref); } } - Expr VisitExpr_(const RefWriteNode* op) final { - auto subgraph = GetSubgraph(GetRef(op)); - if (!subgraph) { + Expr VisitExpr_(const RefWriteNode *op) final { + auto region = GetRegion(GetRef(op)); + if (!region.defined()) { return ExprMutator::VisitExpr_(op); } else { - AddToSubgraph(subgraph, op->ref); Expr ref = VisitExpr(op->ref); Expr value = VisitExpr(op->value); return RefWrite(ref, value); @@ -364,14 +388,14 @@ class Partitioner : public ExprMutator { IRModule Partition() { auto glob_funcs = module_->functions; - for (const auto& pair : glob_funcs) { - if (auto* fn = pair.second.as()) { + for (const auto &pair : glob_funcs) { + if (auto *fn = pair.second.as()) { auto func = GetRef(fn); func = Function(func->params, - VisitExpr(func->body), - func->ret_type, - func->type_params, - func->attrs); + VisitExpr(func->body), + func->ret_type, + func->type_params, + func->attrs); module_->Update(pair.first, func); } } @@ -379,21 +403,99 @@ class Partitioner : public ExprMutator { } private: - int var_id_{0}; - int subgraph_id_{0}; - std::unordered_set> subgraphs_; + /*! + * \brief Get the region an expression belongs to + * if its in a region. + */ + AnnotatedRegion GetRegion(const Expr &e) { + for (auto sg_set_it : regions_sets_) { + auto sg_set = sg_set_it.first; + AnnotatedRegion sg = sg_set->GetRegion(e); + if (sg.defined()) { + return sg; + } + } + return AnnotatedRegion(nullptr); + } + + /*! + * \brief Get the function an expression belongs to + * if its in a region. + */ + BaseFunc GetFunc(const Expr &e) { + for (auto sg_set_it : regions_sets_) { + auto sg_set = sg_set_it.first; + auto func = sg_set_it.second; + + AnnotatedRegion sg = sg_set->GetRegion(e); + if (sg.defined()) { + return func; + } + } + return BaseFunc(nullptr); + } + + /*! + * \brief Get the index of the argument; + * this is to be used as tuplegetitem idx + */ + int GetArgIdx(AnnotatedRegion sg, const Expr &arg) { + int idx = 0; + for (auto arg_ : sg->GetInputs()) { + if (arg == arg_) { + return idx; + } + idx++; + } + return -1; + } + + /*! + * \brief Get the index of the return(output); + * this is to be used as tuplegetitem idx + */ + int GetRetIdx(AnnotatedRegion sg, const Expr &arg) { + int idx = 0; + for (auto arg_ : sg->GetOutputs()) { + if (arg == arg_) { + return idx; + } + idx++; + } + return -1; + } + + /*! + * \brief This map maintains the already created function calls. + * This is required in the multi-output scenario, to link rest of the outputs to call + */ + std::unordered_map region_function_calls; + + /*! + * \brief This map maintains arguments (of region) visits through visitor patterns. + * Those arguement var and expression will be used to when creating the function. + */ + std::unordered_map>, + ObjectHash, ObjectEqual> region_args; + + /*! + * \brief Each region set is associated with a function in the module. + * This map maintains the mapping between regionsets and the function it belongs to + */ + std::unordered_map regions_sets_; IRModule module_; }; + } // namespace partitioning namespace transform { Pass PartitionGraph() { runtime::TypedPackedFunc part_func = - [=](IRModule m, PassContext pc) { - return partitioning::Partitioner(m).Partition(); - }; + [=](IRModule m, PassContext pc) { + return partitioning::Partitioner(m).Partition(); + }; auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {}); return Sequential({partitioned, InferType()}); } diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 1f37ab84d4a5f..893f7c37eea8c 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -678,6 +678,81 @@ def expected(): check_result(mod, {"y": y_data}, (8, 8), np.log(np_add)) +def test_multiple_outputs(): + def create_merged_graph(): + data = relay.var('data', shape=(10, 10)) + + cb_1 = compiler_begin(data, 'test_target') + O_1 = relay.abs(cb_1) + ce_2 = compiler_end(O_1, 'test_target') + O_2 = relay.nn.relu(O_1) + ce_3 = compiler_end(O_2, 'test_target') + + X = relay.tanh(ce_2) + + cb_3 = compiler_begin(ce_3, 'test_target') + cb_4 = compiler_begin(X, 'test_target') + O_3 = relay.add(cb_3, cb_4) + ce_4 = compiler_end(O_3, 'test_target') + + func = relay.Function([data], ce_4) + return func + + def expected(): + mod = tvm.IRModule() + + # function 1 + f1_cb1 = relay.var('test_target_1_i0', shape=(10, 10)) + f1_O_1 = relay.abs(f1_cb1) + f1_O_2 = relay.nn.relu(f1_O_1) + f1_out = relay.Tuple((f1_O_2,f1_O_1)) + func1 = relay.Function([f1_cb1], f1_out) + + func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func1 = func1.with_attr("Compiler", + tvm.tir.StringImm("test_target")) + func1 = func1.with_attr("ExternalSymbol", + tvm.tir.StringImm("test_target_1")) + gv1 = relay.GlobalVar("test_target_1") + mod[gv1] = func1 + + # function 0 + f2_cb3 = relay.var('test_target_0_i0', shape=(10, 10)) + f2_cb4 = relay.var('test_target_0_i1', shape=(10, 10)) + f2_O_3 = relay.add(f2_cb3, f2_cb4) + func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3) + + func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Compiler", + tvm.tir.StringImm("test_target")) + func0 = func0.with_attr("ExternalSymbol", + tvm.tir.StringImm("test_target_0")) + gv0 = relay.GlobalVar("test_target_0") + mod[gv0] = func0 + + # body + data = relay.var('data', shape=(10, 10)) + tuple_out = gv1(data) + ce_2 = relay.TupleGetItem(tuple_out, 1) + ce_3 = relay.TupleGetItem(tuple_out, 0) + + X = relay.tanh(ce_2) + ce_4 = gv0(ce_3, X) + func = relay.Function([data], ce_4) + mod["main"] = func + + return mod + + # print(create_merged_graph()) + mod = tvm.IRModule() + mod["main"] = create_merged_graph() + + ref_mod = expected(); + partitioned = transform.PartitionGraph()(mod) + assert relay.analysis.alpha_equal(partitioned, ref_mod) + if __name__ == "__main__": test_multi_node_compiler() test_extern_ccompiler_single_op() @@ -688,3 +763,4 @@ def expected(): test_function_lifting() test_function_lifting_inline() test_constant_propagation() + test_multiple_outputs()