Skip to content

Commit

Permalink
[functorch] Added im2col batch rule and enabled vmap for nn.functiona…
Browse files Browse the repository at this point in the history
…l.unfold op (pytorch/functorch#262)

Description:
- Added im2col batch rule and enabled vmap for nn.functional.unfold op
- Updated tests

Using EXISTING_BDIM macro to put bdim into 0 as im2col expects dim=0 to be batch dim

Related to pytorch/functorch#240
  • Loading branch information
vfdev-5 authored and zou3519 committed Jul 20, 2022
1 parent ed20f40 commit 13c4548
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
6 changes: 3 additions & 3 deletions functorch/functorch/csrc/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
OP_DECOMPOSE2(multiply, Tensor );
OP_DECOMPOSE(narrow);
OP_DECOMPOSE(negative);
OP_DECOMPOSE(nll_loss_nd);
OP_DECOMPOSE(nll_loss);
OP_DECOMPOSE(nll_loss2d);
OP_DECOMPOSE2(not_equal, Tensor );
OP_DECOMPOSE(outer);
OP_DECOMPOSE(pairwise_distance);
Expand Down Expand Up @@ -125,9 +128,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
OP_DECOMPOSE(var_mean);
OP_DECOMPOSE2(var_mean, dim);
OP_DECOMPOSE2(where, self);
OP_DECOMPOSE(nll_loss_nd);
OP_DECOMPOSE(nll_loss);
OP_DECOMPOSE(nll_loss2d);
}

}}
Expand Down
3 changes: 3 additions & 0 deletions functorch/functorch/csrc/BatchRulesModules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,9 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
m.impl("conv2d", convNd_decomp);
m.impl("conv3d", convNd_decomp);

EXISTING_BDIM(im2col);
EXISTING_BDIM(im2col_backward);

VMAP_SUPPORT("grid_sampler_2d", GRID_SAMPLE_BATCH_RULE(grid_sampler));
VMAP_SUPPORT("grid_sampler_3d", GRID_SAMPLE_BATCH_RULE(grid_sampler));
VMAP_SUPPORT("cudnn_grid_sampler", GRID_SAMPLE_BATCH_RULE(cudnn_grid_sampler));
Expand Down
1 change: 0 additions & 1 deletion functorch/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,6 @@ def test_vmapvjp(self, device, dtype, op):
xfail('nn.functional.gelu'),
xfail('nn.functional.grid_sample'),
xfail('nn.functional.pad', 'circular'),
xfail('nn.functional.unfold'),
xfail('norm', 'fro'),
xfail('norm', 'inf'),
xfail('norm', 'nuc'),
Expand Down
1 change: 0 additions & 1 deletion functorch/test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3124,7 +3124,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('nn.functional.cross_entropy', 'none'),
xfail('nn.functional.cross_entropy', 'sum'),
xfail('nn.functional.pad', 'circular'),
xfail('nn.functional.unfold'),
xfail('norm', 'fro'),
xfail('norm', 'nuc'),
xfail('ormqr'),
Expand Down

0 comments on commit 13c4548

Please sign in to comment.