From 66a851afc4cf901424c4bd7aa0486da0765d6da3 Mon Sep 17 00:00:00 2001 From: vfdev Date: Wed, 10 Nov 2021 23:39:45 +0100 Subject: [PATCH] [functorch] Fixed nn.functional.pad constant mode (pytorch/functorch#249) * Fixed nn.functional.pad constant mode Description: - Fixed nn.functional.pad constant mode - Updated tests * Fixed issues with unexpected failures for fft tests * Update BatchRulesModules.cpp --- functorch/functorch/csrc/BatchRulesModules.cpp | 2 +- functorch/test/test_ops.py | 14 -------------- functorch/test/test_vmap.py | 6 ------ 3 files changed, 1 insertion(+), 21 deletions(-) diff --git a/functorch/functorch/csrc/BatchRulesModules.cpp b/functorch/functorch/csrc/BatchRulesModules.cpp index 03cfd9f096a42..5001cc0a3e2fd 100644 --- a/functorch/functorch/csrc/BatchRulesModules.cpp +++ b/functorch/functorch/csrc/BatchRulesModules.cpp @@ -401,7 +401,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { VMAP_SUPPORT("cudnn_grid_sampler", GRID_SAMPLE_BATCH_RULE(cudnn_grid_sampler)); VMAP_SUPPORT("cross", cross_batch_rule); - UNARY_POINTWISE(constant_pad_nd); + VARIADIC_BDIMS(constant_pad_nd); EXISTING_BDIM(reflection_pad1d); EXISTING_BDIM(reflection_pad2d); EXISTING_BDIM(reflection_pad3d); diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py index 542a32965214b..4e156196fab59 100644 --- a/functorch/test/test_ops.py +++ b/functorch/test/test_ops.py @@ -310,17 +310,12 @@ def vjp_of_vjp(*args_and_cotangents): xfail('diag_embed'), xfail('eig'), xfail('nn.functional.conv_transpose2d'), - xfail('nn.functional.pad', 'constant'), xfail('view_as_complex'), - xfail('fft.fft'), - xfail('fft.ifft'), xfail('fft.ihfft'), xfail('fft.ihfft'), xfail('fft.rfft'), xfail('fft.rfft'), - xfail('fft.fftn'), xfail('fft.rfftn'), - xfail('fft.ifftn'), xfail('cdist'), xfail('fmax'), xfail('fmin'), @@ -357,8 +352,6 @@ def vjp_of_vjp(*args_and_cotangents): xfail('nanmean'), xfail('block_diag'), xfail('nn.functional.dropout'), - xfail('fft.fft2'), - xfail('fft.ifft2'), xfail('fft.ihfft2'), xfail('fft.ihfftn'), xfail('fft.rfft2'), @@ -388,7 +381,6 @@ def test_vmapvjp(self, device, dtype, op): @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) @skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({ - xfail('nn.functional.pad', 'constant'), xfail('view_as_complex'), xfail('__getitem__'), xfail('__rpow__'), @@ -406,10 +398,6 @@ def test_vmapvjp(self, device, dtype, op): xfail('diag'), xfail('diag_embed'), xfail('eig'), - xfail('fft.fft'), - xfail('fft.fftn'), - xfail('fft.ifft'), - xfail('fft.ifftn'), xfail('fft.ihfft'), xfail('fft.rfft'), xfail('fft.rfftn'), @@ -496,8 +484,6 @@ def test_vmapvjp(self, device, dtype, op): xfail('_masked.sum'), xfail('_masked.prod'), xfail('cholesky_solve'), - xfail('fft.fft2'), - xfail('fft.ifft2'), xfail('fft.ihfft2'), xfail('fft.ihfftn'), xfail('fft.rfft2'), diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py index 8cffb0674f8d7..d4d5a42a8cb63 100644 --- a/functorch/test/test_vmap.py +++ b/functorch/test/test_vmap.py @@ -3013,14 +3013,8 @@ class TestVmapOperatorsOpInfo(TestCase): xfail('linalg.svd', device_type='cuda'), xfail('index_put'), xfail('matrix_exp'), - xfail('fft.fft'), - xfail('fft.ifft'), - xfail('fft.ihfft'), - xfail('fft.rfft'), - xfail('fft.rfftn'), xfail('nn.functional.batch_norm'), xfail('lu_unpack'), - xfail('nn.functional.pad', 'constant'), xfail('empty_like'), xfail('histogramdd'), xfail('nn.functional.embedding'),