diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp index 9af421b3e2e99..cf8b411de1b93 100644 --- a/aten/src/ATen/native/Distance.cpp +++ b/aten/src/ATen/native/Distance.cpp @@ -303,9 +303,9 @@ Tensor cosine_similarity(const Tensor& x1_, const Tensor& x2_, int64_t dim, doub auto commonDtype = at::result_type(x1_, x2_); TORCH_CHECK(at::isFloatingType(commonDtype), "expected common dtype to be floating point, yet common dtype is ", commonDtype); - auto common_size = at::infer_size_dimvector(x1_.sizes(), x2_.sizes()); - auto x1 = x1_.to(commonDtype).expand(common_size); - auto x2 = x2_.to(commonDtype).expand(common_size); + auto common_size = at::infer_size_symdimvector(x1_.sym_sizes(), x2_.sym_sizes()); + auto x1 = x1_.to(commonDtype).expand_symint(common_size); + auto x2 = x2_.to(commonDtype).expand_symint(common_size); auto x1_squared_norm = at::pow(x1, 2).sum(dim, /*keepdim=*/true); auto x2_squared_norm = at::pow(x2, 2).sum(dim, /*keepdim=*/true); diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 882b1248af31e..282d64b721b44 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -2811,7 +2811,6 @@ def forward(self, x): xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2... skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for