-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PT FE] [23325] Add aten::masked_select support for pytorch models. #23354
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#include "openvino/frontend/pytorch/node_context.hpp" | ||
#include "utils.hpp" | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace pytorch { | ||
namespace op { | ||
|
||
using namespace ov::op; | ||
|
||
OutputVector translate_masked_select(const NodeContext& context) { | ||
// aten::masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor | ||
num_inputs_check(context, 2, 2); | ||
auto data = context.get_input(0); | ||
auto mask = context.get_input(1); | ||
ov::pass::NodeRegistry rg; | ||
auto res = masked_select(rg, data, mask); | ||
context.mark_nodes(rg.get()); | ||
return {res}; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace pytorch | ||
} // namespace frontend | ||
} // namespace ov |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (C) 2018-2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import numpy as np | ||
import torch | ||
from packaging.version import parse as parse_version | ||
import pytest | ||
|
||
from pytorch_layer_test_class import PytorchLayerTest | ||
|
||
|
||
class TestMaskedSelect(PytorchLayerTest): | ||
def _prepare_input(self, mask_select='ones', mask_dtype=bool, input_dtype=float): | ||
input_shape = [1, 10] | ||
mask = np.zeros(input_shape).astype(mask_dtype) | ||
if mask_select == 'ones': | ||
mask = np.ones(input_shape).astype(mask_dtype) | ||
if mask_select == 'random': | ||
idx = np.random.choice(10, 5) | ||
mask[:, idx] = 1 | ||
return (np.random.randn(1, 10).astype(input_dtype), mask) | ||
|
||
def create_model(self): | ||
import torch | ||
|
||
class aten_masked_select(torch.nn.Module): | ||
def __init__(self): | ||
super(aten_masked_select, self).__init__() | ||
|
||
def forward(self, x, mask): | ||
return x.masked_select(mask) | ||
|
||
ref_net = None | ||
|
||
return aten_masked_select(), ref_net, "aten::masked_select" | ||
|
||
@pytest.mark.parametrize( | ||
"mask_select", ['zeros', 'ones', 'random']) | ||
@pytest.mark.parametrize("input_dtype", [np.float32, np.float64, int, np.int32]) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
def test_masked_select(self, mask_select, input_dtype, ie_device, precision, ir_version): | ||
self._test(*self.create_model(), | ||
ie_device, precision, ir_version, | ||
dynamic_shapes=False, | ||
trace_model=True, | ||
kwargs_to_prepare_input={'mask_select': mask_select, 'mask_dtype': bool, "input_dtype": input_dtype}) | ||
|
||
@pytest.mark.skipif(parse_version(torch.__version__) >= parse_version("2.1.0"), reason="pytorch 2.1 and above does not support nonboolean mask") | ||
@pytest.mark.parametrize( | ||
"mask_select", ['zeros', 'ones', 'random']) | ||
@pytest.mark.parametrize("input_dtype", [np.float32, np.float64, int, np.int32]) | ||
@pytest.mark.parametrize("mask_dtype", [np.uint8, np.int32]) # np.float32 incorrectly casted to bool | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. np.float32 is incorrectly casted by torch or openvino? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was just following test case in test_masked_fill. Should I keep it or remove it? |
||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
def test_masked_select_non_bool_mask(self, mask_select, mask_dtype, input_dtype, ie_device, precision, ir_version): | ||
self._test(*self.create_model(), | ||
ie_device, precision, ir_version, | ||
dynamic_shapes=False, | ||
trace_model=True, | ||
kwargs_to_prepare_input={'mask_select': mask_select, 'mask_dtype': mask_dtype, "input_dtype": input_dtype}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to do
Transpose
afterNonZero
, since the format of indices inGatherND
is different. Like here:openvino/src/frontends/pytorch/src/op/index.cpp
Line 241 in c20d232