Skip to content

Commit

Permalink
[ONNX] Add symbolic support for torch.nn.cosinesimilarity (#72128) (#…
Browse files Browse the repository at this point in the history
…73283)

Summary:
Pull Request resolved: #73283

* Add support for torch.nn.cosine_similarity

* Remove fallback logic

* Fix onnx test failures

* Fix opset version

* Modify rtol

* Add aten fallback mode

* fix mypy

* gate with caffe2 fallback

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D34625650

Pulled By: malfet

fbshipit-source-id: bf15d32b1d7055d0ca166d9941ba90b5c8e81cc2
(cherry picked from commit 7086031)
  • Loading branch information
BowenBao authored and pytorchmergebot committed Mar 9, 2022
1 parent 95b1232 commit 97ae431
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
6 changes: 6 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7271,6 +7271,12 @@ def forward(self, x):
for x in [torch.randn(3, 4), torch.randn(3, 4).to(dtype=torch.bool)]:
self.run_test(EinsumModelTranspose(), input=(x,))

@skipIfUnsupportedMinOpsetVersion(9)
def test_cosine_similarity(self):
x = torch.randn(5, 3, 2)
y = torch.randn(5, 3, 2)
self.run_test(torch.nn.CosineSimilarity(dim=2), input=(x, y))

@skipIfUnsupportedMinOpsetVersion(12)
def test_crossentropyloss(self):
for ignore_index in [-100, 1]:
Expand Down
5 changes: 3 additions & 2 deletions test/onnx/test_utility_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ def test_onnx_fallthrough(self):
# Test aten export of op with symbolic for aten
x = torch.randn(100, 128)
y = torch.randn(100, 128)
model = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
model = torch.nn.PairwiseDistance(p=2, eps=1e-6)

graph, _, __ = self._model_to_graph(model, (x, y),
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
Expand All @@ -966,7 +966,8 @@ def test_onnx_fallthrough(self):
iter = graph.nodes()
self.assertEqual(next(iter).kind(), "onnx::Constant")
self.assertEqual(next(iter).kind(), "onnx::Constant")
self.assertEqual(next(iter).kind(), "aten::cosine_similarity")
self.assertEqual(next(iter).kind(), "onnx::Constant")
self.assertEqual(next(iter).kind(), "aten::pairwise_distance")

# prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11
@skipIfUnsupportedMaxOpsetVersion(10)
Expand Down
14 changes: 11 additions & 3 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,10 +1548,18 @@ def type_as(g, self, other):

@parse_args("v", "v", "i", "f")
def cosine_similarity(g, x1, x2, dim, eps):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
# preserve legacy behavior for Caffe2
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and \
torch.onnx._CAFFE2_ATEN_FALLBACK:
return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps)
else:
return sym_help._onnx_unsupported("cosine_similarity")
cross = sym_help._reducesum_helper(g, mul(g, x1, x2),
axes_i=[dim], keepdims_i=0)
x1_l2 = sym_help._reducesum_helper(g, mul(g, x1, x1),
axes_i=[dim], keepdims_i=0)
x2_l2 = sym_help._reducesum_helper(g, mul(g, x2, x2),
axes_i=[dim], keepdims_i=0)
div_tens = max(g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps])))
return div(g, cross, div_tens)


# ignore clone operators that are inserted by PyTorch autograd
Expand Down

0 comments on commit 97ae431

Please sign in to comment.