Skip to content

Commit

Permalink
[Pytorch fronted]: Added support for Search Sorted op (openvinotoolki…
Browse files Browse the repository at this point in the history
…t#26976)

### Details:
 - Added support for SearchSorted op with unittest.

### Tickets:
 - *[CVS-154559](https://jira.devtools.intel.com/browse/CVS-154559)*

Depends on: 
 - openvinotoolkit#26958
 - openvinotoolkit#27036

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: Michal Lukaszewski <[email protected]>
Co-authored-by: Pawel Raasz <[email protected]>
Co-authored-by: Andrey Babushkin <[email protected]>
Co-authored-by: Alicja Miloszewska <[email protected]>
Co-authored-by: Bogdan Pereanu <[email protected]>
Co-authored-by: Karol Blaszczak <[email protected]>
Co-authored-by: Tatiana Savina <[email protected]>
Co-authored-by: Anastasiya(Asya) Pronina <[email protected]>
Co-authored-by: Dmitry Matveev <[email protected]>
Co-authored-by: Andrei Beleiu <[email protected]>
Co-authored-by: Andrew Kwangwoong Park <[email protected]>
Co-authored-by: Roman Kazantsev <[email protected]>
Co-authored-by: Pavel Durandin <[email protected]>
Co-authored-by: Alexey Smirnov <[email protected]>
Co-authored-by: Hubert Błaszczyk <[email protected]>
Co-authored-by: Vladimir Paramuzov <[email protected]>
Co-authored-by: Sergey Shlyapnikov <[email protected]>
Co-authored-by: Ivan Tikhonov <[email protected]>
Co-authored-by: Andrzej Kopytko <[email protected]>
Co-authored-by: Sebastian Golebiewski <[email protected]>
Co-authored-by: Alina Kladieva <[email protected]>
Co-authored-by: Ilya Lavrenov <[email protected]>
Co-authored-by: Maxim Vafin <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Mateusz Mikolajczyk <[email protected]>
  • Loading branch information
1 parent 9c1055f commit 7d34590
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 0 deletions.
34 changes: 34 additions & 0 deletions src/frontends/pytorch/src/op/search_sorted.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/search_sorted.hpp"

#include "openvino/frontend/pytorch/node_context.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_search_sorted(const NodeContext& context) {
num_inputs_check(context, 2, 5);
Output<Node> sorted;
Output<Node> values;
std::tie(sorted, values) = get_inputs_with_promoted_types(context, 0, 1);
const bool out_int32 = context.const_input<bool>(2);
PYTORCH_OP_CONVERSION_CHECK(out_int32 == false, "aten::searchsorted(out_int32=true) unsupported");
const bool right_mode = context.const_input<bool>(3);
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(4), "aten::searchsorted(side) unsupported");
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(5), "aten::searchsorted(out) unsupported");
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(6), "aten::searchsorted(sorter) unsupported");
auto op = context.mark_node(std::make_shared<ov::op::v15::SearchSorted>(sorted, values, right_mode));
return {op};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ OP_CONVERTER(translate_reshape_fx);
OP_CONVERTER(translate_rsub_fx);
OP_CONVERTER(translate_scalar_tensor_fx);
OP_CONVERTER(translate_scaled_dot_product_attention_fx);
OP_CONVERTER(translate_search_sorted);
OP_CONVERTER(translate_select_scatter_fx);
OP_CONVERTER(translate_slice_fx);
OP_CONVERTER(translate_slice_scatter_fx);
Expand Down Expand Up @@ -617,6 +618,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::rsqrt", op::optional_out<op::translate_rsqrt, 1>},
{"aten::rsqrt_", op::inplace_op<op::translate_rsqrt>},
{"aten::rsub", op::translate_rsub},
{"aten::searchsorted", op::translate_search_sorted},
{"aten::ScalarImplicit", op::skip_node},
{"aten::scaled_dot_product_attention", op::translate_scaled_dot_product_attention},
{"aten::scatter", op::translate_scatter},
Expand Down
47 changes: 47 additions & 0 deletions tests/layer_tests/pytorch_tests/test_search_sorted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest

from pytorch_layer_test_class import PytorchLayerTest
import numpy as np


class TestSearchSorted(PytorchLayerTest):
def _prepare_input(self):
return (np.array(self.sorted).astype(self.sorted_type),np.array(self.values).astype(self.values_type))

def create_model(self, right_mode):
import torch

class aten_searchsorted(torch.nn.Module):
def __init__(self, right_mode):
super(aten_searchsorted, self).__init__()
self.right_mode = right_mode

def forward(self, sorted, values):
return torch.searchsorted(sorted, values, right=self.right_mode)

ref_net = None

return aten_searchsorted(right_mode), ref_net, "aten::searchsorted"

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize(("sorted", "values"), [
([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]], [[3, 6, 9], [3, 6, 9]]),
([1, 3, 5, 7, 9], [[3, 6, 9],[0, 5, 20]]),
([4091, 4092], [[4091, 4092]]), # fp16 cannot exactly represent 4091 number
([1.23, 2.99], [[1.355, 2.9991]])
])
@pytest.mark.parametrize("right_mode", [False, True])
@pytest.mark.parametrize("sorted_type", [np.float32, np.float16, np.int8])
@pytest.mark.parametrize("values_type", [np.float16, np.int32, np.int64])
def test_searchsorted(self, sorted, values, right_mode, sorted_type, values_type, ie_device, precision, ir_version):
self.sorted = sorted
self.values = values
self.sorted_type = sorted_type
self.values_type = values_type
if ie_device == "CPU" and sorted_type == np.float16 and sorted == [4091, 4092]:
pytest.skip(reason="CPU plugin on defult converts fp16 to fp32, if that happens the test will fail for those malicious values")
self._test(*self.create_model(right_mode), ie_device, precision, ir_version)

0 comments on commit 7d34590

Please sign in to comment.