Skip to content

Commit

Permalink
Prelu batching rule (forward + backward) (#609) (#669)
Browse files Browse the repository at this point in the history
* prelu forward rule

* prelu backward rule
  • Loading branch information
Samantha Andow authored Apr 6, 2022
1 parent b504e6d commit d8152ab
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 2 deletions.
181 changes: 181 additions & 0 deletions functorch/csrc/BatchRulesActivation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <functorch/csrc/BatchRulesHelper.h>
#include <functorch/csrc/PlumbingHelper.h>
#include <ATen/Operators.h>

// NB: most activation functions fit pointwise unary or binary rules.
// These are only the ones that have special batch rules to help with organization
namespace at { namespace functorch {
std::tuple<Tensor,optional<int64_t>> prelu_batch_rule(
const Tensor& input, optional<int64_t> input_bdim,
const Tensor& weight, optional<int64_t> weight_bdim) {
if (!weight_bdim && weight.dim() == 0) {
return std::make_tuple(at::prelu(input, weight), input_bdim);
}

const auto input_ = moveBatchDimToFront(input, input_bdim);
auto weight_flatten = moveBatchDimToFront(weight, weight_bdim);

if (weight_flatten.dim() > 1) {
// for an input [N, C, ...]
// weight can be a non-vector but the total number of elements must be the same as C
weight_flatten = at::flatten(weight_flatten, weight_bdim.has_value() ? 1 : 0, -1);
}

const int64_t input_logical_rank = rankWithoutBatchDim(input, input_bdim);
VmapDimVector new_shape(weight_flatten.sizes().begin(), weight_flatten.sizes().end());
const int64_t final_size = weight_bdim ? (input_logical_rank + 1) : input_logical_rank;
new_shape.reserve(final_size);

if (weight_flatten.dim() == 2 || !weight_bdim) {
// if weight (without batching) is not a scalar, its size must match the "channel dimension" of input. To do the
// decomposition, we pad the weight to

// copies checks from prelu if the weight (without vmap) is not a scalar
TORCH_CHECK(input_logical_rank > 0, "Not allow zero-dim input tensor.");

int64_t channel_size = 1; // channel_size default to 1
if (input_logical_rank > 1) {
const auto channel_dim = input_bdim ? 2 : 1;
channel_size = input_.size(channel_dim);
}
const auto weight_num = weight_flatten.size(-1);
TORCH_CHECK(channel_size == weight_num,
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
" and channel size = ", channel_size, ".");

// pads to the left so that the flattened shape matches up with the channel
if (!weight_bdim) {
new_shape.insert(new_shape.begin(), 1);
} else {
new_shape.insert(new_shape.begin() + 1, 1);
}
}

for (int64_t i = new_shape.size(); i < final_size; i ++) {
new_shape.push_back(1);
}
TORCH_INTERNAL_ASSERT(new_shape.size() == final_size);
const auto weight_padded = weight_flatten.view(new_shape);
auto zero_tensor = at::zeros(1, input.options());

// decomposes function,
auto res = at::maximum(zero_tensor, input_) + weight_padded * at::minimum(zero_tensor, input_);
return std::make_tuple(res, 0);
}

VmapDimVector ensure_shape_with_bdim(const Tensor& input, const bool has_bdim, const int64_t batch_size) {
// helper function that get the size of input, ensuring that there's batch dim, without expanding input
if (has_bdim) {
// sad to have to copy but got garbage if tried to return an IntArrayRef and just do input.sizes()
VmapDimVector new_shape(input.sizes().begin(), input.sizes().end());
return new_shape;
}
VmapDimVector new_shape(1, batch_size);
new_shape.reserve(input.dim() + 1);
new_shape.insert(new_shape.end(), input.sizes().begin(), input.sizes().end());
return new_shape;
}

VmapDimVector shape_maybe_with_bdim(const Tensor& input, const bool need_bdim, const bool has_bdim, const int64_t batch_size) {
// if need_bdim, will return the input with a guaranteed bdim. If not, will return the input logical size (no batch dim)
if (need_bdim) {
return ensure_shape_with_bdim(input, has_bdim, batch_size);
} else if (has_bdim) { // !need_bdim && has_bdim
VmapDimVector new_shape(input.sizes().begin() + 1, input.sizes().end());
return new_shape;
} else { // !need_bdim && !has_bdim
VmapDimVector new_shape(input.sizes().begin(), input.sizes().end());
return new_shape;
}
}

std::tuple<Tensor, Tensor> prelu_backward_batched(
const Tensor& grad_out, const Tensor& self, const Tensor& weight,
const VmapDimVector& self_grad_shape, const VmapDimVector& weight_grad_padded_shape, const VmapDimVector& weight_grad_shape) {
// helper function that produces a batched gradient for prelu using a decomposition inspired by the AOTAutograd ones
const auto input_grad_collector = at::where(self > 0, grad_out, weight * grad_out);
const auto input_grad = native::sum_to_size(input_grad_collector, self_grad_shape);
const auto weight_grad_collector = at::where(self > 0, at::zeros(1, self.options()), self * grad_out);
const auto weight_grad_collector_2 = native::sum_to_size(weight_grad_collector, weight_grad_padded_shape);
const auto weight_grad = weight_grad_collector_2.view(weight_grad_shape);
return std::make_tuple(input_grad, weight_grad);
}

std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>> prelu_backward_batch_rule(
const Tensor& grad_out, optional<int64_t> grad_out_bdim,
const Tensor& self, optional<int64_t> self_bdim,
const Tensor& weight, optional<int64_t> weight_bdim) {
const auto batch_size = get_bdim_size3(grad_out, grad_out_bdim, self, self_bdim, weight, weight_bdim);
const auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
const auto self_ = moveBatchDimToFront(self, self_bdim);
const auto self_size_with_bdim = ensure_shape_with_bdim(self_, self_bdim.has_value(), batch_size);
if (!weight_bdim && weight.dim() == 0) {
VmapDimVector weight_grad_shape(1, batch_size);
VmapDimVector weight_grad_shape_padded(self_bdim.has_value() ? self.dim() : self.dim() + 1, 1);
weight_grad_shape_padded[0] = batch_size;
const auto grads = prelu_backward_batched(grad_out_, self_, weight, self_size_with_bdim, weight_grad_shape_padded, weight_grad_shape);
return std::make_tuple(std::get<0>(grads), 0, std::get<1>(grads), 0);
}
const auto weight_ = moveBatchDimToFront(weight, weight_bdim);
auto weight_flatten = weight_;
if (weight_flatten.dim() > 1) {
// for an input [N, C, ...]
// weight can be a non-vector but the total number of elements must be the same as C
weight_flatten = at::flatten(weight_flatten, weight_bdim.has_value() ? 1 : 0, -1);
}

const int64_t self_logical_rank = rankWithoutBatchDim(self, self_bdim);
VmapDimVector new_shape(weight_flatten.sizes().begin(), weight_flatten.sizes().end());
const int64_t final_size = weight_bdim ? (self_logical_rank + 1) : self_logical_rank;
new_shape.reserve(final_size);

if (weight_flatten.dim() == 2 || !weight_bdim) {
// if weight (without batching) is not a scalar, its size must match the "channel dimension" of input. To do the
// decomposition, we pad the weight to

// copies checks from prelu if the weight (without vmap) is not a scalar
TORCH_CHECK(self_logical_rank > 0, "Not allow zero-dim input tensor.");

int64_t channel_size = 1; // channel_size default to 1
if (self_logical_rank > 1) {
channel_size = self_.size(self_bdim.has_value() ? 2 : 1);
}

const auto weight_num = weight_flatten.size(-1);
TORCH_CHECK(channel_size == weight_num,
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
" and channel size = ", channel_size, ".");

// pads to the left so that the flattened shape matches up with the channel
if (!weight_bdim) {
new_shape.insert(new_shape.begin(), 1);
} else {
new_shape.insert(new_shape.begin() + 1, 1);
}
}

for (int64_t i = new_shape.size(); i < final_size; i ++) {
new_shape.push_back(1);
}
// weight grad does not depend on weight values. It is batched iff grad_out or self are batched
const auto weight_grad_is_batched = grad_out_bdim.has_value() || self_bdim.has_value();

const auto weight_padded = weight_flatten.view(new_shape);
const auto weight_grad_shape = shape_maybe_with_bdim(weight_, weight_grad_is_batched, weight_bdim.has_value(), batch_size);
const auto weight_padded_grad_shape = shape_maybe_with_bdim(weight_padded, weight_grad_is_batched, weight_bdim.has_value(), batch_size);

const auto grads = prelu_backward_batched(grad_out_, self_, weight_padded, self_size_with_bdim, weight_padded_grad_shape, weight_grad_shape);
return std::make_tuple(std::get<0>(grads), 0, std::get<1>(grads), (weight_grad_is_batched ? optional<int64_t>(0) : nullopt));
}

TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT(prelu, prelu_batch_rule)
VMAP_SUPPORT(prelu_backward, prelu_backward_batch_rule)
}
}} // namespace at::functorch
1 change: 0 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,6 @@ def test_vmapjvpall(self, device, dtype, op):
xfail('nn.functional.huber_loss'),
xfail('nn.functional.poisson_nll_loss'),
xfail('nn.functional.bilinear'),
xfail('nn.functional.prelu'),
xfail('nn.functional.glu'),
xfail('nn.functional.fractional_max_pool3d'),
xfail('as_strided'),
Expand Down
1 change: 0 additions & 1 deletion test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3230,7 +3230,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('stft'),
xfail('linalg.solve_triangular'),
xfail('nn.functional.glu'),
xfail('nn.functional.prelu'),
xfail('isclose'),
xfail('nn.functional.fractional_max_pool3d'),
xfail('nn.functional.bilinear'),
Expand Down

0 comments on commit d8152ab

Please sign in to comment.