Skip to content

Commit

Permalink
prelu backward rule
Browse files Browse the repository at this point in the history
  • Loading branch information
samdow committed Mar 31, 2022
1 parent c7cc7b2 commit 5e8bb19
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 1 deletion.
106 changes: 106 additions & 0 deletions functorch/csrc/BatchRulesActivation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,113 @@ std::tuple<Tensor,optional<int64_t>> prelu_batch_rule(
return std::make_tuple(res, 0);
}

VmapDimVector ensure_shape_with_bdim(const Tensor& input, const bool has_bdim, const int64_t batch_size) {
// helper function that get the size of input, ensuring that there's batch dim, without expanding input
if (has_bdim) {
// sad to have to copy but got garbage if tried to return an IntArrayRef and just do input.sizes()
VmapDimVector new_shape(input.sizes().begin(), input.sizes().end());
return new_shape;
}
VmapDimVector new_shape(1, batch_size);
new_shape.reserve(input.dim() + 1);
new_shape.insert(new_shape.end(), input.sizes().begin(), input.sizes().end());
return new_shape;
}

VmapDimVector shape_maybe_with_bdim(const Tensor& input, const bool need_bdim, const bool has_bdim, const int64_t batch_size) {
// if need_bdim, will return the input with a guaranteed bdim. If not, will return the input logical size (no batch dim)
if (need_bdim) {
return ensure_shape_with_bdim(input, has_bdim, batch_size);
} else if (has_bdim) { // !need_bdim && has_bdim
VmapDimVector new_shape(input.sizes().begin() + 1, input.sizes().end());
return new_shape;
} else { // !need_bdim && !has_bdim
VmapDimVector new_shape(input.sizes().begin(), input.sizes().end());
return new_shape;
}
}

std::tuple<Tensor, Tensor> prelu_backward_batched(
const Tensor& grad_out, const Tensor& self, const Tensor& weight,
const VmapDimVector& self_shape_with_bdim, const VmapDimVector& padded_shape, const VmapDimVector& unpadded_shape) {
// helper function that assumes that self, weight, and grad_out are all of the correct shape
// based on decomposition of prelu_backward
const auto input_grad = at::where(self > 0, grad_out, weight * grad_out);
const auto input_grad_collector = native::sum_to_size(input_grad, self_shape_with_bdim);
const auto weight_grad_collector = at::where(self > 0, at::zeros(1, self.options()), self * grad_out);
const auto weight_grad_collector_2 = native::sum_to_size(weight_grad_collector, padded_shape);
const auto weight_grad = weight_grad_collector_2.view(unpadded_shape);
return std::make_tuple(input_grad_collector, weight_grad);
}

std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>> prelu_backward_batch_rule(
const Tensor& grad_out, optional<int64_t> grad_out_bdim,
const Tensor& self, optional<int64_t> self_bdim,
const Tensor& weight, optional<int64_t> weight_bdim) {
const auto batch_size = get_bdim_size3(grad_out, grad_out_bdim, self, self_bdim, weight, weight_bdim);
const auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
const auto self_ = moveBatchDimToFront(self, self_bdim);
const auto self_size_with_bdim = ensure_shape_with_bdim(self_, self_bdim.has_value(), batch_size);
if (!weight_bdim && weight.dim() == 0) {
VmapDimVector weight_grad_shape(1, batch_size);
VmapDimVector weight_grad_shape_padded(self_bdim.has_value() ? self.dim() : self.dim() + 1, 1);
weight_grad_shape_padded[0] = batch_size;
const auto grads = prelu_backward_batched(grad_out_, self_, weight, self_size_with_bdim, weight_grad_shape_padded, weight_grad_shape);
return std::make_tuple(std::get<0>(grads), 0, std::get<1>(grads), 0);
}
const auto weight_ = moveBatchDimToFront(weight, weight_bdim);
auto weight_flatten = weight_;
if (weight_flatten.dim() > 1) {
// for an input [N, C, ...]
// weight can be a non-vector but the total number of elements must be the same as C
weight_flatten = at::flatten(weight_flatten, weight_bdim.has_value() ? 1 : 0, -1);
}

const int64_t self_logical_rank = rankWithoutBatchDim(self, self_bdim);
VmapDimVector new_shape(weight_flatten.sizes().begin(), weight_flatten.sizes().end());
const int64_t final_size = weight_bdim ? (self_logical_rank + 1) : self_logical_rank;
new_shape.reserve(final_size);

if (weight_flatten.dim() == 2 || !weight_bdim) {
// if weight (without batching) is not a scalar, its size must match the "channel dimension" of input. To do the
// decomposition, we pad the weight to

// copies checks from prelu if the weight (without vmap) is not a scalar
TORCH_CHECK(self_logical_rank > 0, "Not allow zero-dim input tensor.");

int64_t channel_size = 1; // channel_size default to 1
if (self_logical_rank > 1) {
channel_size = self_.size(self_bdim.has_value() ? 2 : 1);
}

const auto weight_num = weight_flatten.size(-1);
TORCH_CHECK(channel_size == weight_num,
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
" and channel size = ", channel_size, ".");

// pads to the left so that the flattened shape matches up with the channel
if (!weight_bdim) {
new_shape.insert(new_shape.begin(), 1);
} else {
new_shape.insert(new_shape.begin() + 1, 1);
}
}

for (int64_t i = new_shape.size(); i < final_size; i ++) {
new_shape.push_back(1);
}
const auto weight_padded = weight_flatten.view(new_shape);

const auto weight_grad_is_batched = grad_out_bdim.has_value() || self_bdim.has_value(); // weight grad does not depend on weight values
const auto weight_grad_shape = shape_maybe_with_bdim(weight_, weight_grad_is_batched, weight_bdim.has_value(), batch_size);
const auto weight_padded_grad_shape = shape_maybe_with_bdim(weight_padded, weight_grad_is_batched, weight_bdim.has_value(), batch_size);

const auto grads = prelu_backward_batched(grad_out_, self_, weight_padded, self_size_with_bdim, weight_padded_grad_shape, weight_grad_shape);
return std::make_tuple(std::get<0>(grads), 0, std::get<1>(grads), (weight_grad_is_batched ? optional<int64_t>(0) : nullopt));
}

TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT(prelu, prelu_batch_rule)
VMAP_SUPPORT(prelu_backward, prelu_backward_batch_rule)
}
}} // namespace at::functorch
1 change: 0 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,6 @@ def test():
xfail('nn.functional.huber_loss'),
xfail('nn.functional.poisson_nll_loss'),
xfail('nn.functional.bilinear'),
xfail('nn.functional.prelu'),
xfail('nn.functional.glu'),
xfail('nn.functional.fractional_max_pool3d'),
xfail('as_strided'),
Expand Down

0 comments on commit 5e8bb19

Please sign in to comment.