Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
kfeng123 committed Jul 25, 2023
1 parent 77c15b3 commit 75279fb
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
42 changes: 21 additions & 21 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,17 +592,17 @@ def __init__(
):
self.__init_handle_by_constructor__(ffi.CallPattern, op, args)

"""Reset the arguments of the CallPattern
Parameters
----------
args: List[relay.dataflow_pattern.DFPattern]
The new arguments of the CallPattern.
"""
def reset_args(
self,
args: List["DFPattern"],
):
"""Reset the arguments of the CallPattern
Parameters
----------
args: List[relay.dataflow_pattern.DFPattern]
The new arguments of the CallPattern.
"""
ffi.CallPattern_reset_args(self, args)


Expand Down Expand Up @@ -727,30 +727,30 @@ class AltPattern(DFPattern):
def __init__(self, left: "DFPattern", right: "DFPattern"):
self.__init_handle_by_constructor__(ffi.AltPattern, left, right)

"""Reset the left of the AltPattern
Parameters
----------
left: relay.dataflow_pattern.DFPattern
The new left of the AltPattern.
"""
def reset_left(
self,
left: "DFPattern",
):
ffi.AltPattern_reset_left(self, left)
"""Reset the left of the AltPattern
"""Reset the right of the AltPattern
Parameters
----------
left: relay.dataflow_pattern.DFPattern
The new left of the AltPattern.
"""
ffi.AltPattern_reset_left(self, left)

Parameters
----------
right: relay.dataflow_pattern.DFPattern
The new right of the AltPattern.
"""
def reset_right(
self,
right: "DFPattern",
):
"""Reset the right of the AltPattern
Parameters
----------
right: relay.dataflow_pattern.DFPattern
The new right of the AltPattern.
"""
ffi.AltPattern_reset_right(self, right)


Expand Down
7 changes: 4 additions & 3 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1979,7 +1979,7 @@ def test_rewrite_with_pattern_recursion():

class TheRewrite(DFPatternCallback):
def __init__(self, pattern):
super(TheRewrite, self).__init__(rewrite_once = True)
super(TheRewrite, self).__init__(rewrite_once=True)
self.pattern = pattern

def callback(self, pre, post, node_map):
Expand All @@ -1997,7 +1997,7 @@ def test_reset_call_args():
def test_reset_alt_left():
dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
or_pattern = wildcard() | dense_pattern
the_pattern = is_op("cast")( or_pattern )
the_pattern = is_op("cast")(or_pattern)
or_pattern.reset_left(the_pattern)

actual = rewrite(TheRewrite(the_pattern), oup)
Expand All @@ -2006,7 +2006,7 @@ def test_reset_alt_left():
def test_reset_alt_right():
dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
or_pattern = dense_pattern | wildcard()
the_pattern = is_op("cast")( or_pattern )
the_pattern = is_op("cast")(or_pattern)
or_pattern.reset_right(the_pattern)

actual = rewrite(TheRewrite(the_pattern), oup)
Expand All @@ -2016,5 +2016,6 @@ def test_reset_alt_right():
test_reset_alt_left()
test_reset_alt_right()


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 75279fb

Please sign in to comment.