From a4653fdd4f3d0f944b9c66cbb067ea9c08b67974 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Thu, 26 Mar 2020 11:58:28 +0000 Subject: [PATCH] [RELAY] Re-wrote the Graph Partitioner to support multiple outputs *removed the expected use-case as we are taking broken-down PR approach *code style fixes *some trivial one liners --- src/relay/transforms/partition_graph.cc | 51 +++++++++---------------- 1 file changed, 18 insertions(+), 33 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 756284a1301f4..f7240b3904994 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -51,8 +51,8 @@ namespace partitioning { // Cache compiler_begin and compiler_end annotation ops for equivalence check to // reduce registry lookup overhead. -static const Op &compiler_begin_op = Op::Get("annotation.compiler_begin"); -static const Op &compiler_end_op = Op::Get("annotation.compiler_end"); +static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); +static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); /*! * \brief The checker that verifies if a Relay program is annotated correctly @@ -120,26 +120,11 @@ class AnnotationChecker : public ExprVisitor { * as TupleGetItemNode index. * 6) Therefore, functions will be created for all annotated regions. The name for each * global function is created using "Region" id and the compiler name. - * - * Expected Usecase : - * This pass is intended to run as the last pass in a series of passes as follows : - * 1) Annotate Supported Single Ops - annotated each single op with supported backends. - * We use supported_begin and supported_end annotations. - * 2) Annotate Supported Composite Ops - annotate each composite op (that consist of - * multiple single ops). - * We use supported_begin and supported_end - * annotations. - * 3) Deconflict Pass - Make sure each op is annotated by only a single backend. - * In other words, each Annotated Region will be disjoint. - * We promote supported_* annotations to compiler_* annotations. - * 4) Merge Supported Pass - Merge the disjoint compiler_* Annotated regions belonging - * to same backend. - * 5) *Partition Graph* - Convert Disjoint Annotated Regions into Functions. */ class Partitioner : public ExprMutator { public: - explicit Partitioner(const IRModule &module) : module_(module) { + explicit Partitioner(const IRModule& module) : module_(module) { for (auto f : module->functions) { GlobalVar f_var = f.first; BaseFunc f_func = f.second; @@ -282,7 +267,7 @@ class Partitioner : public ExprMutator { if (region->GetOutputs().size() == 1) { // If there is only a single output; no need to add a tuplegetitem node - return Call(glob_func, param_expr); + return ret; } else { // Add a tuplegetitem node to select this output out of many auto tuple_get_item_ = TupleGetItem(ret, index); @@ -388,7 +373,7 @@ class Partitioner : public ExprMutator { IRModule Partition() { auto glob_funcs = module_->functions; - for (const auto &pair : glob_funcs) { + for (const auto& pair : glob_funcs) { if (auto *fn = pair.second.as()) { auto func = GetRef(fn); func = Function(func->params, @@ -405,9 +390,9 @@ class Partitioner : public ExprMutator { private: /*! * \brief Get the region an expression belongs to - * if its in a region. + * if its in a region. */ - AnnotatedRegion GetRegion(const Expr &e) { + AnnotatedRegion GetRegion(const Expr& e) { for (auto sg_set_it : regions_sets_) { auto sg_set = sg_set_it.first; AnnotatedRegion sg = sg_set->GetRegion(e); @@ -420,9 +405,9 @@ class Partitioner : public ExprMutator { /*! * \brief Get the function an expression belongs to - * if its in a region. + * if its in a region. */ - BaseFunc GetFunc(const Expr &e) { + BaseFunc GetFunc(const Expr& e) { for (auto sg_set_it : regions_sets_) { auto sg_set = sg_set_it.first; auto func = sg_set_it.second; @@ -436,10 +421,10 @@ class Partitioner : public ExprMutator { } /*! - * \brief Get the index of the argument; - * this is to be used as tuplegetitem idx - */ - int GetArgIdx(AnnotatedRegion sg, const Expr &arg) { + * \brief Get the index of the argument; + * this is to be used as tuplegetitem idx + */ + int GetArgIdx(AnnotatedRegion sg, const Expr& arg) { int idx = 0; for (auto arg_ : sg->GetInputs()) { if (arg == arg_) { @@ -452,9 +437,9 @@ class Partitioner : public ExprMutator { /*! * \brief Get the index of the return(output); - * this is to be used as tuplegetitem idx + * this is to be used as tuplegetitem idx */ - int GetRetIdx(AnnotatedRegion sg, const Expr &arg) { + int GetRetIdx(AnnotatedRegion sg, const Expr& arg) { int idx = 0; for (auto arg_ : sg->GetOutputs()) { if (arg == arg_) { @@ -467,20 +452,20 @@ class Partitioner : public ExprMutator { /*! * \brief This map maintains the already created function calls. - * This is required in the multi-output scenario, to link rest of the outputs to call + * This is required in the multi-output scenario, to link rest of the outputs to call */ std::unordered_map region_function_calls; /*! * \brief This map maintains arguments (of region) visits through visitor patterns. - * Those arguement var and expression will be used to when creating the function. + * Those arguement var and expression will be used to when creating the function. */ std::unordered_map>, ObjectHash, ObjectEqual> region_args; /*! * \brief Each region set is associated with a function in the module. - * This map maintains the mapping between regionsets and the function it belongs to + * This map maintains the mapping between regionsets and the function it belongs to */ std::unordered_map regions_sets_; IRModule module_;