-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Actually actionable list of batching rules to write #240
Comments
@zou3519
Can we do something with that ? |
Aha. Nope, we can't do anything about that until functionalization is in, good catch. |
Description: - Added backward batch rule for pad replicate/reflect modes - Updated tests Related to pytorch#240
Description: - Added backward batch rule for pad replicate/reflect modes - Updated tests Related to #240
Hi @zou3519, the "forward pass only" ops above means that the vjp and related operators require the functionalization too? |
Description: - Added im2col batch rule and enabled vmap for nn.functional.unfold op - Updated tests Using EXISTING_BDIM macro to put bdim into 0 as im2col expects dim=0 to be batch dim Related to pytorch#240
Description: - Added adaptive_max_poolNd fw/bw batch rules - Updated tests Related to pytorch#240 Notes: I created two additional macros to handle adaptive_max_pool2d and adaptive_max_pool3d_backward. Not sure if we could make a generic rule to handle max_pool2d_with_indices_backward_batch_rule and adaptive_max_pool3d_backward, as max_pool2d_with_indices_backward_batch_rule requires some args in the middle between gradOutput, input and indices.
Yes, "forward pass only" means we should only try to get the |
Description: - Added adaptive_max_poolNd fw/bw batch rules - Updated tests Related to pytorch#240 Notes: I created two additional macros to handle adaptive_max_pool2d and adaptive_max_pool3d_backward. Not sure if we could make a generic rule to handle max_pool2d_with_indices_backward_batch_rule and adaptive_max_pool3d_backward, as max_pool2d_with_indices_backward_batch_rule requires some args in the middle between gradOutput, input and indices.
* Added adaptive_max_poolNd fw/bw batch rules Description: - Added adaptive_max_poolNd fw/bw batch rules - Updated tests Related to #240 Notes: I created two additional macros to handle adaptive_max_pool2d and adaptive_max_pool3d_backward. Not sure if we could make a generic rule to handle max_pool2d_with_indices_backward_batch_rule and adaptive_max_pool3d_backward, as max_pool2d_with_indices_backward_batch_rule requires some args in the middle between gradOutput, input and indices. * Replaced EXISTING_BDIM_MULTIOUT by EXISTING_BDIM_ALL_BOXED * Removed specific implementations with indices.contiguous() for - max_pool2d_with_indices_backward - adaptive_max_pool2d_backward - adaptive_max_pool3d_backward and added ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1 to handle that
@kshitij12345 on which tasks from Parcel 2 you are working on and plan to work on ? I can start working on |
@vfdev-5 I think I'll be picking |
I'll take |
Description: - Added batch rule for LU op - Updated tests Related to #240
To close this issue, it remains to finalize parcel 2: and in parcel 4:
|
There's always more batching rules to write, I'll put up a new issue for them later :) |
Note that |
@lezcano thanks for the update ! I see that |
It will indeed. And that's a good reminder for me to put up a PR doing so :D |
@zou3519 can we update description list with with was done. I think we can remove Parcel 4 from here and create new issue for that if needed. What remains here is to sync and merge householder product PR (#322), cc @kshitij12345 . |
Fwiw, following up on the point above on deprecating |
…pytorch/functorch#251) Description: - Added backward batch rule for pad replicate/reflect modes - Updated tests Related to pytorch/functorch#240
…l.unfold op (pytorch/functorch#262) Description: - Added im2col batch rule and enabled vmap for nn.functional.unfold op - Updated tests Using EXISTING_BDIM macro to put bdim into 0 as im2col expects dim=0 to be batch dim Related to pytorch/functorch#240
* Added adaptive_max_poolNd fw/bw batch rules Description: - Added adaptive_max_poolNd fw/bw batch rules - Updated tests Related to pytorch/functorch#240 Notes: I created two additional macros to handle adaptive_max_pool2d and adaptive_max_pool3d_backward. Not sure if we could make a generic rule to handle max_pool2d_with_indices_backward_batch_rule and adaptive_max_pool3d_backward, as max_pool2d_with_indices_backward_batch_rule requires some args in the middle between gradOutput, input and indices. * Replaced EXISTING_BDIM_MULTIOUT by EXISTING_BDIM_ALL_BOXED * Removed specific implementations with indices.contiguous() for - max_pool2d_with_indices_backward - adaptive_max_pool2d_backward - adaptive_max_pool3d_backward and added ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1 to handle that
Description: - Added cholesky_solve op batch rule - Updated tests Related to pytorch/functorch#240
Description: - Added addr op decomposition - Updated tests Related to pytorch/functorch#240
Description: - Added batch rule for LU op - Updated tests Related to pytorch/functorch#240
…pytorch/functorch#251) Description: - Added backward batch rule for pad replicate/reflect modes - Updated tests Related to pytorch/functorch#240
…l.unfold op (pytorch/functorch#262) Description: - Added im2col batch rule and enabled vmap for nn.functional.unfold op - Updated tests Using EXISTING_BDIM macro to put bdim into 0 as im2col expects dim=0 to be batch dim Related to pytorch/functorch#240
* Added adaptive_max_poolNd fw/bw batch rules Description: - Added adaptive_max_poolNd fw/bw batch rules - Updated tests Related to pytorch/functorch#240 Notes: I created two additional macros to handle adaptive_max_pool2d and adaptive_max_pool3d_backward. Not sure if we could make a generic rule to handle max_pool2d_with_indices_backward_batch_rule and adaptive_max_pool3d_backward, as max_pool2d_with_indices_backward_batch_rule requires some args in the middle between gradOutput, input and indices. * Replaced EXISTING_BDIM_MULTIOUT by EXISTING_BDIM_ALL_BOXED * Removed specific implementations with indices.contiguous() for - max_pool2d_with_indices_backward - adaptive_max_pool2d_backward - adaptive_max_pool3d_backward and added ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1 to handle that
Description: - Added cholesky_solve op batch rule - Updated tests Related to pytorch/functorch#240
Description: - Added addr op decomposition - Updated tests Related to pytorch/functorch#240
Description: - Added batch rule for LU op - Updated tests Related to pytorch/functorch#240
For each of the items here, we should make sure all compositions (vmap, vmap x vjp) have a batching rule. All of these items should be actionable (in that it is possible to write a batching rule and we are not blocked on functionalization, which is coming soon).
Note: you may need to write an OpInfo for the operator if it doesn't exist already or wait for one to be added. A lot of folks are adding OpInfos right now, so if the OpInfo doesn't exist please ask first to see if someone is working on it.
Note: if any of the operations decompose into in-place operations, then we need functionalization to handle them. I think I've already filtered out all of those, but please check me on that.
Parcel 1: top nn.functional.* and top torch.* foo
adaptive_avg_pool{1, 2, 3}d
as well as their backward variants while we're at it)logical_{and, or, xor}
if those don't exist yet. We may need to also add a change to PyTorch core to make the logical_* functions primitives w.r.t. autogradParcel 2: new_blah
adaptive_max_pool{1, 2, 3}d
as well as the backward variantsParcel 3: linalg things
Parcel 4:
index_select, index_copy, etc, all need a backward formula in pytorch/pytorch vmap over composite out-of-place ops whose in-place variant is non-composite #260The text was updated successfully, but these errors were encountered: