Skip to content

Commit

Permalink
[Relay] Fix operator fusion for multiple output (apache#3871)
Browse files Browse the repository at this point in the history
* save

* add test

* refactor

* fix indent

* save

* refactor
  • Loading branch information
MarisaKirisame authored and wweic committed Sep 16, 2019
1 parent e4f97f6 commit 86710c6
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 77 deletions.
18 changes: 11 additions & 7 deletions src/relay/ir/pretty_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,14 +304,16 @@ class PrettyPrinter :
* \return The corresponding name.
*/
Doc AllocTypeVar(const TypeVar& var) {
if (memo_type_.count(var)) {
Doc val = memo_type_[var];
val << "-malformed-ir";
return val;
}
std::string name = var->var->name_hint;
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name;
}
Doc val = GetUniqueName("%" + name);
if (memo_type_.count(var)) {
val << "-malformed-ir";
}
memo_type_[var] = val;
if (var->kind != kType) {
val << ": " << Print(var->kind);
Expand All @@ -325,16 +327,18 @@ class PrettyPrinter :
* \return The corresponding name.
*/
Doc AllocVar(const Var& var) {
// still print if ir is malformed, but show the error.
if (memo_.count(var)) {
Doc val = memo_[var];
val << "-malformed-ir";
return val;
}
std::string name = var->name_hint();
// always make sure first name is alpha
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "v" + name;
}
Doc val = GetUniqueName("%" + name);
// still print if ir is malformed, but show the error.
if (memo_.count(var)) {
val << "-malformed-ir";
}
memo_[var] = val;
if (var->type_annotation.defined()) {
val << ": " << Print(var->type_annotation);
Expand Down
129 changes: 59 additions & 70 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
*
* \file src/tvm/relay/pass/fuse_ops.cc
*
Expand Down Expand Up @@ -247,11 +247,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
node->pattern = op_pattern;
this->Update(call->op, nullptr, kOpaque);
const auto* rtype = call->checked_type().as<TensorTypeNode>();
// pass the message back to all the children it references.
// pass the analysis back to all the children it references.
for (size_t i = 0; i < call->args.size(); ++i) {
const auto* arg_type =
call->args[i]->checked_type().as<TensorTypeNode>();
// specifically check if result type
// specifically check if result type is the same as arguments type
OpPatternKind edge_pattern = op_pattern;
if (edge_pattern == kBroadcast &&
arg_type != nullptr &&
Expand Down Expand Up @@ -403,12 +403,12 @@ class DominatorTree {
return rhs;
}
/*!
* \brief Find the least common acenstor of the two nodes.
* \brief Find the least common ancestor of the two nodes.
* \param lhs The left node.
* \param rhs The right node.
* \param edge_pattern
* The combined edge pattern across all the parents.
* \return The least common ancestor of thw two.
* \return The least common ancestor of the two.
*/
static Node* LeastCommonAncestor(
Node* lhs,
Expand Down Expand Up @@ -436,17 +436,43 @@ class DominatorTree {
}
return lhs;
}
};

DominatorTree DominatorTree::PostDom(common::Arena* arena,
const IndexedForwardGraph& graph) {
DominatorTree tree;
tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
// reverse topo order
for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
size_t index = i - 1;
/*!
* \brief Find the least common ancestor of a list of nodes.
* \param nodes the nodes.
* \param edge_pattern
* The combined edge pattern across all the parents.
* \return The least common ancestor of all nodes.
*/
Node* LeastCommonAncestor(const LinkedList<IndexedForwardGraph::Edge>& input_nodes,
OpPatternKind* edge_pattern) {
auto link = input_nodes.head;
if (link == nullptr) {
return nullptr;
}
auto get_node = [&](const IndexedForwardGraph::Edge& edge) {
size_t oindex = edge.node->index;
CHECK_LT(oindex, nodes.size());
Node* onode = nodes[oindex];
CHECK(onode != nullptr);
return onode;
};
Node* parent = get_node(link->value);
*edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
link = link->next;
for (; link != nullptr; link = link->next) {
parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern);
*edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
}
return parent;
}
/*!
* \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node.
* \param arena The Arena.
* \param gnode An IndexedForwardGraph Node.
* \return The DominatorTree Node.
*/
Node* GetNode(common::Arena* arena, IndexedForwardGraph::Node* gnode) {
Node* tnode = arena->make<Node>();
auto* gnode = graph.post_dfs_order[index];
tnode->gnode = gnode;
if (gnode->extern_ref) {
tnode->depth = 1;
Expand All @@ -455,24 +481,24 @@ DominatorTree DominatorTree::PostDom(common::Arena* arena,
} else {
// find the LCAs of all outputs.
OpPatternKind pattern = kElemWise;
Node* parent = nullptr;
for (auto link = gnode->outputs.head; link != nullptr; link= link->next) {
size_t oindex = link->value.node->index;
CHECK_LT(oindex, tree.nodes.size());
Node* onode = tree.nodes[oindex];
CHECK(onode != nullptr);
if (parent != nullptr) {
parent = LeastCommonAncestor(parent, onode, &pattern);
} else {
parent = onode;
}
pattern = CombinePattern(pattern, link->value.pattern);
}
Node* parent = LeastCommonAncestor(gnode->outputs, &pattern);
tnode->depth = parent ? parent->depth + 1 : 1;
tnode->parent = parent;
tnode->pattern = pattern;
}
tree.nodes[index] = tnode;
return tnode;
}
};


DominatorTree DominatorTree::PostDom(common::Arena* arena,
const IndexedForwardGraph& graph) {
DominatorTree tree;
tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
// reverse topo order
for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
size_t index = i - 1;
tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
}
return tree;
}
Expand Down Expand Up @@ -614,7 +640,7 @@ class GraphPartitioner {
// merge the current group to the parent if possible.
MergeFromTo(gnode, target);
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
CommitFuse_(link->value.node, sink, target);;
CommitFuse_(link->value.node, sink, target);
}
}
/*!
Expand Down Expand Up @@ -851,7 +877,7 @@ class FuseMutator : private ExprMutator {

Expr VisitExpr_(const TupleNode* tuple) {
auto* ret_group = gmap_.at(tuple)->FindRoot();
if (ret_group == gmap_.at(tuple)) {
if (ret_group->root_ref == tuple) {
return ExprMutator::VisitExpr_(tuple);
}
// This tuple is an intermediate node in the group
Expand All @@ -863,7 +889,7 @@ class FuseMutator : private ExprMutator {
auto* ret_group = gmap_.at(tuple_get)->FindRoot();
auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index);
if (ret_group == gmap_.at(tuple_get)) {
if (ret_group->root_ref == tuple_get) {
if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) {
// Isolated. This case occurs when tuple is created by an Opaque op
// e.g. multibox_transform_loc
Expand Down Expand Up @@ -922,45 +948,8 @@ class FuseMutator : private ExprMutator {
}
};

// Temporary solution, should be handled by implementing a "FunctionPass"
// which applies fusion to each function.
struct GlobalVarLiveness : ExprVisitor {
Module module;
std::set<GlobalVar> visited;

explicit GlobalVarLiveness(const Module& mod) : module(mod), visited() {}

void VisitExpr_(const GlobalVarNode* gvar_node) {
auto gvar = GetRef<GlobalVar>(gvar_node);
if (visited.find(gvar) == visited.end()) {
visited.insert(gvar);
this->VisitExpr(this->module->Lookup(gvar));
}
}
};

std::set<GlobalVar> LiveGlobals(const Module& mod, const Expr& expr) {
auto gvl = GlobalVarLiveness(mod);
gvl.VisitExpr(expr);
return gvl.visited;
}

Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) {
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
if (!module.defined()) {
return FuseMutator().Transform(expr, fuse_opt_level);
} else {
auto lgvs = LiveGlobals(module, expr);
for (auto lv : lgvs) {
auto body = module->Lookup(lv);
auto e = FuseMutator().Transform(body, fuse_opt_level);
module->Add(lv, Downcast<Function>(e), true);
}
return FuseMutator().Transform(expr, fuse_opt_level);
}
return FuseMutator().Transform(expr, fuse_opt_level);
}

namespace transform {
Expand Down
13 changes: 13 additions & 0 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,18 @@ def expected():
assert relay.analysis.alpha_equal(new_mod, expected())


def test_split():
"""Test that the result is well formed."""
x = relay.var("x", shape=(6, 9))
y = relay.split(x, 3).astuple()
a = relay.TupleGetItem(y, 0)
b = relay.TupleGetItem(y, 1)
c = relay.TupleGetItem(y, 2)
mod = relay.module.Module()
mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c)
mod = transform.FuseOps()(mod)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Expand All @@ -555,3 +567,4 @@ def expected():
test_inception_like()
test_fuse_parallel_injective()
test_immutable()
test_split()

0 comments on commit 86710c6

Please sign in to comment.