Skip to content

Commit

Permalink
[inductor] Don't clamp on split operation. (pytorch#141078)
Browse files Browse the repository at this point in the history
This PR turns clamping off for the `split` operation. By doing so, we generate less bound
guards and reduce the number of recompilation when varying the input size.

```python
@torch.compile(dynamic=True)
def f(x):
    return x.chunk(4)

>>> f(torch.arange(12))
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11]))

>>> f(torch.arange(11))

(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10]))

>>> f(torch.arange(10))
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9]))
```

Pull Request resolved: pytorch#141078
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#141077
  • Loading branch information
ysiraichi authored and pobin6 committed Dec 5, 2024
1 parent 569d040 commit cd2be98
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
17 changes: 16 additions & 1 deletion test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11802,6 +11802,8 @@ def test_chunk_recompiles(self):
def f(x):
return x.chunk(4)

# Runs f and its torch.compile-d version with a fresh 1D tensor
# of a specific size, and checks that the result is correct.
def run(size):
input = torch.randn(size)
expected_out = f(input)
Expand All @@ -11823,11 +11825,24 @@ def run(size):
run(4 * i)
self.assertEqual(cnts.frame_count, 2)

# Input size: 11
# Not a multiple of 4, but still generates 4 output tensors,
# where the last one has size > 1.
run(11)
self.assertEqual(cnts.frame_count, 2)

# Input size: 10
# Even though it still generates 4 output tensors, the last
# one has size 1, falling into our 0/1 specialization. Thus,
# this one also triggers recompilation.
run(10)
self.assertEqual(cnts.frame_count, 3)

# Input size: 9
# Yields one less output tensor, which should trigger a
# recompilation.
run(9)
self.assertEqual(cnts.frame_count, 3)
self.assertEqual(cnts.frame_count, 4)


@dataclasses.dataclass
Expand Down
24 changes: 18 additions & 6 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,25 +1668,37 @@ def select(x, dim, idx):


@register_lowering(aten.split, type_promotion_kind=None)
def split(x, sizes, dim=0, clamp=True):
def split(x, sizes, dim=0):
dim = _validate_dim(x, dim, 0)
sizes_ = sizes

# If sizes is an integer (or a SymInt), we turn it into a list of sizes
# by computing what the actual size of each chunk should be.
if not isinstance(sizes, (list, tuple)):
x_size = x.get_size()[dim]
chunks = V.graph.sizevars.evaluate_static_shape(
FloorDiv(x.get_size()[dim] + sizes - 1, sizes)
FloorDiv(x_size + sizes - 1, sizes)
)
sizes = [sizes] * chunks
sizes_ = [sizes] * chunks
# The last chunk might have a smaller size than the rest.
sizes_[-1] = x_size - (chunks - 1) * sizes

# From this point, we assume that the sum of the sizes of all chunks
# equals the size of the base tensor.
result = []
start = 0
for size in sizes:
for size in sizes_:
end = start + size
result.append(slice_(x, dim, start, end, clamp=clamp))
# No need for clamping here, since we compute the exact
# start and end values.
result.append(slice_(x, dim, start, end, clamp=False))
start = end
return result


@register_lowering(aten.split_with_sizes, type_promotion_kind=None)
def split_with_sizes(x, sizes, dim=0):
return split(x, sizes, dim, clamp=False)
return split(x, sizes, dim)


@register_lowering(aten.unbind, type_promotion_kind=None)
Expand Down

0 comments on commit cd2be98

Please sign in to comment.