Skip to content

Commit

Permalink
rename is_input to is_var
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed May 31, 2020
1 parent 18c934b commit 389bb67
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
6 changes: 3 additions & 3 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ The next example is matching a diamond with two inputs at the top of the diamond

def test_match_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(is_input(), is_input())
is_conv2d = is_op('nn.conv2d')(is_var(), is_var())
path1 = is_op('nn.relu')(is_conv2d)
path2 = is_op('nn.leaky_relu')(is_conv2d)
diamond = is_op('add')(path1, path2)
Expand All @@ -213,7 +213,7 @@ The final example is matching diamonds with a post-dominator relationship. We em

def test_match_dom_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(is_input(), is_input())
is_conv2d = is_op('nn.conv2d')(is_var(), is_var())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_elemwise, reduction)

Expand All @@ -240,7 +240,7 @@ The high level design is to introduce a language of patterns for now we propose
| pattern(pattern1, ... patternN)
| has_type(pattern, type)
| has_attr(pattern, attrs)
| is_input(name)
| is_var(name)
| is_constant()
| is_op(op_name)
| is_tuple()
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def __init__(self, name_hint="", type_annotation=None):
self.__init_handle_by_constructor__(
ffi.VarPattern, name_hint, type_annotation)

is_input = VarPattern
is_var = VarPattern


@register_df_node
Expand Down
24 changes: 12 additions & 12 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_expr_pattern():


def test_var_pattern():
v = is_input("x")
v = is_var("x")
assert isinstance(v, VarPattern)
assert v.name == "x"

Expand Down Expand Up @@ -120,21 +120,21 @@ def test_match_op_or():
def test_match_call_commutive():
x = relay.var('x')
y = relay.var('y')
add_pattern = is_op('add')(is_input("x"), is_input("y"))
add_pattern = is_op('add')(is_var("x"), is_var("y"))
assert add_pattern.match(x + y)
assert add_pattern.match(y + x)
mul_pattern = is_op('multiply')(is_input("x"), is_input("y"))
mul_pattern = is_op('multiply')(is_var("x"), is_var("y"))
assert mul_pattern.match(x * y)
assert mul_pattern.match(y * x)


def test_no_match_call_commutive():
x = relay.var('x')
y = relay.var('y')
add_pattern = is_op('subtract')(is_input("x"), is_input("y"))
add_pattern = is_op('subtract')(is_var("x"), is_var("y"))
assert add_pattern.match(x - y)
assert not add_pattern.match(y - x)
add_pattern = is_op('divide')(is_input("x"), is_input("y"))
add_pattern = is_op('divide')(is_var("x"), is_var("y"))
assert add_pattern.match(x / y)
assert not add_pattern.match(y / x)

Expand Down Expand Up @@ -232,10 +232,10 @@ def test_match_tuple():
x = relay.var('x')
y = relay.var('y')
z = relay.op.op.get("add")
tuple_pattern = is_tuple((is_input("x"), wildcard(), is_op("add")))
tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add")))
assert tuple_pattern.match(relay.expr.Tuple((x, y, z)))

tuple_pattern = is_tuple((is_input("x"), wildcard(), is_op("add")))
tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add")))
tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1)
assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1))

Expand All @@ -244,10 +244,10 @@ def test_no_match_tuple():
x = relay.var('x')
y = relay.var('y')
z = relay.op.op.get("add")
tuple_pattern = is_tuple((is_input('x'), wildcard(), is_op("add"), wildcard()))
tuple_pattern = is_tuple((is_var('x'), wildcard(), is_op("add"), wildcard()))
assert not tuple_pattern.match(relay.expr.Tuple((x, y, z)))

tuple_pattern = is_tuple((is_input('x'), wildcard(), is_op("add")))
tuple_pattern = is_tuple((is_var('x'), wildcard(), is_op("add")))
tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1)
assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple(
(x, y, z)), 2))
Expand Down Expand Up @@ -1182,7 +1182,7 @@ def test_partition_constant_embedding():
assert tvm.ir.structural_equal(lifted_func(x, wc, b), pattern.partition(reluc))

# Check lifting of input matches
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()),
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_var()),
wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(reluc, pattern.partition(reluc)) #Constants are not Inputs
Expand All @@ -1201,15 +1201,15 @@ def test_partition_constant_embedding():
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))

# Check lifting/embedding of Alt matches
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_var()
| is_constant()),
wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))

# Check lifting/embedding of Alt matches with the other ordering
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(),
is_constant() | is_input()),
is_constant() | is_var()),
wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
Expand Down

0 comments on commit 389bb67

Please sign in to comment.