From 021213832cb98703dda54f631215ac17fbabff7b Mon Sep 17 00:00:00 2001 From: mbaret <55580676+mbaret@users.noreply.github.com> Date: Mon, 30 Mar 2020 21:59:10 +0100 Subject: [PATCH] [RELAY] Add MergeCompilerRegions pass (#5134) * [RELAY] Add MergeCompilerRegions pass This pass is part of the flow to support creating compiler regions with multiple outputs. It should be called after AnnotateTarget and will merge together regions that share the same target to create larger compiler regions that can be off-loaded to external codegens. This pass implements an algorithm to ensure that during the merging, no data dependency issues are created. See the tests for an example of this case. Co-authored-by: Ramana Radhakrishnan Co-authored-by: Manupa Karunaratne Change-Id: Ibd99083564608d888482f57c5080109f3eefec88 * [RELAY] Annotate compiler_ends on each edge This alters the behaviour of the AnnotateTarget pass to enforce the property that all compiler annotations exist along a single data flow edge. Specifically, this means they should have exactly one parent and one child. Change-Id: I0e74803a77767f4f377d17755a13a74a30909797 * Fix comment * Rebase *Node::make * Moved block outside for loop * Code style * Update make API * Remove comment * Remove redundant 'else's * Make one line * Fix comment * RefWrite * Fix merge ordering * Add the RFC example as a test * [FIX] Fixed merging behaviour in AnnotateRegionSet Deleting items from a list while iterating it seems to result in undefined behaviour which sometimes segfaults. This makes sure all the item deletion happens separately. * Added checks * Move comment * Update comments --- python/tvm/relay/transform/transform.py | 11 + src/relay/analysis/annotated_region_set.cc | 6 +- src/relay/transforms/annotate_target.cc | 150 ++++++-- .../transforms/merge_compiler_regions.cc | 344 ++++++++++++++++++ tests/python/relay/test_annotate_target.py | 36 ++ .../relay/test_pass_merge_compiler_regions.py | 206 +++++++++++ 6 files changed, 726 insertions(+), 27 deletions(-) create mode 100644 src/relay/transforms/merge_compiler_regions.cc create mode 100644 tests/python/relay/test_pass_merge_compiler_regions.py diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index aa17c7f3de1c..41aa04095277 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -397,6 +397,17 @@ def MergeComposite(pattern_table): return _ffi_api.MergeComposite(pattern_names, patterns) +def MergeCompilerRegions(): + """Merge together compiler regions. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that merges compiler regions. + """ + return _ffi_api.MergeCompilerRegions() + + def RewriteAnnotatedOps(fallback_device): """Rewrite the annotated program where annotation operators, e.g. `on_deivce`, mark which device an expression should be scheduled to. diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index f8e951bac780..df2eb9643f61 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -55,14 +55,18 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, } // if any of the outputs of src are inputs of dest, they become internal nodes // so remove them from outs + std::vector ins_to_remove; for (const auto& input : dest->ins) { auto call = Downcast(input); auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]); if (it != src->outs.end()) { dest->outs.remove(*it); - dest->ins.remove(input); + ins_to_remove.push_back(input); } } + for (const auto& input : ins_to_remove) { + dest->ins.remove(input); + } regions_.erase(src); } diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index e6f4a18c8b06..c2f7b804cb6a 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -38,38 +38,136 @@ class AnnotateTargetWrapper : public ExprMutator { public: explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {} + Expr Annotate(const Expr& expr) { + return InsertEnd(Mutate(expr)); + } + + bool IsSupported(const Expr& expr) { + if (expr->IsInstance()) { + Call call = Downcast(expr); + auto fannotate = Op::GetAttr("target." + target_); + Op op = Downcast(call->op); + CHECK(op.defined()); + if (fannotate.count(op)) { + return fannotate[op](call->attrs, call->args); + } + } + return false; + } + + Expr InsertEnd(const Expr& arg) { + if (IsSupported(arg)) { + const auto *end_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_end"); + CHECK(end_op); + Expr end = (*end_op)(arg, target_); + return end; + } + return arg; + } + Expr VisitExpr_(const CallNode* cn) { // TODO(@zhiics, @comaniac) Handle composite functions. auto new_e = ExprMutator::VisitExpr_(cn); Call call = Downcast(new_e); - auto fannotate = Op::GetAttr("target." + target_); - Op op = Downcast(call->op); - CHECK(op.defined()); - - if (fannotate.count(op)) { - bool external = fannotate[op](call->attrs, call->args); - if (external) { - tvm::Array compiler_begins; - for (const auto& it : call->args) { - const auto* begin_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); - CHECK(begin_op); - Expr begin = (*begin_op)(it, target_); - compiler_begins.push_back(begin); - } - Expr update_call = Call(call->op, compiler_begins, call->attrs); - const auto* end_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_end"); - CHECK(end_op); - Expr end = (*end_op)(update_call, target_); - return end; + + // add end annotations if the args are supported + Array compiler_ends; + for (const auto& it : call->args) { + compiler_ends.push_back(InsertEnd(it)); + } + call = Call(call->op, compiler_ends, call->attrs); + + // add begin annotations if the call node is supported + if (IsSupported(call)) { + tvm::Array compiler_begins; + const auto* begin_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); + for (const auto& it : call->args) { + CHECK(begin_op); + Expr begin = (*begin_op)(it, target_); + compiler_begins.push_back(begin); } - } else { - LOG(WARNING) << op->name << " in " << target_ - << " is not registered. It will be executed on CPU."; + call = Call(call->op, compiler_begins, call->attrs); } - return new_e; + + return std::move(call); + } + + Expr VisitExpr_(const TupleNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto tup = Downcast(new_e); + Array new_fields; + for (auto field : tup->fields) { + new_fields.push_back(InsertEnd(field)); + } + return Tuple(new_fields); + } + + Expr VisitExpr_(const TupleGetItemNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto get = Downcast(new_e); + return TupleGetItem( + InsertEnd(get->tuple), + get->index); + } + + Expr VisitExpr_(const FunctionNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto func = Downcast(new_e); + return Function( + func->params, + InsertEnd(func->body), + func->ret_type, + func->type_params, + func->attrs); + } + + Expr VisitExpr_(const LetNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto let = Downcast(new_e); + return Let( + let->var, + InsertEnd(let->value), + InsertEnd(let->body)); + } + + Expr VisitExpr_(const IfNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto iff = Downcast(new_e); + return If( + InsertEnd(iff->cond), + InsertEnd(iff->true_branch), + InsertEnd(iff->false_branch)); + } + + Expr VisitExpr_(const RefCreateNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto create = Downcast(new_e); + return RefCreate(InsertEnd(create->value)); + } + + Expr VisitExpr_(const RefReadNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto read = Downcast(new_e); + return RefRead(InsertEnd(read->ref)); + } + + Expr VisitExpr_(const RefWriteNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto write = Downcast(new_e); + return RefWrite( + InsertEnd(write->ref), + InsertEnd(write->value)); } private: @@ -77,7 +175,7 @@ class AnnotateTargetWrapper : public ExprMutator { }; Expr AnnotateTarget(const Expr& expr, const std::string& target) { - return AnnotateTargetWrapper(target).Mutate(expr); + return AnnotateTargetWrapper(target).Annotate(expr); } } // namespace annotate_target diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc new file mode 100644 index 000000000000..e6ec93aecd42 --- /dev/null +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file src/relay/transforms/merge_compiler_regions.cc + * + * \brief After operators have been annotated with the targets that support + * them, this pass creates regions of the operators for each target. It + * is guaranteed that the regions will have a topological ordering so that + * no data dependency issues exist. + * + * This pass only introduces annotations to indicate the regions. + * partition_graph must subsequently be called to lift these regions out + * as external functions. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../analysis/annotated_region_set.h" + + +namespace tvm { +namespace relay { +namespace partitioning { + +// Cache compiler_begin and compiler_end annotation ops for equivalence check to +// reduce registry lookup overhead. +static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); +static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); + +/*! \brief This is a pre-requisite pass to merge-supported pass. + * The AnnotateRestDefault pass will put "default" Compiler Annotations to + * nodes that are not annotated already. This is there to ensure that the + * user will not leave un-annotated nodes MergeCompilerRegions pass is run. + * Why? Because, MergeCompilerRegions pass assumes every node to be annotated. + */ +class AnnotateRestDefault : public ExprMutator { + public: + explicit AnnotateRestDefault(const Expr& expr) { + regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op); + } + + Expr Annotate(const Expr& expr) { + // Its a function that is being passed on to annotate + func_ = Downcast(expr); + + // Corner Case CC1 : If the last node does not belong + // to a region nede to add a compiler_end + auto region = regions_->GetRegion(func_->body); + auto mutated_expr = this->VisitExpr(expr); + if (!region.defined()) { + func_ = Downcast(mutated_expr); + // CC1 : add that compiler end after mutation + auto body = AddCompilerEnd_(func_->body); + func_ = Function(func_->params, body, + body->checked_type_, {}, DictAttrs()); + return Downcast(func_); + } + return mutated_expr; + } + + /*! \brief This function adds compiler ends to nodes that + * have a region AND they should not be arguments of the + * original function + * \param expr The expression to add a compiler end to. + * \return expr The expression with or without a compiler end added. + */ + Expr AddCompilerEnd(const Expr& expr) { + auto region = regions_->GetRegion(expr); + auto visited_expr = VisitExpr(expr); + + // The compiler ends are added to nodes that does have a region + // AND they should not be arguments of the original function + if (!region.defined() && + std::find(func_->params.begin(), + func_->params.end(), visited_expr) + == func_->params.end()) { + return AddCompilerEnd_(visited_expr); + } + return visited_expr; + } + + Expr AddCompilerEnd_(const Expr& expr) { + const auto* end_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_end"); + CHECK(end_op); + Expr end = (*end_op)(expr, target_); + return end; + } + + Expr VisitExpr_(const CallNode* call) final { + auto op_node = call->op.as(); + auto ret = GetRef(call); + + Array args; + + // Add compiler ends if the parent is supported + for (auto arg : call->args) { + args.push_back(AddCompilerEnd(arg)); + } + + if (op_node == nullptr || call->attrs.as() == nullptr) { + // Skip annotatation ops, only add default compiler to actual compute nodes + + auto region = regions_->GetRegion(ret); + if (!region.defined()) { + // if the current node does not belong to annotated region + // annotate the all incoming edges (args) + // with "default" compile_begin and compiler_end annotations. + tvm::Array compiler_begins; + for (auto arg : args) { + const auto* begin_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); + CHECK(begin_op); + Expr begin = (*begin_op)(arg, target_); + compiler_begins.push_back(begin); + } + Expr update_call = Call(call->op, compiler_begins, call->attrs); + return update_call; + } + } + return Call(call->op, args, call->attrs); + }; + + Expr VisitExpr_(const TupleNode *op) { + auto new_e = ExprMutator::VisitExpr_(op); + auto tup = Downcast(new_e); + Array new_fields; + for (auto field : tup->fields) { + new_fields.push_back(AddCompilerEnd(field)); + } + return Tuple(new_fields); + } + + Expr VisitExpr_(const TupleGetItemNode *op) { + auto new_e = ExprMutator::VisitExpr_(op); + auto get = Downcast(new_e); + return TupleGetItem(AddCompilerEnd(get->tuple), get->index); + } + + Expr VisitExpr_(const LetNode *op) { + auto new_e = ExprMutator::VisitExpr_(op); + auto let = Downcast(new_e); + return Let( + let->var, + AddCompilerEnd(let->value), + AddCompilerEnd(let->body)); + } + + Expr VisitExpr_(const IfNode *op) { + auto new_e = ExprMutator::VisitExpr_(op); + auto iff = Downcast(new_e); + return If( + AddCompilerEnd(iff->cond), + AddCompilerEnd(iff->true_branch), + AddCompilerEnd(iff->false_branch)); + } + + Expr VisitExpr_(const RefCreateNode *op) { + auto new_e = ExprMutator::VisitExpr_(op); + auto create = Downcast(new_e); + return RefCreate(AddCompilerEnd(create->value)); + } + + Expr VisitExpr_(const RefReadNode *op) { + auto new_e = ExprMutator::VisitExpr_(op); + auto read = Downcast(new_e); + return RefRead(AddCompilerEnd(read->ref)); + } + + Expr VisitExpr_(const RefWriteNode *op) { + auto new_e = ExprMutator::VisitExpr_(op); + auto write = Downcast(new_e); + return RefWrite( + AddCompilerEnd(write->ref), + AddCompilerEnd(write->value)); + } + + private: + AnnotatedRegionSet regions_; + const std::string target_ = "default"; + Function func_; +}; + +class MergeAnnotations : public ExprMutator { + public: + explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {} + + Expr VisitExpr_(const CallNode* call) final { + if (call->op == compiler_begin_op) { + if (call->args[0]->IsInstance()) { + auto arg = Downcast(call->args[0]); + if (arg->op == compiler_end_op) { + auto region1 = regions_->GetRegion(GetRef(call)); + auto region2 = regions_->GetRegion(arg); + if (region1 == region2) { + return ExprMutator::VisitExpr(arg->args[0]); + } + } + } + } + return ExprMutator::VisitExpr_(call); + } + + private: + AnnotatedRegionSet regions_; +}; + +class RegionMerger : public ExprVisitor { + public: + explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {} + + void VisitExpr_(const CallNode* call) final { + if (call->op == compiler_end_op) { + auto region = regions_->GetRegion(GetRef(call)); + // set the region target + auto compiler_attrs = call->attrs.as(); + region_targets_[region->GetID()] = compiler_attrs->compiler; + std::vector mergeable_regions; + // first look at the region args to determine the parent regions + for (const auto& arg : region->GetInputs()) { + // all args should be begin annotations + auto begin = Downcast(arg); + CHECK_EQ(begin->op, compiler_begin_op); + // the arguments of the begin annotations will be in the parent regions + auto parent_region = regions_->GetRegion(begin->args[0]); + // if there is no parent region, move on + if (!parent_region.defined()) continue; + // merge the parent region if it hasn't been done already + if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) { + VisitExpr(begin->args[0]); + } + mergeable_regions.push_back(parent_region); + } + auto& region_restrictions = region_restrictions_[region->GetID()]; + for (const auto& parent_region : mergeable_regions) { + // add all the parent restrictions to the current region + auto parent_restrictions = region_restrictions_[parent_region->GetID()]; + region_restrictions.insert(parent_restrictions.begin(), + parent_restrictions.end()); + } + for (const auto& parent_region : mergeable_regions) { + bool merged = false; + // check the parent region has the same target + if (region_targets_[parent_region->GetID()] == compiler_attrs->compiler) { + // check the parent region isn't in the restrictions + if (region_restrictions.find(parent_region->GetID()) == region_restrictions.end()) { + // merge the parent region into the current region + regions_->MergeRegions(parent_region, region); + // update the restrictions of all other regions to reflect the change in id + for (const auto& r : regions_) { + auto& restrictions = region_restrictions_[r->GetID()]; + if (restrictions.find(parent_region->GetID()) != restrictions.end()) { + restrictions.erase(parent_region->GetID()); + restrictions.insert(region->GetID()); + } + } + merged = true; + } + } + // if the parent wasn't merged, add it as a restriction to the current region + if (!merged) + region_restrictions.insert(parent_region->GetID()); + } + merged_regions_.insert(region->GetID()); + } + ExprVisitor::VisitExpr_(call); + } + + private: + AnnotatedRegionSet regions_; + std::unordered_set merged_regions_; + std::map> region_restrictions_; + std::map region_targets_; +}; + + +Expr MergeCompilerRegions(const Expr& expr) { + // Annotate all the nodes that aren't annotated as 'default'. + AnnotateRestDefault anno_default(expr); + auto expr_all_annotated = anno_default.Annotate(expr); + + // Create regions using the annotations. + AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr_all_annotated, + compiler_begin_op, compiler_end_op); + + // By now, all the nodes have some sort of annotation. + // Region merger is an ExprVisitor that will update the + // AnnotatedRegionSet, merging all the regions that can be merged. + RegionMerger merger(regions); + merger.VisitExpr(expr_all_annotated); + + // This updates the expression to remove annotations that are now + // 'internal' to a merged region. + MergeAnnotations merge_anno(regions); + return merge_anno.Mutate(expr_all_annotated); +} + +} // namespace partitioning + +namespace transform { + +Pass MergeCompilerRegions() { + runtime::TypedPackedFunc part_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(partitioning::MergeCompilerRegions(f)); + }; + auto partitioned = CreateFunctionPass(part_func, 0, "MergeCompilerRegions", {}); + return Sequential({partitioned, InferType()}); +} + +TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions") +.set_body_typed(transform::MergeCompilerRegions); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_annotate_target.py index 12a15dcb2c3a..23f56c4febad 100644 --- a/tests/python/relay/test_annotate_target.py +++ b/tests/python/relay/test_annotate_target.py @@ -22,6 +22,7 @@ import tvm import tvm.relay.testing +import tvm.relay.op as reg import tvm.relay.transform as transform from tvm import relay from tvm import runtime @@ -183,6 +184,41 @@ def test_extern_dnnl_mobilenet(): (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params) +@reg.register("nn.relu", "target.test") +def relu(attrs, args): + return True + + +def test_multiple_ends(): + def before(): + x = relay.var("x", shape=(10, 10)) + r = relay.nn.relu(x) + a_1 = relay.abs(r) + a_2 = relay.abs(r) + out = relay.add(a_1, a_2) + f = relay.Function([x], out) + mod = tvm.IRModule.from_expr(f) + return mod + + def after(): + x = relay.var("x", shape=(10, 10)) + cb_1 = relay.annotation.compiler_begin(x, "test") + r = relay.nn.relu(cb_1) + ce_1 = relay.annotation.compiler_end(r, "test") + ce_2 = relay.annotation.compiler_end(r, "test") + a_1 = relay.abs(ce_1) + a_2 = relay.abs(ce_2) + out = relay.add(a_1, a_2) + f = relay.Function([x], out) + mod = tvm.IRModule.from_expr(f) + return mod + + result = transform.AnnotateTarget("test")(before()) + expected = transform.InferType()(after()) + assert relay.analysis.alpha_equal(expected, result) + + if __name__ == "__main__": + test_multiple_ends() test_extern_dnnl() test_extern_dnnl_mobilenet() diff --git a/tests/python/relay/test_pass_merge_compiler_regions.py b/tests/python/relay/test_pass_merge_compiler_regions.py new file mode 100644 index 000000000000..04ff46e337fd --- /dev/null +++ b/tests/python/relay/test_pass_merge_compiler_regions.py @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for merge compiler regions.""" +import tvm +from tvm import relay +from tvm.relay.op.annotation import compiler_begin, compiler_end +from tvm.relay.testing import run_opt_pass + + +def test_diamond_graph_fanouts(): + """ + This tests that the data dependencies present in a diamond-shaped + graph are correctly resolved by the merging pass. + + O = supported by target + X = not supported by target + + O O + / \ / \ + O X --> O + + X + \ / \ / + O O + + Note that we can't just merge the three supported operators together, + otherwise both subgraphs would depend on the other. + """ + def diamond_graph_fanouts(): + data = relay.var('data', shape=(10, 10)) + cb_1 = compiler_begin(data, "test") + O_1 = relay.abs(cb_1) + ce_1 = compiler_end(O_1, "test") + ce_2 = compiler_end(O_1, "test") + cb_2 = compiler_begin(ce_1, "test") + O_2 = relay.nn.relu(cb_2) + ce_3 = compiler_end(O_2, "test") + + X = relay.tanh(ce_2) + + cb_3 = compiler_begin(ce_3, "test") + cb_4 = compiler_begin(X, "test") + O_3 = relay.add(cb_3, cb_4) + ce_4 = compiler_end(O_3, "test") + + diamond = relay.Function([data], ce_4) + return diamond + + def expected(): + data = relay.var('data', shape=(10, 10)) + cb_1 = compiler_begin(data, "test") + O_1 = relay.abs(cb_1) + ce_2 = compiler_end(O_1, "test") + O_2 = relay.nn.relu(O_1) + ce_3 = compiler_end(O_2, "test") + + cb_x = compiler_begin(ce_2, "default") + X = relay.tanh(cb_x) + ce_x1 = compiler_end(X, "default") + ce_x2 = compiler_end(X, "default") + + cb_3 = compiler_begin(ce_3, "test") + cb_4 = compiler_begin(ce_x1, "test") + O_3 = relay.add(cb_3, cb_4) + ce_4 = compiler_end(O_3, "test") + + func = relay.Function([data], ce_4) + return func + + result = run_opt_pass(diamond_graph_fanouts(), relay.transform.MergeCompilerRegions()) + golden = run_opt_pass(expected(), relay.transform.InferType()) + assert relay.analysis.alpha_equal(result, golden) + + +def test_example_graph(): + """This tests the merging algorithm on the example used in the RFC. + + See the RFC here: https://discuss.tvm.ai/t/relay-improved-graph-partitioning-algorithm/5830 + Blue nodes are adds, red nodes are subtracts. + """ + def annotated(): + in_1 = relay.var('in_1', shape=(10, 10), dtype='float32') + in_2 = relay.var('in_2', shape=(10, 10), dtype='float32') + in_3 = relay.var('in_3', shape=(10, 10), dtype='float32') + in_4 = relay.var('in_4', shape=(10, 10), dtype='float32') + in_5 = relay.var('in_5', shape=(10, 10), dtype='float32') + in_6 = relay.var('in_6', shape=(10, 10), dtype='float32') + in_7 = relay.var('in_7', shape=(10, 10), dtype='float32') + in_8 = relay.var('in_8', shape=(10, 10), dtype='float32') + in_9 = relay.var('in_9', shape=(10, 10), dtype='float32') + in_10 = relay.var('in_10', shape=(10, 10), dtype='float32') + + begin0 = compiler_begin(in_1, "test") + begin1 = compiler_begin(in_2, "test") + begin2 = compiler_begin(in_3, "test") + begin3 = compiler_begin(in_4, "test") + node0 = relay.add(begin0, begin1) + node1 = relay.add(begin2, begin3) + end0 = compiler_end(node0, "test") + end1 = compiler_end(node1, "test") + begin4 = compiler_begin(end0, "test") + begin5 = compiler_begin(end1, "test") + node2 = relay.add(begin4, begin5) + end2 = compiler_end(node2, "test") + + node3 = relay.subtract(in_5, in_6) + node4 = relay.subtract(in_7, node3) + + begin6 = compiler_begin(end2, "test") + begin7 = compiler_begin(node4, "test") + node5 = relay.add(begin6, begin7) + end3 = compiler_end(node5, "test") + end4 = compiler_end(node5, "test") + node6 = relay.subtract(in_8, end3) + begin8 = compiler_begin(in_9, "test") + begin9 = compiler_begin(end4, "test") + node7 = relay.add(begin8, begin9) + end5 = compiler_end(node7, "test") + + begin10 = compiler_begin(node6, "test") + begin11 = compiler_begin(end5, "test") + node8 = relay.add(begin10, begin11) + end6 = compiler_end(node8, "test") + begin12 = compiler_begin(in_10, "test") + begin13 = compiler_begin(end6, "test") + node9 = relay.add(begin12, begin13) + end7 = compiler_end(node9, "test") + + f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end7) + mod = tvm.IRModule.from_expr(f) + return mod + + def expected(): + in_1 = relay.var('in_1', shape=(10, 10), dtype='float32') + in_2 = relay.var('in_2', shape=(10, 10), dtype='float32') + in_3 = relay.var('in_3', shape=(10, 10), dtype='float32') + in_4 = relay.var('in_4', shape=(10, 10), dtype='float32') + in_5 = relay.var('in_5', shape=(10, 10), dtype='float32') + in_6 = relay.var('in_6', shape=(10, 10), dtype='float32') + in_7 = relay.var('in_7', shape=(10, 10), dtype='float32') + in_8 = relay.var('in_8', shape=(10, 10), dtype='float32') + in_9 = relay.var('in_9', shape=(10, 10), dtype='float32') + in_10 = relay.var('in_10', shape=(10, 10), dtype='float32') + + begin0 = compiler_begin(in_1, "test") + begin1 = compiler_begin(in_2, "test") + begin2 = compiler_begin(in_3, "test") + begin3 = compiler_begin(in_4, "test") + node0 = relay.add(begin0, begin1) + node1 = relay.add(begin2, begin3) + node2 = relay.add(node0, node1) + + begin4 = compiler_begin(in_5, "default") + begin5 = compiler_begin(in_6, "default") + begin6 = compiler_begin(in_7, "default") + node3 = relay.subtract(begin4, begin5) + node4 = relay.subtract(begin6, node3) + end0 = compiler_end(node4, "default") + + begin7 = compiler_begin(end0, "test") + begin8 = compiler_begin(in_9, "test") + + node5 = relay.add(node2, begin7) + end1 = compiler_end(node5, "test") + + begin9 = compiler_begin(end1, "default") + begin10 = compiler_begin(in_8, "default") + node6 = relay.subtract(begin10, begin9) + end2 = compiler_end(node6, "default") + + node7 = relay.add(begin8, node5) + end3 = compiler_end(node7, "test") + begin11 = compiler_begin(end3, "test") + begin12 = compiler_begin(end2, "test") + + node8 = relay.add(begin12, begin11) + + begin13 = compiler_begin(in_10, "test") + node9 = relay.add(begin13, node8) + end4 = compiler_end(node9, "test") + + f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end4) + mod = tvm.IRModule.from_expr(f) + return mod + + mod = annotated() + mod = relay.transform.MergeCompilerRegions()(mod) + ref_mod = expected() + assert relay.analysis.alpha_equal(mod, ref_mod) + + +if __name__ == "__main__": + test_diamond_graph_fanouts() + test_example_graph()