Skip to content

Commit

Permalink
[pt2] add SymInt support for cosine_similarity (pytorch#103400)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed Jun 13, 2023
1 parent c076344 commit d38b651
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 5 deletions.
6 changes: 3 additions & 3 deletions aten/src/ATen/native/Distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,9 @@ Tensor cosine_similarity(const Tensor& x1_, const Tensor& x2_, int64_t dim, doub
auto commonDtype = at::result_type(x1_, x2_);
TORCH_CHECK(at::isFloatingType(commonDtype), "expected common dtype to be floating point, yet common dtype is ", commonDtype);

auto common_size = at::infer_size_dimvector(x1_.sizes(), x2_.sizes());
auto x1 = x1_.to(commonDtype).expand(common_size);
auto x2 = x2_.to(commonDtype).expand(common_size);
auto common_size = at::infer_size_symdimvector(x1_.sym_sizes(), x2_.sym_sizes());
auto x1 = x1_.to(commonDtype).expand_symint(common_size);
auto x2 = x2_.to(commonDtype).expand_symint(common_size);

auto x1_squared_norm = at::pow(x1, 2).sum(dim, /*keepdim=*/true);
auto x2_squared_norm = at::pow(x2, 2).sum(dim, /*keepdim=*/true);
Expand Down
1 change: 0 additions & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2811,7 +2811,6 @@ def forward(self, x):
xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2...
skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te..
xfail('nn.functional.binary_cross_entropy', ''), # aten.fill_.Scalar - couldn't find symbolic meta funct...
xfail('nn.functional.cosine_similarity', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.cross_entropy', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.ctc_loss', ''), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/deco...
xfail('nn.functional.embedding_bag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1491,7 +1491,6 @@ def f(a, b, c, d, e):
xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct...
xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2) must be tupl...
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.cosine_similarity', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
xfail('nn.functional.embedding_bag', ''), # aten._embedding_bag_forward_only.default - couldn't find symbolic meta fun...
Expand Down

0 comments on commit d38b651

Please sign in to comment.