-
Notifications
You must be signed in to change notification settings - Fork 23.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] Don't specialize split
on sizes
parameter.
#141077
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141077
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 06644fb with merge base 12e95aa (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Rebased |
Merge failedReason: New commits were pushed while merging. Please rerun the merge command. Details for Dev Infra teamRaised by workflow job |
This PR turns clamping off for the `split` operation. By doing so, we generate less bound guards and reduce the number of recompilation when varying the input size. ```python @torch.compile(dynamic=True) def f(x): return x.chunk(4) >>> f(torch.arange(12)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11])) >>> f(torch.arange(11)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10])) >>> f(torch.arange(10)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9])) ``` Pull Request resolved: #141078 Approved by: https://github.com/ezyang ghstack dependencies: #141077
…1077) Fix: pytorch#139936 This PR modifies the lowering of `split` operation, so that it won't generate guards, specializing on the sizes parameter. Instead, it specializes on the number of output tensors being generated (i.e. function of the size of the base tensor, and the sizes parameter). As a result, operations such as `chunk` (whose number of output tensors usually is constant given a static chunk number) won't trigger recompiles when varying the size of the base tensor. Pull Request resolved: pytorch#141077 Approved by: https://github.com/ezyang
This PR turns clamping off for the `split` operation. By doing so, we generate less bound guards and reduce the number of recompilation when varying the input size. ```python @torch.compile(dynamic=True) def f(x): return x.chunk(4) >>> f(torch.arange(12)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11])) >>> f(torch.arange(11)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10])) >>> f(torch.arange(10)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9])) ``` Pull Request resolved: pytorch#141078 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#141077
…1077) Fix: pytorch#139936 This PR modifies the lowering of `split` operation, so that it won't generate guards, specializing on the sizes parameter. Instead, it specializes on the number of output tensors being generated (i.e. function of the size of the base tensor, and the sizes parameter). As a result, operations such as `chunk` (whose number of output tensors usually is constant given a static chunk number) won't trigger recompiles when varying the size of the base tensor. Pull Request resolved: pytorch#141077 Approved by: https://github.com/ezyang
This PR turns clamping off for the `split` operation. By doing so, we generate less bound guards and reduce the number of recompilation when varying the input size. ```python @torch.compile(dynamic=True) def f(x): return x.chunk(4) >>> f(torch.arange(12)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11])) >>> f(torch.arange(11)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10])) >>> f(torch.arange(10)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9])) ``` Pull Request resolved: pytorch#141078 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#141077
…1077) Fix: pytorch#139936 This PR modifies the lowering of `split` operation, so that it won't generate guards, specializing on the sizes parameter. Instead, it specializes on the number of output tensors being generated (i.e. function of the size of the base tensor, and the sizes parameter). As a result, operations such as `chunk` (whose number of output tensors usually is constant given a static chunk number) won't trigger recompiles when varying the size of the base tensor. Pull Request resolved: pytorch#141077 Approved by: https://github.com/ezyang
This PR turns clamping off for the `split` operation. By doing so, we generate less bound guards and reduce the number of recompilation when varying the input size. ```python @torch.compile(dynamic=True) def f(x): return x.chunk(4) >>> f(torch.arange(12)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11])) >>> f(torch.arange(11)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10])) >>> f(torch.arange(10)) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9])) ``` Pull Request resolved: pytorch#141078 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#141077
Fix: #139936 This PR modifies the lowering of `split` operation, so that it won't generate guards, specializing on the sizes parameter. Instead, it specializes on the number of output tensors being generated (i.e. function of the size of the base tensor, and the sizes parameter). As a result, operations such as `chunk` (whose number of output tensors usually is constant given a static chunk number) won't trigger recompiles when varying the size of the base tensor. ghstack-source-id: 791bb9b7265fc455db5f7d4d99dcef5f2e919e87 Pull Request resolved: pytorch/pytorch#141077
Stack from ghstack (oldest at bottom):
split
operation. #141078split
onsizes
parameter. #141077Fix: #139936
This PR modifies the lowering of
split
operation, so that it won't generate guards,specializing on the sizes parameter. Instead, it specializes on the number of output
tensors being generated (i.e. function of the size of the base tensor, and the sizes
parameter).
As a result, operations such as
chunk
(whose number of output tensors usually isconstant given a static chunk number) won't trigger recompiles when varying the size of
the base tensor.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov