Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prelu batching rule (forward + backward) #609

Merged
merged 2 commits into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions functorch/csrc/BatchRulesActivation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <functorch/csrc/BatchRulesHelper.h>
#include <functorch/csrc/PlumbingHelper.h>
#include <ATen/Operators.h>

// NB: most activation functions fit pointwise unary or binary rules.
// These are only the ones that have special batch rules to help with organization
namespace at { namespace functorch {
std::tuple<Tensor,optional<int64_t>> prelu_batch_rule(
const Tensor& input, optional<int64_t> input_bdim,
const Tensor& weight, optional<int64_t> weight_bdim) {
if (!weight_bdim && weight.dim() == 0) {
return std::make_tuple(at::prelu(input, weight), input_bdim);
}

const auto input_ = moveBatchDimToFront(input, input_bdim);
auto weight_flatten = moveBatchDimToFront(weight, weight_bdim);

if (weight_flatten.dim() > 1) {
// for an input [N, C, ...]
// weight can be a non-vector but the total number of elements must be the same as C
weight_flatten = at::flatten(weight_flatten, weight_bdim.has_value() ? 1 : 0, -1);
}

const int64_t input_logical_rank = rankWithoutBatchDim(input, input_bdim);
VmapDimVector new_shape(weight_flatten.sizes().begin(), weight_flatten.sizes().end());
const int64_t final_size = weight_bdim ? (input_logical_rank + 1) : input_logical_rank;
new_shape.reserve(final_size);

if (weight_flatten.dim() == 2 || !weight_bdim) {
// if weight (without batching) is not a scalar, its size must match the "channel dimension" of input. To do the
// decomposition, we pad the weight to

// copies checks from prelu if the weight (without vmap) is not a scalar
TORCH_CHECK(input_logical_rank > 0, "Not allow zero-dim input tensor.");

int64_t channel_size = 1; // channel_size default to 1
if (input_logical_rank > 1) {
const auto channel_dim = input_bdim ? 2 : 1;
channel_size = input_.size(channel_dim);
}
const auto weight_num = weight_flatten.size(-1);
TORCH_CHECK(channel_size == weight_num,
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
" and channel size = ", channel_size, ".");

// pads to the left so that the flattened shape matches up with the channel
if (!weight_bdim) {
new_shape.insert(new_shape.begin(), 1);
} else {
new_shape.insert(new_shape.begin() + 1, 1);
}
}

for (int64_t i = new_shape.size(); i < final_size; i ++) {
new_shape.push_back(1);
}
TORCH_INTERNAL_ASSERT(new_shape.size() == final_size);
const auto weight_padded = weight_flatten.view(new_shape);
auto zero_tensor = at::zeros(1, input.options());

// decomposes function,
auto res = at::maximum(zero_tensor, input_) + weight_padded * at::minimum(zero_tensor, input_);
return std::make_tuple(res, 0);
}

VmapDimVector ensure_shape_with_bdim(const Tensor& input, const bool has_bdim, const int64_t batch_size) {
// helper function that get the size of input, ensuring that there's batch dim, without expanding input
if (has_bdim) {
// sad to have to copy but got garbage if tried to return an IntArrayRef and just do input.sizes()
VmapDimVector new_shape(input.sizes().begin(), input.sizes().end());
return new_shape;
}
VmapDimVector new_shape(1, batch_size);
new_shape.reserve(input.dim() + 1);
new_shape.insert(new_shape.end(), input.sizes().begin(), input.sizes().end());
return new_shape;
}

VmapDimVector shape_maybe_with_bdim(const Tensor& input, const bool need_bdim, const bool has_bdim, const int64_t batch_size) {
// if need_bdim, will return the input with a guaranteed bdim. If not, will return the input logical size (no batch dim)
if (need_bdim) {
return ensure_shape_with_bdim(input, has_bdim, batch_size);
} else if (has_bdim) { // !need_bdim && has_bdim
VmapDimVector new_shape(input.sizes().begin() + 1, input.sizes().end());
return new_shape;
} else { // !need_bdim && !has_bdim
VmapDimVector new_shape(input.sizes().begin(), input.sizes().end());
return new_shape;
}
}

std::tuple<Tensor, Tensor> prelu_backward_batched(
const Tensor& grad_out, const Tensor& self, const Tensor& weight,
const VmapDimVector& self_grad_shape, const VmapDimVector& weight_grad_padded_shape, const VmapDimVector& weight_grad_shape) {
// helper function that produces a batched gradient for prelu using a decomposition inspired by the AOTAutograd ones
const auto input_grad_collector = at::where(self > 0, grad_out, weight * grad_out);
const auto input_grad = native::sum_to_size(input_grad_collector, self_grad_shape);
const auto weight_grad_collector = at::where(self > 0, at::zeros(1, self.options()), self * grad_out);
const auto weight_grad_collector_2 = native::sum_to_size(weight_grad_collector, weight_grad_padded_shape);
const auto weight_grad = weight_grad_collector_2.view(weight_grad_shape);
return std::make_tuple(input_grad, weight_grad);
}

std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>> prelu_backward_batch_rule(
const Tensor& grad_out, optional<int64_t> grad_out_bdim,
const Tensor& self, optional<int64_t> self_bdim,
const Tensor& weight, optional<int64_t> weight_bdim) {
const auto batch_size = get_bdim_size3(grad_out, grad_out_bdim, self, self_bdim, weight, weight_bdim);
const auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
const auto self_ = moveBatchDimToFront(self, self_bdim);
const auto self_size_with_bdim = ensure_shape_with_bdim(self_, self_bdim.has_value(), batch_size);
if (!weight_bdim && weight.dim() == 0) {
VmapDimVector weight_grad_shape(1, batch_size);
VmapDimVector weight_grad_shape_padded(self_bdim.has_value() ? self.dim() : self.dim() + 1, 1);
weight_grad_shape_padded[0] = batch_size;
const auto grads = prelu_backward_batched(grad_out_, self_, weight, self_size_with_bdim, weight_grad_shape_padded, weight_grad_shape);
return std::make_tuple(std::get<0>(grads), 0, std::get<1>(grads), 0);
}
const auto weight_ = moveBatchDimToFront(weight, weight_bdim);
auto weight_flatten = weight_;
if (weight_flatten.dim() > 1) {
// for an input [N, C, ...]
// weight can be a non-vector but the total number of elements must be the same as C
weight_flatten = at::flatten(weight_flatten, weight_bdim.has_value() ? 1 : 0, -1);
}

const int64_t self_logical_rank = rankWithoutBatchDim(self, self_bdim);
VmapDimVector new_shape(weight_flatten.sizes().begin(), weight_flatten.sizes().end());
const int64_t final_size = weight_bdim ? (self_logical_rank + 1) : self_logical_rank;
new_shape.reserve(final_size);

if (weight_flatten.dim() == 2 || !weight_bdim) {
// if weight (without batching) is not a scalar, its size must match the "channel dimension" of input. To do the
// decomposition, we pad the weight to

// copies checks from prelu if the weight (without vmap) is not a scalar
TORCH_CHECK(self_logical_rank > 0, "Not allow zero-dim input tensor.");

int64_t channel_size = 1; // channel_size default to 1
if (self_logical_rank > 1) {
channel_size = self_.size(self_bdim.has_value() ? 2 : 1);
}

const auto weight_num = weight_flatten.size(-1);
TORCH_CHECK(channel_size == weight_num,
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
" and channel size = ", channel_size, ".");

// pads to the left so that the flattened shape matches up with the channel
if (!weight_bdim) {
new_shape.insert(new_shape.begin(), 1);
} else {
new_shape.insert(new_shape.begin() + 1, 1);
}
}

for (int64_t i = new_shape.size(); i < final_size; i ++) {
new_shape.push_back(1);
}
// weight grad does not depend on weight values. It is batched iff grad_out or self are batched
const auto weight_grad_is_batched = grad_out_bdim.has_value() || self_bdim.has_value();

const auto weight_padded = weight_flatten.view(new_shape);
const auto weight_grad_shape = shape_maybe_with_bdim(weight_, weight_grad_is_batched, weight_bdim.has_value(), batch_size);
const auto weight_padded_grad_shape = shape_maybe_with_bdim(weight_padded, weight_grad_is_batched, weight_bdim.has_value(), batch_size);

const auto grads = prelu_backward_batched(grad_out_, self_, weight_padded, self_size_with_bdim, weight_padded_grad_shape, weight_grad_shape);
return std::make_tuple(std::get<0>(grads), 0, std::get<1>(grads), (weight_grad_is_batched ? optional<int64_t>(0) : nullopt));
}

TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT(prelu, prelu_batch_rule)
VMAP_SUPPORT(prelu_backward, prelu_backward_batch_rule)
}
}} // namespace at::functorch
1 change: 0 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,6 @@ def test():
xfail('nn.functional.huber_loss'),
xfail('nn.functional.poisson_nll_loss'),
xfail('nn.functional.bilinear'),
xfail('nn.functional.prelu'),
xfail('nn.functional.glu'),
xfail('nn.functional.fractional_max_pool3d'),
xfail('as_strided'),
Expand Down
1 change: 0 additions & 1 deletion test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3240,7 +3240,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('stft'),
xfail('linalg.solve_triangular'),
xfail('nn.functional.glu'),
xfail('nn.functional.prelu'),
xfail('isclose'),
xfail('nn.functional.fractional_max_pool3d'),
xfail('nn.functional.bilinear'),
Expand Down