From 7d34590ee7ae71aacf4cc8ef17f7b853e3518efe Mon Sep 17 00:00:00 2001 From: Piotr Kowalczyk Date: Thu, 24 Oct 2024 15:14:25 +0200 Subject: [PATCH] [Pytorch fronted]: Added support for Search Sorted op (#26976) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Details: - Added support for SearchSorted op with unittest. ### Tickets: - *[CVS-154559](https://jira.devtools.intel.com/browse/CVS-154559)* Depends on: - https://github.com/openvinotoolkit/openvino/pull/26958 - https://github.com/openvinotoolkit/openvino/pull/27036 --------- Signed-off-by: Kazantsev, Roman Signed-off-by: dependabot[bot] Co-authored-by: Michal Lukaszewski Co-authored-by: Pawel Raasz Co-authored-by: Andrey Babushkin Co-authored-by: Alicja Miloszewska Co-authored-by: Bogdan Pereanu Co-authored-by: Karol Blaszczak Co-authored-by: Tatiana Savina Co-authored-by: Anastasiya(Asya) Pronina Co-authored-by: Dmitry Matveev Co-authored-by: Andrei Beleiu Co-authored-by: Andrew Kwangwoong Park Co-authored-by: Roman Kazantsev Co-authored-by: Pavel Durandin Co-authored-by: Alexey Smirnov Co-authored-by: Hubert Błaszczyk <56601011+hub-bla@users.noreply.github.com> Co-authored-by: Vladimir Paramuzov Co-authored-by: Sergey Shlyapnikov Co-authored-by: Ivan Tikhonov Co-authored-by: Andrzej Kopytko Co-authored-by: Sebastian Golebiewski Co-authored-by: Alina Kladieva Co-authored-by: Ilya Lavrenov Co-authored-by: Maxim Vafin Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Mateusz Mikolajczyk --- .../pytorch/src/op/search_sorted.cpp | 34 ++++++++++++++ src/frontends/pytorch/src/op_table.cpp | 2 + .../pytorch_tests/test_search_sorted.py | 47 +++++++++++++++++++ 3 files changed, 83 insertions(+) create mode 100644 src/frontends/pytorch/src/op/search_sorted.cpp create mode 100644 tests/layer_tests/pytorch_tests/test_search_sorted.py diff --git a/src/frontends/pytorch/src/op/search_sorted.cpp b/src/frontends/pytorch/src/op/search_sorted.cpp new file mode 100644 index 00000000000000..ca9f6b49ff7bf9 --- /dev/null +++ b/src/frontends/pytorch/src/op/search_sorted.cpp @@ -0,0 +1,34 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/search_sorted.hpp" + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_search_sorted(const NodeContext& context) { + num_inputs_check(context, 2, 5); + Output sorted; + Output values; + std::tie(sorted, values) = get_inputs_with_promoted_types(context, 0, 1); + const bool out_int32 = context.const_input(2); + PYTORCH_OP_CONVERSION_CHECK(out_int32 == false, "aten::searchsorted(out_int32=true) unsupported"); + const bool right_mode = context.const_input(3); + PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(4), "aten::searchsorted(side) unsupported"); + PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(5), "aten::searchsorted(out) unsupported"); + PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(6), "aten::searchsorted(sorter) unsupported"); + auto op = context.mark_node(std::make_shared(sorted, values, right_mode)); + return {op}; +}; +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 195977432e40e5..66c76e33032ef6 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -300,6 +300,7 @@ OP_CONVERTER(translate_reshape_fx); OP_CONVERTER(translate_rsub_fx); OP_CONVERTER(translate_scalar_tensor_fx); OP_CONVERTER(translate_scaled_dot_product_attention_fx); +OP_CONVERTER(translate_search_sorted); OP_CONVERTER(translate_select_scatter_fx); OP_CONVERTER(translate_slice_fx); OP_CONVERTER(translate_slice_scatter_fx); @@ -617,6 +618,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::rsqrt", op::optional_out}, {"aten::rsqrt_", op::inplace_op}, {"aten::rsub", op::translate_rsub}, + {"aten::searchsorted", op::translate_search_sorted}, {"aten::ScalarImplicit", op::skip_node}, {"aten::scaled_dot_product_attention", op::translate_scaled_dot_product_attention}, {"aten::scatter", op::translate_scatter}, diff --git a/tests/layer_tests/pytorch_tests/test_search_sorted.py b/tests/layer_tests/pytorch_tests/test_search_sorted.py new file mode 100644 index 00000000000000..645033e2ee260b --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_search_sorted.py @@ -0,0 +1,47 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest +import numpy as np + + +class TestSearchSorted(PytorchLayerTest): + def _prepare_input(self): + return (np.array(self.sorted).astype(self.sorted_type),np.array(self.values).astype(self.values_type)) + + def create_model(self, right_mode): + import torch + + class aten_searchsorted(torch.nn.Module): + def __init__(self, right_mode): + super(aten_searchsorted, self).__init__() + self.right_mode = right_mode + + def forward(self, sorted, values): + return torch.searchsorted(sorted, values, right=self.right_mode) + + ref_net = None + + return aten_searchsorted(right_mode), ref_net, "aten::searchsorted" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize(("sorted", "values"), [ + ([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]], [[3, 6, 9], [3, 6, 9]]), + ([1, 3, 5, 7, 9], [[3, 6, 9],[0, 5, 20]]), + ([4091, 4092], [[4091, 4092]]), # fp16 cannot exactly represent 4091 number + ([1.23, 2.99], [[1.355, 2.9991]]) + ]) + @pytest.mark.parametrize("right_mode", [False, True]) + @pytest.mark.parametrize("sorted_type", [np.float32, np.float16, np.int8]) + @pytest.mark.parametrize("values_type", [np.float16, np.int32, np.int64]) + def test_searchsorted(self, sorted, values, right_mode, sorted_type, values_type, ie_device, precision, ir_version): + self.sorted = sorted + self.values = values + self.sorted_type = sorted_type + self.values_type = values_type + if ie_device == "CPU" and sorted_type == np.float16 and sorted == [4091, 4092]: + pytest.skip(reason="CPU plugin on defult converts fp16 to fp32, if that happens the test will fail for those malicious values") + self._test(*self.create_model(right_mode), ie_device, precision, ir_version)