diff --git a/src/frontends/jax/src/op/select_n.cpp b/src/frontends/jax/src/op/select_n.cpp new file mode 100644 index 00000000000000..26a2f3a1d82f90 --- /dev/null +++ b/src/frontends/jax/src/op/select_n.cpp @@ -0,0 +1,46 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/jax/node_context.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/gather_elements.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "utils.hpp" + +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace jax { +namespace op { + +OutputVector translate_select_n(const NodeContext& context) { + num_inputs_check(context, 2); + auto num_inputs = static_cast(context.get_input_size()); + Output which = context.get_input(0); + if (which.get_element_type() == element::boolean) { + which = std::make_shared(which, element::i32); + } + auto const_axis = ov::op::v0::Constant::create(element::i64, Shape{1}, std::vector{0}); + OutputVector unsqueezed_cases(num_inputs - 1); + unsqueezed_cases.reserve(num_inputs - 1); + for (int ind = 1; ind < num_inputs; ++ind) { + auto case_input = context.get_input(ind); + auto unsqueeze = std::make_shared(case_input, const_axis); + unsqueezed_cases[ind - 1] = unsqueeze; + } + Output cases = std::make_shared(unsqueezed_cases, 0); + which = + std::make_shared(which, + ov::op::v0::Constant::create(element::i64, Shape{1}, std::vector{0})); + Output result = std::make_shared(cases, which, 0); + return {result}; +}; + +} // namespace op +} // namespace jax +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/jax/src/op_table.cpp b/src/frontends/jax/src/op_table.cpp index 3ca58745bc1909..9c492dfa3e119d 100644 --- a/src/frontends/jax/src/op_table.cpp +++ b/src/frontends/jax/src/op_table.cpp @@ -52,6 +52,7 @@ OP_CONVERTER(translate_reduce_window_max); OP_CONVERTER(translate_reduce_window_sum); OP_CONVERTER(translate_reshape); OP_CONVERTER(translate_rsqrt); +OP_CONVERTER(translate_select_n); OP_CONVERTER(translate_slice); OP_CONVERTER(translate_square); OP_CONVERTER(translate_squeeze); @@ -92,6 +93,7 @@ const std::map get_supported_ops_jaxpr() { {"transpose", op::translate_transpose}, {"rsqrt", op::translate_rsqrt}, {"reshape", op::translate_reshape}, + {"select_n", op::translate_select_n}, {"slice", op::translate_slice}, {"square", op::translate_square}, {"sqrt", op::translate_1to1_match_1_input}, diff --git a/tests/layer_tests/jax_tests/test_select_n.py b/tests/layer_tests/jax_tests/test_select_n.py new file mode 100644 index 00000000000000..b09fd676efed07 --- /dev/null +++ b/tests/layer_tests/jax_tests/test_select_n.py @@ -0,0 +1,45 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import jax +import numpy as np +import pytest +from jax import numpy as jnp + +from jax_layer_test_class import JaxLayerTest + +rng = np.random.default_rng(5402) + + +class TestSelectN(JaxLayerTest): + def _prepare_input(self): + cases = [] + if (self.case_num == 2): + which = rng.choice([True, False], self.input_shape) + else: + which = rng.uniform(0, self.case_num, self.input_shape).astype(self.input_type) + which = np.array(which) + for i in range(self.case_num): + cases.append(jnp.array(np.random.uniform(-1000, 1000, self.input_shape).astype(self.input_type))) + cases = np.array(cases) + return (which, cases) + + def create_model(self, input_shape, input_type, case_num): + self.input_shape = input_shape + self.input_type = input_type + self.case_num = case_num + + def jax_select_n(which, cases): + return jax.lax.select_n(which, *cases) + + return jax_select_n, None, 'select_n' + + @pytest.mark.parametrize("input_shape", [[], [1], [2, 3], [4, 5, 6], [7, 8, 9, 10]]) + @pytest.mark.parametrize("input_type", [np.int32, np.int64]) + @pytest.mark.parametrize("case_num", [2, 3, 4]) + @pytest.mark.nightly + @pytest.mark.precommit_jax_fe + def test_select_n(self, ie_device, precision, ir_version, input_shape, input_type, case_num): + self._test(*self.create_model(input_shape, input_type, case_num), + ie_device, precision, + ir_version)