Skip to content

Commit

Permalink
[PatternLang] Support any index matching for TupleGetItem (#5909)
Browse files Browse the repository at this point in the history
* support any index matching

* update doc
  • Loading branch information
comaniac authored Jun 24, 2020
1 parent fcaba98 commit 7bee0ea
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 10 deletions.
4 changes: 2 additions & 2 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ Since there are not call nodes, we need to use specific pattern nodes to match t
tuple_pattern = is_tuple((wildcard(), wildcard(), wildcard()))
assert tuple_pattern.match(relay.expr.Tuple((x,y,z)))
The next example is matching a pattern of batch_norm -> get(0) -> relu:
The next example is matching a pattern of batch_norm -> get(0) -> relu. Note that you can also use `is_tuple_get_item(bn_node)` to match a `TupleGetItem` node with any index.

.. code-block:: python
Expand Down Expand Up @@ -280,7 +280,7 @@ The high level design is to introduce a language of patterns for now we propose
| is_expr(expr)
| is_op(op_name)
| is_tuple()
| is_tuple_get_item()
| is_tuple_get_item(pattern, index = None)
| pattern1 `|` pattern2
| dominates(parent_pattern, path_pattern, child_pattern)

Expand Down
15 changes: 8 additions & 7 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def is_tuple(fields: tvm.ir.container.Array) -> "DFPattern":
return TuplePattern(fields)


def is_tuple_get_item(tuple_value: "DFPattern", index: int) -> "DFPattern":
def is_tuple_get_item(tuple_value: "DFPattern", index: Optional[int] = None) -> "DFPattern":
"""
Syntatic sugar for creating an ExprPattern.
Expand All @@ -302,8 +302,8 @@ def is_tuple_get_item(tuple_value: "DFPattern", index: int) -> "DFPattern":
tuple_value: tvm.relay.dataflow_pattern.DFPattern
The input tuple expression.
index: int
The index.
index: Optional[int]
The index to match; Default (None) to match a TupleGetItem with any index.
Returns
-------
Expand Down Expand Up @@ -555,12 +555,13 @@ class TupleGetItemPattern(DFPattern):
tuple_value: tvm.relay.dataflow_pattern.DFPattern
The input tuple expression.
index: int
The index.
index: Optional[int]
The index to match; Default (None) to match a TupleGetItem with any index.
"""

def __init__(self, tuple_value: "DFPattern", index: int):
self.__init_handle_by_constructor__(ffi.TupleGetItemPattern, tuple_value, index)
def __init__(self, tuple_value: "DFPattern", index: Optional[int] = None):
match_index = index if index is not None else -1
self.__init_handle_by_constructor__(ffi.TupleGetItemPattern, tuple_value, match_index)


@register_df_node
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& ex
bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
bool matches = false;
if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
matches = (op->index == tuple_get_item_node->index) &&
matches = (op->index == -1 || op->index == tuple_get_item_node->index) &&
VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
}
return matches;
Expand Down
5 changes: 5 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,11 @@ def test_match_tuple():
tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1)
assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1))

tuple_get_item_pattern = is_tuple_get_item(tuple_pattern) # Match any index
assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 0))
assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1))
assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 2))


def test_no_match_tuple():
x = relay.var('x')
Expand Down

0 comments on commit 7bee0ea

Please sign in to comment.