From 8be1ee139c4c74a1527542fa00d8e0ca66f06308 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 20 Nov 2024 18:44:43 -0300 Subject: [PATCH] [inductor] Don't specialize `split` on `sizes` parameter. (#141077) Fix: #139936 This PR modifies the lowering of `split` operation, so that it won't generate guards, specializing on the sizes parameter. Instead, it specializes on the number of output tensors being generated (i.e. function of the size of the base tensor, and the sizes parameter). As a result, operations such as `chunk` (whose number of output tensors usually is constant given a static chunk number) won't trigger recompiles when varying the size of the base tensor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141077 Approved by: https://github.com/ezyang --- test/inductor/test_torchinductor.py | 31 +++++++++++++++++++++++++++++ torch/_inductor/lowering.py | 12 +++++------ 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index e4aeae1a5ca84a..a215d4529ac3a4 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -11798,6 +11798,37 @@ def f(m, x): f"Ref:\n{ref_grad_list}\nAct:\n{act_grad_list}", ) + def test_chunk_recompiles(self): + def f(x): + return x.chunk(4) + + def run(size): + input = torch.randn(size) + expected_out = f(input) + actual_out = optf(input) + self.assertEqual(expected_out, actual_out) + + cnts = CompileCounterWithBackend("inductor") + optf = torch.compile(f, backend=cnts, fullgraph=True) + + # The first run should compile once with static shapes. + run(4) + self.assertEqual(cnts.frame_count, 1) + + # Varying the input size should trigger a recompilation. + # Since the input size is a multiple of 4 (i.e. all runs shall + # generate 4 output tensors), there should be no further + # recompilation. + for i in range(2, 12): + run(4 * i) + self.assertEqual(cnts.frame_count, 2) + + # Input size: 9 + # Yields one less output tensor, which should trigger a + # recompilation. + run(9) + self.assertEqual(cnts.frame_count, 3) + @dataclasses.dataclass class TestFailure: diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 4e4bea2613e70f..60fe25cc15544d 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1670,13 +1670,11 @@ def select(x, dim, idx): @register_lowering(aten.split, type_promotion_kind=None) def split(x, sizes, dim=0, clamp=True): dim = _validate_dim(x, dim, 0) - if isinstance(sizes, sympy.Expr): - # TODO: We don't have to guard on sizes per se, but the number - # of splits must stay constant - sizes = V.graph.sizevars.evaluate_static_shape(sizes) - if isinstance(sizes, (int, sympy.Integer)): - x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) - sizes = [sizes] * ((x_size + sizes - 1) // sizes) + if not isinstance(sizes, (list, tuple)): + chunks = V.graph.sizevars.evaluate_static_shape( + FloorDiv(x.get_size()[dim] + sizes - 1, sizes) + ) + sizes = [sizes] * chunks result = [] start = 0 for size in sizes: