Skip to content

Commit

Permalink
Added batch rule for (adaptive_)avg_poolNd (#248)
Browse files Browse the repository at this point in the history
* Added batch rule for (adaptive_)avg_poolNd
Description:
- Added batch rule for adaptive_avg_pool{1d,2d,3d} and avg_pool{1d,2d,3d}
- Updated tests

* Enabled nn.functional.interpolate mode=area
  • Loading branch information
vfdev-5 authored Nov 12, 2021
1 parent fedd426 commit 5423fdd
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 10 deletions.
3 changes: 3 additions & 0 deletions functorch/csrc/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ namespace at { namespace functorch {

TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
OP_DECOMPOSE(absolute);
OP_DECOMPOSE(avg_pool1d);
OP_DECOMPOSE(adaptive_avg_pool1d);
OP_DECOMPOSE(adaptive_avg_pool2d);
OP_DECOMPOSE(adaptive_avg_pool3d);
OP_DECOMPOSE(arccos);
OP_DECOMPOSE(arccosh);
OP_DECOMPOSE(arcsin);
Expand Down
5 changes: 5 additions & 0 deletions functorch/csrc/BatchRulesPooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,13 @@ max_pool2d_with_indices_batch_rule(

TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
EXISTING_BDIM(_adaptive_avg_pool2d);
EXISTING_BDIM_ALL_BOXED(_adaptive_avg_pool2d_backward);
EXISTING_BDIM(_adaptive_avg_pool3d);
EXISTING_BDIM_ALL_BOXED(_adaptive_avg_pool3d_backward);
EXISTING_BDIM(avg_pool2d);
EXISTING_BDIM(avg_pool3d);
EXISTING_BDIM_ALL_BOXED(avg_pool2d_backward);
EXISTING_BDIM_ALL_BOXED(avg_pool3d_backward);
VMAP_SUPPORT("max_pool2d_with_indices", max_pool2d_with_indices_batch_rule);
VMAP_SUPPORT("max_pool2d_with_indices_backward", max_pool2d_with_indices_backward_batch_rule);
}
Expand Down
5 changes: 0 additions & 5 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,11 +447,9 @@ def test_vmapvjp(self, device, dtype, op):
xfail('msort'),
xfail('nanmedian'),
xfail('nanquantile'),
xfail('nn.functional.adaptive_avg_pool2d'),
xfail('nn.functional.conv_transpose2d'),
xfail('nn.functional.gelu'),
xfail('nn.functional.grid_sample'),
xfail('nn.functional.interpolate', 'area'),
xfail('nn.functional.pad', 'circular'),
xfail('nn.functional.unfold'),
xfail('norm', 'fro'),
Expand Down Expand Up @@ -487,9 +485,6 @@ def test_vmapvjp(self, device, dtype, op):
xfail('fft.ihfft2'),
xfail('fft.ihfftn'),
xfail('fft.rfft2'),
xfail('nn.functional.adaptive_avg_pool1d'),
xfail('nn.functional.adaptive_avg_pool3d'),
xfail('nn.functional.avg_pool3d'),
xfail('nn.functional.embedding'),
}))
def test_vmapvjp_has_batch_rule(self, device, dtype, op):
Expand Down
5 changes: 0 additions & 5 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3088,7 +3088,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('nn.functional.cross_entropy', 'mean'),
xfail('nn.functional.cross_entropy', 'none'),
xfail('nn.functional.cross_entropy', 'sum'),
xfail('nn.functional.interpolate', 'area'),
xfail('nn.functional.pad', 'circular'),
xfail('nn.functional.unfold'),
xfail('norm', 'fro'),
Expand Down Expand Up @@ -3137,10 +3136,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('fft.rfft2'),
xfail('isinf'),
xfail('isreal'),
xfail('nn.functional.adaptive_avg_pool1d'),
xfail('nn.functional.adaptive_avg_pool3d'),
xfail('nn.functional.avg_pool1d'),
xfail('nn.functional.avg_pool3d'),
xfail('nn.functional.pixel_shuffle'),
xfail('nn.functional.pixel_unshuffle'),
}))
Expand Down

0 comments on commit 5423fdd

Please sign in to comment.