Skip to content

Commit

Permalink
[functorch] Fixed nn.functional.pad constant mode (pytorch/functorch#249
Browse files Browse the repository at this point in the history
)

* Fixed nn.functional.pad constant mode
Description:
- Fixed nn.functional.pad constant mode
- Updated tests

* Fixed issues with unexpected failures for fft tests

* Update BatchRulesModules.cpp
  • Loading branch information
vfdev-5 authored and zou3519 committed Jul 20, 2022
1 parent 05ec481 commit 66a851a
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 21 deletions.
2 changes: 1 addition & 1 deletion functorch/functorch/csrc/BatchRulesModules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("cudnn_grid_sampler", GRID_SAMPLE_BATCH_RULE(cudnn_grid_sampler));
VMAP_SUPPORT("cross", cross_batch_rule);

UNARY_POINTWISE(constant_pad_nd);
VARIADIC_BDIMS(constant_pad_nd);
EXISTING_BDIM(reflection_pad1d);
EXISTING_BDIM(reflection_pad2d);
EXISTING_BDIM(reflection_pad3d);
Expand Down
14 changes: 0 additions & 14 deletions functorch/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,17 +310,12 @@ def vjp_of_vjp(*args_and_cotangents):
xfail('diag_embed'),
xfail('eig'),
xfail('nn.functional.conv_transpose2d'),
xfail('nn.functional.pad', 'constant'),
xfail('view_as_complex'),
xfail('fft.fft'),
xfail('fft.ifft'),
xfail('fft.ihfft'),
xfail('fft.ihfft'),
xfail('fft.rfft'),
xfail('fft.rfft'),
xfail('fft.fftn'),
xfail('fft.rfftn'),
xfail('fft.ifftn'),
xfail('cdist'),
xfail('fmax'),
xfail('fmin'),
Expand Down Expand Up @@ -357,8 +352,6 @@ def vjp_of_vjp(*args_and_cotangents):
xfail('nanmean'),
xfail('block_diag'),
xfail('nn.functional.dropout'),
xfail('fft.fft2'),
xfail('fft.ifft2'),
xfail('fft.ihfft2'),
xfail('fft.ihfftn'),
xfail('fft.rfft2'),
Expand Down Expand Up @@ -388,7 +381,6 @@ def test_vmapvjp(self, device, dtype, op):

@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
@skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({
xfail('nn.functional.pad', 'constant'),
xfail('view_as_complex'),
xfail('__getitem__'),
xfail('__rpow__'),
Expand All @@ -406,10 +398,6 @@ def test_vmapvjp(self, device, dtype, op):
xfail('diag'),
xfail('diag_embed'),
xfail('eig'),
xfail('fft.fft'),
xfail('fft.fftn'),
xfail('fft.ifft'),
xfail('fft.ifftn'),
xfail('fft.ihfft'),
xfail('fft.rfft'),
xfail('fft.rfftn'),
Expand Down Expand Up @@ -496,8 +484,6 @@ def test_vmapvjp(self, device, dtype, op):
xfail('_masked.sum'),
xfail('_masked.prod'),
xfail('cholesky_solve'),
xfail('fft.fft2'),
xfail('fft.ifft2'),
xfail('fft.ihfft2'),
xfail('fft.ihfftn'),
xfail('fft.rfft2'),
Expand Down
6 changes: 0 additions & 6 deletions functorch/test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3013,14 +3013,8 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail('linalg.svd', device_type='cuda'),
xfail('index_put'),
xfail('matrix_exp'),
xfail('fft.fft'),
xfail('fft.ifft'),
xfail('fft.ihfft'),
xfail('fft.rfft'),
xfail('fft.rfftn'),
xfail('nn.functional.batch_norm'),
xfail('lu_unpack'),
xfail('nn.functional.pad', 'constant'),
xfail('empty_like'),
xfail('histogramdd'),
xfail('nn.functional.embedding'),
Expand Down

0 comments on commit 66a851a

Please sign in to comment.