diff --git a/src/frontends/pytorch/src/op/bucketize.cpp b/src/frontends/pytorch/src/op/bucketize.cpp new file mode 100644 index 00000000000000..07ac70458824cb --- /dev/null +++ b/src/frontends/pytorch/src/op/bucketize.cpp @@ -0,0 +1,50 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/bucketize.hpp" + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/logical_or.hpp" +#include "openvino/op/multiply.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_bucketize(const NodeContext& context) { + num_inputs_check(context, 2, 5); + auto input = context.get_input(0); + auto boundaries = context.get_input(1); + + element::Type output_type = ov::element::i64; + if (!context.input_is_none(2) && context.const_input(2)) { + output_type = ov::element::i32; + } + + bool with_right_bound = true; + if (!context.input_is_none(3)) { + with_right_bound = !context.const_input(3); + } + + auto bucketize = + context.mark_node(std::make_shared(input, boundaries, output_type, with_right_bound)); + + if (!context.input_is_none(4)) { + context.mutate_input(4, bucketize); + } + + return {bucketize}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index b34f87fe8bf139..ea2ff9cf6c5a59 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -47,6 +47,7 @@ OP_CONVERTER(translate_bitwise_and); OP_CONVERTER(translate_bitwise_not); OP_CONVERTER(translate_bitwise_or); OP_CONVERTER(translate_bitwise_xor); +OP_CONVERTER(translate_bucketize); OP_CONVERTER(translate_cat); OP_CONVERTER(translate_cdist); OP_CONVERTER(translate_celu); @@ -374,6 +375,7 @@ const std::map get_supported_ops_ts() { {"aten::Bool", op::translate_bool}, // aten::broadcast_tensors - Supported in limited set of patterns {"aten::broadcast_to", op::translate_expand}, + {"aten::bucketize", op::translate_bucketize}, {"aten::cat", op::translate_cat}, {"aten::cdist", op::translate_cdist}, {"aten::ceil", op::optional_out, 1>}, diff --git a/tests/layer_tests/pytorch_tests/test_bucketize.py b/tests/layer_tests/pytorch_tests/test_bucketize.py new file mode 100644 index 00000000000000..29fb550708e464 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_bucketize.py @@ -0,0 +1,53 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestBucketize(PytorchLayerTest): + + def _prepare_input(self, input_shape, boundaries_range, input_dtype, boundaries_dtype): + return ( + np.random.randn(*input_shape).astype(input_dtype), + np.arange(*boundaries_range).astype(boundaries_dtype)) + + def create_model(self, out_int32, right, is_out): + class aten_bucketize(torch.nn.Module): + + def __init__(self, out_int32, right, is_out) -> None: + super().__init__() + self.out_int32 = out_int32 + self.right = right + self.is_out = is_out + + def forward(self, input, boundaries): + if self.is_out: + output_dtype = torch.int32 if self.out_int32 else torch.int64 + output = torch.zeros_like(input, dtype=output_dtype) + torch.bucketize(input, boundaries, out_int32=self.out_int32, right=self.right, out=output) + return output + else: + return torch.bucketize(input, boundaries, out_int32=self.out_int32, right=self.right) + + ref_net = None + + return aten_bucketize(out_int32, right, is_out), ref_net, "aten::bucketize" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("out_int32", [True, False]) + @pytest.mark.parametrize("right", [True, False]) + @pytest.mark.parametrize("is_out", [True, False]) + @pytest.mark.parametrize("input_shape", [[1, ], [2, 1], [2, 2, 1]]) + @pytest.mark.parametrize("input_dtype", ["float32", "int32"]) + @pytest.mark.parametrize("boundaries_range", [[1, 10], (100, 200)]) + @pytest.mark.parametrize("boundaries_dtype", ["float32", "int32"]) + def test_bucketize(self, input_shape, boundaries_range, input_dtype, boundaries_dtype, out_int32, right, is_out, ie_device, precision, ir_version): + self._test(*self.create_model(out_int32, right, is_out), ie_device, precision, ir_version, kwargs_to_prepare_input={ + "input_shape": input_shape, "input_dtype": input_dtype, + "boundaries_range": boundaries_range, "boundaries_dtype": boundaries_dtype, + })