Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Don't specialize split on sizes parameter. #141077

Closed
wants to merge 3 commits into from

Conversation

ysiraichi
Copy link
Collaborator

@ysiraichi ysiraichi commented Nov 20, 2024

Stack from ghstack (oldest at bottom):

Fix: #139936

This PR modifies the lowering of split operation, so that it won't generate guards,
specializing on the sizes parameter. Instead, it specializes on the number of output
tensors being generated (i.e. function of the size of the base tensor, and the sizes
parameter).

As a result, operations such as chunk (whose number of output tensors usually is
constant given a static chunk number) won't trigger recompiles when varying the size of
the base tensor.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Nov 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141077

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 06644fb with merge base 12e95aa (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

@ezyang
Copy link
Contributor

ezyang commented Nov 20, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 20, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Rebased gh/ysiraichi/75/orig onto refs/remotes/origin/viable/strict because #141078 was rebased, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/141077)

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Nov 21, 2024
This PR turns clamping off for the `split` operation. By doing so, we generate less bound
guards and reduce the number of recompilation when varying the input size.

```python
@torch.compile(dynamic=True)
def f(x):
    return x.chunk(4)

>>> f(torch.arange(12))
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11]))

>>> f(torch.arange(11))

(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10]))

>>> f(torch.arange(10))
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9]))
```

Pull Request resolved: #141078
Approved by: https://github.com/ezyang
ghstack dependencies: #141077
youssef62 pushed a commit to youssef62/pytorch that referenced this pull request Nov 23, 2024
…1077)

Fix: pytorch#139936

This PR modifies the lowering of `split` operation, so that it won't generate guards,
specializing on the sizes parameter. Instead, it specializes on the number of output
tensors being generated (i.e. function of the size of the base tensor, and the sizes
parameter).

As a result, operations such as `chunk` (whose number of output tensors usually is
constant given a static chunk number) won't trigger recompiles when varying the size of
the base tensor.

Pull Request resolved: pytorch#141077
Approved by: https://github.com/ezyang
youssef62 pushed a commit to youssef62/pytorch that referenced this pull request Nov 23, 2024
This PR turns clamping off for the `split` operation. By doing so, we generate less bound
guards and reduce the number of recompilation when varying the input size.

```python
@torch.compile(dynamic=True)
def f(x):
    return x.chunk(4)

>>> f(torch.arange(12))
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11]))

>>> f(torch.arange(11))

(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10]))

>>> f(torch.arange(10))
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9]))
```

Pull Request resolved: pytorch#141078
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#141077
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Dec 2, 2024
…1077)

Fix: pytorch#139936

This PR modifies the lowering of `split` operation, so that it won't generate guards,
specializing on the sizes parameter. Instead, it specializes on the number of output
tensors being generated (i.e. function of the size of the base tensor, and the sizes
parameter).

As a result, operations such as `chunk` (whose number of output tensors usually is
constant given a static chunk number) won't trigger recompiles when varying the size of
the base tensor.

Pull Request resolved: pytorch#141077
Approved by: https://github.com/ezyang
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Dec 2, 2024
This PR turns clamping off for the `split` operation. By doing so, we generate less bound
guards and reduce the number of recompilation when varying the input size.

```python
@torch.compile(dynamic=True)
def f(x):
    return x.chunk(4)

>>> f(torch.arange(12))
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11]))

>>> f(torch.arange(11))

(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10]))

>>> f(torch.arange(10))
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9]))
```

Pull Request resolved: pytorch#141078
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#141077
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…1077)

Fix: pytorch#139936

This PR modifies the lowering of `split` operation, so that it won't generate guards,
specializing on the sizes parameter. Instead, it specializes on the number of output
tensors being generated (i.e. function of the size of the base tensor, and the sizes
parameter).

As a result, operations such as `chunk` (whose number of output tensors usually is
constant given a static chunk number) won't trigger recompiles when varying the size of
the base tensor.

Pull Request resolved: pytorch#141077
Approved by: https://github.com/ezyang
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
This PR turns clamping off for the `split` operation. By doing so, we generate less bound
guards and reduce the number of recompilation when varying the input size.

```python
@torch.compile(dynamic=True)
def f(x):
    return x.chunk(4)

>>> f(torch.arange(12))
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11]))

>>> f(torch.arange(11))

(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10]))

>>> f(torch.arange(10))
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9]))
```

Pull Request resolved: pytorch#141078
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#141077
Esquains pushed a commit to Esquains/study1 that referenced this pull request Dec 15, 2024
Fix: #139936

This PR modifies the lowering of `split` operation, so that it won't generate guards,
specializing on the sizes parameter. Instead, it specializes on the number of output
tensors being generated (i.e. function of the size of the base tensor, and the sizes
parameter).

As a result, operations such as `chunk` (whose number of output tensors usually is
constant given a static chunk number) won't trigger recompiles when varying the size of
the base tensor.

ghstack-source-id: 791bb9b7265fc455db5f7d4d99dcef5f2e919e87
Pull Request resolved: pytorch/pytorch#141077
@github-actions github-actions bot deleted the gh/ysiraichi/74/head branch December 22, 2024 02:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants