Skip to content

Commit

Permalink
Added grid_sample backward batch rule
Browse files Browse the repository at this point in the history
Description:
- Added grid_sample backward batch rule: CPU and CUDA
- Updated tests

Notes:
I had to expand on dim 0 in most of the cases and could not use
tricks like in forward pass when batch dim is merged either with channel or H_out
due to wrong grid grads in these cases
  • Loading branch information
vfdev-5 committed Nov 24, 2021
1 parent 11c3caf commit 42ebfac
Show file tree
Hide file tree
Showing 4 changed files with 2,150 additions and 1,093 deletions.
4 changes: 1 addition & 3 deletions codegen/codegen_outofplacebatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def parse_return(return_t):
return tuple([x.strip() for x in m.group(1).split(',')])

def parse_args(args_t):
args = args_t.split(',')
args = args_t.split(', ')
result = []
for arg in args:
split_idx = arg.rfind(' ')
Expand All @@ -170,8 +170,6 @@ def get_signatures(path='build/aten/src/ATen/RegistrationDeclarations.h', includ
for line in lines:
if 'void' in line:
continue
if 'std::array' in line:
continue
m = re.match(r'(.*) \w+\((.*)\); // {"schema": "aten::(\w+\.?\w*)\(.*', line)
if m is None:
continue
Expand Down
248 changes: 248 additions & 0 deletions functorch/csrc/BatchRulesModules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,202 @@ grid_sample_batch_rule(const Tensor& input, optional<int64_t> input_bdim, const
return result;
}

Tensor expand_reshape_dim_into(int64_t batch_size, int64_t dst, const Tensor& x) {
auto x_ = x.unsqueeze(0);
VmapDimVector new_shape(x_.sizes().begin(), x_.sizes().end());
new_shape[0] = batch_size;
x_ = x_.expand(new_shape);
return reshape_dim_into(0, dst, x_);
}


std::tuple<Tensor, Tensor, optional<int64_t>, Tensor, optional<int64_t>, int64_t>
grid_sample_backward_helper_in(
const Tensor& grad_output, optional<int64_t> grad_output_bdim,
const Tensor& input, optional<int64_t> input_bdim,
const Tensor& grid, optional<int64_t> grid_bdim) {
auto new_grad_output = grad_output;
auto new_input = input;
auto new_grid = grid;

optional<int64_t> grad_input_out_bdim = nullopt;
optional<int64_t> grad_grid_out_bdim = nullopt;
int64_t bdim_size = 0;

if (grad_output_bdim) {

bdim_size = grad_output.sizes()[*grad_output_bdim];

if (input_bdim && grid_bdim) {
// case 1: (grad_output is batched, input is batched, grid is batched)
// grad_output: (BN)CH_{out}W_{out}, input: (BN)CH_{in}W_{in}, grid: (BN)H_{out}W_{out}2
// grad_input: (BN)CH_{in}W_{in}

new_grad_output = reshape_dim_into(*grad_output_bdim, 0, grad_output);
new_input = reshape_dim_into(*input_bdim, 0, input);
new_grid = reshape_dim_into(*grid_bdim, 0, grid);
grad_input_out_bdim = 0;
grad_grid_out_bdim = 0;
} else if (input_bdim && !grid_bdim) {
// case 2: (grad_output is batched, input is batched, grid is not batched)
// IF PUT BATCH DIM TO CHANNEL -> backward produces wrong grad_grid
//
// grad_output: (BN)CH_{out}W_{out}, input: (BN)CH_{in}W_{in}, grid: NH_{out}W_{out}2
// -> grid: (BN)H_{out}W_{out}2
// grad_input: (BN)CH_{in}W_{in}

new_grad_output = reshape_dim_into(*grad_output_bdim, 0, grad_output);
new_input = reshape_dim_into(*input_bdim, 0, input);
grad_input_out_bdim = 0;
new_grid = expand_reshape_dim_into(bdim_size, 0, grid);
grad_grid_out_bdim = 0;
} else if (!input_bdim && grid_bdim) {
// case 3: (grad_output is batched, input is not batched, grid is batched)
// IF PUT BATCH DIM TO H_out -> backward produces wrong grad_grid
//
// grad_output: (BN)CH_{out}W_{out}, input: NCH_{in}W_{in}, grid: (BN)H_{out}W_{out}2
// -> input: (BN)CH_{in}W_{in}
// grad_input: (BN)CH_{in}W_{in}

new_grad_output = reshape_dim_into(*grad_output_bdim, 0, grad_output);
new_grid = reshape_dim_into(*grid_bdim, 0, grid);
grad_grid_out_bdim = 0;
// expand input to (BN)H_{out}W_{out}2
new_input = expand_reshape_dim_into(bdim_size, 0, new_input);
grad_input_out_bdim = 0;
} else {
// case 4: (grad_output is batched, input is not batched, grid is not batched)
// IF PUT BATCH DIM TO H_out -> backward produces wrong grad_grid
//
// grad_output: (BN)CH_{out}W_{out}, input: NCH_{in}W_{in}, grid: NH_{out}W_{out}2
// -> grid: (BN)H_{out}W_{out}2
// -> input: (BN)CH_{in}W_{in}
// grad_input: NCH_{in}W_{in}

new_grad_output = reshape_dim_into(*grad_output_bdim, 0, grad_output);
// expand grid to (BN)H_{out}W_{out}2
new_grid = expand_reshape_dim_into(bdim_size, 0, grid);
grad_grid_out_bdim = 0;
// expand input to (BN)CH_{in}W_{in}
new_input = expand_reshape_dim_into(bdim_size, 0, input);
grad_input_out_bdim = 0;
}
} else {
if (input_bdim && grid_bdim) {
// case 5: (grad_output is not batched, input is batched, grid is batched)
// grad_output: NCH_{out}W_{out}, input: (BN)CH_{in}W_{in}, grid: (BN)H_{out}W_{out}2
// -> grad_output: (BN)CH_{out}W_{out}
// grad_input: (BN)CH_{in}W_{in}

bdim_size = input.sizes()[*input_bdim];
// expand new_grad_output to (BN)CH_{out}W_{out}
new_grad_output = expand_reshape_dim_into(bdim_size, 0, new_grad_output);
new_input = reshape_dim_into(*input_bdim, 0, input);
grad_input_out_bdim = 0;
new_grid = reshape_dim_into(*grid_bdim, 0, grid);
grad_grid_out_bdim = 0;
} else if (input_bdim && !grid_bdim) {
// case 6: (grad_output is not batched, input is batched, grid is not batched)
// grad_output: NCH_{out}W_{out}, input: (BN)CH_{in}W_{in}, grid: NH_{out}W_{out}2
// -> grad_output: (BN)CH_{out}W_{out}
// -> grid: (BN)H_{out}W_{out}2
// grad_input: (BN)CH_{in}W_{in}

bdim_size = input.sizes()[*input_bdim];
// expand new_grad_output to (BN)CH_{out}W_{out}
new_grad_output = expand_reshape_dim_into(bdim_size, 0, new_grad_output);
new_input = reshape_dim_into(*input_bdim, 0, input);
grad_input_out_bdim = 0;
// expand new_grid to (BN)H_{out}W_{out}2
new_grid = expand_reshape_dim_into(bdim_size, 0, grid);
grad_grid_out_bdim = 0;
} else if (!input_bdim && grid_bdim) {
// case 7: (grad_output is not batched, input is not batched, grid is batched)
// IF PUT BATCH DIM TO H_out -> backward produces wrong grad_grid
//
// grad_output: NCH_{out}W_{out}, input: NCH_{in}W_{in}, grid: (BN)H_{out}W_{out}2
// -> grad_output: (BN)CH_{out}W_{out}
// -> input: (BN)CH_{out}W_{out}
// grad_input: NCH_{in}W_{in}

bdim_size = grid.sizes()[*grid_bdim];
// expand new_grad_output to NC(BH_{out})W_{out}
new_grad_output = expand_reshape_dim_into(bdim_size, 0, new_grad_output);
// expand new_input to (BN)CH_{in}W_{in}
new_input = expand_reshape_dim_into(bdim_size, 0, new_input);
grad_input_out_bdim = 0;
new_grid = reshape_dim_into(*grid_bdim, 0, grid);
grad_grid_out_bdim = 0;
} // case 8 can be ignored
}
return std::make_tuple(
new_grad_output, new_input, grad_input_out_bdim, new_grid, grad_grid_out_bdim, bdim_size);
}

std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>>
grid_sample_backward_helper_out(
const std::tuple<Tensor, Tensor> & bw_out,
optional<int64_t> grad_input_out_bdim,
optional<int64_t> grad_grid_out_bdim,
int64_t bdim_size) {
auto grad_input = std::get<0>(bw_out);
auto grad_grid = std::get<1>(bw_out);
if (grad_input_out_bdim) {
grad_input = reshape_dim_outof(*grad_input_out_bdim, bdim_size, grad_input);
}
if (grad_grid_out_bdim) {
grad_grid = reshape_dim_outof(*grad_grid_out_bdim, bdim_size, grad_grid);
}
auto result = std::make_tuple(grad_input, grad_input_out_bdim, grad_grid, grad_grid_out_bdim);
return result;
}


template<typename F, F Func, typename... ExtraArgs>
std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>>
grid_sample_backward_batch_rule(
const Tensor& grad_output, optional<int64_t> grad_output_bdim,
const Tensor& input, optional<int64_t> input_bdim,
const Tensor& grid, optional<int64_t> grid_bdim,
ExtraArgs... extra_args) {

auto new_bw_input = grid_sample_backward_helper_in(
grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);

auto new_grad_output = std::get<0>(new_bw_input);
auto new_input = std::get<1>(new_bw_input);
auto grad_input_out_bdim = std::get<2>(new_bw_input);
auto new_grid = std::get<3>(new_bw_input);
auto grad_grid_out_bdim = std::get<4>(new_bw_input);
int64_t bdim_size = std::get<5>(new_bw_input);

auto bw_out = Func(new_grad_output, new_input, new_grid, std::forward<ExtraArgs>(extra_args)...);

return grid_sample_backward_helper_out(bw_out, grad_input_out_bdim, grad_grid_out_bdim, bdim_size);
}

template<typename F, F Func>
std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>>
cudnn_grid_sample_backward_batch_rule(
const Tensor& input, optional<int64_t> input_bdim,
const Tensor& grid, optional<int64_t> grid_bdim,
const Tensor& grad_output, optional<int64_t> grad_output_bdim) {

auto new_bw_input = grid_sample_backward_helper_in(
grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);

auto new_grad_output = std::get<0>(new_bw_input);
auto new_input = std::get<1>(new_bw_input);
auto grad_input_out_bdim = std::get<2>(new_bw_input);
auto new_grid = std::get<3>(new_bw_input);
auto grad_grid_out_bdim = std::get<4>(new_bw_input);
int64_t bdim_size = std::get<5>(new_bw_input);

auto bw_out = Func(new_input, new_grid, new_grad_output);

return grid_sample_backward_helper_out(bw_out, grad_input_out_bdim, grad_grid_out_bdim, bdim_size);
}

std::tuple<Tensor, optional<int64_t>> cross_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
const Tensor& other, optional<int64_t> other_bdim,
Expand Down Expand Up @@ -370,12 +566,53 @@ struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
}
};

template <typename A, A a, typename C>
struct GridSampleBackwardBatchRuleHelper;

template <typename F, F Func, typename T1, typename T2, typename T3, typename... T>
struct GridSampleBackwardBatchRuleHelper<F, Func, typelist<T1, T2, T3, T...>> {
static std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>> apply(
const Tensor& grad_output, optional<int64_t> grad_output_batch_dim,
const Tensor& input, optional<int64_t> input_batch_dim,
const Tensor& grid, optional<int64_t> grid_batch_dim,
T... extra_args) {
return grid_sample_backward_batch_rule<F, Func, T...>(
grad_output, grad_output_batch_dim,
input, input_batch_dim,
grid, grid_batch_dim,
std::forward<T>(extra_args)...);
}
};

template <typename F, F Func>
struct CudnnGridSampleBackwardBatchRuleHelper {
static std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>> apply(
const Tensor& input, optional<int64_t> input_batch_dim,
const Tensor& grid, optional<int64_t> grid_batch_dim,
const Tensor& grad_output, optional<int64_t> grad_output_batch_dim) {
return cudnn_grid_sample_backward_batch_rule<F, Func>(
input, input_batch_dim,
grid, grid_batch_dim,
grad_output, grad_output_batch_dim
);
}
};

#define GRID_SAMPLE_BATCH_RULE(fn) SINGLE_ARG(\
GridSampleBatchRuleHelper<\
decltype(&ATEN_FN(fn)),\
&ATEN_FN(fn),\
c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)

#define GRID_SAMPLE_BW_BATCH_RULE(fn) SINGLE_ARG(\
GridSampleBackwardBatchRuleHelper<\
decltype(&ATEN_FN(fn)),\
&ATEN_FN(fn),\
c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)

#define CUDNN_GRID_SAMPLE_BW_BATCH_RULE(fn)\
CudnnGridSampleBackwardBatchRuleHelper<decltype(&ATEN_FN(fn)), &ATEN_FN(fn)>::apply

#define UPSAMPLE_BACKWARD(op, overload) VMAP_SUPPORT(#op"."#overload, SINGLE_ARG(\
UpsampleBackwardBatchRuleHelper<\
decltype(&ATEN_FN2(op, overload)),\
Expand All @@ -386,6 +623,12 @@ struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
EXISTING_BDIM2(op, vec); \
EXISTING_BDIM(op);

Tensor this_grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
return input;
}



TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("convolution", convolution_batch_rule);
// m.impl("conv_transpose2d", convNd_transpose_decomp);
Expand All @@ -400,7 +643,12 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
EXISTING_BDIM(im2col_backward);

VMAP_SUPPORT("grid_sampler_2d", GRID_SAMPLE_BATCH_RULE(grid_sampler));
VMAP_SUPPORT("grid_sampler_2d_backward", GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_2d_backward));

VMAP_SUPPORT("grid_sampler_3d", GRID_SAMPLE_BATCH_RULE(grid_sampler));
VMAP_SUPPORT("grid_sampler_3d_backward", GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_3d_backward));
VMAP_SUPPORT("cudnn_grid_sampler_backward", CUDNN_GRID_SAMPLE_BW_BATCH_RULE(cudnn_grid_sampler_backward));

VMAP_SUPPORT("cudnn_grid_sampler", GRID_SAMPLE_BATCH_RULE(cudnn_grid_sampler));
VMAP_SUPPORT("cross", cross_batch_rule);

Expand Down
Loading

0 comments on commit 42ebfac

Please sign in to comment.