From cd2be98fc4b3220e9340eb28c0c87192cbdbfc48 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 20 Nov 2024 18:44:43 -0300 Subject: [PATCH] [inductor] Don't clamp on `split` operation. (#141078) This PR turns clamping off for the `split` operation. By doing so, we generate less bound guards and reduce the number of recompilation when varying the input size. ```python @torch.compile(dynamic=True) def f(x): return x.chunk(4) >>> f(torch.arange(12)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11])) >>> f(torch.arange(11)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10])) >>> f(torch.arange(10)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9])) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/141078 Approved by: https://github.com/ezyang ghstack dependencies: #141077 --- test/inductor/test_torchinductor.py | 17 ++++++++++++++++- torch/_inductor/lowering.py | 24 ++++++++++++++++++------ 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index a215d4529ac3a..6b6e69b5d0576 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -11802,6 +11802,8 @@ def test_chunk_recompiles(self): def f(x): return x.chunk(4) + # Runs f and its torch.compile-d version with a fresh 1D tensor + # of a specific size, and checks that the result is correct. def run(size): input = torch.randn(size) expected_out = f(input) @@ -11823,11 +11825,24 @@ def run(size): run(4 * i) self.assertEqual(cnts.frame_count, 2) + # Input size: 11 + # Not a multiple of 4, but still generates 4 output tensors, + # where the last one has size > 1. + run(11) + self.assertEqual(cnts.frame_count, 2) + + # Input size: 10 + # Even though it still generates 4 output tensors, the last + # one has size 1, falling into our 0/1 specialization. Thus, + # this one also triggers recompilation. + run(10) + self.assertEqual(cnts.frame_count, 3) + # Input size: 9 # Yields one less output tensor, which should trigger a # recompilation. run(9) - self.assertEqual(cnts.frame_count, 3) + self.assertEqual(cnts.frame_count, 4) @dataclasses.dataclass diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 60fe25cc15544..995ad3e75f94b 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1668,25 +1668,37 @@ def select(x, dim, idx): @register_lowering(aten.split, type_promotion_kind=None) -def split(x, sizes, dim=0, clamp=True): +def split(x, sizes, dim=0): dim = _validate_dim(x, dim, 0) + sizes_ = sizes + + # If sizes is an integer (or a SymInt), we turn it into a list of sizes + # by computing what the actual size of each chunk should be. if not isinstance(sizes, (list, tuple)): + x_size = x.get_size()[dim] chunks = V.graph.sizevars.evaluate_static_shape( - FloorDiv(x.get_size()[dim] + sizes - 1, sizes) + FloorDiv(x_size + sizes - 1, sizes) ) - sizes = [sizes] * chunks + sizes_ = [sizes] * chunks + # The last chunk might have a smaller size than the rest. + sizes_[-1] = x_size - (chunks - 1) * sizes + + # From this point, we assume that the sum of the sizes of all chunks + # equals the size of the base tensor. result = [] start = 0 - for size in sizes: + for size in sizes_: end = start + size - result.append(slice_(x, dim, start, end, clamp=clamp)) + # No need for clamping here, since we compute the exact + # start and end values. + result.append(slice_(x, dim, start, end, clamp=False)) start = end return result @register_lowering(aten.split_with_sizes, type_promotion_kind=None) def split_with_sizes(x, sizes, dim=0): - return split(x, sizes, dim, clamp=False) + return split(x, sizes, dim) @register_lowering(aten.unbind, type_promotion_kind=None)