diff --git a/docs/articles_en/about-openvino/compatibility-and-support/supported-operations-framework-frontend.rst b/docs/articles_en/about-openvino/compatibility-and-support/supported-operations-framework-frontend.rst index ec60d5a4fdb638..e6c30736bb825d 100644 --- a/docs/articles_en/about-openvino/compatibility-and-support/supported-operations-framework-frontend.rst +++ b/docs/articles_en/about-openvino/compatibility-and-support/supported-operations-framework-frontend.rst @@ -210,6 +210,7 @@ This page lists operations supported by OpenVINO Framework Frontend. aten::masked_fill_ aten::masked_scatter aten::masked_scatter_ + aten::masked_select aten::matmul aten::max aten::max_pool1d diff --git a/src/frontends/pytorch/src/op/masked_select.cpp b/src/frontends/pytorch/src/op/masked_select.cpp new file mode 100644 index 00000000000000..380b13a03470c8 --- /dev/null +++ b/src/frontends/pytorch/src/op/masked_select.cpp @@ -0,0 +1,25 @@ +#include "openvino/frontend/pytorch/node_context.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_masked_select(const NodeContext& context) { + // aten::masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor + num_inputs_check(context, 2, 2); + auto data = context.get_input(0); + auto mask = context.get_input(1); + ov::pass::NodeRegistry rg; + auto res = masked_select(rg, data, mask); + context.mark_nodes(rg.get()); + return {res}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index ae62d4b30e7d3b..f89fa042731bd9 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -127,6 +127,7 @@ OP_CONVERTER(translate_loop); OP_CONVERTER(translate_lstm); OP_CONVERTER(translate_masked_fill); OP_CONVERTER(translate_masked_scatter); +OP_CONVERTER(translate_masked_select); OP_CONVERTER(translate_max); OP_CONVERTER(translate_maximum); OP_CONVERTER(translate_max_poolnd); @@ -499,6 +500,7 @@ const std::map get_supported_ops_ts() { {"aten::masked_fill_", op::inplace_op}, {"aten::masked_scatter", op::translate_masked_scatter}, {"aten::masked_scatter_", op::inplace_op}, + {"aten::masked_select", op::translate_masked_select}, {"aten::matmul", op::translate_1to1_match_2_inputs}, {"aten::max", op::translate_max}, {"aten::maximum", op::translate_maximum}, diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 32f62eed603d47..f163b65eaca397 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -9,6 +9,7 @@ #include "openvino/core/rt_info.hpp" #include "openvino/frontend/pytorch/decoder.hpp" #include "openvino/opsets/opset10.hpp" +#include "openvino/op/constant.hpp" #include "openvino/util/log.hpp" #include "pt_framework_node.hpp" #include "translate_session.hpp" @@ -591,6 +592,13 @@ Output masked_fill(ov::pass::NodeRegistry& rg, return rg.make(bool_mask, _value, data); } +Output masked_select(ov::pass::NodeRegistry& rg, + const Output& data, + const Output& mask) { + auto _index = rg.make(mask); + return rg.make(data, _index); +} + } // namespace pytorch } // namespace frontend } // namespace ov diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index 6493d9a3f62c2d..09e4b6bd666d93 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -107,6 +107,10 @@ Output masked_fill(ov::pass::NodeRegistry& rg, const Output& mask, const Output& value); +Output masked_select(ov::pass::NodeRegistry& rg, + const Output& data, + const Output& mask); + namespace op { template OutputVector inplace_op(const NodeContext& context) { diff --git a/tests/layer_tests/pytorch_tests/test_masked_select.py b/tests/layer_tests/pytorch_tests/test_masked_select.py new file mode 100644 index 00000000000000..a2276e1eb2b3b0 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_masked_select.py @@ -0,0 +1,61 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch +from packaging.version import parse as parse_version +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestMaskedSelect(PytorchLayerTest): + def _prepare_input(self, mask_select='ones', mask_dtype=bool, input_dtype=float): + input_shape = [1, 10] + mask = np.zeros(input_shape).astype(mask_dtype) + if mask_select == 'ones': + mask = np.ones(input_shape).astype(mask_dtype) + if mask_select == 'random': + idx = np.random.choice(10, 5) + mask[:, idx] = 1 + return (np.random.randn(1, 10).astype(input_dtype), mask) + + def create_model(self): + import torch + + class aten_masked_select(torch.nn.Module): + def __init__(self): + super(aten_masked_select, self).__init__() + + def forward(self, x, mask): + return x.masked_select(mask) + + ref_net = None + + return aten_masked_select(), ref_net, "aten::masked_select" + + @pytest.mark.parametrize( + "mask_select", ['zeros', 'ones', 'random']) + @pytest.mark.parametrize("input_dtype", [np.float32, np.float64, int, np.int32]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_masked_select(self, mask_select, input_dtype, ie_device, precision, ir_version): + self._test(*self.create_model(), + ie_device, precision, ir_version, + dynamic_shapes=False, + trace_model=True, + kwargs_to_prepare_input={'mask_select': mask_select, 'mask_dtype': bool, "input_dtype": input_dtype}) + + @pytest.mark.skipif(parse_version(torch.__version__) >= parse_version("2.1.0"), reason="pytorch 2.1 and above does not support nonboolean mask") + @pytest.mark.parametrize( + "mask_select", ['zeros', 'ones', 'random']) + @pytest.mark.parametrize("input_dtype", [np.float32, np.float64, int, np.int32]) + @pytest.mark.parametrize("mask_dtype", [np.uint8, np.int32]) # np.float32 incorrectly casted to bool + @pytest.mark.nightly + @pytest.mark.precommit + def test_masked_select_non_bool_mask(self, mask_select, mask_dtype, input_dtype, ie_device, precision, ir_version): + self._test(*self.create_model(), + ie_device, precision, ir_version, + dynamic_shapes=False, + trace_model=True, + kwargs_to_prepare_input={'mask_select': mask_select, 'mask_dtype': mask_dtype, "input_dtype": input_dtype})