Skip to content

Commit

Permalink
prelu backward rule
Browse files Browse the repository at this point in the history
  • Loading branch information
samdow committed Mar 29, 2022
1 parent a34469c commit b6ff916
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
81 changes: 81 additions & 0 deletions functorch/csrc/BatchRulesBinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,86 @@ std::tuple<Tensor,optional<int64_t>> prelu_batch_rule(
return std::make_tuple(res, 0);
}

VmapDimVector ensure_shape_with_bdim(const Tensor& input, const bool has_bdim, const int64_t batch_size) {
if (has_bdim) {
VmapDimVector new_shape(input.sizes().begin(), input.sizes().end());
return new_shape;
}
VmapDimVector new_shape(1, batch_size);
new_shape.reserve(input.dim() + 1);
new_shape.insert(new_shape.end(), input.sizes().begin(), input.sizes().end());
return new_shape;
}

std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>> prelu_backward_batch_rule(
const Tensor& grad_out, optional<int64_t> grad_out_bdim,
const Tensor& self, optional<int64_t> self_bdim,
const Tensor& weight, optional<int64_t> weight_bdim) {
const auto batch_size = get_bdim_size3(grad_out, grad_out_bdim, self, self_bdim, weight, weight_bdim);
const auto grad_out_ = moveBatchDimToFront(ensure_has_bdim(grad_out, grad_out_bdim.has_value(), batch_size), grad_out_bdim);
const auto self_ = moveBatchDimToFront(ensure_has_bdim(self, self_bdim.has_value(), batch_size), self_bdim);
if (!weight_bdim && weight.dim() == 0) {
VmapDimVector weight_grad_shape(1, batch_size);
VmapDimVector weight_grad_shape_padded(self_.dim(), 1);
weight_grad_shape_padded[0] = batch_size;
const auto input_grad = at::where(self_ > 0, grad_out_, weight * grad_out_);
const auto input_grad_collector = native::sum_to_size(input_grad, self_.sizes());
const auto weight_grad_collector = at::where(self_ > 0, at::zeros(1, self.options()), self_ * grad_out_);
const auto weight_grad_collector_2 = native::sum_to_size(weight_grad_collector, weight_grad_shape_padded);
const auto weight_grad = weight_grad_collector_2.view(weight_grad_shape);

return std::make_tuple(input_grad_collector, 0, weight_grad, 0);
}
const auto weight_ = moveBatchDimToFront(weight, weight_bdim);
auto weight_flatten = weight_;
if (weight_flatten.dim() > 1) {
weight_flatten = at::flatten(weight_flatten, weight_bdim.has_value() ? 1 : 0, -1);
}

const int64_t self_logical_rank = rankWithoutBatchDim(self, self_bdim);
VmapDimVector new_shape(weight_flatten.sizes().begin(), weight_flatten.sizes().end());
const int64_t final_size = weight_bdim ? (self_logical_rank + 1) : self_logical_rank;
new_shape.reserve(final_size);

if (weight_flatten.dim() == 2 || !weight_bdim) {
// copies checks from prelu if the weight (without vmap) is not a scalar
TORCH_CHECK(self_logical_rank > 0, "Not allow zero-dim input tensor.");

int64_t channel_size = 1; // channel_size default to 1
if (self_logical_rank > 1) {
channel_size = self_.size(2); // always 2 since we ensure self_ has a batch dim
}

const auto weight_num = weight_flatten.size(-1);
TORCH_CHECK(channel_size == weight_num,
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
" and channel size = ", channel_size, ".");

// pads to the left so that the flattened shape matches up with the channel
if (!weight_bdim) {
new_shape.insert(new_shape.begin(), 1);
} else {
new_shape.insert(new_shape.begin() + 1, 1);
}
}

for (int64_t i = new_shape.size(); i < final_size; i ++) {
new_shape.push_back(1);
}
const auto weight_padded = weight_flatten.view(new_shape);

const auto weight_grad_shape = ensure_shape_with_bdim(weight_, weight_bdim.has_value(), batch_size);
const auto weight_padded_grad_shape = ensure_shape_with_bdim(weight_padded, weight_bdim.has_value(), batch_size);

const auto input_grad = at::where(self_ > 0, grad_out_, weight_padded * grad_out_);
const auto input_grad_collector = native::sum_to_size(input_grad, self_.sizes());
const auto weight_grad_collector = at::where(self_ > 0, at::zeros(1, self.options()), self_ * grad_out_);
const auto weight_grad_collector_2 = native::sum_to_size(weight_grad_collector, weight_padded_grad_shape);
const auto weight_grad = weight_grad_collector_2.view(weight_grad_shape);

return std::make_tuple(input_grad_collector, 0, weight_grad, 0);
}

TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
#define BINARY_RANDOM_POINTWISE(op) \
m.impl(#op, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN(op)));
Expand Down Expand Up @@ -455,6 +535,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
BINARY_POINTWISE(logit_backward);
POINTWISE_BOXED(log_sigmoid_backward);
BINARY_POINTWISE(gelu_backward);
VMAP_SUPPORT(prelu_backward, prelu_backward_batch_rule);
BINARY_POINTWISE(sigmoid_backward);
POINTWISE_BOXED(softplus_backward);
BINARY_POINTWISE(softshrink_backward);
Expand Down
1 change: 0 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,6 @@ def test():
xfail('nn.functional.huber_loss'),
xfail('nn.functional.poisson_nll_loss'),
xfail('nn.functional.bilinear'),
xfail('nn.functional.prelu'),
xfail('nn.functional.glu'),
xfail('nn.functional.fractional_max_pool3d'),
xfail('as_strided'),
Expand Down

0 comments on commit b6ff916

Please sign in to comment.