From d92bed602a86dcec86856b1766e5726d33bb49d4 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 13 Mar 2020 16:54:14 +0000 Subject: [PATCH 1/5] [RELAY] Re-wrote the Graph Partitioner to support multiple outputs Input : A Relay module that have functions with disjoint annotated regions using compiler_begin and compiler_end. There could be multiple outputs. Output : A Relay module with global functions for such disjoint annotated regions with calls inserted at the respective location Dependencies : AnnotatedRegionSet Utility class. Methodology : 1) The AnnotatedRegionSet utility class is able to construct a collection of nodes that are bound by a give annotation -- here we use compiler_begin and compiler_end 2) Initially, for each function in the module AnnotatedRegionSets are populated. 3) Then, Vistor pass is traversed until a compiler_end node is encountered that belongs to a "region". 4) When the first compiler_end of a given annotated region is found, a function is formed and inserted. a) if the region has multiple outputs, a Tuple node (capturing all outputs) is returned. 5) Thereafter, if we encounter an another output of the same annotated region, it is important to note that the function is already formed. Therefore, it will lookup the function and add a TupleGetItemNode is inserted. a) We will use the location index of "rets" of each "Region" of AnnotatedRegionSet as TupleGetItemNode index. 6) Therefore, functions will be created for all annotated regions. The name for each global function is created using "Region" id and the compiler name. Change-Id: I1372f02a845b6d3da03b561763e03a378dca263c --- src/relay/transforms/partition_graph.cc | 470 +++++++++++------- .../python/relay/test_pass_partition_graph.py | 76 +++ 2 files changed, 362 insertions(+), 184 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 3e4a1820b731..756284a1301f 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -25,7 +25,7 @@ * These nodes are used as boundaries to partition the Relay function into * multiple regions that can be offloaded to different accelerators/backends. * - * Each of these paritioned functions, a.k.a subgraphs, will be viewed as + * Each of these paritioned functions, a.k.a regions, will be viewed as * external functions, and they will use the provided compiler for codegen. */ @@ -36,13 +36,14 @@ #include #include -#include +#include #include #include -#include #include #include "../backend/utils.h" +#include "../analysis/annotated_region_set.h" + namespace tvm { namespace relay { @@ -50,22 +51,8 @@ namespace partitioning { // Cache compiler_begin and compiler_end annotation ops for equivalence check to // reduce registry lookup overhead. -static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); -static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); - -/*! - * \brief The subgraph properties for partitioning. - */ -struct Subgraph { - /*! \brief The subgraph ID. */ - int id; - - /*! \brief The input arguments of this subgraph. */ - std::vector> args; - - /*! \brief Nodes in this subgraph. */ - std::unordered_set nodes; -}; +static const Op &compiler_begin_op = Op::Get("annotation.compiler_begin"); +static const Op &compiler_end_op = Op::Get("annotation.compiler_end"); /*! * \brief The checker that verifies if a Relay program is annotated correctly @@ -86,7 +73,7 @@ class AnnotationChecker : public ExprVisitor { return true; } - void VisitExpr_(const CallNode* call) final { + void VisitExpr_(const CallNode *call) final { auto op_node = call->op.as(); if (op_node == nullptr || call->attrs.as() == nullptr) { return; @@ -102,61 +89,71 @@ class AnnotationChecker : public ExprVisitor { bool found_end_{false}; }; -/*! \brief This class partitions the expr labeled with begin and end annoations +/*! \brief This class partitions the expr labeled with begin and end annotations * into function containing multiple regions. Each region is labeled with * a compiler attribute so that it will be handled by any compilers that are not * in the TVM stack. * - * TODO(@zhiics) This following algorithm is not adequate to handle all cases, - * i.e. multiple `compiler_end` nodes. + * Input : A Relay module that have functions with disjoint annotated regions + * using compiler_begin and compiler_end. There could be multiple outputs. + * + * Output : A Relay module with global functions for such disjoint annotated regions + * with calls inserted at the respective location + * + * Dependencies : RegionSet Utility class. + * + * Methodology : + * 1) The RegionSet utility class is able to construct a collection of + * nodes that are bound by a given annotation -- here we use compiler_begin + * and compiler_end + * 2) Initially, for each function in the module RegionSets are populated. + * 3) Then, Vistor pass is traversed until a compiler_end node is encountered + * that belongs to a "region". + * 4) When the first compiler_end of a given annotated region is found, a function is + * formed and inserted. + * a) if the region has multiple outputs, a Tuple node (capturing all outputs) + * is returned. + * 5) Thereafter, if we encounter an another output of the same annotated region, + * it is important to note that the function is already formed. Therefore, it will + * lookup the function and add a TupleGetItemNode. + * a) We will use the location index of "rets" of each "Region" of RegionSet + * as TupleGetItemNode index. + * 6) Therefore, functions will be created for all annotated regions. The name for each + * global function is created using "Region" id and the compiler name. + * + * Expected Usecase : + * This pass is intended to run as the last pass in a series of passes as follows : + * 1) Annotate Supported Single Ops - annotated each single op with supported backends. + * We use supported_begin and supported_end annotations. + * 2) Annotate Supported Composite Ops - annotate each composite op (that consist of + * multiple single ops). + * We use supported_begin and supported_end + * annotations. + * 3) Deconflict Pass - Make sure each op is annotated by only a single backend. + * In other words, each Annotated Region will be disjoint. + * We promote supported_* annotations to compiler_* annotations. + * 4) Merge Supported Pass - Merge the disjoint compiler_* Annotated regions belonging + * to same backend. + * 5) *Partition Graph* - Convert Disjoint Annotated Regions into Functions. */ + class Partitioner : public ExprMutator { public: - explicit Partitioner(const IRModule& module) : module_(module) {} - - std::shared_ptr GetSubgraph(const Expr node) { - for (auto candidate : this->subgraphs_) { - if (candidate->nodes.find(node) != candidate->nodes.end()) { - return candidate; - } + explicit Partitioner(const IRModule &module) : module_(module) { + for (auto f : module->functions) { + GlobalVar f_var = f.first; + BaseFunc f_func = f.second; + + // Creating regionset per function in the module + auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op, + partitioning::compiler_end_op); + regions_sets_[region_set] = f_func; } - return nullptr; } - void MergeSubgraph(std::shared_ptr subgraph1, - std::shared_ptr subgraph2) { - if (subgraph1 == subgraph2) { - return; - } - - // Merge subgraph 2 to subgraph 1 and erase subgraph 2. - subgraph1->nodes.insert(subgraph2->nodes.begin(), subgraph2->nodes.end()); - for (auto arg : subgraph2->args) { - subgraph1->args.push_back(arg); - } - this->subgraphs_.erase(subgraph2); - } - - void AddToSubgraph(std::shared_ptr subgraph, const Expr expr) { - auto subgraph2 = GetSubgraph(expr); - if (subgraph2) { - MergeSubgraph(subgraph, subgraph2); - } else { - subgraph->nodes.insert(expr); - } - } - - Expr VisitExpr_(const CallNode* call) final { + Expr VisitExpr_(const CallNode *call) final { auto op_node = call->op.as(); - if (op_node == nullptr || call->attrs.as() == nullptr) { - // Propogate subgraph to arguments - auto subgraph = GetSubgraph(GetRef(call)); - if (subgraph) { - for (auto arg : call->args) { - AddToSubgraph(subgraph, arg); - } - } return ExprMutator::VisitExpr_(call); } else if (call->op == compiler_begin_op) { // The annotation node is inserted on edge so it must have only one argument. @@ -165,101 +162,142 @@ class Partitioner : public ExprMutator { // Traverse the rest graph. auto input_expr = VisitExpr(call->args[0]); - // Replace the begin annotation with an external call input variable. - auto compiler_attrs = call->attrs.as(); + AnnotatedRegion sg = GetRegion(GetRef(call)); + int index = GetArgIdx(sg, GetRef(call)); + CHECK_NE(index, -1); // The type of the created variable is the same as the compiler_begin // node. - auto var = Var(compiler_attrs->compiler + "_input" + std::to_string(var_id_++), - call->checked_type_); - - // Find the corresponding subgraph and add the argument. - auto subgraph = GetSubgraph(GetRef(call)); - if (!subgraph) { - throw Error(ErrorBuilder() - << "Cannot find the corresponding subgraph for start annotation:\n" - << AsText(GetRef(call), false)); - } - subgraph->args.push_back({var, input_expr}); + std::string target = call->attrs.as()->compiler; + std::string varname = target + "_" + std::to_string(sg->GetID()) + + "_i" + std::to_string(index); + auto var = Var(varname, GetRef(call)->checked_type_); + + auto cand = std::make_pair(var, input_expr); + if (std::find(region_args[sg].begin(), + region_args[sg].end(), cand) == region_args[sg].end()) { + region_args[sg].push_back(cand); + } + return std::move(var); } else { CHECK_EQ(call->op, compiler_end_op); // The annotation node is inserted on edge so it must have only one argument. CHECK_EQ(call->args.size(), 1U); - auto compiler_attrs = call->attrs.as(); + AnnotatedRegion region = GetRegion(GetRef(call)); - // Check if the argument already belongs to an existing subgraph - auto subgraph = GetSubgraph(call->args[0]); - if (!subgraph) { - auto ret = this->subgraphs_.emplace(std::make_shared()); - subgraph = *ret.first; - subgraph->nodes.insert(call->args[0]); - subgraph->id = this->subgraph_id_++; - } - subgraph->nodes.insert(GetRef(call)); + // TODO(@manupa-arm) : need to use the parent function (to which region + // belongs to) name/key for the funtions that are created + BaseFunc f = GetFunc(GetRef(call)); // Traverse subgraph inputs. auto input = VisitExpr(call->args[0]); - Array params; - Array args; - std::unordered_map params_bind; - - // The subgraph may be merged so we need to update it again. - subgraph = GetSubgraph(GetRef(call)); - CHECK(subgraph); - - // Record the constants for propagation. - for (auto pair : subgraph->args) { - params.push_back(pair.first); - if (const auto* cn = pair.second.as()) { - params_bind[pair.first->name_hint()] = cn->data; + CHECK(region.defined()) << "Region not defined for " << GetRef(call); + // functions are created for each annotated regions, + // when their first output is encountered. + // If multiple outputs are there, a tuple node is inserted at the end. + // region_function_calls is map that maintains + // (each annotated regions) --> created function + + if (region_function_calls.find(region) != region_function_calls.end()) { + // This section is executed only if there are multiple outputs in the region + // Thus, the function is always created and at the end there would be a tuple node + // Therefore, we insert a tuple get item node. + + // Use the already created tuple node + auto sg_call = region_function_calls[region]; + int index = GetRetIdx(region, GetRef(call)); + CHECK_NE(index, -1); + + auto tuple_get_item_ = TupleGetItem(sg_call, index); + tuple_get_item_->checked_type_ = GetRef(call)->args[0]->checked_type_; + return std::move(tuple_get_item_); + } else { + // First time this region is encountered in the traversal + // Creating the function + + Array fields; + + for (auto ret : region->GetOutputs()) { + auto ret_expr = VisitExpr(Downcast(ret)->args[0]); + fields.push_back(ret_expr); + } + int index = GetRetIdx(region, GetRef(call)); + CHECK_NE(index, -1); + + Array params; + Array param_expr; + std::unordered_map params_bind; + + for (auto pair : region_args[region]) { + params.push_back(pair.first); + if (const auto* cn = pair.second.as()) { + params_bind[pair.first->name_hint()] = cn->data; + } else { + param_expr.push_back(pair.second); + } + } + + Function global_region_func; + if (region->GetOutputs().size() == 1) { + // If there are only a single output; no need to add a tuple + global_region_func = Function(params, fields[0], + call->args[0]->checked_type_, {}, DictAttrs()); } else { - args.push_back(pair.second); + auto tuple = Tuple(fields); + global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs()); + } + + std::string target = call->attrs.as()->compiler; + std::string name = target + "_" + std::to_string(region->GetID()); + + global_region_func = WithAttr(std::move(global_region_func), attr::kExternalSymbol, + tir::StringImmNode::make(name)); + global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, + tvm::Integer(1)); + global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, + tvm::tir::StringImmNode::make(target)); + global_region_func = WithAttr(std::move(global_region_func), attr::kInline, + tvm::Integer(1)); + + // Constant propagation + if (!params_bind.empty()) { + global_region_func = backend::BindParamsByName(global_region_func, params_bind); } - } - auto subgraph_func = - Function(params, input, call->checked_type_, {}); - - std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id); - subgraph_func = - WithAttr(std::move(subgraph_func), attr::kExternalSymbol, tir::StringImmNode::make(name)); - subgraph_func = - WithAttr(std::move(subgraph_func), attr::kPrimitive, tvm::Integer(1)); - subgraph_func = - WithAttr(std::move(subgraph_func), attr::kCompiler, - tvm::tir::StringImmNode::make(compiler_attrs->compiler)); - subgraph_func = - WithAttr(std::move(subgraph_func), attr::kInline, tvm::Integer(1)); - - // Constant propagation - if (!params_bind.empty()) { - subgraph_func = backend::BindParamsByName(subgraph_func, params_bind); + std::string fname = name; + CHECK(!module_->ContainGlobalVar(fname)) + << "Global function " << fname << " already exists"; + // Create a global function and add it to the IRModule for the region. + // This way we lift the functions that should be handled by external + // codegen to the module scope and rely on the pass manager to prevent relay + // function level passes (i.e. simplify inference and fusion) optimizing it. + GlobalVar glob_func(fname); + module_->Add(glob_func, global_region_func); + + // The return type of callnode is the same as the type of the + // compiler_end node. + auto ret = Call(glob_func, param_expr); + region_function_calls[region] = ret; + + if (region->GetOutputs().size() == 1) { + // If there is only a single output; no need to add a tuplegetitem node + return Call(glob_func, param_expr); + } else { + // Add a tuplegetitem node to select this output out of many + auto tuple_get_item_ = TupleGetItem(ret, index); + tuple_get_item_->checked_type_ = GetRef(call)->args[0]->checked_type_; + return std::move(tuple_get_item_); + } } - CHECK(!module_->ContainGlobalVar(name)) - << "Global function " << name << " already exists"; - // Create a global function and add it to the IRModule for the subgraph. - // This way we lift the functions that should be handled by external - // codegen to the module scope and rely on the pass manager to prevent relay - // function level passes (i.e. simplify inference and fusion) optimizing it. - GlobalVar glob_func(name); - module_->Add(glob_func, subgraph_func); - // The return type of callnode is the same as the type of the - // compiler_end node. - auto ret = Call(glob_func, args); - ret->checked_type_ = call->checked_type_; - return std::move(ret); } } - Expr VisitExpr_(const TupleNode* op) final { - auto subgraph = GetSubgraph(GetRef(op)); - if (!subgraph) { + Expr VisitExpr_(const TupleNode *op) final { + auto region = GetRegion(GetRef(op)); + if (!region.defined()) { return ExprMutator::VisitExpr_(op); } else { - for (auto field : op->fields) { - AddToSubgraph(subgraph, field); - } Array fields; for (auto field : op->fields) { fields.push_back(VisitExpr(field)); @@ -268,26 +306,22 @@ class Partitioner : public ExprMutator { } } - Expr VisitExpr_(const TupleGetItemNode* g) final { - auto subgraph = GetSubgraph(GetRef(g)); - if (!subgraph) { + Expr VisitExpr_(const TupleGetItemNode *g) final { + auto region = GetRegion(GetRef(g)); + if (!region.defined()) { return ExprMutator::VisitExpr_(g); } else { - AddToSubgraph(subgraph, g->tuple); auto t = VisitExpr(g->tuple); return TupleGetItem(t, g->index); } } - Expr VisitExpr_(const FunctionNode* op) final { - auto subgraph = GetSubgraph(GetRef(op)); - if (!subgraph) { + Expr VisitExpr_(const FunctionNode *op) final { + auto region = GetRegion(GetRef(op)); + if (!region.defined()) { return ExprMutator::VisitExpr_(op); } else { Array params; - for (auto param : op->params) { - AddToSubgraph(subgraph, param); - } for (auto param : op->params) { Var new_param = Downcast(VisitExpr(param)); params.push_back(new_param); @@ -297,30 +331,23 @@ class Partitioner : public ExprMutator { } } - Expr VisitExpr_(const LetNode* op) final { - auto subgraph = GetSubgraph(GetRef(op)); - if (!subgraph) { + Expr VisitExpr_(const LetNode *op) final { + auto region = GetRegion(GetRef(op)); + if (!region.defined()) { return ExprMutator::VisitExpr_(op); } else { - AddToSubgraph(subgraph, op->var); - AddToSubgraph(subgraph, op->value); - AddToSubgraph(subgraph, op->body); Var var = Downcast(VisitExpr(op->var)); auto value = VisitExpr(op->value); auto body = VisitExpr(op->body); - return Let(var, value, body); } } - Expr VisitExpr_(const IfNode* op) final { - auto subgraph = GetSubgraph(GetRef(op)); - if (!subgraph) { + Expr VisitExpr_(const IfNode *op) final { + auto region = GetRegion(GetRef(op)); + if (!region.defined()) { return ExprMutator::VisitExpr_(op); } else { - AddToSubgraph(subgraph, op->cond); - AddToSubgraph(subgraph, op->true_branch); - AddToSubgraph(subgraph, op->false_branch); auto guard = VisitExpr(op->cond); auto true_b = VisitExpr(op->true_branch); auto false_b = VisitExpr(op->false_branch); @@ -328,34 +355,31 @@ class Partitioner : public ExprMutator { } } - Expr VisitExpr_(const RefCreateNode* op) final { - auto subgraph = GetSubgraph(GetRef(op)); - if (!subgraph) { + Expr VisitExpr_(const RefCreateNode *op) final { + auto region = GetRegion(GetRef(op)); + if (!region.defined()) { return ExprMutator::VisitExpr_(op); } else { - AddToSubgraph(subgraph, op->value); Expr value = VisitExpr(op->value); return RefCreate(value); } } - Expr VisitExpr_(const RefReadNode* op) final { - auto subgraph = GetSubgraph(GetRef(op)); - if (!subgraph) { + Expr VisitExpr_(const RefReadNode *op) final { + auto region = GetRegion(GetRef(op)); + if (!region.defined()) { return ExprMutator::VisitExpr_(op); } else { - AddToSubgraph(subgraph, op->ref); Expr ref = VisitExpr(op->ref); return RefRead(ref); } } - Expr VisitExpr_(const RefWriteNode* op) final { - auto subgraph = GetSubgraph(GetRef(op)); - if (!subgraph) { + Expr VisitExpr_(const RefWriteNode *op) final { + auto region = GetRegion(GetRef(op)); + if (!region.defined()) { return ExprMutator::VisitExpr_(op); } else { - AddToSubgraph(subgraph, op->ref); Expr ref = VisitExpr(op->ref); Expr value = VisitExpr(op->value); return RefWrite(ref, value); @@ -364,14 +388,14 @@ class Partitioner : public ExprMutator { IRModule Partition() { auto glob_funcs = module_->functions; - for (const auto& pair : glob_funcs) { - if (auto* fn = pair.second.as()) { + for (const auto &pair : glob_funcs) { + if (auto *fn = pair.second.as()) { auto func = GetRef(fn); func = Function(func->params, - VisitExpr(func->body), - func->ret_type, - func->type_params, - func->attrs); + VisitExpr(func->body), + func->ret_type, + func->type_params, + func->attrs); module_->Update(pair.first, func); } } @@ -379,21 +403,99 @@ class Partitioner : public ExprMutator { } private: - int var_id_{0}; - int subgraph_id_{0}; - std::unordered_set> subgraphs_; + /*! + * \brief Get the region an expression belongs to + * if its in a region. + */ + AnnotatedRegion GetRegion(const Expr &e) { + for (auto sg_set_it : regions_sets_) { + auto sg_set = sg_set_it.first; + AnnotatedRegion sg = sg_set->GetRegion(e); + if (sg.defined()) { + return sg; + } + } + return AnnotatedRegion(nullptr); + } + + /*! + * \brief Get the function an expression belongs to + * if its in a region. + */ + BaseFunc GetFunc(const Expr &e) { + for (auto sg_set_it : regions_sets_) { + auto sg_set = sg_set_it.first; + auto func = sg_set_it.second; + + AnnotatedRegion sg = sg_set->GetRegion(e); + if (sg.defined()) { + return func; + } + } + return BaseFunc(nullptr); + } + + /*! + * \brief Get the index of the argument; + * this is to be used as tuplegetitem idx + */ + int GetArgIdx(AnnotatedRegion sg, const Expr &arg) { + int idx = 0; + for (auto arg_ : sg->GetInputs()) { + if (arg == arg_) { + return idx; + } + idx++; + } + return -1; + } + + /*! + * \brief Get the index of the return(output); + * this is to be used as tuplegetitem idx + */ + int GetRetIdx(AnnotatedRegion sg, const Expr &arg) { + int idx = 0; + for (auto arg_ : sg->GetOutputs()) { + if (arg == arg_) { + return idx; + } + idx++; + } + return -1; + } + + /*! + * \brief This map maintains the already created function calls. + * This is required in the multi-output scenario, to link rest of the outputs to call + */ + std::unordered_map region_function_calls; + + /*! + * \brief This map maintains arguments (of region) visits through visitor patterns. + * Those arguement var and expression will be used to when creating the function. + */ + std::unordered_map>, + ObjectHash, ObjectEqual> region_args; + + /*! + * \brief Each region set is associated with a function in the module. + * This map maintains the mapping between regionsets and the function it belongs to + */ + std::unordered_map regions_sets_; IRModule module_; }; + } // namespace partitioning namespace transform { Pass PartitionGraph() { runtime::TypedPackedFunc part_func = - [=](IRModule m, PassContext pc) { - return partitioning::Partitioner(m).Partition(); - }; + [=](IRModule m, PassContext pc) { + return partitioning::Partitioner(m).Partition(); + }; auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {}); return Sequential({partitioned, InferType()}); } diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index fc8dfb619124..f748a8f8936c 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -678,6 +678,81 @@ def expected(): check_result(mod, {"y": y_data}, (8, 8), np.log(np_add)) +def test_multiple_outputs(): + def create_merged_graph(): + data = relay.var('data', shape=(10, 10)) + + cb_1 = compiler_begin(data, 'test_target') + O_1 = relay.abs(cb_1) + ce_2 = compiler_end(O_1, 'test_target') + O_2 = relay.nn.relu(O_1) + ce_3 = compiler_end(O_2, 'test_target') + + X = relay.tanh(ce_2) + + cb_3 = compiler_begin(ce_3, 'test_target') + cb_4 = compiler_begin(X, 'test_target') + O_3 = relay.add(cb_3, cb_4) + ce_4 = compiler_end(O_3, 'test_target') + + func = relay.Function([data], ce_4) + return func + + def expected(): + mod = tvm.IRModule() + + # function 1 + f1_cb1 = relay.var('test_target_1_i0', shape=(10, 10)) + f1_O_1 = relay.abs(f1_cb1) + f1_O_2 = relay.nn.relu(f1_O_1) + f1_out = relay.Tuple((f1_O_2,f1_O_1)) + func1 = relay.Function([f1_cb1], f1_out) + + func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func1 = func1.with_attr("Compiler", + tvm.tir.StringImm("test_target")) + func1 = func1.with_attr("ExternalSymbol", + tvm.tir.StringImm("test_target_1")) + gv1 = relay.GlobalVar("test_target_1") + mod[gv1] = func1 + + # function 0 + f2_cb3 = relay.var('test_target_0_i0', shape=(10, 10)) + f2_cb4 = relay.var('test_target_0_i1', shape=(10, 10)) + f2_O_3 = relay.add(f2_cb3, f2_cb4) + func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3) + + func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Compiler", + tvm.tir.StringImm("test_target")) + func0 = func0.with_attr("ExternalSymbol", + tvm.tir.StringImm("test_target_0")) + gv0 = relay.GlobalVar("test_target_0") + mod[gv0] = func0 + + # body + data = relay.var('data', shape=(10, 10)) + tuple_out = gv1(data) + ce_2 = relay.TupleGetItem(tuple_out, 1) + ce_3 = relay.TupleGetItem(tuple_out, 0) + + X = relay.tanh(ce_2) + ce_4 = gv0(ce_3, X) + func = relay.Function([data], ce_4) + mod["main"] = func + + return mod + + # print(create_merged_graph()) + mod = tvm.IRModule() + mod["main"] = create_merged_graph() + + ref_mod = expected(); + partitioned = transform.PartitionGraph()(mod) + assert relay.analysis.alpha_equal(partitioned, ref_mod) + if __name__ == "__main__": test_multi_node_compiler() test_extern_ccompiler_single_op() @@ -688,3 +763,4 @@ def expected(): test_function_lifting() test_function_lifting_inline() test_constant_propagation() + test_multiple_outputs() From 259ec292d80d9f1dd27f0a5bf92f2aaa1139fc94 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Thu, 26 Mar 2020 11:58:28 +0000 Subject: [PATCH 2/5] [RELAY] Re-wrote the Graph Partitioner to support multiple outputs *removed the expected use-case as we are taking broken-down PR approach *code style fixes *some trivial one liners --- src/relay/transforms/partition_graph.cc | 51 +++++++++---------------- 1 file changed, 18 insertions(+), 33 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 756284a1301f..f7240b390499 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -51,8 +51,8 @@ namespace partitioning { // Cache compiler_begin and compiler_end annotation ops for equivalence check to // reduce registry lookup overhead. -static const Op &compiler_begin_op = Op::Get("annotation.compiler_begin"); -static const Op &compiler_end_op = Op::Get("annotation.compiler_end"); +static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); +static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); /*! * \brief The checker that verifies if a Relay program is annotated correctly @@ -120,26 +120,11 @@ class AnnotationChecker : public ExprVisitor { * as TupleGetItemNode index. * 6) Therefore, functions will be created for all annotated regions. The name for each * global function is created using "Region" id and the compiler name. - * - * Expected Usecase : - * This pass is intended to run as the last pass in a series of passes as follows : - * 1) Annotate Supported Single Ops - annotated each single op with supported backends. - * We use supported_begin and supported_end annotations. - * 2) Annotate Supported Composite Ops - annotate each composite op (that consist of - * multiple single ops). - * We use supported_begin and supported_end - * annotations. - * 3) Deconflict Pass - Make sure each op is annotated by only a single backend. - * In other words, each Annotated Region will be disjoint. - * We promote supported_* annotations to compiler_* annotations. - * 4) Merge Supported Pass - Merge the disjoint compiler_* Annotated regions belonging - * to same backend. - * 5) *Partition Graph* - Convert Disjoint Annotated Regions into Functions. */ class Partitioner : public ExprMutator { public: - explicit Partitioner(const IRModule &module) : module_(module) { + explicit Partitioner(const IRModule& module) : module_(module) { for (auto f : module->functions) { GlobalVar f_var = f.first; BaseFunc f_func = f.second; @@ -282,7 +267,7 @@ class Partitioner : public ExprMutator { if (region->GetOutputs().size() == 1) { // If there is only a single output; no need to add a tuplegetitem node - return Call(glob_func, param_expr); + return ret; } else { // Add a tuplegetitem node to select this output out of many auto tuple_get_item_ = TupleGetItem(ret, index); @@ -388,7 +373,7 @@ class Partitioner : public ExprMutator { IRModule Partition() { auto glob_funcs = module_->functions; - for (const auto &pair : glob_funcs) { + for (const auto& pair : glob_funcs) { if (auto *fn = pair.second.as()) { auto func = GetRef(fn); func = Function(func->params, @@ -405,9 +390,9 @@ class Partitioner : public ExprMutator { private: /*! * \brief Get the region an expression belongs to - * if its in a region. + * if its in a region. */ - AnnotatedRegion GetRegion(const Expr &e) { + AnnotatedRegion GetRegion(const Expr& e) { for (auto sg_set_it : regions_sets_) { auto sg_set = sg_set_it.first; AnnotatedRegion sg = sg_set->GetRegion(e); @@ -420,9 +405,9 @@ class Partitioner : public ExprMutator { /*! * \brief Get the function an expression belongs to - * if its in a region. + * if its in a region. */ - BaseFunc GetFunc(const Expr &e) { + BaseFunc GetFunc(const Expr& e) { for (auto sg_set_it : regions_sets_) { auto sg_set = sg_set_it.first; auto func = sg_set_it.second; @@ -436,10 +421,10 @@ class Partitioner : public ExprMutator { } /*! - * \brief Get the index of the argument; - * this is to be used as tuplegetitem idx - */ - int GetArgIdx(AnnotatedRegion sg, const Expr &arg) { + * \brief Get the index of the argument; + * this is to be used as tuplegetitem idx + */ + int GetArgIdx(AnnotatedRegion sg, const Expr& arg) { int idx = 0; for (auto arg_ : sg->GetInputs()) { if (arg == arg_) { @@ -452,9 +437,9 @@ class Partitioner : public ExprMutator { /*! * \brief Get the index of the return(output); - * this is to be used as tuplegetitem idx + * this is to be used as tuplegetitem idx */ - int GetRetIdx(AnnotatedRegion sg, const Expr &arg) { + int GetRetIdx(AnnotatedRegion sg, const Expr& arg) { int idx = 0; for (auto arg_ : sg->GetOutputs()) { if (arg == arg_) { @@ -467,20 +452,20 @@ class Partitioner : public ExprMutator { /*! * \brief This map maintains the already created function calls. - * This is required in the multi-output scenario, to link rest of the outputs to call + * This is required in the multi-output scenario, to link rest of the outputs to call */ std::unordered_map region_function_calls; /*! * \brief This map maintains arguments (of region) visits through visitor patterns. - * Those arguement var and expression will be used to when creating the function. + * Those arguement var and expression will be used to when creating the function. */ std::unordered_map>, ObjectHash, ObjectEqual> region_args; /*! * \brief Each region set is associated with a function in the module. - * This map maintains the mapping between regionsets and the function it belongs to + * This map maintains the mapping between regionsets and the function it belongs to */ std::unordered_map regions_sets_; IRModule module_; From 911c421296517a4f26f509f96d3cfba8f9473977 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Thu, 26 Mar 2020 13:50:52 +0000 Subject: [PATCH 3/5] [RELAY] Re-wrote the Graph Partitioner to support multiple outputs *fixed an implicit copy to a move --- src/relay/transforms/partition_graph.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index f7240b390499..e66adf15106e 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -267,7 +267,7 @@ class Partitioner : public ExprMutator { if (region->GetOutputs().size() == 1) { // If there is only a single output; no need to add a tuplegetitem node - return ret; + return std::move(ret); } else { // Add a tuplegetitem node to select this output out of many auto tuple_get_item_ = TupleGetItem(ret, index); From cf2d501d56cba4243b2bc736396372b97dbad09f Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 27 Mar 2020 14:04:03 +0000 Subject: [PATCH 4/5] [RELAY] Re-wrote the Graph Partitioner to support multiple outputs *code style changes for comments *renamed test case multiple outputs --> mixed single multiple outputs Since the existing test case checks for both single and multiple output scenarios *added a new test case with conv2d + batch_norm *some var name changes in the test --- src/relay/transforms/partition_graph.cc | 42 +++---- .../python/relay/test_pass_partition_graph.py | 117 +++++++++++++++++- 2 files changed, 133 insertions(+), 26 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index e66adf15106e..32db40ad9afa 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -389,9 +389,9 @@ class Partitioner : public ExprMutator { private: /*! - * \brief Get the region an expression belongs to - * if its in a region. - */ + * \brief Get the region an expression belongs to + * if its in a region. + */ AnnotatedRegion GetRegion(const Expr& e) { for (auto sg_set_it : regions_sets_) { auto sg_set = sg_set_it.first; @@ -404,9 +404,9 @@ class Partitioner : public ExprMutator { } /*! - * \brief Get the function an expression belongs to - * if its in a region. - */ + * \brief Get the function an expression belongs to + * if its in a region. + */ BaseFunc GetFunc(const Expr& e) { for (auto sg_set_it : regions_sets_) { auto sg_set = sg_set_it.first; @@ -421,9 +421,9 @@ class Partitioner : public ExprMutator { } /*! - * \brief Get the index of the argument; - * this is to be used as tuplegetitem idx - */ + * \brief Get the index of the argument; + * this is to be used as tuplegetitem idx + */ int GetArgIdx(AnnotatedRegion sg, const Expr& arg) { int idx = 0; for (auto arg_ : sg->GetInputs()) { @@ -436,9 +436,9 @@ class Partitioner : public ExprMutator { } /*! - * \brief Get the index of the return(output); - * this is to be used as tuplegetitem idx - */ + * \brief Get the index of the return(output); + * this is to be used as tuplegetitem idx + */ int GetRetIdx(AnnotatedRegion sg, const Expr& arg) { int idx = 0; for (auto arg_ : sg->GetOutputs()) { @@ -451,22 +451,22 @@ class Partitioner : public ExprMutator { } /*! - * \brief This map maintains the already created function calls. - * This is required in the multi-output scenario, to link rest of the outputs to call - */ + * \brief This map maintains the already created function calls. + * This is required in the multi-output scenario, to link rest of the outputs to call + */ std::unordered_map region_function_calls; /*! - * \brief This map maintains arguments (of region) visits through visitor patterns. - * Those arguement var and expression will be used to when creating the function. - */ + * \brief This map maintains arguments (of region) visits through visitor patterns. + * Those arguement var and expression will be used to when creating the function. + */ std::unordered_map>, ObjectHash, ObjectEqual> region_args; /*! - * \brief Each region set is associated with a function in the module. - * This map maintains the mapping between regionsets and the function it belongs to - */ + * \brief Each region set is associated with a function in the module. + * This map maintains the mapping between regionsets and the function it belongs to + */ std::unordered_map regions_sets_; IRModule module_; }; diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index f748a8f8936c..594faf7ae462 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -679,7 +679,114 @@ def expected(): def test_multiple_outputs(): - def create_merged_graph(): + + def create_graph(): + data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")) + weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32")) + bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32")) + bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32")) + bn_mean = relay.var("bn_mean", relay.TensorType((16, ), "float32")) + bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32")) + + data_cb = compiler_begin(data, 'test_target') + weight_cb = compiler_begin(weight, 'test_target') + bn_gamma_cb = compiler_begin(bn_gamma, 'test_target') + bn_beta_cb = compiler_begin(bn_beta, 'test_target') + bn_mean_cb = compiler_begin(bn_mean, 'test_target') + bn_var_cb = compiler_begin(bn_var, 'test_target') + + conv_o = relay.nn.conv2d( + data=data_cb, + weight=weight_cb, + kernel_size=(3, 3), + channels=16, + padding=(1, 1)) + + bn_o = relay.nn.batch_norm(conv_o, bn_gamma_cb, bn_beta_cb, bn_mean_cb, + bn_var_cb) + + relu_o = relay.nn.relu(bn_o[0]) + relu_o_ce = compiler_end(relu_o, 'test_target') + + bn_omean = bn_o[1] + rebn_omean_ce = compiler_end(bn_omean, 'test_target') + bn_ovar = bn_o[2] + bn_ovar_ce = compiler_end(bn_ovar, 'test_target') + + dummy_mean_abs = relay.abs(rebn_omean_ce) + dummy_ovar_abs = relay.abs(bn_ovar_ce) + dummy_tuple = relay.Tuple((relu_o_ce, dummy_mean_abs,dummy_ovar_abs)) + + func = relay.Function([data, weight, bn_gamma, bn_beta, + bn_mean, bn_var], dummy_tuple) + return func + + def expected(): + mod = tvm.IRModule() + + # function 0 + data = relay.var("test_target_0_i0", relay.TensorType((1, 3, 224, 224), "float32")) + weight = relay.var("test_target_0_i1", relay.TensorType((16, 3, 3, 3), "float32")) + bn_gamma = relay.var("test_target_0_i2", relay.TensorType((16, ), "float32")) + bn_beta = relay.var("test_target_0_i3", relay.TensorType((16, ), "float32")) + bn_mean = relay.var("test_target_0_i4", relay.TensorType((16, ), "float32")) + bn_var = relay.var("test_target_0_i5", relay.TensorType((16, ), "float32")) + + conv_o = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=(3, 3), + channels=16, + padding=(1, 1)) + + bn_o = relay.nn.batch_norm(conv_o, bn_gamma, bn_beta, bn_mean, + bn_var) + + relu_o = relay.nn.relu(bn_o[0]) + tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2])) + + func0 = relay.Function([data, weight, bn_gamma, bn_beta, + bn_mean, bn_var], tuple_o) + func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Compiler", + tvm.tir.StringImm("test_target")) + func0 = func0.with_attr("ExternalSymbol", + tvm.tir.StringImm("test_target_0")) + gv0 = relay.GlobalVar("test_target_0") + mod[gv0] = func0 + + # body + data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")) + weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32")) + bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32")) + bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32")) + bn_mean = relay.var("bn_mean", relay.TensorType((16, ), "float32")) + bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32")) + + f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var) + f0_relu_o = relay.TupleGetItem(f0_o, 0) + f0_mean_o = relay.TupleGetItem(f0_o, 1) + f0_var_o = relay.TupleGetItem(f0_o, 2) + + f0_mean_abs = relay.abs(f0_mean_o) + f0_var_abs = relay.abs(f0_var_o) + main_tuple = relay.Tuple((f0_relu_o, f0_mean_abs, f0_var_abs)) + + func = relay.Function([data, weight, bn_gamma, + bn_beta, bn_mean, bn_var], main_tuple) + mod["main"] = func + return mod + + mod = tvm.IRModule() + mod["main"] = create_graph() + ref_mod = expected() + partitioned = transform.PartitionGraph()(mod) + assert relay.analysis.alpha_equal(partitioned, ref_mod) + + +def test_mixed_single_multiple_outputs(): + def create_graph(): data = relay.var('data', shape=(10, 10)) cb_1 = compiler_begin(data, 'test_target') @@ -705,7 +812,7 @@ def expected(): f1_cb1 = relay.var('test_target_1_i0', shape=(10, 10)) f1_O_1 = relay.abs(f1_cb1) f1_O_2 = relay.nn.relu(f1_O_1) - f1_out = relay.Tuple((f1_O_2,f1_O_1)) + f1_out = relay.Tuple((f1_O_2, f1_O_1)) func1 = relay.Function([f1_cb1], f1_out) func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) @@ -745,11 +852,10 @@ def expected(): return mod - # print(create_merged_graph()) mod = tvm.IRModule() - mod["main"] = create_merged_graph() + mod["main"] = create_graph() - ref_mod = expected(); + ref_mod = expected() partitioned = transform.PartitionGraph()(mod) assert relay.analysis.alpha_equal(partitioned, ref_mod) @@ -764,3 +870,4 @@ def expected(): test_function_lifting_inline() test_constant_propagation() test_multiple_outputs() + test_mixed_single_multiple_outputs() From ed2415f9ab801f3ebd154639dedfb6a8c4c1b22f Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Mon, 30 Mar 2020 17:42:46 +0100 Subject: [PATCH 5/5] [RELAY] Re-wrote the Graph Partitioner to support multiple outputs *rebased --- tests/python/relay/test_pass_partition_graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 594faf7ae462..0dfc89d469ca 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -782,7 +782,7 @@ def expected(): mod["main"] = create_graph() ref_mod = expected() partitioned = transform.PartitionGraph()(mod) - assert relay.analysis.alpha_equal(partitioned, ref_mod) + assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) def test_mixed_single_multiple_outputs(): @@ -857,7 +857,7 @@ def expected(): ref_mod = expected() partitioned = transform.PartitionGraph()(mod) - assert relay.analysis.alpha_equal(partitioned, ref_mod) + assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) if __name__ == "__main__": test_multi_node_compiler()