Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Fix operator fusion for multiple output #3871

Merged
merged 6 commits into from
Sep 5, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/relay/ir/pretty_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,14 +304,16 @@ class PrettyPrinter :
* \return The corresponding name.
*/
Doc AllocTypeVar(const TypeVar& var) {
if (memo_type_.count(var)) {
Doc val = memo_type_[var];
val << "-malformed-ir";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we change this to print in a more informative way? for example maybe use a colored highlight to show which part is malformed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no color highlighting in the current doc.

return val;
}
std::string name = var->var->name_hint;
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name;
}
Doc val = GetUniqueName("%" + name);
if (memo_type_.count(var)) {
val << "-malformed-ir";
}
memo_type_[var] = val;
if (var->kind != kType) {
val << ": " << Print(var->kind);
Expand All @@ -325,16 +327,18 @@ class PrettyPrinter :
* \return The corresponding name.
*/
Doc AllocVar(const Var& var) {
// still print if ir is malformed, but show the error.
if (memo_.count(var)) {
Doc val = memo_[var];
val << "-malformed-ir";
return val;
}
std::string name = var->name_hint();
// always make sure first name is alpha
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "v" + name;
}
Doc val = GetUniqueName("%" + name);
// still print if ir is malformed, but show the error.
if (memo_.count(var)) {
val << "-malformed-ir";
}
memo_[var] = val;
if (var->type_annotation.defined()) {
val << ": " << Print(var->type_annotation);
Expand Down
59 changes: 12 additions & 47 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
*
* \file src/tvm/relay/pass/fuse_ops.cc
*
Expand Down Expand Up @@ -247,11 +247,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
node->pattern = op_pattern;
this->Update(call->op, nullptr, kOpaque);
const auto* rtype = call->checked_type().as<TensorTypeNode>();
// pass the message back to all the children it references.
// pass the analysis back to all the children it references.
for (size_t i = 0; i < call->args.size(); ++i) {
const auto* arg_type =
call->args[i]->checked_type().as<TensorTypeNode>();
// specifically check if result type
// specifically check if result type is the same as arguments type
OpPatternKind edge_pattern = op_pattern;
if (edge_pattern == kBroadcast &&
arg_type != nullptr &&
Expand Down Expand Up @@ -408,7 +408,7 @@ class DominatorTree {
* \param rhs The right node.
* \param edge_pattern
* The combined edge pattern across all the parents.
* \return The least common ancestor of thw two.
* \return The least common ancestor of the two.
*/
static Node* LeastCommonAncestor(
Node* lhs,
Expand Down Expand Up @@ -456,15 +456,17 @@ DominatorTree DominatorTree::PostDom(common::Arena* arena,
// find the LCAs of all outputs.
OpPatternKind pattern = kElemWise;
Node* parent = nullptr;
bool init = true;
for (auto link = gnode->outputs.head; link != nullptr; link= link->next) {
size_t oindex = link->value.node->index;
CHECK_LT(oindex, tree.nodes.size());
Node* onode = tree.nodes[oindex];
CHECK(onode != nullptr);
if (parent != nullptr) {
parent = LeastCommonAncestor(parent, onode, &pattern);
} else {
if (init) {
parent = onode;
init = false;
} else {
parent = LeastCommonAncestor(parent, onode, &pattern);
}
pattern = CombinePattern(pattern, link->value.pattern);
}
Expand Down Expand Up @@ -614,7 +616,7 @@ class GraphPartitioner {
// merge the current group to the parent if possible.
MergeFromTo(gnode, target);
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
CommitFuse_(link->value.node, sink, target);;
CommitFuse_(link->value.node, sink, target);
}
}
/*!
Expand Down Expand Up @@ -863,7 +865,7 @@ class FuseMutator : private ExprMutator {
auto* ret_group = gmap_.at(tuple_get)->FindRoot();
auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index);
if (ret_group == gmap_.at(tuple_get)) {
if (ret_group->root_ref == tuple_get) {
MarisaKirisame marked this conversation as resolved.
Show resolved Hide resolved
if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) {
// Isolated. This case occurs when tuple is created by an Opaque op
// e.g. multibox_transform_loc
Expand Down Expand Up @@ -922,45 +924,8 @@ class FuseMutator : private ExprMutator {
}
};

// Temporary solution, should be handled by implementing a "FunctionPass"
// which applies fusion to each function.
struct GlobalVarLiveness : ExprVisitor {
Module module;
std::set<GlobalVar> visited;

explicit GlobalVarLiveness(const Module& mod) : module(mod), visited() {}

void VisitExpr_(const GlobalVarNode* gvar_node) {
auto gvar = GetRef<GlobalVar>(gvar_node);
if (visited.find(gvar) == visited.end()) {
visited.insert(gvar);
this->VisitExpr(this->module->Lookup(gvar));
}
}
};

std::set<GlobalVar> LiveGlobals(const Module& mod, const Expr& expr) {
auto gvl = GlobalVarLiveness(mod);
gvl.VisitExpr(expr);
return gvl.visited;
}

Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) {
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
if (!module.defined()) {
return FuseMutator().Transform(expr, fuse_opt_level);
} else {
auto lgvs = LiveGlobals(module, expr);
for (auto lv : lgvs) {
auto body = module->Lookup(lv);
auto e = FuseMutator().Transform(body, fuse_opt_level);
module->Add(lv, Downcast<Function>(e), true);
}
return FuseMutator().Transform(expr, fuse_opt_level);
}
return FuseMutator().Transform(expr, fuse_opt_level);
}

namespace transform {
Expand Down
13 changes: 13 additions & 0 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,18 @@ def expected():
assert relay.analysis.alpha_equal(new_mod, expected())


def test_split():
"""Test that the result is well formed."""
x = relay.var("x", shape=(6, 9))
y = relay.split(x, 3).astuple()
a = relay.TupleGetItem(y, 0)
b = relay.TupleGetItem(y, 1)
c = relay.TupleGetItem(y, 2)
mod = relay.module.Module()
mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c)
mod = transform.FuseOps()(mod)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Expand All @@ -555,3 +567,4 @@ def expected():
test_inception_like()
test_fuse_parallel_injective()
test_immutable()
test_split()