diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index b237f14d613f..6cacff25c2f1 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -148,7 +148,7 @@ Since there are not call nodes, we need to use specific pattern nodes to match t tuple_pattern = is_tuple((wildcard(), wildcard(), wildcard())) assert tuple_pattern.match(relay.expr.Tuple((x,y,z))) -The next example is matching a pattern of batch_norm -> get(0) -> relu: +The next example is matching a pattern of batch_norm -> get(0) -> relu. Note that you can also use `is_tuple_get_item(bn_node)` to match a `TupleGetItem` node with any index. .. code-block:: python @@ -280,7 +280,7 @@ The high level design is to introduce a language of patterns for now we propose | is_expr(expr) | is_op(op_name) | is_tuple() - | is_tuple_get_item() + | is_tuple_get_item(pattern, index = None) | pattern1 `|` pattern2 | dominates(parent_pattern, path_pattern, child_pattern) diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 915842c8e5fa..317d28e1dbea 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -293,7 +293,7 @@ def is_tuple(fields: tvm.ir.container.Array) -> "DFPattern": return TuplePattern(fields) -def is_tuple_get_item(tuple_value: "DFPattern", index: int) -> "DFPattern": +def is_tuple_get_item(tuple_value: "DFPattern", index: Optional[int] = None) -> "DFPattern": """ Syntatic sugar for creating an ExprPattern. @@ -302,8 +302,8 @@ def is_tuple_get_item(tuple_value: "DFPattern", index: int) -> "DFPattern": tuple_value: tvm.relay.dataflow_pattern.DFPattern The input tuple expression. - index: int - The index. + index: Optional[int] + The index to match; Default (None) to match a TupleGetItem with any index. Returns ------- @@ -555,12 +555,13 @@ class TupleGetItemPattern(DFPattern): tuple_value: tvm.relay.dataflow_pattern.DFPattern The input tuple expression. - index: int - The index. + index: Optional[int] + The index to match; Default (None) to match a TupleGetItem with any index. """ - def __init__(self, tuple_value: "DFPattern", index: int): - self.__init_handle_by_constructor__(ffi.TupleGetItemPattern, tuple_value, index) + def __init__(self, tuple_value: "DFPattern", index: Optional[int] = None): + match_index = index if index is not None else -1 + self.__init_handle_by_constructor__(ffi.TupleGetItemPattern, tuple_value, match_index) @register_df_node diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index c9bf11e884ab..d33891a69bc6 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -359,7 +359,7 @@ bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& ex bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) { bool matches = false; if (const auto* tuple_get_item_node = expr.as()) { - matches = (op->index == tuple_get_item_node->index) && + matches = (op->index == -1 || op->index == tuple_get_item_node->index) && VisitDFPattern(op->tuple, tuple_get_item_node->tuple); } return matches; diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 9727e53bab0a..4fce4732669c 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -253,6 +253,11 @@ def test_match_tuple(): tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1) assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1)) + tuple_get_item_pattern = is_tuple_get_item(tuple_pattern) # Match any index + assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 0)) + assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1)) + assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 2)) + def test_no_match_tuple(): x = relay.var('x')