diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index aa17c7f3de1c..41aa04095277 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -397,6 +397,17 @@ def MergeComposite(pattern_table): return _ffi_api.MergeComposite(pattern_names, patterns) +def MergeCompilerRegions(): + """Merge together compiler regions. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that merges compiler regions. + """ + return _ffi_api.MergeCompilerRegions() + + def RewriteAnnotatedOps(fallback_device): """Rewrite the annotated program where annotation operators, e.g. `on_deivce`, mark which device an expression should be scheduled to. diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index f8e951bac780..df2eb9643f61 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -55,14 +55,18 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, } // if any of the outputs of src are inputs of dest, they become internal nodes // so remove them from outs + std::vector ins_to_remove; for (const auto& input : dest->ins) { auto call = Downcast(input); auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]); if (it != src->outs.end()) { dest->outs.remove(*it); - dest->ins.remove(input); + ins_to_remove.push_back(input); } } + for (const auto& input : ins_to_remove) { + dest->ins.remove(input); + } regions_.erase(src); } diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index e6f4a18c8b06..c2f7b804cb6a 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -38,38 +38,136 @@ class AnnotateTargetWrapper : public ExprMutator { public: explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {} + Expr Annotate(const Expr& expr) { + return InsertEnd(Mutate(expr)); + } + + bool IsSupported(const Expr& expr) { + if (expr->IsInstance()) { + Call call = Downcast(expr); + auto fannotate = Op::GetAttr("target." + target_); + Op op = Downcast(call->op); + CHECK(op.defined()); + if (fannotate.count(op)) { + return fannotate[op](call->attrs, call->args); + } + } + return false; + } + + Expr InsertEnd(const Expr& arg) { + if (IsSupported(arg)) { + const auto *end_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_end"); + CHECK(end_op); + Expr end = (*end_op)(arg, target_); + return end; + } + return arg; + } + Expr VisitExpr_(const CallNode* cn) { // TODO(@zhiics, @comaniac) Handle composite functions. auto new_e = ExprMutator::VisitExpr_(cn); Call call = Downcast(new_e); - auto fannotate = Op::GetAttr("target." + target_); - Op op = Downcast(call->op); - CHECK(op.defined()); - - if (fannotate.count(op)) { - bool external = fannotate[op](call->attrs, call->args); - if (external) { - tvm::Array compiler_begins; - for (const auto& it : call->args) { - const auto* begin_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); - CHECK(begin_op); - Expr begin = (*begin_op)(it, target_); - compiler_begins.push_back(begin); - } - Expr update_call = Call(call->op, compiler_begins, call->attrs); - const auto* end_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_end"); - CHECK(end_op); - Expr end = (*end_op)(update_call, target_); - return end; + + // add end annotations if the args are supported + Array compiler_ends; + for (const auto& it : call->args) { + compiler_ends.push_back(InsertEnd(it)); + } + call = Call(call->op, compiler_ends, call->attrs); + + // add begin annotations if the call node is supported + if (IsSupported(call)) { + tvm::Array compiler_begins; + const auto* begin_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); + for (const auto& it : call->args) { + CHECK(begin_op); + Expr begin = (*begin_op)(it, target_); + compiler_begins.push_back(begin); } - } else { - LOG(WARNING) << op->name << " in " << target_ - << " is not registered. It will be executed on CPU."; + call = Call(call->op, compiler_begins, call->attrs); } - return new_e; + + return std::move(call); + } + + Expr VisitExpr_(const TupleNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto tup = Downcast(new_e); + Array new_fields; + for (auto field : tup->fields) { + new_fields.push_back(InsertEnd(field)); + } + return Tuple(new_fields); + } + + Expr VisitExpr_(const TupleGetItemNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto get = Downcast(new_e); + return TupleGetItem( + InsertEnd(get->tuple), + get->index); + } + + Expr VisitExpr_(const FunctionNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto func = Downcast(new_e); + return Function( + func->params, + InsertEnd(func->body), + func->ret_type, + func->type_params, + func->attrs); + } + + Expr VisitExpr_(const LetNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto let = Downcast(new_e); + return Let( + let->var, + InsertEnd(let->value), + InsertEnd(let->body)); + } + + Expr VisitExpr_(const IfNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto iff = Downcast(new_e); + return If( + InsertEnd(iff->cond), + InsertEnd(iff->true_branch), + InsertEnd(iff->false_branch)); + } + + Expr VisitExpr_(const RefCreateNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto create = Downcast(new_e); + return RefCreate(InsertEnd(create->value)); + } + + Expr VisitExpr_(const RefReadNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto read = Downcast(new_e); + return RefRead(InsertEnd(read->ref)); + } + + Expr VisitExpr_(const RefWriteNode* op) { + auto new_e = ExprMutator::VisitExpr_(op); + + auto write = Downcast(new_e); + return RefWrite( + InsertEnd(write->ref), + InsertEnd(write->value)); } private: @@ -77,7 +175,7 @@ class AnnotateTargetWrapper : public ExprMutator { }; Expr AnnotateTarget(const Expr& expr, const std::string& target) { - return AnnotateTargetWrapper(target).Mutate(expr); + return AnnotateTargetWrapper(target).Annotate(expr); } } // namespace annotate_target diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc new file mode 100644 index 000000000000..e6ec93aecd42 --- /dev/null +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file src/relay/transforms/merge_compiler_regions.cc + * + * \brief After operators have been annotated with the targets that support + * them, this pass creates regions of the operators for each target. It + * is guaranteed that the regions will have a topological ordering so that + * no data dependency issues exist. + * + * This pass only introduces annotations to indicate the regions. + * partition_graph must subsequently be called to lift these regions out + * as external functions. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../analysis/annotated_region_set.h" + + +namespace tvm { +namespace relay { +namespace partitioning { + +// Cache compiler_begin and compiler_end annotation ops for equivalence check to +// reduce registry lookup overhead. +static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); +static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); + +/*! \brief This is a pre-requisite pass to merge-supported pass. + * The AnnotateRestDefault pass will put "default" Compiler Annotations to + * nodes that are not annotated already. This is there to ensure that the + * user will not leave un-annotated nodes MergeCompilerRegions pass is run. + * Why? Because, MergeCompilerRegions pass assumes every node to be annotated. + */ +class AnnotateRestDefault : public ExprMutator { + public: + explicit AnnotateRestDefault(const Expr& expr) { + regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op); + } + + Expr Annotate(const Expr& expr) { + // Its a function that is being passed on to annotate + func_ = Downcast(expr); + + // Corner Case CC1 : If the last node does not belong + // to a region nede to add a compiler_end + auto region = regions_->GetRegion(func_->body); + auto mutated_expr = this->VisitExpr(expr); + if (!region.defined()) { + func_ = Downcast(mutated_expr); + // CC1 : add that compiler end after mutation + auto body = AddCompilerEnd_(func_->body); + func_ = Function(func_->params, body, + body->checked_type_, {}, DictAttrs()); + return Downcast(func_); + } + return mutated_expr; + } + + /*! \brief This function adds compiler ends to nodes that + * have a region AND they should not be arguments of the + * original function + * \param expr The expression to add a compiler end to. + * \return expr The expression with or without a compiler end added. + */ + Expr AddCompilerEnd(const Expr& expr) { + auto region = regions_->GetRegion(expr); + auto visited_expr = VisitExpr(expr); + + // The compiler ends are added to nodes that does have a region + // AND they should not be arguments of the original function + if (!region.defined() && + std::find(func_->params.begin(), + func_->params.end(), visited_expr) + == func_->params.end()) { + return AddCompilerEnd_(visited_expr); + } + return visited_expr; + } + + Expr AddCompilerEnd_(const Expr& expr) { + const auto* end_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_end"); + CHECK(end_op); + Expr end = (*end_op)(expr, target_); + return end; + } + + Expr VisitExpr_(const CallNode* call) final { + auto op_node = call->op.as(); + auto ret = GetRef(call); + + Array args; + + // Add compiler ends if the parent is supported + for (auto arg : call->args) { + args.push_back(AddCompilerEnd(arg)); + } + + if (op_node == nullptr || call->attrs.as() == nullptr) { + // Skip annotatation ops, only add default compiler to actual compute nodes + + auto region = regions_->GetRegion(ret); + if (!region.defined()) { + // if the current node does not belong to annotated region + // annotate the all incoming edges (args) + // with "default" compile_begin and compiler_end annotations. + tvm::Array compiler_begins; + for (auto arg : args) { + const auto* begin_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); + CHECK(begin_op); + Expr begin = (*begin_op)(arg, target_); + compiler_begins.push_back(begin); + } + Expr update_call = Call(call->op, compiler_begins, call->attrs); + return update_call; + } + } + return Call(call->op, args, call->attrs); + }; + + Expr VisitExpr_(const TupleNode *op) { + auto new_e = ExprMutator::VisitExpr_(op); + auto tup = Downcast(new_e); + Array new_fields; + for (auto field : tup->fields) { + new_fields.push_back(AddCompilerEnd(field)); + } + return Tuple(new_fields); + } + + Expr VisitExpr_(const TupleGetItemNode *op) { + auto new_e = ExprMutator::VisitExpr_(op); + auto get = Downcast(new_e); + return TupleGetItem(AddCompilerEnd(get->tuple), get->index); + } + + Expr VisitExpr_(const LetNode *op) { + auto new_e = ExprMutator::VisitExpr_(op); + auto let = Downcast(new_e); + return Let( + let->var, + AddCompilerEnd(let->value), + AddCompilerEnd(let->body)); + } + + Expr VisitExpr_(const IfNode *op) { + auto new_e = ExprMutator::VisitExpr_(op); + auto iff = Downcast(new_e); + return If( + AddCompilerEnd(iff->cond), + AddCompilerEnd(iff->true_branch), + AddCompilerEnd(iff->false_branch)); + } + + Expr VisitExpr_(const RefCreateNode *op) { + auto new_e = ExprMutator::VisitExpr_(op); + auto create = Downcast(new_e); + return RefCreate(AddCompilerEnd(create->value)); + } + + Expr VisitExpr_(const RefReadNode *op) { + auto new_e = ExprMutator::VisitExpr_(op); + auto read = Downcast(new_e); + return RefRead(AddCompilerEnd(read->ref)); + } + + Expr VisitExpr_(const RefWriteNode *op) { + auto new_e = ExprMutator::VisitExpr_(op); + auto write = Downcast(new_e); + return RefWrite( + AddCompilerEnd(write->ref), + AddCompilerEnd(write->value)); + } + + private: + AnnotatedRegionSet regions_; + const std::string target_ = "default"; + Function func_; +}; + +class MergeAnnotations : public ExprMutator { + public: + explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {} + + Expr VisitExpr_(const CallNode* call) final { + if (call->op == compiler_begin_op) { + if (call->args[0]->IsInstance()) { + auto arg = Downcast(call->args[0]); + if (arg->op == compiler_end_op) { + auto region1 = regions_->GetRegion(GetRef(call)); + auto region2 = regions_->GetRegion(arg); + if (region1 == region2) { + return ExprMutator::VisitExpr(arg->args[0]); + } + } + } + } + return ExprMutator::VisitExpr_(call); + } + + private: + AnnotatedRegionSet regions_; +}; + +class RegionMerger : public ExprVisitor { + public: + explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {} + + void VisitExpr_(const CallNode* call) final { + if (call->op == compiler_end_op) { + auto region = regions_->GetRegion(GetRef(call)); + // set the region target + auto compiler_attrs = call->attrs.as(); + region_targets_[region->GetID()] = compiler_attrs->compiler; + std::vector mergeable_regions; + // first look at the region args to determine the parent regions + for (const auto& arg : region->GetInputs()) { + // all args should be begin annotations + auto begin = Downcast(arg); + CHECK_EQ(begin->op, compiler_begin_op); + // the arguments of the begin annotations will be in the parent regions + auto parent_region = regions_->GetRegion(begin->args[0]); + // if there is no parent region, move on + if (!parent_region.defined()) continue; + // merge the parent region if it hasn't been done already + if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) { + VisitExpr(begin->args[0]); + } + mergeable_regions.push_back(parent_region); + } + auto& region_restrictions = region_restrictions_[region->GetID()]; + for (const auto& parent_region : mergeable_regions) { + // add all the parent restrictions to the current region + auto parent_restrictions = region_restrictions_[parent_region->GetID()]; + region_restrictions.insert(parent_restrictions.begin(), + parent_restrictions.end()); + } + for (const auto& parent_region : mergeable_regions) { + bool merged = false; + // check the parent region has the same target + if (region_targets_[parent_region->GetID()] == compiler_attrs->compiler) { + // check the parent region isn't in the restrictions + if (region_restrictions.find(parent_region->GetID()) == region_restrictions.end()) { + // merge the parent region into the current region + regions_->MergeRegions(parent_region, region); + // update the restrictions of all other regions to reflect the change in id + for (const auto& r : regions_) { + auto& restrictions = region_restrictions_[r->GetID()]; + if (restrictions.find(parent_region->GetID()) != restrictions.end()) { + restrictions.erase(parent_region->GetID()); + restrictions.insert(region->GetID()); + } + } + merged = true; + } + } + // if the parent wasn't merged, add it as a restriction to the current region + if (!merged) + region_restrictions.insert(parent_region->GetID()); + } + merged_regions_.insert(region->GetID()); + } + ExprVisitor::VisitExpr_(call); + } + + private: + AnnotatedRegionSet regions_; + std::unordered_set merged_regions_; + std::map> region_restrictions_; + std::map region_targets_; +}; + + +Expr MergeCompilerRegions(const Expr& expr) { + // Annotate all the nodes that aren't annotated as 'default'. + AnnotateRestDefault anno_default(expr); + auto expr_all_annotated = anno_default.Annotate(expr); + + // Create regions using the annotations. + AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr_all_annotated, + compiler_begin_op, compiler_end_op); + + // By now, all the nodes have some sort of annotation. + // Region merger is an ExprVisitor that will update the + // AnnotatedRegionSet, merging all the regions that can be merged. + RegionMerger merger(regions); + merger.VisitExpr(expr_all_annotated); + + // This updates the expression to remove annotations that are now + // 'internal' to a merged region. + MergeAnnotations merge_anno(regions); + return merge_anno.Mutate(expr_all_annotated); +} + +} // namespace partitioning + +namespace transform { + +Pass MergeCompilerRegions() { + runtime::TypedPackedFunc part_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(partitioning::MergeCompilerRegions(f)); + }; + auto partitioned = CreateFunctionPass(part_func, 0, "MergeCompilerRegions", {}); + return Sequential({partitioned, InferType()}); +} + +TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions") +.set_body_typed(transform::MergeCompilerRegions); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_annotate_target.py index 12a15dcb2c3a..23f56c4febad 100644 --- a/tests/python/relay/test_annotate_target.py +++ b/tests/python/relay/test_annotate_target.py @@ -22,6 +22,7 @@ import tvm import tvm.relay.testing +import tvm.relay.op as reg import tvm.relay.transform as transform from tvm import relay from tvm import runtime @@ -183,6 +184,41 @@ def test_extern_dnnl_mobilenet(): (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params) +@reg.register("nn.relu", "target.test") +def relu(attrs, args): + return True + + +def test_multiple_ends(): + def before(): + x = relay.var("x", shape=(10, 10)) + r = relay.nn.relu(x) + a_1 = relay.abs(r) + a_2 = relay.abs(r) + out = relay.add(a_1, a_2) + f = relay.Function([x], out) + mod = tvm.IRModule.from_expr(f) + return mod + + def after(): + x = relay.var("x", shape=(10, 10)) + cb_1 = relay.annotation.compiler_begin(x, "test") + r = relay.nn.relu(cb_1) + ce_1 = relay.annotation.compiler_end(r, "test") + ce_2 = relay.annotation.compiler_end(r, "test") + a_1 = relay.abs(ce_1) + a_2 = relay.abs(ce_2) + out = relay.add(a_1, a_2) + f = relay.Function([x], out) + mod = tvm.IRModule.from_expr(f) + return mod + + result = transform.AnnotateTarget("test")(before()) + expected = transform.InferType()(after()) + assert relay.analysis.alpha_equal(expected, result) + + if __name__ == "__main__": + test_multiple_ends() test_extern_dnnl() test_extern_dnnl_mobilenet() diff --git a/tests/python/relay/test_pass_merge_compiler_regions.py b/tests/python/relay/test_pass_merge_compiler_regions.py new file mode 100644 index 000000000000..04ff46e337fd --- /dev/null +++ b/tests/python/relay/test_pass_merge_compiler_regions.py @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for merge compiler regions.""" +import tvm +from tvm import relay +from tvm.relay.op.annotation import compiler_begin, compiler_end +from tvm.relay.testing import run_opt_pass + + +def test_diamond_graph_fanouts(): + """ + This tests that the data dependencies present in a diamond-shaped + graph are correctly resolved by the merging pass. + + O = supported by target + X = not supported by target + + O O + / \ / \ + O X --> O + + X + \ / \ / + O O + + Note that we can't just merge the three supported operators together, + otherwise both subgraphs would depend on the other. + """ + def diamond_graph_fanouts(): + data = relay.var('data', shape=(10, 10)) + cb_1 = compiler_begin(data, "test") + O_1 = relay.abs(cb_1) + ce_1 = compiler_end(O_1, "test") + ce_2 = compiler_end(O_1, "test") + cb_2 = compiler_begin(ce_1, "test") + O_2 = relay.nn.relu(cb_2) + ce_3 = compiler_end(O_2, "test") + + X = relay.tanh(ce_2) + + cb_3 = compiler_begin(ce_3, "test") + cb_4 = compiler_begin(X, "test") + O_3 = relay.add(cb_3, cb_4) + ce_4 = compiler_end(O_3, "test") + + diamond = relay.Function([data], ce_4) + return diamond + + def expected(): + data = relay.var('data', shape=(10, 10)) + cb_1 = compiler_begin(data, "test") + O_1 = relay.abs(cb_1) + ce_2 = compiler_end(O_1, "test") + O_2 = relay.nn.relu(O_1) + ce_3 = compiler_end(O_2, "test") + + cb_x = compiler_begin(ce_2, "default") + X = relay.tanh(cb_x) + ce_x1 = compiler_end(X, "default") + ce_x2 = compiler_end(X, "default") + + cb_3 = compiler_begin(ce_3, "test") + cb_4 = compiler_begin(ce_x1, "test") + O_3 = relay.add(cb_3, cb_4) + ce_4 = compiler_end(O_3, "test") + + func = relay.Function([data], ce_4) + return func + + result = run_opt_pass(diamond_graph_fanouts(), relay.transform.MergeCompilerRegions()) + golden = run_opt_pass(expected(), relay.transform.InferType()) + assert relay.analysis.alpha_equal(result, golden) + + +def test_example_graph(): + """This tests the merging algorithm on the example used in the RFC. + + See the RFC here: https://discuss.tvm.ai/t/relay-improved-graph-partitioning-algorithm/5830 + Blue nodes are adds, red nodes are subtracts. + """ + def annotated(): + in_1 = relay.var('in_1', shape=(10, 10), dtype='float32') + in_2 = relay.var('in_2', shape=(10, 10), dtype='float32') + in_3 = relay.var('in_3', shape=(10, 10), dtype='float32') + in_4 = relay.var('in_4', shape=(10, 10), dtype='float32') + in_5 = relay.var('in_5', shape=(10, 10), dtype='float32') + in_6 = relay.var('in_6', shape=(10, 10), dtype='float32') + in_7 = relay.var('in_7', shape=(10, 10), dtype='float32') + in_8 = relay.var('in_8', shape=(10, 10), dtype='float32') + in_9 = relay.var('in_9', shape=(10, 10), dtype='float32') + in_10 = relay.var('in_10', shape=(10, 10), dtype='float32') + + begin0 = compiler_begin(in_1, "test") + begin1 = compiler_begin(in_2, "test") + begin2 = compiler_begin(in_3, "test") + begin3 = compiler_begin(in_4, "test") + node0 = relay.add(begin0, begin1) + node1 = relay.add(begin2, begin3) + end0 = compiler_end(node0, "test") + end1 = compiler_end(node1, "test") + begin4 = compiler_begin(end0, "test") + begin5 = compiler_begin(end1, "test") + node2 = relay.add(begin4, begin5) + end2 = compiler_end(node2, "test") + + node3 = relay.subtract(in_5, in_6) + node4 = relay.subtract(in_7, node3) + + begin6 = compiler_begin(end2, "test") + begin7 = compiler_begin(node4, "test") + node5 = relay.add(begin6, begin7) + end3 = compiler_end(node5, "test") + end4 = compiler_end(node5, "test") + node6 = relay.subtract(in_8, end3) + begin8 = compiler_begin(in_9, "test") + begin9 = compiler_begin(end4, "test") + node7 = relay.add(begin8, begin9) + end5 = compiler_end(node7, "test") + + begin10 = compiler_begin(node6, "test") + begin11 = compiler_begin(end5, "test") + node8 = relay.add(begin10, begin11) + end6 = compiler_end(node8, "test") + begin12 = compiler_begin(in_10, "test") + begin13 = compiler_begin(end6, "test") + node9 = relay.add(begin12, begin13) + end7 = compiler_end(node9, "test") + + f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end7) + mod = tvm.IRModule.from_expr(f) + return mod + + def expected(): + in_1 = relay.var('in_1', shape=(10, 10), dtype='float32') + in_2 = relay.var('in_2', shape=(10, 10), dtype='float32') + in_3 = relay.var('in_3', shape=(10, 10), dtype='float32') + in_4 = relay.var('in_4', shape=(10, 10), dtype='float32') + in_5 = relay.var('in_5', shape=(10, 10), dtype='float32') + in_6 = relay.var('in_6', shape=(10, 10), dtype='float32') + in_7 = relay.var('in_7', shape=(10, 10), dtype='float32') + in_8 = relay.var('in_8', shape=(10, 10), dtype='float32') + in_9 = relay.var('in_9', shape=(10, 10), dtype='float32') + in_10 = relay.var('in_10', shape=(10, 10), dtype='float32') + + begin0 = compiler_begin(in_1, "test") + begin1 = compiler_begin(in_2, "test") + begin2 = compiler_begin(in_3, "test") + begin3 = compiler_begin(in_4, "test") + node0 = relay.add(begin0, begin1) + node1 = relay.add(begin2, begin3) + node2 = relay.add(node0, node1) + + begin4 = compiler_begin(in_5, "default") + begin5 = compiler_begin(in_6, "default") + begin6 = compiler_begin(in_7, "default") + node3 = relay.subtract(begin4, begin5) + node4 = relay.subtract(begin6, node3) + end0 = compiler_end(node4, "default") + + begin7 = compiler_begin(end0, "test") + begin8 = compiler_begin(in_9, "test") + + node5 = relay.add(node2, begin7) + end1 = compiler_end(node5, "test") + + begin9 = compiler_begin(end1, "default") + begin10 = compiler_begin(in_8, "default") + node6 = relay.subtract(begin10, begin9) + end2 = compiler_end(node6, "default") + + node7 = relay.add(begin8, node5) + end3 = compiler_end(node7, "test") + begin11 = compiler_begin(end3, "test") + begin12 = compiler_begin(end2, "test") + + node8 = relay.add(begin12, begin11) + + begin13 = compiler_begin(in_10, "test") + node9 = relay.add(begin13, node8) + end4 = compiler_end(node9, "test") + + f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end4) + mod = tvm.IRModule.from_expr(f) + return mod + + mod = annotated() + mod = relay.transform.MergeCompilerRegions()(mod) + ref_mod = expected() + assert relay.analysis.alpha_equal(mod, ref_mod) + + +if __name__ == "__main__": + test_diamond_graph_fanouts() + test_example_graph()