Skip to content

Commit

Permalink
Add more syntatic sugar
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Jun 1, 2020
1 parent 55aefc2 commit 748f1c2
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 113 deletions.
1 change: 1 addition & 0 deletions docs/api/python/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Python API
relay/transform
relay/analysis
relay/backend
relay/dataflow_pattern
relay/testing
autotvm
rpc
Expand Down
21 changes: 13 additions & 8 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Since there are not call nodes, we need to use specific pattern nodes to match t
x = relay.var('x')
y = relay.var('y')
z = relay.var('z')
tuple_pattern = TuplePattern((wildcard(), wildcard(), wildcard()))
tuple_pattern = is_tuple((wildcard(), wildcard(), wildcard()))
assert tuple_pattern.match(relay.expr.Tuple((x,y,z)))
The next example is matching a pattern of batch_norm -> get(0) -> relu:
Expand All @@ -123,7 +123,7 @@ The next example is matching a pattern of batch_norm -> get(0) -> relu:
def test_match_tuple_get_item():
bn_node = is_op('nn.batch_norm')(wildcard(), wildcard(), wildcard(), wildcard(), wildcard())
tuple_get_item_node = TupleGetItemPattern(bn_node, 0)
tuple_get_item_node = is_tuple_get_item(bn_node, 0)
pat = is_op('nn.relu')(tuple_get_item_node)
x = relay.var('x', shape=(1, 8))
Expand All @@ -142,7 +142,7 @@ if a specific parameter in a subgraph has been bound or not.
.. code-block:: python
def test_match_constant():
conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern())
conv2d = is_op('nn.conv2d')(wildcard(), is_constant())
pattern = is_op('nn.bias_add')(conv2d, wildcard())
x = relay.var('x', shape=(1, 3, 224, 224))
Expand All @@ -162,12 +162,12 @@ if a specific parameter in a subgraph has been bound or not.
assert pattern.match(mod['main'].body)
On the other hand, if you need to match the constant with a specific value, you can directly
use ``ExprPattern``. This could be useful for algebraic simplify.
use ``is_expr``. This could be useful for algebraic simplify.

.. code-block:: python
def test_match_plus_zero():
zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0)))
zero = (is_expr(relay.const(0)) | is_expr(relay.const(0.0)))
pattern = wildcard() + zero
x = relay.Var('x')
Expand All @@ -193,7 +193,7 @@ The next example is matching a diamond with two inputs at the top of the diamond

def test_match_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(is_input(), is_input())
is_conv2d = is_op('nn.conv2d')(is_var(), is_var())
path1 = is_op('nn.relu')(is_conv2d)
path2 = is_op('nn.leaky_relu')(is_conv2d)
diamond = is_op('add')(path1, path2)
Expand All @@ -213,7 +213,7 @@ The final example is matching diamonds with a post-dominator relationship. We em

def test_match_dom_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(is_input(), is_input())
is_conv2d = is_op('nn.conv2d')(is_var(), is_var())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_elemwise, reduction)

Expand All @@ -240,7 +240,12 @@ The high level design is to introduce a language of patterns for now we propose
| pattern(pattern1, ... patternN)
| has_type(pattern, type)
| has_attr(pattern, attrs)
| is_input(name)
| is_var(name)
| is_constant()
| is_expr(expr)
| is_op(op_name)
| is_tuple()
| is_tuple_get_item()
| pattern1 `|` pattern2
| dominates(parent_pattern, path_pattern, child_pattern)

Expand Down
Loading

0 comments on commit 748f1c2

Please sign in to comment.