From c3394f20a19b10a2368e98ff0395910fee7bab2d Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Sat, 31 Aug 2019 20:01:03 -0700 Subject: [PATCH 1/6] save --- src/relay/ir/pretty_printer.cc | 18 +++++++---- src/relay/pass/fuse_ops.cc | 59 +++++++--------------------------- 2 files changed, 23 insertions(+), 54 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 0ee76dc4c9aa..5197414992f9 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -304,14 +304,16 @@ class PrettyPrinter : * \return The corresponding name. */ Doc AllocTypeVar(const TypeVar& var) { + if (memo_type_.count(var)) { + Doc val = memo_type_[var]; + val << "-malformed-ir"; + return val; + } std::string name = var->var->name_hint; if (name.length() == 0 || !std::isalpha(name[0])) { name = "t" + name; } Doc val = GetUniqueName("%" + name); - if (memo_type_.count(var)) { - val << "-malformed-ir"; - } memo_type_[var] = val; if (var->kind != kType) { val << ": " << Print(var->kind); @@ -325,16 +327,18 @@ class PrettyPrinter : * \return The corresponding name. */ Doc AllocVar(const Var& var) { + // still print if ir is malformed, but show the error. + if (memo_.count(var)) { + Doc val = memo_[var]; + val << "-malformed-ir"; + return val; + } std::string name = var->name_hint(); // always make sure first name is alpha if (name.length() == 0 || !std::isalpha(name[0])) { name = "v" + name; } Doc val = GetUniqueName("%" + name); - // still print if ir is malformed, but show the error. - if (memo_.count(var)) { - val << "-malformed-ir"; - } memo_[var] = val; if (var->type_annotation.defined()) { val << ": " << Print(var->type_annotation); diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 9dc180f26a44..6f801367d87e 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * * \file src/tvm/relay/pass/fuse_ops.cc * @@ -247,11 +247,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor { node->pattern = op_pattern; this->Update(call->op, nullptr, kOpaque); const auto* rtype = call->checked_type().as(); - // pass the message back to all the children it references. + // pass the analysis back to all the children it references. for (size_t i = 0; i < call->args.size(); ++i) { const auto* arg_type = call->args[i]->checked_type().as(); - // specifically check if result type + // specifically check if result type is the same as arguments type OpPatternKind edge_pattern = op_pattern; if (edge_pattern == kBroadcast && arg_type != nullptr && @@ -408,7 +408,7 @@ class DominatorTree { * \param rhs The right node. * \param edge_pattern * The combined edge pattern across all the parents. - * \return The least common ancestor of thw two. + * \return The least common ancestor of the two. */ static Node* LeastCommonAncestor( Node* lhs, @@ -456,15 +456,17 @@ DominatorTree DominatorTree::PostDom(common::Arena* arena, // find the LCAs of all outputs. OpPatternKind pattern = kElemWise; Node* parent = nullptr; + bool init = true; for (auto link = gnode->outputs.head; link != nullptr; link= link->next) { size_t oindex = link->value.node->index; CHECK_LT(oindex, tree.nodes.size()); Node* onode = tree.nodes[oindex]; CHECK(onode != nullptr); - if (parent != nullptr) { - parent = LeastCommonAncestor(parent, onode, &pattern); - } else { + if (init) { parent = onode; + init = false; + } else { + parent = LeastCommonAncestor(parent, onode, &pattern); } pattern = CombinePattern(pattern, link->value.pattern); } @@ -614,7 +616,7 @@ class GraphPartitioner { // merge the current group to the parent if possible. MergeFromTo(gnode, target); for (auto link = src->outputs.head; link != nullptr; link = link->next) { - CommitFuse_(link->value.node, sink, target);; + CommitFuse_(link->value.node, sink, target); } } /*! @@ -863,7 +865,7 @@ class FuseMutator : private ExprMutator { auto* ret_group = gmap_.at(tuple_get)->FindRoot(); auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0]; auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index); - if (ret_group == gmap_.at(tuple_get)) { + if (ret_group->root_ref == tuple_get) { if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) { // Isolated. This case occurs when tuple is created by an Opaque op // e.g. multibox_transform_loc @@ -922,45 +924,8 @@ class FuseMutator : private ExprMutator { } }; -// Temporary solution, should be handled by implementing a "FunctionPass" -// which applies fusion to each function. -struct GlobalVarLiveness : ExprVisitor { - Module module; - std::set visited; - - explicit GlobalVarLiveness(const Module& mod) : module(mod), visited() {} - - void VisitExpr_(const GlobalVarNode* gvar_node) { - auto gvar = GetRef(gvar_node); - if (visited.find(gvar) == visited.end()) { - visited.insert(gvar); - this->VisitExpr(this->module->Lookup(gvar)); - } - } -}; - -std::set LiveGlobals(const Module& mod, const Expr& expr) { - auto gvl = GlobalVarLiveness(mod); - gvl.VisitExpr(expr); - return gvl.visited; -} - Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) { - // First we convert all chains of fusable ops into - // abstracted functions which we mark as primtive - // then we convert these primtive functions into - // new operators. - if (!module.defined()) { - return FuseMutator().Transform(expr, fuse_opt_level); - } else { - auto lgvs = LiveGlobals(module, expr); - for (auto lv : lgvs) { - auto body = module->Lookup(lv); - auto e = FuseMutator().Transform(body, fuse_opt_level); - module->Add(lv, Downcast(e), true); - } - return FuseMutator().Transform(expr, fuse_opt_level); - } + return FuseMutator().Transform(expr, fuse_opt_level); } namespace transform { From 2ab7f446cf3da78bea023e6b3c12d9cb63b8923e Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Sun, 1 Sep 2019 20:58:55 -0700 Subject: [PATCH 2/6] add test --- tests/python/relay/test_pass_fuse_ops.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 4c03840ec4b3..f148502feece 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -541,6 +541,18 @@ def expected(): assert relay.analysis.alpha_equal(new_mod, expected()) +def test_split(): + """Test that the result is well formed.""" + x = relay.var("x", shape=(6, 9)) + y = relay.split(x, 3).astuple() + a = relay.TupleGetItem(y, 0) + b = relay.TupleGetItem(y, 1) + c = relay.TupleGetItem(y, 2) + mod = relay.module.Module() + mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c) + mod = transform.FuseOps()(mod) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() @@ -555,3 +567,4 @@ def expected(): test_inception_like() test_fuse_parallel_injective() test_immutable() + test_split() From e4716f071549021f0b2eb2dc3ad7f265b0ff26e7 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 4 Sep 2019 10:47:33 -0700 Subject: [PATCH 3/6] refactor --- src/relay/pass/fuse_ops.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 6f801367d87e..6c85034b9ec9 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -852,8 +852,8 @@ class FuseMutator : private ExprMutator { } Expr VisitExpr_(const TupleNode* tuple) { - auto* ret_group = gmap_.at(tuple)->FindRoot(); - if (ret_group == gmap_.at(tuple)) { + auto* ret_group = gmap_.at(tuple)->FindRoot(); + if (ret_group->root_ref == tuple) { return ExprMutator::VisitExpr_(tuple); } // This tuple is an intermediate node in the group From 6a1777060536074b02161a5bf5d22f96206bee8e Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 4 Sep 2019 11:25:42 -0700 Subject: [PATCH 4/6] fix indent --- src/relay/pass/fuse_ops.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 6c85034b9ec9..74c635b4d17f 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -852,8 +852,8 @@ class FuseMutator : private ExprMutator { } Expr VisitExpr_(const TupleNode* tuple) { - auto* ret_group = gmap_.at(tuple)->FindRoot(); - if (ret_group->root_ref == tuple) { + auto* ret_group = gmap_.at(tuple)->FindRoot(); + if (ret_group->root_ref == tuple) { return ExprMutator::VisitExpr_(tuple); } // This tuple is an intermediate node in the group From b165549deb79da5aec0d8a7ae70325cf5c370a65 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 4 Sep 2019 12:34:51 -0700 Subject: [PATCH 5/6] save --- src/relay/pass/fuse_ops.cc | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 74c635b4d17f..3aad5b2b32b4 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -436,17 +436,14 @@ class DominatorTree { } return lhs; } -}; - -DominatorTree DominatorTree::PostDom(common::Arena* arena, - const IndexedForwardGraph& graph) { - DominatorTree tree; - tree.nodes.resize(graph.post_dfs_order.size(), nullptr); - // reverse topo order - for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { - size_t index = i - 1; + /*! + * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node. + * \param arena The Arena. + * \param gnode An IndexedForwardGraph Node. + * \return The DominatorTree Node. + */ + Node* GetNode(common::Arena* arena, IndexedForwardGraph::Node* gnode) { Node* tnode = arena->make(); - auto* gnode = graph.post_dfs_order[index]; tnode->gnode = gnode; if (gnode->extern_ref) { tnode->depth = 1; @@ -459,8 +456,8 @@ DominatorTree DominatorTree::PostDom(common::Arena* arena, bool init = true; for (auto link = gnode->outputs.head; link != nullptr; link= link->next) { size_t oindex = link->value.node->index; - CHECK_LT(oindex, tree.nodes.size()); - Node* onode = tree.nodes[oindex]; + CHECK_LT(oindex, nodes.size()); + Node* onode = nodes[oindex]; CHECK(onode != nullptr); if (init) { parent = onode; @@ -474,7 +471,19 @@ DominatorTree DominatorTree::PostDom(common::Arena* arena, tnode->parent = parent; tnode->pattern = pattern; } - tree.nodes[index] = tnode; + return tnode; + } +}; + + +DominatorTree DominatorTree::PostDom(common::Arena* arena, + const IndexedForwardGraph& graph) { + DominatorTree tree; + tree.nodes.resize(graph.post_dfs_order.size(), nullptr); + // reverse topo order + for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { + size_t index = i - 1; + tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]); } return tree; } From 9f55f715e6348df1d8f2ff5d67842fd25a0c588c Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 4 Sep 2019 19:10:52 -0700 Subject: [PATCH 6/6] refactor --- src/relay/pass/fuse_ops.cc | 47 +++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 3aad5b2b32b4..b5faf4c4310f 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -403,7 +403,7 @@ class DominatorTree { return rhs; } /*! - * \brief Find the least common acenstor of the two nodes. + * \brief Find the least common ancestor of the two nodes. * \param lhs The left node. * \param rhs The right node. * \param edge_pattern @@ -436,6 +436,35 @@ class DominatorTree { } return lhs; } + /*! + * \brief Find the least common ancestor of a list of nodes. + * \param nodes the nodes. + * \param edge_pattern + * The combined edge pattern across all the parents. + * \return The least common ancestor of all nodes. + */ + Node* LeastCommonAncestor(const LinkedList& input_nodes, + OpPatternKind* edge_pattern) { + auto link = input_nodes.head; + if (link == nullptr) { + return nullptr; + } + auto get_node = [&](const IndexedForwardGraph::Edge& edge) { + size_t oindex = edge.node->index; + CHECK_LT(oindex, nodes.size()); + Node* onode = nodes[oindex]; + CHECK(onode != nullptr); + return onode; + }; + Node* parent = get_node(link->value); + *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); + link = link->next; + for (; link != nullptr; link = link->next) { + parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern); + *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); + } + return parent; + } /*! * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node. * \param arena The Arena. @@ -452,21 +481,7 @@ class DominatorTree { } else { // find the LCAs of all outputs. OpPatternKind pattern = kElemWise; - Node* parent = nullptr; - bool init = true; - for (auto link = gnode->outputs.head; link != nullptr; link= link->next) { - size_t oindex = link->value.node->index; - CHECK_LT(oindex, nodes.size()); - Node* onode = nodes[oindex]; - CHECK(onode != nullptr); - if (init) { - parent = onode; - init = false; - } else { - parent = LeastCommonAncestor(parent, onode, &pattern); - } - pattern = CombinePattern(pattern, link->value.pattern); - } + Node* parent = LeastCommonAncestor(gnode->outputs, &pattern); tnode->depth = parent ? parent->depth + 1 : 1; tnode->parent = parent; tnode->pattern = pattern;