Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
kfeng123 committed Jul 25, 2023
1 parent 77c15b3 commit aae21ec
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
3 changes: 3 additions & 0 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ def __init__(
args: List[relay.dataflow_pattern.DFPattern]
The new arguments of the CallPattern.
"""

def reset_args(
self,
args: List["DFPattern"],
Expand Down Expand Up @@ -734,6 +735,7 @@ def __init__(self, left: "DFPattern", right: "DFPattern"):
left: relay.dataflow_pattern.DFPattern
The new left of the AltPattern.
"""

def reset_left(
self,
left: "DFPattern",
Expand All @@ -747,6 +749,7 @@ def reset_left(
right: relay.dataflow_pattern.DFPattern
The new right of the AltPattern.
"""

def reset_right(
self,
right: "DFPattern",
Expand Down
7 changes: 4 additions & 3 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1979,7 +1979,7 @@ def test_rewrite_with_pattern_recursion():

class TheRewrite(DFPatternCallback):
def __init__(self, pattern):
super(TheRewrite, self).__init__(rewrite_once = True)
super(TheRewrite, self).__init__(rewrite_once=True)
self.pattern = pattern

def callback(self, pre, post, node_map):
Expand All @@ -1997,7 +1997,7 @@ def test_reset_call_args():
def test_reset_alt_left():
dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
or_pattern = wildcard() | dense_pattern
the_pattern = is_op("cast")( or_pattern )
the_pattern = is_op("cast")(or_pattern)
or_pattern.reset_left(the_pattern)

actual = rewrite(TheRewrite(the_pattern), oup)
Expand All @@ -2006,7 +2006,7 @@ def test_reset_alt_left():
def test_reset_alt_right():
dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
or_pattern = dense_pattern | wildcard()
the_pattern = is_op("cast")( or_pattern )
the_pattern = is_op("cast")(or_pattern)
or_pattern.reset_right(the_pattern)

actual = rewrite(TheRewrite(the_pattern), oup)
Expand All @@ -2016,5 +2016,6 @@ def test_reset_alt_right():
test_reset_alt_left()
test_reset_alt_right()


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit aae21ec

Please sign in to comment.