From 389bb67ec0fcbdfa57bc6b650193edf0e51e731c Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Sun, 31 May 2020 04:30:42 +0000 Subject: [PATCH] rename is_input to is_var --- docs/langref/relay_pattern.rst | 6 ++--- python/tvm/relay/dataflow_pattern/__init__.py | 2 +- tests/python/relay/test_dataflow_pattern.py | 24 +++++++++---------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index 135966e8a30ee..88214905240a2 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -193,7 +193,7 @@ The next example is matching a diamond with two inputs at the top of the diamond def test_match_diamond(): # Pattern - is_conv2d = is_op('nn.conv2d')(is_input(), is_input()) + is_conv2d = is_op('nn.conv2d')(is_var(), is_var()) path1 = is_op('nn.relu')(is_conv2d) path2 = is_op('nn.leaky_relu')(is_conv2d) diamond = is_op('add')(path1, path2) @@ -213,7 +213,7 @@ The final example is matching diamonds with a post-dominator relationship. We em def test_match_dom_diamond(): # Pattern - is_conv2d = is_op('nn.conv2d')(is_input(), is_input()) + is_conv2d = is_op('nn.conv2d')(is_var(), is_var()) reduction = is_op('add')(wildcard(), wildcard()) diamond = dominates(is_conv2d, is_elemwise, reduction) @@ -240,7 +240,7 @@ The high level design is to introduce a language of patterns for now we propose | pattern(pattern1, ... patternN) | has_type(pattern, type) | has_attr(pattern, attrs) - | is_input(name) + | is_var(name) | is_constant() | is_op(op_name) | is_tuple() diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index a803760752ad6..bdeabd7e09067 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -285,7 +285,7 @@ def __init__(self, name_hint="", type_annotation=None): self.__init_handle_by_constructor__( ffi.VarPattern, name_hint, type_annotation) -is_input = VarPattern +is_var = VarPattern @register_df_node diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 2541d5deb02d3..180244de1e167 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -38,7 +38,7 @@ def test_expr_pattern(): def test_var_pattern(): - v = is_input("x") + v = is_var("x") assert isinstance(v, VarPattern) assert v.name == "x" @@ -120,10 +120,10 @@ def test_match_op_or(): def test_match_call_commutive(): x = relay.var('x') y = relay.var('y') - add_pattern = is_op('add')(is_input("x"), is_input("y")) + add_pattern = is_op('add')(is_var("x"), is_var("y")) assert add_pattern.match(x + y) assert add_pattern.match(y + x) - mul_pattern = is_op('multiply')(is_input("x"), is_input("y")) + mul_pattern = is_op('multiply')(is_var("x"), is_var("y")) assert mul_pattern.match(x * y) assert mul_pattern.match(y * x) @@ -131,10 +131,10 @@ def test_match_call_commutive(): def test_no_match_call_commutive(): x = relay.var('x') y = relay.var('y') - add_pattern = is_op('subtract')(is_input("x"), is_input("y")) + add_pattern = is_op('subtract')(is_var("x"), is_var("y")) assert add_pattern.match(x - y) assert not add_pattern.match(y - x) - add_pattern = is_op('divide')(is_input("x"), is_input("y")) + add_pattern = is_op('divide')(is_var("x"), is_var("y")) assert add_pattern.match(x / y) assert not add_pattern.match(y / x) @@ -232,10 +232,10 @@ def test_match_tuple(): x = relay.var('x') y = relay.var('y') z = relay.op.op.get("add") - tuple_pattern = is_tuple((is_input("x"), wildcard(), is_op("add"))) + tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add"))) assert tuple_pattern.match(relay.expr.Tuple((x, y, z))) - tuple_pattern = is_tuple((is_input("x"), wildcard(), is_op("add"))) + tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add"))) tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1) assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1)) @@ -244,10 +244,10 @@ def test_no_match_tuple(): x = relay.var('x') y = relay.var('y') z = relay.op.op.get("add") - tuple_pattern = is_tuple((is_input('x'), wildcard(), is_op("add"), wildcard())) + tuple_pattern = is_tuple((is_var('x'), wildcard(), is_op("add"), wildcard())) assert not tuple_pattern.match(relay.expr.Tuple((x, y, z))) - tuple_pattern = is_tuple((is_input('x'), wildcard(), is_op("add"))) + tuple_pattern = is_tuple((is_var('x'), wildcard(), is_op("add"))) tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1) assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple( (x, y, z)), 2)) @@ -1182,7 +1182,7 @@ def test_partition_constant_embedding(): assert tvm.ir.structural_equal(lifted_func(x, wc, b), pattern.partition(reluc)) # Check lifting of input matches - pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()), + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_var()), wildcard())) assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) assert tvm.ir.structural_equal(reluc, pattern.partition(reluc)) #Constants are not Inputs @@ -1201,7 +1201,7 @@ def test_partition_constant_embedding(): assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) # Check lifting/embedding of Alt matches - pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input() + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_var() | is_constant()), wildcard())) assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) @@ -1209,7 +1209,7 @@ def test_partition_constant_embedding(): # Check lifting/embedding of Alt matches with the other ordering pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), - is_constant() | is_input()), + is_constant() | is_var()), wildcard())) assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))