diff --git a/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp b/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp index 81a055a4a09a5b..0c0ddf7e637050 100644 --- a/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp +++ b/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp @@ -204,6 +204,7 @@ REGISTER_FACTORY(v5, Loop); // ------------------------------ Supported v6 ops ------------------------------ // REGISTER_FACTORY(v6, CTCGreedyDecoderSeqLen); REGISTER_FACTORY(v6, MVN); +REGISTER_FACTORY(v6, GatherElements); // ------------------------------ Supported v7 ops ------------------------------ // REGISTER_FACTORY(v7, Gather); diff --git a/inference-engine/src/cldnn_engine/ops/gather_elements.cpp b/inference-engine/src/cldnn_engine/ops/gather_elements.cpp new file mode 100644 index 00000000000000..d61382807506c1 --- /dev/null +++ b/inference-engine/src/cldnn_engine/ops/gather_elements.cpp @@ -0,0 +1,66 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "cldnn_program.h" +#include "cldnn_common_utils.h" + +#include "ngraph/op/gather_elements.hpp" +#include "ngraph/op/constant.hpp" + +#include "cldnn/primitives/gather_elements.hpp" + +namespace CLDNNPlugin { + +static cldnn::gather_elements::gather_elements_axis GetGatherAxis(int axis, unsigned rank) { + if (axis < 0) + axis += rank; + if (axis < 0 || axis >= rank) + IE_THROW() << "GatherElements axis is not correspond to number of dimensions"; + + // Difference in dimension ordering between IE and clDNN, + // reverse spatial dimensions after batch and feature. + unsigned cldnn_axis = axis; + if (axis >= 2) { + auto spatial_axis = axis - 2; + // Default and minimum number of dimensions is 4 + auto spatial_size = std::max(rank, 4u) - 2; + cldnn_axis = spatial_size - spatial_axis - 1 + 2; + } + + switch (cldnn_axis) { + case 0: return cldnn::gather_elements::gather_elements_axis::along_b; + case 1: return cldnn::gather_elements::gather_elements_axis::along_f; + case 2: return cldnn::gather_elements::gather_elements_axis::along_x; + case 3: return cldnn::gather_elements::gather_elements_axis::along_y; + case 4: return cldnn::gather_elements::gather_elements_axis::along_z; + case 5: return cldnn::gather_elements::gather_elements_axis::along_w; + default: IE_THROW() << "Unsupported GatherElements axis: " << axis; + } + return cldnn::gather_elements::gather_elements_axis::along_f; // shouldn't get here +} + +void CreateGatherElementsOp(Program& p, const std::shared_ptr& op) { + p.ValidateInputs(op, {2}); + auto inputPrimitives = p.GetInputPrimitiveIDs(op); + std::string layerName = layer_type_name_ID(op); + + size_t rank = op->get_input_shape(0).size(); + int32_t axis = static_cast(op->get_axis()); + + auto outLayout = DefaultFormatForDims(op->get_output_shape(0).size()); + + auto primitive = cldnn::gather_elements(layerName, + inputPrimitives[0], + inputPrimitives[1], + outLayout, + CldnnTensorFromIEDims(op->get_output_shape(0)), + GetGatherAxis(axis, rank)); + + p.AddPrimitive(primitive); + p.AddPrimitiveToProfiler(op); +} + +REGISTER_FACTORY_IMPL(v6, GatherElements); + +} // namespace CLDNNPlugin diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp index 1ad8bbd0d4c335..0220364af315f8 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp @@ -4,7 +4,8 @@ #include -#include "shared_test_classes/single_layer/gather_elements.hpp" +#include "single_layer_tests/gather_elements.hpp" +#include "common_test_utils/test_constants.hpp" using namespace LayerTestsDefinitions; @@ -16,8 +17,6 @@ const std::vector dPrecisions = { InferenceEngine::Precision::I32, InferenceEngine::Precision::I64, InferenceEngine::Precision::I16, - InferenceEngine::Precision::U8, - InferenceEngine::Precision::I8 }; const std::vector iPrecisions = { InferenceEngine::Precision::I32, diff --git a/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp new file mode 100644 index 00000000000000..cbc4e9fed4fc5f --- /dev/null +++ b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp @@ -0,0 +1,227 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "single_layer_tests/gather_elements.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; +using namespace ngraph::opset6; + +namespace { + +const std::vector inputPrecisions = { + InferenceEngine::Precision::FP32, + InferenceEngine::Precision::FP16, + InferenceEngine::Precision::I32, +}; + +const std::vector idxPrecisions = { + InferenceEngine::Precision::I32, + InferenceEngine::Precision::I64, +}; + +INSTANTIATE_TEST_CASE_P(smoke_set1, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector({2, 2})), + ::testing::Values(std::vector({2, 2})), + ::testing::ValuesIn(std::vector({-1, 0, 1})), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_set2, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector({2, 2, 1})), + ::testing::Values(std::vector({4, 2, 1})), + ::testing::ValuesIn(std::vector({0, -3})), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_set3, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector({2, 2, 3, 5})), + ::testing::Values(std::vector({2, 2, 3, 7})), + ::testing::Values(3, -1), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_set4, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector({3, 2, 3, 8})), + ::testing::Values(std::vector({2, 2, 3, 8})), + ::testing::Values(0, -4), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_set5, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector({3, 2, 3, 4, 8})), + ::testing::Values(std::vector({3, 2, 3, 5, 8})), + ::testing::Values(3, -2), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis0, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector{7, 7, 8, 4}), + ::testing::Values(std::vector{2, 7, 8, 4}), + ::testing::Values(0), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis1, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector{6, 1, 8, 4}), + ::testing::Values(std::vector{6, 8, 8, 4}), + ::testing::Values(1, -3), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis2, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector{6, 7, 4, 4}), + ::testing::Values(std::vector{6, 7, 2, 4}), + ::testing::Values(2, -2), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis3, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector{6, 5, 8, 7}), + ::testing::Values(std::vector{6, 5, 8, 7}), + ::testing::Values(3, -1), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis0, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector{2, 3, 9, 4, 9}), + ::testing::Values(std::vector{1, 3, 9, 4, 9}), + ::testing::Values(0), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis1, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector{2, 3, 5, 4, 7}), + ::testing::Values(std::vector{2, 9, 5, 4, 7}), + ::testing::Values(1, -4), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis2, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector{1, 2, 6, 8, 9}), + ::testing::Values(std::vector{1, 2, 6, 8, 9}), + ::testing::Values(2, -3), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis3, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector{2, 2, 4, 7, 7}), + ::testing::Values(std::vector{2, 2, 4, 3, 7}), + ::testing::Values(3, -2), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis4, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector{1, 3, 9, 3, 2}), + ::testing::Values(std::vector{1, 3, 9, 3, 9}), + ::testing::Values(4, -1), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis0, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector{3, 3, 2, 4, 4, 3}), + ::testing::Values(std::vector{7, 3, 2, 4, 4, 3}), + ::testing::Values(0), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis1, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector{1, 6, 2, 3, 5, 9}), + ::testing::Values(std::vector{1, 6, 2, 3, 5, 9}), + ::testing::Values(1, -5), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis2, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector{2, 3, 9, 7, 2, 1}), + ::testing::Values(std::vector{2, 3, 5, 7, 2, 1}), + ::testing::Values(2, -4), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis3, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector{1, 3, 4, 5, 1, 3}), + ::testing::Values(std::vector{1, 3, 4, 4, 1, 3}), + ::testing::Values(3, -3), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis4, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector{1, 3, 2, 4, 3, 3}), + ::testing::Values(std::vector{1, 3, 2, 4, 6, 3}), + ::testing::Values(4, -2), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis5, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector{2, 1, 7, 8, 1, 6}), + ::testing::Values(std::vector{2, 1, 7, 8, 1, 5}), + ::testing::Values(5, -1), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); + +} // namespace diff --git a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp new file mode 100644 index 00000000000000..eea88d4abf3183 --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp @@ -0,0 +1,15 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "shared_test_classes/single_layer/gather_elements.hpp" + +namespace LayerTestsDefinitions { + +TEST_P(GatherElementsLayerTest, CompareWithRefs) { + Run(); +} + +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/shared_test_classes/src/single_layer/gather_elements.cpp b/inference-engine/tests/functional/shared_test_classes/src/single_layer/gather_elements.cpp index af5832302aa0be..d559e04a53d2c1 100644 --- a/inference-engine/tests/functional/shared_test_classes/src/single_layer/gather_elements.cpp +++ b/inference-engine/tests/functional/shared_test_classes/src/single_layer/gather_elements.cpp @@ -48,7 +48,4 @@ void GatherElementsLayerTest::SetUp() { function = std::make_shared(results, params, "gatherEl"); } -TEST_P(GatherElementsLayerTest, CompareWithRefs) { - Run(); -} } // namespace LayerTestsDefinitions diff --git a/inference-engine/thirdparty/clDNN/api/cldnn/primitives/gather_elements.hpp b/inference-engine/thirdparty/clDNN/api/cldnn/primitives/gather_elements.hpp new file mode 100644 index 00000000000000..d6d0ca9fdb24f9 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/api/cldnn/primitives/gather_elements.hpp @@ -0,0 +1,58 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#pragma once +#include "primitive.hpp" + +namespace cldnn { +/// @addtogroup cpp_api C++ API +/// @{ +/// @addtogroup cpp_topology Network Topology +/// @{ +/// @addtogroup cpp_primitives Primitives +/// @{ + +/// @brief +/// @details +struct gather_elements : public primitive_base { + CLDNN_DECLARE_PRIMITIVE(gather_elements) + + enum gather_elements_axis { + along_b, + along_f, + along_x, + along_y, + along_z, + along_w + }; + + /// @brief Constructs gather_elements primitive. + /// @param id This primitive id. + /// @param data Input data primitive id. + /// @param indices Input indexes primitive id. + /// @param output_format Output format. + /// @param output_shape Output shape. + /// @param axis Gathering axis. + gather_elements(const primitive_id& id, + const primitive_id& data, + const primitive_id& indices, + const format& output_format, + const tensor& output_shape, + const gather_elements_axis axis, + const padding& output_padding = padding()) + : primitive_base(id, {data, indices}, output_padding), output_format(output_format), output_shape(output_shape), axis(axis) {} + + /// @brief Gather Elements output format + format output_format; + /// @brief Gather Elements output shape + tensor output_shape; + + /// @brief Which axis to gather on. + gather_elements_axis axis; +}; +/// @} +/// @} +/// @} +} // namespace cldnn diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h b/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h index 7a072d998d4789..dbe6bd7004c672 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h @@ -48,6 +48,7 @@ enum class KernelType { ONE_HOT, GATHER, GATHER_ND, + GATHER_ELEMENTS, SCATTER_UPDATE, SCATTER_ND_UPDATE, SCATTER_ELEMENTS_UPDATE, diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp new file mode 100644 index 00000000000000..eb01e12a12f0ee --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp @@ -0,0 +1,154 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "gather_elements_kernel_ref.h" +#include "kernel_selector_utils.h" +#include +#include + +namespace kernel_selector { +static size_t GetGatherElementsChannelIndex(const gather_elements_params& params) { + Tensor::DataChannelName name = Tensor::DataChannelName::X; + + size_t inputSize = params.inputs[0].GetDims().size(); + + switch (params.axis) { + case GatherAxis::X: + return inputSize - 1; + case GatherAxis::Y: + return inputSize - 2; + case GatherAxis::Z: + return inputSize - 3; + case GatherAxis::W: + return 2; + case GatherAxis::FEATURE: + return 1; + case GatherAxis::BATCH: + return 0; + default: + break; + } + + return DataTensor::Channelndex(params.output.GetLayout(), name); +} + +ParamsKey GatherElementsKernelRef::GetSupportedKey() const { + ParamsKey k; + k.EnableInputDataType(Datatype::F16); + k.EnableInputDataType(Datatype::F32); + k.EnableInputDataType(Datatype::INT32); + k.EnableOutputDataType(Datatype::F16); + k.EnableOutputDataType(Datatype::F32); + k.EnableOutputDataType(Datatype::INT32); + k.EnableInputLayout(DataLayout::bfyx); + k.EnableOutputLayout(DataLayout::bfyx); + k.EnableInputLayout(DataLayout::bfzyx); + k.EnableOutputLayout(DataLayout::bfzyx); + k.EnableInputLayout(DataLayout::bfwzyx); + k.EnableOutputLayout(DataLayout::bfwzyx); + k.EnableTensorOffset(); + k.EnableTensorPitches(); + k.EnableBatching(); + k.EnableDifferentTypes(); + return k; +} + +static inline std::vector GetDefaultOrder(size_t size) { + std::vector default_order; + if (size <= 4) { + default_order = { "b", "f", "y", "x" }; + } else if (size == 5) { + default_order = { "b", "f", "z", "y", "x" }; + } else if (size == 6) { + default_order = { "b", "f", "w", "z", "y", "x" }; + } + + return default_order; +} + +CommonDispatchData GatherElementsKernelRef::SetDefault(const gather_elements_params& params, const optional_params&) const { + CommonDispatchData dispatchData; + + const auto& output = params.output; + + switch (params.inputs[1].GetLayout()) { + case DataLayout::bfyx: + dispatchData.gws = {output.X().v, output.Y().v, output.Feature().v * output.Batch().v}; + break; + + case DataLayout::bfzyx: + dispatchData.gws = {output.X().v, output.Y().v * output.Z().v, output.Feature().v * output.Batch().v}; + break; + + case DataLayout::bfwzyx: + dispatchData.gws = {output.X().v * output.Y().v, output.Z().v * output.W().v, output.Feature().v * output.Batch().v}; + break; + + default: + throw std::invalid_argument("Unsupported data layout for gather elements primitive"); + break; + } + + dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo); + + return dispatchData; +} + +JitConstants GatherElementsKernelRef::GetJitConstants(const gather_elements_params& params) const { + JitConstants jit = MakeBaseParamsJitConstants(params); + + jit.AddConstant(MakeJitConstant("AXIS", GetGatherElementsChannelIndex(params))); + + if (!params.fused_ops.empty()) { + std::vector idx_order = GetDefaultOrder(params.inputs[0].GetDims().size()); + FusedOpsConfiguration conf = { "", idx_order, "val", params.inputs[0].GetDType() }; + jit.Merge(MakeFusedOpsJitConstants(params, { conf })); + } + + return jit; +} + +bool GatherElementsKernelRef::Validate(const Params& p, const optional_params& o) const { + if (p.GetType() != KernelType::GATHER_ELEMENTS || o.GetType() != KernelType::GATHER_ELEMENTS) { + return false; + } + + const gather_elements_params& params = static_cast(p); + auto input_dims = params.inputs[0].LogicalDims(); + auto indices_dims = params.inputs[1].LogicalDims(); + + if (input_dims.size() != indices_dims.size()) { + return false; + } + + for (auto& fused_op : params.fused_ops) { + if (!IsFusedPrimitiveSupported(fused_op)) + return false; + } + + return true; +} + +KernelsData GatherElementsKernelRef::GetKernelsData(const Params& params, const optional_params& options) const { + if (!Validate(params, options)) { + return {}; + } + + KernelData kd = KernelData::Default(params); + gather_elements_params& newParams = *static_cast(kd.params.get()); + + auto dispatchData = SetDefault(newParams, options); + auto cldnn_jit = GetJitConstants(newParams); + + auto entry_point = GetEntryPoint(kernelName, newParams.layerID, params, options); + auto jit = CreateJit(kernelName, cldnn_jit, entry_point); + auto& kernel = kd.kernels[0]; + FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 2, GetFusedPrimitiveInputsCount(params)); + return { kd }; +} + +KernelsPriority GatherElementsKernelRef::GetKernelsPriority(const Params& /*params*/, const optional_params& /*options*/) const { + return DONT_USE_IF_HAVE_SOMETHING_ELSE; +} +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h new file mode 100644 index 00000000000000..8eec4ae96326fa --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h @@ -0,0 +1,45 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "kernel_base_opencl.h" + +namespace kernel_selector { +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// gather_elements_params +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +struct gather_elements_params : public base_params { + gather_elements_params() : base_params(KernelType::GATHER_ELEMENTS), axis(GatherAxis::BATCH) {} + + GatherAxis axis; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// gather_elements_optional_params +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +struct gather_elements_optional_params : optional_params { + gather_elements_optional_params() : optional_params(KernelType::GATHER_ELEMENTS) {} +}; + +class GatherElementsKernelRef : public KernelBaseOpenCL { +public: + GatherElementsKernelRef() : KernelBaseOpenCL("gather_elements_ref") {} + virtual ~GatherElementsKernelRef() {} + virtual JitConstants GetJitConstants(const gather_elements_params& params) const; + virtual CommonDispatchData SetDefault(const gather_elements_params& params, const optional_params&) const; + KernelsData GetKernelsData(const Params& params, const optional_params& options) const override; + KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override; + ParamsKey GetSupportedKey() const override; + std::vector GetSupportedFusedOps() const override { + return { FusedOpType::QUANTIZE, + FusedOpType::SCALE, + FusedOpType::ACTIVATION, + FusedOpType::ELTWISE }; + } + +protected: + bool Validate(const Params& p, const optional_params& o) const override; +}; +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.cpp new file mode 100644 index 00000000000000..3a451cf574add9 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.cpp @@ -0,0 +1,27 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +#include "gather_elements_kernel_selector.h" +#include "gather_elements_kernel_ref.h" + +namespace kernel_selector { + +gather_elements_kernel_selector::gather_elements_kernel_selector() { Attach(); } + +KernelsData gather_elements_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const { + return GetNaiveBestKernel(params, options, KernelType::GATHER_ELEMENTS); +} +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.h new file mode 100644 index 00000000000000..333298a45de53d --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.h @@ -0,0 +1,35 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +#pragma once + +#include "kernel_selector.h" + +namespace kernel_selector { +class gather_elements_kernel_selector : public kernel_selector_base { +public: + static gather_elements_kernel_selector& Instance() { + static gather_elements_kernel_selector instance_; + return instance_; + } + + gather_elements_kernel_selector(); + + virtual ~gather_elements_kernel_selector() {} + + KernelsData GetBestKernels(const Params& params, const optional_params& options) const override; +}; +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl new file mode 100644 index 00000000000000..d03c1c85b13aa2 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl @@ -0,0 +1,86 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "include/data_types.cl" +#include "include/fetch_data.cl" + +#define GET_OUTPUT_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order) + +KERNEL(gather_elements_ref)(const __global INPUT0_TYPE* data, + const __global INPUT1_TYPE* indices, + __global OUTPUT_TYPE* output +#if HAS_FUSED_OPS_DECLS + , FUSED_OPS_DECLS +#endif +) +{ + const uint dim0 = get_global_id(0); + const uint dim1 = get_global_id(1); + const uint dim2 = get_global_id(2); + + // Calculate indice index +#if INPUT1_DIMS == 4 + #define ORDER b,f,y,x + const uint x = dim0; + const uint y = dim1; +#elif INPUT1_DIMS == 5 + #define ORDER b,f,z,y,x + const uint x = dim0; + const uint y = dim1 % OUTPUT_SIZE_Y; + const uint z = dim1 / OUTPUT_SIZE_Y; +#else + #define ORDER b,f,w,z,y,x + const uint x = dim0 % OUTPUT_SIZE_X; + const uint y = dim0 / OUTPUT_SIZE_X; + const uint z = dim1 % OUTPUT_SIZE_Z; + const uint w = dim1 / OUTPUT_SIZE_Z; +#endif + const uint f = dim2 % OUTPUT_FEATURE_NUM; + const uint b = dim2 / OUTPUT_FEATURE_NUM; + + const int out_idx = GET_OUTPUT_INDEX(INPUT1, ORDER); + +#if INPUT1_DIMS == 4 + size_t data_shape[4] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X}; + size_t indices_shape[4] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Y, INPUT1_SIZE_X}; +#elif INPUT1_DIMS == 5 + size_t data_shape[5] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X}; + size_t indices_shape[5] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X}; +#else + size_t data_shape[6] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_W, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X}; + size_t indices_shape[6] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X}; +#endif + + size_t max_inner_sum = 1, max_outer_sum = 1, outer_sum_inc_data = 1, outer_sum_inc_indices = 1; + for (size_t i = AXIS + 1; i < INPUT1_DIMS; i++) + max_inner_sum *= indices_shape[i]; + + for (int i = 0; i < AXIS; i++) + max_outer_sum *= indices_shape[i]; + + for (size_t i = AXIS; i < INPUT1_DIMS; i++) { + outer_sum_inc_data *= data_shape[i]; + } + max_outer_sum *= outer_sum_inc_data; + + for (size_t i = AXIS; i < INPUT1_DIMS; i++) { + outer_sum_inc_indices *= indices_shape[i]; + } + + size_t outer_sum = (out_idx / outer_sum_inc_indices) * outer_sum_inc_data; + size_t inner_sum = out_idx % max_inner_sum; + + uint idx = outer_sum + max_inner_sum * indices[out_idx] + inner_sum; + INPUT0_TYPE val = data[idx]; + +#if HAS_FUSED_OPS + FUSED_OPS; + output[out_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT); +#else + output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS); +#endif +} + +#undef ORDER +#undef GET_OUTPUT_INDEX diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.cpp index deeb31e350e890..35c2115e5ac9e3 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.cpp @@ -402,6 +402,8 @@ std::string toString(GatherAxis a) { switch (a) { case GatherAxis::X: return "X"; case GatherAxis::Y: return "Y"; + case GatherAxis::Z: return "Z"; + case GatherAxis::W: return "W"; case GatherAxis::FEATURE: return "FEATURE"; case GatherAxis::BATCH: return "BATCH"; default: return ""; diff --git a/inference-engine/thirdparty/clDNN/src/gather_elements.cpp b/inference-engine/thirdparty/clDNN/src/gather_elements.cpp new file mode 100644 index 00000000000000..7a3a920aa6277e --- /dev/null +++ b/inference-engine/thirdparty/clDNN/src/gather_elements.cpp @@ -0,0 +1,62 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "gather_elements_inst.h" + +#include "primitive_type_base.h" +#include "cldnn/runtime/error_handler.hpp" +#include "json_object.h" +#include + +namespace cldnn { +primitive_type_id gather_elements::type_id() { + static primitive_type_base instance; + return &instance; +} + +layout gather_elements_inst::calc_output_layout(gather_elements_node const& node) { + auto op = node.get_primitive(); + + auto input_layout_origin = node.input(0).get_output_layout(); + auto indices_layout_origin = node.input(1).get_output_layout(); + + auto input_layout = input_layout_origin.size.sizes(input_layout_origin.format); + auto indices_layout = indices_layout_origin.size.sizes(indices_layout_origin.format); + + if (node.has_fused_primitives()) { + input_layout_origin.data_type = node.get_fused_output_layout().data_type; + } + + auto output_type = indices_layout_origin.data_type; + auto output_format = op->output_format; + auto output_shape = op->output_shape; + + // calculate initial output shape + return layout(output_type, output_format, output_shape); +} + +std::string gather_elements_inst::to_string(gather_elements_node const& node) { + auto desc = node.get_primitive(); + auto node_info = node.desc_to_json(); + auto& input = node.input(); + + std::stringstream primitive_description; + + json_composite gather_elements_info; + gather_elements_info.add("input id", input.id()); + gather_elements_info.add("input shape", node.input(0).get_output_layout().size.to_string()); + gather_elements_info.add("indices shape", node.input(1).get_output_layout().size.to_string()); + gather_elements_info.add("output format", calc_output_layout(node).format); + gather_elements_info.add("output shape", calc_output_layout(node).size.to_string()); + gather_elements_info.add("axis", desc->axis); + + node_info->add("gather_elements info", gather_elements_info); + node_info->dump(primitive_description); + + return primitive_description.str(); +} + +gather_elements_inst::typed_primitive_inst(network_impl& network, gather_elements_node const& node) : parent(network, node) {} + +} // namespace cldnn diff --git a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp index ffabb96f2e40bf..7b82fb54e3c409 100644 --- a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp +++ b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp @@ -32,6 +32,7 @@ #include "space_to_depth_inst.h" #include "gather_inst.h" #include "gather_nd_inst.h" +#include "gather_elements_inst.h" #include "scatter_update_inst.h" #include "scatter_nd_update_inst.h" #include "scatter_elements_update_inst.h" @@ -200,6 +201,7 @@ void prepare_primitive_fusing::fuse_activations(program_impl &p) { !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && + !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type())) @@ -609,6 +611,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { should_fuse |= input_data.is_type(); + should_fuse |= input_data.is_type(); + should_fuse |= input_data.is_type(); should_fuse |= input_data.is_type(); @@ -677,6 +681,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { should_fuse |= input_data.is_type(); + should_fuse |= input_data.is_type(); + should_fuse |= input_data.is_type(); should_fuse |= input_data.is_type(); @@ -767,6 +773,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); + should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); + should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); @@ -829,6 +837,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { (parents[i]->is_type() && eltwise_supports_fusings(parents[i]->as())) || (parents[i]->is_type()) || (parents[i]->is_type()) || + (parents[i]->is_type()) || (parents[i]->is_type()) || (parents[i]->is_type()) || (parents[i]->is_type() && pooling_supports_fusings(parents[i]->as())) || diff --git a/inference-engine/thirdparty/clDNN/src/impls/ocl/gather_elements.cpp b/inference-engine/thirdparty/clDNN/src/impls/ocl/gather_elements.cpp new file mode 100644 index 00000000000000..968eb6bbb7db6c --- /dev/null +++ b/inference-engine/thirdparty/clDNN/src/impls/ocl/gather_elements.cpp @@ -0,0 +1,86 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "gather_elements_inst.h" +#include "primitive_base.hpp" +#include "impls/implementation_map.hpp" +#include "kernel_selector_helper.h" +#include "gather/gather_elements_kernel_selector.h" +#include "gather/gather_elements_kernel_ref.h" +#include "cldnn/runtime/error_handler.hpp" + +using namespace cldnn; + +namespace cldnn { +namespace ocl { +kernel_selector::gather_elements_axis convert_axis(gather_elements::gather_elements_axis axis) { + switch (axis) { + case gather_elements::along_x: + return kernel_selector::gather_elements_axis::X; + case gather_elements::along_y: + return kernel_selector::gather_elements_axis::Y; + case gather_elements::along_z: + return kernel_selector::gather_elements_axis::Z; + case gather_elements::along_w: + return kernel_selector::gather_elements_axis::W; + case gather_elements::along_f: + return kernel_selector::gather_elements_axis::FEATURE; + case gather_elements::along_b: + return kernel_selector::gather_elements_axis::BATCH; + default: + return kernel_selector::gather_elements_axis::BATCH; + } +} + +struct gather_elements_impl : typed_primitive_impl_ocl { + using parent = typed_primitive_impl_ocl; + using parent::parent; + + std::unique_ptr clone() const override { + return make_unique(*this); + } + +public: + static primitive_impl* create(const gather_elements_node& arg) { + auto gather_elements_params = get_default_params(arg); + auto gather_elements_optional_params = + get_default_optional_params(arg.get_program()); + + gather_elements_params.axis = convert_axis(arg.get_primitive()->axis); + + gather_elements_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout())); + + auto& kernel_selector = kernel_selector::gather_elements_kernel_selector::Instance(); + auto best_kernels = kernel_selector.GetBestKernels(gather_elements_params, gather_elements_optional_params); + + CLDNN_ERROR_BOOL(arg.id(), + "Best_kernel.empty()", + best_kernels.empty(), + "Cannot find a proper kernel with this arguments"); + + auto gather_elements = new gather_elements_impl(arg, best_kernels[0]); + + return gather_elements; + } +}; + +namespace detail { + +attach_gather_elements_impl::attach_gather_elements_impl() { + implementation_map::add(impl_types::ocl, gather_elements_impl::create, { + std::make_tuple(data_types::f32, format::bfyx), + std::make_tuple(data_types::f16, format::bfyx), + std::make_tuple(data_types::i32, format::bfyx), + std::make_tuple(data_types::f32, format::bfzyx), + std::make_tuple(data_types::f16, format::bfzyx), + std::make_tuple(data_types::i32, format::bfzyx), + std::make_tuple(data_types::f32, format::bfwzyx), + std::make_tuple(data_types::f16, format::bfwzyx), + std::make_tuple(data_types::i32, format::bfwzyx), + }); +} + +} // namespace detail +} // namespace ocl +} // namespace cldnn diff --git a/inference-engine/thirdparty/clDNN/src/impls/ocl/register.cpp b/inference-engine/thirdparty/clDNN/src/impls/ocl/register.cpp index c62b64de62dff8..86a423a84713e6 100644 --- a/inference-engine/thirdparty/clDNN/src/impls/ocl/register.cpp +++ b/inference-engine/thirdparty/clDNN/src/impls/ocl/register.cpp @@ -30,6 +30,7 @@ void register_implementations() { REGISTER_OCL(eltwise); REGISTER_OCL(fully_connected); REGISTER_OCL(gather); + REGISTER_OCL(gather_elements); REGISTER_OCL(gather_nd); REGISTER_OCL(gemm); REGISTER_OCL(lrn); diff --git a/inference-engine/thirdparty/clDNN/src/impls/ocl/register.hpp b/inference-engine/thirdparty/clDNN/src/impls/ocl/register.hpp index 036162ed8d82fb..dcd58574e52558 100644 --- a/inference-engine/thirdparty/clDNN/src/impls/ocl/register.hpp +++ b/inference-engine/thirdparty/clDNN/src/impls/ocl/register.hpp @@ -22,6 +22,7 @@ #include "cldnn/primitives/fully_connected.hpp" #include "cldnn/primitives/gather.hpp" #include "cldnn/primitives/gather_nd.hpp" +#include "cldnn/primitives/gather_elements.hpp" #include "cldnn/primitives/gemm.hpp" #include "cldnn/primitives/lrn.hpp" #include "cldnn/primitives/lstm.hpp" @@ -94,6 +95,7 @@ REGISTER_OCL(embed); REGISTER_OCL(fully_connected); REGISTER_OCL(gather); REGISTER_OCL(gather_nd); +REGISTER_OCL(gather_elements); REGISTER_OCL(gemm); REGISTER_OCL(lrn); REGISTER_OCL(lstm_gemm); diff --git a/inference-engine/thirdparty/clDNN/src/include/gather_elements_inst.h b/inference-engine/thirdparty/clDNN/src/include/gather_elements_inst.h new file mode 100644 index 00000000000000..ebefc9c032dea6 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/src/include/gather_elements_inst.h @@ -0,0 +1,49 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#pragma once +#include "cldnn/primitives/gather_elements.hpp" +#include "primitive_inst.h" +#include + +namespace cldnn { +template <> +struct typed_program_node : public typed_program_node_base { + using parent = typed_program_node_base; + +public: + using parent::parent; + + program_node& input(size_t index = 0) const { return get_dependency(index); } +}; + +using gather_elements_node = typed_program_node; + +template <> +class typed_primitive_inst : public typed_primitive_inst_base { + using parent = typed_primitive_inst_base; + +public: + static layout calc_output_layout(gather_elements_node const& node); + static std::string to_string(gather_elements_node const& node); + +public: + typed_primitive_inst(network_impl& network, gather_elements_node const& desc); +}; + +using gather_elements_inst = typed_primitive_inst; +} // namespace cldnn diff --git a/inference-engine/thirdparty/clDNN/src/include/kernel_selector_helper.h b/inference-engine/thirdparty/clDNN/src/include/kernel_selector_helper.h index f97f74ebbbcda8..719824cb4ccafc 100644 --- a/inference-engine/thirdparty/clDNN/src/include/kernel_selector_helper.h +++ b/inference-engine/thirdparty/clDNN/src/include/kernel_selector_helper.h @@ -72,6 +72,7 @@ using shape_calculation_mode = kernel_selector::ShapeCalculationMode; using interpolate_axis = kernel_selector::InterpolateAxis; using border_type = kernel_selector::BorderType; using gather_axis = kernel_selector::GatherAxis; +using gather_elements_axis = kernel_selector::GatherAxis; using scatter_update_axis = kernel_selector::ScatterUpdateAxis; using reduce_mode = kernel_selector::ReduceMode; using cum_sum_axis = kernel_selector::CumSumAxis; diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp index 0fea8be6648aa4..fb51f20a7ce107 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -8413,3 +8414,228 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_nd_activation_scale_eltwise, gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_3, 2, 5 }, gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_4, 2, 5 }, })); + + + +/* ----------------------------------------------------------------------------------------------------- */ +/* ------------------------------------------ GatherElements cases ------------------------------------- */ +/* ----------------------------------------------------------------------------------------------------- */ +struct gather_elements_test_params { + data_types data_type; + + format input_format; + tensor input_shape; + + format indices_format; + tensor indices_shape; + + format output_format; + tensor output_shape; + + cldnn::gather_elements::gather_elements_axis axis; + + data_types default_type; + format default_format; + + size_t expected_fused_primitives; + size_t expected_not_fused_primitives; +}; + +#define CASE_GATHER_ELEMENTS_FP16_4D_1 data_types::f16, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, cldnn::gather_elements::gather_elements_axis::along_y, data_types::f16, format::bfyx +#define CASE_GATHER_ELEMENTS_FP16_4D_2 data_types::f16, format::bfyx, {3, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, cldnn::gather_elements::gather_elements_axis::along_b, data_types::f16, format::bfyx +#define CASE_GATHER_ELEMENTS_FP16_4D_3 data_types::f16, format::bfyx, {1, 3, 2, 9}, format::bfyx, {1, 3, 5, 9}, format::bfyx, {1, 3, 5, 9}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f16, format::bfyx + +#define CASE_GATHER_ELEMENTS_FP16_5D_1 data_types::f16, format::bfzyx, {3, 2, 5, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f16, format::bfzyx +#define CASE_GATHER_ELEMENTS_FP16_5D_2 data_types::f16, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 7, 4, 3}, format::bfzyx, {5, 4, 7, 4, 3}, cldnn::gather_elements::gather_elements_axis::along_z, data_types::f16, format::bfzyx + +#define CASE_GATHER_ELEMENTS_FP16_6D_1 data_types::f16, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, cldnn::gather_elements::gather_elements_axis::along_f, data_types::f16, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP16_6D_2 data_types::f16, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_w, data_types::f16, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP16_6D_3 data_types::f16, format::bfwzyx, {2, 2, 3, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f16, format::bfwzyx + + +#define CASE_GATHER_ELEMENTS_FP32_4D_1 data_types::f32, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, cldnn::gather_elements::gather_elements_axis::along_y, data_types::f32, format::bfyx +#define CASE_GATHER_ELEMENTS_FP32_4D_2 data_types::f32, format::bfyx, {3, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, cldnn::gather_elements::gather_elements_axis::along_b, data_types::f32, format::bfyx +#define CASE_GATHER_ELEMENTS_FP32_4D_3 data_types::f32, format::bfyx, {1, 3, 2, 9}, format::bfyx, {1, 3, 5, 9}, format::bfyx, {1, 3, 5, 9}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f32, format::bfyx + +#define CASE_GATHER_ELEMENTS_FP32_5D_1 data_types::f32, format::bfzyx, {3, 2, 5, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f32, format::bfzyx +#define CASE_GATHER_ELEMENTS_FP32_5D_2 data_types::f32, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 7, 4, 3}, format::bfzyx, {5, 4, 7, 4, 3}, cldnn::gather_elements::gather_elements_axis::along_z, data_types::f32, format::bfzyx + +#define CASE_GATHER_ELEMENTS_FP32_6D_1 data_types::f32, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, cldnn::gather_elements::gather_elements_axis::along_f, data_types::f32, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP32_6D_2 data_types::f32, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_w, data_types::f32, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP32_6D_3 data_types::f32, format::bfwzyx, {2, 2, 3, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f32, format::bfwzyx + +class GatherElementsPrimitiveFusingTest : public ::BaseFusingTest { +public: + void execute(gather_elements_test_params& p) { + auto input_prim = get_mem(get_input_layout(p)); + network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused); + network network_fused(this->engine, this->topology_fused, bo_fused); + network_fused.set_input_data("input", input_prim); + network_not_fused.set_input_data("input", input_prim); + compare(network_not_fused, network_fused, p); + } + + size_t get_axis_dim(gather_elements_test_params& p) { + switch (p.axis) { + case cldnn::gather_elements::gather_elements_axis::along_x: + return p.input_shape.spatial[0]; + case cldnn::gather_elements::gather_elements_axis::along_y: + return p.input_shape.spatial[1]; + case cldnn::gather_elements::gather_elements_axis::along_z: + return p.input_shape.spatial[2]; + case cldnn::gather_elements::gather_elements_axis::along_w: + return p.input_shape.spatial[3]; + case cldnn::gather_elements::gather_elements_axis::along_f: + return p.input_shape.feature[0]; + case cldnn::gather_elements::gather_elements_axis::along_b: + return p.input_shape.batch[0]; + default: + return 1; + } + } + + layout get_input_layout(gather_elements_test_params& p) { + return layout{ p.data_type, p.input_format, p.input_shape }; + } + + layout get_indices_layout(gather_elements_test_params& p) { + return layout{ p.data_type, p.indices_format, p.indices_shape }; + } + + layout get_output_layout(gather_elements_test_params& p) { + return layout{ p.data_type, p.output_format, p.output_shape }; + } + + layout get_per_channel_layout(gather_elements_test_params& p) { + return layout{ p.default_type, p.default_format, tensor{1, p.output_shape.feature[0], 1, 1} }; + } +}; + +class gather_elements_quantize : public GatherElementsPrimitiveFusingTest {}; +TEST_P(gather_elements_quantize, basic) { + auto p = GetParam(); + create_topologies(input_layout("input", get_input_layout(p)), + data("gather_elements_indices", get_mem(get_indices_layout(p), 0, static_cast(get_axis_dim(p))-1)), + data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)), + data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)), + data("out_lo", get_mem(get_single_element_layout(p), -127)), + data("out_hi", get_mem(get_single_element_layout(p), 127)), + gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis), + quantize("quantize", "gather_elements_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8), + reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32) + ); + tolerance = 1.f; + execute(p); +} + +INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_elements_quantize, + ::testing::ValuesIn(std::vector{ + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_1, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_2, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_3, 2, 3 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_1, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_2, 2, 3 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_1, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_2, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_3, 2, 3 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_1, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_2, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_3, 2, 3 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_1, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_2, 2, 3 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_1, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 3 }, +})); + + +class gather_elements_scale_activation : public GatherElementsPrimitiveFusingTest {}; +TEST_P(gather_elements_scale_activation, basic) { + auto p = GetParam(); + create_topologies(input_layout("input", get_input_layout(p)), + data("gather_elements_indices", get_mem(get_indices_layout(p), 0, static_cast(get_axis_dim(p))-1)), + data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)), + gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis), + activation("activation", "gather_elements_prim", activation_func::abs), + scale("scale", "activation", "scale_data"), + reorder("reorder_bfyx", "scale", p.default_format, data_types::f32) + ); + + tolerance = 1e-5f; + execute(p); +} + +INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_elements_scale_activation, + ::testing::ValuesIn(std::vector{ + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_1, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_2, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_3, 2, 4 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_1, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_2, 2, 4 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_1, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_2, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_3, 2, 4 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_1, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_2, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_3, 2, 4 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_1, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_2, 2, 4 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_1, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 4 }, +})); + + +class gather_elements_activation_scale_eltwise : public GatherElementsPrimitiveFusingTest {}; +TEST_P(gather_elements_activation_scale_eltwise, basic) { + auto p = GetParam(); + + create_topologies(input_layout("input", get_input_layout(p)), + data("gather_elements_indices", get_mem(get_indices_layout(p), 0, static_cast(get_axis_dim(p))-1)), + data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255)), + data("eltwise_data", get_mem(get_output_layout(p))), + gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis), + activation("activation", "gather_elements_prim", activation_func::abs), + scale("scale", "activation", "scale_data"), + eltwise("eltwise", { "scale", "eltwise_data" }, eltwise_mode::sum, p.data_type), + reorder("reorder_bfyx", "eltwise", p.default_format, data_types::f32) + ); + + tolerance = 1e-5f; + execute(p); +} + +INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_elements_activation_scale_eltwise, + ::testing::ValuesIn(std::vector{ + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_1, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_2, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_3, 2, 5 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_1, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_2, 2, 5 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_1, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_2, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_3, 2, 5 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_1, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_2, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_3, 2, 5 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_1, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_2, 2, 5 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_1, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 5 }, +})); diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp new file mode 100644 index 00000000000000..034f9f6699ada5 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp @@ -0,0 +1,1141 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test_utils.h" + +#include +#include +#include +#include +#include + +#include +#include + +using namespace cldnn; +using namespace ::tests; + +inline void DoTest(engine& engine, + const cldnn::memory::ptr& input0, // data + const cldnn::memory::ptr& input1, // indices + const std::vector& expected_results, + const tensor& output_tensor, + const cldnn::gather_elements::gather_elements_axis axis) { + topology topology; + topology.add(input_layout("InputData", input0->get_layout())); + topology.add(input_layout("InputIndices", input1->get_layout())); + topology.add( + gather_elements("gather_elements", "InputData", "InputIndices", input1->get_layout().format, output_tensor, axis) + ); + + network network(engine, topology); + + network.set_input_data("InputData", input0); + network.set_input_data("InputIndices", input1); + auto outputs = network.execute(); + auto output = outputs.at("gather_elements").get_memory(); + cldnn::mem_lock output_ptr(output, get_test_stream()); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(gather_elements_gpu_fp16, d3283_i2283_a0) { + auto& engine = get_test_engine(); + + auto axis = cldnn::gather_elements::gather_elements_axis::along_b; + auto input0 = engine.allocate_memory({ data_types::f16, format::bfyx, { 3, 2, 8, 3 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 2, 8, 3 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), + FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(5), + FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(0), FLOAT16(8), FLOAT16(3), FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(3), + FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(8), FLOAT16(8), + FLOAT16(9), FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(7), + FLOAT16(10), FLOAT16(9), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(6), FLOAT16(9), + FLOAT16(4), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(1), + FLOAT16(1), FLOAT16(6), FLOAT16(8), FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(8), + FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(9), FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(1), + FLOAT16(1), FLOAT16(3), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(7), FLOAT16(10), FLOAT16(2), + FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(4), + FLOAT16(5), FLOAT16(1), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(10), FLOAT16(6), FLOAT16(1), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(2), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(4), FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), + FLOAT16(1), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(5), FLOAT16(3), + FLOAT16(6), FLOAT16(5), FLOAT16(6), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(6), FLOAT16(1), + FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(4), FLOAT16(4), FLOAT16(7), FLOAT16(8), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(9), FLOAT16(8), FLOAT16(4), FLOAT16(4), + FLOAT16(5), FLOAT16(1), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(0), FLOAT16(6), FLOAT16(1), + FLOAT16(2), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(7), + FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(0), FLOAT16(5), FLOAT16(0), FLOAT16(5), FLOAT16(3), + FLOAT16(6), FLOAT16(9), FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(2), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(3), FLOAT16(0), FLOAT16(4), FLOAT16(10), FLOAT16(7), FLOAT16(10), FLOAT16(2), + FLOAT16(9), FLOAT16(3), FLOAT16(0), FLOAT16(7), FLOAT16(6), FLOAT16(8), FLOAT16(8), FLOAT16(4), + FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(3), FLOAT16(3), FLOAT16(10), FLOAT16(6), FLOAT16(1), + }; + + DoTest(engine, input0, input1, expected_results, tensor(2, 2, 8, 3), axis); +} + +TEST(gather_elements_gpu_fp16, d2235_i2235_a3) { + auto& engine = get_test_engine(); + + auto axis = cldnn::gather_elements::gather_elements_axis::along_x; + auto input0 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 2, 3, 5 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 2, 3, 5 } }); // indices + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), + FLOAT16(5), FLOAT16(5), FLOAT16(2), + FLOAT16(0), FLOAT16(7), FLOAT16(7), + FLOAT16(10), FLOAT16(4), FLOAT16(5), + FLOAT16(9), FLOAT16(0), FLOAT16(0), + FLOAT16(5), FLOAT16(7), FLOAT16(0), + FLOAT16(4), FLOAT16(0), FLOAT16(4), + FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), FLOAT16(1), + FLOAT16(7), FLOAT16(4), FLOAT16(7), + FLOAT16(10), FLOAT16(8), FLOAT16(2), + FLOAT16(0), FLOAT16(8), FLOAT16(3), + FLOAT16(6), FLOAT16(8), FLOAT16(10), + FLOAT16(4), FLOAT16(2), FLOAT16(10), + FLOAT16(7), FLOAT16(8), FLOAT16(7), + FLOAT16(0), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), FLOAT16(8), + FLOAT16(5), FLOAT16(2), FLOAT16(3), + FLOAT16(3), FLOAT16(1), FLOAT16(5), + FLOAT16(9), FLOAT16(10), FLOAT16(0), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(2), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(1), FLOAT16(8), + FLOAT16(2), FLOAT16(2), FLOAT16(5), + FLOAT16(0), FLOAT16(0), FLOAT16(7), + FLOAT16(10), FLOAT16(10), FLOAT16(10), + FLOAT16(0), FLOAT16(9), FLOAT16(0), + FLOAT16(7), FLOAT16(0), FLOAT16(7), + FLOAT16(4), FLOAT16(0), FLOAT16(4), + FLOAT16(6), FLOAT16(7), FLOAT16(10), + FLOAT16(5), FLOAT16(9), FLOAT16(5), + FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(8), FLOAT16(2), FLOAT16(2), + FLOAT16(8), FLOAT16(8), FLOAT16(8), + FLOAT16(8), FLOAT16(6), FLOAT16(10), + FLOAT16(4), FLOAT16(10), FLOAT16(10), + FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(0), FLOAT16(0), FLOAT16(9), + FLOAT16(4), FLOAT16(8), FLOAT16(8), + FLOAT16(3), FLOAT16(3), FLOAT16(5), + FLOAT16(5), FLOAT16(3), FLOAT16(3), + FLOAT16(9), FLOAT16(9), FLOAT16(0), + }; + + DoTest(engine, input0, input1, expected_results, tensor(2, 2, 3, 5), axis); +} + +TEST(gather_elements_gpu_fp16, d1329_i1359_an1) { + auto& engine = get_test_engine(); + + auto axis = cldnn::gather_elements::gather_elements_axis::along_x; + auto input0 = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 3, 2, 9 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 3, 5, 9 } }); // indices + set_values(input0, { + FLOAT16(0), FLOAT16(1), + FLOAT16(8), FLOAT16(5), + FLOAT16(5), FLOAT16(2), + FLOAT16(0), FLOAT16(7), + FLOAT16(7), FLOAT16(10), + FLOAT16(4), FLOAT16(5), + FLOAT16(9), FLOAT16(0), + FLOAT16(0), FLOAT16(5), + FLOAT16(7), FLOAT16(0), + FLOAT16(4), FLOAT16(0), + FLOAT16(4), FLOAT16(7), + FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), + FLOAT16(1), FLOAT16(7), + FLOAT16(4), FLOAT16(7), + FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(0), + FLOAT16(8), FLOAT16(3), + FLOAT16(6), FLOAT16(8), + FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(10), + FLOAT16(7), FLOAT16(8), + FLOAT16(7), FLOAT16(0), + FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), + FLOAT16(8), FLOAT16(5), + FLOAT16(2), FLOAT16(3), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(5), FLOAT16(8), + FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), + FLOAT16(0), FLOAT16(7), FLOAT16(0), FLOAT16(7), FLOAT16(7), + FLOAT16(10), FLOAT16(7), FLOAT16(7), FLOAT16(10), FLOAT16(10), + FLOAT16(4), FLOAT16(4), FLOAT16(5), FLOAT16(4), FLOAT16(4), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(9), FLOAT16(9), + FLOAT16(5), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(4), FLOAT16(4), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(4), FLOAT16(7), + FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(10), + FLOAT16(5), FLOAT16(9), FLOAT16(5), FLOAT16(9), FLOAT16(5), + FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(1), FLOAT16(7), + FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(7), FLOAT16(7), + FLOAT16(8), FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(8), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(8), FLOAT16(8), FLOAT16(3), FLOAT16(8), FLOAT16(8), + FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(8), FLOAT16(6), + FLOAT16(10), FLOAT16(4), FLOAT16(10), FLOAT16(10), FLOAT16(10), + FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(10), FLOAT16(10), + FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(0), FLOAT16(7), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(6), FLOAT16(9), FLOAT16(9), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(2), FLOAT16(4), FLOAT16(2), FLOAT16(4), + FLOAT16(5), FLOAT16(8), FLOAT16(8), FLOAT16(5), FLOAT16(8), + FLOAT16(3), FLOAT16(3), FLOAT16(2), FLOAT16(3), FLOAT16(3), + }; + + DoTest(engine, input0, input1, expected_results, tensor(1, 3, 5, 9), axis); +} + +TEST(gather_elements_gpu_fp16, d12853_i12923_a3) { + auto& engine = get_test_engine(); + + auto axis = cldnn::gather_elements::gather_elements_axis::along_y; + auto input0 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 1, 2, 8, 5, 3 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 1, 2, 8, 2, 3 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), + FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(5), + FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(0), FLOAT16(8), FLOAT16(3), FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(3), + FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(8), FLOAT16(8), + FLOAT16(9), FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(7), + FLOAT16(10), FLOAT16(9), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(6), FLOAT16(9), + FLOAT16(4), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(1), + FLOAT16(1), FLOAT16(6), FLOAT16(8), FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(8), + FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(9), FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(1), + FLOAT16(1), FLOAT16(3), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(7), FLOAT16(10), FLOAT16(2), + FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(4), + FLOAT16(5), FLOAT16(1), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(10), FLOAT16(6), FLOAT16(1), + FLOAT16(10), FLOAT16(4), FLOAT16(1), FLOAT16(6), FLOAT16(2), FLOAT16(5), FLOAT16(5), FLOAT16(10), + FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(6), FLOAT16(1), FLOAT16(7), FLOAT16(6), FLOAT16(8), + FLOAT16(2), FLOAT16(5), FLOAT16(4), FLOAT16(2), FLOAT16(0), FLOAT16(9), FLOAT16(4), FLOAT16(1), + FLOAT16(10), FLOAT16(4), FLOAT16(1), FLOAT16(9), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(4), + FLOAT16(2), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(3), FLOAT16(4), FLOAT16(8), FLOAT16(10), + FLOAT16(7), FLOAT16(2), FLOAT16(7), FLOAT16(9), FLOAT16(2), FLOAT16(9), FLOAT16(5), FLOAT16(5), + FLOAT16(6), FLOAT16(8), FLOAT16(8), FLOAT16(5), FLOAT16(10), FLOAT16(6), FLOAT16(4), FLOAT16(9), + FLOAT16(7), FLOAT16(7), FLOAT16(10), FLOAT16(10), FLOAT16(9), FLOAT16(3), FLOAT16(5), FLOAT16(5), + FLOAT16(1), FLOAT16(4), FLOAT16(6), FLOAT16(9), FLOAT16(4), FLOAT16(8), FLOAT16(9), FLOAT16(7), + FLOAT16(8), FLOAT16(7), FLOAT16(8), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(0), + FLOAT16(7), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(10), FLOAT16(9), FLOAT16(9), + FLOAT16(5), FLOAT16(1), FLOAT16(4), FLOAT16(10), FLOAT16(2), FLOAT16(4), FLOAT16(3), FLOAT16(5), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(2), FLOAT16(4), FLOAT16(3), FLOAT16(4), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(4), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(3), FLOAT16(1), FLOAT16(4), FLOAT16(2), FLOAT16(4), FLOAT16(2), FLOAT16(1), FLOAT16(3), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(4), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(3), + FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(3), FLOAT16(4), FLOAT16(3), FLOAT16(4), FLOAT16(4), FLOAT16(1), FLOAT16(0), FLOAT16(3), + FLOAT16(2), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(4), FLOAT16(3), FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(3), FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(3), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(3), FLOAT16(3), FLOAT16(4), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(3), + FLOAT16(3), FLOAT16(4), FLOAT16(3), FLOAT16(3), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(3), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(4), FLOAT16(0), FLOAT16(4), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(0), FLOAT16(8), FLOAT16(7), FLOAT16(6), FLOAT16(2), FLOAT16(0), FLOAT16(5), + FLOAT16(2), FLOAT16(1), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(2), FLOAT16(0), FLOAT16(5), + FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(0), FLOAT16(10), FLOAT16(5), FLOAT16(3), FLOAT16(4), + FLOAT16(5), FLOAT16(4), FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(5), FLOAT16(4), + FLOAT16(6), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(6), FLOAT16(7), FLOAT16(7), + FLOAT16(1), FLOAT16(9), FLOAT16(8), FLOAT16(9), FLOAT16(1), FLOAT16(5), FLOAT16(8), FLOAT16(8), + FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(6), FLOAT16(1), FLOAT16(7), FLOAT16(6), FLOAT16(2), + FLOAT16(1), FLOAT16(3), FLOAT16(0), FLOAT16(6), FLOAT16(2), FLOAT16(7), FLOAT16(6), FLOAT16(1), + FLOAT16(7), FLOAT16(8), FLOAT16(8), FLOAT16(5), FLOAT16(0), FLOAT16(9), FLOAT16(0), FLOAT16(4), + FLOAT16(2), FLOAT16(2), FLOAT16(7), FLOAT16(5), FLOAT16(3), FLOAT16(9), FLOAT16(4), FLOAT16(5), + FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(7), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(9), + FLOAT16(1), FLOAT16(7), FLOAT16(10), FLOAT16(0), FLOAT16(9), FLOAT16(4), FLOAT16(5), FLOAT16(5), + }; + + DoTest(engine, input0, input1, expected_results, tensor(1, 2, 8, 2, 3), axis); +} + +TEST(gather_elements_gpu_fp16, d25441_i22441_an4) { + auto& engine = get_test_engine(); + + auto axis = cldnn::gather_elements::gather_elements_axis::along_f; + auto input0 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 2, 5, 4, 4, 1 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 2, 2, 4, 4, 1 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), + FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), + FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(5), + FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(5), + FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), + FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(0), FLOAT16(8), FLOAT16(3), + FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), + FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), + FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), + FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(3), + FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), + FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), + FLOAT16(10), FLOAT16(0), FLOAT16(8), FLOAT16(8), + FLOAT16(9), FLOAT16(1), FLOAT16(0), FLOAT16(7), + FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(7), + FLOAT16(10), FLOAT16(9), FLOAT16(2), FLOAT16(3), + FLOAT16(3), FLOAT16(5), FLOAT16(6), FLOAT16(9), + FLOAT16(4), FLOAT16(9), FLOAT16(2), FLOAT16(4), + FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(1), + FLOAT16(1), FLOAT16(6), FLOAT16(8), FLOAT16(0), + FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(8), + FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(9), + FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(1), + FLOAT16(1), FLOAT16(3), FLOAT16(0), FLOAT16(4), + FLOAT16(0), FLOAT16(7), FLOAT16(10), FLOAT16(2), + FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(7), + FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(4), + FLOAT16(5), FLOAT16(1), FLOAT16(6), FLOAT16(9), + FLOAT16(6), FLOAT16(10), FLOAT16(6), FLOAT16(1), + FLOAT16(10), FLOAT16(4), FLOAT16(1), FLOAT16(6), + FLOAT16(2), FLOAT16(5), FLOAT16(5), FLOAT16(10), + FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(6), + FLOAT16(1), FLOAT16(7), FLOAT16(6), FLOAT16(8), + + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(2), FLOAT16(4), FLOAT16(3), + FLOAT16(4), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(4), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(3), FLOAT16(1), FLOAT16(4), FLOAT16(2), + FLOAT16(4), FLOAT16(2), FLOAT16(1), FLOAT16(3), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(4), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(3), + FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(3), FLOAT16(4), FLOAT16(3), FLOAT16(4), + FLOAT16(4), FLOAT16(1), FLOAT16(0), FLOAT16(3), + FLOAT16(2), FLOAT16(4), FLOAT16(4), FLOAT16(4), + FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(4), + FLOAT16(3), FLOAT16(0), FLOAT16(2), FLOAT16(4), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(5), + FLOAT16(10), FLOAT16(2), FLOAT16(0), FLOAT16(10), + FLOAT16(3), FLOAT16(10), FLOAT16(1), FLOAT16(5), + FLOAT16(4), FLOAT16(0), FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(3), + FLOAT16(10), FLOAT16(8), FLOAT16(6), FLOAT16(1), + FLOAT16(2), FLOAT16(5), FLOAT16(7), FLOAT16(5), + FLOAT16(4), FLOAT16(0), FLOAT16(6), FLOAT16(3), + FLOAT16(10), FLOAT16(9), FLOAT16(6), FLOAT16(9), + FLOAT16(1), FLOAT16(6), FLOAT16(5), FLOAT16(7), + FLOAT16(5), FLOAT16(2), FLOAT16(6), FLOAT16(6), + FLOAT16(1), FLOAT16(5), FLOAT16(6), FLOAT16(1), + FLOAT16(6), FLOAT16(4), FLOAT16(1), FLOAT16(6), + FLOAT16(2), FLOAT16(6), FLOAT16(5), FLOAT16(7), + FLOAT16(1), FLOAT16(9), FLOAT16(2), FLOAT16(6), + FLOAT16(6), FLOAT16(5), FLOAT16(10), FLOAT16(8), + }; + + DoTest(engine, input0, input1, expected_results, tensor(2, 2, 4, 4, 1), axis); +} + +TEST(gather_elements_gpu_fp16, d32843_i12843_a0) { + auto& engine = get_test_engine(); + + auto axis = cldnn::gather_elements::gather_elements_axis::along_b; + auto input0 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 3, 2, 8, 4, 3 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 1, 2, 8, 4, 3 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), + FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(5), + FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(0), FLOAT16(8), FLOAT16(3), FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(3), + FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(8), FLOAT16(8), + FLOAT16(9), FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(7), + FLOAT16(10), FLOAT16(9), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(6), FLOAT16(9), + FLOAT16(4), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(1), + FLOAT16(1), FLOAT16(6), FLOAT16(8), FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(8), + FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(9), FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(1), + FLOAT16(1), FLOAT16(3), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(7), FLOAT16(10), FLOAT16(2), + FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(4), + FLOAT16(5), FLOAT16(1), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(10), FLOAT16(6), FLOAT16(1), + FLOAT16(10), FLOAT16(4), FLOAT16(1), FLOAT16(6), FLOAT16(2), FLOAT16(5), FLOAT16(5), FLOAT16(10), + FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(6), FLOAT16(1), FLOAT16(7), FLOAT16(6), FLOAT16(8), + FLOAT16(2), FLOAT16(5), FLOAT16(4), FLOAT16(2), FLOAT16(0), FLOAT16(9), FLOAT16(4), FLOAT16(1), + FLOAT16(10), FLOAT16(4), FLOAT16(1), FLOAT16(9), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(4), + FLOAT16(2), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(3), FLOAT16(4), FLOAT16(8), FLOAT16(10), + FLOAT16(7), FLOAT16(2), FLOAT16(7), FLOAT16(9), FLOAT16(2), FLOAT16(9), FLOAT16(5), FLOAT16(5), + FLOAT16(6), FLOAT16(8), FLOAT16(8), FLOAT16(5), FLOAT16(10), FLOAT16(6), FLOAT16(4), FLOAT16(9), + FLOAT16(7), FLOAT16(7), FLOAT16(10), FLOAT16(10), FLOAT16(9), FLOAT16(3), FLOAT16(5), FLOAT16(5), + FLOAT16(1), FLOAT16(4), FLOAT16(6), FLOAT16(9), FLOAT16(4), FLOAT16(8), FLOAT16(9), FLOAT16(7), + FLOAT16(8), FLOAT16(7), FLOAT16(8), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(0), + FLOAT16(7), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(10), FLOAT16(9), FLOAT16(9), + FLOAT16(5), FLOAT16(1), FLOAT16(4), FLOAT16(10), FLOAT16(2), FLOAT16(4), FLOAT16(3), FLOAT16(5), + FLOAT16(9), FLOAT16(4), FLOAT16(5), FLOAT16(8), FLOAT16(4), FLOAT16(2), FLOAT16(10), FLOAT16(1), + FLOAT16(6), FLOAT16(6), FLOAT16(0), FLOAT16(0), FLOAT16(8), FLOAT16(8), FLOAT16(3), FLOAT16(4), + FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(9), FLOAT16(7), FLOAT16(9), FLOAT16(1), FLOAT16(0), + FLOAT16(8), FLOAT16(6), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(10), FLOAT16(10), + FLOAT16(4), FLOAT16(2), FLOAT16(7), FLOAT16(3), FLOAT16(8), FLOAT16(8), FLOAT16(4), FLOAT16(3), + FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(10), FLOAT16(2), FLOAT16(9), FLOAT16(1), FLOAT16(4), + FLOAT16(6), FLOAT16(1), FLOAT16(9), FLOAT16(1), FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(6), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(6), + FLOAT16(0), FLOAT16(6), FLOAT16(2), FLOAT16(3), FLOAT16(7), FLOAT16(1), FLOAT16(8), FLOAT16(5), + FLOAT16(6), FLOAT16(6), FLOAT16(3), FLOAT16(7), FLOAT16(1), FLOAT16(1), FLOAT16(5), FLOAT16(9), + FLOAT16(8), FLOAT16(6), FLOAT16(8), FLOAT16(3), FLOAT16(1), FLOAT16(5), FLOAT16(3), FLOAT16(6), + FLOAT16(5), FLOAT16(4), FLOAT16(2), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(0), FLOAT16(4), FLOAT16(2), FLOAT16(7), FLOAT16(7), FLOAT16(5), FLOAT16(8), + FLOAT16(7), FLOAT16(10), FLOAT16(5), FLOAT16(10), FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(7), + FLOAT16(4), FLOAT16(6), FLOAT16(10), FLOAT16(1), FLOAT16(7), FLOAT16(3), FLOAT16(5), FLOAT16(5), + FLOAT16(9), FLOAT16(0), FLOAT16(3), FLOAT16(7), FLOAT16(6), FLOAT16(10), FLOAT16(2), FLOAT16(10), + FLOAT16(2), FLOAT16(9), FLOAT16(7), FLOAT16(5), FLOAT16(8), FLOAT16(0), FLOAT16(1), FLOAT16(7), + FLOAT16(7), FLOAT16(4), FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(7), FLOAT16(3), FLOAT16(8), + FLOAT16(1), FLOAT16(0), FLOAT16(5), FLOAT16(0), FLOAT16(1), FLOAT16(9), FLOAT16(8), FLOAT16(8), + FLOAT16(4), FLOAT16(0), FLOAT16(6), FLOAT16(5), FLOAT16(0), FLOAT16(5), FLOAT16(4), FLOAT16(2), + FLOAT16(4), FLOAT16(6), FLOAT16(7), FLOAT16(7), FLOAT16(5), FLOAT16(3), FLOAT16(8), FLOAT16(4), + FLOAT16(7), FLOAT16(3), FLOAT16(0), FLOAT16(1), FLOAT16(5), FLOAT16(8), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(7), FLOAT16(3), FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(5), + FLOAT16(4), FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(7), FLOAT16(6), FLOAT16(7), FLOAT16(3), + FLOAT16(0), FLOAT16(10), FLOAT16(5), FLOAT16(0), FLOAT16(9), FLOAT16(0), FLOAT16(4), FLOAT16(5), + FLOAT16(6), FLOAT16(8), FLOAT16(7), FLOAT16(5), FLOAT16(0), FLOAT16(1), FLOAT16(10), FLOAT16(2), + FLOAT16(3), FLOAT16(6), FLOAT16(6), FLOAT16(1), FLOAT16(6), FLOAT16(10), FLOAT16(3), FLOAT16(9), + FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(9), FLOAT16(2), FLOAT16(8), + FLOAT16(7), FLOAT16(4), FLOAT16(2), FLOAT16(7), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(6), + FLOAT16(0), FLOAT16(1), FLOAT16(6), FLOAT16(4), FLOAT16(0), FLOAT16(7), FLOAT16(4), FLOAT16(9), + FLOAT16(1), FLOAT16(10), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(8), FLOAT16(10), FLOAT16(2), + FLOAT16(3), FLOAT16(8), FLOAT16(5), FLOAT16(8), FLOAT16(7), FLOAT16(7), FLOAT16(8), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(6), FLOAT16(7), FLOAT16(6), FLOAT16(4), FLOAT16(2), FLOAT16(2), + FLOAT16(7), FLOAT16(1), FLOAT16(8), FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(1), FLOAT16(10), + FLOAT16(5), FLOAT16(6), FLOAT16(10), FLOAT16(0), FLOAT16(6), FLOAT16(7), FLOAT16(5), FLOAT16(0), + FLOAT16(4), FLOAT16(5), FLOAT16(8), FLOAT16(0), FLOAT16(4), FLOAT16(10), FLOAT16(5), FLOAT16(3), + FLOAT16(4), FLOAT16(8), FLOAT16(2), FLOAT16(1), FLOAT16(4), FLOAT16(10), FLOAT16(10), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(5), FLOAT16(1), FLOAT16(5), FLOAT16(1), FLOAT16(9), FLOAT16(4), + FLOAT16(4), FLOAT16(3), FLOAT16(7), FLOAT16(6), FLOAT16(9), FLOAT16(8), FLOAT16(9), FLOAT16(7), + FLOAT16(4), FLOAT16(10), FLOAT16(6), FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(4), FLOAT16(2), + FLOAT16(0), FLOAT16(4), FLOAT16(5), FLOAT16(3), FLOAT16(1), FLOAT16(2), FLOAT16(8), FLOAT16(5), + FLOAT16(7), FLOAT16(9), FLOAT16(2), FLOAT16(7), FLOAT16(2), FLOAT16(4), FLOAT16(0), FLOAT16(5), + + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(8), FLOAT16(5), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(7), + FLOAT16(4), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(5), FLOAT16(5), + FLOAT16(4), FLOAT16(4), FLOAT16(7), FLOAT16(9), FLOAT16(5), FLOAT16(8), FLOAT16(6), FLOAT16(4), + FLOAT16(8), FLOAT16(5), FLOAT16(8), FLOAT16(1), FLOAT16(4), FLOAT16(7), FLOAT16(5), FLOAT16(0), + FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(8), FLOAT16(5), FLOAT16(4), + FLOAT16(4), FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(3), + FLOAT16(9), FLOAT16(10), FLOAT16(5), FLOAT16(0), FLOAT16(9), FLOAT16(3), FLOAT16(4), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(5), FLOAT16(0), FLOAT16(5), FLOAT16(3), FLOAT16(4), + FLOAT16(3), FLOAT16(6), FLOAT16(2), FLOAT16(9), FLOAT16(10), FLOAT16(10), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(2), FLOAT16(2), FLOAT16(4), FLOAT16(0), FLOAT16(0), FLOAT16(8), FLOAT16(8), + FLOAT16(4), FLOAT16(4), FLOAT16(7), FLOAT16(7), FLOAT16(9), FLOAT16(6), FLOAT16(4), FLOAT16(6), + FLOAT16(10), FLOAT16(9), FLOAT16(2), FLOAT16(10), FLOAT16(2), FLOAT16(7), FLOAT16(6), FLOAT16(9), + FLOAT16(1), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(10), FLOAT16(5), FLOAT16(3), FLOAT16(2), + FLOAT16(3), FLOAT16(6), FLOAT16(5), FLOAT16(0), FLOAT16(5), FLOAT16(8), FLOAT16(7), FLOAT16(8), + FLOAT16(0), FLOAT16(6), FLOAT16(2), FLOAT16(9), FLOAT16(6), FLOAT16(1), FLOAT16(7), FLOAT16(2), + FLOAT16(1), FLOAT16(3), FLOAT16(3), FLOAT16(7), FLOAT16(0), FLOAT16(7), FLOAT16(5), FLOAT16(9), + FLOAT16(8), FLOAT16(3), FLOAT16(10), FLOAT16(3), FLOAT16(1), FLOAT16(5), FLOAT16(4), FLOAT16(6), + FLOAT16(4), FLOAT16(5), FLOAT16(6), FLOAT16(4), FLOAT16(4), FLOAT16(10), FLOAT16(5), FLOAT16(1), + FLOAT16(3), FLOAT16(4), FLOAT16(2), FLOAT16(1), FLOAT16(7), FLOAT16(7), FLOAT16(5), FLOAT16(10), + FLOAT16(7), FLOAT16(1), FLOAT16(5), FLOAT16(10), FLOAT16(3), FLOAT16(1), FLOAT16(5), FLOAT16(4), + FLOAT16(2), FLOAT16(3), FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(8), FLOAT16(5), FLOAT16(5), + FLOAT16(4), FLOAT16(4), FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(10), FLOAT16(4), FLOAT16(2), + FLOAT16(2), FLOAT16(9), FLOAT16(7), FLOAT16(5), FLOAT16(3), FLOAT16(4), FLOAT16(8), FLOAT16(5), + FLOAT16(7), FLOAT16(4), FLOAT16(6), FLOAT16(8), FLOAT16(2), FLOAT16(7), FLOAT16(3), FLOAT16(5), + }; + + DoTest(engine, input0, input1, expected_results, tensor(1, 2, 8, 4, 3), axis); +} + +TEST(gather_elements_gpu_fp16, d223442_i226442_a5) { + auto& engine = get_test_engine(); + + auto axis = cldnn::gather_elements::gather_elements_axis::along_x; + auto input0 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 2, 2, 3, 4, 4, 2 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 2, 2, 6, 4, 4, 2 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), + FLOAT16(5), FLOAT16(5), FLOAT16(2), + FLOAT16(0), FLOAT16(7), FLOAT16(7), + FLOAT16(10), FLOAT16(4), FLOAT16(5), + FLOAT16(9), FLOAT16(0), FLOAT16(0), + FLOAT16(5), FLOAT16(7), FLOAT16(0), + FLOAT16(4), FLOAT16(0), FLOAT16(4), + FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), FLOAT16(1), + FLOAT16(7), FLOAT16(4), FLOAT16(7), + FLOAT16(10), FLOAT16(8), FLOAT16(2), + FLOAT16(0), FLOAT16(8), FLOAT16(3), + FLOAT16(6), FLOAT16(8), FLOAT16(10), + FLOAT16(4), FLOAT16(2), FLOAT16(10), + FLOAT16(7), FLOAT16(8), FLOAT16(7), + FLOAT16(0), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), FLOAT16(8), + FLOAT16(5), FLOAT16(2), FLOAT16(3), + FLOAT16(3), FLOAT16(1), FLOAT16(5), + FLOAT16(9), FLOAT16(10), FLOAT16(0), + FLOAT16(9), FLOAT16(5), FLOAT16(5), + FLOAT16(3), FLOAT16(10), FLOAT16(5), + FLOAT16(2), FLOAT16(0), FLOAT16(10), + FLOAT16(0), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(10), FLOAT16(5), + FLOAT16(5), FLOAT16(10), FLOAT16(0), + FLOAT16(8), FLOAT16(8), FLOAT16(9), + FLOAT16(1), FLOAT16(0), FLOAT16(7), + FLOAT16(9), FLOAT16(6), FLOAT16(8), + FLOAT16(7), FLOAT16(10), FLOAT16(9), + FLOAT16(2), FLOAT16(3), FLOAT16(3), + FLOAT16(5), FLOAT16(6), FLOAT16(9), + FLOAT16(4), FLOAT16(9), FLOAT16(2), + FLOAT16(4), FLOAT16(5), FLOAT16(5), + FLOAT16(3), FLOAT16(1), FLOAT16(1), + FLOAT16(6), FLOAT16(8), FLOAT16(0), + FLOAT16(5), FLOAT16(5), FLOAT16(10), + FLOAT16(8), FLOAT16(6), FLOAT16(9), + FLOAT16(6), FLOAT16(9), FLOAT16(1), + FLOAT16(2), FLOAT16(7), FLOAT16(1), + FLOAT16(1), FLOAT16(3), FLOAT16(0), + FLOAT16(4), FLOAT16(0), FLOAT16(7), + FLOAT16(10), FLOAT16(2), FLOAT16(1), + FLOAT16(3), FLOAT16(9), FLOAT16(7), + FLOAT16(1), FLOAT16(7), FLOAT16(4), + FLOAT16(4), FLOAT16(5), FLOAT16(1), + FLOAT16(6), FLOAT16(9), FLOAT16(6), + FLOAT16(10), FLOAT16(6), FLOAT16(1), + FLOAT16(10), FLOAT16(4), FLOAT16(1), + FLOAT16(6), FLOAT16(2), FLOAT16(5), + FLOAT16(5), FLOAT16(10), FLOAT16(1), + FLOAT16(2), FLOAT16(3), FLOAT16(6), + FLOAT16(1), FLOAT16(7), FLOAT16(6), + FLOAT16(8), FLOAT16(2), FLOAT16(5), + FLOAT16(4), FLOAT16(2), FLOAT16(0), + FLOAT16(9), FLOAT16(4), FLOAT16(1), + FLOAT16(10), FLOAT16(4), FLOAT16(1), + FLOAT16(9), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(4), FLOAT16(2), + FLOAT16(1), FLOAT16(8), FLOAT16(5), + FLOAT16(3), FLOAT16(4), FLOAT16(8), + FLOAT16(10), FLOAT16(7), FLOAT16(2), + FLOAT16(7), FLOAT16(9), FLOAT16(2), + FLOAT16(9), FLOAT16(5), FLOAT16(5), + FLOAT16(6), FLOAT16(8), FLOAT16(8), + FLOAT16(5), FLOAT16(10), FLOAT16(6), + FLOAT16(4), FLOAT16(9), FLOAT16(7), + FLOAT16(7), FLOAT16(10), FLOAT16(10), + FLOAT16(9), FLOAT16(3), FLOAT16(5), + FLOAT16(5), FLOAT16(1), FLOAT16(4), + FLOAT16(6), FLOAT16(9), FLOAT16(4), + FLOAT16(8), FLOAT16(9), FLOAT16(7), + FLOAT16(8), FLOAT16(7), FLOAT16(8), + FLOAT16(0), FLOAT16(9), FLOAT16(5), + FLOAT16(5), FLOAT16(0), FLOAT16(7), + FLOAT16(5), FLOAT16(7), FLOAT16(7), + FLOAT16(2), FLOAT16(10), FLOAT16(9), + FLOAT16(9), FLOAT16(5), FLOAT16(1), + FLOAT16(4), FLOAT16(10), FLOAT16(2), + FLOAT16(4), FLOAT16(3), FLOAT16(5), + FLOAT16(9), FLOAT16(4), FLOAT16(5), + FLOAT16(8), FLOAT16(4), FLOAT16(2), + FLOAT16(10), FLOAT16(1), FLOAT16(6), + FLOAT16(6), FLOAT16(0), FLOAT16(0), + FLOAT16(8), FLOAT16(8), FLOAT16(3), + FLOAT16(4), FLOAT16(7), FLOAT16(7), + FLOAT16(2), FLOAT16(9), FLOAT16(7), + FLOAT16(9), FLOAT16(1), FLOAT16(0), + FLOAT16(8), FLOAT16(6), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(4), + FLOAT16(10), FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(7), FLOAT16(3), + FLOAT16(8), FLOAT16(8), FLOAT16(4), + FLOAT16(3), FLOAT16(2), FLOAT16(0), + FLOAT16(2), FLOAT16(10), FLOAT16(2), + FLOAT16(9), FLOAT16(1), FLOAT16(4), + FLOAT16(6), FLOAT16(1), FLOAT16(9), + FLOAT16(1), FLOAT16(10), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(6), FLOAT16(7), FLOAT16(8), + FLOAT16(7), FLOAT16(8), FLOAT16(7), + FLOAT16(6), FLOAT16(0), FLOAT16(6), + FLOAT16(2), FLOAT16(3), FLOAT16(7), + FLOAT16(1), FLOAT16(8), FLOAT16(5), + FLOAT16(6), FLOAT16(6), FLOAT16(3), + FLOAT16(7), FLOAT16(1), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(8), + FLOAT16(6), FLOAT16(8), FLOAT16(3), + FLOAT16(1), FLOAT16(5), FLOAT16(3), + FLOAT16(6), FLOAT16(5), FLOAT16(4), + FLOAT16(2), FLOAT16(4), FLOAT16(4), + FLOAT16(4), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(0), FLOAT16(4), + FLOAT16(2), FLOAT16(7), FLOAT16(7), + FLOAT16(5), FLOAT16(8), FLOAT16(7), + FLOAT16(10), FLOAT16(5), FLOAT16(10), + FLOAT16(3), FLOAT16(5), FLOAT16(5), + FLOAT16(7), FLOAT16(4), FLOAT16(6), + FLOAT16(10), FLOAT16(1), FLOAT16(7), + FLOAT16(3), FLOAT16(5), FLOAT16(5), + FLOAT16(9), FLOAT16(0), FLOAT16(3), + FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(2), FLOAT16(10), FLOAT16(2), + FLOAT16(9), FLOAT16(7), FLOAT16(5), + FLOAT16(8), FLOAT16(0), FLOAT16(1), + FLOAT16(7), FLOAT16(7), FLOAT16(4), + FLOAT16(6), FLOAT16(8), FLOAT16(10), + FLOAT16(7), FLOAT16(3), FLOAT16(8), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(0), + FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(5), FLOAT16(5), FLOAT16(5), + FLOAT16(7), FLOAT16(0), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(5), FLOAT16(4), FLOAT16(5), FLOAT16(4), FLOAT16(10), FLOAT16(5), + FLOAT16(0), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(9), FLOAT16(9), + FLOAT16(7), FLOAT16(0), FLOAT16(0), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(0), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), + FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(7), FLOAT16(7), FLOAT16(10), + FLOAT16(5), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(9), + FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(2), FLOAT16(10), FLOAT16(8), FLOAT16(8), FLOAT16(2), FLOAT16(2), + FLOAT16(8), FLOAT16(8), FLOAT16(0), FLOAT16(3), FLOAT16(0), FLOAT16(0), + FLOAT16(6), FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(8), FLOAT16(6), + FLOAT16(4), FLOAT16(10), FLOAT16(2), FLOAT16(10), FLOAT16(2), FLOAT16(10), + FLOAT16(7), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(0), FLOAT16(6), FLOAT16(6), FLOAT16(9), FLOAT16(0), FLOAT16(0), + FLOAT16(8), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(4), FLOAT16(2), + FLOAT16(5), FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(3), FLOAT16(5), + FLOAT16(3), FLOAT16(1), FLOAT16(1), FLOAT16(3), FLOAT16(1), FLOAT16(1), + FLOAT16(10), FLOAT16(9), FLOAT16(0), FLOAT16(10), FLOAT16(9), FLOAT16(0), + FLOAT16(9), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), + FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(3), FLOAT16(5), FLOAT16(10), + FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(10), FLOAT16(10), + FLOAT16(0), FLOAT16(5), FLOAT16(4), FLOAT16(4), FLOAT16(5), FLOAT16(0), + FLOAT16(10), FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(10), + FLOAT16(10), FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(10), FLOAT16(10), + FLOAT16(8), FLOAT16(9), FLOAT16(8), FLOAT16(9), FLOAT16(8), FLOAT16(9), + FLOAT16(7), FLOAT16(0), FLOAT16(0), FLOAT16(7), FLOAT16(0), FLOAT16(0), + FLOAT16(8), FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(8), FLOAT16(6), + FLOAT16(9), FLOAT16(9), FLOAT16(7), FLOAT16(10), FLOAT16(10), FLOAT16(10), + FLOAT16(2), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(2), FLOAT16(3), + FLOAT16(6), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(6), FLOAT16(5), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(9), FLOAT16(4), FLOAT16(4), + FLOAT16(5), FLOAT16(5), FLOAT16(4), FLOAT16(4), FLOAT16(5), FLOAT16(5), + FLOAT16(3), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(8), FLOAT16(8), FLOAT16(6), FLOAT16(6), FLOAT16(0), + FLOAT16(10), FLOAT16(10), FLOAT16(5), FLOAT16(10), FLOAT16(5), FLOAT16(5), + FLOAT16(6), FLOAT16(8), FLOAT16(9), FLOAT16(9), FLOAT16(8), FLOAT16(9), + FLOAT16(9), FLOAT16(6), FLOAT16(6), FLOAT16(1), FLOAT16(9), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(7), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(3), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(4), FLOAT16(4), FLOAT16(7), FLOAT16(4), FLOAT16(0), + FLOAT16(10), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(10), FLOAT16(10), + FLOAT16(3), FLOAT16(3), FLOAT16(3), FLOAT16(9), FLOAT16(9), FLOAT16(7), + FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(1), FLOAT16(4), FLOAT16(4), + FLOAT16(4), FLOAT16(4), FLOAT16(1), FLOAT16(1), FLOAT16(5), FLOAT16(5), + FLOAT16(9), FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(9), FLOAT16(9), + FLOAT16(6), FLOAT16(10), FLOAT16(6), FLOAT16(10), FLOAT16(10), FLOAT16(10), + FLOAT16(1), FLOAT16(10), FLOAT16(1), FLOAT16(10), FLOAT16(1), FLOAT16(10), + FLOAT16(2), FLOAT16(5), FLOAT16(6), FLOAT16(2), FLOAT16(2), FLOAT16(6), + FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(1), FLOAT16(10), FLOAT16(10), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(6), + FLOAT16(1), FLOAT16(1), FLOAT16(6), FLOAT16(7), FLOAT16(7), FLOAT16(6), + FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(8), FLOAT16(8), FLOAT16(2), + FLOAT16(4), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(4), + FLOAT16(1), FLOAT16(9), FLOAT16(4), FLOAT16(9), FLOAT16(9), FLOAT16(4), + FLOAT16(1), FLOAT16(4), FLOAT16(1), FLOAT16(4), FLOAT16(4), FLOAT16(10), + FLOAT16(1), FLOAT16(1), FLOAT16(9), FLOAT16(1), FLOAT16(9), FLOAT16(1), + FLOAT16(4), FLOAT16(4), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(8), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(5), FLOAT16(8), + FLOAT16(3), FLOAT16(4), FLOAT16(3), FLOAT16(3), FLOAT16(3), FLOAT16(8), + FLOAT16(10), FLOAT16(10), FLOAT16(7), FLOAT16(10), FLOAT16(10), FLOAT16(2), + FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(9), FLOAT16(9), FLOAT16(9), + FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(9), FLOAT16(9), FLOAT16(9), + FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(8), + FLOAT16(5), FLOAT16(6), FLOAT16(6), FLOAT16(5), FLOAT16(10), FLOAT16(5), + FLOAT16(7), FLOAT16(9), FLOAT16(7), FLOAT16(7), FLOAT16(9), FLOAT16(7), + FLOAT16(10), FLOAT16(10), FLOAT16(7), FLOAT16(10), FLOAT16(7), FLOAT16(10), + FLOAT16(5), FLOAT16(3), FLOAT16(9), FLOAT16(3), FLOAT16(9), FLOAT16(3), + FLOAT16(5), FLOAT16(1), FLOAT16(1), FLOAT16(4), FLOAT16(4), FLOAT16(4), + FLOAT16(9), FLOAT16(9), FLOAT16(9), FLOAT16(4), FLOAT16(6), FLOAT16(6), + FLOAT16(9), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(7), FLOAT16(9), + FLOAT16(8), FLOAT16(8), FLOAT16(7), FLOAT16(8), FLOAT16(8), FLOAT16(8), + FLOAT16(9), FLOAT16(0), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(0), FLOAT16(0), + FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(2), FLOAT16(9), FLOAT16(2), FLOAT16(9), FLOAT16(9), FLOAT16(10), + FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(1), FLOAT16(5), FLOAT16(9), + FLOAT16(4), FLOAT16(10), FLOAT16(2), FLOAT16(10), FLOAT16(4), FLOAT16(4), + FLOAT16(5), FLOAT16(3), FLOAT16(4), FLOAT16(3), FLOAT16(4), FLOAT16(5), + FLOAT16(5), FLOAT16(9), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(4), + FLOAT16(4), FLOAT16(8), FLOAT16(8), FLOAT16(2), FLOAT16(4), FLOAT16(4), + FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(1), FLOAT16(10), FLOAT16(6), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(6), FLOAT16(0), FLOAT16(0), + FLOAT16(3), FLOAT16(8), FLOAT16(8), FLOAT16(3), FLOAT16(8), FLOAT16(8), + FLOAT16(4), FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(9), FLOAT16(2), FLOAT16(7), FLOAT16(9), FLOAT16(7), FLOAT16(7), + FLOAT16(9), FLOAT16(0), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(10), + FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(3), FLOAT16(7), FLOAT16(3), + FLOAT16(4), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(8), + FLOAT16(3), FLOAT16(0), FLOAT16(3), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(10), FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(9), FLOAT16(4), FLOAT16(1), FLOAT16(1), FLOAT16(4), FLOAT16(4), + FLOAT16(6), FLOAT16(1), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(1), + FLOAT16(10), FLOAT16(2), FLOAT16(1), FLOAT16(10), FLOAT16(1), FLOAT16(10), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), + FLOAT16(8), FLOAT16(6), FLOAT16(6), FLOAT16(8), FLOAT16(6), FLOAT16(6), + FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(0), FLOAT16(0), FLOAT16(6), FLOAT16(6), FLOAT16(0), FLOAT16(0), + FLOAT16(7), FLOAT16(3), FLOAT16(3), FLOAT16(2), FLOAT16(7), FLOAT16(3), + FLOAT16(5), FLOAT16(1), FLOAT16(1), FLOAT16(5), FLOAT16(8), FLOAT16(5), + FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(6), + FLOAT16(1), FLOAT16(1), FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(7), + FLOAT16(9), FLOAT16(5), FLOAT16(8), FLOAT16(8), FLOAT16(5), FLOAT16(9), + FLOAT16(6), FLOAT16(8), FLOAT16(8), FLOAT16(6), FLOAT16(6), FLOAT16(6), + FLOAT16(3), FLOAT16(5), FLOAT16(3), FLOAT16(5), FLOAT16(1), FLOAT16(1), + FLOAT16(6), FLOAT16(5), FLOAT16(4), FLOAT16(5), FLOAT16(6), FLOAT16(5), + FLOAT16(4), FLOAT16(2), FLOAT16(4), FLOAT16(4), FLOAT16(2), FLOAT16(2), + FLOAT16(4), FLOAT16(5), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), + FLOAT16(3), FLOAT16(3), FLOAT16(0), FLOAT16(4), FLOAT16(3), FLOAT16(4), + FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(5), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(5), FLOAT16(5), + FLOAT16(10), FLOAT16(5), FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(5), + FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(5), FLOAT16(5), + FLOAT16(6), FLOAT16(6), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(10), FLOAT16(1), FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(7), + FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(5), + FLOAT16(0), FLOAT16(9), FLOAT16(3), FLOAT16(9), FLOAT16(0), FLOAT16(3), + FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(10), FLOAT16(10), FLOAT16(6), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(10), FLOAT16(10), FLOAT16(10), + FLOAT16(5), FLOAT16(9), FLOAT16(7), FLOAT16(7), FLOAT16(5), FLOAT16(9), + FLOAT16(0), FLOAT16(8), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(8), + FLOAT16(7), FLOAT16(7), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), + FLOAT16(8), FLOAT16(10), FLOAT16(8), FLOAT16(6), FLOAT16(10), FLOAT16(8), + FLOAT16(3), FLOAT16(3), FLOAT16(7), FLOAT16(8), FLOAT16(3), FLOAT16(8), + }; + + DoTest(engine, input0, input1, expected_results, tensor(2, 2, 6, 4, 4, 2), axis); +} + +TEST(gather_elements_gpu_fp16, d124251_i124221_an3) { + auto& engine = get_test_engine(); + + auto axis = cldnn::gather_elements::gather_elements_axis::along_z; + auto input0 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 1, 2, 4, 2, 5, 1 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 1, 2, 4, 2, 2, 1 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), + FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), + FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(5), + FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(5), + FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), + FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(0), FLOAT16(8), FLOAT16(3), + FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), + FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), + FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), + FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(3), + FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), + FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), + FLOAT16(10), FLOAT16(0), FLOAT16(8), FLOAT16(8), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(2), FLOAT16(4), FLOAT16(3), + FLOAT16(4), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(4), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(3), FLOAT16(1), FLOAT16(4), FLOAT16(2), + FLOAT16(4), FLOAT16(2), FLOAT16(1), FLOAT16(3), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(4), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(4), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(0), FLOAT16(8), FLOAT16(7), + FLOAT16(6), FLOAT16(2), FLOAT16(0), FLOAT16(5), + FLOAT16(2), FLOAT16(1), FLOAT16(4), FLOAT16(5), + FLOAT16(9), FLOAT16(2), FLOAT16(0), FLOAT16(5), + FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(0), + FLOAT16(10), FLOAT16(5), FLOAT16(3), FLOAT16(4), + FLOAT16(5), FLOAT16(4), FLOAT16(10), FLOAT16(5), + FLOAT16(2), FLOAT16(0), FLOAT16(5), FLOAT16(8), + }; + + DoTest(engine, input0, input1, expected_results, tensor(1, 2, 4, 2, 2, 1), axis); +} + +TEST(gather_elements_gpu_fp16, d233113_i233115_a2) { + auto& engine = get_test_engine(); + + auto axis = cldnn::gather_elements::gather_elements_axis::along_w; + auto input0 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 2, 3, 3, 1, 1, 3 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 2, 3, 3, 1, 1, 5 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), + FLOAT16(5), FLOAT16(5), FLOAT16(2), + FLOAT16(0), FLOAT16(7), FLOAT16(7), + FLOAT16(10), FLOAT16(4), FLOAT16(5), + FLOAT16(9), FLOAT16(0), FLOAT16(0), + FLOAT16(5), FLOAT16(7), FLOAT16(0), + FLOAT16(4), FLOAT16(0), FLOAT16(4), + FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), FLOAT16(1), + FLOAT16(7), FLOAT16(4), FLOAT16(7), + FLOAT16(10), FLOAT16(8), FLOAT16(2), + FLOAT16(0), FLOAT16(8), FLOAT16(3), + FLOAT16(6), FLOAT16(8), FLOAT16(10), + FLOAT16(4), FLOAT16(2), FLOAT16(10), + FLOAT16(7), FLOAT16(8), FLOAT16(7), + FLOAT16(0), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), FLOAT16(8), + FLOAT16(5), FLOAT16(2), FLOAT16(3), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(2), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(5), FLOAT16(7), + FLOAT16(0), FLOAT16(7), FLOAT16(8), + FLOAT16(0), FLOAT16(1), FLOAT16(7), + FLOAT16(0), FLOAT16(1), FLOAT16(8), + FLOAT16(5), FLOAT16(1), FLOAT16(2), + FLOAT16(9), FLOAT16(7), FLOAT16(0), + FLOAT16(5), FLOAT16(0), FLOAT16(0), + FLOAT16(9), FLOAT16(4), FLOAT16(0), + FLOAT16(9), FLOAT16(4), FLOAT16(0), + FLOAT16(5), FLOAT16(4), FLOAT16(5), + FLOAT16(7), FLOAT16(5), FLOAT16(1), + FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(7), FLOAT16(0), FLOAT16(1), + FLOAT16(4), FLOAT16(5), FLOAT16(1), + FLOAT16(9), FLOAT16(5), FLOAT16(1), + FLOAT16(7), FLOAT16(4), FLOAT16(3), + FLOAT16(10), FLOAT16(8), FLOAT16(3), + FLOAT16(0), FLOAT16(8), FLOAT16(7), + FLOAT16(0), FLOAT16(4), FLOAT16(7), + FLOAT16(7), FLOAT16(4), FLOAT16(3), + FLOAT16(7), FLOAT16(8), FLOAT16(10), + FLOAT16(4), FLOAT16(8), FLOAT16(7), + FLOAT16(4), FLOAT16(2), FLOAT16(10), + FLOAT16(7), FLOAT16(8), FLOAT16(10), + FLOAT16(6), FLOAT16(8), FLOAT16(7), + FLOAT16(5), FLOAT16(4), FLOAT16(9), + FLOAT16(0), FLOAT16(2), FLOAT16(8), + FLOAT16(5), FLOAT16(4), FLOAT16(3), + FLOAT16(0), FLOAT16(6), FLOAT16(8), + FLOAT16(5), FLOAT16(6), FLOAT16(3), + }; + + DoTest(engine, input0, input1, expected_results, tensor(2, 3, 3, 1, 1, 5), axis); +}