Skip to content

Commit

Permalink
[functorch] Added backward batch rule for pad replicate/reflect modes (
Browse files Browse the repository at this point in the history
…pytorch/functorch#251)

Description:
- Added backward batch rule for pad replicate/reflect modes
- Updated tests

Related to pytorch/functorch#240
  • Loading branch information
vfdev-5 authored and zou3519 committed Jul 20, 2022
1 parent 7a42ec1 commit 05ec481
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 8 additions & 0 deletions functorch/functorch/csrc/BatchRulesModules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,14 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
EXISTING_BDIM(replication_pad2d);
EXISTING_BDIM(replication_pad3d);

EXISTING_BDIM_ALL_BOXED(replication_pad1d_backward);
EXISTING_BDIM_ALL_BOXED(replication_pad2d_backward);
EXISTING_BDIM_ALL_BOXED(replication_pad3d_backward);

EXISTING_BDIM_ALL_BOXED(reflection_pad1d_backward);
EXISTING_BDIM_ALL_BOXED(reflection_pad2d_backward);
EXISTING_BDIM_ALL_BOXED(reflection_pad3d_backward);

UPSAMPLE_BATCH(upsample_bicubic2d);
UPSAMPLE_BATCH(upsample_bilinear2d);
UPSAMPLE_BATCH(upsample_linear1d);
Expand Down
2 changes: 0 additions & 2 deletions functorch/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,6 @@ def test_vmapvjp(self, device, dtype, op):
xfail('nn.functional.grid_sample'),
xfail('nn.functional.interpolate', 'area'),
xfail('nn.functional.pad', 'circular'),
xfail('nn.functional.pad', 'reflect'),
xfail('nn.functional.pad', 'replicate'),
xfail('nn.functional.unfold'),
xfail('norm', 'fro'),
xfail('norm', 'inf'),
Expand Down

0 comments on commit 05ec481

Please sign in to comment.