diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 0ee76dc4c9aa..5197414992f9 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -304,14 +304,16 @@ class PrettyPrinter : * \return The corresponding name. */ Doc AllocTypeVar(const TypeVar& var) { + if (memo_type_.count(var)) { + Doc val = memo_type_[var]; + val << "-malformed-ir"; + return val; + } std::string name = var->var->name_hint; if (name.length() == 0 || !std::isalpha(name[0])) { name = "t" + name; } Doc val = GetUniqueName("%" + name); - if (memo_type_.count(var)) { - val << "-malformed-ir"; - } memo_type_[var] = val; if (var->kind != kType) { val << ": " << Print(var->kind); @@ -325,16 +327,18 @@ class PrettyPrinter : * \return The corresponding name. */ Doc AllocVar(const Var& var) { + // still print if ir is malformed, but show the error. + if (memo_.count(var)) { + Doc val = memo_[var]; + val << "-malformed-ir"; + return val; + } std::string name = var->name_hint(); // always make sure first name is alpha if (name.length() == 0 || !std::isalpha(name[0])) { name = "v" + name; } Doc val = GetUniqueName("%" + name); - // still print if ir is malformed, but show the error. - if (memo_.count(var)) { - val << "-malformed-ir"; - } memo_[var] = val; if (var->type_annotation.defined()) { val << ": " << Print(var->type_annotation); diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 9dc180f26a44..b5faf4c4310f 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * * \file src/tvm/relay/pass/fuse_ops.cc * @@ -247,11 +247,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor { node->pattern = op_pattern; this->Update(call->op, nullptr, kOpaque); const auto* rtype = call->checked_type().as(); - // pass the message back to all the children it references. + // pass the analysis back to all the children it references. for (size_t i = 0; i < call->args.size(); ++i) { const auto* arg_type = call->args[i]->checked_type().as(); - // specifically check if result type + // specifically check if result type is the same as arguments type OpPatternKind edge_pattern = op_pattern; if (edge_pattern == kBroadcast && arg_type != nullptr && @@ -403,12 +403,12 @@ class DominatorTree { return rhs; } /*! - * \brief Find the least common acenstor of the two nodes. + * \brief Find the least common ancestor of the two nodes. * \param lhs The left node. * \param rhs The right node. * \param edge_pattern * The combined edge pattern across all the parents. - * \return The least common ancestor of thw two. + * \return The least common ancestor of the two. */ static Node* LeastCommonAncestor( Node* lhs, @@ -436,17 +436,43 @@ class DominatorTree { } return lhs; } -}; - -DominatorTree DominatorTree::PostDom(common::Arena* arena, - const IndexedForwardGraph& graph) { - DominatorTree tree; - tree.nodes.resize(graph.post_dfs_order.size(), nullptr); - // reverse topo order - for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { - size_t index = i - 1; + /*! + * \brief Find the least common ancestor of a list of nodes. + * \param nodes the nodes. + * \param edge_pattern + * The combined edge pattern across all the parents. + * \return The least common ancestor of all nodes. + */ + Node* LeastCommonAncestor(const LinkedList& input_nodes, + OpPatternKind* edge_pattern) { + auto link = input_nodes.head; + if (link == nullptr) { + return nullptr; + } + auto get_node = [&](const IndexedForwardGraph::Edge& edge) { + size_t oindex = edge.node->index; + CHECK_LT(oindex, nodes.size()); + Node* onode = nodes[oindex]; + CHECK(onode != nullptr); + return onode; + }; + Node* parent = get_node(link->value); + *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); + link = link->next; + for (; link != nullptr; link = link->next) { + parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern); + *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); + } + return parent; + } + /*! + * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node. + * \param arena The Arena. + * \param gnode An IndexedForwardGraph Node. + * \return The DominatorTree Node. + */ + Node* GetNode(common::Arena* arena, IndexedForwardGraph::Node* gnode) { Node* tnode = arena->make(); - auto* gnode = graph.post_dfs_order[index]; tnode->gnode = gnode; if (gnode->extern_ref) { tnode->depth = 1; @@ -455,24 +481,24 @@ DominatorTree DominatorTree::PostDom(common::Arena* arena, } else { // find the LCAs of all outputs. OpPatternKind pattern = kElemWise; - Node* parent = nullptr; - for (auto link = gnode->outputs.head; link != nullptr; link= link->next) { - size_t oindex = link->value.node->index; - CHECK_LT(oindex, tree.nodes.size()); - Node* onode = tree.nodes[oindex]; - CHECK(onode != nullptr); - if (parent != nullptr) { - parent = LeastCommonAncestor(parent, onode, &pattern); - } else { - parent = onode; - } - pattern = CombinePattern(pattern, link->value.pattern); - } + Node* parent = LeastCommonAncestor(gnode->outputs, &pattern); tnode->depth = parent ? parent->depth + 1 : 1; tnode->parent = parent; tnode->pattern = pattern; } - tree.nodes[index] = tnode; + return tnode; + } +}; + + +DominatorTree DominatorTree::PostDom(common::Arena* arena, + const IndexedForwardGraph& graph) { + DominatorTree tree; + tree.nodes.resize(graph.post_dfs_order.size(), nullptr); + // reverse topo order + for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { + size_t index = i - 1; + tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]); } return tree; } @@ -614,7 +640,7 @@ class GraphPartitioner { // merge the current group to the parent if possible. MergeFromTo(gnode, target); for (auto link = src->outputs.head; link != nullptr; link = link->next) { - CommitFuse_(link->value.node, sink, target);; + CommitFuse_(link->value.node, sink, target); } } /*! @@ -851,7 +877,7 @@ class FuseMutator : private ExprMutator { Expr VisitExpr_(const TupleNode* tuple) { auto* ret_group = gmap_.at(tuple)->FindRoot(); - if (ret_group == gmap_.at(tuple)) { + if (ret_group->root_ref == tuple) { return ExprMutator::VisitExpr_(tuple); } // This tuple is an intermediate node in the group @@ -863,7 +889,7 @@ class FuseMutator : private ExprMutator { auto* ret_group = gmap_.at(tuple_get)->FindRoot(); auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0]; auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index); - if (ret_group == gmap_.at(tuple_get)) { + if (ret_group->root_ref == tuple_get) { if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) { // Isolated. This case occurs when tuple is created by an Opaque op // e.g. multibox_transform_loc @@ -922,45 +948,8 @@ class FuseMutator : private ExprMutator { } }; -// Temporary solution, should be handled by implementing a "FunctionPass" -// which applies fusion to each function. -struct GlobalVarLiveness : ExprVisitor { - Module module; - std::set visited; - - explicit GlobalVarLiveness(const Module& mod) : module(mod), visited() {} - - void VisitExpr_(const GlobalVarNode* gvar_node) { - auto gvar = GetRef(gvar_node); - if (visited.find(gvar) == visited.end()) { - visited.insert(gvar); - this->VisitExpr(this->module->Lookup(gvar)); - } - } -}; - -std::set LiveGlobals(const Module& mod, const Expr& expr) { - auto gvl = GlobalVarLiveness(mod); - gvl.VisitExpr(expr); - return gvl.visited; -} - Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) { - // First we convert all chains of fusable ops into - // abstracted functions which we mark as primtive - // then we convert these primtive functions into - // new operators. - if (!module.defined()) { - return FuseMutator().Transform(expr, fuse_opt_level); - } else { - auto lgvs = LiveGlobals(module, expr); - for (auto lv : lgvs) { - auto body = module->Lookup(lv); - auto e = FuseMutator().Transform(body, fuse_opt_level); - module->Add(lv, Downcast(e), true); - } - return FuseMutator().Transform(expr, fuse_opt_level); - } + return FuseMutator().Transform(expr, fuse_opt_level); } namespace transform { diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 4c03840ec4b3..f148502feece 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -541,6 +541,18 @@ def expected(): assert relay.analysis.alpha_equal(new_mod, expected()) +def test_split(): + """Test that the result is well formed.""" + x = relay.var("x", shape=(6, 9)) + y = relay.split(x, 3).astuple() + a = relay.TupleGetItem(y, 0) + b = relay.TupleGetItem(y, 1) + c = relay.TupleGetItem(y, 2) + mod = relay.module.Module() + mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c) + mod = transform.FuseOps()(mod) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() @@ -555,3 +567,4 @@ def expected(): test_inception_like() test_fuse_parallel_injective() test_immutable() + test_split()