Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Fix operator fusion for multiple output #3871

Merged
merged 6 commits into from
Sep 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/relay/ir/pretty_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,14 +304,16 @@ class PrettyPrinter :
* \return The corresponding name.
*/
Doc AllocTypeVar(const TypeVar& var) {
if (memo_type_.count(var)) {
Doc val = memo_type_[var];
val << "-malformed-ir";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we change this to print in a more informative way? for example maybe use a colored highlight to show which part is malformed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no color highlighting in the current doc.

return val;
}
std::string name = var->var->name_hint;
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name;
}
Doc val = GetUniqueName("%" + name);
if (memo_type_.count(var)) {
val << "-malformed-ir";
}
memo_type_[var] = val;
if (var->kind != kType) {
val << ": " << Print(var->kind);
Expand All @@ -325,16 +327,18 @@ class PrettyPrinter :
* \return The corresponding name.
*/
Doc AllocVar(const Var& var) {
// still print if ir is malformed, but show the error.
if (memo_.count(var)) {
Doc val = memo_[var];
val << "-malformed-ir";
return val;
}
std::string name = var->name_hint();
// always make sure first name is alpha
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "v" + name;
}
Doc val = GetUniqueName("%" + name);
// still print if ir is malformed, but show the error.
if (memo_.count(var)) {
val << "-malformed-ir";
}
memo_[var] = val;
if (var->type_annotation.defined()) {
val << ": " << Print(var->type_annotation);
Expand Down
129 changes: 59 additions & 70 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
*
* \file src/tvm/relay/pass/fuse_ops.cc
*
Expand Down Expand Up @@ -247,11 +247,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
node->pattern = op_pattern;
this->Update(call->op, nullptr, kOpaque);
const auto* rtype = call->checked_type().as<TensorTypeNode>();
// pass the message back to all the children it references.
// pass the analysis back to all the children it references.
for (size_t i = 0; i < call->args.size(); ++i) {
const auto* arg_type =
call->args[i]->checked_type().as<TensorTypeNode>();
// specifically check if result type
// specifically check if result type is the same as arguments type
OpPatternKind edge_pattern = op_pattern;
if (edge_pattern == kBroadcast &&
arg_type != nullptr &&
Expand Down Expand Up @@ -403,12 +403,12 @@ class DominatorTree {
return rhs;
}
/*!
* \brief Find the least common acenstor of the two nodes.
* \brief Find the least common ancestor of the two nodes.
* \param lhs The left node.
* \param rhs The right node.
* \param edge_pattern
* The combined edge pattern across all the parents.
* \return The least common ancestor of thw two.
* \return The least common ancestor of the two.
*/
static Node* LeastCommonAncestor(
Node* lhs,
Expand Down Expand Up @@ -436,17 +436,43 @@ class DominatorTree {
}
return lhs;
}
};

DominatorTree DominatorTree::PostDom(common::Arena* arena,
const IndexedForwardGraph& graph) {
DominatorTree tree;
tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
// reverse topo order
for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
size_t index = i - 1;
/*!
* \brief Find the least common ancestor of a list of nodes.
* \param nodes the nodes.
* \param edge_pattern
* The combined edge pattern across all the parents.
* \return The least common ancestor of all nodes.
*/
Node* LeastCommonAncestor(const LinkedList<IndexedForwardGraph::Edge>& input_nodes,
OpPatternKind* edge_pattern) {
auto link = input_nodes.head;
if (link == nullptr) {
return nullptr;
}
auto get_node = [&](const IndexedForwardGraph::Edge& edge) {
size_t oindex = edge.node->index;
CHECK_LT(oindex, nodes.size());
Node* onode = nodes[oindex];
CHECK(onode != nullptr);
return onode;
};
Node* parent = get_node(link->value);
*edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
link = link->next;
for (; link != nullptr; link = link->next) {
parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern);
*edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
}
return parent;
}
/*!
* \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node.
* \param arena The Arena.
* \param gnode An IndexedForwardGraph Node.
* \return The DominatorTree Node.
*/
Node* GetNode(common::Arena* arena, IndexedForwardGraph::Node* gnode) {
Node* tnode = arena->make<Node>();
auto* gnode = graph.post_dfs_order[index];
tnode->gnode = gnode;
if (gnode->extern_ref) {
tnode->depth = 1;
Expand All @@ -455,24 +481,24 @@ DominatorTree DominatorTree::PostDom(common::Arena* arena,
} else {
// find the LCAs of all outputs.
OpPatternKind pattern = kElemWise;
Node* parent = nullptr;
for (auto link = gnode->outputs.head; link != nullptr; link= link->next) {
size_t oindex = link->value.node->index;
CHECK_LT(oindex, tree.nodes.size());
Node* onode = tree.nodes[oindex];
CHECK(onode != nullptr);
if (parent != nullptr) {
parent = LeastCommonAncestor(parent, onode, &pattern);
} else {
parent = onode;
}
pattern = CombinePattern(pattern, link->value.pattern);
}
Node* parent = LeastCommonAncestor(gnode->outputs, &pattern);
tnode->depth = parent ? parent->depth + 1 : 1;
tnode->parent = parent;
tnode->pattern = pattern;
}
tree.nodes[index] = tnode;
return tnode;
}
};


DominatorTree DominatorTree::PostDom(common::Arena* arena,
const IndexedForwardGraph& graph) {
DominatorTree tree;
tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
// reverse topo order
for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
size_t index = i - 1;
tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
}
return tree;
}
Expand Down Expand Up @@ -614,7 +640,7 @@ class GraphPartitioner {
// merge the current group to the parent if possible.
MergeFromTo(gnode, target);
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
CommitFuse_(link->value.node, sink, target);;
CommitFuse_(link->value.node, sink, target);
}
}
/*!
Expand Down Expand Up @@ -851,7 +877,7 @@ class FuseMutator : private ExprMutator {

Expr VisitExpr_(const TupleNode* tuple) {
auto* ret_group = gmap_.at(tuple)->FindRoot();
if (ret_group == gmap_.at(tuple)) {
if (ret_group->root_ref == tuple) {
return ExprMutator::VisitExpr_(tuple);
}
// This tuple is an intermediate node in the group
Expand All @@ -863,7 +889,7 @@ class FuseMutator : private ExprMutator {
auto* ret_group = gmap_.at(tuple_get)->FindRoot();
auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index);
if (ret_group == gmap_.at(tuple_get)) {
if (ret_group->root_ref == tuple_get) {
MarisaKirisame marked this conversation as resolved.
Show resolved Hide resolved
if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) {
// Isolated. This case occurs when tuple is created by an Opaque op
// e.g. multibox_transform_loc
Expand Down Expand Up @@ -922,45 +948,8 @@ class FuseMutator : private ExprMutator {
}
};

// Temporary solution, should be handled by implementing a "FunctionPass"
// which applies fusion to each function.
struct GlobalVarLiveness : ExprVisitor {
Module module;
std::set<GlobalVar> visited;

explicit GlobalVarLiveness(const Module& mod) : module(mod), visited() {}

void VisitExpr_(const GlobalVarNode* gvar_node) {
auto gvar = GetRef<GlobalVar>(gvar_node);
if (visited.find(gvar) == visited.end()) {
visited.insert(gvar);
this->VisitExpr(this->module->Lookup(gvar));
}
}
};

std::set<GlobalVar> LiveGlobals(const Module& mod, const Expr& expr) {
auto gvl = GlobalVarLiveness(mod);
gvl.VisitExpr(expr);
return gvl.visited;
}

Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) {
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
if (!module.defined()) {
return FuseMutator().Transform(expr, fuse_opt_level);
} else {
auto lgvs = LiveGlobals(module, expr);
for (auto lv : lgvs) {
auto body = module->Lookup(lv);
auto e = FuseMutator().Transform(body, fuse_opt_level);
module->Add(lv, Downcast<Function>(e), true);
}
return FuseMutator().Transform(expr, fuse_opt_level);
}
return FuseMutator().Transform(expr, fuse_opt_level);
}

namespace transform {
Expand Down
13 changes: 13 additions & 0 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,18 @@ def expected():
assert relay.analysis.alpha_equal(new_mod, expected())


def test_split():
"""Test that the result is well formed."""
x = relay.var("x", shape=(6, 9))
y = relay.split(x, 3).astuple()
a = relay.TupleGetItem(y, 0)
b = relay.TupleGetItem(y, 1)
c = relay.TupleGetItem(y, 2)
mod = relay.module.Module()
mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c)
mod = transform.FuseOps()(mod)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Expand All @@ -555,3 +567,4 @@ def expected():
test_inception_like()
test_fuse_parallel_injective()
test_immutable()
test_split()