Skip to content

Commit

Permalink
[inductor] Don't specialize split on sizes parameter.
Browse files Browse the repository at this point in the history
Fix: #139936

This PR modifies the lowering of `split` operation, so that it won't generate guards,
specializing on the sizes parameter. Instead, it specializes on the number of output
tensors being generated (i.e. function of the size of the base tensor, and the sizes
parameter).

As a result, operations such as `chunk` (whose number of output tensors usually is
constant given a static chunk number) won't trigger recompiles when varying the size of
the base tensor.

ghstack-source-id: 791bb9b7265fc455db5f7d4d99dcef5f2e919e87
Pull Request resolved: pytorch/pytorch#141077
  • Loading branch information
ysiraichi committed Nov 20, 2024
1 parent 12e95aa commit 8b432e5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
31 changes: 31 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11784,6 +11784,37 @@ def f(m, x):
f"Ref:\n{ref_grad_list}\nAct:\n{act_grad_list}",
)

def test_chunk_recompiles(self):
def f(x):
return x.chunk(4)

def run(size):
input = torch.randn(size)
expected_out = f(input)
actual_out = optf(input)
self.assertEqual(expected_out, actual_out)

cnts = CompileCounterWithBackend("inductor")
optf = torch.compile(f, backend=cnts, fullgraph=True)

# The first run should compile once with static shapes.
run(4)
self.assertEqual(cnts.frame_count, 1)

# Varying the input size should trigger a recompilation.
# Since the input size is a multiple of 4 (i.e. all runs shall
# generate 4 output tensors), there should be no further
# recompilation.
for i in range(2, 12):
run(4 * i)
self.assertEqual(cnts.frame_count, 2)

# Input size: 9
# Yields one less output tensor, which should trigger a
# recompilation.
run(9)
self.assertEqual(cnts.frame_count, 3)


@dataclasses.dataclass
class TestFailure:
Expand Down
12 changes: 5 additions & 7 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,13 +1666,11 @@ def select(x, dim, idx):
@register_lowering(aten.split, type_promotion_kind=None)
def split(x, sizes, dim=0, clamp=True):
dim = _validate_dim(x, dim, 0)
if isinstance(sizes, sympy.Expr):
# TODO: We don't have to guard on sizes per se, but the number
# of splits must stay constant
sizes = V.graph.sizevars.evaluate_static_shape(sizes)
if isinstance(sizes, (int, sympy.Integer)):
x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim])
sizes = [sizes] * ((x_size + sizes - 1) // sizes)
if not isinstance(sizes, (list, tuple)):
chunks = V.graph.sizevars.evaluate_static_shape(
FloorDiv(x.get_size()[dim] + sizes - 1, sizes)
)
sizes = [sizes] * chunks
result = []
start = 0
for size in sizes:
Expand Down

0 comments on commit 8b432e5

Please sign in to comment.