Skip to content

Commit

Permalink
[pt2] add SymInt support for bilinear (pytorch#103396)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#103396
Approved by: https://github.com/ezyang
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed Jun 13, 2023
1 parent 4a76fb4 commit c076344
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 25 deletions.
44 changes: 22 additions & 22 deletions aten/src/ATen/native/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,45 +626,45 @@ Tensor _trilinear(const Tensor& i1_, const Tensor& i2_, const Tensor& i3_,
Tensor i1 = i1_;
Tensor i2 = i2_;
Tensor i3 = i3_;
std::vector<int64_t> output_size;
std::vector<c10::SymInt> output_size;
std::vector<int64_t> sum_dims_12, sum_dims_23;
int64_t unroll_size = -1;
// asserts...
for (const auto i : c10::irange(total_dim)) {
int64_t s = 0;
c10::SymInt s = 0;
if (expand1[i]) {
i1 = i1.unsqueeze(i);
} else {
s = i1.size(i);
s = i1.sym_size(i);
}
if (expand2[i]) {
i2 = i2.unsqueeze(i);
} else {
s = i2.size(i);
s = i2.sym_size(i);
}
if (expand3[i]) {
i3 = i3.unsqueeze(i);
if (sumdim[i] && (i != unroll_dim))
sum_dims_12.push_back(i);
} else {
s = i3.size(i);
s = i3.sym_size(i);
if (sumdim[i] && (i != unroll_dim))
sum_dims_23.push_back(i);
}
output_size.push_back(sumdim[i] ? 1 : s);
if (i == unroll_dim)
unroll_size = s;
unroll_size = s.guard_int(__FILE__, __LINE__);
}
int64_t slicemul1 = (expand1[unroll_dim] ? 0 : 1);
int64_t slicemul2 = (expand2[unroll_dim] ? 0 : 1);
int64_t slicemul3 = (expand3[unroll_dim] ? 0 : 1);

auto output = at::zeros(output_size, i1.options());
auto output = at::zeros_symint(output_size, i1.options());

// Three conditionals are necessary since this function is meant to work for both
// forward and backward, which changes the dimensions of the inputs.
// Note that if output has zero elems is because (at least) one of i1, i2, i3 has zero elems.
if (i1.numel() != 0 && i2.numel() != 0 && i3.numel() != 0) {
if (i1.sym_numel() != 0 && i2.sym_numel() != 0 && i3.sym_numel() != 0) {
if (! sumdim[unroll_dim]) {
for (const auto k : c10::irange(unroll_size)) {
Tensor buf = at::native::sumproduct_pair(i1.narrow(unroll_dim, k * slicemul1, 1),
Expand Down Expand Up @@ -696,26 +696,26 @@ Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight

TORCH_CHECK(input1.dim() == input2.dim(), "bilinear(): input dimensions do not match: got ", input1.dim(), " and ", input2.dim());
for (const auto i : c10::irange(input1.dim() - 1)) {
TORCH_CHECK(input1.size(i) == input2.size(i),
"bilinear(): input batch dimensions do not match at dim ", i, ": got ", input1.size(i), " and ", input2.size(i));
TORCH_CHECK(input1.sym_size(i) == input2.sym_size(i),
"bilinear(): input batch dimensions do not match at dim ", i, ": got ", input1.sym_size(i), " and ", input2.sym_size(i));
}
TORCH_CHECK(input1.size(input1.dim() - 1) == weight.size(1),
TORCH_CHECK(input1.sym_size(input1.dim() - 1) == weight.sym_size(1),
"bilinear(): input1 size does not match weight size: got ",
input1.size(input1.dim() - 1), " but expected ", weight.size(1));
TORCH_CHECK(input2.size(input2.dim() - 1) == weight.size(2),
input1.sym_size(input1.dim() - 1), " but expected ", weight.sym_size(1));
TORCH_CHECK(input2.sym_size(input2.dim() - 1) == weight.sym_size(2),
"bilinear(): input2 size does not match weight size: got ",
input2.size(input2.dim() - 1), " but expected ", weight.size(2));
TORCH_CHECK(!bias.defined() || bias.size(0) == weight.size(0),
input2.sym_size(input2.dim() - 1), " but expected ", weight.sym_size(2));
TORCH_CHECK(!bias.defined() || bias.sym_size(0) == weight.sym_size(0),
"bilinear(): bias size does not match weight size: got ",
bias.size(0), " but expected ", weight.size(0));
bias.sym_size(0), " but expected ", weight.sym_size(0));

std::vector<int64_t> output_size;
auto size1 = input1.sizes();
std::vector<c10::SymInt> output_size;
auto size1 = input1.sym_sizes();
output_size.insert(output_size.end(), size1.begin(), size1.end() - 1);
output_size.push_back(weight.size(0));
auto input1_flattened = input1.reshape({-1, input1.size(-1)});
auto input2_flattened = input2.reshape({-1, input2.size(-1)});
Tensor output = at::_trilinear(input1_flattened, weight, input2_flattened, {1,3}, {0}, {1,2}, {2,3}).reshape(output_size);
output_size.push_back(weight.sym_size(0));
auto input1_flattened = input1.reshape_symint({-1, input1.sym_size(-1)});
auto input2_flattened = input2.reshape_symint({-1, input2.sym_size(-1)});
Tensor output = at::_trilinear(input1_flattened, weight, input2_flattened, {1,3}, {0}, {1,2}, {2,3}).reshape_symint(output_size);
if (bias.defined()) {
output = output + bias;
}
Expand Down
2 changes: 0 additions & 2 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2810,7 +2810,6 @@ def forward(self, x):
xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbo...
xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2...
skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te..
xfail('nn.functional.bilinear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.binary_cross_entropy', ''), # aten.fill_.Scalar - couldn't find symbolic meta funct...
xfail('nn.functional.cosine_similarity', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.cross_entropy', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
Expand Down Expand Up @@ -3040,7 +3039,6 @@ def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op):
torch.nn.Transformer, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
torch.nn.TransformerEncoder, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
torch.nn.GaussianNLLLoss, # NotImplementedError: local_scalar_dense/item NYI for torch.bool
torch.nn.Bilinear, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.ReplicationPad1d, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.ReplicationPad2d, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.ReplicationPad3d, # Cannot call sizes() on tensor with symbolic sizes/strides
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,7 +1490,6 @@ def f(a, b, c, d, e):
xfail('nn.functional.adaptive_max_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct...
xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2) must be tupl...
xfail('nn.functional.bilinear', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.cosine_similarity', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
Expand Down

0 comments on commit c076344

Please sign in to comment.