diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 185a92898d54..249f4ccf7a44 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -302,12 +302,12 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) { auto call_node = expr.as(); auto index_node = expr_to_node(expr); + size_t arg_counter{0}; for (auto node : index_node->inputs_) { if (!(call_node && node->ref() == call_node->op)) { + arg_counter += 1; memoize_ = true; - if (VisitDFPattern(op->parent, node->ref())) { - return true; - } else { + if (!VisitDFPattern(op->parent, node->ref())) { memoize_ = false; if (!VisitDFPattern(op->path, node->ref())) { return false; @@ -318,6 +318,9 @@ bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& e } } } + if (!arg_counter) { + return false; + } return true; } @@ -605,8 +608,10 @@ void PatternGrouper::CreateGroup(const Expr& expr) { // Don't treat fuzzy Dominator patterns input variables for partition if (auto op = node->ref().as()) { for (auto fuzzy_op : {op->parent, op->path}) { - for (auto match : node_map[fuzzy_op]) { - fuzzy_matches.insert(match); + if (node_map.count(fuzzy_op)) { + for (auto match : node_map[fuzzy_op]) { + fuzzy_matches.insert(match); + } } } } diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index bcb665121b08..c4a83735cee9 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -750,6 +750,37 @@ def test_not_match_dominator(): assert not diamond.match(out) +def test_not_match_dominator2(): + # Pattern + P = is_op("nn.conv2d")(wildcard(), wildcard()) # 'parent' + I = is_op("nn.relu")(wildcard()) # 'intermediate' ('path' in the code) + C = is_op("add")(wildcard(), wildcard()) # 'child' + pattern = dominates(P, I, C) + + # n6(P) + # / \ + # n7 \ + # / \ + # n8(P) n9(I) + # \ / + # \ / + # \ / + # n10(C) + + x = relay.var("x") + w = relay.var("w") + n6 = relay.op.nn.conv2d(x, w) # matches P + n7 = relay.op.tanh(n6) # does not match I + n8 = relay.op.nn.conv2d(n7, w) # matches P + n9 = relay.op.nn.relu(n6) # matches I + n10 = relay.add(n8, n9) # matches C + + # Does not match: Can't match the parent pattern P at both 8 and 6. + # Note that if we did allow P to be used twice the implementation would + # need to be changed to not 'jump over' n7. + assert not pattern.match(n10) + + def test_match_typed_dominator(): # Pattern is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())