Skip to content

Commit

Permalink
Added adaptive_max_poolNd batch rule (#263)
Browse files Browse the repository at this point in the history
* Added adaptive_max_poolNd fw/bw batch rules
Description:
- Added adaptive_max_poolNd fw/bw batch rules
- Updated tests

Related to #240

Notes:
I created two additional macros to handle adaptive_max_pool2d and adaptive_max_pool3d_backward.
Not sure if we could make a generic rule to handle max_pool2d_with_indices_backward_batch_rule and adaptive_max_pool3d_backward,
as max_pool2d_with_indices_backward_batch_rule requires some args in the middle between gradOutput, input and indices.

* Replaced EXISTING_BDIM_MULTIOUT by EXISTING_BDIM_ALL_BOXED

* Removed specific implementations with indices.contiguous() for
- max_pool2d_with_indices_backward
- adaptive_max_pool2d_backward
- adaptive_max_pool3d_backward
and added ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1 to handle that
  • Loading branch information
vfdev-5 authored Nov 23, 2021
1 parent 10df6ca commit da27398
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 42 deletions.
1 change: 1 addition & 0 deletions functorch/csrc/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace at { namespace functorch {
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
OP_DECOMPOSE(absolute);
OP_DECOMPOSE(avg_pool1d);
OP_DECOMPOSE(adaptive_max_pool1d);
OP_DECOMPOSE(adaptive_avg_pool1d);
OP_DECOMPOSE(adaptive_avg_pool2d);
OP_DECOMPOSE(adaptive_avg_pool3d);
Expand Down
22 changes: 19 additions & 3 deletions functorch/csrc/BatchRulesHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ inline void boxed_existing_bdim_all_batch_rule(
#define EXISTING_BDIM_ALL_BOXED(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());

template <int64_t feature_rank>
template <int64_t feature_rank, int64_t contig_tensor_index=-1>
inline void boxed_all_tensors_have_optional_bdim(
const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
Expand Down Expand Up @@ -302,11 +302,19 @@ inline void boxed_all_tensors_have_optional_bdim(
}
if (*is_no_batch_dim_case) {
TORCH_INTERNAL_ASSERT(logical_rank == feature_rank);
(*stack)[args_begin + tensor_pos[tensor_idx]] = moveBatchDimToFront(value_, bdim);
value_ = moveBatchDimToFront(value_, bdim);
if (tensor_idx == contig_tensor_index) {
value_ = value_.contiguous();
}
(*stack)[args_begin + tensor_pos[tensor_idx]] = value_;
continue;
}
TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
(*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(*bdim, 0, value_);
value_ = reshape_dim_into(*bdim, 0, value_);
if (tensor_idx == contig_tensor_index) {
value_ = value_.contiguous();
}
(*stack)[args_begin + tensor_pos[tensor_idx]] = value_;
}

op.callBoxed(stack);
Expand All @@ -330,6 +338,14 @@ inline void boxed_all_tensors_have_optional_bdim(
#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_all_tensors_have_optional_bdim<feature_rank>>());

#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(feature_rank, op, contig_tensor_index) \
m.impl(#op, \
torch::CppFunction::makeFromBoxedFunction<\
boxed_all_tensors_have_optional_bdim<\
feature_rank, \
contig_tensor_index>\
>());

template <typename A, A a, typename C>
struct ExistingBdimBatchRuleHelper;

Expand Down
39 changes: 6 additions & 33 deletions functorch/csrc/BatchRulesPooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,38 +26,6 @@ static Tensor reshape_bdim_into_front(
return reshape_dim_into(*bdim, 0, value_);
}

// We can't use ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED because the CUDA
// kernel rightfully assumes that indices is contiguous.
std::tuple<Tensor,optional<int64_t>> max_pool2d_with_indices_backward_batch_rule(
const Tensor& gradOutput, optional<int64_t> gradOutput_bdim,
const Tensor& input, optional<int64_t> input_bdim,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const Tensor& indices, optional<int64_t> indices_bdim) {
TORCH_INTERNAL_ASSERT(input_bdim.has_value() ^ !indices_bdim.has_value());
const auto bdim_size = get_bdim_size2(gradOutput, gradOutput_bdim, input, input_bdim);
const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
bool chw_case = input_logical_rank == 3;

const auto gradOutput_ = reshape_bdim_into_front(gradOutput, gradOutput_bdim, bdim_size, chw_case);
const auto input_ = reshape_bdim_into_front(input, input_bdim, bdim_size, chw_case);
const auto indices_ = reshape_bdim_into_front(indices, indices_bdim, bdim_size, chw_case);

const auto result = at::max_pool2d_with_indices_backward(
gradOutput_, input_, kernel_size, stride, padding, dilation, ceil_mode,
// max_pool2d_with_indices rightfully assumes that indices is contiguous
indices_.contiguous());

if (chw_case) {
return std::make_tuple(std::move(result), 0);
} else {
return std::make_tuple(reshape_dim_outof(0, bdim_size, result), 0);
}
}

std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
max_pool2d_with_indices_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
Expand Down Expand Up @@ -91,8 +59,13 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
EXISTING_BDIM(avg_pool3d);
EXISTING_BDIM_ALL_BOXED(avg_pool2d_backward);
EXISTING_BDIM_ALL_BOXED(avg_pool3d_backward);
EXISTING_BDIM_ALL_BOXED(adaptive_max_pool2d);
EXISTING_BDIM_ALL_BOXED(adaptive_max_pool3d);
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(3, adaptive_max_pool2d_backward, 2);
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(4, adaptive_max_pool3d_backward, 2);

VMAP_SUPPORT("max_pool2d_with_indices", max_pool2d_with_indices_batch_rule);
VMAP_SUPPORT("max_pool2d_with_indices_backward", max_pool2d_with_indices_backward_batch_rule);
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(3, max_pool2d_with_indices_backward, 2);
}

}}
3 changes: 0 additions & 3 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,6 @@ def test_vmapvjp(self, device, dtype, op):
xfail('diagonal_scatter'),
xfail('double', 'channels_last'),
xfail('linalg.cross'),
xfail('nn.functional.adaptive_max_pool1d'),
xfail('nn.functional.adaptive_max_pool2d'),
xfail('nn.functional.adaptive_max_pool3d'),
xfail('nn.functional.conv1d'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('nn.functional.hardsigmoid'),
Expand Down
3 changes: 0 additions & 3 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3201,9 +3201,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('slice_scatter'),
xfail('unique_consecutive'),
xfail('unique'),
xfail('nn.functional.adaptive_max_pool1d'),
xfail('nn.functional.adaptive_max_pool2d'),
xfail('nn.functional.adaptive_max_pool3d'),
xfail('nn.functional.conv1d'),
xfail('nn.functional.cosine_embedding_loss'),
# xfail('nn.functional.cross_entropy'),
Expand Down

0 comments on commit da27398

Please sign in to comment.