Skip to content

Commit

Permalink
[inductor] Don't specialize split on sizes parameter. (pytorch#14…
Browse files Browse the repository at this point in the history
…1077)

Fix: pytorch#139936

This PR modifies the lowering of `split` operation, so that it won't generate guards,
specializing on the sizes parameter. Instead, it specializes on the number of output
tensors being generated (i.e. function of the size of the base tensor, and the sizes
parameter).

As a result, operations such as `chunk` (whose number of output tensors usually is
constant given a static chunk number) won't trigger recompiles when varying the size of
the base tensor.

Pull Request resolved: pytorch#141077
Approved by: https://github.com/ezyang
  • Loading branch information
ysiraichi authored and pobin6 committed Dec 5, 2024
1 parent 1b6d610 commit 569d040
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
31 changes: 31 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11798,6 +11798,37 @@ def f(m, x):
f"Ref:\n{ref_grad_list}\nAct:\n{act_grad_list}",
)

def test_chunk_recompiles(self):
def f(x):
return x.chunk(4)

def run(size):
input = torch.randn(size)
expected_out = f(input)
actual_out = optf(input)
self.assertEqual(expected_out, actual_out)

cnts = CompileCounterWithBackend("inductor")
optf = torch.compile(f, backend=cnts, fullgraph=True)

# The first run should compile once with static shapes.
run(4)
self.assertEqual(cnts.frame_count, 1)

# Varying the input size should trigger a recompilation.
# Since the input size is a multiple of 4 (i.e. all runs shall
# generate 4 output tensors), there should be no further
# recompilation.
for i in range(2, 12):
run(4 * i)
self.assertEqual(cnts.frame_count, 2)

# Input size: 9
# Yields one less output tensor, which should trigger a
# recompilation.
run(9)
self.assertEqual(cnts.frame_count, 3)


@dataclasses.dataclass
class TestFailure:
Expand Down
12 changes: 5 additions & 7 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1670,13 +1670,11 @@ def select(x, dim, idx):
@register_lowering(aten.split, type_promotion_kind=None)
def split(x, sizes, dim=0, clamp=True):
dim = _validate_dim(x, dim, 0)
if isinstance(sizes, sympy.Expr):
# TODO: We don't have to guard on sizes per se, but the number
# of splits must stay constant
sizes = V.graph.sizevars.evaluate_static_shape(sizes)
if isinstance(sizes, (int, sympy.Integer)):
x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim])
sizes = [sizes] * ((x_size + sizes - 1) // sizes)
if not isinstance(sizes, (list, tuple)):
chunks = V.graph.sizevars.evaluate_static_shape(
FloorDiv(x.get_size()[dim] + sizes - 1, sizes)
)
sizes = [sizes] * chunks
result = []
start = 0
for size in sizes:
Expand Down

0 comments on commit 569d040

Please sign in to comment.