Skip to content

Commit

Permalink
[Relay] [Bugfix] Fix some bugs of dominator pattern (#15473)
Browse files Browse the repository at this point in the history
* fix some bugs

* add test
  • Loading branch information
kfeng123 authored Aug 4, 2023
1 parent 2f09064 commit ac99367
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,12 +302,12 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
auto call_node = expr.as<CallNode>();
auto index_node = expr_to_node(expr);
size_t arg_counter{0};
for (auto node : index_node->inputs_) {
if (!(call_node && node->ref() == call_node->op)) {
arg_counter += 1;
memoize_ = true;
if (VisitDFPattern(op->parent, node->ref())) {
return true;
} else {
if (!VisitDFPattern(op->parent, node->ref())) {
memoize_ = false;
if (!VisitDFPattern(op->path, node->ref())) {
return false;
Expand All @@ -318,6 +318,9 @@ bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& e
}
}
}
if (!arg_counter) {
return false;
}
return true;
}

Expand Down Expand Up @@ -605,8 +608,10 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
// Don't treat fuzzy Dominator patterns input variables for partition
if (auto op = node->ref().as<DominatorPatternNode>()) {
for (auto fuzzy_op : {op->parent, op->path}) {
for (auto match : node_map[fuzzy_op]) {
fuzzy_matches.insert(match);
if (node_map.count(fuzzy_op)) {
for (auto match : node_map[fuzzy_op]) {
fuzzy_matches.insert(match);
}
}
}
}
Expand Down
31 changes: 31 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,37 @@ def test_not_match_dominator():
assert not diamond.match(out)


def test_not_match_dominator2():
# Pattern
P = is_op("nn.conv2d")(wildcard(), wildcard()) # 'parent'
I = is_op("nn.relu")(wildcard()) # 'intermediate' ('path' in the code)
C = is_op("add")(wildcard(), wildcard()) # 'child'
pattern = dominates(P, I, C)

# n6(P)
# / \
# n7 \
# / \
# n8(P) n9(I)
# \ /
# \ /
# \ /
# n10(C)

x = relay.var("x")
w = relay.var("w")
n6 = relay.op.nn.conv2d(x, w) # matches P
n7 = relay.op.tanh(n6) # does not match I
n8 = relay.op.nn.conv2d(n7, w) # matches P
n9 = relay.op.nn.relu(n6) # matches I
n10 = relay.add(n8, n9) # matches C

# Does not match: Can't match the parent pattern P at both 8 and 6.
# Note that if we did allow P to be used twice the implementation would
# need to be changed to not 'jump over' n7.
assert not pattern.match(n10)


def test_match_typed_dominator():
# Pattern
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
Expand Down

0 comments on commit ac99367

Please sign in to comment.