diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index e4aeae1a5ca84..a215d4529ac3a 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -11798,6 +11798,37 @@ def f(m, x): f"Ref:\n{ref_grad_list}\nAct:\n{act_grad_list}", ) + def test_chunk_recompiles(self): + def f(x): + return x.chunk(4) + + def run(size): + input = torch.randn(size) + expected_out = f(input) + actual_out = optf(input) + self.assertEqual(expected_out, actual_out) + + cnts = CompileCounterWithBackend("inductor") + optf = torch.compile(f, backend=cnts, fullgraph=True) + + # The first run should compile once with static shapes. + run(4) + self.assertEqual(cnts.frame_count, 1) + + # Varying the input size should trigger a recompilation. + # Since the input size is a multiple of 4 (i.e. all runs shall + # generate 4 output tensors), there should be no further + # recompilation. + for i in range(2, 12): + run(4 * i) + self.assertEqual(cnts.frame_count, 2) + + # Input size: 9 + # Yields one less output tensor, which should trigger a + # recompilation. + run(9) + self.assertEqual(cnts.frame_count, 3) + @dataclasses.dataclass class TestFailure: diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 4e4bea2613e70..60fe25cc15544 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1670,13 +1670,11 @@ def select(x, dim, idx): @register_lowering(aten.split, type_promotion_kind=None) def split(x, sizes, dim=0, clamp=True): dim = _validate_dim(x, dim, 0) - if isinstance(sizes, sympy.Expr): - # TODO: We don't have to guard on sizes per se, but the number - # of splits must stay constant - sizes = V.graph.sizevars.evaluate_static_shape(sizes) - if isinstance(sizes, (int, sympy.Integer)): - x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) - sizes = [sizes] * ((x_size + sizes - 1) // sizes) + if not isinstance(sizes, (list, tuple)): + chunks = V.graph.sizevars.evaluate_static_shape( + FloorDiv(x.get_size()[dim] + sizes - 1, sizes) + ) + sizes = [sizes] * chunks result = [] start = 0 for size in sizes: