Skip to content

Commit

Permalink
Added backward batch rule for pad replicate/reflect modes
Browse files Browse the repository at this point in the history
Description:
- Added backward batch rule for pad replicate/reflect modes
- Updated tests

Related to pytorch#240
  • Loading branch information
vfdev-5 committed Nov 9, 2021
1 parent b94cfc7 commit f6d2d80
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 8 additions & 0 deletions functorch/csrc/BatchRulesModules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,14 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
EXISTING_BDIM(replication_pad2d);
EXISTING_BDIM(replication_pad3d);

EXISTING_BDIM_ALL_BOXED(replication_pad1d_backward);
EXISTING_BDIM_ALL_BOXED(replication_pad2d_backward);
EXISTING_BDIM_ALL_BOXED(replication_pad3d_backward);

EXISTING_BDIM_ALL_BOXED(reflection_pad1d_backward);
EXISTING_BDIM_ALL_BOXED(reflection_pad2d_backward);
EXISTING_BDIM_ALL_BOXED(reflection_pad3d_backward);

UPSAMPLE_BATCH(upsample_bicubic2d);
UPSAMPLE_BATCH(upsample_bilinear2d);
UPSAMPLE_BATCH(upsample_linear1d);
Expand Down
2 changes: 0 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,6 @@ def test_vmapvjp(self, device, dtype, op):
xfail('nn.functional.grid_sample'),
xfail('nn.functional.interpolate', 'area'),
xfail('nn.functional.pad', 'circular'),
xfail('nn.functional.pad', 'reflect'),
xfail('nn.functional.pad', 'replicate'),
xfail('nn.functional.unfold'),
xfail('norm', 'fro'),
xfail('norm', 'inf'),
Expand Down

0 comments on commit f6d2d80

Please sign in to comment.