Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PatternLang] Don't rewrite expressions used outside of the pattern #5930

Merged
merged 2 commits into from
Jun 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 40 additions & 22 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,8 @@ class PatternGrouper {
return gid_assignments_;
}
/* \brief Group expressions that match the pattern */
const std::vector<Group>& GroupMatches(const DFPattern& pattern, const Expr& pre) {
groups_ = {Group()};
const std::unordered_map<int, Group>& GroupMatches(const DFPattern& pattern, const Expr& pre) {
groups_.clear();
gid_assignments_.clear();

pattern_ = pattern;
Expand All @@ -487,15 +487,17 @@ class PatternGrouper {
for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; --i) {
size_t index = i - 1;
Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_;
if (auto op = current.as<FunctionNode>()) {
if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) {
pre_partitioned.insert(current);
PostOrderVisit(op->body,
[&pre_partitioned](const Expr& expr) { pre_partitioned.insert(expr); });
if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've already grouped
if (auto op = current.as<FunctionNode>()) {
if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) {
pre_partitioned.insert(current);
PostOrderVisit(op->body,
[&pre_partitioned](const Expr& expr) { pre_partitioned.insert(expr); });
}
}
if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_, current)) {
CreateGroup(current);
}
}
if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_, current)) {
CreateGroup(current);
}
}
}
Expand Down Expand Up @@ -616,20 +618,37 @@ class PatternGrouper {
CHECK(DFPatternMatcher(body).Match(pattern_, body));
group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
group.name = extractor.GetName();
// Check to make sure we aren't overlapping with another group
// Check to make sure we aren't overlapping with another group or creating an invalid fusion
// The MatchExtractor will create a new graph by replacing nodes that match the inputs of the
// pattern with the input FunctionVar* Variables. The resulting memoization map will only
// contain nodes in the expression that matched the pattern. If a non-input node of the pattern
// (i.e., some piece of computation) overlaps with the nodes in a previous group, we'll have a
// situation where we try to rewrite the same node twice in the second rewriting or parition
// pass. This isn't valid, so we check for it here. We ignore Ops, functions, and constants
// because they exist more globally outside of the fusion.
for (auto kv : extractor.GetMemo()) {
if (gid_assignments_.count(kv.first) != 0 && inputs.count(kv.first) == 0 &&
kv.first.as<OpNode>() == nullptr && kv.first.as<FunctionNode>() == nullptr &&
kv.first.as<ConstantNode>() == nullptr) {
// Exit due to overlapping partitions
return;
// Similiarly, if interior nodes in a group are used outside of the group fusing to a single
// output would create an invalid graph tranformation, so we block the creation of such groups.
auto memo = extractor.GetMemo();
for (auto kv : memo) {
// Check to ensure that this node isn't an input or a global
if (inputs.count(kv.first) == 0 && kv.first.as<OpNode>() == nullptr &&
kv.first.as<FunctionNode>() == nullptr && kv.first.as<ConstantNode>() == nullptr) {
if (gid_assignments_.count(kv.first) != 0) {
// check to see if the node is use in other groups
// Exit due to overlapping partitions
return;
} else if (kv.second != body) {
// if the node isn't the ouput of the group
auto node = matcher_->expr_graph_.node_map_.at(kv.first);
for (auto* output : node->outputs_) {
// and the node is used by nodes outside of the group
if (memo.count(output->ref_) == 0) {
// Exit because nodes in this pattern's body are used outside the pattern
// fusing it would be invalid
return;
}
}
}
}
}
// Assign Group Ids
Expand All @@ -639,8 +658,7 @@ class PatternGrouper {
}

// Save Group
groups_.emplace_back(std::move(group));
CHECK_EQ(groups_[gid_].gid, gid_);
groups_[group.gid] = std::move(group);
}

/* \brief EmbedConst implements rules for embedding constants into partitioned functions or
Expand Down Expand Up @@ -675,7 +693,7 @@ class PatternGrouper {
}
// Internal State
DFPattern pattern_;
std::vector<Group> groups_;
std::unordered_map<int, Group> groups_;
std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> gid_assignments_;
DFPatternMatcher* matcher_ = nullptr;
IndexedGraph<DFPattern> pattern_graph_;
Expand Down Expand Up @@ -753,7 +771,7 @@ class PatternRewriter : protected MixedModeMutator {
}

DFPatternCallback callback_;
std::vector<PatternGrouper::Group> groups_;
std::unordered_map<int, PatternGrouper::Group> groups_;
std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> gid_assignments_;
};

Expand Down Expand Up @@ -805,7 +823,7 @@ class PatternPartitioner : protected MixedModeMutator {
}

Map<String, ObjectRef> attrs_;
std::vector<PatternGrouper::Group> groups_;
std::unordered_map<int, PatternGrouper::Group> groups_;
std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> gid_assignments_;
PackedFunc check_;
};
Expand Down
31 changes: 31 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,37 @@ def test_partition_double_batchnorm():
reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
assert tvm.ir.structural_equal(partitioned, reference)

def test_overlappting_partitions():
x = wildcard()
gamma = wildcard()
beta = wildcard()
moving_mean = wildcard()
moving_var = wildcard()
bn_node = is_op('nn.batch_norm')(x, gamma, beta, moving_mean, moving_var)
tuple_get_item_node = TupleGetItemPattern(bn_node, 0)

x = relay.var('x')
var = relay.var('var')
mean = relay.var('mean')
beta = relay.var('beta')
gamma = relay.var('gamma')
BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)
T1 = BN[0]
T2 = BN[0]
add = T1 + T2

assert tuple_get_item_node.partition(add) == add
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two questions to this test case:

  1. Should we use structural_equal here?
  2. Does that mean we do not even treat BN -> T2 as a match anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partitioning either path is invalid, since either side would need intermediate nodes from the other, so we expect the original expression to come back unchanged, thus the ==. We treat it as a match, but we don't treat it as something we can independently rewrite.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. That makes sense.


def test_partition_overused():
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))

x = relay.var('input')
w = relay.var('weight')
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
out = relu + conv2d

assert pattern.partition(out) == out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, fusing the conv and relu would make the rest of the expr invalid, so we expect the expr to come back unchanged.


def test_partition_check():
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
Expand Down