Skip to content

Commit

Permalink
[PT FE] Support aten::masked_select for pytorch models (#26162)
Browse files Browse the repository at this point in the history
### Details:
 - support `aten::masked_select` operator

### Tickets:
 - [None](#23325)
  • Loading branch information
hub-bla authored Aug 26, 2024
1 parent b80a179 commit 2660f85
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/frontends/pytorch/src/op/masked_select.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_masked_select(const NodeContext& context) {
// aten::masked_select(Tensor self, Tensor mask, Tensor source) -> Tensor
num_inputs_check(context, 2, 2);
auto data = context.get_input(0);
auto mask = context.get_input(1);
auto res = masked_select(context, data, mask);
return {res};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ OP_CONVERTER(translate_loop);
OP_CONVERTER(translate_lstm);
OP_CONVERTER(translate_masked_fill);
OP_CONVERTER(translate_masked_scatter);
OP_CONVERTER(translate_masked_select);
OP_CONVERTER(translate_max);
OP_CONVERTER(translate_maximum);
OP_CONVERTER(translate_max_poolnd);
Expand Down Expand Up @@ -528,6 +529,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::lt", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},
{"aten::masked_fill", op::translate_masked_fill},
{"aten::masked_scatter", op::translate_masked_scatter},
{"aten::masked_select", op::translate_masked_select},
{"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::max", op::translate_max},
{"aten::mv", op::translate_1to1_match_2_inputs<opset10::MatMul>},
Expand Down
11 changes: 11 additions & 0 deletions src/frontends/pytorch/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_promote_types.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/gather_nd.hpp"
#include "openvino/op/mod.hpp"
#include "openvino/op/non_zero.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reduce_prod.hpp"
#include "openvino/op/reshape.hpp"
Expand All @@ -22,6 +25,7 @@
#include "openvino/op/slice.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/util/log.hpp"
#include "pt_framework_node.hpp"
Expand Down Expand Up @@ -592,6 +596,13 @@ Output<Node> concat_list_from_inputs(const NodeContext& context, size_t begin, s
return concat;
}

Output<Node> masked_select(const NodeContext& context, const Output<Node>& data, const Output<Node>& mask) {
auto input_order = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {1, 0}));
auto nonzero = context.mark_node(std::make_shared<v3::NonZero>(mask));
auto masked_id = context.mark_node(std::make_shared<v1::Transpose>(nonzero, input_order));
return context.mark_node(std::make_shared<v8::GatherND>(data, masked_id));
}

} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ Output<Node> masked_fill(ov::pass::NodeRegistry& rg,

Output<Node> concat_list_from_inputs(const NodeContext& context, size_t begin, size_t end);

Output<Node> masked_select(const NodeContext& context, const Output<Node>& data, const Output<Node>& mask);

namespace op {
template <OutputVector (*T)(const NodeContext&), size_t idx = 0>
OutputVector inplace_op(const NodeContext& context) {
Expand Down
62 changes: 62 additions & 0 deletions tests/layer_tests/pytorch_tests/test_masked_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import torch
from packaging.version import parse as parse_version
import pytest

from pytorch_layer_test_class import PytorchLayerTest


class TestMaskedSelect(PytorchLayerTest):
def _prepare_input(self, mask_select='ones', mask_dtype=bool, input_dtype=float):
input_shape = [1, 10]
mask = np.zeros(input_shape).astype(mask_dtype)
if mask_select == 'ones':
mask = np.ones(input_shape).astype(mask_dtype)
if mask_select == 'random':
idx = np.random.choice(10, 5)
mask[:, idx] = 1
return (np.random.randn(1, 10).astype(input_dtype), mask)

def create_model(self):
import torch

class aten_masked_select(torch.nn.Module):
def __init__(self):
super(aten_masked_select, self).__init__()

def forward(self, x, mask):
return x.masked_select(mask)

ref_net = None

return aten_masked_select(), ref_net, "aten::masked_select"

@pytest.mark.parametrize(
"mask_select", ['zeros', 'ones', 'random'])
@pytest.mark.parametrize("input_dtype", [np.float32, np.float64, int, np.int32])
@pytest.mark.nightly
@pytest.mark.precommit
def test_masked_select(self, mask_select, input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(),
ie_device, precision, ir_version,
dynamic_shapes=False,
trace_model=True,
kwargs_to_prepare_input={'mask_select': mask_select, 'mask_dtype': bool, "input_dtype": input_dtype})

@pytest.mark.skipif(parse_version(torch.__version__) >= parse_version("2.1.0"), reason="pytorch 2.1 and above does not support nonboolean mask")
@pytest.mark.parametrize(
"mask_select", ['zeros', 'ones', 'random'])
@pytest.mark.parametrize("input_dtype", [np.float32, np.float64, int, np.int32])
@pytest.mark.parametrize("mask_dtype", [np.uint8, np.int32, np.float32])
@pytest.mark.nightly
@pytest.mark.precommit
def test_masked_select_non_bool_mask(self, mask_select, mask_dtype, input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(),
ie_device, precision, ir_version,
dynamic_shapes=False,
trace_model=True,
kwargs_to_prepare_input={'mask_select': mask_select, 'mask_dtype': mask_dtype, "input_dtype": input_dtype})

0 comments on commit 2660f85

Please sign in to comment.