From 1b6875b813333a1c927ae7b62018faaad357c788 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Sun, 14 Nov 2021 22:41:35 +0000 Subject: [PATCH] Added im2col batch rule and enabled vmap for nn.functional.unfold op Description: - Added im2col batch rule and enabled vmap for nn.functional.unfold op - Updated tests Using EXISTING_BDIM macro to put bdim into 0 as im2col expects dim=0 to be batch dim Related to #240 --- functorch/csrc/BatchRulesDecompositions.cpp | 6 +++--- functorch/csrc/BatchRulesModules.cpp | 3 +++ test/test_ops.py | 1 - test/test_vmap.py | 1 - 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/functorch/csrc/BatchRulesDecompositions.cpp b/functorch/csrc/BatchRulesDecompositions.cpp index eb230315e..409ca7302 100644 --- a/functorch/csrc/BatchRulesDecompositions.cpp +++ b/functorch/csrc/BatchRulesDecompositions.cpp @@ -87,6 +87,9 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { OP_DECOMPOSE2(multiply, Tensor ); OP_DECOMPOSE(narrow); OP_DECOMPOSE(negative); + OP_DECOMPOSE(nll_loss_nd); + OP_DECOMPOSE(nll_loss); + OP_DECOMPOSE(nll_loss2d); OP_DECOMPOSE2(not_equal, Tensor ); OP_DECOMPOSE(outer); OP_DECOMPOSE(pairwise_distance); @@ -125,9 +128,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { OP_DECOMPOSE(var_mean); OP_DECOMPOSE2(var_mean, dim); OP_DECOMPOSE2(where, self); - OP_DECOMPOSE(nll_loss_nd); - OP_DECOMPOSE(nll_loss); - OP_DECOMPOSE(nll_loss2d); } }} diff --git a/functorch/csrc/BatchRulesModules.cpp b/functorch/csrc/BatchRulesModules.cpp index 5001cc0a3..127c10ea0 100644 --- a/functorch/csrc/BatchRulesModules.cpp +++ b/functorch/csrc/BatchRulesModules.cpp @@ -396,6 +396,9 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { m.impl("conv2d", convNd_decomp); m.impl("conv3d", convNd_decomp); + EXISTING_BDIM(im2col); + EXISTING_BDIM(im2col_backward); + VMAP_SUPPORT("grid_sampler_2d", GRID_SAMPLE_BATCH_RULE(grid_sampler)); VMAP_SUPPORT("grid_sampler_3d", GRID_SAMPLE_BATCH_RULE(grid_sampler)); VMAP_SUPPORT("cudnn_grid_sampler", GRID_SAMPLE_BATCH_RULE(cudnn_grid_sampler)); diff --git a/test/test_ops.py b/test/test_ops.py index 6bff1ddd0..4370d400e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -456,7 +456,6 @@ def test_vmapvjp(self, device, dtype, op): xfail('nn.functional.gelu'), xfail('nn.functional.grid_sample'), xfail('nn.functional.pad', 'circular'), - xfail('nn.functional.unfold'), xfail('norm', 'fro'), xfail('norm', 'inf'), xfail('norm', 'nuc'), diff --git a/test/test_vmap.py b/test/test_vmap.py index a26e90e6a..d7c5d77cd 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3124,7 +3124,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('nn.functional.cross_entropy', 'none'), xfail('nn.functional.cross_entropy', 'sum'), xfail('nn.functional.pad', 'circular'), - xfail('nn.functional.unfold'), xfail('norm', 'fro'), xfail('norm', 'nuc'), xfail('ormqr'),