diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 028555faf883b..9852486b55ecb 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -599,6 +599,7 @@ def __init__( args: List[relay.dataflow_pattern.DFPattern] The new arguments of the CallPattern. """ + def reset_args( self, args: List["DFPattern"], @@ -734,6 +735,7 @@ def __init__(self, left: "DFPattern", right: "DFPattern"): left: relay.dataflow_pattern.DFPattern The new left of the AltPattern. """ + def reset_left( self, left: "DFPattern", @@ -747,6 +749,7 @@ def reset_left( right: relay.dataflow_pattern.DFPattern The new right of the AltPattern. """ + def reset_right( self, right: "DFPattern", diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 10d68e38e9650..726655862ec57 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -1979,7 +1979,7 @@ def test_rewrite_with_pattern_recursion(): class TheRewrite(DFPatternCallback): def __init__(self, pattern): - super(TheRewrite, self).__init__(rewrite_once = True) + super(TheRewrite, self).__init__(rewrite_once=True) self.pattern = pattern def callback(self, pre, post, node_map): @@ -1997,7 +1997,7 @@ def test_reset_call_args(): def test_reset_alt_left(): dense_pattern = is_op("nn.dense")(wildcard(), wildcard()) or_pattern = wildcard() | dense_pattern - the_pattern = is_op("cast")( or_pattern ) + the_pattern = is_op("cast")(or_pattern) or_pattern.reset_left(the_pattern) actual = rewrite(TheRewrite(the_pattern), oup) @@ -2006,7 +2006,7 @@ def test_reset_alt_left(): def test_reset_alt_right(): dense_pattern = is_op("nn.dense")(wildcard(), wildcard()) or_pattern = dense_pattern | wildcard() - the_pattern = is_op("cast")( or_pattern ) + the_pattern = is_op("cast")(or_pattern) or_pattern.reset_right(the_pattern) actual = rewrite(TheRewrite(the_pattern), oup) @@ -2016,5 +2016,6 @@ def test_reset_alt_right(): test_reset_alt_left() test_reset_alt_right() + if __name__ == "__main__": tvm.testing.main()