From 95793838906463faa43afb78a9cafebdcc810136 Mon Sep 17 00:00:00 2001 From: yunji Date: Thu, 17 Jun 2021 18:20:26 +0900 Subject: [PATCH 01/11] Init boilerplate codes --- .../cldnn_engine/cldnn_primitives_list.hpp | 2 + .../src/cldnn_engine/ops/gather_elements.cpp | 37 + .../single_layer_tests/gather_elements.cpp | 81 ++ .../single_layer_tests/gather_elements.hpp | 15 + .../thirdparty/clDNN/api/gather_elements.hpp | 57 + .../kernel_selector/common/common_types.h | 1 + .../gather/gather_elements_kernel_ref.cpp | 210 +++ .../gather/gather_elements_kernel_ref.h | 61 + .../gather_elements_kernel_selector.cpp | 27 + .../gather/gather_elements_kernel_selector.h | 35 + .../core/cl_kernels/gather_elements_ref.cl | 231 ++++ .../thirdparty/clDNN/src/gather_elements.cpp | 114 ++ .../clDNN/src/gpu/gather_elements_gpu.cpp | 78 ++ .../prepare_primitive_fusing.cpp | 9 + .../clDNN/src/impls/ocl/register.hpp | 1 + .../clDNN/src/include/gather_elements_inst.h | 49 + .../test_cases/gather_elements_gpu_test.cpp | 1210 +++++++++++++++++ 17 files changed, 2218 insertions(+) create mode 100644 inference-engine/src/cldnn_engine/ops/gather_elements.cpp create mode 100644 inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp create mode 100644 inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp create mode 100644 inference-engine/thirdparty/clDNN/api/gather_elements.hpp create mode 100644 inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp create mode 100644 inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h create mode 100644 inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.cpp create mode 100644 inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.h create mode 100644 inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl create mode 100644 inference-engine/thirdparty/clDNN/src/gather_elements.cpp create mode 100644 inference-engine/thirdparty/clDNN/src/gpu/gather_elements_gpu.cpp create mode 100644 inference-engine/thirdparty/clDNN/src/include/gather_elements_inst.h create mode 100644 inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp diff --git a/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp b/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp index 81a055a4a09a5b..4081c5be17029e 100644 --- a/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp +++ b/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp @@ -204,9 +204,11 @@ REGISTER_FACTORY(v5, Loop); // ------------------------------ Supported v6 ops ------------------------------ // REGISTER_FACTORY(v6, CTCGreedyDecoderSeqLen); REGISTER_FACTORY(v6, MVN); +REGISTER_FACTORY(v6, GatherElements); // ------------------------------ Supported v7 ops ------------------------------ // REGISTER_FACTORY(v7, Gather); +// REGISTER_FACTORY(v7, GatherElements); // ------------------------------ Supported v8 ops ------------------------------ // REGISTER_FACTORY(v8, Gather); diff --git a/inference-engine/src/cldnn_engine/ops/gather_elements.cpp b/inference-engine/src/cldnn_engine/ops/gather_elements.cpp new file mode 100644 index 00000000000000..75ce16463994be --- /dev/null +++ b/inference-engine/src/cldnn_engine/ops/gather_elements.cpp @@ -0,0 +1,37 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "cldnn_program.h" +#include "cldnn_common_utils.h" + +#include "ngraph/op/gather_elements.hpp" +#include "ngraph/op/constant.hpp" + +#include "api/gather_elements.hpp" + +namespace CLDNNPlugin { + +void CreateGatherElementsOp(Program& p, const std::shared_ptr& op) { + p.ValidateInputs(op, {2}); + auto inputPrimitives = p.GetInputPrimitiveIDs(op); + std::string layerName = layer_type_name_ID(op); + + // 아마도 필요 없을 듯 + int32_t indices_rank = static_cast(op->get_input_shape(1).size()); + + auto axis = op->get_axis(); + + auto primitive = cldnn::gather_elements(layerName, + inputPrimitives[0], + inputPrimitives[1], + indices_rank, + axis); + + p.AddPrimitive(primitive); + p.AddPrimitiveToProfiler(op); +} + +REGISTER_FACTORY_IMPL(v6, GatherElements); + +} // namespace CLDNNPlugin \ No newline at end of file diff --git a/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp new file mode 100644 index 00000000000000..d7ca24b2a85219 --- /dev/null +++ b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp @@ -0,0 +1,81 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// // + +// #include +// #include + +// #include "single_layer_tests/gather_elements.hpp" +// #include "common_test_utils/test_constants.hpp" + +// using namespace LayerTestsDefinitions; +// using namespace ngraph::opset6; + +// namespace { + +// const std::vector inputPrecisions = { +// InferenceEngine::Precision::FP32, +// InferenceEngine::Precision::FP16, +// InferenceEngine::Precision::I32, +// }; + +// const std::vector idxPrecisions = { +// InferenceEngine::Precision::I32, +// InferenceEngine::Precision::I64, +// }; + +// set1 +// const auto gatherNDArgsSubset1 = ::testing::Combine( +// ::testing::ValuesIn(std::vector>( +// { {2, 2}, {2, 3, 4} })), // Data shape +// ::testing::ValuesIn(std::vector>( +// { {2, 1}, {2, 1, 1} })), // Indices shape +// ::testing::ValuesIn(std::vector({ 0, 1 })) // Batch dims +// ); + +// INSTANTIATE_TEST_CASE_P(smoke_GatherND_set1, GatherNDLayerTest, +// ::testing::Combine( +// gatherNDArgsSubset1, +// ::testing::ValuesIn(inputPrecisions), +// ::testing::ValuesIn(idxPrecisions), +// ::testing::Values(CommonTestUtils::DEVICE_GPU), +// ::testing::Values({})), +// GatherNDLayerTest::getTestCaseName); + +// // set2 +// const auto gatherNDArgsSubset2 = ::testing::Combine( +// ::testing::ValuesIn(std::vector>( +// { {15, 12, 20, 15, 2}, {15, 12, 18, 7, 17} })), // Data shape +// ::testing::ValuesIn(std::vector>( +// { {15, 12, 2}, {15, 12, 5, 9, 1, 3} })), // Indices shape +// ::testing::ValuesIn(std::vector({ 1, 2 })) // Batch dims +// ); + +// INSTANTIATE_TEST_CASE_P(smoke_GatherND_set2, GatherNDLayerTest, +// ::testing::Combine( +// gatherNDArgsSubset2, +// ::testing::ValuesIn(inputPrecisions), +// ::testing::ValuesIn(idxPrecisions), +// ::testing::Values(CommonTestUtils::DEVICE_GPU), +// ::testing::Values({})), +// GatherNDLayerTest::getTestCaseName); + +// // set3 +// const auto gatherNDArgsSubset3 = ::testing::Combine( +// ::testing::ValuesIn(std::vector>( +// { {4, 3, 2, 5, 5, 2}, {4, 3, 2, 5, 7, 2} })), // Data shape +// ::testing::ValuesIn(std::vector>( +// { {4, 3, 2, 5, 1}, {4, 3, 2, 5, 6, 2} })), // Indices shape +// ::testing::ValuesIn(std::vector({ 3, 4 })) // Batch dims +// ); + +// INSTANTIATE_TEST_CASE_P(smoke_GatherND_set3, GatherNDLayerTest, +// ::testing::Combine( +// gatherNDArgsSubset3, +// ::testing::ValuesIn(inputPrecisions), +// ::testing::ValuesIn(idxPrecisions), +// ::testing::Values(CommonTestUtils::DEVICE_GPU), +// ::testing::Values({})), +// GatherNDLayerTest::getTestCaseName); + +// } // namespace diff --git a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp new file mode 100644 index 00000000000000..eea88d4abf3183 --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp @@ -0,0 +1,15 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "shared_test_classes/single_layer/gather_elements.hpp" + +namespace LayerTestsDefinitions { + +TEST_P(GatherElementsLayerTest, CompareWithRefs) { + Run(); +} + +} // namespace LayerTestsDefinitions diff --git a/inference-engine/thirdparty/clDNN/api/gather_elements.hpp b/inference-engine/thirdparty/clDNN/api/gather_elements.hpp new file mode 100644 index 00000000000000..a945a156cb1bde --- /dev/null +++ b/inference-engine/thirdparty/clDNN/api/gather_elements.hpp @@ -0,0 +1,57 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#pragma once +#include "primitive.hpp" + +namespace cldnn { +/// @addtogroup cpp_api C++ API +/// @{ +/// @addtogroup cpp_topology Network Topology +/// @{ +/// @addtogroup cpp_primitives Primitives +/// @{ + +/// @brief +/// @details +struct gather_elements : public primitive_base { + CLDNN_DECLARE_PRIMITIVE(gather_elements) + + /// @brief Constructs gather_elements primitive. + /// @param id This primitive id. + /// @param data Input data primitive id. + /// @param indices Input indexes primitive id. + /// @param indices_rank Rank of indices. + /// @param axis An attribute of GatherElements. Required. + gather_elements(const primitive_id& id, + const primitive_id& data, + const primitive_id& indices, + const uint8_t indices_rank, + const uint8_t axis = 0, + const padding& output_padding = padding()) + : primitive_base(id, {data, indices}, output_padding), indices_rank(indices_rank), axis(axis) {} + + /// @brief indices_rank + uint8_t indices_rank; + + /// @brief Which axis to gather on. + uint8_t axis; +}; +/// @} +/// @} +/// @} +} // namespace cldnn \ No newline at end of file diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h b/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h index 7a072d998d4789..dbe6bd7004c672 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h @@ -48,6 +48,7 @@ enum class KernelType { ONE_HOT, GATHER, GATHER_ND, + GATHER_ELEMENTS, SCATTER_UPDATE, SCATTER_ND_UPDATE, SCATTER_ELEMENTS_UPDATE, diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp new file mode 100644 index 00000000000000..ceb415a2571752 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp @@ -0,0 +1,210 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +#include "gather_elements_kernel_ref.h" +#include "kernel_selector_utils.h" +#include +#include + +namespace kernel_selector { + +ParamsKey GatherElementsKernelRef::GetSupportedKey() const { + ParamsKey k; + k.EnableInputDataType(Datatype::F16); + k.EnableInputDataType(Datatype::F32); + k.EnableInputDataType(Datatype::INT32); + k.EnableOutputDataType(Datatype::F16); + k.EnableOutputDataType(Datatype::F32); + k.EnableOutputDataType(Datatype::INT32); + k.EnableOutputDataType(Datatype::INT8); + k.EnableOutputDataType(Datatype::UINT8); + k.EnableInputLayout(DataLayout::bfyx); + k.EnableOutputLayout(DataLayout::bfyx); + k.EnableInputLayout(DataLayout::bfzyx); + k.EnableOutputLayout(DataLayout::bfzyx); + k.EnableInputLayout(DataLayout::bfwzyx); + k.EnableOutputLayout(DataLayout::bfwzyx); + k.EnableTensorOffset(); + k.EnableTensorPitches(); + k.EnableBatching(); + k.EnableDifferentTypes(); + return k; +} + +static inline std::string GetOrderString(std::vector& order) { + std::string order_str = order[0]; + for (size_t i = 1; i < order.size(); i++) + order_str += ", " + order[i]; + + return order_str; +} + +static inline std::vector GetDefaultOrder(size_t size) { + std::vector default_order; + if (size <= 4) { + default_order = { "b", "f", "y", "x" }; + } else if (size == 5) { + default_order = { "b", "f", "z", "y", "x" }; + } else if (size == 6) { + default_order = { "b", "f", "w", "z", "y", "x" }; + } + + return default_order; +} + +CommonDispatchData GatherElementsKernelRef::SetDefault(const gather_elements_params& params, const optional_params&) const { + CommonDispatchData dispatchData; + + auto indices_dims = params.inputs[1].LogicalDims(); + + if (indices_dims.size() > 1) { + std::reverse(indices_dims.begin(), indices_dims.end()); + } + + indices_dims[params.indices_rank - 1] = 1; // set last dim of indices to 1 + + switch (params.inputs[1].GetLayout()) { + case DataLayout::bfyx: + dispatchData.gws = { indices_dims[3], indices_dims[2], indices_dims[1] * indices_dims[0] }; + break; + + case DataLayout::bfzyx: + dispatchData.gws = { indices_dims[4] * indices_dims[3], indices_dims[2], indices_dims[1] * indices_dims[0] }; + break; + + case DataLayout::bfwzyx: + dispatchData.gws = { indices_dims[5] * indices_dims[4], indices_dims[3] * indices_dims[2], indices_dims[1] * indices_dims[0] }; + break; + + default: + throw std::invalid_argument("Unsupported data layout for scatter elements update primitive"); + break; + } + + dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo); + + return dispatchData; +} + +static size_t GetIndicesLastDim(const gather_elements_params& params) { + // get indices dims + auto indices_dims = params.inputs[1].LogicalDims(); + + if (indices_dims.size() > 1) { + std::reverse(indices_dims.begin(), indices_dims.end()); + } + + auto indices_last_dim = indices_dims[params.indices_rank - 1]; + + return indices_last_dim; +} + +static size_t GetSliceSize(const gather_elements_params& params) { + // get input dims + auto input_dims = params.inputs[0].LogicalDims(); + + if (input_dims.size() > 1) { + std::reverse(input_dims.begin(), input_dims.end()); + } + + // get last dim of indices + auto indices_last_dim = GetIndicesLastDim(params); + + // calculate slize size which is used in kernel to copy + size_t wi_slice_size = 1; + for (size_t i = params.batch_dims + indices_last_dim; i < input_dims.size(); i++) { + wi_slice_size *= input_dims[i]; + } + + return wi_slice_size; +} + +JitConstants GatherElementsKernelRef::GetJitConstants(const gather_elements_params& params) const { + JitConstants jit = MakeBaseParamsJitConstants(params); + + jit.AddConstant(MakeJitConstant("INDICES_RANK", params.indices_rank)); + jit.AddConstant(MakeJitConstant("BATCH_DIMS", params.batch_dims)); + jit.AddConstant(MakeJitConstant("WI_SLICE_SIZE", GetSliceSize(params))); + jit.AddConstant(MakeJitConstant("INDICES_LAST_DIM", GetIndicesLastDim(params))); + + if (!params.fused_ops.empty()) { + FusedOpsConfiguration conf = { "", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() }; + jit.Merge(MakeFusedOpsJitConstants(params, { conf })); + } + + return jit; +} + +bool GatherElementsKernelRef::Validate(const Params& p, const optional_params& o) const { + if (p.GetType() != KernelType:: GATHER_ELEMENTS || o.GetType() != KernelType::GATHER_ELEMENTS) { + return false; + } + + const gather_elements_params& params = static_cast(p); + auto input_dims = params.inputs[0].LogicalDims(); + auto indices_dims = params.inputs[1].LogicalDims(); + auto indices_rank = params.indices_rank; + auto batch_dims = params.batch_dims; + + std::reverse(input_dims.begin(), input_dims.end()); + std::reverse(indices_dims.begin(), indices_dims.end()); + + if (indices_rank < 1) { + return false; + } + + if (batch_dims + indices_dims[indices_rank - 1] > input_dims.size()) { + return false; + } + + if (batch_dims >= std::min(input_dims.size(), static_cast(indices_rank))) { + return false; + } + + for (uint8_t i = 0; i < batch_dims; i++) { + if (input_dims[i] != indices_dims[i]) { + return false; + } + } + + for (auto& fused_op : params.fused_ops) { + if (!IsFusedPrimitiveSupported(fused_op)) + return false; + } + + return true; +} + +KernelsData GatherElementsKernelRef::GetKernelsData(const Params& params, const optional_params& options) const { + if (!Validate(params, options)) { + return {}; + } + + KernelData kd = KernelData::Default(params); + gather_elements_params& newParams = *static_cast(kd.params.get()); + + auto dispatchData = SetDefault(newParams, options); + auto cldnn_jit = GetJitConstants(newParams); + + auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options); + auto jit = CreateJit(kernelName, cldnn_jit, entry_point); + auto& kernel = kd.kernels[0]; + FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 2, GetFusedPrimitiveInputsCount(params)); + + return { kd }; +} + +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h new file mode 100644 index 00000000000000..a6097e4ccaf400 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h @@ -0,0 +1,61 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +#pragma once + +#include "kernel_base_opencl.h" + +namespace kernel_selector { +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// gather_elements_params +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +struct gather_elements_params : public base_params { + gather_elements_params() : base_params(KernelType::GATHER_ELEMENTS), indices_rank(0), batch_dims(0) {} + + uint8_t indices_rank; + + uint8_t batch_dims; + uint8_t axis; + + virtual ParamsKey GetParamsKey() const { return base_params::GetParamsKey(); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// gather_elements_optional_params +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +struct gather_elements_optional_params : optional_params { + gather_elements_optional_params() : optional_params(KernelType::GATHER_ELEMENTS) {} +}; + +class GatherElementsKernelRef : public KernelBaseOpenCL { +public: + GatherElementsKernelRef() : KernelBaseOpenCL("gather_elements_ref") {} + virtual ~GatherElementsKernelRef() {} + virtual JitConstants GetJitConstants(const gather_elements_params& params) const; + virtual CommonDispatchData SetDefault(const gather_elements_params& params, const optional_params&) const; + KernelsData GetKernelsData(const Params& params, const optional_params& options) const; + ParamsKey GetSupportedKey() const override; + std::vector GetSupportedFusedOps() const override { + return { FusedOpType::QUANTIZE, + FusedOpType::SCALE, + FusedOpType::ACTIVATION, + FusedOpType::ELTWISE }; + } + +protected: + bool Validate(const Params& p, const optional_params& o) const override; +}; +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.cpp new file mode 100644 index 00000000000000..361e89e6ad5c2b --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.cpp @@ -0,0 +1,27 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +#include "gather_elements_kernel_selector.h" +#include "gather_elements_kernel_ref.h" + +namespace kernel_selector { + +gather_elements_kernel_selector::gather_elements_kernel_selector() { Attach(); } + +KernelsData gather_elements_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const { + return GetNaiveBestKernel(params, options, KernelType::GATHER_ELEMENTS); +} +} // namespace kernel_selector \ No newline at end of file diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.h new file mode 100644 index 00000000000000..7395a55ca720c8 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.h @@ -0,0 +1,35 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +#pragma once + +#include "kernel_selector.h" + +namespace kernel_selector { +class gather_elements_kernel_selector : public kernel_selector_base { +public: + static gather_elements_kernel_selector& Instance() { + static gather_elements_kernel_selector instance_; + return instance_; + } + + gather_elements_kernel_selector(); + + virtual ~gather_elements_kernel_selector() {} + + KernelsData GetBestKernels(const Params& params, const optional_params& options) const override; +}; +} // namespace kernel_selector \ No newline at end of file diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl new file mode 100644 index 00000000000000..91cb7d9be773e9 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl @@ -0,0 +1,231 @@ +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "include/fetch.cl" + +#define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order) +#define GET_OUTPUT_INDEX(out_order) OUTPUT_GET_INDEX(out_order) + +#if INPUT0_DIMS == 4 + #define IN_ORDER in_b,in_f,in_y,in_x +#elif INPUT0_DIMS == 5 + #define IN_ORDER in_b,in_f,in_z,in_y,in_x +#else + #define IN_ORDER in_b,in_f,in_w,in_z,in_y,in_x +#endif + +#if INPUT1_DIMS == 4 + #define IDX_ORDER idx_b,idx_f,idx_y,idx_x +#elif INPUT1_DIMS == 5 + #define IDX_ORDER idx_b,idx_f,idx_z,idx_y,idx_x +#else + #define IDX_ORDER idx_b,idx_f,idx_w,idx_z,idx_y,idx_x +#endif + +#if OUTPUT_DIMS == 4 + #define OUT_ORDER out_b,out_f,out_y,out_x +#elif OUTPUT_DIMS == 5 + #define OUT_ORDER out_b,out_f,out_z,out_y,out_x +#else + #define OUT_ORDER out_b,out_f,out_w,out_z,out_y,out_x +#endif + +#define INDICES_MAX_DIM 6 + +KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, + const __global INPUT1_TYPE* indices, + __global OUTPUT_TYPE* output +#if HAS_FUSED_OPS_DECLS + , FUSED_OPS_DECLS +#endif +) +{ + + const uint dim0 = get_global_id(0); + const uint dim1 = get_global_id(1); + const uint dim2 = get_global_id(2); + + // Calculate indice index + const uint F_NUM = (INDICES_RANK == 2) ? 1 : INPUT1_FEATURE_NUM; + const uint idx_f = dim2 % F_NUM; + const uint idx_b = dim2 / F_NUM; + + #if INPUT1_DIMS == 4 + const uint idx_x = dim0; + const uint idx_y = dim1; + const uint idx_z = 0; + const uint idx_w = 0; + + const uint idx_arr[INPUT1_DIMS*2] = {idx_b, idx_f, idx_y, idx_x, 0, 0, 0, 0}; + const uint idx_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Y, INPUT1_SIZE_X}; + #elif INPUT1_DIMS == 5 + const uint X_NUM = (INDICES_RANK == 5) ? 1 : INPUT1_SIZE_X; + + const uint idx_x = dim0 % X_NUM; + const uint idx_y = dim0 / X_NUM; + const uint idx_z = dim1; + const uint idx_w = 0; + + const uint idx_arr[INPUT1_DIMS*2] = {idx_b, idx_f, idx_z, idx_y, idx_x, 0, 0, 0, 0, 0}; + const uint idx_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X}; + #else + const uint X_NUM = (INDICES_RANK == 6) ? 1 : INPUT1_SIZE_X; + const uint Z_NUM = (INDICES_RANK == 4) ? 1 : INPUT1_SIZE_Z; + + const uint idx_x = dim0 % X_NUM; + const uint idx_y = dim0 / X_NUM; + const uint idx_z = dim1 % Z_NUM; + const uint idx_w = dim1 / Z_NUM; + + const uint idx_arr[INPUT1_DIMS*2] = {idx_b, idx_f, idx_w, idx_z, idx_y, idx_x, 0, 0, 0, 0, 0, 0}; + const uint idx_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X}; + #endif + + const int idx = GET_UPDATES_INDEX(INPUT1, IDX_ORDER); + + // Calculate data index + uint indices_val[INDICES_MAX_DIM + BATCH_DIMS]; + for (int i = 0; i < INDICES_MAX_DIM + BATCH_DIMS; i++) { + indices_val[i] = 0; + } + + for (int i = 0; i < BATCH_DIMS; i++) { + indices_val[i] = idx_arr[i]; + } + + for (int i = 0; i < INDICES_LAST_DIM; i++) { + indices_val[i + BATCH_DIMS] = indices[idx+i]; + } + + #if INPUT0_DIMS == 4 + const uint in_x = indices_val[3]; + const uint in_y = indices_val[2]; + #elif INPUT0_DIMS == 5 + const uint in_x = indices_val[4]; + const uint in_y = indices_val[3]; + const uint in_z = indices_val[2]; + #else + const uint in_x = indices_val[5]; + const uint in_y = indices_val[4]; + const uint in_z = indices_val[3]; + const uint in_w = indices_val[2]; + #endif + const uint in_f = indices_val[1]; + const uint in_b = indices_val[0]; + + const uint data_idx = GET_UPDATES_INDEX(INPUT0, IN_ORDER); + + // Calculate output index + #if BATCH_DIMS <= 1 + const uint out_x = idx_x; + const uint out_y = idx_y; + const uint out_z = idx_z; + const uint out_w = idx_w; + const uint out_f = idx_f; + const uint out_b = idx_b; + #else + uint pitch_acc = 1; + uint output_batch_size = 0; + for (int i = BATCH_DIMS - 1; i >= 0; i--) { + output_batch_size += (idx_arr[i] * pitch_acc); + pitch_acc *= idx_dim[i]; + } + + #if OUTPUT_DIMS == 4 + const uint out_x = idx_arr[BATCH_DIMS+2]; + const uint out_y = idx_arr[BATCH_DIMS+1]; + #elif OUTPUT_DIMS == 5 + const uint out_x = idx_arr[BATCH_DIMS+3]; + const uint out_y = idx_arr[BATCH_DIMS+2]; + const uint out_z = idx_arr[BATCH_DIMS+1]; + #else + const uint out_x = idx_arr[BATCH_DIMS+4]; + const uint out_y = idx_arr[BATCH_DIMS+3]; + const uint out_z = idx_arr[BATCH_DIMS+2]; + const uint out_w = idx_arr[BATCH_DIMS+1]; + #endif + const uint out_f = idx_arr[BATCH_DIMS+0]; + const uint out_b = output_batch_size; + #endif + + const uint output_idx = GET_OUTPUT_INDEX(OUT_ORDER); + + // Copy data to output as slice size + #if HAS_FUSED_OPS + #if OUTPUT_DIMS == 4 + const uint y_pitch = OUTPUT_SIZE_X; + const uint f_pitch = y_pitch * OUTPUT_SIZE_Y; + #elif OUTPUT_DIMS == 5 + const uint y_pitch = OUTPUT_SIZE_X; + const uint z_pitch = y_pitch * OUTPUT_SIZE_Y; + const uint f_pitch = z_pitch * OUTPUT_SIZE_Z; + #else + const uint y_pitch = OUTPUT_SIZE_X; + const uint z_pitch = y_pitch * OUTPUT_SIZE_Y; + const uint w_pitch = z_pitch * OUTPUT_SIZE_Z; + const uint f_pitch = w_pitch * OUTPUT_SIZE_W; + #endif + const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM; + #endif + + for (int i = 0; i < WI_SLICE_SIZE; i++) { + uint dst_idx = output_idx + i; + INPUT0_TYPE val = data[data_idx + i]; + + #if HAS_FUSED_OPS + const uint b_remain = dst_idx % b_pitch; + const uint f_remain = b_remain % f_pitch; + #if OUTPUT_DIMS == 4 + const uint y_remain = f_remain % y_pitch; + + const uint y = f_remain / y_pitch; + #elif OUTPUT_DIMS == 5 + const uint z_remain = f_remain % z_pitch; + const uint y_remain = z_remain % y_pitch; + + const uint z = f_remain / z_pitch; + const uint y = z_remain / y_pitch; + #else + const uint w_remain = f_remain % w_pitch; + const uint z_remain = w_remain % z_pitch; + const uint y_remain = z_remain % y_pitch; + + const uint w = f_remain / w_pitch; + const uint z = w_remain / z_pitch; + const uint y = z_remain / y_pitch; + #endif + const uint b = dst_idx / b_pitch; + const uint f = b_remain / f_pitch; + const uint x = y_remain; + + #if FUSED_OPS_CAN_USE_PRELOAD + FUSED_OPS_PRELOAD; + FUSED_OPS_CALC; + #else + FUSED_OPS; + #endif + + output[dst_idx] = FUSED_OPS_RESULT; + #else + output[dst_idx] = ACTIVATION(val, ACTIVATION_PARAMS); + #endif + } +} + +#undef INDICES_MAX_DIM +#undef GET_UPDATES_INDEX +#undef GET_OUTPUT_INDEX +#undef OUT_ORDER +#undef IDX_ORDER +#undef IN_ORDER diff --git a/inference-engine/thirdparty/clDNN/src/gather_elements.cpp b/inference-engine/thirdparty/clDNN/src/gather_elements.cpp new file mode 100644 index 00000000000000..e8209186708b63 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/src/gather_elements.cpp @@ -0,0 +1,114 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +#include "gather_elements_inst.h" + +#include "primitive_type_base.h" +#include "error_handler.h" +#include "json_object.h" +#include + +namespace cldnn { +primitive_type_id gather_elements::type_id() { + static primitive_type_base instance; + return &instance; +} + +layout gather_elements_inst::calc_output_layout(gather_elements_node const& node) { + auto op = node.get_primitive(); + + auto input_layout_origin = node.input(0).get_output_layout(); + auto indices_layout_origin = node.input(1).get_output_layout(); + + auto input_layout = input_layout_origin.size.sizes(input_layout_origin.format); + auto indices_layout = indices_layout_origin.size.sizes(indices_layout_origin.format); + + // const size_t input_dims = input_layout.size(); + + // const auto indices_rank = op->indices_rank; + const auto axis = op->axis; + + // calculate initial output shape + std::vector output_sizes; + + // for (uint8_t x = 0; x < indices_rank - 1; x++) { + // output_sizes.push_back(indices_layout[x]); + // } + + // const size_t indices_last_dim = indices_layout[indices_rank - 1]; + // for (size_t x = static_cast(axis + indices_last_dim); x < input_dims; x++) { + // output_sizes.push_back(input_layout[x]); + // } + + // // calculate batch_size by axis + // int batch_size = 1; + // for (uint8_t x = 0; x < axis; x++) { + // batch_size *= output_sizes[x]; + // } + + // create final output shape by axis + std::vector final_output_sizes; + + // if (axis > 0) { + // final_output_sizes.push_back(batch_size); + // } + + for (size_t x = static_cast(axis); x < output_sizes.size(); x++) { + final_output_sizes.push_back(output_sizes[x]); + } + + auto output_format = cldnn::format::bfyx; + if (final_output_sizes.size() >= 6) { + output_format = cldnn::format::bfwzyx; + } else if (final_output_sizes.size() == 5) { + output_format = cldnn::format::bfzyx; + } + + auto output_sizes_tensor = tensor(tensor(final_output_sizes).sizes(output_format)); + auto padding = op->output_padding; + + + // if (node.has_fused_primitives()) { + // input_layout_origin.data_type = node.get_fused_output_layout().data_type; + // } + + return layout(input_layout_origin.data_type, output_format, output_sizes_tensor, padding); +} + +std::string gather_elements_inst::to_string(gather_elements_node const& node) { + auto desc = node.get_primitive(); + auto node_info = node.desc_to_json(); + auto& input = node.input(); + + std::stringstream primitive_description; + + json_composite gather_elements_info; + gather_elements_info.add("input id", input.id()); + gather_elements_info.add("input shape", node.input(0).get_output_layout().size.to_string()); + gather_elements_info.add("indices shape", node.input(1).get_output_layout().size.to_string()); + // gather_elements_info.add("indices rank", desc->indices_rank); + gather_elements_info.add("axis", desc->axis); + // gather_elements_info.add("output shape", calc_output_layout(node).size.to_string()); + + node_info->add("gather_elements info", gather_elements_info); + node_info->dump(primitive_description); + + return primitive_description.str(); +} + +gather_elements_inst::typed_primitive_inst(network_impl& network, gather_elements_node const& node) : parent(network, node) {} + +} // namespace cldnn diff --git a/inference-engine/thirdparty/clDNN/src/gpu/gather_elements_gpu.cpp b/inference-engine/thirdparty/clDNN/src/gpu/gather_elements_gpu.cpp new file mode 100644 index 00000000000000..1789a0d5de7a3f --- /dev/null +++ b/inference-engine/thirdparty/clDNN/src/gpu/gather_elements_gpu.cpp @@ -0,0 +1,78 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +#include "gather_elements_inst.h" +#include "primitive_gpu_base.h" +#include "implementation_map.h" +#include "kernel_selector_helper.h" +#include "gather/gather_elements_kernel_selector.h" +#include "gather/gather_elements_kernel_ref.h" +#include "error_handler.h" + +using namespace cldnn; + +namespace cldnn { +namespace gpu { + +struct gather_elements_gpu : typed_primitive_gpu_impl { + using parent = typed_primitive_gpu_impl; + using parent::parent; + +public: + static primitive_impl* create(const gather_elements_node& arg) { + auto gather_elements_params = get_default_params(arg); + auto gather_elements_optional_params = + get_default_optional_params(arg.get_program()); + + // gather_elements_params.indices_rank = arg.get_primitive()->indices_rank; + gather_elements_params.axis = arg.get_primitive()->axis; + + gather_elements_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout())); + + auto& kernel_selector = kernel_selector::gather_elements_kernel_selector::Instance(); + auto best_kernels = kernel_selector.GetBestKernels(gather_elements_params, gather_elements_optional_params); + + CLDNN_ERROR_BOOL(arg.id(), + "Best_kernel.empty()", + best_kernels.empty(), + "Cannot find a proper kernel with this arguments"); + + auto gather_elements = new gather_elements_gpu(arg, best_kernels[0]); + + return gather_elements; + } +}; + +namespace detail { + +attach_gather_elements_gpu::attach_gather_elements_gpu() { + auto val_fw = gather_elements_gpu::create; + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), val_fw); + + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfzyx), val_fw); + + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfwzyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfwzyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfwzyx), val_fw); +} + +} // namespace detail +} // namespace gpu +} // namespace cldnn diff --git a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp index ffabb96f2e40bf..7b82fb54e3c409 100644 --- a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp +++ b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp @@ -32,6 +32,7 @@ #include "space_to_depth_inst.h" #include "gather_inst.h" #include "gather_nd_inst.h" +#include "gather_elements_inst.h" #include "scatter_update_inst.h" #include "scatter_nd_update_inst.h" #include "scatter_elements_update_inst.h" @@ -200,6 +201,7 @@ void prepare_primitive_fusing::fuse_activations(program_impl &p) { !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && + !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type())) @@ -609,6 +611,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { should_fuse |= input_data.is_type(); + should_fuse |= input_data.is_type(); + should_fuse |= input_data.is_type(); should_fuse |= input_data.is_type(); @@ -677,6 +681,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { should_fuse |= input_data.is_type(); + should_fuse |= input_data.is_type(); + should_fuse |= input_data.is_type(); should_fuse |= input_data.is_type(); @@ -767,6 +773,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); + should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); + should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); @@ -829,6 +837,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { (parents[i]->is_type() && eltwise_supports_fusings(parents[i]->as())) || (parents[i]->is_type()) || (parents[i]->is_type()) || + (parents[i]->is_type()) || (parents[i]->is_type()) || (parents[i]->is_type()) || (parents[i]->is_type() && pooling_supports_fusings(parents[i]->as())) || diff --git a/inference-engine/thirdparty/clDNN/src/impls/ocl/register.hpp b/inference-engine/thirdparty/clDNN/src/impls/ocl/register.hpp index 036162ed8d82fb..aaf2a777cde83a 100644 --- a/inference-engine/thirdparty/clDNN/src/impls/ocl/register.hpp +++ b/inference-engine/thirdparty/clDNN/src/impls/ocl/register.hpp @@ -94,6 +94,7 @@ REGISTER_OCL(embed); REGISTER_OCL(fully_connected); REGISTER_OCL(gather); REGISTER_OCL(gather_nd); +REGISTER_OCL(gather_elements); REGISTER_OCL(gemm); REGISTER_OCL(lrn); REGISTER_OCL(lstm_gemm); diff --git a/inference-engine/thirdparty/clDNN/src/include/gather_elements_inst.h b/inference-engine/thirdparty/clDNN/src/include/gather_elements_inst.h new file mode 100644 index 00000000000000..2b6952bdf6f015 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/src/include/gather_elements_inst.h @@ -0,0 +1,49 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#pragma once +#include "api/gather_elements.hpp" +#include "primitive_inst.h" +#include + +namespace cldnn { +template <> +struct typed_program_node : public typed_program_node_base { + using parent = typed_program_node_base; + +public: + using parent::parent; + + program_node& input(size_t index = 0) const { return get_dependency(index); } +}; + +using gather_elements_node = typed_program_node; + +template <> +class typed_primitive_inst : public typed_primitive_inst_base { + using parent = typed_primitive_inst_base; + +public: + static layout calc_output_layout(gather_elements_node const& node); + static std::string to_string(gather_elements_node const& node); + +public: + typed_primitive_inst(network_impl& network, gather_elements_node const& desc); +}; + +using gather_elements_inst = typed_primitive_inst; +} // namespace cldnn diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp new file mode 100644 index 00000000000000..aa043a1e8d6dad --- /dev/null +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp @@ -0,0 +1,1210 @@ +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include + +#include +#include +#include +#include +#include + +#include +#include + +using namespace cldnn; +using namespace ::tests; + +inline void DoTest(const engine& engine, + const cldnn::memory& input0, // data + const cldnn::memory& input1, // indices + const std::vector& expected_results, + // const int indices_rank, + const int axis) { + topology topology; + topology.add(input_layout("InputData", input0.get_layout())); + topology.add(input_layout("InputIndices", input1.get_layout())); + int indices_rank = 2; + topology.add( + gather_elements("gather_elements", "InputData", "InputIndices", indices_rank, axis) + ); + + network network(engine, topology); + + network.set_input_data("InputData", input0); + network.set_input_data("InputIndices", input1); + auto outputs = network.execute(); + auto output = outputs.at("gather_elements").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +// 4-1-1 +TEST(gather_elements_gpu_fp16, d2235_i2237_a3) { + const auto& engine = get_test_engine(); + + const int axis = 3; + auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 5 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 7 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), + FLOAT16(2), FLOAT16(0), FLOAT16(7), FLOAT16(7), FLOAT16(10), + FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(0), + + FLOAT16(5), FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), FLOAT16(9), + FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(7), + + + FLOAT16(10), FLOAT16(8), FLOAT16(2), FLOAT16(0), FLOAT16(8), + FLOAT16(3), FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), + + FLOAT16(0), FLOAT16(6), FLOAT16(9), FLOAT16(2), FLOAT16(4), + FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), + FLOAT16(1), FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(2), FLOAT16(4), FLOAT16(3), FLOAT16(4), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(4), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(3), FLOAT16(1), FLOAT16(4), FLOAT16(2), FLOAT16(4), + + FLOAT16(2), FLOAT16(1), FLOAT16(3), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(4), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(3), FLOAT16(4), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(3), FLOAT16(4), + + + FLOAT16(3), FLOAT16(4), FLOAT16(4), FLOAT16(1), FLOAT16(0), FLOAT16(3), FLOAT16(2), + FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(4), FLOAT16(3), FLOAT16(0), FLOAT16(2), + + FLOAT16(2), FLOAT16(3), FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(3), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(3), FLOAT16(3), FLOAT16(4), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(3), FLOAT16(3), FLOAT16(4), FLOAT16(3), FLOAT16(4), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(10), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), + FLOAT16(5), FLOAT16(5), FLOAT16(0), FLOAT16(5), FLOAT16(0), FLOAT16(9), FLOAT16(0), + + FLOAT16(0), FLOAT16(7), FLOAT16(4), FLOAT16(0), FLOAT16(7), FLOAT16(0), FLOAT16(0), + FLOAT16(7), FLOAT16(4), FLOAT16(6), FLOAT16(10), FLOAT16(9), FLOAT16(6), FLOAT16(6), + FLOAT16(7), FLOAT16(7), FLOAT16(5), FLOAT16(7), FLOAT16(5), FLOAT16(4), FLOAT16(7), + + + FLOAT16(0), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(10), FLOAT16(0), FLOAT16(2), + FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(3), FLOAT16(4), FLOAT16(3), + FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(7), FLOAT16(8), FLOAT16(2), FLOAT16(7), + + FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(9), FLOAT16(9), FLOAT16(0), FLOAT16(2), + FLOAT16(5), FLOAT16(5), FLOAT16(8), FLOAT16(3), FLOAT16(3), FLOAT16(3), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(10), FLOAT16(10), FLOAT16(0), FLOAT16(10), FLOAT16(0), + }; + + DoTest(engine,input0, input1, expected_results, axis); +} + +// 4-1-2 +TEST(gather_elements_gpu_fp16, d2235_i2237_an1) { + const auto& engine = get_test_engine(); + + const int axis = -1; + auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 5 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 7 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), + FLOAT16(2), FLOAT16(0), FLOAT16(7), FLOAT16(7), FLOAT16(10), + FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(0), + + FLOAT16(5), FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), FLOAT16(9), + FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(7), + + + FLOAT16(10), FLOAT16(8), FLOAT16(2), FLOAT16(0), FLOAT16(8), + FLOAT16(3), FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), + + FLOAT16(0), FLOAT16(6), FLOAT16(9), FLOAT16(2), FLOAT16(4), + FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), + FLOAT16(1), FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(2), FLOAT16(4), FLOAT16(3), FLOAT16(4), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(4), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(3), FLOAT16(1), FLOAT16(4), FLOAT16(2), FLOAT16(4), + + FLOAT16(2), FLOAT16(1), FLOAT16(3), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(4), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(3), FLOAT16(4), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(3), FLOAT16(4), + + + FLOAT16(3), FLOAT16(4), FLOAT16(4), FLOAT16(1), FLOAT16(0), FLOAT16(3), FLOAT16(2), + FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(4), FLOAT16(3), FLOAT16(0), FLOAT16(2), + + FLOAT16(2), FLOAT16(3), FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(3), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(3), FLOAT16(3), FLOAT16(4), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(3), FLOAT16(3), FLOAT16(4), FLOAT16(3), FLOAT16(4), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(10), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), + FLOAT16(5), FLOAT16(5), FLOAT16(0), FLOAT16(5), FLOAT16(0), FLOAT16(9), FLOAT16(0), + + FLOAT16(0), FLOAT16(7), FLOAT16(4), FLOAT16(0), FLOAT16(7), FLOAT16(0), FLOAT16(0), + FLOAT16(7), FLOAT16(4), FLOAT16(6), FLOAT16(10), FLOAT16(9), FLOAT16(6), FLOAT16(6), + FLOAT16(7), FLOAT16(7), FLOAT16(5), FLOAT16(7), FLOAT16(5), FLOAT16(4), FLOAT16(7), + + + FLOAT16(0), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(10), FLOAT16(0), FLOAT16(2), + FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(3), FLOAT16(4), FLOAT16(3), + FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(7), FLOAT16(8), FLOAT16(2), FLOAT16(7), + + FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(9), FLOAT16(9), FLOAT16(0), FLOAT16(2), + FLOAT16(5), FLOAT16(5), FLOAT16(8), FLOAT16(3), FLOAT16(3), FLOAT16(3), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(10), FLOAT16(10), FLOAT16(0), FLOAT16(10), FLOAT16(0), + }; + + DoTest(engine,input0, input1, expected_results, axis); +} + +// 4-2 +TEST(gather_elements_gpu_fp16, d2329_i2329_a2) { + const auto& engine = get_test_engine(); + + const int axis = 2; + auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 2, 9 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 2, 9 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), FLOAT16(7), + FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(0), + + FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), FLOAT16(9), FLOAT16(5), FLOAT16(1), + FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), FLOAT16(2), FLOAT16(0), FLOAT16(8), FLOAT16(3), + + FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), + FLOAT16(0), FLOAT16(6), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), + + + FLOAT16(3), FLOAT16(1), FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), + FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(4), + + FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(8), FLOAT16(8), FLOAT16(9), + FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(7), FLOAT16(10), FLOAT16(9), + + FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(6), FLOAT16(9), FLOAT16(4), FLOAT16(9), FLOAT16(2), + FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(1), FLOAT16(1), FLOAT16(6), FLOAT16(8), FLOAT16(0), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(1), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(7), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), FLOAT16(7), + FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(7), FLOAT16(6), FLOAT16(2), FLOAT16(0), FLOAT16(5), FLOAT16(1), + FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(10), FLOAT16(8), FLOAT16(2), FLOAT16(9), FLOAT16(5), FLOAT16(3), + FLOAT16(6), FLOAT16(8), FLOAT16(9), FLOAT16(4), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), + FLOAT16(6), FLOAT16(8), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(7), + FLOAT16(3), FLOAT16(1), FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(5), + FLOAT16(3), FLOAT16(0), FLOAT16(7), FLOAT16(9), FLOAT16(10), FLOAT16(0), FLOAT16(8), FLOAT16(10), FLOAT16(9), + FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(7), FLOAT16(8), FLOAT16(9), + FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(3), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(9), FLOAT16(2), + FLOAT16(2), FLOAT16(5), FLOAT16(3), FLOAT16(5), FLOAT16(1), FLOAT16(1), FLOAT16(4), FLOAT16(8), FLOAT16(0), + }; + + DoTest(engine,input0, input1, expected_results, axis); +} + +// 4-3 +TEST(gather_elements_gpu_fp16, d3238_i2238_a0) { + const auto& engine = get_test_engine(); + + const int axis = 0; + auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 3, 8 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 8 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), + FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(5), + FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), + + FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(0), FLOAT16(8), FLOAT16(3), FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(9), + + + FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(3), + FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(4), + + FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(8), FLOAT16(8), + FLOAT16(9), FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(7), + FLOAT16(10), FLOAT16(9), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(6), FLOAT16(9), + + + FLOAT16(4), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(1), + FLOAT16(1), FLOAT16(6), FLOAT16(8), FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(8), + FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(9), FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(1), + + FLOAT16(1), FLOAT16(3), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(7), FLOAT16(10), FLOAT16(2), + FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(4), + FLOAT16(5), FLOAT16(1), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(10), FLOAT16(6), FLOAT16(1), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), + + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), + + + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), + + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(2), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(4), FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), + FLOAT16(1), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(5), FLOAT16(3), + FLOAT16(6), FLOAT16(5), FLOAT16(6), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(6), FLOAT16(1), + FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(4), FLOAT16(4), FLOAT16(7), FLOAT16(8), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(9), FLOAT16(8), FLOAT16(4), FLOAT16(4), + FLOAT16(5), FLOAT16(1), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(0), FLOAT16(6), FLOAT16(1), + FLOAT16(2), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(7), + FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(0), FLOAT16(5), FLOAT16(0), FLOAT16(5), FLOAT16(3), + FLOAT16(6), FLOAT16(9), FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(2), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(3), FLOAT16(0), FLOAT16(4), FLOAT16(10), FLOAT16(7), FLOAT16(10), FLOAT16(2), + FLOAT16(9), FLOAT16(3), FLOAT16(0), FLOAT16(7), FLOAT16(6), FLOAT16(8), FLOAT16(8), FLOAT16(4), + FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(3), FLOAT16(3), FLOAT16(10), FLOAT16(6), FLOAT16(1), + }; + + DoTest(engine,input0, input1, expected_results, axis); +} + +// 5-1 +TEST(gather_elements_gpu_fp16, d32223_i32228_a4) { + const auto& engine = get_test_engine(); + + const int axis = 4; + auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 2, 2, 3 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 2, 2, 8 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), + FLOAT16(5), FLOAT16(5), FLOAT16(2), + + FLOAT16(0), FLOAT16(7), FLOAT16(7), + FLOAT16(10), FLOAT16(4), FLOAT16(5), + + FLOAT16(9), FLOAT16(0), FLOAT16(0), + FLOAT16(5), FLOAT16(7), FLOAT16(0), + + FLOAT16(4), FLOAT16(0), FLOAT16(4), + FLOAT16(7), FLOAT16(6), FLOAT16(10), + + FLOAT16(9), FLOAT16(5), FLOAT16(1), + FLOAT16(7), FLOAT16(4), FLOAT16(7), + FLOAT16(10), FLOAT16(8), FLOAT16(2), + FLOAT16(0), FLOAT16(8), FLOAT16(3), + FLOAT16(6), FLOAT16(8), FLOAT16(10), + FLOAT16(4), FLOAT16(2), FLOAT16(10), + FLOAT16(7), FLOAT16(8), FLOAT16(7), + FLOAT16(0), FLOAT16(6), FLOAT16(9), + + FLOAT16(2), FLOAT16(4), FLOAT16(8), + FLOAT16(5), FLOAT16(2), FLOAT16(3), + FLOAT16(3), FLOAT16(1), FLOAT16(5), + FLOAT16(9), FLOAT16(10), FLOAT16(0), + FLOAT16(9), FLOAT16(5), FLOAT16(5), + FLOAT16(3), FLOAT16(10), FLOAT16(5), + FLOAT16(2), FLOAT16(0), FLOAT16(10), + FLOAT16(0), FLOAT16(5), FLOAT16(4), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), + FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(0), FLOAT16(7), + FLOAT16(4), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(10), FLOAT16(10), FLOAT16(4), FLOAT16(5), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(9), FLOAT16(0), FLOAT16(9), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(0), + FLOAT16(0), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), + FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(10), FLOAT16(10), FLOAT16(7), FLOAT16(6), FLOAT16(6), + FLOAT16(1), FLOAT16(1), FLOAT16(5), FLOAT16(5), FLOAT16(9), FLOAT16(1), FLOAT16(9), FLOAT16(9), + FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(8), FLOAT16(2), FLOAT16(8), FLOAT16(2), FLOAT16(10), FLOAT16(10), FLOAT16(8), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(8), FLOAT16(8), FLOAT16(3), FLOAT16(0), FLOAT16(0), + FLOAT16(10), FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(8), FLOAT16(6), FLOAT16(6), FLOAT16(10), + FLOAT16(10), FLOAT16(4), FLOAT16(10), FLOAT16(4), FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(4), + FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(7), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(7), + FLOAT16(0), FLOAT16(0), FLOAT16(6), FLOAT16(6), FLOAT16(9), FLOAT16(9), FLOAT16(6), FLOAT16(6), + FLOAT16(4), FLOAT16(2), FLOAT16(8), FLOAT16(4), FLOAT16(2), FLOAT16(4), FLOAT16(2), FLOAT16(4), + FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(2), FLOAT16(5), + FLOAT16(1), FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(3), + FLOAT16(10), FLOAT16(0), FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(0), FLOAT16(10), FLOAT16(0), + FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), + FLOAT16(5), FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(5), FLOAT16(5), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(10), FLOAT16(10), + FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(4), + }; + + DoTest(engine,input0, input1, expected_results, axis); +} + +// 5-2 +TEST(gather_elements_gpu_fp16, d23327_i23327_a3) { + const auto& engine = get_test_engine(); + + const int axis = 3; + auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 3, 2, 7 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 3, 2, 7 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), + FLOAT16(7), FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), + FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(4), + FLOAT16(7), FLOAT16(6), FLOAT16(10), FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), + FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), FLOAT16(2), FLOAT16(0), FLOAT16(8), + FLOAT16(3), FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), FLOAT16(2), FLOAT16(10), + FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(9), FLOAT16(2), + FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), + FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(10), FLOAT16(0), + FLOAT16(5), FLOAT16(4), FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(10), + FLOAT16(0), FLOAT16(8), FLOAT16(8), FLOAT16(9), FLOAT16(1), FLOAT16(0), FLOAT16(7), + FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(7), FLOAT16(10), FLOAT16(9), FLOAT16(2), + FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(6), FLOAT16(9), FLOAT16(4), FLOAT16(9), + FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(1), FLOAT16(1), + FLOAT16(6), FLOAT16(8), FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(8), + FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(9), FLOAT16(1), FLOAT16(2), FLOAT16(7), + FLOAT16(1), FLOAT16(1), FLOAT16(3), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(7), + FLOAT16(10), FLOAT16(2), FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(7), FLOAT16(1), + FLOAT16(7), FLOAT16(4), FLOAT16(4), FLOAT16(5), FLOAT16(1), FLOAT16(6), FLOAT16(9), + FLOAT16(6), FLOAT16(10), FLOAT16(6), FLOAT16(1), FLOAT16(10), FLOAT16(4), FLOAT16(1), + FLOAT16(6), FLOAT16(2), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(1), FLOAT16(2), + FLOAT16(3), FLOAT16(6), FLOAT16(1), FLOAT16(7), FLOAT16(6), FLOAT16(8), FLOAT16(2), + FLOAT16(5), FLOAT16(4), FLOAT16(2), FLOAT16(0), FLOAT16(9), FLOAT16(4), FLOAT16(1), + FLOAT16(10), FLOAT16(4), FLOAT16(1), FLOAT16(9), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(4), FLOAT16(2), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(3), FLOAT16(4), + FLOAT16(8), FLOAT16(10), FLOAT16(7), FLOAT16(2), FLOAT16(7), FLOAT16(9), FLOAT16(2), + FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(6), FLOAT16(8), FLOAT16(8), FLOAT16(5), + FLOAT16(10), FLOAT16(6), FLOAT16(4), FLOAT16(9), FLOAT16(7), FLOAT16(7), FLOAT16(10), + FLOAT16(10), FLOAT16(9), FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(1), FLOAT16(4), + FLOAT16(6), FLOAT16(9), FLOAT16(4), FLOAT16(8), FLOAT16(9), FLOAT16(7), FLOAT16(8), + FLOAT16(7), FLOAT16(8), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(0), + FLOAT16(7), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(10), FLOAT16(9), + FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(4), FLOAT16(10), FLOAT16(2), FLOAT16(4), + FLOAT16(3), FLOAT16(5), FLOAT16(9), FLOAT16(4), FLOAT16(5), FLOAT16(8), FLOAT16(4), + FLOAT16(2), FLOAT16(10), FLOAT16(1), FLOAT16(6), FLOAT16(6), FLOAT16(0), FLOAT16(0), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(1), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(7), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(1), FLOAT16(7), + FLOAT16(0), FLOAT16(5), FLOAT16(10), FLOAT16(9), FLOAT16(4), FLOAT16(0), FLOAT16(7), + FLOAT16(4), FLOAT16(7), FLOAT16(8), FLOAT16(10), FLOAT16(4), FLOAT16(0), FLOAT16(8), + FLOAT16(3), FLOAT16(7), FLOAT16(10), FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(10), + FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(0), FLOAT16(6), FLOAT16(3), FLOAT16(1), + FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(6), FLOAT16(3), FLOAT16(2), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(2), FLOAT16(0), FLOAT16(5), FLOAT16(0), + FLOAT16(5), FLOAT16(10), FLOAT16(5), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(0), + FLOAT16(5), FLOAT16(4), FLOAT16(3), FLOAT16(9), FLOAT16(1), FLOAT16(0), FLOAT16(10), + FLOAT16(5), FLOAT16(4), FLOAT16(8), FLOAT16(9), FLOAT16(1), FLOAT16(0), FLOAT16(7), + FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(6), FLOAT16(10), FLOAT16(9), FLOAT16(2), + FLOAT16(9), FLOAT16(6), FLOAT16(5), FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(2), + FLOAT16(2), FLOAT16(4), FLOAT16(0), FLOAT16(5), FLOAT16(3), FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(10), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(4), FLOAT16(0), FLOAT16(7), + FLOAT16(1), FLOAT16(9), FLOAT16(6), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(7), + FLOAT16(10), FLOAT16(2), FLOAT16(4), FLOAT16(3), FLOAT16(1), FLOAT16(6), FLOAT16(1), + FLOAT16(7), FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(9), + FLOAT16(6), FLOAT16(2), FLOAT16(5), FLOAT16(1), FLOAT16(10), FLOAT16(4), FLOAT16(2), + FLOAT16(6), FLOAT16(10), FLOAT16(6), FLOAT16(5), FLOAT16(10), FLOAT16(1), FLOAT16(2), + FLOAT16(3), FLOAT16(6), FLOAT16(2), FLOAT16(0), FLOAT16(9), FLOAT16(4), FLOAT16(2), + FLOAT16(5), FLOAT16(4), FLOAT16(2), FLOAT16(0), FLOAT16(9), FLOAT16(8), FLOAT16(2), + FLOAT16(4), FLOAT16(4), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(1), FLOAT16(4), + FLOAT16(4), FLOAT16(4), FLOAT16(1), FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(0), + FLOAT16(9), FLOAT16(5), FLOAT16(7), FLOAT16(2), FLOAT16(7), FLOAT16(8), FLOAT16(5), + FLOAT16(8), FLOAT16(10), FLOAT16(7), FLOAT16(6), FLOAT16(8), FLOAT16(8), FLOAT16(5), + FLOAT16(10), FLOAT16(6), FLOAT16(3), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(4), + FLOAT16(10), FLOAT16(6), FLOAT16(3), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(4), + FLOAT16(7), FLOAT16(8), FLOAT16(4), FLOAT16(8), FLOAT16(9), FLOAT16(5), FLOAT16(0), + FLOAT16(7), FLOAT16(8), FLOAT16(0), FLOAT16(8), FLOAT16(9), FLOAT16(7), FLOAT16(8), + FLOAT16(9), FLOAT16(5), FLOAT16(7), FLOAT16(4), FLOAT16(2), FLOAT16(10), FLOAT16(9), + FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(10), FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(5), FLOAT16(9), FLOAT16(4), FLOAT16(5), FLOAT16(8), FLOAT16(4), + FLOAT16(3), FLOAT16(10), FLOAT16(9), FLOAT16(4), FLOAT16(6), FLOAT16(8), FLOAT16(0), + }; + + DoTest(engine,input0, input1, expected_results, axis); +} + +// 6-1 +TEST(gather_elements_gpu_fp16, d232328_i232328_a3) { + const auto& engine = get_test_engine(); + + const int axis = 3; + auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 2, 3, 2, 8 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 2, 3, 2, 8 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), + FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(5), + FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(0), FLOAT16(8), FLOAT16(3), FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(3), + FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(8), FLOAT16(8), + FLOAT16(9), FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(7), + FLOAT16(10), FLOAT16(9), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(6), FLOAT16(9), + FLOAT16(4), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(1), + FLOAT16(1), FLOAT16(6), FLOAT16(8), FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(8), + FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(9), FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(1), + FLOAT16(1), FLOAT16(3), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(7), FLOAT16(10), FLOAT16(2), + FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(4), + FLOAT16(5), FLOAT16(1), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(10), FLOAT16(6), FLOAT16(1), + FLOAT16(10), FLOAT16(4), FLOAT16(1), FLOAT16(6), FLOAT16(2), FLOAT16(5), FLOAT16(5), FLOAT16(10), + FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(6), FLOAT16(1), FLOAT16(7), FLOAT16(6), FLOAT16(8), + FLOAT16(2), FLOAT16(5), FLOAT16(4), FLOAT16(2), FLOAT16(0), FLOAT16(9), FLOAT16(4), FLOAT16(1), + FLOAT16(10), FLOAT16(4), FLOAT16(1), FLOAT16(9), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(4), + FLOAT16(2), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(3), FLOAT16(4), FLOAT16(8), FLOAT16(10), + FLOAT16(7), FLOAT16(2), FLOAT16(7), FLOAT16(9), FLOAT16(2), FLOAT16(9), FLOAT16(5), FLOAT16(5), + FLOAT16(6), FLOAT16(8), FLOAT16(8), FLOAT16(5), FLOAT16(10), FLOAT16(6), FLOAT16(4), FLOAT16(9), + FLOAT16(7), FLOAT16(7), FLOAT16(10), FLOAT16(10), FLOAT16(9), FLOAT16(3), FLOAT16(5), FLOAT16(5), + FLOAT16(1), FLOAT16(4), FLOAT16(6), FLOAT16(9), FLOAT16(4), FLOAT16(8), FLOAT16(9), FLOAT16(7), + FLOAT16(8), FLOAT16(7), FLOAT16(8), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(0), + FLOAT16(7), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(10), FLOAT16(9), FLOAT16(9), + FLOAT16(5), FLOAT16(1), FLOAT16(4), FLOAT16(10), FLOAT16(2), FLOAT16(4), FLOAT16(3), FLOAT16(5), + FLOAT16(9), FLOAT16(4), FLOAT16(5), FLOAT16(8), FLOAT16(4), FLOAT16(2), FLOAT16(10), FLOAT16(1), + FLOAT16(6), FLOAT16(6), FLOAT16(0), FLOAT16(0), FLOAT16(8), FLOAT16(8), FLOAT16(3), FLOAT16(4), + FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(9), FLOAT16(7), FLOAT16(9), FLOAT16(1), FLOAT16(0), + FLOAT16(8), FLOAT16(6), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(10), FLOAT16(10), + FLOAT16(4), FLOAT16(2), FLOAT16(7), FLOAT16(3), FLOAT16(8), FLOAT16(8), FLOAT16(4), FLOAT16(3), + FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(10), FLOAT16(2), FLOAT16(9), FLOAT16(1), FLOAT16(4), + FLOAT16(6), FLOAT16(1), FLOAT16(9), FLOAT16(1), FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(6), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(6), + FLOAT16(0), FLOAT16(6), FLOAT16(2), FLOAT16(3), FLOAT16(7), FLOAT16(1), FLOAT16(8), FLOAT16(5), + FLOAT16(6), FLOAT16(6), FLOAT16(3), FLOAT16(7), FLOAT16(1), FLOAT16(1), FLOAT16(5), FLOAT16(9), + FLOAT16(8), FLOAT16(6), FLOAT16(8), FLOAT16(3), FLOAT16(1), FLOAT16(5), FLOAT16(3), FLOAT16(6), + FLOAT16(5), FLOAT16(4), FLOAT16(2), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(0), FLOAT16(4), FLOAT16(2), FLOAT16(7), FLOAT16(7), FLOAT16(5), FLOAT16(8), + FLOAT16(7), FLOAT16(10), FLOAT16(5), FLOAT16(10), FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(7), + FLOAT16(4), FLOAT16(6), FLOAT16(10), FLOAT16(1), FLOAT16(7), FLOAT16(3), FLOAT16(5), FLOAT16(5), + FLOAT16(9), FLOAT16(0), FLOAT16(3), FLOAT16(7), FLOAT16(6), FLOAT16(10), FLOAT16(2), FLOAT16(10), + FLOAT16(2), FLOAT16(9), FLOAT16(7), FLOAT16(5), FLOAT16(8), FLOAT16(0), FLOAT16(1), FLOAT16(7), + FLOAT16(7), FLOAT16(4), FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(7), FLOAT16(3), FLOAT16(8), + FLOAT16(1), FLOAT16(0), FLOAT16(5), FLOAT16(0), FLOAT16(1), FLOAT16(9), FLOAT16(8), FLOAT16(8), + FLOAT16(4), FLOAT16(0), FLOAT16(6), FLOAT16(5), FLOAT16(0), FLOAT16(5), FLOAT16(4), FLOAT16(2), + FLOAT16(4), FLOAT16(6), FLOAT16(7), FLOAT16(7), FLOAT16(5), FLOAT16(3), FLOAT16(8), FLOAT16(4), + FLOAT16(7), FLOAT16(3), FLOAT16(0), FLOAT16(1), FLOAT16(5), FLOAT16(8), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(7), FLOAT16(3), FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(5), + FLOAT16(4), FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(7), FLOAT16(6), FLOAT16(7), FLOAT16(3), + FLOAT16(0), FLOAT16(10), FLOAT16(5), FLOAT16(0), FLOAT16(9), FLOAT16(0), FLOAT16(4), FLOAT16(5), + FLOAT16(6), FLOAT16(8), FLOAT16(7), FLOAT16(5), FLOAT16(0), FLOAT16(1), FLOAT16(10), FLOAT16(2), + FLOAT16(3), FLOAT16(6), FLOAT16(6), FLOAT16(1), FLOAT16(6), FLOAT16(10), FLOAT16(3), FLOAT16(9), + FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(9), FLOAT16(2), FLOAT16(8), + FLOAT16(7), FLOAT16(4), FLOAT16(2), FLOAT16(7), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(6), + FLOAT16(0), FLOAT16(1), FLOAT16(6), FLOAT16(4), FLOAT16(0), FLOAT16(7), FLOAT16(4), FLOAT16(9), + FLOAT16(1), FLOAT16(10), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(8), FLOAT16(10), FLOAT16(2), + FLOAT16(3), FLOAT16(8), FLOAT16(5), FLOAT16(8), FLOAT16(7), FLOAT16(7), FLOAT16(8), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(6), FLOAT16(7), FLOAT16(6), FLOAT16(4), FLOAT16(2), FLOAT16(2), + FLOAT16(7), FLOAT16(1), FLOAT16(8), FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(1), FLOAT16(10), + FLOAT16(5), FLOAT16(6), FLOAT16(10), FLOAT16(0), FLOAT16(6), FLOAT16(7), FLOAT16(5), FLOAT16(0), + FLOAT16(4), FLOAT16(5), FLOAT16(8), FLOAT16(0), FLOAT16(4), FLOAT16(10), FLOAT16(5), FLOAT16(3), + FLOAT16(4), FLOAT16(8), FLOAT16(2), FLOAT16(1), FLOAT16(4), FLOAT16(10), FLOAT16(10), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(5), FLOAT16(1), FLOAT16(5), FLOAT16(1), FLOAT16(9), FLOAT16(4), + FLOAT16(4), FLOAT16(3), FLOAT16(7), FLOAT16(6), FLOAT16(9), FLOAT16(8), FLOAT16(9), FLOAT16(7), + FLOAT16(4), FLOAT16(10), FLOAT16(6), FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(4), FLOAT16(2), + FLOAT16(0), FLOAT16(4), FLOAT16(5), FLOAT16(3), FLOAT16(1), FLOAT16(2), FLOAT16(8), FLOAT16(5), + FLOAT16(7), FLOAT16(9), FLOAT16(2), FLOAT16(7), FLOAT16(2), FLOAT16(4), FLOAT16(0), FLOAT16(5), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(0), FLOAT16(8), FLOAT16(3), FLOAT16(6), FLOAT16(2), FLOAT16(0), FLOAT16(7), + FLOAT16(2), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(4), FLOAT16(0), FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(0), FLOAT16(8), FLOAT16(0), FLOAT16(6), FLOAT16(7), FLOAT16(0), FLOAT16(4), + FLOAT16(9), FLOAT16(10), FLOAT16(1), FLOAT16(8), FLOAT16(9), FLOAT16(0), FLOAT16(10), FLOAT16(9), + FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(2), FLOAT16(10), FLOAT16(7), + FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(0), FLOAT16(0), FLOAT16(9), + FLOAT16(10), FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(9), FLOAT16(3), FLOAT16(8), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(8), FLOAT16(8), + FLOAT16(9), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(6), FLOAT16(3), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(2), FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(9), + FLOAT16(10), FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(2), FLOAT16(3), FLOAT16(5), FLOAT16(7), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(5), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(3), + FLOAT16(1), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(1), FLOAT16(5), FLOAT16(3), FLOAT16(4), + FLOAT16(5), FLOAT16(6), FLOAT16(6), FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(10), FLOAT16(8), + FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(4), FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(4), + FLOAT16(1), FLOAT16(6), FLOAT16(0), FLOAT16(4), FLOAT16(6), FLOAT16(10), FLOAT16(10), FLOAT16(2), + FLOAT16(6), FLOAT16(9), FLOAT16(9), FLOAT16(9), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(1), + FLOAT16(5), FLOAT16(1), FLOAT16(8), FLOAT16(4), FLOAT16(6), FLOAT16(10), FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(0), FLOAT16(9), FLOAT16(4), FLOAT16(10), + FLOAT16(10), FLOAT16(2), FLOAT16(1), FLOAT16(9), FLOAT16(1), FLOAT16(9), FLOAT16(0), FLOAT16(5), + FLOAT16(10), FLOAT16(1), FLOAT16(8), FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(4), FLOAT16(1), + FLOAT16(7), FLOAT16(2), FLOAT16(1), FLOAT16(9), FLOAT16(2), FLOAT16(1), FLOAT16(5), FLOAT16(5), + FLOAT16(10), FLOAT16(5), FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(5), FLOAT16(8), FLOAT16(10), + FLOAT16(1), FLOAT16(4), FLOAT16(1), FLOAT16(9), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(8), + FLOAT16(7), FLOAT16(5), FLOAT16(7), FLOAT16(9), FLOAT16(10), FLOAT16(6), FLOAT16(9), FLOAT16(7), + FLOAT16(7), FLOAT16(7), FLOAT16(8), FLOAT16(0), FLOAT16(9), FLOAT16(4), FLOAT16(3), FLOAT16(0), + FLOAT16(1), FLOAT16(5), FLOAT16(7), FLOAT16(9), FLOAT16(4), FLOAT16(6), FLOAT16(4), FLOAT16(9), + FLOAT16(5), FLOAT16(1), FLOAT16(8), FLOAT16(10), FLOAT16(9), FLOAT16(3), FLOAT16(5), FLOAT16(5), + FLOAT16(7), FLOAT16(5), FLOAT16(8), FLOAT16(7), FLOAT16(4), FLOAT16(6), FLOAT16(4), FLOAT16(9), + FLOAT16(8), FLOAT16(1), FLOAT16(10), FLOAT16(10), FLOAT16(9), FLOAT16(4), FLOAT16(5), FLOAT16(5), + FLOAT16(9), FLOAT16(4), FLOAT16(2), FLOAT16(8), FLOAT16(4), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(6), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(8), FLOAT16(4), FLOAT16(1), FLOAT16(4), + FLOAT16(9), FLOAT16(4), FLOAT16(5), FLOAT16(8), FLOAT16(4), FLOAT16(9), FLOAT16(1), FLOAT16(3), + FLOAT16(8), FLOAT16(6), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(9), FLOAT16(3), FLOAT16(4), + FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(9), FLOAT16(7), FLOAT16(8), FLOAT16(4), FLOAT16(3), + FLOAT16(8), FLOAT16(6), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(8), FLOAT16(3), FLOAT16(4), + FLOAT16(8), FLOAT16(1), FLOAT16(8), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(8), FLOAT16(6), + FLOAT16(2), FLOAT16(6), FLOAT16(3), FLOAT16(8), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(4), + FLOAT16(0), FLOAT16(6), FLOAT16(9), FLOAT16(1), FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(6), + FLOAT16(2), FLOAT16(6), FLOAT16(2), FLOAT16(7), FLOAT16(1), FLOAT16(4), FLOAT16(7), FLOAT16(4), + FLOAT16(8), FLOAT16(1), FLOAT16(9), FLOAT16(3), FLOAT16(10), FLOAT16(1), FLOAT16(3), FLOAT16(6), + FLOAT16(5), FLOAT16(6), FLOAT16(2), FLOAT16(8), FLOAT16(1), FLOAT16(8), FLOAT16(7), FLOAT16(9), + FLOAT16(2), FLOAT16(6), FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(7), FLOAT16(5), FLOAT16(7), + FLOAT16(7), FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(6), FLOAT16(10), FLOAT16(5), FLOAT16(8), + FLOAT16(2), FLOAT16(9), FLOAT16(10), FLOAT16(2), FLOAT16(7), FLOAT16(7), FLOAT16(1), FLOAT16(5), + FLOAT16(7), FLOAT16(0), FLOAT16(5), FLOAT16(10), FLOAT16(3), FLOAT16(7), FLOAT16(5), FLOAT16(7), + FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(1), FLOAT16(5), + FLOAT16(9), FLOAT16(0), FLOAT16(6), FLOAT16(8), FLOAT16(6), FLOAT16(5), FLOAT16(5), FLOAT16(7), + FLOAT16(0), FLOAT16(1), FLOAT16(7), FLOAT16(3), FLOAT16(0), FLOAT16(5), FLOAT16(8), FLOAT16(5), + FLOAT16(4), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(3), + FLOAT16(4), FLOAT16(1), FLOAT16(7), FLOAT16(7), FLOAT16(1), FLOAT16(5), FLOAT16(8), FLOAT16(4), + FLOAT16(4), FLOAT16(3), FLOAT16(6), FLOAT16(1), FLOAT16(0), FLOAT16(8), FLOAT16(4), FLOAT16(0), + FLOAT16(4), FLOAT16(1), FLOAT16(7), FLOAT16(3), FLOAT16(5), FLOAT16(3), FLOAT16(8), FLOAT16(5), + FLOAT16(4), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(0), + FLOAT16(0), FLOAT16(10), FLOAT16(6), FLOAT16(7), FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(5), + FLOAT16(10), FLOAT16(8), FLOAT16(7), FLOAT16(5), FLOAT16(8), FLOAT16(1), FLOAT16(4), FLOAT16(9), + FLOAT16(3), FLOAT16(6), FLOAT16(5), FLOAT16(7), FLOAT16(6), FLOAT16(10), FLOAT16(1), FLOAT16(6), + FLOAT16(6), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(0), FLOAT16(9), FLOAT16(2), FLOAT16(8), + FLOAT16(3), FLOAT16(4), FLOAT16(6), FLOAT16(0), FLOAT16(9), FLOAT16(10), FLOAT16(1), FLOAT16(9), + FLOAT16(6), FLOAT16(8), FLOAT16(6), FLOAT16(4), FLOAT16(0), FLOAT16(9), FLOAT16(10), FLOAT16(9), + FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(0), FLOAT16(6), FLOAT16(4), FLOAT16(2), FLOAT16(2), + FLOAT16(3), FLOAT16(5), FLOAT16(8), FLOAT16(1), FLOAT16(7), FLOAT16(7), FLOAT16(8), FLOAT16(10), + FLOAT16(1), FLOAT16(6), FLOAT16(6), FLOAT16(0), FLOAT16(6), FLOAT16(8), FLOAT16(5), FLOAT16(0), + FLOAT16(4), FLOAT16(1), FLOAT16(5), FLOAT16(0), FLOAT16(7), FLOAT16(7), FLOAT16(8), FLOAT16(3), + FLOAT16(1), FLOAT16(2), FLOAT16(6), FLOAT16(7), FLOAT16(6), FLOAT16(8), FLOAT16(5), FLOAT16(2), + FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(0), FLOAT16(7), FLOAT16(10), FLOAT16(5), FLOAT16(3), + FLOAT16(0), FLOAT16(4), FLOAT16(2), FLOAT16(1), FLOAT16(4), FLOAT16(2), FLOAT16(9), FLOAT16(7), + FLOAT16(4), FLOAT16(10), FLOAT16(5), FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(4), FLOAT16(4), + FLOAT16(4), FLOAT16(3), FLOAT16(7), FLOAT16(6), FLOAT16(4), FLOAT16(2), FLOAT16(9), FLOAT16(5), + FLOAT16(7), FLOAT16(1), FLOAT16(5), FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(9), FLOAT16(5), + FLOAT16(4), FLOAT16(4), FLOAT16(5), FLOAT16(3), FLOAT16(1), FLOAT16(8), FLOAT16(9), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(5), FLOAT16(7), FLOAT16(5), FLOAT16(5), FLOAT16(0), FLOAT16(5), + }; + + DoTest(engine,input0, input1, expected_results, axis); +} + +// 6-2 +TEST(gather_elements_gpu_fp16, d222443_i222446_a5) { + const auto& engine = get_test_engine(); + + const int axis = 5; + auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 2, 4, 4, 3 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 2, 4, 4, 6 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), + FLOAT16(5), FLOAT16(5), FLOAT16(2), + FLOAT16(0), FLOAT16(7), FLOAT16(7), + FLOAT16(10), FLOAT16(4), FLOAT16(5), + FLOAT16(9), FLOAT16(0), FLOAT16(0), + FLOAT16(5), FLOAT16(7), FLOAT16(0), + FLOAT16(4), FLOAT16(0), FLOAT16(4), + FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), FLOAT16(1), + FLOAT16(7), FLOAT16(4), FLOAT16(7), + FLOAT16(10), FLOAT16(8), FLOAT16(2), + FLOAT16(0), FLOAT16(8), FLOAT16(3), + FLOAT16(6), FLOAT16(8), FLOAT16(10), + FLOAT16(4), FLOAT16(2), FLOAT16(10), + FLOAT16(7), FLOAT16(8), FLOAT16(7), + FLOAT16(0), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), FLOAT16(8), + FLOAT16(5), FLOAT16(2), FLOAT16(3), + FLOAT16(3), FLOAT16(1), FLOAT16(5), + FLOAT16(9), FLOAT16(10), FLOAT16(0), + FLOAT16(9), FLOAT16(5), FLOAT16(5), + FLOAT16(3), FLOAT16(10), FLOAT16(5), + FLOAT16(2), FLOAT16(0), FLOAT16(10), + FLOAT16(0), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(10), FLOAT16(5), + FLOAT16(5), FLOAT16(10), FLOAT16(0), + FLOAT16(8), FLOAT16(8), FLOAT16(9), + FLOAT16(1), FLOAT16(0), FLOAT16(7), + FLOAT16(9), FLOAT16(6), FLOAT16(8), + FLOAT16(7), FLOAT16(10), FLOAT16(9), + FLOAT16(2), FLOAT16(3), FLOAT16(3), + FLOAT16(5), FLOAT16(6), FLOAT16(9), + FLOAT16(4), FLOAT16(9), FLOAT16(2), + FLOAT16(4), FLOAT16(5), FLOAT16(5), + FLOAT16(3), FLOAT16(1), FLOAT16(1), + FLOAT16(6), FLOAT16(8), FLOAT16(0), + FLOAT16(5), FLOAT16(5), FLOAT16(10), + FLOAT16(8), FLOAT16(6), FLOAT16(9), + FLOAT16(6), FLOAT16(9), FLOAT16(1), + FLOAT16(2), FLOAT16(7), FLOAT16(1), + FLOAT16(1), FLOAT16(3), FLOAT16(0), + FLOAT16(4), FLOAT16(0), FLOAT16(7), + FLOAT16(10), FLOAT16(2), FLOAT16(1), + FLOAT16(3), FLOAT16(9), FLOAT16(7), + FLOAT16(1), FLOAT16(7), FLOAT16(4), + FLOAT16(4), FLOAT16(5), FLOAT16(1), + FLOAT16(6), FLOAT16(9), FLOAT16(6), + FLOAT16(10), FLOAT16(6), FLOAT16(1), + FLOAT16(10), FLOAT16(4), FLOAT16(1), + FLOAT16(6), FLOAT16(2), FLOAT16(5), + FLOAT16(5), FLOAT16(10), FLOAT16(1), + FLOAT16(2), FLOAT16(3), FLOAT16(6), + FLOAT16(1), FLOAT16(7), FLOAT16(6), + FLOAT16(8), FLOAT16(2), FLOAT16(5), + FLOAT16(4), FLOAT16(2), FLOAT16(0), + FLOAT16(9), FLOAT16(4), FLOAT16(1), + FLOAT16(10), FLOAT16(4), FLOAT16(1), + FLOAT16(9), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(4), FLOAT16(2), + FLOAT16(1), FLOAT16(8), FLOAT16(5), + FLOAT16(3), FLOAT16(4), FLOAT16(8), + FLOAT16(10), FLOAT16(7), FLOAT16(2), + FLOAT16(7), FLOAT16(9), FLOAT16(2), + FLOAT16(9), FLOAT16(5), FLOAT16(5), + FLOAT16(6), FLOAT16(8), FLOAT16(8), + FLOAT16(5), FLOAT16(10), FLOAT16(6), + FLOAT16(4), FLOAT16(9), FLOAT16(7), + FLOAT16(7), FLOAT16(10), FLOAT16(10), + FLOAT16(9), FLOAT16(3), FLOAT16(5), + FLOAT16(5), FLOAT16(1), FLOAT16(4), + FLOAT16(6), FLOAT16(9), FLOAT16(4), + FLOAT16(8), FLOAT16(9), FLOAT16(7), + FLOAT16(8), FLOAT16(7), FLOAT16(8), + FLOAT16(0), FLOAT16(9), FLOAT16(5), + FLOAT16(5), FLOAT16(0), FLOAT16(7), + FLOAT16(5), FLOAT16(7), FLOAT16(7), + FLOAT16(2), FLOAT16(10), FLOAT16(9), + FLOAT16(9), FLOAT16(5), FLOAT16(1), + FLOAT16(4), FLOAT16(10), FLOAT16(2), + FLOAT16(4), FLOAT16(3), FLOAT16(5), + FLOAT16(9), FLOAT16(4), FLOAT16(5), + FLOAT16(8), FLOAT16(4), FLOAT16(2), + FLOAT16(10), FLOAT16(1), FLOAT16(6), + FLOAT16(6), FLOAT16(0), FLOAT16(0), + FLOAT16(8), FLOAT16(8), FLOAT16(3), + FLOAT16(4), FLOAT16(7), FLOAT16(7), + FLOAT16(2), FLOAT16(9), FLOAT16(7), + FLOAT16(9), FLOAT16(1), FLOAT16(0), + FLOAT16(8), FLOAT16(6), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(4), + FLOAT16(10), FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(7), FLOAT16(3), + FLOAT16(8), FLOAT16(8), FLOAT16(4), + FLOAT16(3), FLOAT16(2), FLOAT16(0), + FLOAT16(2), FLOAT16(10), FLOAT16(2), + FLOAT16(9), FLOAT16(1), FLOAT16(4), + FLOAT16(6), FLOAT16(1), FLOAT16(9), + FLOAT16(1), FLOAT16(10), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(6), FLOAT16(7), FLOAT16(8), + FLOAT16(7), FLOAT16(8), FLOAT16(7), + FLOAT16(6), FLOAT16(0), FLOAT16(6), + FLOAT16(2), FLOAT16(3), FLOAT16(7), + FLOAT16(1), FLOAT16(8), FLOAT16(5), + FLOAT16(6), FLOAT16(6), FLOAT16(3), + FLOAT16(7), FLOAT16(1), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(8), + FLOAT16(6), FLOAT16(8), FLOAT16(3), + FLOAT16(1), FLOAT16(5), FLOAT16(3), + FLOAT16(6), FLOAT16(5), FLOAT16(4), + FLOAT16(2), FLOAT16(4), FLOAT16(4), + FLOAT16(4), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(0), FLOAT16(4), + FLOAT16(2), FLOAT16(7), FLOAT16(7), + FLOAT16(5), FLOAT16(8), FLOAT16(7), + FLOAT16(10), FLOAT16(5), FLOAT16(10), + FLOAT16(3), FLOAT16(5), FLOAT16(5), + FLOAT16(7), FLOAT16(4), FLOAT16(6), + FLOAT16(10), FLOAT16(1), FLOAT16(7), + FLOAT16(3), FLOAT16(5), FLOAT16(5), + FLOAT16(9), FLOAT16(0), FLOAT16(3), + FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(2), FLOAT16(10), FLOAT16(2), + FLOAT16(9), FLOAT16(7), FLOAT16(5), + FLOAT16(8), FLOAT16(0), FLOAT16(1), + FLOAT16(7), FLOAT16(7), FLOAT16(4), + FLOAT16(6), FLOAT16(8), FLOAT16(10), + FLOAT16(7), FLOAT16(3), FLOAT16(8), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), + FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(0), + FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(5), FLOAT16(5), FLOAT16(5), + FLOAT16(7), FLOAT16(0), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(5), FLOAT16(4), FLOAT16(5), FLOAT16(4), FLOAT16(10), FLOAT16(5), + FLOAT16(0), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(9), FLOAT16(9), + FLOAT16(7), FLOAT16(0), FLOAT16(0), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(0), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), + FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(7), FLOAT16(7), FLOAT16(10), + FLOAT16(5), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(9), + FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(2), FLOAT16(10), FLOAT16(8), FLOAT16(8), FLOAT16(2), FLOAT16(2), + FLOAT16(8), FLOAT16(8), FLOAT16(0), FLOAT16(3), FLOAT16(0), FLOAT16(0), + FLOAT16(6), FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(8), FLOAT16(6), + FLOAT16(4), FLOAT16(10), FLOAT16(2), FLOAT16(10), FLOAT16(2), FLOAT16(10), + FLOAT16(7), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(0), FLOAT16(6), FLOAT16(6), FLOAT16(9), FLOAT16(0), FLOAT16(0), + FLOAT16(8), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(4), FLOAT16(2), + FLOAT16(5), FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(3), FLOAT16(5), + FLOAT16(3), FLOAT16(1), FLOAT16(1), FLOAT16(3), FLOAT16(1), FLOAT16(1), + FLOAT16(10), FLOAT16(9), FLOAT16(0), FLOAT16(10), FLOAT16(9), FLOAT16(0), + FLOAT16(9), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), + FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(3), FLOAT16(5), FLOAT16(10), + FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(10), FLOAT16(10), + FLOAT16(0), FLOAT16(5), FLOAT16(4), FLOAT16(4), FLOAT16(5), FLOAT16(0), + FLOAT16(10), FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(10), + FLOAT16(10), FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(10), FLOAT16(10), + FLOAT16(8), FLOAT16(9), FLOAT16(8), FLOAT16(9), FLOAT16(8), FLOAT16(9), + FLOAT16(7), FLOAT16(0), FLOAT16(0), FLOAT16(7), FLOAT16(0), FLOAT16(0), + FLOAT16(8), FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(8), FLOAT16(6), + FLOAT16(9), FLOAT16(9), FLOAT16(7), FLOAT16(10), FLOAT16(10), FLOAT16(10), + FLOAT16(2), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(2), FLOAT16(3), + FLOAT16(6), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(6), FLOAT16(5), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(9), FLOAT16(4), FLOAT16(4), + FLOAT16(5), FLOAT16(5), FLOAT16(4), FLOAT16(4), FLOAT16(5), FLOAT16(5), + FLOAT16(3), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(8), FLOAT16(8), FLOAT16(6), FLOAT16(6), FLOAT16(0), + FLOAT16(10), FLOAT16(10), FLOAT16(5), FLOAT16(10), FLOAT16(5), FLOAT16(5), + FLOAT16(6), FLOAT16(8), FLOAT16(9), FLOAT16(9), FLOAT16(8), FLOAT16(9), + FLOAT16(9), FLOAT16(6), FLOAT16(6), FLOAT16(1), FLOAT16(9), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(7), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(3), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(4), FLOAT16(4), FLOAT16(7), FLOAT16(4), FLOAT16(0), + FLOAT16(10), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(10), FLOAT16(10), + FLOAT16(3), FLOAT16(3), FLOAT16(3), FLOAT16(9), FLOAT16(9), FLOAT16(7), + FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(1), FLOAT16(4), FLOAT16(4), + FLOAT16(4), FLOAT16(4), FLOAT16(1), FLOAT16(1), FLOAT16(5), FLOAT16(5), + FLOAT16(9), FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(9), FLOAT16(9), + FLOAT16(6), FLOAT16(10), FLOAT16(6), FLOAT16(10), FLOAT16(10), FLOAT16(10), + FLOAT16(1), FLOAT16(10), FLOAT16(1), FLOAT16(10), FLOAT16(1), FLOAT16(10), + FLOAT16(2), FLOAT16(5), FLOAT16(6), FLOAT16(2), FLOAT16(2), FLOAT16(6), + FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(1), FLOAT16(10), FLOAT16(10), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(6), + FLOAT16(1), FLOAT16(1), FLOAT16(6), FLOAT16(7), FLOAT16(7), FLOAT16(6), + FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(8), FLOAT16(8), FLOAT16(2), + FLOAT16(4), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(4), + FLOAT16(1), FLOAT16(9), FLOAT16(4), FLOAT16(9), FLOAT16(9), FLOAT16(4), + FLOAT16(1), FLOAT16(4), FLOAT16(1), FLOAT16(4), FLOAT16(4), FLOAT16(10), + FLOAT16(1), FLOAT16(1), FLOAT16(9), FLOAT16(1), FLOAT16(9), FLOAT16(1), + FLOAT16(4), FLOAT16(4), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(8), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(5), FLOAT16(8), + FLOAT16(3), FLOAT16(4), FLOAT16(3), FLOAT16(3), FLOAT16(3), FLOAT16(8), + FLOAT16(10), FLOAT16(10), FLOAT16(7), FLOAT16(10), FLOAT16(10), FLOAT16(2), + FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(9), FLOAT16(9), FLOAT16(9), + FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(9), FLOAT16(9), FLOAT16(9), + FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(8), + FLOAT16(5), FLOAT16(6), FLOAT16(6), FLOAT16(5), FLOAT16(10), FLOAT16(5), + FLOAT16(7), FLOAT16(9), FLOAT16(7), FLOAT16(7), FLOAT16(9), FLOAT16(7), + FLOAT16(10), FLOAT16(10), FLOAT16(7), FLOAT16(10), FLOAT16(7), FLOAT16(10), + FLOAT16(5), FLOAT16(3), FLOAT16(9), FLOAT16(3), FLOAT16(9), FLOAT16(3), + FLOAT16(5), FLOAT16(1), FLOAT16(1), FLOAT16(4), FLOAT16(4), FLOAT16(4), + FLOAT16(9), FLOAT16(9), FLOAT16(9), FLOAT16(4), FLOAT16(6), FLOAT16(6), + FLOAT16(9), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(7), FLOAT16(9), + FLOAT16(8), FLOAT16(8), FLOAT16(7), FLOAT16(8), FLOAT16(8), FLOAT16(8), + FLOAT16(9), FLOAT16(0), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(0), FLOAT16(0), + FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(2), FLOAT16(9), FLOAT16(2), FLOAT16(9), FLOAT16(9), FLOAT16(10), + FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(1), FLOAT16(5), FLOAT16(9), + FLOAT16(4), FLOAT16(10), FLOAT16(2), FLOAT16(10), FLOAT16(4), FLOAT16(4), + FLOAT16(5), FLOAT16(3), FLOAT16(4), FLOAT16(3), FLOAT16(4), FLOAT16(5), + FLOAT16(5), FLOAT16(9), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(4), + FLOAT16(4), FLOAT16(8), FLOAT16(8), FLOAT16(2), FLOAT16(4), FLOAT16(4), + FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(1), FLOAT16(10), FLOAT16(6), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(6), FLOAT16(0), FLOAT16(0), + FLOAT16(3), FLOAT16(8), FLOAT16(8), FLOAT16(3), FLOAT16(8), FLOAT16(8), + FLOAT16(4), FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(9), FLOAT16(2), FLOAT16(7), FLOAT16(9), FLOAT16(7), FLOAT16(7), + FLOAT16(9), FLOAT16(0), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(2), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(0), + FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(10), + FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(3), FLOAT16(7), FLOAT16(3), + FLOAT16(4), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(8), + FLOAT16(3), FLOAT16(0), FLOAT16(3), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(2), FLOAT16(10), FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(9), FLOAT16(4), FLOAT16(1), FLOAT16(1), FLOAT16(4), FLOAT16(4), + FLOAT16(6), FLOAT16(1), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(1), + FLOAT16(10), FLOAT16(2), FLOAT16(1), FLOAT16(10), FLOAT16(1), FLOAT16(10), + FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), + FLOAT16(8), FLOAT16(6), FLOAT16(6), FLOAT16(8), FLOAT16(6), FLOAT16(6), + FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(0), FLOAT16(0), FLOAT16(6), FLOAT16(6), FLOAT16(0), FLOAT16(0), + FLOAT16(7), FLOAT16(3), FLOAT16(3), FLOAT16(2), FLOAT16(7), FLOAT16(3), + FLOAT16(5), FLOAT16(1), FLOAT16(1), FLOAT16(5), FLOAT16(8), FLOAT16(5), + FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(6), + FLOAT16(1), FLOAT16(1), FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(7), + FLOAT16(9), FLOAT16(5), FLOAT16(8), FLOAT16(8), FLOAT16(5), FLOAT16(9), + FLOAT16(6), FLOAT16(8), FLOAT16(8), FLOAT16(6), FLOAT16(6), FLOAT16(6), + FLOAT16(3), FLOAT16(5), FLOAT16(3), FLOAT16(5), FLOAT16(1), FLOAT16(1), + FLOAT16(6), FLOAT16(5), FLOAT16(4), FLOAT16(5), FLOAT16(6), FLOAT16(5), + FLOAT16(4), FLOAT16(2), FLOAT16(4), FLOAT16(4), FLOAT16(2), FLOAT16(2), + FLOAT16(4), FLOAT16(5), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), + FLOAT16(3), FLOAT16(3), FLOAT16(0), FLOAT16(4), FLOAT16(3), FLOAT16(4), + FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(5), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(5), FLOAT16(5), + FLOAT16(10), FLOAT16(5), FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(5), + FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(5), FLOAT16(5), + FLOAT16(6), FLOAT16(6), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(10), FLOAT16(1), FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(7), + FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(5), + FLOAT16(0), FLOAT16(9), FLOAT16(3), FLOAT16(9), FLOAT16(0), FLOAT16(3), + FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(10), FLOAT16(10), FLOAT16(6), + FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(10), FLOAT16(10), FLOAT16(10), + FLOAT16(5), FLOAT16(9), FLOAT16(7), FLOAT16(7), FLOAT16(5), FLOAT16(9), + FLOAT16(0), FLOAT16(8), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(8), + FLOAT16(7), FLOAT16(7), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), + FLOAT16(8), FLOAT16(10), FLOAT16(8), FLOAT16(6), FLOAT16(10), FLOAT16(8), + FLOAT16(3), FLOAT16(3), FLOAT16(7), FLOAT16(8), FLOAT16(3), FLOAT16(8), + }; + + DoTest(engine,input0, input1, expected_results, axis); +} + +// TEST(gather_elements_gpu_fp16, d32223_i32228_a4) { +// const auto& engine = get_test_engine(); + +// const int axis = ; +// auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, }); // data +// auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, }); // indices + +// set_values(input0, { +// }); + +// set_values(input1, { +// }); + +// std::vector expected_results = { +// }; + +// DoTest(engine,input0, input1, expected_results, axis); +// } \ No newline at end of file From af8ff0cba55f590a3197640fab2544101cc6cc34 Mon Sep 17 00:00:00 2001 From: yunji Date: Thu, 17 Jun 2021 18:27:46 +0900 Subject: [PATCH 02/11] Add cldnn unit test implementation - fix shpae error. - add rank=4,5,6 test cases. - change gws and lws setting. --- .../src/cldnn_engine/ops/gather_elements.cpp | 7 +- .../single_layer_tests/gather_elements.cpp | 36 ++- .../thirdparty/clDNN/api/gather_elements.hpp | 14 +- .../gather/gather_elements_kernel_ref.cpp | 118 ++++---- .../gather/gather_elements_kernel_ref.h | 5 +- .../core/cl_kernels/gather_elements_ref.cl | 272 +++++++----------- .../thirdparty/clDNN/src/gather_elements.cpp | 55 +--- .../test_cases/gather_elements_gpu_test.cpp | 42 +-- 8 files changed, 216 insertions(+), 333 deletions(-) diff --git a/inference-engine/src/cldnn_engine/ops/gather_elements.cpp b/inference-engine/src/cldnn_engine/ops/gather_elements.cpp index 75ce16463994be..914a7021611cbe 100644 --- a/inference-engine/src/cldnn_engine/ops/gather_elements.cpp +++ b/inference-engine/src/cldnn_engine/ops/gather_elements.cpp @@ -17,15 +17,14 @@ void CreateGatherElementsOp(Program& p, const std::shared_ptr(op->get_input_shape(1).size()); - auto axis = op->get_axis(); + auto outLayout = DefaultFormatForDims(op->get_output_shape(0).size()); auto primitive = cldnn::gather_elements(layerName, inputPrimitives[0], inputPrimitives[1], - indices_rank, + outLayout, + CldnnTensorFromIEDims(op->get_output_shape(0)), axis); p.AddPrimitive(primitive); diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp index 1ad8bbd0d4c335..dd2f462cb7825b 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp @@ -11,17 +11,17 @@ using namespace LayerTestsDefinitions; namespace { const std::vector dPrecisions = { - InferenceEngine::Precision::FP32, + // InferenceEngine::Precision::FP32, InferenceEngine::Precision::FP16, - InferenceEngine::Precision::I32, - InferenceEngine::Precision::I64, - InferenceEngine::Precision::I16, - InferenceEngine::Precision::U8, - InferenceEngine::Precision::I8 + // InferenceEngine::Precision::I32, + // InferenceEngine::Precision::I64, + // InferenceEngine::Precision::I16, + // InferenceEngine::Precision::U8, + // InferenceEngine::Precision::I8 }; const std::vector iPrecisions = { InferenceEngine::Precision::I32, - InferenceEngine::Precision::I64 + // InferenceEngine::Precision::I64 }; INSTANTIATE_TEST_SUITE_P(smoke_set1, GatherElementsLayerTest, @@ -48,27 +48,35 @@ INSTANTIATE_TEST_SUITE_P(smoke_set3, GatherElementsLayerTest, ::testing::Combine( ::testing::Values(std::vector({2, 2, 3, 5})), // Data shape ::testing::Values(std::vector({2, 2, 3, 7})), // Indices shape - ::testing::Values(3, -1), // Axis + ::testing::Values(3), // Axis ::testing::ValuesIn(dPrecisions), ::testing::ValuesIn(iPrecisions), ::testing::Values(CommonTestUtils::DEVICE_CPU)), GatherElementsLayerTest::getTestCaseName); +<<<<<<< HEAD INSTANTIATE_TEST_SUITE_P(smoke_set4, GatherElementsLayerTest, +======= +INSTANTIATE_TEST_CASE_P(yunji_set2, GatherElementsLayerTest, +>>>>>>> Add cldnn unit test implementation ::testing::Combine( - ::testing::Values(std::vector({3, 2, 3, 8})), // Data shape - ::testing::Values(std::vector({2, 2, 3, 8})), // Indices shape - ::testing::Values(0, -4), // Axis + ::testing::Values(std::vector({3, 2, 2, 2, 3})), // Data shape + ::testing::Values(std::vector({3, 2, 2, 2, 8})), // Indices shape + ::testing::Values(4), // Axis ::testing::ValuesIn(dPrecisions), ::testing::ValuesIn(iPrecisions), ::testing::Values(CommonTestUtils::DEVICE_CPU)), GatherElementsLayerTest::getTestCaseName); +<<<<<<< HEAD INSTANTIATE_TEST_SUITE_P(smoke_set5, GatherElementsLayerTest, +======= +INSTANTIATE_TEST_CASE_P(yunji_set3, GatherElementsLayerTest, +>>>>>>> Add cldnn unit test implementation ::testing::Combine( - ::testing::Values(std::vector({3, 2, 3, 4, 8})), // Data shape - ::testing::Values(std::vector({3, 2, 3, 5, 8})), // Indices shape - ::testing::Values(3, -2), // Axis + ::testing::Values(std::vector({2, 2, 2, 4, 4, 3})), // Data shape + ::testing::Values(std::vector({2, 2, 2, 4, 4, 6})), // Indices shape + ::testing::Values(5), // Axis ::testing::ValuesIn(dPrecisions), ::testing::ValuesIn(iPrecisions), ::testing::Values(CommonTestUtils::DEVICE_CPU)), diff --git a/inference-engine/thirdparty/clDNN/api/gather_elements.hpp b/inference-engine/thirdparty/clDNN/api/gather_elements.hpp index a945a156cb1bde..b48f0091f7f6a9 100644 --- a/inference-engine/thirdparty/clDNN/api/gather_elements.hpp +++ b/inference-engine/thirdparty/clDNN/api/gather_elements.hpp @@ -35,18 +35,22 @@ struct gather_elements : public primitive_base { /// @param id This primitive id. /// @param data Input data primitive id. /// @param indices Input indexes primitive id. - /// @param indices_rank Rank of indices. + /// @param output_format Output format: bfyx, bfzyx, bfwzyx + /// @param output_shape Output shape: {2, 2, 3, 5}, {2, 2, 3, 3, 6} /// @param axis An attribute of GatherElements. Required. gather_elements(const primitive_id& id, const primitive_id& data, const primitive_id& indices, - const uint8_t indices_rank, + const format& output_format, + const tensor& output_shape, const uint8_t axis = 0, const padding& output_padding = padding()) - : primitive_base(id, {data, indices}, output_padding), indices_rank(indices_rank), axis(axis) {} + : primitive_base(id, {data, indices}, output_padding), output_format(output_format), output_shape(output_shape), axis(axis) {} - /// @brief indices_rank - uint8_t indices_rank; + /// @brief Gather Elements output format + format output_format; + /// @brief Gather Elements output shape + tensor output_shape; /// @brief Which axis to gather on. uint8_t axis; diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp index ceb415a2571752..307fa7c0f996f9 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp @@ -18,7 +18,7 @@ #include "kernel_selector_utils.h" #include #include - +#include namespace kernel_selector { ParamsKey GatherElementsKernelRef::GetSupportedKey() const { @@ -68,77 +68,75 @@ static inline std::vector GetDefaultOrder(size_t size) { CommonDispatchData GatherElementsKernelRef::SetDefault(const gather_elements_params& params, const optional_params&) const { CommonDispatchData dispatchData; - auto indices_dims = params.inputs[1].LogicalDims(); - - if (indices_dims.size() > 1) { - std::reverse(indices_dims.begin(), indices_dims.end()); - } - - indices_dims[params.indices_rank - 1] = 1; // set last dim of indices to 1 + const auto& output = params.output; + // printf("%ld %ld %ld %ld %ld %ld\n", output.X().v, output.Y().v, output.Z().v, output.W().v, output.Feature().v, output.Batch().v); switch (params.inputs[1].GetLayout()) { case DataLayout::bfyx: - dispatchData.gws = { indices_dims[3], indices_dims[2], indices_dims[1] * indices_dims[0] }; + dispatchData.gws = {output.X().v, output.Y().v, output.Feature().v * output.Batch().v}; break; case DataLayout::bfzyx: - dispatchData.gws = { indices_dims[4] * indices_dims[3], indices_dims[2], indices_dims[1] * indices_dims[0] }; + dispatchData.gws = {output.X().v * output.Y().v, output.Z().v, output.Feature().v * output.Batch().v}; break; case DataLayout::bfwzyx: - dispatchData.gws = { indices_dims[5] * indices_dims[4], indices_dims[3] * indices_dims[2], indices_dims[1] * indices_dims[0] }; + dispatchData.gws = {output.X().v * output.Y().v, output.Z().v * output.W().v, output.Feature().v * output.Batch().v}; break; default: - throw std::invalid_argument("Unsupported data layout for scatter elements update primitive"); + throw std::invalid_argument("Unsupported data layout for gather elements primitive"); break; } - dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo); + // dispatchData.lws = {1, 1, 1}; return dispatchData; } -static size_t GetIndicesLastDim(const gather_elements_params& params) { - // get indices dims - auto indices_dims = params.inputs[1].LogicalDims(); +// static size_t GetIndicesLastDim(const gather_elements_params& params) { +// // get indices dims +// auto indices_dims = params.inputs[1].LogicalDims(); +// // std::cout << indices_dims << "incide dims\n"; - if (indices_dims.size() > 1) { - std::reverse(indices_dims.begin(), indices_dims.end()); - } +// if (indices_dims.size() > 1) { +// std::reverse(indices_dims.begin(), indices_dims.end()); +// } - auto indices_last_dim = indices_dims[params.indices_rank - 1]; +// auto indices_last_dim = indices_dims[0]; - return indices_last_dim; -} +// return indices_last_dim; +// } -static size_t GetSliceSize(const gather_elements_params& params) { - // get input dims - auto input_dims = params.inputs[0].LogicalDims(); +// static size_t GetSliceSize(const gather_elements_params& params) { +// // get input dims +// // auto input_dims = params.inputs[0].LogicalDims(); - if (input_dims.size() > 1) { - std::reverse(input_dims.begin(), input_dims.end()); - } +// // if (input_dims.size() > 1) { +// // std::reverse(input_dims.begin(), input_dims.end()); +// // } - // get last dim of indices - auto indices_last_dim = GetIndicesLastDim(params); +// // // get last dim of indices +// // auto indices_last_dim = GetIndicesLastDim(params); - // calculate slize size which is used in kernel to copy - size_t wi_slice_size = 1; - for (size_t i = params.batch_dims + indices_last_dim; i < input_dims.size(); i++) { - wi_slice_size *= input_dims[i]; - } +// // // calculate slize size which is used in kernel to copy +// // size_t wi_slice_size = 1; +// // for (size_t i = params.batch_dims + indices_last_dim; i < input_dims.size(); i++) { +// // wi_slice_size *= input_dims[i]; +// // } - return wi_slice_size; -} +// return 3; +// } JitConstants GatherElementsKernelRef::GetJitConstants(const gather_elements_params& params) const { JitConstants jit = MakeBaseParamsJitConstants(params); - - jit.AddConstant(MakeJitConstant("INDICES_RANK", params.indices_rank)); - jit.AddConstant(MakeJitConstant("BATCH_DIMS", params.batch_dims)); - jit.AddConstant(MakeJitConstant("WI_SLICE_SIZE", GetSliceSize(params))); - jit.AddConstant(MakeJitConstant("INDICES_LAST_DIM", GetIndicesLastDim(params))); + // parameters in gather_elements_kernel_ref.h + auto p_axis = static_cast(params.axis); + if (p_axis < 0) { + p_axis = params.inputs[0].LogicalDims().size() + params.axis; + } + // printf("%d\n", p_axis); + jit.AddConstant(MakeJitConstant("AXIS", p_axis)); if (!params.fused_ops.empty()) { FusedOpsConfiguration conf = { "", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() }; @@ -149,15 +147,14 @@ JitConstants GatherElementsKernelRef::GetJitConstants(const gather_elements_para } bool GatherElementsKernelRef::Validate(const Params& p, const optional_params& o) const { - if (p.GetType() != KernelType:: GATHER_ELEMENTS || o.GetType() != KernelType::GATHER_ELEMENTS) { + if (p.GetType() != KernelType::GATHER_ELEMENTS || o.GetType() != KernelType::GATHER_ELEMENTS) { return false; } const gather_elements_params& params = static_cast(p); auto input_dims = params.inputs[0].LogicalDims(); auto indices_dims = params.inputs[1].LogicalDims(); - auto indices_rank = params.indices_rank; - auto batch_dims = params.batch_dims; + auto indices_rank = indices_dims.size(); std::reverse(input_dims.begin(), input_dims.end()); std::reverse(indices_dims.begin(), indices_dims.end()); @@ -166,24 +163,24 @@ bool GatherElementsKernelRef::Validate(const Params& p, const optional_params& o return false; } - if (batch_dims + indices_dims[indices_rank - 1] > input_dims.size()) { - return false; - } + // if (batch_dims + indices_dims[indices_rank - 1] > input_dims.size()) { + // return false; + // } - if (batch_dims >= std::min(input_dims.size(), static_cast(indices_rank))) { - return false; - } + // if (batch_dims >= std::min(input_dims.size(), static_cast(indices_rank))) { + // return false; + // } - for (uint8_t i = 0; i < batch_dims; i++) { - if (input_dims[i] != indices_dims[i]) { - return false; - } - } + // for (uint8_t i = 0; i < batch_dims; i++) { + // if (input_dims[i] != indices_dims[i]) { + // return false; + // } + // } - for (auto& fused_op : params.fused_ops) { - if (!IsFusedPrimitiveSupported(fused_op)) - return false; - } + // for (auto& fused_op : params.fused_ops) { + // if (!IsFusedPrimitiveSupported(fused_op)) + // return false; + // } return true; } @@ -203,7 +200,6 @@ KernelsData GatherElementsKernelRef::GetKernelsData(const Params& params, const auto jit = CreateJit(kernelName, cldnn_jit, entry_point); auto& kernel = kd.kernels[0]; FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 2, GetFusedPrimitiveInputsCount(params)); - return { kd }; } diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h index a6097e4ccaf400..a58fa3f87ab991 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h @@ -23,11 +23,8 @@ namespace kernel_selector { // gather_elements_params //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// struct gather_elements_params : public base_params { - gather_elements_params() : base_params(KernelType::GATHER_ELEMENTS), indices_rank(0), batch_dims(0) {} + gather_elements_params() : base_params(KernelType::GATHER_ELEMENTS), axis(0) {} - uint8_t indices_rank; - - uint8_t batch_dims; uint8_t axis; virtual ParamsKey GetParamsKey() const { return base_params::GetParamsKey(); } diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl index 91cb7d9be773e9..b4800d7392be05 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl @@ -15,15 +15,10 @@ #include "include/fetch.cl" #define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order) -#define GET_OUTPUT_INDEX(out_order) OUTPUT_GET_INDEX(out_order) +#define GET_OUTPUT_INDEX(idx_order) OUTPUT_GET_INDEX(idx_order) -#if INPUT0_DIMS == 4 - #define IN_ORDER in_b,in_f,in_y,in_x -#elif INPUT0_DIMS == 5 - #define IN_ORDER in_b,in_f,in_z,in_y,in_x -#else - #define IN_ORDER in_b,in_f,in_w,in_z,in_y,in_x -#endif +#define ORDER b,f,y,x +#define IN_ORDER in_b,in_f,in_y,in_x #if INPUT1_DIMS == 4 #define IDX_ORDER idx_b,idx_f,idx_y,idx_x @@ -33,194 +28,119 @@ #define IDX_ORDER idx_b,idx_f,idx_w,idx_z,idx_y,idx_x #endif -#if OUTPUT_DIMS == 4 - #define OUT_ORDER out_b,out_f,out_y,out_x -#elif OUTPUT_DIMS == 5 - #define OUT_ORDER out_b,out_f,out_z,out_y,out_x -#else - #define OUT_ORDER out_b,out_f,out_w,out_z,out_y,out_x -#endif +#define OUT_ORDER out_b,out_f,out_y,out_x +#define GET_INDEX(prefix, num, idx_order) CAT(CAT(prefix, num), _GET_INDEX)(idx_order) #define INDICES_MAX_DIM 6 KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, const __global INPUT1_TYPE* indices, - __global OUTPUT_TYPE* output -#if HAS_FUSED_OPS_DECLS - , FUSED_OPS_DECLS -#endif -) + __global OUTPUT_TYPE* output) { - const uint dim0 = get_global_id(0); const uint dim1 = get_global_id(1); const uint dim2 = get_global_id(2); // Calculate indice index - const uint F_NUM = (INDICES_RANK == 2) ? 1 : INPUT1_FEATURE_NUM; + const uint F_NUM = INPUT1_FEATURE_NUM; const uint idx_f = dim2 % F_NUM; const uint idx_b = dim2 / F_NUM; - #if INPUT1_DIMS == 4 - const uint idx_x = dim0; - const uint idx_y = dim1; - const uint idx_z = 0; - const uint idx_w = 0; - - const uint idx_arr[INPUT1_DIMS*2] = {idx_b, idx_f, idx_y, idx_x, 0, 0, 0, 0}; - const uint idx_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Y, INPUT1_SIZE_X}; - #elif INPUT1_DIMS == 5 - const uint X_NUM = (INDICES_RANK == 5) ? 1 : INPUT1_SIZE_X; - - const uint idx_x = dim0 % X_NUM; - const uint idx_y = dim0 / X_NUM; - const uint idx_z = dim1; - const uint idx_w = 0; - - const uint idx_arr[INPUT1_DIMS*2] = {idx_b, idx_f, idx_z, idx_y, idx_x, 0, 0, 0, 0, 0}; - const uint idx_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X}; - #else - const uint X_NUM = (INDICES_RANK == 6) ? 1 : INPUT1_SIZE_X; - const uint Z_NUM = (INDICES_RANK == 4) ? 1 : INPUT1_SIZE_Z; - - const uint idx_x = dim0 % X_NUM; - const uint idx_y = dim0 / X_NUM; - const uint idx_z = dim1 % Z_NUM; - const uint idx_w = dim1 / Z_NUM; - - const uint idx_arr[INPUT1_DIMS*2] = {idx_b, idx_f, idx_w, idx_z, idx_y, idx_x, 0, 0, 0, 0, 0, 0}; - const uint idx_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X}; - #endif - - const int idx = GET_UPDATES_INDEX(INPUT1, IDX_ORDER); - - // Calculate data index - uint indices_val[INDICES_MAX_DIM + BATCH_DIMS]; - for (int i = 0; i < INDICES_MAX_DIM + BATCH_DIMS; i++) { - indices_val[i] = 0; - } +#if INPUT1_DIMS == 4 + // const uint idx_x = dim0; // y + // const uint idx_y = dim1; // x + // const uint idx_z = 0; + // const uint idx_w = 0; + const uint idx_x = dim0; + const uint idx_y = dim1; +#elif INPUT1_DIMS == 5 + // const uint idx_x = dim0 / INPUT1_SIZE_Y; // z + // const uint idx_y = dim0 % INPUT1_SIZE_Y; // y + // const uint idx_z = dim1; // x + // const uint idx_w = 0; + const uint idx_x = dim0 % OUTPUT_SIZE_X; + const uint idx_y = dim0 / OUTPUT_SIZE_X; + const uint idx_z = dim1; - for (int i = 0; i < BATCH_DIMS; i++) { - indices_val[i] = idx_arr[i]; +#else + // INPUT1_DIMS == 6 + const uint idx_x = dim0 % OUTPUT_SIZE_X; // x + const uint idx_y = dim0 / OUTPUT_SIZE_X; // y + const uint idx_z = dim1 % OUTPUT_SIZE_Z; // z + const uint idx_w = dim1 / OUTPUT_SIZE_Z; // w +#endif + + const int out_idx = GET_UPDATES_INDEX(INPUT1, IDX_ORDER); + // printf("%d\n", out_idx); + int axis = AXIS; + size_t rank = INPUT0_DIMS; // indices_shape.size(), data_shape.size() +//     printf("rank and axis: %d %d\n", rank, axis); + + size_t data_shape[10] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_X, INPUT0_SIZE_Y, INPUT0_SIZE_Z, INPUT0_SIZE_W}; + size_t indices_shape[10] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_X, INPUT1_SIZE_Y, INPUT1_SIZE_Z, INPUT1_SIZE_W}; + + size_t max_inner_sum = 1, max_outer_sum = 1, outer_sum_inc_data = 1, outer_sum_inc_indices = 1; + for (size_t i = axis + 1; i < rank; i++) + max_inner_sum *= indices_shape[i]; + + for (int i = 0; i < axis; i++) + max_outer_sum *= indices_shape[i]; + + for (size_t i = axis; i < rank; i++) { + outer_sum_inc_data *= data_shape[i]; } + max_outer_sum *= outer_sum_inc_data; - for (int i = 0; i < INDICES_LAST_DIM; i++) { - indices_val[i + BATCH_DIMS] = indices[idx+i]; + for (size_t i = axis; i < rank; i++) { + outer_sum_inc_indices *= indices_shape[i]; } - #if INPUT0_DIMS == 4 - const uint in_x = indices_val[3]; - const uint in_y = indices_val[2]; - #elif INPUT0_DIMS == 5 - const uint in_x = indices_val[4]; - const uint in_y = indices_val[3]; - const uint in_z = indices_val[2]; - #else - const uint in_x = indices_val[5]; - const uint in_y = indices_val[4]; - const uint in_z = indices_val[3]; - const uint in_w = indices_val[2]; - #endif - const uint in_f = indices_val[1]; - const uint in_b = indices_val[0]; - - const uint data_idx = GET_UPDATES_INDEX(INPUT0, IN_ORDER); - - // Calculate output index - #if BATCH_DIMS <= 1 - const uint out_x = idx_x; - const uint out_y = idx_y; - const uint out_z = idx_z; - const uint out_w = idx_w; - const uint out_f = idx_f; - const uint out_b = idx_b; - #else - uint pitch_acc = 1; - uint output_batch_size = 0; - for (int i = BATCH_DIMS - 1; i >= 0; i--) { - output_batch_size += (idx_arr[i] * pitch_acc); - pitch_acc *= idx_dim[i]; - } - - #if OUTPUT_DIMS == 4 - const uint out_x = idx_arr[BATCH_DIMS+2]; - const uint out_y = idx_arr[BATCH_DIMS+1]; - #elif OUTPUT_DIMS == 5 - const uint out_x = idx_arr[BATCH_DIMS+3]; - const uint out_y = idx_arr[BATCH_DIMS+2]; - const uint out_z = idx_arr[BATCH_DIMS+1]; - #else - const uint out_x = idx_arr[BATCH_DIMS+4]; - const uint out_y = idx_arr[BATCH_DIMS+3]; - const uint out_z = idx_arr[BATCH_DIMS+2]; - const uint out_w = idx_arr[BATCH_DIMS+1]; - #endif - const uint out_f = idx_arr[BATCH_DIMS+0]; - const uint out_b = output_batch_size; - #endif - - const uint output_idx = GET_OUTPUT_INDEX(OUT_ORDER); - - // Copy data to output as slice size - #if HAS_FUSED_OPS - #if OUTPUT_DIMS == 4 - const uint y_pitch = OUTPUT_SIZE_X; - const uint f_pitch = y_pitch * OUTPUT_SIZE_Y; - #elif OUTPUT_DIMS == 5 - const uint y_pitch = OUTPUT_SIZE_X; - const uint z_pitch = y_pitch * OUTPUT_SIZE_Y; - const uint f_pitch = z_pitch * OUTPUT_SIZE_Z; - #else - const uint y_pitch = OUTPUT_SIZE_X; - const uint z_pitch = y_pitch * OUTPUT_SIZE_Y; - const uint w_pitch = z_pitch * OUTPUT_SIZE_Z; - const uint f_pitch = w_pitch * OUTPUT_SIZE_W; - #endif - const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM; - #endif - - for (int i = 0; i < WI_SLICE_SIZE; i++) { - uint dst_idx = output_idx + i; - INPUT0_TYPE val = data[data_idx + i]; - - #if HAS_FUSED_OPS - const uint b_remain = dst_idx % b_pitch; - const uint f_remain = b_remain % f_pitch; - #if OUTPUT_DIMS == 4 - const uint y_remain = f_remain % y_pitch; - - const uint y = f_remain / y_pitch; - #elif OUTPUT_DIMS == 5 - const uint z_remain = f_remain % z_pitch; - const uint y_remain = z_remain % y_pitch; - - const uint z = f_remain / z_pitch; - const uint y = z_remain / y_pitch; - #else - const uint w_remain = f_remain % w_pitch; - const uint z_remain = w_remain % z_pitch; - const uint y_remain = z_remain % y_pitch; - - const uint w = f_remain / w_pitch; - const uint z = w_remain / z_pitch; - const uint y = z_remain / y_pitch; - #endif - const uint b = dst_idx / b_pitch; - const uint f = b_remain / f_pitch; - const uint x = y_remain; - - #if FUSED_OPS_CAN_USE_PRELOAD - FUSED_OPS_PRELOAD; - FUSED_OPS_CALC; - #else - FUSED_OPS; - #endif - - output[dst_idx] = FUSED_OPS_RESULT; - #else - output[dst_idx] = ACTIVATION(val, ACTIVATION_PARAMS); - #endif +//     printf("max_inner_sum: %ld\n", max_inner_sum); +//     printf("outer_sum_inc_data: %ld\n",outer_sum_inc_data); +//     printf("max_inner_sum, max_outer_sum, outer_sum_inc_data: %d %d %d\n",max_inner_sum, max_outer_sum, outer_sum_inc); + +// ======================================================================================== + + size_t outer_sum = (out_idx / outer_sum_inc_indices) * outer_sum_inc_data; + size_t inner_sum = out_idx % max_inner_sum; + if (indices[out_idx] < 0 || indices[out_idx] >= data_shape[axis]) { + printf("indices values of GatherElement exceed data size.\n"); + return; } + uint idx = outer_sum + max_inner_sum * indices[out_idx] + inner_sum; + uint tmp = outer_sum; + + INPUT0_TYPE val = data[idx]; + output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS); + + // output[out_idx] = TO_OUTPUT_TYPE(axis); + // output[out_idx] = axis; +// ======================================================================================== + + // output[out_idx] = TO_OUTPUT_TYPE(out_idx); + +// ======================================================================================== + + // for (size_t outer_sum = 0, i = 0; outer_sum < max_outer_sum; outer_sum += outer_sum_inc_data) { + // for (size_t k = 0; k < indices_shape[axis]; k++) { + // for (size_t inner_sum = 0; inner_sum < max_inner_sum; inner_sum++) { + // if (indices[i] < 0 || indices[i] >= data_shape[axis]) + // { + // printf("indices values of GatherElement exceed data size.\n"); + // return; + // } + + // // uint idx = outer_sum + max_inner_sum * indices[i] + inner_sum; + // uint idx = outer_sum; + // // uint idx = max_inner_sum * indices[i]; + // // INPUT0_TYPE val = data[idx]; + // // output[i] = ACTIVATION(val, ACTIVATION_PARAMS); + // output[i] = idx; + // // output[output_idx] = TO_OUTPUT_TYPE(val); + // i++; + // } + // } + // } } #undef INDICES_MAX_DIM diff --git a/inference-engine/thirdparty/clDNN/src/gather_elements.cpp b/inference-engine/thirdparty/clDNN/src/gather_elements.cpp index e8209186708b63..80e912233955fc 100644 --- a/inference-engine/thirdparty/clDNN/src/gather_elements.cpp +++ b/inference-engine/thirdparty/clDNN/src/gather_elements.cpp @@ -37,55 +37,14 @@ layout gather_elements_inst::calc_output_layout(gather_elements_node const& node auto indices_layout = indices_layout_origin.size.sizes(indices_layout_origin.format); // const size_t input_dims = input_layout.size(); + auto output_type = indices_layout_origin.data_type; + auto output_format = op->output_format; + auto output_shape = op->output_shape; - // const auto indices_rank = op->indices_rank; - const auto axis = op->axis; + // const auto axis = op->axis; // calculate initial output shape - std::vector output_sizes; - - // for (uint8_t x = 0; x < indices_rank - 1; x++) { - // output_sizes.push_back(indices_layout[x]); - // } - - // const size_t indices_last_dim = indices_layout[indices_rank - 1]; - // for (size_t x = static_cast(axis + indices_last_dim); x < input_dims; x++) { - // output_sizes.push_back(input_layout[x]); - // } - - // // calculate batch_size by axis - // int batch_size = 1; - // for (uint8_t x = 0; x < axis; x++) { - // batch_size *= output_sizes[x]; - // } - - // create final output shape by axis - std::vector final_output_sizes; - - // if (axis > 0) { - // final_output_sizes.push_back(batch_size); - // } - - for (size_t x = static_cast(axis); x < output_sizes.size(); x++) { - final_output_sizes.push_back(output_sizes[x]); - } - - auto output_format = cldnn::format::bfyx; - if (final_output_sizes.size() >= 6) { - output_format = cldnn::format::bfwzyx; - } else if (final_output_sizes.size() == 5) { - output_format = cldnn::format::bfzyx; - } - - auto output_sizes_tensor = tensor(tensor(final_output_sizes).sizes(output_format)); - auto padding = op->output_padding; - - - // if (node.has_fused_primitives()) { - // input_layout_origin.data_type = node.get_fused_output_layout().data_type; - // } - - return layout(input_layout_origin.data_type, output_format, output_sizes_tensor, padding); + return layout(output_type, output_format, output_shape); } std::string gather_elements_inst::to_string(gather_elements_node const& node) { @@ -99,9 +58,9 @@ std::string gather_elements_inst::to_string(gather_elements_node const& node) { gather_elements_info.add("input id", input.id()); gather_elements_info.add("input shape", node.input(0).get_output_layout().size.to_string()); gather_elements_info.add("indices shape", node.input(1).get_output_layout().size.to_string()); - // gather_elements_info.add("indices rank", desc->indices_rank); + gather_elements_info.add("output format", calc_output_layout(node).format); + gather_elements_info.add("output shape", calc_output_layout(node).size.to_string()); gather_elements_info.add("axis", desc->axis); - // gather_elements_info.add("output shape", calc_output_layout(node).size.to_string()); node_info->add("gather_elements info", gather_elements_info); node_info->dump(primitive_description); diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp index aa043a1e8d6dad..12aaf82910b2ca 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp @@ -31,14 +31,13 @@ inline void DoTest(const engine& engine, const cldnn::memory& input0, // data const cldnn::memory& input1, // indices const std::vector& expected_results, - // const int indices_rank, + const tensor& output_tensor, const int axis) { topology topology; topology.add(input_layout("InputData", input0.get_layout())); topology.add(input_layout("InputIndices", input1.get_layout())); - int indices_rank = 2; topology.add( - gather_elements("gather_elements", "InputData", "InputIndices", indices_rank, axis) + gather_elements("gather_elements", "InputData", "InputIndices", input1.get_layout().format, output_tensor, axis) ); network network(engine, topology); @@ -50,6 +49,8 @@ inline void DoTest(const engine& engine, auto output_ptr = output.pointer(); for (size_t i = 0; i < expected_results.size(); ++i) { + // printf("%ld : %f %f\n", i, expected_results[i], float16_to_float32(output_ptr[i]) ); + // printf("%ld : %f\n", i, float16_to_float32(output_ptr[i]) ); EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); } } @@ -119,7 +120,7 @@ TEST(gather_elements_gpu_fp16, d2235_i2237_a3) { FLOAT16(1), FLOAT16(1), FLOAT16(10), FLOAT16(10), FLOAT16(0), FLOAT16(10), FLOAT16(0), }; - DoTest(engine,input0, input1, expected_results, axis); + DoTest(engine,input0, input1, expected_results, tensor(2, 2, 3, 7), axis); } // 4-1-2 @@ -187,7 +188,7 @@ TEST(gather_elements_gpu_fp16, d2235_i2237_an1) { FLOAT16(1), FLOAT16(1), FLOAT16(10), FLOAT16(10), FLOAT16(0), FLOAT16(10), FLOAT16(0), }; - DoTest(engine,input0, input1, expected_results, axis); + DoTest(engine,input0, input1, expected_results, tensor(2, 2, 3, 7), axis); } // 4-2 @@ -249,7 +250,7 @@ TEST(gather_elements_gpu_fp16, d2329_i2329_a2) { FLOAT16(2), FLOAT16(5), FLOAT16(3), FLOAT16(5), FLOAT16(1), FLOAT16(1), FLOAT16(4), FLOAT16(8), FLOAT16(0), }; - DoTest(engine,input0, input1, expected_results, axis); + DoTest(engine,input0, input1, expected_results, tensor(2, 3, 2, 9), axis); } // 4-3 @@ -322,7 +323,7 @@ TEST(gather_elements_gpu_fp16, d3238_i2238_a0) { FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(3), FLOAT16(3), FLOAT16(10), FLOAT16(6), FLOAT16(1), }; - DoTest(engine,input0, input1, expected_results, axis); + DoTest(engine,input0, input1, expected_results, tensor(2, 2, 3, 8), axis); } // 5-1 @@ -330,8 +331,8 @@ TEST(gather_elements_gpu_fp16, d32223_i32228_a4) { const auto& engine = get_test_engine(); const int axis = 4; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 2, 2, 3 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 2, 2, 8 } }); // indices + auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 2, 2, 3 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 2, 2, 8 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), @@ -419,7 +420,7 @@ TEST(gather_elements_gpu_fp16, d32223_i32228_a4) { FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(4), }; - DoTest(engine,input0, input1, expected_results, axis); + DoTest(engine,input0, input1, expected_results, tensor(3, 2, 2, 2, 8), axis); } // 5-2 @@ -427,8 +428,8 @@ TEST(gather_elements_gpu_fp16, d23327_i23327_a3) { const auto& engine = get_test_engine(); const int axis = 3; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 3, 2, 7 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 3, 2, 7 } }); // indices + auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 3, 3, 2, 7 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 3, 3, 2, 7 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), @@ -547,7 +548,7 @@ TEST(gather_elements_gpu_fp16, d23327_i23327_a3) { FLOAT16(3), FLOAT16(10), FLOAT16(9), FLOAT16(4), FLOAT16(6), FLOAT16(8), FLOAT16(0), }; - DoTest(engine,input0, input1, expected_results, axis); + DoTest(engine,input0, input1, expected_results, tensor(2, 3, 3, 2, 7), axis); } // 6-1 @@ -555,8 +556,8 @@ TEST(gather_elements_gpu_fp16, d232328_i232328_a3) { const auto& engine = get_test_engine(); const int axis = 3; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 2, 3, 2, 8 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 2, 3, 2, 8 } }); // indices + auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 2, 3, 2, 8 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 2, 3, 2, 8 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), @@ -783,16 +784,15 @@ TEST(gather_elements_gpu_fp16, d232328_i232328_a3) { FLOAT16(0), FLOAT16(1), FLOAT16(5), FLOAT16(7), FLOAT16(5), FLOAT16(5), FLOAT16(0), FLOAT16(5), }; - DoTest(engine,input0, input1, expected_results, axis); + DoTest(engine,input0, input1, expected_results, tensor(2, 3, 2, 3, 2, 8), axis); } - // 6-2 TEST(gather_elements_gpu_fp16, d222443_i222446_a5) { const auto& engine = get_test_engine(); const int axis = 5; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 2, 4, 4, 3 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 2, 4, 4, 6 } }); // indices + auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 2, 4, 4, 3 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 2, 4, 4, 6 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), @@ -1187,7 +1187,7 @@ TEST(gather_elements_gpu_fp16, d222443_i222446_a5) { FLOAT16(3), FLOAT16(3), FLOAT16(7), FLOAT16(8), FLOAT16(3), FLOAT16(8), }; - DoTest(engine,input0, input1, expected_results, axis); + DoTest(engine,input0, input1, expected_results, tensor(2, 2, 2, 4, 4, 6), axis); } // TEST(gather_elements_gpu_fp16, d32223_i32228_a4) { @@ -1207,4 +1207,4 @@ TEST(gather_elements_gpu_fp16, d222443_i222446_a5) { // }; // DoTest(engine,input0, input1, expected_results, axis); -// } \ No newline at end of file +// } From 845963d3f2d2171aedcf22d20fd7e804fa78b6f7 Mon Sep 17 00:00:00 2001 From: yunji Date: Wed, 7 Jul 2021 01:16:27 +0900 Subject: [PATCH 03/11] Add Fusing Test implementation --- .../gather/gather_elements_kernel_ref.cpp | 15 +- .../core/cl_kernels/gather_elements_ref.cl | 103 ++++----- .../thirdparty/clDNN/src/gather_elements.cpp | 4 + .../tests/test_cases/fusings_gpu_test.cpp | 216 +++++++++++++++++- 4 files changed, 276 insertions(+), 62 deletions(-) diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp index 307fa7c0f996f9..a4649fc80176ff 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp @@ -77,7 +77,8 @@ CommonDispatchData GatherElementsKernelRef::SetDefault(const gather_elements_par break; case DataLayout::bfzyx: - dispatchData.gws = {output.X().v * output.Y().v, output.Z().v, output.Feature().v * output.Batch().v}; + // dispatchData.gws = {output.X().v * output.Y().v, output.Z().v, output.Feature().v * output.Batch().v}; + dispatchData.gws = {output.X().v, output.Y().v * output.Z().v, output.Feature().v * output.Batch().v}; break; case DataLayout::bfwzyx: @@ -130,6 +131,7 @@ CommonDispatchData GatherElementsKernelRef::SetDefault(const gather_elements_par JitConstants GatherElementsKernelRef::GetJitConstants(const gather_elements_params& params) const { JitConstants jit = MakeBaseParamsJitConstants(params); + // parameters in gather_elements_kernel_ref.h auto p_axis = static_cast(params.axis); if (p_axis < 0) { @@ -139,7 +141,8 @@ JitConstants GatherElementsKernelRef::GetJitConstants(const gather_elements_para jit.AddConstant(MakeJitConstant("AXIS", p_axis)); if (!params.fused_ops.empty()) { - FusedOpsConfiguration conf = { "", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() }; + std::vector idx_order = GetDefaultOrder(params.inputs[0].GetDims().size()); + FusedOpsConfiguration conf = { "", idx_order, "val", params.inputs[0].GetDType() }; jit.Merge(MakeFusedOpsJitConstants(params, { conf })); } @@ -177,10 +180,10 @@ bool GatherElementsKernelRef::Validate(const Params& p, const optional_params& o // } // } - // for (auto& fused_op : params.fused_ops) { - // if (!IsFusedPrimitiveSupported(fused_op)) - // return false; - // } + for (auto& fused_op : params.fused_ops) { + if (!IsFusedPrimitiveSupported(fused_op)) + return false; + } return true; } diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl index b4800d7392be05..40c48356ef7cb3 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl @@ -13,21 +13,13 @@ // limitations under the License. #include "include/fetch.cl" +#include "include/include_all.cl" #define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order) #define GET_OUTPUT_INDEX(idx_order) OUTPUT_GET_INDEX(idx_order) -#define ORDER b,f,y,x #define IN_ORDER in_b,in_f,in_y,in_x -#if INPUT1_DIMS == 4 - #define IDX_ORDER idx_b,idx_f,idx_y,idx_x -#elif INPUT1_DIMS == 5 - #define IDX_ORDER idx_b,idx_f,idx_z,idx_y,idx_x -#else - #define IDX_ORDER idx_b,idx_f,idx_w,idx_z,idx_y,idx_x -#endif - #define OUT_ORDER out_b,out_f,out_y,out_x #define GET_INDEX(prefix, num, idx_order) CAT(CAT(prefix, num), _GET_INDEX)(idx_order) @@ -35,42 +27,58 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, const __global INPUT1_TYPE* indices, - __global OUTPUT_TYPE* output) + __global OUTPUT_TYPE* output +#if HAS_FUSED_OPS_DECLS + , FUSED_OPS_DECLS +#endif +) { const uint dim0 = get_global_id(0); const uint dim1 = get_global_id(1); const uint dim2 = get_global_id(2); // Calculate indice index - const uint F_NUM = INPUT1_FEATURE_NUM; - const uint idx_f = dim2 % F_NUM; - const uint idx_b = dim2 / F_NUM; - + // const uint idx_f = dim2 % INPUT1_FEATURE_NUM; + // const uint idx_b = dim2 / INPUT1_FEATURE_NUM; + const uint f = dim2 % OUTPUT_FEATURE_NUM; + // const uint f = 0; + const uint b = dim2 / OUTPUT_FEATURE_NUM; #if INPUT1_DIMS == 4 - // const uint idx_x = dim0; // y - // const uint idx_y = dim1; // x - // const uint idx_z = 0; - // const uint idx_w = 0; - const uint idx_x = dim0; - const uint idx_y = dim1; + #define IDX_ORDER idx_b,idx_f,idx_y,idx_x + #define ORDER b,f,y,x + const uint x = dim0; + const uint y = dim1; + + // const uint idx_x = dim0; + // const uint idx_y = dim1; #elif INPUT1_DIMS == 5 - // const uint idx_x = dim0 / INPUT1_SIZE_Y; // z - // const uint idx_y = dim0 % INPUT1_SIZE_Y; // y - // const uint idx_z = dim1; // x - // const uint idx_w = 0; - const uint idx_x = dim0 % OUTPUT_SIZE_X; - const uint idx_y = dim0 / OUTPUT_SIZE_X; - const uint idx_z = dim1; + #define ORDER b,f,z,y,x + #define IDX_ORDER idx_b,idx_f,idx_z,idx_y,idx_x + const uint x = dim0; + const uint y = dim1 % OUTPUT_SIZE_Y; + const uint z = dim1 / OUTPUT_SIZE_Y; + // x*y, z + // const uint idx_x = dim0 % OUTPUT_SIZE_X; + // const uint idx_y = dim0 / OUTPUT_SIZE_X; + // const uint idx_z = dim1; #else - // INPUT1_DIMS == 6 - const uint idx_x = dim0 % OUTPUT_SIZE_X; // x - const uint idx_y = dim0 / OUTPUT_SIZE_X; // y - const uint idx_z = dim1 % OUTPUT_SIZE_Z; // z - const uint idx_w = dim1 / OUTPUT_SIZE_Z; // w + #define IDX_ORDER idx_b,idx_f,idx_w,idx_z,idx_y,idx_x + #define ORDER b,f,w,z,y,x + const uint x = dim0 % OUTPUT_SIZE_X; + const uint y = dim0 / OUTPUT_SIZE_X; + const uint z = dim1 % OUTPUT_SIZE_Z; + const uint w = dim1 / OUTPUT_SIZE_Z; + // const uint f = dim2 % OUTPUT_FEATURE_NUM; + // const uint b = dim2 / OUTPUT_FEATURE_NUM; + // const uint idx_x = dim0 % OUTPUT_SIZE_X; // x + // const uint idx_y = dim0 / OUTPUT_SIZE_X; // y + // const uint idx_z = dim1 % OUTPUT_SIZE_Z; // z + // const uint idx_w = dim1 / OUTPUT_SIZE_Z; // w #endif - const int out_idx = GET_UPDATES_INDEX(INPUT1, IDX_ORDER); + // const int out_idx = GET_UPDATES_INDEX(INPUT1, IDX_ORDER); + const int out_idx = GET_UPDATES_INDEX(INPUT1, ORDER); // printf("%d\n", out_idx); int axis = AXIS; size_t rank = INPUT0_DIMS; // indices_shape.size(), data_shape.size() @@ -111,7 +119,7 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, uint tmp = outer_sum; INPUT0_TYPE val = data[idx]; - output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS); + // output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS); // output[out_idx] = TO_OUTPUT_TYPE(axis); // output[out_idx] = axis; @@ -120,27 +128,12 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, // output[out_idx] = TO_OUTPUT_TYPE(out_idx); // ======================================================================================== - - // for (size_t outer_sum = 0, i = 0; outer_sum < max_outer_sum; outer_sum += outer_sum_inc_data) { - // for (size_t k = 0; k < indices_shape[axis]; k++) { - // for (size_t inner_sum = 0; inner_sum < max_inner_sum; inner_sum++) { - // if (indices[i] < 0 || indices[i] >= data_shape[axis]) - // { - // printf("indices values of GatherElement exceed data size.\n"); - // return; - // } - - // // uint idx = outer_sum + max_inner_sum * indices[i] + inner_sum; - // uint idx = outer_sum; - // // uint idx = max_inner_sum * indices[i]; - // // INPUT0_TYPE val = data[idx]; - // // output[i] = ACTIVATION(val, ACTIVATION_PARAMS); - // output[i] = idx; - // // output[output_idx] = TO_OUTPUT_TYPE(val); - // i++; - // } - // } - // } +#if HAS_FUSED_OPS + FUSED_OPS; + output[out_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT); +#else + output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS); +#endif } #undef INDICES_MAX_DIM diff --git a/inference-engine/thirdparty/clDNN/src/gather_elements.cpp b/inference-engine/thirdparty/clDNN/src/gather_elements.cpp index 80e912233955fc..869c31e322e359 100644 --- a/inference-engine/thirdparty/clDNN/src/gather_elements.cpp +++ b/inference-engine/thirdparty/clDNN/src/gather_elements.cpp @@ -36,6 +36,10 @@ layout gather_elements_inst::calc_output_layout(gather_elements_node const& node auto input_layout = input_layout_origin.size.sizes(input_layout_origin.format); auto indices_layout = indices_layout_origin.size.sizes(indices_layout_origin.format); + if (node.has_fused_primitives()) { + input_layout_origin.data_type = node.get_fused_output_layout().data_type; + } + // const size_t input_dims = input_layout.size(); auto output_type = indices_layout_origin.data_type; auto output_format = op->output_format; diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp index 0fea8be6648aa4..58b0edb34e0e54 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp @@ -178,6 +178,10 @@ class BaseFusingTest : public ::testing::TestWithParam { description << " " << i.original_id << " " << i.kernel_id << std::endl; } SCOPED_TRACE(description.str()); + // std::cout << "Count reorder? " << count_reorder << std::endl; + // std::cout << "(executed primitives) fused, not fused: " << fused.get_executed_primitives().size() << ", " << not_fused.get_executed_primitives().size() << std::endl; + // std::cout << "(reorder count) fused, not fused: " << reorders_count_fused << ", " << reorders_count_not_fused << std::endl; + // std::cout << "(exepected) fused, not fused: " << p.expected_fused_primitives << ", " << p.expected_not_fused_primitives << std::endl; // Subtract reorders count to handle execution in different layouts when input/output reorders can be added in the graph ASSERT_EQ(fused.get_executed_primitives().size() - (count_reorder ? 0 : reorders_count_fused), p.expected_fused_primitives); ASSERT_EQ(not_fused.get_executed_primitives().size() - (count_reorder ? 0 : reorders_count_not_fused), p.expected_not_fused_primitives); @@ -8412,4 +8416,214 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_nd_activation_scale_eltwise, gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_2, 2, 5 }, gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_3, 2, 5 }, gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_4, 2, 5 }, -})); +}), ); + + +/* ----------------------------------------------------------------------------------------------------- */ +/* ------------------------------------------ GatherElements cases ------------------------------------------- */ +/* ----------------------------------------------------------------------------------------------------- */ +struct gather_elements_test_params { + data_types data_type; + + format input_format; + tensor input_shape; + + format indices_format; + tensor indices_shape; + + format output_format; + tensor output_shape; + + int axis; + + data_types default_type; + format default_format; + + size_t expected_fused_primitives; + size_t expected_not_fused_primitives; +}; + +#define CASE_GATHER_ELEMENTS_FP16_4D_1 data_types::f16, format::bfyx, {6, 7, 9, 8}, format::bfyx, {3, 1, 1, 1}, format::bfyx, {3, 1, 1, 1}, 0, data_types::f16, format::bfyx +#define CASE_GATHER_ELEMENTS_FP16_4D_2 data_types::f16, format::bfyx, {6, 7, 9, 8}, format::bfyx, {6, 1, 1, 1}, format::bfyx, {6, 1, 1, 1}, 1, data_types::f16, format::bfyx +#define CASE_GATHER_ELEMENTS_FP16_4D_3 data_types::f16, format::bfyx, {2, 2, 3, 5}, format::bfyx, {2, 2, 3, 7}, format::bfyx, {2, 2, 3, 7}, 3, data_types::f16, format::bfyx + +#define CASE_GATHER_ELEMENTS_FP16_5D_1 data_types::f16, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 8}, format::bfzyx, {3, 2, 2, 2, 8}, 4, data_types::f16, format::bfzyx +#define CASE_GATHER_ELEMENTS_FP16_5D_2 data_types::f16, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 1, 1, 3}, format::bfzyx, {5, 4, 1, 1, 3}, 2, data_types::f16, format::bfzyx + +#define CASE_GATHER_ELEMENTS_FP16_6D_1 data_types::f16, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 4, 6, 7, 8, 2}, 4, data_types::f16, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP16_6D_2 data_types::f16, format::bfwzyx, {2, 3, 2, 3, 2, 8}, format::bfwzyx, {2, 3, 2, 3, 2, 8}, format::bfwzyx, {2, 3, 2, 3, 2, 8}, 3, data_types::f16, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP16_6D_3 data_types::f16, format::bfwzyx, {2, 2, 2, 4, 4, 3}, format::bfwzyx, {2, 2, 2, 4, 4, 6}, format::bfwzyx, {2, 2, 2, 4, 4, 6}, 5, data_types::f16, format::bfwzyx + +#define CASE_GATHER_ELEMENTS_FP32_4D_1 data_types::f32, format::bfyx, {6, 7, 9, 8}, format::bfyx, {3, 1, 1, 1}, format::bfyx, {3, 1, 1, 1}, 0, data_types::f32, format::bfyx +#define CASE_GATHER_ELEMENTS_FP32_4D_2 data_types::f32, format::bfyx, {6, 7, 9, 8}, format::bfyx, {6, 1, 1, 1}, format::bfyx, {6, 1, 1, 1}, 1, data_types::f32, format::bfyx +#define CASE_GATHER_ELEMENTS_FP32_4D_3 data_types::f32, format::bfyx, {2, 2, 3, 5}, format::bfyx, {2, 2, 3, 7}, format::bfyx, {2, 2, 3, 7}, 3, data_types::f32, format::bfyx + +#define CASE_GATHER_ELEMENTS_FP32_5D_1 data_types::f32, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 8}, format::bfzyx, {3, 2, 2, 2, 8}, 4, data_types::f32, format::bfzyx +#define CASE_GATHER_ELEMENTS_FP32_5D_2 data_types::f32, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 1, 1, 3}, format::bfzyx, {5, 4, 1, 1, 3}, 2, data_types::f32, format::bfzyx + +#define CASE_GATHER_ELEMENTS_FP32_6D_1 data_types::f32, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 4, 6, 7, 8, 2}, 4, data_types::f32, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP32_6D_2 data_types::f32, format::bfwzyx, {2, 3, 2, 3, 2, 8}, format::bfwzyx, {2, 3, 2, 3, 2, 8}, format::bfwzyx, {2, 3, 2, 3, 2, 8}, 3, data_types::f32, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP32_6D_3 data_types::f32, format::bfwzyx, {2, 2, 2, 4, 4, 3}, format::bfwzyx, {2, 2, 2, 4, 4, 6}, format::bfwzyx, {2, 2, 2, 4, 4, 6}, 5, data_types::f32, format::bfwzyx + +class GatherElementsPrimitiveFusingTest : public ::BaseFusingTest { +public: + void execute(gather_elements_test_params& p) { + auto input_prim = get_mem(get_input_layout(p)); + network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused); + // network network_fused(this->engine, this->topology_non_fused, bo_not_fused); + network network_fused(this->engine, this->topology_fused, bo_fused); + network_fused.set_input_data("input", input_prim); + network_not_fused.set_input_data("input", input_prim); + compare(network_not_fused, network_fused, p); + } + + layout get_input_layout(gather_elements_test_params& p) { + return layout{ p.data_type, p.input_format, p.input_shape }; + } + + layout get_indices_layout(gather_elements_test_params& p) { + return layout{ p.data_type, p.indices_format, p.indices_shape }; + } + + layout get_output_layout(gather_elements_test_params& p) { + return layout{ p.data_type, p.output_format, p.output_shape }; + } + + layout get_per_channel_layout(gather_elements_test_params& p) { + return layout{ p.default_type, p.default_format, tensor{1, p.output_shape.feature[0], 1, 1} }; + } +}; + +class gather_elements_quantize : public GatherElementsPrimitiveFusingTest {}; +TEST_P(gather_elements_quantize, basic) { + auto p = GetParam(); + create_topologies(input_layout("input", get_input_layout(p)), + // data("gather_elements_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices - 1)), + // data("gather_elements_indices", get_mem(get_indices_layout(p), 0, /*p.max_number_in_indices - 1*/)), + data("gather_elements_indices", get_mem(get_indices_layout(p), 0, 2)), // 2 -> ? + data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)), + data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)), + data("out_lo", get_mem(get_single_element_layout(p), -127)), + data("out_hi", get_mem(get_single_element_layout(p), 127)), + // output format, output shape, axis + gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis), + quantize("quantize", "gather_elements_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8), + reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32) + ); + tolerance = 1.f; + execute(p); +} + +INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_elements_quantize, + ::testing::ValuesIn(std::vector{ + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_1, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_2, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_3, 2, 3 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_1, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_2, 2, 3 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_1, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_2, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_3, 2, 3 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_1, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_2, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_3, 2, 3 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_1, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_2, 2, 3 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_1, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 3 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 3 }, +}), ); + + +class gather_elements_scale_activation : public GatherElementsPrimitiveFusingTest {}; +TEST_P(gather_elements_scale_activation, basic) { + auto p = GetParam(); + create_topologies(input_layout("input", get_input_layout(p)), + // data("gather_indices", get_mem(get_indices_layout(p), 0, static_cast(get_axis_dim(p)))), + data("gather_elements_indices", get_mem(get_indices_layout(p), 0, 2)), // 2 -> ? + data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)), + gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis), + activation("activation", "gather_elements_prim", activation_func::abs), + scale("scale", "activation", "scale_data"), + reorder("reorder_bfyx", "scale", p.default_format, data_types::f32) + ); + + tolerance = 1e-5f; + execute(p); +} + +INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_elements_scale_activation, + ::testing::ValuesIn(std::vector{ + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_1, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_2, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_3, 2, 4 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_1, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_2, 2, 4 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_1, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_2, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_3, 2, 4 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_1, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_2, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_3, 2, 4 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_1, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_2, 2, 4 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_1, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 4 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 4 }, +}), ); + + +class gather_elements_activation_scale_eltwise : public GatherElementsPrimitiveFusingTest {}; +TEST_P(gather_elements_activation_scale_eltwise, basic) { + auto p = GetParam(); + + create_topologies(input_layout("input", get_input_layout(p)), + // data("gather_nd_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices - 1)), + data("gather_elements_indices", get_mem(get_indices_layout(p), 0, 2)), // 2 -> ? + data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255)), + data("eltwise_data", get_mem(get_output_layout(p))), + gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis), + activation("activation", "gather_elements_prim", activation_func::abs), + scale("scale", "activation", "scale_data"), + eltwise("eltwise", { "scale", "eltwise_data" }, eltwise_mode::sum, p.data_type), + reorder("reorder_bfyx", "eltwise", p.default_format, data_types::f32) + ); + + tolerance = 1e-5f; + execute(p); +} + +INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_elements_activation_scale_eltwise, + ::testing::ValuesIn(std::vector{ + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_1, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_2, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_3, 2, 5 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_1, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_2, 2, 5 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_1, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_2, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_3, 2, 5 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_1, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_2, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_3, 2, 5 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_1, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_2, 2, 5 }, + + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_1, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 5 }, + gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 5 }, +}), ); From a6e3296ff7674130e74244d6995075a12e8ac24f Mon Sep 17 00:00:00 2001 From: yunji Date: Wed, 14 Jul 2021 19:54:01 +0900 Subject: [PATCH 04/11] Add functional test implementation --- .../single_layer_tests/gather_elements.cpp | 47 +- .../single_layer_tests/gather_elements.cpp | 399 +++++-- .../single_layer_tests/gather_elements.hpp | 6 +- .../core/cl_kernels/gather_elements_ref.cl | 63 +- .../tests/test_cases/fusings_gpu_test.cpp | 88 +- .../test_cases/gather_elements_gpu_test.cpp | 1045 ++++++++--------- 6 files changed, 980 insertions(+), 668 deletions(-) diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp index dd2f462cb7825b..822fd2bdf783b7 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp @@ -11,17 +11,17 @@ using namespace LayerTestsDefinitions; namespace { const std::vector dPrecisions = { - // InferenceEngine::Precision::FP32, + InferenceEngine::Precision::FP32, InferenceEngine::Precision::FP16, - // InferenceEngine::Precision::I32, - // InferenceEngine::Precision::I64, - // InferenceEngine::Precision::I16, - // InferenceEngine::Precision::U8, - // InferenceEngine::Precision::I8 + InferenceEngine::Precision::I32, + InferenceEngine::Precision::I64, + InferenceEngine::Precision::I16, + InferenceEngine::Precision::U8, + InferenceEngine::Precision::I8 }; const std::vector iPrecisions = { InferenceEngine::Precision::I32, - // InferenceEngine::Precision::I64 + InferenceEngine::Precision::I64 }; INSTANTIATE_TEST_SUITE_P(smoke_set1, GatherElementsLayerTest, @@ -48,37 +48,56 @@ INSTANTIATE_TEST_SUITE_P(smoke_set3, GatherElementsLayerTest, ::testing::Combine( ::testing::Values(std::vector({2, 2, 3, 5})), // Data shape ::testing::Values(std::vector({2, 2, 3, 7})), // Indices shape - ::testing::Values(3), // Axis + ::testing::Values(3, -1), // Axis ::testing::ValuesIn(dPrecisions), ::testing::ValuesIn(iPrecisions), ::testing::Values(CommonTestUtils::DEVICE_CPU)), GatherElementsLayerTest::getTestCaseName); +<<<<<<< HEAD <<<<<<< HEAD INSTANTIATE_TEST_SUITE_P(smoke_set4, GatherElementsLayerTest, ======= INSTANTIATE_TEST_CASE_P(yunji_set2, GatherElementsLayerTest, >>>>>>> Add cldnn unit test implementation +======= +INSTANTIATE_TEST_CASE_P(smoke_set4, GatherElementsLayerTest, +>>>>>>> Add functional test implementation ::testing::Combine( - ::testing::Values(std::vector({3, 2, 2, 2, 3})), // Data shape - ::testing::Values(std::vector({3, 2, 2, 2, 8})), // Indices shape - ::testing::Values(4), // Axis + ::testing::Values(std::vector({3, 2, 3, 8})), // Data shape + ::testing::Values(std::vector({2, 2, 3, 8})), // Indices shape + ::testing::Values(0, -4), // Axis ::testing::ValuesIn(dPrecisions), ::testing::ValuesIn(iPrecisions), ::testing::Values(CommonTestUtils::DEVICE_CPU)), GatherElementsLayerTest::getTestCaseName); +<<<<<<< HEAD <<<<<<< HEAD INSTANTIATE_TEST_SUITE_P(smoke_set5, GatherElementsLayerTest, ======= INSTANTIATE_TEST_CASE_P(yunji_set3, GatherElementsLayerTest, >>>>>>> Add cldnn unit test implementation +======= + +INSTANTIATE_TEST_CASE_P(smoke_set5, GatherElementsLayerTest, +>>>>>>> Add functional test implementation ::testing::Combine( - ::testing::Values(std::vector({2, 2, 2, 4, 4, 3})), // Data shape - ::testing::Values(std::vector({2, 2, 2, 4, 4, 6})), // Indices shape - ::testing::Values(5), // Axis + ::testing::Values(std::vector({3, 2, 3, 4, 8})), // Data shape + ::testing::Values(std::vector({3, 2, 3, 5, 8})), // Indices shape + ::testing::Values(3, -2), // Axis ::testing::ValuesIn(dPrecisions), ::testing::ValuesIn(iPrecisions), ::testing::Values(CommonTestUtils::DEVICE_CPU)), GatherElementsLayerTest::getTestCaseName); + +// INSTANTIATE_TEST_CASE_P(yunji_set35, GatherElementsLayerTest, +// ::testing::Combine( +// ::testing::Values(std::vector({2, 3, 3, 1, 1, 3})), // Data shape +// ::testing::Values(std::vector({2, 3, 5, 1, 1, 3})), // Indices shape +// ::testing::Values(2), // Axis +// ::testing::ValuesIn(dPrecisions), +// ::testing::ValuesIn(iPrecisions), +// ::testing::Values(CommonTestUtils::DEVICE_CPU)), +// GatherElementsLayerTest::getTestCaseName); } // namespace diff --git a/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp index d7ca24b2a85219..180863e9c04958 100644 --- a/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp +++ b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp @@ -2,80 +2,325 @@ // SPDX-License-Identifier: Apache-2.0 // // -// #include -// #include - -// #include "single_layer_tests/gather_elements.hpp" -// #include "common_test_utils/test_constants.hpp" - -// using namespace LayerTestsDefinitions; -// using namespace ngraph::opset6; - -// namespace { - -// const std::vector inputPrecisions = { -// InferenceEngine::Precision::FP32, -// InferenceEngine::Precision::FP16, -// InferenceEngine::Precision::I32, -// }; - -// const std::vector idxPrecisions = { -// InferenceEngine::Precision::I32, -// InferenceEngine::Precision::I64, -// }; - -// set1 -// const auto gatherNDArgsSubset1 = ::testing::Combine( -// ::testing::ValuesIn(std::vector>( -// { {2, 2}, {2, 3, 4} })), // Data shape -// ::testing::ValuesIn(std::vector>( -// { {2, 1}, {2, 1, 1} })), // Indices shape -// ::testing::ValuesIn(std::vector({ 0, 1 })) // Batch dims -// ); - -// INSTANTIATE_TEST_CASE_P(smoke_GatherND_set1, GatherNDLayerTest, -// ::testing::Combine( -// gatherNDArgsSubset1, -// ::testing::ValuesIn(inputPrecisions), -// ::testing::ValuesIn(idxPrecisions), -// ::testing::Values(CommonTestUtils::DEVICE_GPU), -// ::testing::Values({})), -// GatherNDLayerTest::getTestCaseName); - -// // set2 -// const auto gatherNDArgsSubset2 = ::testing::Combine( -// ::testing::ValuesIn(std::vector>( -// { {15, 12, 20, 15, 2}, {15, 12, 18, 7, 17} })), // Data shape -// ::testing::ValuesIn(std::vector>( -// { {15, 12, 2}, {15, 12, 5, 9, 1, 3} })), // Indices shape -// ::testing::ValuesIn(std::vector({ 1, 2 })) // Batch dims -// ); - -// INSTANTIATE_TEST_CASE_P(smoke_GatherND_set2, GatherNDLayerTest, -// ::testing::Combine( -// gatherNDArgsSubset2, -// ::testing::ValuesIn(inputPrecisions), -// ::testing::ValuesIn(idxPrecisions), -// ::testing::Values(CommonTestUtils::DEVICE_GPU), -// ::testing::Values({})), -// GatherNDLayerTest::getTestCaseName); - -// // set3 -// const auto gatherNDArgsSubset3 = ::testing::Combine( -// ::testing::ValuesIn(std::vector>( -// { {4, 3, 2, 5, 5, 2}, {4, 3, 2, 5, 7, 2} })), // Data shape -// ::testing::ValuesIn(std::vector>( -// { {4, 3, 2, 5, 1}, {4, 3, 2, 5, 6, 2} })), // Indices shape -// ::testing::ValuesIn(std::vector({ 3, 4 })) // Batch dims -// ); - -// INSTANTIATE_TEST_CASE_P(smoke_GatherND_set3, GatherNDLayerTest, -// ::testing::Combine( -// gatherNDArgsSubset3, -// ::testing::ValuesIn(inputPrecisions), -// ::testing::ValuesIn(idxPrecisions), -// ::testing::Values(CommonTestUtils::DEVICE_GPU), -// ::testing::Values({})), -// GatherNDLayerTest::getTestCaseName); - -// } // namespace +#include +#include + +#include "single_layer_tests/gather_elements.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; +using namespace ngraph::opset6; + +namespace { + +const std::vector inputPrecisions = { + InferenceEngine::Precision::FP32, + InferenceEngine::Precision::FP16, + InferenceEngine::Precision::I32, +}; + +const std::vector idxPrecisions = { + InferenceEngine::Precision::I32, + InferenceEngine::Precision::I64, +}; + +// ======= CPU Func Test Cases ======== // +INSTANTIATE_TEST_CASE_P(smoke_set1, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector({2, 2})), // Data shape + ::testing::Values(std::vector({2, 2})), // Indices shape + ::testing::ValuesIn(std::vector({-1, 0, 1})), // Axis + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_set2, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector({2, 2, 1})), // Data shape + ::testing::Values(std::vector({4, 2, 1})), // Indices shape + ::testing::ValuesIn(std::vector({0, -3})), // Axis + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_set3, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector({2, 2, 3, 5})), // Data shape + ::testing::Values(std::vector({2, 2, 3, 7})), // Indices shape + ::testing::Values(3, -1), // Axis + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_set4, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector({3, 2, 3, 8})), // Data shape + ::testing::Values(std::vector({2, 2, 3, 8})), // Indices shape + ::testing::Values(0, -4), // Axis + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_set5, GatherElementsLayerTest, + ::testing::Combine( + ::testing::Values(std::vector({3, 2, 3, 4, 8})), // Data shape + ::testing::Values(std::vector({3, 2, 3, 5, 8})), // Indices shape + ::testing::Values(3, -2), // Axis + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + GatherElementsLayerTest::getTestCaseName); + +// ======= Rank 4 ======== // + +const std::vector> ShapesRank4Axis0 = { + std::vector{1, 7, 8, 4}, + std::vector{2, 7, 8, 4}, + std::vector{7, 7, 8, 4}, + std::vector{9, 7, 8, 4}, +}; +const std::vector> ShapesRank4Axis1 = { + std::vector{6, 1, 8, 4}, + std::vector{6, 5, 8, 4}, + std::vector{6, 8, 8, 4}, + std::vector{6, 9, 8, 4}, +}; +const std::vector> ShapesRank4Axis2 = { + std::vector{6, 7, 2, 4}, + std::vector{6, 7, 4, 4}, + std::vector{6, 7, 5, 4}, + std::vector{6, 7, 7, 4}, +}; +const std::vector> ShapesRank4Axis3 = { + std::vector{6, 5, 8, 1}, + std::vector{6, 5, 8, 4}, + std::vector{6, 5, 8, 7}, + std::vector{6, 5, 8, 9}, +}; + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis0, GatherElementsLayerTest, + ::testing::Combine( + ::testing::ValuesIn(ShapesRank4Axis0), // Data shapes + ::testing::ValuesIn(ShapesRank4Axis0), // Indices shpae + // ::testing::ValuesIn(axis0), + ::testing::ValuesIn(std::vector({ 0 })), + ::testing::ValuesIn(inputPrecisions), // Data precision + ::testing::ValuesIn(idxPrecisions), // Indices precision + ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis1, GatherElementsLayerTest, + ::testing::Combine( + ::testing::ValuesIn(ShapesRank4Axis1), // Data shapes + ::testing::ValuesIn(ShapesRank4Axis1), // Indices shpae + ::testing::ValuesIn(std::vector({ 1, -3 })), + ::testing::ValuesIn(inputPrecisions), // Data precision + ::testing::ValuesIn(idxPrecisions), // Indices precision + ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis2, GatherElementsLayerTest, + ::testing::Combine( + ::testing::ValuesIn(ShapesRank4Axis2), // Data shapes + ::testing::ValuesIn(ShapesRank4Axis2), // Indices shpae + ::testing::ValuesIn(std::vector({ 2, -2 })), + ::testing::ValuesIn(inputPrecisions), // Data precision + ::testing::ValuesIn(idxPrecisions), // Indices precision + ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis3, GatherElementsLayerTest, + ::testing::Combine( + ::testing::ValuesIn(ShapesRank4Axis3), // Data shapes + ::testing::ValuesIn(ShapesRank4Axis3), // Indices shpae + ::testing::ValuesIn(std::vector({ 3, -1 })), + ::testing::ValuesIn(inputPrecisions), // Data precision + ::testing::ValuesIn(idxPrecisions), // Indices precision + ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + GatherElementsLayerTest::getTestCaseName); + +// ====== rank = 5 ====== // +const std::vector> ShapesRank5Axis0 = { + std::vector{2, 3, 9, 4, 9}, + std::vector{1, 3, 9, 4, 9}, + std::vector{5, 3, 9, 4, 9}, + std::vector{7, 3, 9, 4, 9}, +}; +const std::vector> ShapesRank5Axis1 = { + std::vector{2, 1, 5, 4, 7}, + std::vector{2, 3, 5, 4, 7}, + std::vector{2, 8, 5, 4, 7}, + std::vector{2, 9, 5, 4, 7}, +}; +const std::vector> ShapesRank5Axis2 = { + std::vector{1, 2, 2, 8, 9}, + std::vector{1, 2, 3, 8, 9}, + std::vector{1, 2, 6, 8, 9}, + std::vector{1, 2, 7, 8, 9}, +}; +const std::vector> ShapesRank5Axis3 = { + std::vector{2, 2, 4, 3, 7}, + std::vector{2, 2, 4, 4, 7}, + std::vector{2, 2, 4, 7, 7}, + std::vector{2, 2, 4, 9, 7}, +}; +const std::vector> ShapesRank5Axis4 = { + std::vector{1, 3, 9, 3, 1}, + std::vector{1, 3, 9, 3, 2}, + std::vector{1, 3, 9, 3, 5}, + std::vector{1, 3, 9, 3, 9}, +}; + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis0, GatherElementsLayerTest, + ::testing::Combine( + ::testing::ValuesIn(ShapesRank5Axis0), // Data shapes + ::testing::ValuesIn(ShapesRank5Axis0), // Indices shpae + ::testing::ValuesIn(std::vector({ 0 })), + ::testing::ValuesIn(inputPrecisions), // Data precision + ::testing::ValuesIn(idxPrecisions), // Indices precision + ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis1, GatherElementsLayerTest, + ::testing::Combine( + ::testing::ValuesIn(ShapesRank5Axis1), // Data shapes + ::testing::ValuesIn(ShapesRank5Axis1), // Indices shpae + ::testing::ValuesIn(std::vector({ 1, -4 })), + ::testing::ValuesIn(inputPrecisions), // Data precision + ::testing::ValuesIn(idxPrecisions), // Indices precision + ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis2, GatherElementsLayerTest, + ::testing::Combine( + ::testing::ValuesIn(ShapesRank5Axis2), // Data shapes + ::testing::ValuesIn(ShapesRank5Axis2), // Indices shpae + ::testing::ValuesIn(std::vector({ 2, -3 })), + ::testing::ValuesIn(inputPrecisions), // Data precision + ::testing::ValuesIn(idxPrecisions), // Indices precision + ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis3, GatherElementsLayerTest, + ::testing::Combine( + ::testing::ValuesIn(ShapesRank5Axis3), // Data shapes + ::testing::ValuesIn(ShapesRank5Axis3), // Indices shpae + ::testing::ValuesIn(std::vector({ 3, -2 })), + ::testing::ValuesIn(inputPrecisions), // Data precision + ::testing::ValuesIn(idxPrecisions), // Indices precision + ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis4, GatherElementsLayerTest, + ::testing::Combine( + ::testing::ValuesIn(ShapesRank5Axis4), // Data shapes + ::testing::ValuesIn(ShapesRank5Axis4), // Indices shpae + ::testing::ValuesIn(std::vector({ 4, -1 })), + ::testing::ValuesIn(inputPrecisions), // Data precision + ::testing::ValuesIn(idxPrecisions), // Indices precision + ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + GatherElementsLayerTest::getTestCaseName); + +// ====== rank = 6 ====== // +const std::vector> ShapesRank6Axis0 = { + std::vector{1, 3, 2, 4, 4, 3}, + std::vector{3, 3, 2, 4, 4, 3}, + std::vector{6, 3, 2, 4, 4, 3}, + std::vector{7, 3, 2, 4, 4, 3}, +}; +const std::vector> ShapesRank6Axis1 = { + std::vector{1, 2, 2, 3, 5, 9}, + std::vector{1, 5, 2, 3, 5, 9}, + std::vector{1, 6, 2, 3, 5, 9}, + std::vector{1, 9, 2, 3, 5, 9}, +}; +const std::vector> ShapesRank6Axis2 = { + std::vector{2, 3, 2, 7, 2, 1}, + std::vector{2, 3, 5, 7, 2, 1}, + std::vector{2, 3, 8, 7, 2, 1}, + std::vector{2, 3, 9, 7, 2, 1}, +}; +const std::vector> ShapesRank6Axis3 = { + std::vector{1, 3, 4, 2, 1, 3}, + std::vector{1, 3, 4, 4, 1, 3}, + std::vector{1, 3, 4, 5, 1, 3}, + std::vector{1, 3, 4, 8, 1, 3}, +}; +const std::vector> ShapesRank6Axis4 = { + std::vector{1, 3, 2, 4, 1, 3}, + std::vector{1, 3, 2, 4, 4, 3}, + std::vector{1, 3, 2, 4, 6, 3}, + std::vector{1, 3, 2, 4, 7, 3}, +}; +const std::vector> ShapesRank6Axis5 = { + std::vector{2, 1, 7, 8, 1, 2}, + std::vector{2, 1, 7, 8, 1, 3}, + std::vector{2, 1, 7, 8, 1, 4}, + std::vector{2, 1, 7, 8, 1, 6}, +}; + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis0, GatherElementsLayerTest, + ::testing::Combine( + ::testing::ValuesIn(ShapesRank6Axis0), // Data shapes + ::testing::ValuesIn(ShapesRank6Axis0), // Indices shpae + ::testing::ValuesIn(std::vector({ 0 })), + ::testing::ValuesIn(inputPrecisions), // Data precision + ::testing::ValuesIn(idxPrecisions), // Indices precision + ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis1, GatherElementsLayerTest, + ::testing::Combine( + ::testing::ValuesIn(ShapesRank6Axis1), // Data shapes + ::testing::ValuesIn(ShapesRank6Axis1), // Indices shpae + ::testing::ValuesIn(std::vector({ 1, -5 })), + ::testing::ValuesIn(inputPrecisions), // Data precision + ::testing::ValuesIn(idxPrecisions), // Indices precision + ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis2, GatherElementsLayerTest, + ::testing::Combine( + ::testing::ValuesIn(ShapesRank6Axis2), // Data shapes + ::testing::ValuesIn(ShapesRank6Axis2), // Indices shpae + ::testing::ValuesIn(std::vector({ 2, -4 })), + ::testing::ValuesIn(inputPrecisions), // Data precision + ::testing::ValuesIn(idxPrecisions), // Indices precision + ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis3, GatherElementsLayerTest, + ::testing::Combine( + ::testing::ValuesIn(ShapesRank6Axis3), // Data shapes + ::testing::ValuesIn(ShapesRank6Axis3), // Indices shpae + ::testing::ValuesIn(std::vector({ 3, -3 })), + ::testing::ValuesIn(inputPrecisions), // Data precision + ::testing::ValuesIn(idxPrecisions), // Indices precision + ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis4, GatherElementsLayerTest, + ::testing::Combine( + ::testing::ValuesIn(ShapesRank6Axis4), // Data shapes + ::testing::ValuesIn(ShapesRank6Axis4), // Indices shpae + ::testing::ValuesIn(std::vector({ 4, -2 })), + ::testing::ValuesIn(inputPrecisions), // Data precision + ::testing::ValuesIn(idxPrecisions), // Indices precision + ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + GatherElementsLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis5, GatherElementsLayerTest, + ::testing::Combine( + ::testing::ValuesIn(ShapesRank6Axis5), // Data shapes + ::testing::ValuesIn(ShapesRank6Axis5), // Indices shpae + ::testing::ValuesIn(std::vector({ 5, -1 })), + ::testing::ValuesIn(inputPrecisions), // Data precision + ::testing::ValuesIn(idxPrecisions), // Indices precision + ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + GatherElementsLayerTest::getTestCaseName); + +} // namespace diff --git a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp index eea88d4abf3183..61313a9cbcff0a 100644 --- a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp +++ b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp @@ -8,8 +8,8 @@ namespace LayerTestsDefinitions { -TEST_P(GatherElementsLayerTest, CompareWithRefs) { - Run(); -} +// TEST_P(GatherElementsLayerTest, CompareWithRefs) { +// Run(); +// } } // namespace LayerTestsDefinitions diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl index 40c48356ef7cb3..882e5a12411297 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl @@ -38,10 +38,7 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, const uint dim2 = get_global_id(2); // Calculate indice index - // const uint idx_f = dim2 % INPUT1_FEATURE_NUM; - // const uint idx_b = dim2 / INPUT1_FEATURE_NUM; const uint f = dim2 % OUTPUT_FEATURE_NUM; - // const uint f = 0; const uint b = dim2 / OUTPUT_FEATURE_NUM; #if INPUT1_DIMS == 4 #define IDX_ORDER idx_b,idx_f,idx_y,idx_x @@ -49,8 +46,6 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, const uint x = dim0; const uint y = dim1; - // const uint idx_x = dim0; - // const uint idx_y = dim1; #elif INPUT1_DIMS == 5 #define ORDER b,f,z,y,x #define IDX_ORDER idx_b,idx_f,idx_z,idx_y,idx_x @@ -58,9 +53,6 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, const uint y = dim1 % OUTPUT_SIZE_Y; const uint z = dim1 / OUTPUT_SIZE_Y; // x*y, z - // const uint idx_x = dim0 % OUTPUT_SIZE_X; - // const uint idx_y = dim0 / OUTPUT_SIZE_X; - // const uint idx_z = dim1; #else #define IDX_ORDER idx_b,idx_f,idx_w,idx_z,idx_y,idx_x @@ -69,24 +61,35 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, const uint y = dim0 / OUTPUT_SIZE_X; const uint z = dim1 % OUTPUT_SIZE_Z; const uint w = dim1 / OUTPUT_SIZE_Z; - // const uint f = dim2 % OUTPUT_FEATURE_NUM; - // const uint b = dim2 / OUTPUT_FEATURE_NUM; - // const uint idx_x = dim0 % OUTPUT_SIZE_X; // x - // const uint idx_y = dim0 / OUTPUT_SIZE_X; // y - // const uint idx_z = dim1 % OUTPUT_SIZE_Z; // z - // const uint idx_w = dim1 / OUTPUT_SIZE_Z; // w #endif - + // const int out_idx = GET_UPDATES_INDEX(INPUT1, IDX_ORDER); const int out_idx = GET_UPDATES_INDEX(INPUT1, ORDER); // printf("%d\n", out_idx); int axis = AXIS; size_t rank = INPUT0_DIMS; // indices_shape.size(), data_shape.size() -//     printf("rank and axis: %d %d\n", rank, axis); - - size_t data_shape[10] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_X, INPUT0_SIZE_Y, INPUT0_SIZE_Z, INPUT0_SIZE_W}; - size_t indices_shape[10] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_X, INPUT1_SIZE_Y, INPUT1_SIZE_Z, INPUT1_SIZE_W}; + // if (out_idx == 10) { + //     printf("rank and axis: %d %d\n", rank, axis); + // } + // if(out_idx == 10) { printf("Axis: %d\n", axis); } +#if INPUT0_DIMS == 4 + // size_t data_shape[10] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_X, INPUT0_SIZE_Y, INPUT0_SIZE_Z, INPUT0_SIZE_W}; + size_t data_shape[10] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X, INPUT0_SIZE_Z, INPUT0_SIZE_W}; + // size_t indices_shape[10] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_X, INPUT1_SIZE_Y, INPUT1_SIZE_Z, INPUT1_SIZE_W}; + size_t indices_shape[10] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Y, INPUT1_SIZE_X, INPUT1_SIZE_Z, INPUT1_SIZE_W}; +#elif INPUT0_DIMS == 5 +// #else + size_t data_shape[10] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X, INPUT0_SIZE_W}; + size_t indices_shape[10] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X, INPUT1_SIZE_W}; +#else + size_t data_shape[10] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_W, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X}; + size_t indices_shape[10] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X}; +#endif + + // 6 5 8 1 : b f y x + // x = 1 + // y = 8 size_t max_inner_sum = 1, max_outer_sum = 1, outer_sum_inc_data = 1, outer_sum_inc_indices = 1; for (size_t i = axis + 1; i < rank; i++) max_inner_sum *= indices_shape[i]; @@ -99,24 +102,39 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, } max_outer_sum *= outer_sum_inc_data; - for (size_t i = axis; i < rank; i++) { + for (size_t i = axis; i < rank; i++) { // 2, 3 outer_sum_inc_indices *= indices_shape[i]; } + if(out_idx == 10) { + // printf("%ld %ld %ld %ld %ld %ld\n", indices[0], indices[1], indices[2], indices[3], indices[4], indices[5]); + // printf("%ld %ld %ld %ld %ld %ld\n", indices[6], indices[7], indices[8], indices[9], indices[10], indices[11]); + // printf("%ld %ld %ld %ld %ld %ld\n", indices[12], indices[13], indices[14], indices[15], indices[16], indices[17]); + + printf("aixs: %ld\n", AXIS); + printf("data: %ld %ld %ld %ld %ld %ld\n", data_shape[0], data_shape[1], data_shape[2], data_shape[3], data_shape[4], data_shape[5]); + printf("indi: %ld %ld %ld %ld %ld %ld\n", indices_shape[0], indices_shape[1], indices_shape[2], indices_shape[3], indices_shape[4], indices_shape[5]); + } + //     printf("max_inner_sum: %ld\n", max_inner_sum); //     printf("outer_sum_inc_data: %ld\n",outer_sum_inc_data); //     printf("max_inner_sum, max_outer_sum, outer_sum_inc_data: %d %d %d\n",max_inner_sum, max_outer_sum, outer_sum_inc); // ======================================================================================== - size_t outer_sum = (out_idx / outer_sum_inc_indices) * outer_sum_inc_data; + size_t outer_sum = (out_idx / outer_sum_inc_indices); + outer_sum *= outer_sum_inc_data; + // size_t outer_sum = (out_idx) * outer_sum_inc_data; size_t inner_sum = out_idx % max_inner_sum; if (indices[out_idx] < 0 || indices[out_idx] >= data_shape[axis]) { - printf("indices values of GatherElement exceed data size.\n"); + printf("indices values of GatherElement exceed data size. %ld %ld \n", out_idx, indices[out_idx]); return; } uint idx = outer_sum + max_inner_sum * indices[out_idx] + inner_sum; uint tmp = outer_sum; + // printf("%d %d, ", out_idx, outer_sum); + // if(out_idx == 10) { printf("outer_sum: %d\n", tmp); } + INPUT0_TYPE val = data[idx]; // output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS); @@ -132,6 +150,7 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, FUSED_OPS; output[out_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT); #else + // output[out_idx] = outer_sum; output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS); #endif } diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp index 58b0edb34e0e54..ba1a8110209b6b 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp @@ -8443,27 +8443,27 @@ struct gather_elements_test_params { size_t expected_not_fused_primitives; }; -#define CASE_GATHER_ELEMENTS_FP16_4D_1 data_types::f16, format::bfyx, {6, 7, 9, 8}, format::bfyx, {3, 1, 1, 1}, format::bfyx, {3, 1, 1, 1}, 0, data_types::f16, format::bfyx -#define CASE_GATHER_ELEMENTS_FP16_4D_2 data_types::f16, format::bfyx, {6, 7, 9, 8}, format::bfyx, {6, 1, 1, 1}, format::bfyx, {6, 1, 1, 1}, 1, data_types::f16, format::bfyx -#define CASE_GATHER_ELEMENTS_FP16_4D_3 data_types::f16, format::bfyx, {2, 2, 3, 5}, format::bfyx, {2, 2, 3, 7}, format::bfyx, {2, 2, 3, 7}, 3, data_types::f16, format::bfyx +#define CASE_GATHER_ELEMENTS_FP16_4D_1 data_types::f16, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, 3, data_types::f16, format::bfyx +#define CASE_GATHER_ELEMENTS_FP16_4D_2 data_types::f16, format::bfyx, {3, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, 0, data_types::f16, format::bfyx +#define CASE_GATHER_ELEMENTS_FP16_4D_3 data_types::f16, format::bfyx, {1, 3, 2, 9}, format::bfyx, {1, 3, 5, 9}, format::bfyx, {1, 3, 5, 9}, -1, data_types::f16, format::bfyx -#define CASE_GATHER_ELEMENTS_FP16_5D_1 data_types::f16, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 8}, format::bfzyx, {3, 2, 2, 2, 8}, 4, data_types::f16, format::bfzyx -#define CASE_GATHER_ELEMENTS_FP16_5D_2 data_types::f16, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 1, 1, 3}, format::bfzyx, {5, 4, 1, 1, 3}, 2, data_types::f16, format::bfzyx +#define CASE_GATHER_ELEMENTS_FP16_5D_1 data_types::f16, format::bfzyx, {3, 2, 5, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, 4, data_types::f16, format::bfzyx +#define CASE_GATHER_ELEMENTS_FP16_5D_2 data_types::f16, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 7, 4, 3}, format::bfzyx, {5, 4, 7, 4, 3}, 2, data_types::f16, format::bfzyx #define CASE_GATHER_ELEMENTS_FP16_6D_1 data_types::f16, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 4, 6, 7, 8, 2}, 4, data_types::f16, format::bfwzyx -#define CASE_GATHER_ELEMENTS_FP16_6D_2 data_types::f16, format::bfwzyx, {2, 3, 2, 3, 2, 8}, format::bfwzyx, {2, 3, 2, 3, 2, 8}, format::bfwzyx, {2, 3, 2, 3, 2, 8}, 3, data_types::f16, format::bfwzyx -#define CASE_GATHER_ELEMENTS_FP16_6D_3 data_types::f16, format::bfwzyx, {2, 2, 2, 4, 4, 3}, format::bfwzyx, {2, 2, 2, 4, 4, 6}, format::bfwzyx, {2, 2, 2, 4, 4, 6}, 5, data_types::f16, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP16_6D_2 data_types::f16, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 1}, -2, data_types::f16, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP16_6D_3 data_types::f16, format::bfwzyx, {2, 2, 3, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, 5, data_types::f16, format::bfwzyx -#define CASE_GATHER_ELEMENTS_FP32_4D_1 data_types::f32, format::bfyx, {6, 7, 9, 8}, format::bfyx, {3, 1, 1, 1}, format::bfyx, {3, 1, 1, 1}, 0, data_types::f32, format::bfyx -#define CASE_GATHER_ELEMENTS_FP32_4D_2 data_types::f32, format::bfyx, {6, 7, 9, 8}, format::bfyx, {6, 1, 1, 1}, format::bfyx, {6, 1, 1, 1}, 1, data_types::f32, format::bfyx -#define CASE_GATHER_ELEMENTS_FP32_4D_3 data_types::f32, format::bfyx, {2, 2, 3, 5}, format::bfyx, {2, 2, 3, 7}, format::bfyx, {2, 2, 3, 7}, 3, data_types::f32, format::bfyx +#define CASE_GATHER_ELEMENTS_FP32_4D_1 data_types::f32, format::bfyx, {6, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, 3, data_types::f32, format::bfyx +#define CASE_GATHER_ELEMENTS_FP32_4D_2 data_types::f32, format::bfyx, {3, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, 0, data_types::f32, format::bfyx +#define CASE_GATHER_ELEMENTS_FP32_4D_3 data_types::f32, format::bfyx, {1, 3, 2, 9}, format::bfyx, {1, 3, 5, 9}, format::bfyx, {1, 3, 5, 9}, -1, data_types::f32, format::bfyx -#define CASE_GATHER_ELEMENTS_FP32_5D_1 data_types::f32, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 8}, format::bfzyx, {3, 2, 2, 2, 8}, 4, data_types::f32, format::bfzyx -#define CASE_GATHER_ELEMENTS_FP32_5D_2 data_types::f32, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 1, 1, 3}, format::bfzyx, {5, 4, 1, 1, 3}, 2, data_types::f32, format::bfzyx +#define CASE_GATHER_ELEMENTS_FP32_5D_1 data_types::f32, format::bfzyx, {3, 2, 5, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, 4, data_types::f32, format::bfzyx +#define CASE_GATHER_ELEMENTS_FP32_5D_2 data_types::f32, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 7, 4, 3}, format::bfzyx, {5, 4, 7, 4, 3}, 2, data_types::f32, format::bfzyx #define CASE_GATHER_ELEMENTS_FP32_6D_1 data_types::f32, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 4, 6, 7, 8, 2}, 4, data_types::f32, format::bfwzyx -#define CASE_GATHER_ELEMENTS_FP32_6D_2 data_types::f32, format::bfwzyx, {2, 3, 2, 3, 2, 8}, format::bfwzyx, {2, 3, 2, 3, 2, 8}, format::bfwzyx, {2, 3, 2, 3, 2, 8}, 3, data_types::f32, format::bfwzyx -#define CASE_GATHER_ELEMENTS_FP32_6D_3 data_types::f32, format::bfwzyx, {2, 2, 2, 4, 4, 3}, format::bfwzyx, {2, 2, 2, 4, 4, 6}, format::bfwzyx, {2, 2, 2, 4, 4, 6}, 5, data_types::f32, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP32_6D_2 data_types::f32, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 1}, -2, data_types::f32, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP32_6D_3 data_types::f32, format::bfwzyx, {2, 2, 3, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, 5, data_types::f32, format::bfwzyx class GatherElementsPrimitiveFusingTest : public ::BaseFusingTest { public: @@ -8477,6 +8477,56 @@ class GatherElementsPrimitiveFusingTest : public ::BaseFusingTest(get_axis_dim(p))); create_topologies(input_layout("input", get_input_layout(p)), // data("gather_elements_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices - 1)), + // data("gather_indices", get_mem(get_indices_layout(p), 0, static_cast(get_axis_dim(p)))), // data("gather_elements_indices", get_mem(get_indices_layout(p), 0, /*p.max_number_in_indices - 1*/)), - data("gather_elements_indices", get_mem(get_indices_layout(p), 0, 2)), // 2 -> ? + data("gather_elements_indices", get_mem(get_indices_layout(p), 0, static_cast(get_axis_dim(p))-1)), data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)), data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)), data("out_lo", get_mem(get_single_element_layout(p), -127)), @@ -8545,7 +8597,8 @@ TEST_P(gather_elements_scale_activation, basic) { auto p = GetParam(); create_topologies(input_layout("input", get_input_layout(p)), // data("gather_indices", get_mem(get_indices_layout(p), 0, static_cast(get_axis_dim(p)))), - data("gather_elements_indices", get_mem(get_indices_layout(p), 0, 2)), // 2 -> ? + // data("gather_elements_indices", get_mem(get_indices_layout(p), 0, 2)), // 2 -> ? + data("gather_elements_indices", get_mem(get_indices_layout(p), 0, static_cast(get_axis_dim(p))-1)), data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)), gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis), activation("activation", "gather_elements_prim", activation_func::abs), @@ -8589,7 +8642,8 @@ TEST_P(gather_elements_activation_scale_eltwise, basic) { create_topologies(input_layout("input", get_input_layout(p)), // data("gather_nd_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices - 1)), - data("gather_elements_indices", get_mem(get_indices_layout(p), 0, 2)), // 2 -> ? + // data("gather_elements_indices", get_mem(get_indices_layout(p), 0, 2)), // 2 -> ? + data("gather_elements_indices", get_mem(get_indices_layout(p), 0, static_cast(get_axis_dim(p))-1)), data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255)), data("eltwise_data", get_mem(get_output_layout(p))), gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis), diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp index 12aaf82910b2ca..4fe6708748bbcc 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp @@ -55,235 +55,31 @@ inline void DoTest(const engine& engine, } } -// 4-1-1 -TEST(gather_elements_gpu_fp16, d2235_i2237_a3) { - const auto& engine = get_test_engine(); - - const int axis = 3; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 5 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 7 } }); // indices - - set_values(input0, { - FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), - FLOAT16(2), FLOAT16(0), FLOAT16(7), FLOAT16(7), FLOAT16(10), - FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(0), - - FLOAT16(5), FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), - FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), FLOAT16(9), - FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(7), - - - FLOAT16(10), FLOAT16(8), FLOAT16(2), FLOAT16(0), FLOAT16(8), - FLOAT16(3), FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), - FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), - - FLOAT16(0), FLOAT16(6), FLOAT16(9), FLOAT16(2), FLOAT16(4), - FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), - FLOAT16(1), FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), - }); - - set_values(input1, { - FLOAT16(0), FLOAT16(2), FLOAT16(4), FLOAT16(3), FLOAT16(4), FLOAT16(0), FLOAT16(0), - FLOAT16(1), FLOAT16(4), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), - FLOAT16(1), FLOAT16(1), FLOAT16(3), FLOAT16(1), FLOAT16(4), FLOAT16(2), FLOAT16(4), - - FLOAT16(2), FLOAT16(1), FLOAT16(3), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(4), - FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(3), FLOAT16(4), FLOAT16(2), FLOAT16(2), - FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(3), FLOAT16(4), - - - FLOAT16(3), FLOAT16(4), FLOAT16(4), FLOAT16(1), FLOAT16(0), FLOAT16(3), FLOAT16(2), - FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(0), - FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(4), FLOAT16(3), FLOAT16(0), FLOAT16(2), - - FLOAT16(2), FLOAT16(3), FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(3), - FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(3), FLOAT16(3), FLOAT16(4), FLOAT16(2), - FLOAT16(0), FLOAT16(0), FLOAT16(3), FLOAT16(3), FLOAT16(4), FLOAT16(3), FLOAT16(4), - }); - - std::vector expected_results = { - FLOAT16(0), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(0), FLOAT16(0), - FLOAT16(0), FLOAT16(10), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), - FLOAT16(5), FLOAT16(5), FLOAT16(0), FLOAT16(5), FLOAT16(0), FLOAT16(9), FLOAT16(0), - - FLOAT16(0), FLOAT16(7), FLOAT16(4), FLOAT16(0), FLOAT16(7), FLOAT16(0), FLOAT16(0), - FLOAT16(7), FLOAT16(4), FLOAT16(6), FLOAT16(10), FLOAT16(9), FLOAT16(6), FLOAT16(6), - FLOAT16(7), FLOAT16(7), FLOAT16(5), FLOAT16(7), FLOAT16(5), FLOAT16(4), FLOAT16(7), - - - FLOAT16(0), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(10), FLOAT16(0), FLOAT16(2), - FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(3), FLOAT16(4), FLOAT16(3), - FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(7), FLOAT16(8), FLOAT16(2), FLOAT16(7), - - FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(9), FLOAT16(9), FLOAT16(0), FLOAT16(2), - FLOAT16(5), FLOAT16(5), FLOAT16(8), FLOAT16(3), FLOAT16(3), FLOAT16(3), FLOAT16(2), - FLOAT16(1), FLOAT16(1), FLOAT16(10), FLOAT16(10), FLOAT16(0), FLOAT16(10), FLOAT16(0), - }; - - DoTest(engine,input0, input1, expected_results, tensor(2, 2, 3, 7), axis); -} - -// 4-1-2 -TEST(gather_elements_gpu_fp16, d2235_i2237_an1) { - const auto& engine = get_test_engine(); - - const int axis = -1; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 5 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 7 } }); // indices - - set_values(input0, { - FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), - FLOAT16(2), FLOAT16(0), FLOAT16(7), FLOAT16(7), FLOAT16(10), - FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(0), - - FLOAT16(5), FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), - FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), FLOAT16(9), - FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(7), - +// ======================== Rank 4 ======================== // - FLOAT16(10), FLOAT16(8), FLOAT16(2), FLOAT16(0), FLOAT16(8), - FLOAT16(3), FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), - FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), - - FLOAT16(0), FLOAT16(6), FLOAT16(9), FLOAT16(2), FLOAT16(4), - FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), - FLOAT16(1), FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), - }); - - set_values(input1, { - FLOAT16(0), FLOAT16(2), FLOAT16(4), FLOAT16(3), FLOAT16(4), FLOAT16(0), FLOAT16(0), - FLOAT16(1), FLOAT16(4), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), - FLOAT16(1), FLOAT16(1), FLOAT16(3), FLOAT16(1), FLOAT16(4), FLOAT16(2), FLOAT16(4), - - FLOAT16(2), FLOAT16(1), FLOAT16(3), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(4), - FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(3), FLOAT16(4), FLOAT16(2), FLOAT16(2), - FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(3), FLOAT16(4), - - - FLOAT16(3), FLOAT16(4), FLOAT16(4), FLOAT16(1), FLOAT16(0), FLOAT16(3), FLOAT16(2), - FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(0), - FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(4), FLOAT16(3), FLOAT16(0), FLOAT16(2), - - FLOAT16(2), FLOAT16(3), FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(3), - FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(3), FLOAT16(3), FLOAT16(4), FLOAT16(2), - FLOAT16(0), FLOAT16(0), FLOAT16(3), FLOAT16(3), FLOAT16(4), FLOAT16(3), FLOAT16(4), - }); - - std::vector expected_results = { - FLOAT16(0), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(0), FLOAT16(0), - FLOAT16(0), FLOAT16(10), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), - FLOAT16(5), FLOAT16(5), FLOAT16(0), FLOAT16(5), FLOAT16(0), FLOAT16(9), FLOAT16(0), - - FLOAT16(0), FLOAT16(7), FLOAT16(4), FLOAT16(0), FLOAT16(7), FLOAT16(0), FLOAT16(0), - FLOAT16(7), FLOAT16(4), FLOAT16(6), FLOAT16(10), FLOAT16(9), FLOAT16(6), FLOAT16(6), - FLOAT16(7), FLOAT16(7), FLOAT16(5), FLOAT16(7), FLOAT16(5), FLOAT16(4), FLOAT16(7), - - - FLOAT16(0), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(10), FLOAT16(0), FLOAT16(2), - FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(3), FLOAT16(4), FLOAT16(3), - FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(7), FLOAT16(8), FLOAT16(2), FLOAT16(7), - - FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(9), FLOAT16(9), FLOAT16(0), FLOAT16(2), - FLOAT16(5), FLOAT16(5), FLOAT16(8), FLOAT16(3), FLOAT16(3), FLOAT16(3), FLOAT16(2), - FLOAT16(1), FLOAT16(1), FLOAT16(10), FLOAT16(10), FLOAT16(0), FLOAT16(10), FLOAT16(0), - }; - - DoTest(engine,input0, input1, expected_results, tensor(2, 2, 3, 7), axis); -} - -// 4-2 -TEST(gather_elements_gpu_fp16, d2329_i2329_a2) { - const auto& engine = get_test_engine(); - - const int axis = 2; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 2, 9 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 2, 9 } }); // indices - - set_values(input0, { - FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), FLOAT16(7), - FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(0), - - FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), FLOAT16(9), FLOAT16(5), FLOAT16(1), - FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), FLOAT16(2), FLOAT16(0), FLOAT16(8), FLOAT16(3), - - FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), - FLOAT16(0), FLOAT16(6), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), - - - FLOAT16(3), FLOAT16(1), FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), - FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(4), - - FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(8), FLOAT16(8), FLOAT16(9), - FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(7), FLOAT16(10), FLOAT16(9), - - FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(6), FLOAT16(9), FLOAT16(4), FLOAT16(9), FLOAT16(2), - FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(1), FLOAT16(1), FLOAT16(6), FLOAT16(8), FLOAT16(0), - }); - - set_values(input1, { - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), - FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), - FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), - FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), - FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), - FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), - FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), - FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), - FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), - FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), - }); - - std::vector expected_results = { - FLOAT16(0), FLOAT16(1), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(7), FLOAT16(0), - FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), FLOAT16(7), - FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(7), FLOAT16(6), FLOAT16(2), FLOAT16(0), FLOAT16(5), FLOAT16(1), - FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(10), FLOAT16(8), FLOAT16(2), FLOAT16(9), FLOAT16(5), FLOAT16(3), - FLOAT16(6), FLOAT16(8), FLOAT16(9), FLOAT16(4), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), - FLOAT16(6), FLOAT16(8), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(7), - FLOAT16(3), FLOAT16(1), FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(4), - FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(5), - FLOAT16(3), FLOAT16(0), FLOAT16(7), FLOAT16(9), FLOAT16(10), FLOAT16(0), FLOAT16(8), FLOAT16(10), FLOAT16(9), - FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(7), FLOAT16(8), FLOAT16(9), - FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(3), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(9), FLOAT16(2), - FLOAT16(2), FLOAT16(5), FLOAT16(3), FLOAT16(5), FLOAT16(1), FLOAT16(1), FLOAT16(4), FLOAT16(8), FLOAT16(0), - }; - - DoTest(engine,input0, input1, expected_results, tensor(2, 3, 2, 9), axis); -} - -// 4-3 -TEST(gather_elements_gpu_fp16, d3238_i2238_a0) { +TEST(gather_elements_gpu_fp16, d3283_i2283_a0) { const auto& engine = get_test_engine(); const int axis = 0; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 3, 8 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 8 } }); // indices + auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 8, 3 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 8, 3 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), - FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), FLOAT16(2), FLOAT16(0), FLOAT16(8), FLOAT16(3), FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(9), - - FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(1), FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(4), - FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(8), FLOAT16(8), FLOAT16(9), FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(7), FLOAT16(10), FLOAT16(9), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(6), FLOAT16(9), - - FLOAT16(4), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(1), FLOAT16(1), FLOAT16(6), FLOAT16(8), FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(8), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(9), FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(1), - FLOAT16(1), FLOAT16(3), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(7), FLOAT16(10), FLOAT16(2), FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(4), FLOAT16(5), FLOAT16(1), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(10), FLOAT16(6), FLOAT16(1), @@ -293,16 +89,12 @@ TEST(gather_elements_gpu_fp16, d3238_i2238_a0) { FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), - FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), - - FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), - FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(2), @@ -323,30 +115,24 @@ TEST(gather_elements_gpu_fp16, d3238_i2238_a0) { FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(3), FLOAT16(3), FLOAT16(10), FLOAT16(6), FLOAT16(1), }; - DoTest(engine,input0, input1, expected_results, tensor(2, 2, 3, 8), axis); + DoTest(engine,input0, input1, expected_results, tensor(2, 2, 8, 3), axis); } -// 5-1 -TEST(gather_elements_gpu_fp16, d32223_i32228_a4) { +TEST(gather_elements_gpu_fp16, d2235_i2235_a3) { const auto& engine = get_test_engine(); - const int axis = 4; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 2, 2, 3 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 2, 2, 8 } }); // indices - + const int axis = 3; + auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 5 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 5 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), - FLOAT16(0), FLOAT16(7), FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(5), - FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(0), - FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), - FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), FLOAT16(2), @@ -355,209 +141,333 @@ TEST(gather_elements_gpu_fp16, d32223_i32228_a4) { FLOAT16(4), FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(9), - FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(1), FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), - FLOAT16(9), FLOAT16(5), FLOAT16(5), - FLOAT16(3), FLOAT16(10), FLOAT16(5), - FLOAT16(2), FLOAT16(0), FLOAT16(10), - FLOAT16(0), FLOAT16(5), FLOAT16(4), }); set_values(input1, { - FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), - FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), - FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), - FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), - FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), - FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), - FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), - FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), - FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), - FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), - FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), - FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), - FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), - FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), - FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), - FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), - FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), - FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), - FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), - FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), - FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), - FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), - FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(2), }); std::vector expected_results = { - FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(0), FLOAT16(0), FLOAT16(0), - FLOAT16(2), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), - FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(0), FLOAT16(7), - FLOAT16(4), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(10), FLOAT16(10), FLOAT16(4), FLOAT16(5), - FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(9), FLOAT16(0), FLOAT16(9), - FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(0), - FLOAT16(0), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), - FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(10), FLOAT16(10), FLOAT16(7), FLOAT16(6), FLOAT16(6), - FLOAT16(1), FLOAT16(1), FLOAT16(5), FLOAT16(5), FLOAT16(9), FLOAT16(1), FLOAT16(9), FLOAT16(9), - FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(7), FLOAT16(7), - FLOAT16(8), FLOAT16(2), FLOAT16(8), FLOAT16(2), FLOAT16(10), FLOAT16(10), FLOAT16(8), FLOAT16(2), - FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(8), FLOAT16(8), FLOAT16(3), FLOAT16(0), FLOAT16(0), - FLOAT16(10), FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(8), FLOAT16(6), FLOAT16(6), FLOAT16(10), - FLOAT16(10), FLOAT16(4), FLOAT16(10), FLOAT16(4), FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(4), - FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(7), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(7), - FLOAT16(0), FLOAT16(0), FLOAT16(6), FLOAT16(6), FLOAT16(9), FLOAT16(9), FLOAT16(6), FLOAT16(6), - FLOAT16(4), FLOAT16(2), FLOAT16(8), FLOAT16(4), FLOAT16(2), FLOAT16(4), FLOAT16(2), FLOAT16(4), - FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(2), FLOAT16(5), - FLOAT16(1), FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(3), - FLOAT16(10), FLOAT16(0), FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(0), FLOAT16(10), FLOAT16(0), - FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), - FLOAT16(5), FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(5), FLOAT16(5), - FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(10), FLOAT16(10), - FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(4), + FLOAT16(0), FLOAT16(1), FLOAT16(8), + FLOAT16(2), FLOAT16(2), FLOAT16(5), + FLOAT16(0), FLOAT16(0), FLOAT16(7), + FLOAT16(10), FLOAT16(10), FLOAT16(10), + FLOAT16(0), FLOAT16(9), FLOAT16(0), + FLOAT16(7), FLOAT16(0), FLOAT16(7), + FLOAT16(4), FLOAT16(0), FLOAT16(4), + FLOAT16(6), FLOAT16(7), FLOAT16(10), + FLOAT16(5), FLOAT16(9), FLOAT16(5), + FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(8), FLOAT16(2), FLOAT16(2), + FLOAT16(8), FLOAT16(8), FLOAT16(8), + FLOAT16(8), FLOAT16(6), FLOAT16(10), + FLOAT16(4), FLOAT16(10), FLOAT16(10), + FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(0), FLOAT16(0), FLOAT16(9), + FLOAT16(4), FLOAT16(8), FLOAT16(8), + FLOAT16(3), FLOAT16(3), FLOAT16(5), + FLOAT16(5), FLOAT16(3), FLOAT16(3), + FLOAT16(9), FLOAT16(9), FLOAT16(0), }; - DoTest(engine,input0, input1, expected_results, tensor(3, 2, 2, 2, 8), axis); + DoTest(engine,input0, input1, expected_results, tensor(2, 2, 3, 5), axis); } -// 5-2 -TEST(gather_elements_gpu_fp16, d23327_i23327_a3) { +TEST(gather_elements_gpu_fp16, d1329_i1359_an1) { + const auto& engine = get_test_engine(); + + const int axis = -1; + auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 1, 3, 2, 9 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 1, 3, 5, 9 } }); // indices + set_values(input0, { + FLOAT16(0), FLOAT16(1), + FLOAT16(8), FLOAT16(5), + FLOAT16(5), FLOAT16(2), + FLOAT16(0), FLOAT16(7), + FLOAT16(7), FLOAT16(10), + FLOAT16(4), FLOAT16(5), + FLOAT16(9), FLOAT16(0), + FLOAT16(0), FLOAT16(5), + FLOAT16(7), FLOAT16(0), + FLOAT16(4), FLOAT16(0), + FLOAT16(4), FLOAT16(7), + FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), + FLOAT16(1), FLOAT16(7), + FLOAT16(4), FLOAT16(7), + FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(0), + FLOAT16(8), FLOAT16(3), + FLOAT16(6), FLOAT16(8), + FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(10), + FLOAT16(7), FLOAT16(8), + FLOAT16(7), FLOAT16(0), + FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), + FLOAT16(8), FLOAT16(5), + FLOAT16(2), FLOAT16(3), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(8), FLOAT16(8), FLOAT16(8), FLOAT16(5), FLOAT16(8), + FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), FLOAT16(5), + FLOAT16(0), FLOAT16(7), FLOAT16(0), FLOAT16(7), FLOAT16(7), + FLOAT16(10), FLOAT16(7), FLOAT16(7), FLOAT16(10), FLOAT16(10), + FLOAT16(4), FLOAT16(4), FLOAT16(5), FLOAT16(4), FLOAT16(4), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(9), FLOAT16(9), + FLOAT16(5), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(4), FLOAT16(4), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(7), FLOAT16(7), FLOAT16(7), FLOAT16(4), FLOAT16(7), + FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(10), + FLOAT16(5), FLOAT16(9), FLOAT16(5), FLOAT16(9), FLOAT16(5), + FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(1), FLOAT16(7), + FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(7), FLOAT16(7), + FLOAT16(8), FLOAT16(10), FLOAT16(10), FLOAT16(10), FLOAT16(8), + FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(8), FLOAT16(8), FLOAT16(3), FLOAT16(8), FLOAT16(8), + FLOAT16(6), FLOAT16(6), FLOAT16(6), FLOAT16(8), FLOAT16(6), + FLOAT16(10), FLOAT16(4), FLOAT16(10), FLOAT16(10), FLOAT16(10), + FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(10), FLOAT16(10), + FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(7), FLOAT16(7), + FLOAT16(0), FLOAT16(7), FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(6), FLOAT16(9), FLOAT16(9), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(2), FLOAT16(4), FLOAT16(2), FLOAT16(4), + FLOAT16(5), FLOAT16(8), FLOAT16(8), FLOAT16(5), FLOAT16(8), + FLOAT16(3), FLOAT16(3), FLOAT16(2), FLOAT16(3), FLOAT16(3), + }; + + DoTest(engine,input0, input1, expected_results, tensor(1, 3, 5, 9), axis); +} + +// ======================== Rank 5 ======================== // + +TEST(gather_elements_gpu_fp16, d12853_i12923_a3) { const auto& engine = get_test_engine(); const int axis = 3; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 3, 3, 2, 7 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 3, 3, 2, 7 } }); // indices + auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 1, 2, 8, 5, 3 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 1, 2, 8, 2, 3 } }); // indices set_values(input0, { - FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), - FLOAT16(7), FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), - FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(4), - FLOAT16(7), FLOAT16(6), FLOAT16(10), FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), - FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), FLOAT16(2), FLOAT16(0), FLOAT16(8), - FLOAT16(3), FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), FLOAT16(2), FLOAT16(10), - FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(9), FLOAT16(2), - FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(1), - FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), - FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(10), FLOAT16(0), - FLOAT16(5), FLOAT16(4), FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(10), - FLOAT16(0), FLOAT16(8), FLOAT16(8), FLOAT16(9), FLOAT16(1), FLOAT16(0), FLOAT16(7), - FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(7), FLOAT16(10), FLOAT16(9), FLOAT16(2), - FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(6), FLOAT16(9), FLOAT16(4), FLOAT16(9), - FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(1), FLOAT16(1), - FLOAT16(6), FLOAT16(8), FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(8), - FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(9), FLOAT16(1), FLOAT16(2), FLOAT16(7), - FLOAT16(1), FLOAT16(1), FLOAT16(3), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(7), - FLOAT16(10), FLOAT16(2), FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(7), FLOAT16(1), - FLOAT16(7), FLOAT16(4), FLOAT16(4), FLOAT16(5), FLOAT16(1), FLOAT16(6), FLOAT16(9), - FLOAT16(6), FLOAT16(10), FLOAT16(6), FLOAT16(1), FLOAT16(10), FLOAT16(4), FLOAT16(1), - FLOAT16(6), FLOAT16(2), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(1), FLOAT16(2), - FLOAT16(3), FLOAT16(6), FLOAT16(1), FLOAT16(7), FLOAT16(6), FLOAT16(8), FLOAT16(2), - FLOAT16(5), FLOAT16(4), FLOAT16(2), FLOAT16(0), FLOAT16(9), FLOAT16(4), FLOAT16(1), - FLOAT16(10), FLOAT16(4), FLOAT16(1), FLOAT16(9), FLOAT16(1), FLOAT16(1), FLOAT16(0), - FLOAT16(4), FLOAT16(2), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(3), FLOAT16(4), - FLOAT16(8), FLOAT16(10), FLOAT16(7), FLOAT16(2), FLOAT16(7), FLOAT16(9), FLOAT16(2), - FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(6), FLOAT16(8), FLOAT16(8), FLOAT16(5), - FLOAT16(10), FLOAT16(6), FLOAT16(4), FLOAT16(9), FLOAT16(7), FLOAT16(7), FLOAT16(10), - FLOAT16(10), FLOAT16(9), FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(1), FLOAT16(4), - FLOAT16(6), FLOAT16(9), FLOAT16(4), FLOAT16(8), FLOAT16(9), FLOAT16(7), FLOAT16(8), - FLOAT16(7), FLOAT16(8), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(0), - FLOAT16(7), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(10), FLOAT16(9), - FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(4), FLOAT16(10), FLOAT16(2), FLOAT16(4), - FLOAT16(3), FLOAT16(5), FLOAT16(9), FLOAT16(4), FLOAT16(5), FLOAT16(8), FLOAT16(4), - FLOAT16(2), FLOAT16(10), FLOAT16(1), FLOAT16(6), FLOAT16(6), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), + FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(5), + FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(0), FLOAT16(8), FLOAT16(3), FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(3), + FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(8), FLOAT16(8), + FLOAT16(9), FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(7), + FLOAT16(10), FLOAT16(9), FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(6), FLOAT16(9), + FLOAT16(4), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(1), + FLOAT16(1), FLOAT16(6), FLOAT16(8), FLOAT16(0), FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(8), + FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(9), FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(1), + FLOAT16(1), FLOAT16(3), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(7), FLOAT16(10), FLOAT16(2), + FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(4), + FLOAT16(5), FLOAT16(1), FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(10), FLOAT16(6), FLOAT16(1), + FLOAT16(10), FLOAT16(4), FLOAT16(1), FLOAT16(6), FLOAT16(2), FLOAT16(5), FLOAT16(5), FLOAT16(10), + FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(6), FLOAT16(1), FLOAT16(7), FLOAT16(6), FLOAT16(8), + FLOAT16(2), FLOAT16(5), FLOAT16(4), FLOAT16(2), FLOAT16(0), FLOAT16(9), FLOAT16(4), FLOAT16(1), + FLOAT16(10), FLOAT16(4), FLOAT16(1), FLOAT16(9), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(4), + FLOAT16(2), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(3), FLOAT16(4), FLOAT16(8), FLOAT16(10), + FLOAT16(7), FLOAT16(2), FLOAT16(7), FLOAT16(9), FLOAT16(2), FLOAT16(9), FLOAT16(5), FLOAT16(5), + FLOAT16(6), FLOAT16(8), FLOAT16(8), FLOAT16(5), FLOAT16(10), FLOAT16(6), FLOAT16(4), FLOAT16(9), + FLOAT16(7), FLOAT16(7), FLOAT16(10), FLOAT16(10), FLOAT16(9), FLOAT16(3), FLOAT16(5), FLOAT16(5), + FLOAT16(1), FLOAT16(4), FLOAT16(6), FLOAT16(9), FLOAT16(4), FLOAT16(8), FLOAT16(9), FLOAT16(7), + FLOAT16(8), FLOAT16(7), FLOAT16(8), FLOAT16(0), FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(0), + FLOAT16(7), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(10), FLOAT16(9), FLOAT16(9), + FLOAT16(5), FLOAT16(1), FLOAT16(4), FLOAT16(10), FLOAT16(2), FLOAT16(4), FLOAT16(3), FLOAT16(5), }); set_values(input1, { - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), - FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), - FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), - FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), - FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), - FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), - FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), - FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), - FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), - FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), - FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), - FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), - FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), - FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), - FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), - FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), - FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), - FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), - FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), - FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), - FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), - FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), - FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), - FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), - FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), - FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), - FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(0), FLOAT16(2), FLOAT16(4), FLOAT16(3), FLOAT16(4), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(4), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(3), FLOAT16(1), FLOAT16(4), FLOAT16(2), FLOAT16(4), FLOAT16(2), FLOAT16(1), FLOAT16(3), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(4), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(3), + FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(3), FLOAT16(4), FLOAT16(3), FLOAT16(4), FLOAT16(4), FLOAT16(1), FLOAT16(0), FLOAT16(3), + FLOAT16(2), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(4), FLOAT16(3), FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(3), FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(3), FLOAT16(1), FLOAT16(1), + FLOAT16(0), FLOAT16(3), FLOAT16(3), FLOAT16(4), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(3), + FLOAT16(3), FLOAT16(4), FLOAT16(3), FLOAT16(3), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(3), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(4), FLOAT16(0), FLOAT16(4), }); std::vector expected_results = { - FLOAT16(0), FLOAT16(1), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(2), FLOAT16(0), - FLOAT16(0), FLOAT16(7), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), - FLOAT16(0), FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(1), FLOAT16(7), - FLOAT16(0), FLOAT16(5), FLOAT16(10), FLOAT16(9), FLOAT16(4), FLOAT16(0), FLOAT16(7), - FLOAT16(4), FLOAT16(7), FLOAT16(8), FLOAT16(10), FLOAT16(4), FLOAT16(0), FLOAT16(8), - FLOAT16(3), FLOAT16(7), FLOAT16(10), FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(10), - FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(0), FLOAT16(6), FLOAT16(3), FLOAT16(1), - FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(2), FLOAT16(6), FLOAT16(3), FLOAT16(2), - FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(2), FLOAT16(0), FLOAT16(5), FLOAT16(0), - FLOAT16(5), FLOAT16(10), FLOAT16(5), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(0), - FLOAT16(5), FLOAT16(4), FLOAT16(3), FLOAT16(9), FLOAT16(1), FLOAT16(0), FLOAT16(10), - FLOAT16(5), FLOAT16(4), FLOAT16(8), FLOAT16(9), FLOAT16(1), FLOAT16(0), FLOAT16(7), - FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(6), FLOAT16(10), FLOAT16(9), FLOAT16(2), - FLOAT16(9), FLOAT16(6), FLOAT16(5), FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(2), - FLOAT16(2), FLOAT16(4), FLOAT16(0), FLOAT16(5), FLOAT16(3), FLOAT16(10), FLOAT16(8), - FLOAT16(2), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(10), FLOAT16(1), - FLOAT16(1), FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(4), FLOAT16(0), FLOAT16(7), - FLOAT16(1), FLOAT16(9), FLOAT16(6), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(7), - FLOAT16(10), FLOAT16(2), FLOAT16(4), FLOAT16(3), FLOAT16(1), FLOAT16(6), FLOAT16(1), - FLOAT16(7), FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(9), - FLOAT16(6), FLOAT16(2), FLOAT16(5), FLOAT16(1), FLOAT16(10), FLOAT16(4), FLOAT16(2), - FLOAT16(6), FLOAT16(10), FLOAT16(6), FLOAT16(5), FLOAT16(10), FLOAT16(1), FLOAT16(2), - FLOAT16(3), FLOAT16(6), FLOAT16(2), FLOAT16(0), FLOAT16(9), FLOAT16(4), FLOAT16(2), - FLOAT16(5), FLOAT16(4), FLOAT16(2), FLOAT16(0), FLOAT16(9), FLOAT16(8), FLOAT16(2), - FLOAT16(4), FLOAT16(4), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(1), FLOAT16(4), - FLOAT16(4), FLOAT16(4), FLOAT16(1), FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(0), - FLOAT16(9), FLOAT16(5), FLOAT16(7), FLOAT16(2), FLOAT16(7), FLOAT16(8), FLOAT16(5), - FLOAT16(8), FLOAT16(10), FLOAT16(7), FLOAT16(6), FLOAT16(8), FLOAT16(8), FLOAT16(5), - FLOAT16(10), FLOAT16(6), FLOAT16(3), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(4), - FLOAT16(10), FLOAT16(6), FLOAT16(3), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(4), - FLOAT16(7), FLOAT16(8), FLOAT16(4), FLOAT16(8), FLOAT16(9), FLOAT16(5), FLOAT16(0), - FLOAT16(7), FLOAT16(8), FLOAT16(0), FLOAT16(8), FLOAT16(9), FLOAT16(7), FLOAT16(8), - FLOAT16(9), FLOAT16(5), FLOAT16(7), FLOAT16(4), FLOAT16(2), FLOAT16(10), FLOAT16(9), - FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), FLOAT16(10), FLOAT16(10), FLOAT16(4), - FLOAT16(2), FLOAT16(5), FLOAT16(9), FLOAT16(4), FLOAT16(5), FLOAT16(8), FLOAT16(4), - FLOAT16(3), FLOAT16(10), FLOAT16(9), FLOAT16(4), FLOAT16(6), FLOAT16(8), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(8), FLOAT16(7), FLOAT16(6), FLOAT16(2), FLOAT16(0), FLOAT16(5), + FLOAT16(2), FLOAT16(1), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(2), FLOAT16(0), FLOAT16(5), + FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(0), FLOAT16(10), FLOAT16(5), FLOAT16(3), FLOAT16(4), + FLOAT16(5), FLOAT16(4), FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(5), FLOAT16(4), + FLOAT16(6), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(5), FLOAT16(6), FLOAT16(7), FLOAT16(7), + FLOAT16(1), FLOAT16(9), FLOAT16(8), FLOAT16(9), FLOAT16(1), FLOAT16(5), FLOAT16(8), FLOAT16(8), + FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(6), FLOAT16(1), FLOAT16(7), FLOAT16(6), FLOAT16(2), + FLOAT16(1), FLOAT16(3), FLOAT16(0), FLOAT16(6), FLOAT16(2), FLOAT16(7), FLOAT16(6), FLOAT16(1), + FLOAT16(7), FLOAT16(8), FLOAT16(8), FLOAT16(5), FLOAT16(0), FLOAT16(9), FLOAT16(0), FLOAT16(4), + FLOAT16(2), FLOAT16(2), FLOAT16(7), FLOAT16(5), FLOAT16(3), FLOAT16(9), FLOAT16(4), FLOAT16(5), + FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(7), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(9), + FLOAT16(1), FLOAT16(7), FLOAT16(10), FLOAT16(0), FLOAT16(9), FLOAT16(4), FLOAT16(5), FLOAT16(5), }; - DoTest(engine,input0, input1, expected_results, tensor(2, 3, 3, 2, 7), axis); + DoTest(engine,input0, input1, expected_results, tensor(1, 2, 8, 2, 3), axis); } -// 6-1 -TEST(gather_elements_gpu_fp16, d232328_i232328_a3) { +TEST(gather_elements_gpu_fp16, d25441_i22441_an4) { const auto& engine = get_test_engine(); - const int axis = 3; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 2, 3, 2, 8 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 2, 3, 2, 8 } }); // indices + const int axis = -4; + auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 5, 4, 4, 1 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 2, 4, 4, 1 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), + FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), + FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(5), + FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(5), + FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), + FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(0), FLOAT16(8), FLOAT16(3), + FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), + FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), + FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), + FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(3), + FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), + FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), + FLOAT16(10), FLOAT16(0), FLOAT16(8), FLOAT16(8), + FLOAT16(9), FLOAT16(1), FLOAT16(0), FLOAT16(7), + FLOAT16(9), FLOAT16(6), FLOAT16(8), FLOAT16(7), + FLOAT16(10), FLOAT16(9), FLOAT16(2), FLOAT16(3), + FLOAT16(3), FLOAT16(5), FLOAT16(6), FLOAT16(9), + FLOAT16(4), FLOAT16(9), FLOAT16(2), FLOAT16(4), + FLOAT16(5), FLOAT16(5), FLOAT16(3), FLOAT16(1), + FLOAT16(1), FLOAT16(6), FLOAT16(8), FLOAT16(0), + FLOAT16(5), FLOAT16(5), FLOAT16(10), FLOAT16(8), + FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(9), + FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(1), + FLOAT16(1), FLOAT16(3), FLOAT16(0), FLOAT16(4), + FLOAT16(0), FLOAT16(7), FLOAT16(10), FLOAT16(2), + FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(7), + FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(4), + FLOAT16(5), FLOAT16(1), FLOAT16(6), FLOAT16(9), + FLOAT16(6), FLOAT16(10), FLOAT16(6), FLOAT16(1), + FLOAT16(10), FLOAT16(4), FLOAT16(1), FLOAT16(6), + FLOAT16(2), FLOAT16(5), FLOAT16(5), FLOAT16(10), + FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(6), + FLOAT16(1), FLOAT16(7), FLOAT16(6), FLOAT16(8), + + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(2), FLOAT16(4), FLOAT16(3), + FLOAT16(4), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(4), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(3), FLOAT16(1), FLOAT16(4), FLOAT16(2), + FLOAT16(4), FLOAT16(2), FLOAT16(1), FLOAT16(3), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(4), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(3), + FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(3), FLOAT16(4), FLOAT16(3), FLOAT16(4), + FLOAT16(4), FLOAT16(1), FLOAT16(0), FLOAT16(3), + FLOAT16(2), FLOAT16(4), FLOAT16(4), FLOAT16(4), + FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(4), + FLOAT16(3), FLOAT16(0), FLOAT16(2), FLOAT16(4), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(5), + FLOAT16(10), FLOAT16(2), FLOAT16(0), FLOAT16(10), + FLOAT16(3), FLOAT16(10), FLOAT16(1), FLOAT16(5), + FLOAT16(4), FLOAT16(0), FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(3), + FLOAT16(10), FLOAT16(8), FLOAT16(6), FLOAT16(1), + FLOAT16(2), FLOAT16(5), FLOAT16(7), FLOAT16(5), + FLOAT16(4), FLOAT16(0), FLOAT16(6), FLOAT16(3), + FLOAT16(10), FLOAT16(9), FLOAT16(6), FLOAT16(9), + FLOAT16(1), FLOAT16(6), FLOAT16(5), FLOAT16(7), + FLOAT16(5), FLOAT16(2), FLOAT16(6), FLOAT16(6), + FLOAT16(1), FLOAT16(5), FLOAT16(6), FLOAT16(1), + FLOAT16(6), FLOAT16(4), FLOAT16(1), FLOAT16(6), + FLOAT16(2), FLOAT16(6), FLOAT16(5), FLOAT16(7), + FLOAT16(1), FLOAT16(9), FLOAT16(2), FLOAT16(6), + FLOAT16(6), FLOAT16(5), FLOAT16(10), FLOAT16(8), + }; + + DoTest(engine,input0, input1, expected_results, tensor(2, 2, 4, 4, 1), axis); +} + +TEST(gather_elements_gpu_fp16, d32843_i12843_a0) { + const auto& engine = get_test_engine(); + + const int axis = 0; + auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 8, 4, 3 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 1, 2, 8, 4, 3 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), @@ -632,6 +542,7 @@ TEST(gather_elements_gpu_fp16, d232328_i232328_a3) { FLOAT16(4), FLOAT16(10), FLOAT16(6), FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(4), FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(5), FLOAT16(3), FLOAT16(1), FLOAT16(2), FLOAT16(8), FLOAT16(5), FLOAT16(7), FLOAT16(9), FLOAT16(2), FLOAT16(7), FLOAT16(2), FLOAT16(4), FLOAT16(0), FLOAT16(5), + }); set_values(input1, { @@ -658,141 +569,47 @@ TEST(gather_elements_gpu_fp16, d232328_i232328_a3) { FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), - FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), - FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), - FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), - FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), - FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), - FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), - FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), - FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(2), - FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), - FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), - FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), - FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(2), - FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), - FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), - FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(2), - FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), - FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), - FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(2), - FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), - FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), - FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), - FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), - FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), - FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), - FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(2), - FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), - FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), - FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), - FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(0), - FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), - FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), - FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), - FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(1), - FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), - FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), - FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(1), - FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(2), - FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), - FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), - FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), - FLOAT16(2), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), - FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0), - FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(2), - FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), - FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(0), - FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(2), }); std::vector expected_results = { - FLOAT16(0), FLOAT16(0), FLOAT16(8), FLOAT16(3), FLOAT16(6), FLOAT16(2), FLOAT16(0), FLOAT16(7), - FLOAT16(2), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(4), FLOAT16(0), FLOAT16(10), FLOAT16(8), - FLOAT16(2), FLOAT16(0), FLOAT16(8), FLOAT16(0), FLOAT16(6), FLOAT16(7), FLOAT16(0), FLOAT16(4), - FLOAT16(9), FLOAT16(10), FLOAT16(1), FLOAT16(8), FLOAT16(9), FLOAT16(0), FLOAT16(10), FLOAT16(9), - FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(2), FLOAT16(10), FLOAT16(7), - FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(0), FLOAT16(0), FLOAT16(9), - FLOAT16(10), FLOAT16(1), FLOAT16(0), FLOAT16(7), FLOAT16(9), FLOAT16(3), FLOAT16(8), FLOAT16(1), - FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(8), FLOAT16(8), - FLOAT16(9), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(6), FLOAT16(3), FLOAT16(1), - FLOAT16(5), FLOAT16(9), FLOAT16(2), FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(9), - FLOAT16(10), FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(2), FLOAT16(3), FLOAT16(5), FLOAT16(7), - FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(5), FLOAT16(10), FLOAT16(5), FLOAT16(5), FLOAT16(3), - FLOAT16(1), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(1), FLOAT16(5), FLOAT16(3), FLOAT16(4), - FLOAT16(5), FLOAT16(6), FLOAT16(6), FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(10), FLOAT16(8), - FLOAT16(6), FLOAT16(9), FLOAT16(6), FLOAT16(4), FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(4), - FLOAT16(1), FLOAT16(6), FLOAT16(0), FLOAT16(4), FLOAT16(6), FLOAT16(10), FLOAT16(10), FLOAT16(2), - FLOAT16(6), FLOAT16(9), FLOAT16(9), FLOAT16(9), FLOAT16(5), FLOAT16(2), FLOAT16(3), FLOAT16(1), - FLOAT16(5), FLOAT16(1), FLOAT16(8), FLOAT16(4), FLOAT16(6), FLOAT16(10), FLOAT16(10), FLOAT16(8), - FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), FLOAT16(0), FLOAT16(9), FLOAT16(4), FLOAT16(10), - FLOAT16(10), FLOAT16(2), FLOAT16(1), FLOAT16(9), FLOAT16(1), FLOAT16(9), FLOAT16(0), FLOAT16(5), - FLOAT16(10), FLOAT16(1), FLOAT16(8), FLOAT16(2), FLOAT16(0), FLOAT16(4), FLOAT16(4), FLOAT16(1), - FLOAT16(7), FLOAT16(2), FLOAT16(1), FLOAT16(9), FLOAT16(2), FLOAT16(1), FLOAT16(5), FLOAT16(5), - FLOAT16(10), FLOAT16(5), FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(5), FLOAT16(8), FLOAT16(10), - FLOAT16(1), FLOAT16(4), FLOAT16(1), FLOAT16(9), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(8), - FLOAT16(7), FLOAT16(5), FLOAT16(7), FLOAT16(9), FLOAT16(10), FLOAT16(6), FLOAT16(9), FLOAT16(7), - FLOAT16(7), FLOAT16(7), FLOAT16(8), FLOAT16(0), FLOAT16(9), FLOAT16(4), FLOAT16(3), FLOAT16(0), - FLOAT16(1), FLOAT16(5), FLOAT16(7), FLOAT16(9), FLOAT16(4), FLOAT16(6), FLOAT16(4), FLOAT16(9), - FLOAT16(5), FLOAT16(1), FLOAT16(8), FLOAT16(10), FLOAT16(9), FLOAT16(3), FLOAT16(5), FLOAT16(5), - FLOAT16(7), FLOAT16(5), FLOAT16(8), FLOAT16(7), FLOAT16(4), FLOAT16(6), FLOAT16(4), FLOAT16(9), - FLOAT16(8), FLOAT16(1), FLOAT16(10), FLOAT16(10), FLOAT16(9), FLOAT16(4), FLOAT16(5), FLOAT16(5), - FLOAT16(9), FLOAT16(4), FLOAT16(2), FLOAT16(8), FLOAT16(4), FLOAT16(2), FLOAT16(1), FLOAT16(1), - FLOAT16(6), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(8), FLOAT16(4), FLOAT16(1), FLOAT16(4), - FLOAT16(9), FLOAT16(4), FLOAT16(5), FLOAT16(8), FLOAT16(4), FLOAT16(9), FLOAT16(1), FLOAT16(3), - FLOAT16(8), FLOAT16(6), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(9), FLOAT16(3), FLOAT16(4), - FLOAT16(4), FLOAT16(2), FLOAT16(2), FLOAT16(9), FLOAT16(7), FLOAT16(8), FLOAT16(4), FLOAT16(3), - FLOAT16(8), FLOAT16(6), FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(8), FLOAT16(3), FLOAT16(4), - FLOAT16(8), FLOAT16(1), FLOAT16(8), FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(8), FLOAT16(6), - FLOAT16(2), FLOAT16(6), FLOAT16(3), FLOAT16(8), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(4), - FLOAT16(0), FLOAT16(6), FLOAT16(9), FLOAT16(1), FLOAT16(10), FLOAT16(2), FLOAT16(2), FLOAT16(6), - FLOAT16(2), FLOAT16(6), FLOAT16(2), FLOAT16(7), FLOAT16(1), FLOAT16(4), FLOAT16(7), FLOAT16(4), - FLOAT16(8), FLOAT16(1), FLOAT16(9), FLOAT16(3), FLOAT16(10), FLOAT16(1), FLOAT16(3), FLOAT16(6), - FLOAT16(5), FLOAT16(6), FLOAT16(2), FLOAT16(8), FLOAT16(1), FLOAT16(8), FLOAT16(7), FLOAT16(9), - FLOAT16(2), FLOAT16(6), FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(7), FLOAT16(5), FLOAT16(7), - FLOAT16(7), FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(6), FLOAT16(10), FLOAT16(5), FLOAT16(8), - FLOAT16(2), FLOAT16(9), FLOAT16(10), FLOAT16(2), FLOAT16(7), FLOAT16(7), FLOAT16(1), FLOAT16(5), - FLOAT16(7), FLOAT16(0), FLOAT16(5), FLOAT16(10), FLOAT16(3), FLOAT16(7), FLOAT16(5), FLOAT16(7), - FLOAT16(4), FLOAT16(0), FLOAT16(4), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(1), FLOAT16(5), - FLOAT16(9), FLOAT16(0), FLOAT16(6), FLOAT16(8), FLOAT16(6), FLOAT16(5), FLOAT16(5), FLOAT16(7), - FLOAT16(0), FLOAT16(1), FLOAT16(7), FLOAT16(3), FLOAT16(0), FLOAT16(5), FLOAT16(8), FLOAT16(5), - FLOAT16(4), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(8), FLOAT16(7), FLOAT16(3), - FLOAT16(4), FLOAT16(1), FLOAT16(7), FLOAT16(7), FLOAT16(1), FLOAT16(5), FLOAT16(8), FLOAT16(4), - FLOAT16(4), FLOAT16(3), FLOAT16(6), FLOAT16(1), FLOAT16(0), FLOAT16(8), FLOAT16(4), FLOAT16(0), - FLOAT16(4), FLOAT16(1), FLOAT16(7), FLOAT16(3), FLOAT16(5), FLOAT16(3), FLOAT16(8), FLOAT16(5), - FLOAT16(4), FLOAT16(0), FLOAT16(0), FLOAT16(5), FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(0), - FLOAT16(0), FLOAT16(10), FLOAT16(6), FLOAT16(7), FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(5), - FLOAT16(10), FLOAT16(8), FLOAT16(7), FLOAT16(5), FLOAT16(8), FLOAT16(1), FLOAT16(4), FLOAT16(9), - FLOAT16(3), FLOAT16(6), FLOAT16(5), FLOAT16(7), FLOAT16(6), FLOAT16(10), FLOAT16(1), FLOAT16(6), - FLOAT16(6), FLOAT16(1), FLOAT16(7), FLOAT16(4), FLOAT16(0), FLOAT16(9), FLOAT16(2), FLOAT16(8), - FLOAT16(3), FLOAT16(4), FLOAT16(6), FLOAT16(0), FLOAT16(9), FLOAT16(10), FLOAT16(1), FLOAT16(9), - FLOAT16(6), FLOAT16(8), FLOAT16(6), FLOAT16(4), FLOAT16(0), FLOAT16(9), FLOAT16(10), FLOAT16(9), - FLOAT16(5), FLOAT16(10), FLOAT16(0), FLOAT16(0), FLOAT16(6), FLOAT16(4), FLOAT16(2), FLOAT16(2), - FLOAT16(3), FLOAT16(5), FLOAT16(8), FLOAT16(1), FLOAT16(7), FLOAT16(7), FLOAT16(8), FLOAT16(10), - FLOAT16(1), FLOAT16(6), FLOAT16(6), FLOAT16(0), FLOAT16(6), FLOAT16(8), FLOAT16(5), FLOAT16(0), - FLOAT16(4), FLOAT16(1), FLOAT16(5), FLOAT16(0), FLOAT16(7), FLOAT16(7), FLOAT16(8), FLOAT16(3), - FLOAT16(1), FLOAT16(2), FLOAT16(6), FLOAT16(7), FLOAT16(6), FLOAT16(8), FLOAT16(5), FLOAT16(2), - FLOAT16(4), FLOAT16(5), FLOAT16(5), FLOAT16(0), FLOAT16(7), FLOAT16(10), FLOAT16(5), FLOAT16(3), - FLOAT16(0), FLOAT16(4), FLOAT16(2), FLOAT16(1), FLOAT16(4), FLOAT16(2), FLOAT16(9), FLOAT16(7), - FLOAT16(4), FLOAT16(10), FLOAT16(5), FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(4), FLOAT16(4), - FLOAT16(4), FLOAT16(3), FLOAT16(7), FLOAT16(6), FLOAT16(4), FLOAT16(2), FLOAT16(9), FLOAT16(5), - FLOAT16(7), FLOAT16(1), FLOAT16(5), FLOAT16(3), FLOAT16(5), FLOAT16(5), FLOAT16(9), FLOAT16(5), - FLOAT16(4), FLOAT16(4), FLOAT16(5), FLOAT16(3), FLOAT16(1), FLOAT16(8), FLOAT16(9), FLOAT16(2), - FLOAT16(0), FLOAT16(1), FLOAT16(5), FLOAT16(7), FLOAT16(5), FLOAT16(5), FLOAT16(0), FLOAT16(5), + FLOAT16(0), FLOAT16(8), FLOAT16(5), FLOAT16(0), FLOAT16(1), FLOAT16(2), FLOAT16(0), FLOAT16(7), + FLOAT16(4), FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(9), FLOAT16(0), FLOAT16(5), FLOAT16(5), + FLOAT16(4), FLOAT16(4), FLOAT16(7), FLOAT16(9), FLOAT16(5), FLOAT16(8), FLOAT16(6), FLOAT16(4), + FLOAT16(8), FLOAT16(5), FLOAT16(8), FLOAT16(1), FLOAT16(4), FLOAT16(7), FLOAT16(5), FLOAT16(0), + FLOAT16(0), FLOAT16(5), FLOAT16(7), FLOAT16(7), FLOAT16(2), FLOAT16(8), FLOAT16(5), FLOAT16(4), + FLOAT16(4), FLOAT16(1), FLOAT16(3), FLOAT16(9), FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(3), + FLOAT16(9), FLOAT16(10), FLOAT16(5), FLOAT16(0), FLOAT16(9), FLOAT16(3), FLOAT16(4), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(5), FLOAT16(0), FLOAT16(5), FLOAT16(3), FLOAT16(4), + FLOAT16(3), FLOAT16(6), FLOAT16(2), FLOAT16(9), FLOAT16(10), FLOAT16(10), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(2), FLOAT16(2), FLOAT16(4), FLOAT16(0), FLOAT16(0), FLOAT16(8), FLOAT16(8), + FLOAT16(4), FLOAT16(4), FLOAT16(7), FLOAT16(7), FLOAT16(9), FLOAT16(6), FLOAT16(4), FLOAT16(6), + FLOAT16(10), FLOAT16(9), FLOAT16(2), FLOAT16(10), FLOAT16(2), FLOAT16(7), FLOAT16(6), FLOAT16(9), + FLOAT16(1), FLOAT16(9), FLOAT16(2), FLOAT16(4), FLOAT16(10), FLOAT16(5), FLOAT16(3), FLOAT16(2), + FLOAT16(3), FLOAT16(6), FLOAT16(5), FLOAT16(0), FLOAT16(5), FLOAT16(8), FLOAT16(7), FLOAT16(8), + FLOAT16(0), FLOAT16(6), FLOAT16(2), FLOAT16(9), FLOAT16(6), FLOAT16(1), FLOAT16(7), FLOAT16(2), + FLOAT16(1), FLOAT16(3), FLOAT16(3), FLOAT16(7), FLOAT16(0), FLOAT16(7), FLOAT16(5), FLOAT16(9), + FLOAT16(8), FLOAT16(3), FLOAT16(10), FLOAT16(3), FLOAT16(1), FLOAT16(5), FLOAT16(4), FLOAT16(6), + FLOAT16(4), FLOAT16(5), FLOAT16(6), FLOAT16(4), FLOAT16(4), FLOAT16(10), FLOAT16(5), FLOAT16(1), + FLOAT16(3), FLOAT16(4), FLOAT16(2), FLOAT16(1), FLOAT16(7), FLOAT16(7), FLOAT16(5), FLOAT16(10), + FLOAT16(7), FLOAT16(1), FLOAT16(5), FLOAT16(10), FLOAT16(3), FLOAT16(1), FLOAT16(5), FLOAT16(4), + FLOAT16(2), FLOAT16(3), FLOAT16(7), FLOAT16(1), FLOAT16(7), FLOAT16(8), FLOAT16(5), FLOAT16(5), + FLOAT16(4), FLOAT16(4), FLOAT16(3), FLOAT16(3), FLOAT16(5), FLOAT16(10), FLOAT16(4), FLOAT16(2), + FLOAT16(2), FLOAT16(9), FLOAT16(7), FLOAT16(5), FLOAT16(3), FLOAT16(4), FLOAT16(8), FLOAT16(5), + FLOAT16(7), FLOAT16(4), FLOAT16(6), FLOAT16(8), FLOAT16(2), FLOAT16(7), FLOAT16(3), FLOAT16(5), }; - DoTest(engine,input0, input1, expected_results, tensor(2, 3, 2, 3, 2, 8), axis); + DoTest(engine,input0, input1, expected_results, tensor(1, 2, 8, 4, 3), axis); } -// 6-2 -TEST(gather_elements_gpu_fp16, d222443_i222446_a5) { + +// ======================== Rank 6 ======================== // + +TEST(gather_elements_gpu_fp16, d223442_i226442_a5) { const auto& engine = get_test_engine(); const int axis = 5; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 2, 4, 4, 3 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 2, 4, 4, 6 } }); // indices + auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 3, 4, 4, 2 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 6, 4, 4, 2 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), @@ -1187,24 +1004,182 @@ TEST(gather_elements_gpu_fp16, d222443_i222446_a5) { FLOAT16(3), FLOAT16(3), FLOAT16(7), FLOAT16(8), FLOAT16(3), FLOAT16(8), }; - DoTest(engine,input0, input1, expected_results, tensor(2, 2, 2, 4, 4, 6), axis); + DoTest(engine,input0, input1, expected_results, tensor(2, 2, 6, 4, 4, 2), axis); +} + +TEST(gather_elements_gpu_fp16, d124251_i124221_an3) { + const auto& engine = get_test_engine(); + + const int axis = -3; + auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 1, 2, 4, 2, 5, 1 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 1, 2, 4, 2, 2, 1 } }); // indices + + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), + FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), + FLOAT16(7), FLOAT16(10), FLOAT16(4), FLOAT16(5), + FLOAT16(9), FLOAT16(0), FLOAT16(0), FLOAT16(5), + FLOAT16(7), FLOAT16(0), FLOAT16(4), FLOAT16(0), + FLOAT16(4), FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), FLOAT16(1), FLOAT16(7), + FLOAT16(4), FLOAT16(7), FLOAT16(10), FLOAT16(8), + FLOAT16(2), FLOAT16(0), FLOAT16(8), FLOAT16(3), + FLOAT16(6), FLOAT16(8), FLOAT16(10), FLOAT16(4), + FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(8), + FLOAT16(7), FLOAT16(0), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), FLOAT16(8), FLOAT16(5), + FLOAT16(2), FLOAT16(3), FLOAT16(3), FLOAT16(1), + FLOAT16(5), FLOAT16(9), FLOAT16(10), FLOAT16(0), + FLOAT16(9), FLOAT16(5), FLOAT16(5), FLOAT16(3), + FLOAT16(10), FLOAT16(5), FLOAT16(2), FLOAT16(0), + FLOAT16(10), FLOAT16(0), FLOAT16(5), FLOAT16(4), + FLOAT16(3), FLOAT16(10), FLOAT16(5), FLOAT16(5), + FLOAT16(10), FLOAT16(0), FLOAT16(8), FLOAT16(8), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(2), FLOAT16(4), FLOAT16(3), + FLOAT16(4), FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(4), FLOAT16(0), FLOAT16(1), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), + FLOAT16(3), FLOAT16(1), FLOAT16(4), FLOAT16(2), + FLOAT16(4), FLOAT16(2), FLOAT16(1), FLOAT16(3), + FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(4), + FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(4), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(0), FLOAT16(8), FLOAT16(7), + FLOAT16(6), FLOAT16(2), FLOAT16(0), FLOAT16(5), + FLOAT16(2), FLOAT16(1), FLOAT16(4), FLOAT16(5), + FLOAT16(9), FLOAT16(2), FLOAT16(0), FLOAT16(5), + FLOAT16(10), FLOAT16(4), FLOAT16(5), FLOAT16(0), + FLOAT16(10), FLOAT16(5), FLOAT16(3), FLOAT16(4), + FLOAT16(5), FLOAT16(4), FLOAT16(10), FLOAT16(5), + FLOAT16(2), FLOAT16(0), FLOAT16(5), FLOAT16(8), + }; + + DoTest(engine,input0, input1, expected_results, tensor(1, 2, 4, 2, 2, 1), axis); } -// TEST(gather_elements_gpu_fp16, d32223_i32228_a4) { -// const auto& engine = get_test_engine(); +TEST(gather_elements_gpu_fp16, d233113_i233115_a2) { + const auto& engine = get_test_engine(); -// const int axis = ; -// auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, }); // data -// auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, }); // indices + const int axis = 2; + auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 3, 1, 1, 3 } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 3, 1, 1, 5 } }); // indices -// set_values(input0, { -// }); + set_values(input0, { + FLOAT16(0), FLOAT16(1), FLOAT16(8), + FLOAT16(5), FLOAT16(5), FLOAT16(2), + FLOAT16(0), FLOAT16(7), FLOAT16(7), + FLOAT16(10), FLOAT16(4), FLOAT16(5), + FLOAT16(9), FLOAT16(0), FLOAT16(0), + FLOAT16(5), FLOAT16(7), FLOAT16(0), + FLOAT16(4), FLOAT16(0), FLOAT16(4), + FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(9), FLOAT16(5), FLOAT16(1), + FLOAT16(7), FLOAT16(4), FLOAT16(7), + FLOAT16(10), FLOAT16(8), FLOAT16(2), + FLOAT16(0), FLOAT16(8), FLOAT16(3), + FLOAT16(6), FLOAT16(8), FLOAT16(10), + FLOAT16(4), FLOAT16(2), FLOAT16(10), + FLOAT16(7), FLOAT16(8), FLOAT16(7), + FLOAT16(0), FLOAT16(6), FLOAT16(9), + FLOAT16(2), FLOAT16(4), FLOAT16(8), + FLOAT16(5), FLOAT16(2), FLOAT16(3), + }); + + set_values(input1, { + FLOAT16(0), FLOAT16(1), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(1), + FLOAT16(1), FLOAT16(0), FLOAT16(2), + FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(2), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(0), FLOAT16(2), + FLOAT16(2), FLOAT16(0), FLOAT16(1), + FLOAT16(1), FLOAT16(2), FLOAT16(2), + FLOAT16(1), FLOAT16(1), FLOAT16(0), + FLOAT16(2), FLOAT16(0), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(2), + FLOAT16(2), FLOAT16(1), FLOAT16(0), + FLOAT16(0), FLOAT16(2), FLOAT16(1), + FLOAT16(2), FLOAT16(1), FLOAT16(2), + FLOAT16(0), FLOAT16(0), FLOAT16(1), + FLOAT16(2), FLOAT16(0), FLOAT16(2), + }); + + std::vector expected_results = { + FLOAT16(0), FLOAT16(5), FLOAT16(7), + FLOAT16(0), FLOAT16(7), FLOAT16(8), + FLOAT16(0), FLOAT16(1), FLOAT16(7), + FLOAT16(0), FLOAT16(1), FLOAT16(8), + FLOAT16(5), FLOAT16(1), FLOAT16(2), + FLOAT16(9), FLOAT16(7), FLOAT16(0), + FLOAT16(5), FLOAT16(0), FLOAT16(0), + FLOAT16(9), FLOAT16(4), FLOAT16(0), + FLOAT16(9), FLOAT16(4), FLOAT16(0), + FLOAT16(5), FLOAT16(4), FLOAT16(5), + FLOAT16(7), FLOAT16(5), FLOAT16(1), + FLOAT16(7), FLOAT16(6), FLOAT16(10), + FLOAT16(7), FLOAT16(0), FLOAT16(1), + FLOAT16(4), FLOAT16(5), FLOAT16(1), + FLOAT16(9), FLOAT16(5), FLOAT16(1), + FLOAT16(7), FLOAT16(4), FLOAT16(3), + FLOAT16(10), FLOAT16(8), FLOAT16(3), + FLOAT16(0), FLOAT16(8), FLOAT16(7), + FLOAT16(0), FLOAT16(4), FLOAT16(7), + FLOAT16(7), FLOAT16(4), FLOAT16(3), + FLOAT16(7), FLOAT16(8), FLOAT16(10), + FLOAT16(4), FLOAT16(8), FLOAT16(7), + FLOAT16(4), FLOAT16(2), FLOAT16(10), + FLOAT16(7), FLOAT16(8), FLOAT16(10), + FLOAT16(6), FLOAT16(8), FLOAT16(7), + FLOAT16(5), FLOAT16(4), FLOAT16(9), + FLOAT16(0), FLOAT16(2), FLOAT16(8), + FLOAT16(5), FLOAT16(4), FLOAT16(3), + FLOAT16(0), FLOAT16(6), FLOAT16(8), + FLOAT16(5), FLOAT16(6), FLOAT16(3), + }; -// set_values(input1, { -// }); + DoTest(engine,input0, input1, expected_results, tensor(2, 3, 3, 1, 1, 5), axis); +} -// std::vector expected_results = { -// }; +/* +TEST(gather_elements_gpu_fp16, d_i_a) { + const auto& engine = get_test_engine(); + + const int axis = ; + auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { } }); // data + auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { } }); // indices + + set_values(input0, { + + }); + + set_values(input1, { + + }); + + std::vector expected_results = { + + }; + + DoTest(engine,input0, input1, expected_results, tensor(), axis); +} +*/ -// DoTest(engine,input0, input1, expected_results, axis); -// } From 5d05981215aed7c959a89bbfe7c0b3191e8c76e8 Mon Sep 17 00:00:00 2001 From: yunji Date: Wed, 14 Jul 2021 21:11:56 +0900 Subject: [PATCH 05/11] code clean up --- .../single_layer_tests/gather_elements.cpp | 13 +-- .../single_layer_tests/gather_elements.cpp | 6 -- .../thirdparty/clDNN/api/gather_elements.hpp | 4 +- .../gather/gather_elements_kernel_ref.cpp | 68 +----------- .../core/cl_kernels/gather_elements_ref.cl | 101 ++++-------------- .../tests/test_cases/fusings_gpu_test.cpp | 46 ++------ .../test_cases/gather_elements_gpu_test.cpp | 33 ------ 7 files changed, 31 insertions(+), 240 deletions(-) diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp index 822fd2bdf783b7..b4dad74bba3f6e 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp @@ -72,6 +72,7 @@ INSTANTIATE_TEST_CASE_P(smoke_set4, GatherElementsLayerTest, ::testing::Values(CommonTestUtils::DEVICE_CPU)), GatherElementsLayerTest::getTestCaseName); +<<<<<<< HEAD <<<<<<< HEAD <<<<<<< HEAD INSTANTIATE_TEST_SUITE_P(smoke_set5, GatherElementsLayerTest, @@ -80,6 +81,8 @@ INSTANTIATE_TEST_CASE_P(yunji_set3, GatherElementsLayerTest, >>>>>>> Add cldnn unit test implementation ======= +======= +>>>>>>> code clean up INSTANTIATE_TEST_CASE_P(smoke_set5, GatherElementsLayerTest, >>>>>>> Add functional test implementation ::testing::Combine( @@ -90,14 +93,4 @@ INSTANTIATE_TEST_CASE_P(smoke_set5, GatherElementsLayerTest, ::testing::ValuesIn(iPrecisions), ::testing::Values(CommonTestUtils::DEVICE_CPU)), GatherElementsLayerTest::getTestCaseName); - -// INSTANTIATE_TEST_CASE_P(yunji_set35, GatherElementsLayerTest, -// ::testing::Combine( -// ::testing::Values(std::vector({2, 3, 3, 1, 1, 3})), // Data shape -// ::testing::Values(std::vector({2, 3, 5, 1, 1, 3})), // Indices shape -// ::testing::Values(2), // Axis -// ::testing::ValuesIn(dPrecisions), -// ::testing::ValuesIn(iPrecisions), -// ::testing::Values(CommonTestUtils::DEVICE_CPU)), -// GatherElementsLayerTest::getTestCaseName); } // namespace diff --git a/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp index 180863e9c04958..3395723d41ea7c 100644 --- a/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp +++ b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp @@ -24,7 +24,6 @@ const std::vector idxPrecisions = { InferenceEngine::Precision::I64, }; -// ======= CPU Func Test Cases ======== // INSTANTIATE_TEST_CASE_P(smoke_set1, GatherElementsLayerTest, ::testing::Combine( ::testing::Values(std::vector({2, 2})), // Data shape @@ -75,8 +74,6 @@ INSTANTIATE_TEST_CASE_P(smoke_set5, GatherElementsLayerTest, ::testing::Values(CommonTestUtils::DEVICE_CPU)), GatherElementsLayerTest::getTestCaseName); -// ======= Rank 4 ======== // - const std::vector> ShapesRank4Axis0 = { std::vector{1, 7, 8, 4}, std::vector{2, 7, 8, 4}, @@ -106,7 +103,6 @@ INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis0, GatherElementsLayerTest ::testing::Combine( ::testing::ValuesIn(ShapesRank4Axis0), // Data shapes ::testing::ValuesIn(ShapesRank4Axis0), // Indices shpae - // ::testing::ValuesIn(axis0), ::testing::ValuesIn(std::vector({ 0 })), ::testing::ValuesIn(inputPrecisions), // Data precision ::testing::ValuesIn(idxPrecisions), // Indices precision @@ -143,7 +139,6 @@ INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis3, GatherElementsLayerTest ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name GatherElementsLayerTest::getTestCaseName); -// ====== rank = 5 ====== // const std::vector> ShapesRank5Axis0 = { std::vector{2, 3, 9, 4, 9}, std::vector{1, 3, 9, 4, 9}, @@ -225,7 +220,6 @@ INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis4, GatherElementsLayerTest ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name GatherElementsLayerTest::getTestCaseName); -// ====== rank = 6 ====== // const std::vector> ShapesRank6Axis0 = { std::vector{1, 3, 2, 4, 4, 3}, std::vector{3, 3, 2, 4, 4, 3}, diff --git a/inference-engine/thirdparty/clDNN/api/gather_elements.hpp b/inference-engine/thirdparty/clDNN/api/gather_elements.hpp index b48f0091f7f6a9..72273b6c33dffe 100644 --- a/inference-engine/thirdparty/clDNN/api/gather_elements.hpp +++ b/inference-engine/thirdparty/clDNN/api/gather_elements.hpp @@ -35,8 +35,8 @@ struct gather_elements : public primitive_base { /// @param id This primitive id. /// @param data Input data primitive id. /// @param indices Input indexes primitive id. - /// @param output_format Output format: bfyx, bfzyx, bfwzyx - /// @param output_shape Output shape: {2, 2, 3, 5}, {2, 2, 3, 3, 6} + /// @param output_format Output format. + /// @param output_shape Output shape. /// @param axis An attribute of GatherElements. Required. gather_elements(const primitive_id& id, const primitive_id& data, diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp index a4649fc80176ff..e4aff8ac41c758 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp @@ -44,14 +44,6 @@ ParamsKey GatherElementsKernelRef::GetSupportedKey() const { return k; } -static inline std::string GetOrderString(std::vector& order) { - std::string order_str = order[0]; - for (size_t i = 1; i < order.size(); i++) - order_str += ", " + order[i]; - - return order_str; -} - static inline std::vector GetDefaultOrder(size_t size) { std::vector default_order; if (size <= 4) { @@ -69,7 +61,6 @@ CommonDispatchData GatherElementsKernelRef::SetDefault(const gather_elements_par CommonDispatchData dispatchData; const auto& output = params.output; - // printf("%ld %ld %ld %ld %ld %ld\n", output.X().v, output.Y().v, output.Z().v, output.W().v, output.Feature().v, output.Batch().v); switch (params.inputs[1].GetLayout()) { case DataLayout::bfyx: @@ -77,7 +68,6 @@ CommonDispatchData GatherElementsKernelRef::SetDefault(const gather_elements_par break; case DataLayout::bfzyx: - // dispatchData.gws = {output.X().v * output.Y().v, output.Z().v, output.Feature().v * output.Batch().v}; dispatchData.gws = {output.X().v, output.Y().v * output.Z().v, output.Feature().v * output.Batch().v}; break; @@ -89,55 +79,19 @@ CommonDispatchData GatherElementsKernelRef::SetDefault(const gather_elements_par throw std::invalid_argument("Unsupported data layout for gather elements primitive"); break; } + dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo); - // dispatchData.lws = {1, 1, 1}; return dispatchData; } -// static size_t GetIndicesLastDim(const gather_elements_params& params) { -// // get indices dims -// auto indices_dims = params.inputs[1].LogicalDims(); -// // std::cout << indices_dims << "incide dims\n"; - -// if (indices_dims.size() > 1) { -// std::reverse(indices_dims.begin(), indices_dims.end()); -// } - -// auto indices_last_dim = indices_dims[0]; - -// return indices_last_dim; -// } - -// static size_t GetSliceSize(const gather_elements_params& params) { -// // get input dims -// // auto input_dims = params.inputs[0].LogicalDims(); - -// // if (input_dims.size() > 1) { -// // std::reverse(input_dims.begin(), input_dims.end()); -// // } - -// // // get last dim of indices -// // auto indices_last_dim = GetIndicesLastDim(params); - -// // // calculate slize size which is used in kernel to copy -// // size_t wi_slice_size = 1; -// // for (size_t i = params.batch_dims + indices_last_dim; i < input_dims.size(); i++) { -// // wi_slice_size *= input_dims[i]; -// // } - -// return 3; -// } - JitConstants GatherElementsKernelRef::GetJitConstants(const gather_elements_params& params) const { JitConstants jit = MakeBaseParamsJitConstants(params); - // parameters in gather_elements_kernel_ref.h auto p_axis = static_cast(params.axis); if (p_axis < 0) { p_axis = params.inputs[0].LogicalDims().size() + params.axis; } - // printf("%d\n", p_axis); jit.AddConstant(MakeJitConstant("AXIS", p_axis)); if (!params.fused_ops.empty()) { @@ -157,29 +111,11 @@ bool GatherElementsKernelRef::Validate(const Params& p, const optional_params& o const gather_elements_params& params = static_cast(p); auto input_dims = params.inputs[0].LogicalDims(); auto indices_dims = params.inputs[1].LogicalDims(); - auto indices_rank = indices_dims.size(); - std::reverse(input_dims.begin(), input_dims.end()); - std::reverse(indices_dims.begin(), indices_dims.end()); - - if (indices_rank < 1) { + if (input_dims.size() != indices_dims.size()) { return false; } - // if (batch_dims + indices_dims[indices_rank - 1] > input_dims.size()) { - // return false; - // } - - // if (batch_dims >= std::min(input_dims.size(), static_cast(indices_rank))) { - // return false; - // } - - // for (uint8_t i = 0; i < batch_dims; i++) { - // if (input_dims[i] != indices_dims[i]) { - // return false; - // } - // } - for (auto& fused_op : params.fused_ops) { if (!IsFusedPrimitiveSupported(fused_op)) return false; diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl index 882e5a12411297..d99da73fe05e61 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl @@ -16,14 +16,6 @@ #include "include/include_all.cl" #define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order) -#define GET_OUTPUT_INDEX(idx_order) OUTPUT_GET_INDEX(idx_order) - -#define IN_ORDER in_b,in_f,in_y,in_x - -#define OUT_ORDER out_b,out_f,out_y,out_x -#define GET_INDEX(prefix, num, idx_order) CAT(CAT(prefix, num), _GET_INDEX)(idx_order) - -#define INDICES_MAX_DIM 6 KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, const __global INPUT1_TYPE* indices, @@ -38,126 +30,69 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, const uint dim2 = get_global_id(2); // Calculate indice index - const uint f = dim2 % OUTPUT_FEATURE_NUM; - const uint b = dim2 / OUTPUT_FEATURE_NUM; #if INPUT1_DIMS == 4 - #define IDX_ORDER idx_b,idx_f,idx_y,idx_x #define ORDER b,f,y,x const uint x = dim0; const uint y = dim1; - #elif INPUT1_DIMS == 5 #define ORDER b,f,z,y,x - #define IDX_ORDER idx_b,idx_f,idx_z,idx_y,idx_x const uint x = dim0; const uint y = dim1 % OUTPUT_SIZE_Y; const uint z = dim1 / OUTPUT_SIZE_Y; - // x*y, z - #else - #define IDX_ORDER idx_b,idx_f,idx_w,idx_z,idx_y,idx_x #define ORDER b,f,w,z,y,x const uint x = dim0 % OUTPUT_SIZE_X; const uint y = dim0 / OUTPUT_SIZE_X; const uint z = dim1 % OUTPUT_SIZE_Z; const uint w = dim1 / OUTPUT_SIZE_Z; #endif - - // const int out_idx = GET_UPDATES_INDEX(INPUT1, IDX_ORDER); + const uint f = dim2 % OUTPUT_FEATURE_NUM; + const uint b = dim2 / OUTPUT_FEATURE_NUM; + const int out_idx = GET_UPDATES_INDEX(INPUT1, ORDER); - // printf("%d\n", out_idx); - int axis = AXIS; - size_t rank = INPUT0_DIMS; // indices_shape.size(), data_shape.size() - // if (out_idx == 10) { - //     printf("rank and axis: %d %d\n", rank, axis); - // } - // if(out_idx == 10) { printf("Axis: %d\n", axis); } -#if INPUT0_DIMS == 4 - // size_t data_shape[10] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_X, INPUT0_SIZE_Y, INPUT0_SIZE_Z, INPUT0_SIZE_W}; - size_t data_shape[10] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X, INPUT0_SIZE_Z, INPUT0_SIZE_W}; - // size_t indices_shape[10] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_X, INPUT1_SIZE_Y, INPUT1_SIZE_Z, INPUT1_SIZE_W}; - size_t indices_shape[10] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Y, INPUT1_SIZE_X, INPUT1_SIZE_Z, INPUT1_SIZE_W}; -#elif INPUT0_DIMS == 5 -// #else - size_t data_shape[10] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X, INPUT0_SIZE_W}; - size_t indices_shape[10] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X, INPUT1_SIZE_W}; + +#if INPUT1_DIMS == 4 + size_t data_shape[4] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X}; + size_t indices_shape[4] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Y, INPUT1_SIZE_X}; +#elif INPUT1_DIMS == 5 + size_t data_shape[5] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X}; + size_t indices_shape[5] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X}; #else - size_t data_shape[10] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_W, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X}; - size_t indices_shape[10] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X}; + size_t data_shape[6] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_W, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X}; + size_t indices_shape[6] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X}; #endif - - // 6 5 8 1 : b f y x - // x = 1 - // y = 8 size_t max_inner_sum = 1, max_outer_sum = 1, outer_sum_inc_data = 1, outer_sum_inc_indices = 1; - for (size_t i = axis + 1; i < rank; i++) + for (size_t i = AXIS + 1; i < INPUT1_DIMS; i++) max_inner_sum *= indices_shape[i]; - for (int i = 0; i < axis; i++) + for (int i = 0; i < AXIS; i++) max_outer_sum *= indices_shape[i]; - for (size_t i = axis; i < rank; i++) { + for (size_t i = AXIS; i < INPUT1_DIMS; i++) { outer_sum_inc_data *= data_shape[i]; } max_outer_sum *= outer_sum_inc_data; - for (size_t i = axis; i < rank; i++) { // 2, 3 + for (size_t i = AXIS; i < INPUT1_DIMS; i++) { outer_sum_inc_indices *= indices_shape[i]; } - if(out_idx == 10) { - // printf("%ld %ld %ld %ld %ld %ld\n", indices[0], indices[1], indices[2], indices[3], indices[4], indices[5]); - // printf("%ld %ld %ld %ld %ld %ld\n", indices[6], indices[7], indices[8], indices[9], indices[10], indices[11]); - // printf("%ld %ld %ld %ld %ld %ld\n", indices[12], indices[13], indices[14], indices[15], indices[16], indices[17]); - - printf("aixs: %ld\n", AXIS); - printf("data: %ld %ld %ld %ld %ld %ld\n", data_shape[0], data_shape[1], data_shape[2], data_shape[3], data_shape[4], data_shape[5]); - printf("indi: %ld %ld %ld %ld %ld %ld\n", indices_shape[0], indices_shape[1], indices_shape[2], indices_shape[3], indices_shape[4], indices_shape[5]); - } - -//     printf("max_inner_sum: %ld\n", max_inner_sum); -//     printf("outer_sum_inc_data: %ld\n",outer_sum_inc_data); -//     printf("max_inner_sum, max_outer_sum, outer_sum_inc_data: %d %d %d\n",max_inner_sum, max_outer_sum, outer_sum_inc); - -// ======================================================================================== - - size_t outer_sum = (out_idx / outer_sum_inc_indices); - outer_sum *= outer_sum_inc_data; - // size_t outer_sum = (out_idx) * outer_sum_inc_data; + size_t outer_sum = (out_idx / outer_sum_inc_indices) * outer_sum_inc_data; size_t inner_sum = out_idx % max_inner_sum; - if (indices[out_idx] < 0 || indices[out_idx] >= data_shape[axis]) { + if (indices[out_idx] < 0 || indices[out_idx] >= data_shape[AXIS]) { printf("indices values of GatherElement exceed data size. %ld %ld \n", out_idx, indices[out_idx]); return; } uint idx = outer_sum + max_inner_sum * indices[out_idx] + inner_sum; - uint tmp = outer_sum; - // printf("%d %d, ", out_idx, outer_sum); - // if(out_idx == 10) { printf("outer_sum: %d\n", tmp); } - - INPUT0_TYPE val = data[idx]; - // output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS); - - // output[out_idx] = TO_OUTPUT_TYPE(axis); - // output[out_idx] = axis; -// ======================================================================================== - - // output[out_idx] = TO_OUTPUT_TYPE(out_idx); -// ======================================================================================== #if HAS_FUSED_OPS FUSED_OPS; output[out_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT); #else - // output[out_idx] = outer_sum; output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS); #endif } -#undef INDICES_MAX_DIM #undef GET_UPDATES_INDEX -#undef GET_OUTPUT_INDEX -#undef OUT_ORDER -#undef IDX_ORDER -#undef IN_ORDER diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp index ba1a8110209b6b..8fe712405f46cd 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp @@ -178,11 +178,6 @@ class BaseFusingTest : public ::testing::TestWithParam { description << " " << i.original_id << " " << i.kernel_id << std::endl; } SCOPED_TRACE(description.str()); - // std::cout << "Count reorder? " << count_reorder << std::endl; - // std::cout << "(executed primitives) fused, not fused: " << fused.get_executed_primitives().size() << ", " << not_fused.get_executed_primitives().size() << std::endl; - // std::cout << "(reorder count) fused, not fused: " << reorders_count_fused << ", " << reorders_count_not_fused << std::endl; - // std::cout << "(exepected) fused, not fused: " << p.expected_fused_primitives << ", " << p.expected_not_fused_primitives << std::endl; - // Subtract reorders count to handle execution in different layouts when input/output reorders can be added in the graph ASSERT_EQ(fused.get_executed_primitives().size() - (count_reorder ? 0 : reorders_count_fused), p.expected_fused_primitives); ASSERT_EQ(not_fused.get_executed_primitives().size() - (count_reorder ? 0 : reorders_count_not_fused), p.expected_not_fused_primitives); ASSERT_EQ(outputs_ref.size(), outputs_fused.size()); @@ -8470,7 +8465,6 @@ class GatherElementsPrimitiveFusingTest : public ::BaseFusingTestengine, this->topology_non_fused, bo_not_fused); - // network network_fused(this->engine, this->topology_non_fused, bo_not_fused); network network_fused(this->engine, this->topology_fused, bo_fused); network_fused.set_input_data("input", input_prim); network_not_fused.set_input_data("input", input_prim); @@ -8478,24 +8472,23 @@ class GatherElementsPrimitiveFusingTest : public ::BaseFusingTest(get_axis_dim(p))); create_topologies(input_layout("input", get_input_layout(p)), - // data("gather_elements_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices - 1)), - // data("gather_indices", get_mem(get_indices_layout(p), 0, static_cast(get_axis_dim(p)))), - // data("gather_elements_indices", get_mem(get_indices_layout(p), 0, /*p.max_number_in_indices - 1*/)), data("gather_elements_indices", get_mem(get_indices_layout(p), 0, static_cast(get_axis_dim(p))-1)), data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)), data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)), data("out_lo", get_mem(get_single_element_layout(p), -127)), data("out_hi", get_mem(get_single_element_layout(p), 127)), - // output format, output shape, axis gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis), quantize("quantize", "gather_elements_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8), reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32) @@ -8596,8 +8566,6 @@ class gather_elements_scale_activation : public GatherElementsPrimitiveFusingTes TEST_P(gather_elements_scale_activation, basic) { auto p = GetParam(); create_topologies(input_layout("input", get_input_layout(p)), - // data("gather_indices", get_mem(get_indices_layout(p), 0, static_cast(get_axis_dim(p)))), - // data("gather_elements_indices", get_mem(get_indices_layout(p), 0, 2)), // 2 -> ? data("gather_elements_indices", get_mem(get_indices_layout(p), 0, static_cast(get_axis_dim(p))-1)), data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)), gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis), @@ -8641,8 +8609,6 @@ TEST_P(gather_elements_activation_scale_eltwise, basic) { auto p = GetParam(); create_topologies(input_layout("input", get_input_layout(p)), - // data("gather_nd_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices - 1)), - // data("gather_elements_indices", get_mem(get_indices_layout(p), 0, 2)), // 2 -> ? data("gather_elements_indices", get_mem(get_indices_layout(p), 0, static_cast(get_axis_dim(p))-1)), data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255)), data("eltwise_data", get_mem(get_output_layout(p))), diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp index 4fe6708748bbcc..5e4b8b6eef7e28 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp @@ -49,14 +49,10 @@ inline void DoTest(const engine& engine, auto output_ptr = output.pointer(); for (size_t i = 0; i < expected_results.size(); ++i) { - // printf("%ld : %f %f\n", i, expected_results[i], float16_to_float32(output_ptr[i]) ); - // printf("%ld : %f\n", i, float16_to_float32(output_ptr[i]) ); EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); } } -// ======================== Rank 4 ======================== // - TEST(gather_elements_gpu_fp16, d3283_i2283_a0) { const auto& engine = get_test_engine(); @@ -295,8 +291,6 @@ TEST(gather_elements_gpu_fp16, d1329_i1359_an1) { DoTest(engine,input0, input1, expected_results, tensor(1, 3, 5, 9), axis); } -// ======================== Rank 5 ======================== // - TEST(gather_elements_gpu_fp16, d12853_i12923_a3) { const auto& engine = get_test_engine(); @@ -602,8 +596,6 @@ TEST(gather_elements_gpu_fp16, d32843_i12843_a0) { DoTest(engine,input0, input1, expected_results, tensor(1, 2, 8, 4, 3), axis); } -// ======================== Rank 6 ======================== // - TEST(gather_elements_gpu_fp16, d223442_i226442_a5) { const auto& engine = get_test_engine(); @@ -1158,28 +1150,3 @@ TEST(gather_elements_gpu_fp16, d233113_i233115_a2) { DoTest(engine,input0, input1, expected_results, tensor(2, 3, 3, 1, 1, 5), axis); } - -/* -TEST(gather_elements_gpu_fp16, d_i_a) { - const auto& engine = get_test_engine(); - - const int axis = ; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { } }); // indices - - set_values(input0, { - - }); - - set_values(input1, { - - }); - - std::vector expected_results = { - - }; - - DoTest(engine,input0, input1, expected_results, tensor(), axis); -} -*/ - From e3cffc230cbf62f353d29038cc5f7b5d92375749 Mon Sep 17 00:00:00 2001 From: yunji Date: Fri, 16 Jul 2021 00:17:12 +0900 Subject: [PATCH 06/11] Change the type of axis parameter to enumeration. --- .../src/cldnn_engine/ops/gather_elements.cpp | 34 +++++++++- .../single_layer_tests/gather_elements.cpp | 10 +-- .../single_layer_tests/gather_elements.hpp | 4 -- .../thirdparty/clDNN/api/gather_elements.hpp | 16 ++++- .../kernel_selector/common/common_types.h | 12 ++++ .../gather/gather_elements_kernel_ref.cpp | 32 +++++++-- .../gather/gather_elements_kernel_ref.h | 4 +- .../core/cl_kernels/gather_elements_ref.cl | 2 +- .../core/kernel_selector_common.cpp | 12 ++++ .../core/kernel_selector_common.h | 1 + .../clDNN/src/gpu/gather_elements_gpu.cpp | 21 +++++- .../src/include/kernel_selector_helper.h | 1 + .../tests/test_cases/fusings_gpu_test.cpp | 65 ++++++++----------- .../test_cases/gather_elements_gpu_test.cpp | 20 +++--- 14 files changed, 161 insertions(+), 73 deletions(-) diff --git a/inference-engine/src/cldnn_engine/ops/gather_elements.cpp b/inference-engine/src/cldnn_engine/ops/gather_elements.cpp index 914a7021611cbe..7aa061c545dea7 100644 --- a/inference-engine/src/cldnn_engine/ops/gather_elements.cpp +++ b/inference-engine/src/cldnn_engine/ops/gather_elements.cpp @@ -12,12 +12,42 @@ namespace CLDNNPlugin { +static cldnn::gather_elements::gather_elements_axis GetGatherElementsAxis(int axis, unsigned rank) { + if (axis < 0) + axis += rank; + if (axis < 0 || axis >= rank) + IE_THROW() << "GatherElements axis is not correspond to number of dimensions"; + + // Difference in dimension ordering between IE and clDNN, + // reverse spatial dimensions after batch and feature. + unsigned cldnn_axis = axis; + if (axis >= 2) { + auto spatial_axis = axis - 2; + // Default and minimum number of dimensions is 4 + auto spatial_size = std::max(rank, 4u) - 2; + cldnn_axis = spatial_size - spatial_axis - 1 + 2; + } + + switch (cldnn_axis) { + case 0: return cldnn::gather_elements::gather_elements_axis::along_b; + case 1: return cldnn::gather_elements::gather_elements_axis::along_f; + case 2: return cldnn::gather_elements::gather_elements_axis::along_x; + case 3: return cldnn::gather_elements::gather_elements_axis::along_y; + case 4: return cldnn::gather_elements::gather_elements_axis::along_z; + case 5: return cldnn::gather_elements::gather_elements_axis::along_w; + default: IE_THROW() << "Unsupported ScatterElementsUpdate axis: " << axis; + } + return cldnn::gather_elements::gather_elements_axis::along_f; // shouldn't get here +} + void CreateGatherElementsOp(Program& p, const std::shared_ptr& op) { p.ValidateInputs(op, {2}); auto inputPrimitives = p.GetInputPrimitiveIDs(op); std::string layerName = layer_type_name_ID(op); - auto axis = op->get_axis(); + size_t rank = op->get_input_shape(0).size(); + int32_t axis = static_cast(op->get_axis()); + auto outLayout = DefaultFormatForDims(op->get_output_shape(0).size()); auto primitive = cldnn::gather_elements(layerName, @@ -25,7 +55,7 @@ void CreateGatherElementsOp(Program& p, const std::shared_ptrget_output_shape(0)), - axis); + GetGatherElementsAxis(axis, rank)); p.AddPrimitive(primitive); p.AddPrimitiveToProfiler(op); diff --git a/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp index 3395723d41ea7c..be951a0bb5840b 100644 --- a/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp +++ b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp @@ -31,7 +31,7 @@ INSTANTIATE_TEST_CASE_P(smoke_set1, GatherElementsLayerTest, ::testing::ValuesIn(std::vector({-1, 0, 1})), // Axis ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(idxPrecisions), - ::testing::Values(CommonTestUtils::DEVICE_CPU)), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_set2, GatherElementsLayerTest, @@ -41,7 +41,7 @@ INSTANTIATE_TEST_CASE_P(smoke_set2, GatherElementsLayerTest, ::testing::ValuesIn(std::vector({0, -3})), // Axis ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(idxPrecisions), - ::testing::Values(CommonTestUtils::DEVICE_CPU)), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_set3, GatherElementsLayerTest, @@ -51,7 +51,7 @@ INSTANTIATE_TEST_CASE_P(smoke_set3, GatherElementsLayerTest, ::testing::Values(3, -1), // Axis ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(idxPrecisions), - ::testing::Values(CommonTestUtils::DEVICE_CPU)), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_set4, GatherElementsLayerTest, @@ -61,7 +61,7 @@ INSTANTIATE_TEST_CASE_P(smoke_set4, GatherElementsLayerTest, ::testing::Values(0, -4), // Axis ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(idxPrecisions), - ::testing::Values(CommonTestUtils::DEVICE_CPU)), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_set5, GatherElementsLayerTest, @@ -71,7 +71,7 @@ INSTANTIATE_TEST_CASE_P(smoke_set5, GatherElementsLayerTest, ::testing::Values(3, -2), // Axis ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(idxPrecisions), - ::testing::Values(CommonTestUtils::DEVICE_CPU)), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); const std::vector> ShapesRank4Axis0 = { diff --git a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp index 61313a9cbcff0a..9c6329c76b3e81 100644 --- a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp +++ b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp @@ -8,8 +8,4 @@ namespace LayerTestsDefinitions { -// TEST_P(GatherElementsLayerTest, CompareWithRefs) { -// Run(); -// } - } // namespace LayerTestsDefinitions diff --git a/inference-engine/thirdparty/clDNN/api/gather_elements.hpp b/inference-engine/thirdparty/clDNN/api/gather_elements.hpp index 72273b6c33dffe..b05a3044e68a59 100644 --- a/inference-engine/thirdparty/clDNN/api/gather_elements.hpp +++ b/inference-engine/thirdparty/clDNN/api/gather_elements.hpp @@ -31,19 +31,29 @@ namespace cldnn { struct gather_elements : public primitive_base { CLDNN_DECLARE_PRIMITIVE(gather_elements) + enum gather_elements_axis { + along_b, + along_f, + along_x, + along_y, + along_z, + along_w + }; + /// @brief Constructs gather_elements primitive. /// @param id This primitive id. /// @param data Input data primitive id. /// @param indices Input indexes primitive id. /// @param output_format Output format. /// @param output_shape Output shape. - /// @param axis An attribute of GatherElements. Required. + /// @param axis Gathering axis. gather_elements(const primitive_id& id, const primitive_id& data, const primitive_id& indices, const format& output_format, const tensor& output_shape, - const uint8_t axis = 0, + // const uint8_t axis = 0, + const gather_elements_axis axis, const padding& output_padding = padding()) : primitive_base(id, {data, indices}, output_padding), output_format(output_format), output_shape(output_shape), axis(axis) {} @@ -53,7 +63,7 @@ struct gather_elements : public primitive_base { tensor output_shape; /// @brief Which axis to gather on. - uint8_t axis; + gather_elements_axis axis; }; /// @} /// @} diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h b/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h index dbe6bd7004c672..fbb108a8124baa 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h @@ -514,6 +514,18 @@ enum class GatherAxis { BATCH, }; +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// GatherElementsAxis +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +enum class GatherElementsAxis { + X, + Y, + Z, + W, + FEATURE, + BATCH, +}; + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // ScatterUpdateAxis //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp index e4aff8ac41c758..c4afc4c3ae1d82 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp @@ -18,8 +18,32 @@ #include "kernel_selector_utils.h" #include #include -#include + namespace kernel_selector { +static size_t GetGatherElementsChannelIndex(const gather_elements_params& params) { + Tensor::DataChannelName name = Tensor::DataChannelName::X; + + size_t inputSize = params.inputs[0].GetDims().size(); + + switch (params.axis) { + case GatherElementsAxis::X: + return inputSize - 1; + case GatherElementsAxis::Y: + return inputSize - 2; + case GatherElementsAxis::Z: + return inputSize - 3; + case GatherElementsAxis::W: + return 2; + case GatherElementsAxis::FEATURE: + return 1; + case GatherElementsAxis::BATCH: + return 0; + default: + break; + } + + return DataTensor::Channelndex(params.output.GetLayout(), name); +} ParamsKey GatherElementsKernelRef::GetSupportedKey() const { ParamsKey k; @@ -88,11 +112,7 @@ CommonDispatchData GatherElementsKernelRef::SetDefault(const gather_elements_par JitConstants GatherElementsKernelRef::GetJitConstants(const gather_elements_params& params) const { JitConstants jit = MakeBaseParamsJitConstants(params); - auto p_axis = static_cast(params.axis); - if (p_axis < 0) { - p_axis = params.inputs[0].LogicalDims().size() + params.axis; - } - jit.AddConstant(MakeJitConstant("AXIS", p_axis)); + jit.AddConstant(MakeJitConstant("AXIS", GetGatherElementsChannelIndex(params))); if (!params.fused_ops.empty()) { std::vector idx_order = GetDefaultOrder(params.inputs[0].GetDims().size()); diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h index a58fa3f87ab991..5826671790389e 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h @@ -23,9 +23,9 @@ namespace kernel_selector { // gather_elements_params //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// struct gather_elements_params : public base_params { - gather_elements_params() : base_params(KernelType::GATHER_ELEMENTS), axis(0) {} + gather_elements_params() : base_params(KernelType::GATHER_ELEMENTS), axis(GatherElementsAxis::BATCH) {} - uint8_t axis; + GatherElementsAxis axis; virtual ParamsKey GetParamsKey() const { return base_params::GetParamsKey(); } }; diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl index d99da73fe05e61..f7005f3450d6b4 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl @@ -81,7 +81,7 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, size_t outer_sum = (out_idx / outer_sum_inc_indices) * outer_sum_inc_data; size_t inner_sum = out_idx % max_inner_sum; if (indices[out_idx] < 0 || indices[out_idx] >= data_shape[AXIS]) { - printf("indices values of GatherElement exceed data size. %ld %ld \n", out_idx, indices[out_idx]); + printf("indices values of GatherElement exceed data size.\n"); return; } uint idx = outer_sum + max_inner_sum * indices[out_idx] + inner_sum; diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.cpp index deeb31e350e890..0d6578fca27898 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.cpp @@ -408,6 +408,18 @@ std::string toString(GatherAxis a) { } } +std::string toString(GatherElementsAxis a) { + switch (a) { + case GatherElementsAxis::X: return "X"; + case GatherElementsAxis::Y: return "Y"; + case GatherElementsAxis::Z: return "Z"; + case GatherElementsAxis::W: return "W"; + case GatherElementsAxis::FEATURE: return "FEATURE"; + case GatherElementsAxis::BATCH: return "BATCH"; + default: return ""; + } +} + std::string toString(ScatterUpdateAxis a) { switch (a) { case ScatterUpdateAxis::X: return "X"; diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.h index d0b0054f1b2dda..9e7009499bdbec 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.h @@ -148,6 +148,7 @@ std::string toString(MVNEpsMode mode); std::string toString(WeightsLayout layout); std::string toString(ConcatAxis a); std::string toString(GatherAxis a); +std::string toString(GatherElementsAxis a); std::string toString(ScatterUpdateAxis a); std::string toString(ResampleType type); std::string toString(CoordinateTransformationMode mode); diff --git a/inference-engine/thirdparty/clDNN/src/gpu/gather_elements_gpu.cpp b/inference-engine/thirdparty/clDNN/src/gpu/gather_elements_gpu.cpp index 1789a0d5de7a3f..a7de56fb6c7dd4 100644 --- a/inference-engine/thirdparty/clDNN/src/gpu/gather_elements_gpu.cpp +++ b/inference-engine/thirdparty/clDNN/src/gpu/gather_elements_gpu.cpp @@ -26,6 +26,24 @@ using namespace cldnn; namespace cldnn { namespace gpu { +kernel_selector::gather_elements_axis convert_axis(gather_elements::gather_elements_axis axis) { + switch (axis) { + case gather_elements::along_x: + return kernel_selector::gather_elements_axis::X; + case gather_elements::along_y: + return kernel_selector::gather_elements_axis::Y; + case gather_elements::along_z: + return kernel_selector::gather_elements_axis::Z; + case gather_elements::along_w: + return kernel_selector::gather_elements_axis::W; + case gather_elements::along_f: + return kernel_selector::gather_elements_axis::FEATURE; + case gather_elements::along_b: + return kernel_selector::gather_elements_axis::BATCH; + default: + return kernel_selector::gather_elements_axis::X; + } +} struct gather_elements_gpu : typed_primitive_gpu_impl { using parent = typed_primitive_gpu_impl; @@ -37,8 +55,7 @@ struct gather_elements_gpu : typed_primitive_gpu_impl { auto gather_elements_optional_params = get_default_optional_params(arg.get_program()); - // gather_elements_params.indices_rank = arg.get_primitive()->indices_rank; - gather_elements_params.axis = arg.get_primitive()->axis; + gather_elements_params.axis = convert_axis(arg.get_primitive()->axis); gather_elements_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout())); diff --git a/inference-engine/thirdparty/clDNN/src/include/kernel_selector_helper.h b/inference-engine/thirdparty/clDNN/src/include/kernel_selector_helper.h index f97f74ebbbcda8..e08aa400c0887b 100644 --- a/inference-engine/thirdparty/clDNN/src/include/kernel_selector_helper.h +++ b/inference-engine/thirdparty/clDNN/src/include/kernel_selector_helper.h @@ -72,6 +72,7 @@ using shape_calculation_mode = kernel_selector::ShapeCalculationMode; using interpolate_axis = kernel_selector::InterpolateAxis; using border_type = kernel_selector::BorderType; using gather_axis = kernel_selector::GatherAxis; +using gather_elements_axis = kernel_selector::GatherElementsAxis; using scatter_update_axis = kernel_selector::ScatterUpdateAxis; using reduce_mode = kernel_selector::ReduceMode; using cum_sum_axis = kernel_selector::CumSumAxis; diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp index 8fe712405f46cd..747f575bc47bb3 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp @@ -8429,7 +8429,7 @@ struct gather_elements_test_params { format output_format; tensor output_shape; - int axis; + cldnn::gather_elements::gather_elements_axis axis; data_types default_type; format default_format; @@ -8438,27 +8438,28 @@ struct gather_elements_test_params { size_t expected_not_fused_primitives; }; -#define CASE_GATHER_ELEMENTS_FP16_4D_1 data_types::f16, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, 3, data_types::f16, format::bfyx -#define CASE_GATHER_ELEMENTS_FP16_4D_2 data_types::f16, format::bfyx, {3, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, 0, data_types::f16, format::bfyx -#define CASE_GATHER_ELEMENTS_FP16_4D_3 data_types::f16, format::bfyx, {1, 3, 2, 9}, format::bfyx, {1, 3, 5, 9}, format::bfyx, {1, 3, 5, 9}, -1, data_types::f16, format::bfyx +#define CASE_GATHER_ELEMENTS_FP16_4D_1 data_types::f16, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, cldnn::gather_elements::gather_elements_axis::along_y, data_types::f16, format::bfyx +#define CASE_GATHER_ELEMENTS_FP16_4D_2 data_types::f16, format::bfyx, {3, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, cldnn::gather_elements::gather_elements_axis::along_b, data_types::f16, format::bfyx +#define CASE_GATHER_ELEMENTS_FP16_4D_3 data_types::f16, format::bfyx, {1, 3, 2, 9}, format::bfyx, {1, 3, 5, 9}, format::bfyx, {1, 3, 5, 9}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f16, format::bfyx -#define CASE_GATHER_ELEMENTS_FP16_5D_1 data_types::f16, format::bfzyx, {3, 2, 5, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, 4, data_types::f16, format::bfzyx -#define CASE_GATHER_ELEMENTS_FP16_5D_2 data_types::f16, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 7, 4, 3}, format::bfzyx, {5, 4, 7, 4, 3}, 2, data_types::f16, format::bfzyx +#define CASE_GATHER_ELEMENTS_FP16_5D_1 data_types::f16, format::bfzyx, {3, 2, 5, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f16, format::bfzyx +#define CASE_GATHER_ELEMENTS_FP16_5D_2 data_types::f16, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 7, 4, 3}, format::bfzyx, {5, 4, 7, 4, 3}, cldnn::gather_elements::gather_elements_axis::along_z, data_types::f16, format::bfzyx -#define CASE_GATHER_ELEMENTS_FP16_6D_1 data_types::f16, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 4, 6, 7, 8, 2}, 4, data_types::f16, format::bfwzyx -#define CASE_GATHER_ELEMENTS_FP16_6D_2 data_types::f16, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 1}, -2, data_types::f16, format::bfwzyx -#define CASE_GATHER_ELEMENTS_FP16_6D_3 data_types::f16, format::bfwzyx, {2, 2, 3, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, 5, data_types::f16, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP16_6D_1 data_types::f16, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, cldnn::gather_elements::gather_elements_axis::along_f, data_types::f16, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP16_6D_2 data_types::f16, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_w, data_types::f16, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP16_6D_3 data_types::f16, format::bfwzyx, {2, 2, 3, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f16, format::bfwzyx -#define CASE_GATHER_ELEMENTS_FP32_4D_1 data_types::f32, format::bfyx, {6, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, 3, data_types::f32, format::bfyx -#define CASE_GATHER_ELEMENTS_FP32_4D_2 data_types::f32, format::bfyx, {3, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, 0, data_types::f32, format::bfyx -#define CASE_GATHER_ELEMENTS_FP32_4D_3 data_types::f32, format::bfyx, {1, 3, 2, 9}, format::bfyx, {1, 3, 5, 9}, format::bfyx, {1, 3, 5, 9}, -1, data_types::f32, format::bfyx -#define CASE_GATHER_ELEMENTS_FP32_5D_1 data_types::f32, format::bfzyx, {3, 2, 5, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, 4, data_types::f32, format::bfzyx -#define CASE_GATHER_ELEMENTS_FP32_5D_2 data_types::f32, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 7, 4, 3}, format::bfzyx, {5, 4, 7, 4, 3}, 2, data_types::f32, format::bfzyx +#define CASE_GATHER_ELEMENTS_FP32_4D_1 data_types::f32, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, cldnn::gather_elements::gather_elements_axis::along_y, data_types::f32, format::bfyx +#define CASE_GATHER_ELEMENTS_FP32_4D_2 data_types::f32, format::bfyx, {3, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, cldnn::gather_elements::gather_elements_axis::along_b, data_types::f32, format::bfyx +#define CASE_GATHER_ELEMENTS_FP32_4D_3 data_types::f32, format::bfyx, {1, 3, 2, 9}, format::bfyx, {1, 3, 5, 9}, format::bfyx, {1, 3, 5, 9}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f32, format::bfyx -#define CASE_GATHER_ELEMENTS_FP32_6D_1 data_types::f32, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 4, 6, 7, 8, 2}, 4, data_types::f32, format::bfwzyx -#define CASE_GATHER_ELEMENTS_FP32_6D_2 data_types::f32, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 1}, -2, data_types::f32, format::bfwzyx -#define CASE_GATHER_ELEMENTS_FP32_6D_3 data_types::f32, format::bfwzyx, {2, 2, 3, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, 5, data_types::f32, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP32_5D_1 data_types::f32, format::bfzyx, {3, 2, 5, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f32, format::bfzyx +#define CASE_GATHER_ELEMENTS_FP32_5D_2 data_types::f32, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 7, 4, 3}, format::bfzyx, {5, 4, 7, 4, 3}, cldnn::gather_elements::gather_elements_axis::along_z, data_types::f32, format::bfzyx + +#define CASE_GATHER_ELEMENTS_FP32_6D_1 data_types::f32, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, cldnn::gather_elements::gather_elements_axis::along_f, data_types::f32, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP32_6D_2 data_types::f32, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_w, data_types::f32, format::bfwzyx +#define CASE_GATHER_ELEMENTS_FP32_6D_3 data_types::f32, format::bfwzyx, {2, 2, 3, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f32, format::bfwzyx class GatherElementsPrimitiveFusingTest : public ::BaseFusingTest { public: @@ -8472,31 +8473,19 @@ class GatherElementsPrimitiveFusingTest : public ::BaseFusingTest& expected_results, const tensor& output_tensor, - const int axis) { + const cldnn::gather_elements::gather_elements_axis axis) { topology topology; topology.add(input_layout("InputData", input0.get_layout())); topology.add(input_layout("InputIndices", input1.get_layout())); @@ -56,7 +56,7 @@ inline void DoTest(const engine& engine, TEST(gather_elements_gpu_fp16, d3283_i2283_a0) { const auto& engine = get_test_engine(); - const int axis = 0; + auto axis = cldnn::gather_elements::gather_elements_axis::along_b; auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 8, 3 } }); // data auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 8, 3 } }); // indices @@ -117,7 +117,7 @@ TEST(gather_elements_gpu_fp16, d3283_i2283_a0) { TEST(gather_elements_gpu_fp16, d2235_i2235_a3) { const auto& engine = get_test_engine(); - const int axis = 3; + auto axis = cldnn::gather_elements::gather_elements_axis::along_x; auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 5 } }); // data auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 5 } }); // indices set_values(input0, { @@ -195,7 +195,7 @@ TEST(gather_elements_gpu_fp16, d2235_i2235_a3) { TEST(gather_elements_gpu_fp16, d1329_i1359_an1) { const auto& engine = get_test_engine(); - const int axis = -1; + auto axis = cldnn::gather_elements::gather_elements_axis::along_x; auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 1, 3, 2, 9 } }); // data auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 1, 3, 5, 9 } }); // indices set_values(input0, { @@ -294,7 +294,7 @@ TEST(gather_elements_gpu_fp16, d1329_i1359_an1) { TEST(gather_elements_gpu_fp16, d12853_i12923_a3) { const auto& engine = get_test_engine(); - const int axis = 3; + auto axis = cldnn::gather_elements::gather_elements_axis::along_y; auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 1, 2, 8, 5, 3 } }); // data auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 1, 2, 8, 2, 3 } }); // indices @@ -367,7 +367,7 @@ TEST(gather_elements_gpu_fp16, d12853_i12923_a3) { TEST(gather_elements_gpu_fp16, d25441_i22441_an4) { const auto& engine = get_test_engine(); - const int axis = -4; + auto axis = cldnn::gather_elements::gather_elements_axis::along_f; auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 5, 4, 4, 1 } }); // data auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 2, 4, 4, 1 } }); // indices @@ -459,7 +459,7 @@ TEST(gather_elements_gpu_fp16, d25441_i22441_an4) { TEST(gather_elements_gpu_fp16, d32843_i12843_a0) { const auto& engine = get_test_engine(); - const int axis = 0; + auto axis = cldnn::gather_elements::gather_elements_axis::along_b; auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 8, 4, 3 } }); // data auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 1, 2, 8, 4, 3 } }); // indices @@ -599,7 +599,7 @@ TEST(gather_elements_gpu_fp16, d32843_i12843_a0) { TEST(gather_elements_gpu_fp16, d223442_i226442_a5) { const auto& engine = get_test_engine(); - const int axis = 5; + auto axis = cldnn::gather_elements::gather_elements_axis::along_x; auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 3, 4, 4, 2 } }); // data auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 6, 4, 4, 2 } }); // indices @@ -1002,7 +1002,7 @@ TEST(gather_elements_gpu_fp16, d223442_i226442_a5) { TEST(gather_elements_gpu_fp16, d124251_i124221_an3) { const auto& engine = get_test_engine(); - const int axis = -3; + auto axis = cldnn::gather_elements::gather_elements_axis::along_z; auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 1, 2, 4, 2, 5, 1 } }); // data auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 1, 2, 4, 2, 2, 1 } }); // indices @@ -1057,7 +1057,7 @@ TEST(gather_elements_gpu_fp16, d124251_i124221_an3) { TEST(gather_elements_gpu_fp16, d233113_i233115_a2) { const auto& engine = get_test_engine(); - const int axis = 2; + auto axis = cldnn::gather_elements::gather_elements_axis::along_w; auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 3, 1, 1, 3 } }); // data auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 3, 1, 1, 5 } }); // indices From b3af5dbc97e76d7973ad134c9a9ce1e255c5babb Mon Sep 17 00:00:00 2001 From: yunji Date: Fri, 16 Jul 2021 19:52:04 +0900 Subject: [PATCH 07/11] Apply New Runtime API - Resolve conflicts. - Change file location. --- .../src/cldnn_engine/ops/gather_elements.cpp | 2 +- .../single_layer_tests/gather_elements.cpp | 22 +---- .../primitives}/gather_elements.hpp | 0 .../gather/gather_elements_kernel_ref.cpp | 2 +- .../core/cl_kernels/gather_elements_ref.cl | 18 +--- .../thirdparty/clDNN/src/gather_elements.cpp | 2 +- .../clDNN/src/gpu/gather_elements_gpu.cpp | 95 ------------------ .../clDNN/src/impls/ocl/gather_elements.cpp | 86 ++++++++++++++++ .../clDNN/src/impls/ocl/register.cpp | 1 + .../clDNN/src/impls/ocl/register.hpp | 1 + .../clDNN/src/include/gather_elements_inst.h | 2 +- .../tests/test_cases/fusings_gpu_test.cpp | 19 ++-- .../test_cases/gather_elements_gpu_test.cpp | 98 +++++++++---------- 13 files changed, 152 insertions(+), 196 deletions(-) rename inference-engine/thirdparty/clDNN/api/{ => cldnn/primitives}/gather_elements.hpp (100%) delete mode 100644 inference-engine/thirdparty/clDNN/src/gpu/gather_elements_gpu.cpp create mode 100644 inference-engine/thirdparty/clDNN/src/impls/ocl/gather_elements.cpp diff --git a/inference-engine/src/cldnn_engine/ops/gather_elements.cpp b/inference-engine/src/cldnn_engine/ops/gather_elements.cpp index 7aa061c545dea7..52ddfe13479aa7 100644 --- a/inference-engine/src/cldnn_engine/ops/gather_elements.cpp +++ b/inference-engine/src/cldnn_engine/ops/gather_elements.cpp @@ -8,7 +8,7 @@ #include "ngraph/op/gather_elements.hpp" #include "ngraph/op/constant.hpp" -#include "api/gather_elements.hpp" +#include "cldnn/primitives/gather_elements.hpp" namespace CLDNNPlugin { diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp index b4dad74bba3f6e..30b84f007f011b 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp @@ -54,15 +54,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_set3, GatherElementsLayerTest, ::testing::Values(CommonTestUtils::DEVICE_CPU)), GatherElementsLayerTest::getTestCaseName); -<<<<<<< HEAD -<<<<<<< HEAD INSTANTIATE_TEST_SUITE_P(smoke_set4, GatherElementsLayerTest, -======= -INSTANTIATE_TEST_CASE_P(yunji_set2, GatherElementsLayerTest, ->>>>>>> Add cldnn unit test implementation -======= -INSTANTIATE_TEST_CASE_P(smoke_set4, GatherElementsLayerTest, ->>>>>>> Add functional test implementation ::testing::Combine( ::testing::Values(std::vector({3, 2, 3, 8})), // Data shape ::testing::Values(std::vector({2, 2, 3, 8})), // Indices shape @@ -72,19 +64,7 @@ INSTANTIATE_TEST_CASE_P(smoke_set4, GatherElementsLayerTest, ::testing::Values(CommonTestUtils::DEVICE_CPU)), GatherElementsLayerTest::getTestCaseName); -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD INSTANTIATE_TEST_SUITE_P(smoke_set5, GatherElementsLayerTest, -======= -INSTANTIATE_TEST_CASE_P(yunji_set3, GatherElementsLayerTest, ->>>>>>> Add cldnn unit test implementation -======= - -======= ->>>>>>> code clean up -INSTANTIATE_TEST_CASE_P(smoke_set5, GatherElementsLayerTest, ->>>>>>> Add functional test implementation ::testing::Combine( ::testing::Values(std::vector({3, 2, 3, 4, 8})), // Data shape ::testing::Values(std::vector({3, 2, 3, 5, 8})), // Indices shape @@ -93,4 +73,4 @@ INSTANTIATE_TEST_CASE_P(smoke_set5, GatherElementsLayerTest, ::testing::ValuesIn(iPrecisions), ::testing::Values(CommonTestUtils::DEVICE_CPU)), GatherElementsLayerTest::getTestCaseName); -} // namespace +} // namespace \ No newline at end of file diff --git a/inference-engine/thirdparty/clDNN/api/gather_elements.hpp b/inference-engine/thirdparty/clDNN/api/cldnn/primitives/gather_elements.hpp similarity index 100% rename from inference-engine/thirdparty/clDNN/api/gather_elements.hpp rename to inference-engine/thirdparty/clDNN/api/cldnn/primitives/gather_elements.hpp diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp index c4afc4c3ae1d82..092555014ef8f1 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp @@ -155,7 +155,7 @@ KernelsData GatherElementsKernelRef::GetKernelsData(const Params& params, const auto dispatchData = SetDefault(newParams, options); auto cldnn_jit = GetJitConstants(newParams); - auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options); + auto entry_point = GetEntryPoint(kernelName, newParams.layerID, params, options); auto jit = CreateJit(kernelName, cldnn_jit, entry_point); auto& kernel = kd.kernels[0]; FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 2, GetFusedPrimitiveInputsCount(params)); diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl index f7005f3450d6b4..5eedb0e1dcaba6 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl @@ -1,19 +1,9 @@ -// Copyright (c) 2021 Intel Corporation +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "include/fetch.cl" -#include "include/include_all.cl" +#include "include/data_types.cl" +#include "include/fetch_data.cl" #define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order) diff --git a/inference-engine/thirdparty/clDNN/src/gather_elements.cpp b/inference-engine/thirdparty/clDNN/src/gather_elements.cpp index 869c31e322e359..ddaa73c04a091b 100644 --- a/inference-engine/thirdparty/clDNN/src/gather_elements.cpp +++ b/inference-engine/thirdparty/clDNN/src/gather_elements.cpp @@ -17,7 +17,7 @@ #include "gather_elements_inst.h" #include "primitive_type_base.h" -#include "error_handler.h" +#include "cldnn/runtime/error_handler.hpp" #include "json_object.h" #include diff --git a/inference-engine/thirdparty/clDNN/src/gpu/gather_elements_gpu.cpp b/inference-engine/thirdparty/clDNN/src/gpu/gather_elements_gpu.cpp deleted file mode 100644 index a7de56fb6c7dd4..00000000000000 --- a/inference-engine/thirdparty/clDNN/src/gpu/gather_elements_gpu.cpp +++ /dev/null @@ -1,95 +0,0 @@ -/* -// Copyright (c) 2021 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -*/ - -#include "gather_elements_inst.h" -#include "primitive_gpu_base.h" -#include "implementation_map.h" -#include "kernel_selector_helper.h" -#include "gather/gather_elements_kernel_selector.h" -#include "gather/gather_elements_kernel_ref.h" -#include "error_handler.h" - -using namespace cldnn; - -namespace cldnn { -namespace gpu { -kernel_selector::gather_elements_axis convert_axis(gather_elements::gather_elements_axis axis) { - switch (axis) { - case gather_elements::along_x: - return kernel_selector::gather_elements_axis::X; - case gather_elements::along_y: - return kernel_selector::gather_elements_axis::Y; - case gather_elements::along_z: - return kernel_selector::gather_elements_axis::Z; - case gather_elements::along_w: - return kernel_selector::gather_elements_axis::W; - case gather_elements::along_f: - return kernel_selector::gather_elements_axis::FEATURE; - case gather_elements::along_b: - return kernel_selector::gather_elements_axis::BATCH; - default: - return kernel_selector::gather_elements_axis::X; - } -} - -struct gather_elements_gpu : typed_primitive_gpu_impl { - using parent = typed_primitive_gpu_impl; - using parent::parent; - -public: - static primitive_impl* create(const gather_elements_node& arg) { - auto gather_elements_params = get_default_params(arg); - auto gather_elements_optional_params = - get_default_optional_params(arg.get_program()); - - gather_elements_params.axis = convert_axis(arg.get_primitive()->axis); - - gather_elements_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout())); - - auto& kernel_selector = kernel_selector::gather_elements_kernel_selector::Instance(); - auto best_kernels = kernel_selector.GetBestKernels(gather_elements_params, gather_elements_optional_params); - - CLDNN_ERROR_BOOL(arg.id(), - "Best_kernel.empty()", - best_kernels.empty(), - "Cannot find a proper kernel with this arguments"); - - auto gather_elements = new gather_elements_gpu(arg, best_kernels[0]); - - return gather_elements; - } -}; - -namespace detail { - -attach_gather_elements_gpu::attach_gather_elements_gpu() { - auto val_fw = gather_elements_gpu::create; - implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw); - implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw); - implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), val_fw); - - implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw); - implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw); - implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfzyx), val_fw); - - implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfwzyx), val_fw); - implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfwzyx), val_fw); - implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfwzyx), val_fw); -} - -} // namespace detail -} // namespace gpu -} // namespace cldnn diff --git a/inference-engine/thirdparty/clDNN/src/impls/ocl/gather_elements.cpp b/inference-engine/thirdparty/clDNN/src/impls/ocl/gather_elements.cpp new file mode 100644 index 00000000000000..beb4fbd1068dcd --- /dev/null +++ b/inference-engine/thirdparty/clDNN/src/impls/ocl/gather_elements.cpp @@ -0,0 +1,86 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "gather_elements_inst.h" +#include "primitive_base.hpp" +#include "impls/implementation_map.hpp" +#include "kernel_selector_helper.h" +#include "gather/gather_elements_kernel_selector.h" +#include "gather/gather_elements_kernel_ref.h" +#include "cldnn/runtime/error_handler.hpp" + +using namespace cldnn; + +namespace cldnn { +namespace ocl { +kernel_selector::gather_elements_axis convert_axis(gather_elements::gather_elements_axis axis) { + switch (axis) { + case gather_elements::along_x: + return kernel_selector::gather_elements_axis::X; + case gather_elements::along_y: + return kernel_selector::gather_elements_axis::Y; + case gather_elements::along_z: + return kernel_selector::gather_elements_axis::Z; + case gather_elements::along_w: + return kernel_selector::gather_elements_axis::W; + case gather_elements::along_f: + return kernel_selector::gather_elements_axis::FEATURE; + case gather_elements::along_b: + return kernel_selector::gather_elements_axis::BATCH; + default: + return kernel_selector::gather_elements_axis::X; + } +} + +struct gather_elements_impl : typed_primitive_impl_ocl { + using parent = typed_primitive_impl_ocl; + using parent::parent; + + std::unique_ptr clone() const override { + return make_unique(*this); + } + +public: + static primitive_impl* create(const gather_elements_node& arg) { + auto gather_elements_params = get_default_params(arg); + auto gather_elements_optional_params = + get_default_optional_params(arg.get_program()); + + gather_elements_params.axis = convert_axis(arg.get_primitive()->axis); + + gather_elements_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout())); + + auto& kernel_selector = kernel_selector::gather_elements_kernel_selector::Instance(); + auto best_kernels = kernel_selector.GetBestKernels(gather_elements_params, gather_elements_optional_params); + + CLDNN_ERROR_BOOL(arg.id(), + "Best_kernel.empty()", + best_kernels.empty(), + "Cannot find a proper kernel with this arguments"); + + auto gather_elements = new gather_elements_impl(arg, best_kernels[0]); + + return gather_elements; + } +}; + +namespace detail { + +attach_gather_elements_impl::attach_gather_elements_impl() { + implementation_map::add(impl_types::ocl, gather_elements_impl::create, { + std::make_tuple(data_types::f32, format::bfyx), + std::make_tuple(data_types::f16, format::bfyx), + std::make_tuple(data_types::i32, format::bfyx), + std::make_tuple(data_types::f32, format::bfzyx), + std::make_tuple(data_types::f16, format::bfzyx), + std::make_tuple(data_types::i32, format::bfzyx), + std::make_tuple(data_types::f32, format::bfwzyx), + std::make_tuple(data_types::f16, format::bfwzyx), + std::make_tuple(data_types::i32, format::bfwzyx), + }); +} + +} // namespace detail +} // namespace ocl +} // namespace cldnn diff --git a/inference-engine/thirdparty/clDNN/src/impls/ocl/register.cpp b/inference-engine/thirdparty/clDNN/src/impls/ocl/register.cpp index c62b64de62dff8..86a423a84713e6 100644 --- a/inference-engine/thirdparty/clDNN/src/impls/ocl/register.cpp +++ b/inference-engine/thirdparty/clDNN/src/impls/ocl/register.cpp @@ -30,6 +30,7 @@ void register_implementations() { REGISTER_OCL(eltwise); REGISTER_OCL(fully_connected); REGISTER_OCL(gather); + REGISTER_OCL(gather_elements); REGISTER_OCL(gather_nd); REGISTER_OCL(gemm); REGISTER_OCL(lrn); diff --git a/inference-engine/thirdparty/clDNN/src/impls/ocl/register.hpp b/inference-engine/thirdparty/clDNN/src/impls/ocl/register.hpp index aaf2a777cde83a..dcd58574e52558 100644 --- a/inference-engine/thirdparty/clDNN/src/impls/ocl/register.hpp +++ b/inference-engine/thirdparty/clDNN/src/impls/ocl/register.hpp @@ -22,6 +22,7 @@ #include "cldnn/primitives/fully_connected.hpp" #include "cldnn/primitives/gather.hpp" #include "cldnn/primitives/gather_nd.hpp" +#include "cldnn/primitives/gather_elements.hpp" #include "cldnn/primitives/gemm.hpp" #include "cldnn/primitives/lrn.hpp" #include "cldnn/primitives/lstm.hpp" diff --git a/inference-engine/thirdparty/clDNN/src/include/gather_elements_inst.h b/inference-engine/thirdparty/clDNN/src/include/gather_elements_inst.h index 2b6952bdf6f015..ebefc9c032dea6 100644 --- a/inference-engine/thirdparty/clDNN/src/include/gather_elements_inst.h +++ b/inference-engine/thirdparty/clDNN/src/include/gather_elements_inst.h @@ -16,7 +16,7 @@ /////////////////////////////////////////////////////////////////////////////////////////////////// #pragma once -#include "api/gather_elements.hpp" +#include "cldnn/primitives/gather_elements.hpp" #include "primitive_inst.h" #include diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp index 747f575bc47bb3..fb51f20a7ce107 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -178,6 +179,7 @@ class BaseFusingTest : public ::testing::TestWithParam { description << " " << i.original_id << " " << i.kernel_id << std::endl; } SCOPED_TRACE(description.str()); + // Subtract reorders count to handle execution in different layouts when input/output reorders can be added in the graph ASSERT_EQ(fused.get_executed_primitives().size() - (count_reorder ? 0 : reorders_count_fused), p.expected_fused_primitives); ASSERT_EQ(not_fused.get_executed_primitives().size() - (count_reorder ? 0 : reorders_count_not_fused), p.expected_not_fused_primitives); ASSERT_EQ(outputs_ref.size(), outputs_fused.size()); @@ -8411,11 +8413,12 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_nd_activation_scale_eltwise, gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_2, 2, 5 }, gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_3, 2, 5 }, gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_4, 2, 5 }, -}), ); +})); + /* ----------------------------------------------------------------------------------------------------- */ -/* ------------------------------------------ GatherElements cases ------------------------------------------- */ +/* ------------------------------------------ GatherElements cases ------------------------------------- */ /* ----------------------------------------------------------------------------------------------------- */ struct gather_elements_test_params { data_types data_type; @@ -8525,7 +8528,7 @@ TEST_P(gather_elements_quantize, basic) { execute(p); } -INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_elements_quantize, +INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_elements_quantize, ::testing::ValuesIn(std::vector{ gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_1, 2, 3 }, gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_2, 2, 3 }, @@ -8548,7 +8551,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_elements_quantize, gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_1, 2, 3 }, gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 3 }, gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 3 }, -}), ); +})); class gather_elements_scale_activation : public GatherElementsPrimitiveFusingTest {}; @@ -8567,7 +8570,7 @@ TEST_P(gather_elements_scale_activation, basic) { execute(p); } -INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_elements_scale_activation, +INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_elements_scale_activation, ::testing::ValuesIn(std::vector{ gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_1, 2, 4 }, gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_2, 2, 4 }, @@ -8590,7 +8593,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_elements_scale_activation, gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_1, 2, 4 }, gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 4 }, gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 4 }, -}), ); +})); class gather_elements_activation_scale_eltwise : public GatherElementsPrimitiveFusingTest {}; @@ -8612,7 +8615,7 @@ TEST_P(gather_elements_activation_scale_eltwise, basic) { execute(p); } -INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_elements_activation_scale_eltwise, +INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_elements_activation_scale_eltwise, ::testing::ValuesIn(std::vector{ gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_1, 2, 5 }, gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_2, 2, 5 }, @@ -8635,4 +8638,4 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_elements_activation_scale_eltwise, gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_1, 2, 5 }, gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 5 }, gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 5 }, -}), ); +})); diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp index 9b667b8e4ac72e..62f84fb7cb5600 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp @@ -1,43 +1,32 @@ -// Copyright (c) 2021 Intel Corporation +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -/////////////////////////////////////////////////////////////////////////////////////////////////// -#include +#include "test_utils.h" -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include -#include +#include using namespace cldnn; using namespace ::tests; -inline void DoTest(const engine& engine, - const cldnn::memory& input0, // data - const cldnn::memory& input1, // indices +inline void DoTest(engine& engine, + const cldnn::memory::ptr& input0, // data + const cldnn::memory::ptr& input1, // indices const std::vector& expected_results, const tensor& output_tensor, const cldnn::gather_elements::gather_elements_axis axis) { topology topology; - topology.add(input_layout("InputData", input0.get_layout())); - topology.add(input_layout("InputIndices", input1.get_layout())); + topology.add(input_layout("InputData", input0->get_layout())); + topology.add(input_layout("InputIndices", input1->get_layout())); topology.add( - gather_elements("gather_elements", "InputData", "InputIndices", input1.get_layout().format, output_tensor, axis) + gather_elements("gather_elements", "InputData", "InputIndices", input1->get_layout().format, output_tensor, axis) ); network network(engine, topology); @@ -46,7 +35,8 @@ inline void DoTest(const engine& engine, network.set_input_data("InputIndices", input1); auto outputs = network.execute(); auto output = outputs.at("gather_elements").get_memory(); - auto output_ptr = output.pointer(); + // auto output_ptr = output.pointer(); + cldnn::mem_lock output_ptr(output, get_test_stream()); for (size_t i = 0; i < expected_results.size(); ++i) { EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); @@ -54,11 +44,11 @@ inline void DoTest(const engine& engine, } TEST(gather_elements_gpu_fp16, d3283_i2283_a0) { - const auto& engine = get_test_engine(); + auto& engine = get_test_engine(); auto axis = cldnn::gather_elements::gather_elements_axis::along_b; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 8, 3 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 8, 3 } }); // indices + auto input0 = engine.allocate_memory({ data_types::f16, format::bfyx, { 3, 2, 8, 3 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 2, 8, 3 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), @@ -115,11 +105,11 @@ TEST(gather_elements_gpu_fp16, d3283_i2283_a0) { } TEST(gather_elements_gpu_fp16, d2235_i2235_a3) { - const auto& engine = get_test_engine(); + auto& engine = get_test_engine(); auto axis = cldnn::gather_elements::gather_elements_axis::along_x; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 5 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 3, 5 } }); // indices + auto input0 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 2, 3, 5 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 2, 3, 5 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), @@ -193,11 +183,11 @@ TEST(gather_elements_gpu_fp16, d2235_i2235_a3) { } TEST(gather_elements_gpu_fp16, d1329_i1359_an1) { - const auto& engine = get_test_engine(); + auto& engine = get_test_engine(); auto axis = cldnn::gather_elements::gather_elements_axis::along_x; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 1, 3, 2, 9 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 1, 3, 5, 9 } }); // indices + auto input0 = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 3, 2, 9 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 3, 5, 9 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), @@ -292,11 +282,11 @@ TEST(gather_elements_gpu_fp16, d1329_i1359_an1) { } TEST(gather_elements_gpu_fp16, d12853_i12923_a3) { - const auto& engine = get_test_engine(); + auto& engine = get_test_engine(); auto axis = cldnn::gather_elements::gather_elements_axis::along_y; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 1, 2, 8, 5, 3 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 1, 2, 8, 2, 3 } }); // indices + auto input0 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 1, 2, 8, 5, 3 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 1, 2, 8, 2, 3 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), @@ -365,11 +355,11 @@ TEST(gather_elements_gpu_fp16, d12853_i12923_a3) { } TEST(gather_elements_gpu_fp16, d25441_i22441_an4) { - const auto& engine = get_test_engine(); + auto& engine = get_test_engine(); auto axis = cldnn::gather_elements::gather_elements_axis::along_f; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 5, 4, 4, 1 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 2, 4, 4, 1 } }); // indices + auto input0 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 2, 5, 4, 4, 1 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 2, 2, 4, 4, 1 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), @@ -457,11 +447,11 @@ TEST(gather_elements_gpu_fp16, d25441_i22441_an4) { } TEST(gather_elements_gpu_fp16, d32843_i12843_a0) { - const auto& engine = get_test_engine(); + auto& engine = get_test_engine(); auto axis = cldnn::gather_elements::gather_elements_axis::along_b; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 8, 4, 3 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 1, 2, 8, 4, 3 } }); // indices + auto input0 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 3, 2, 8, 4, 3 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 1, 2, 8, 4, 3 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), FLOAT16(5), FLOAT16(2), FLOAT16(0), FLOAT16(7), @@ -597,11 +587,11 @@ TEST(gather_elements_gpu_fp16, d32843_i12843_a0) { } TEST(gather_elements_gpu_fp16, d223442_i226442_a5) { - const auto& engine = get_test_engine(); + auto& engine = get_test_engine(); auto axis = cldnn::gather_elements::gather_elements_axis::along_x; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 3, 4, 4, 2 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 6, 4, 4, 2 } }); // indices + auto input0 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 2, 2, 3, 4, 4, 2 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 2, 2, 6, 4, 4, 2 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), @@ -1000,11 +990,11 @@ TEST(gather_elements_gpu_fp16, d223442_i226442_a5) { } TEST(gather_elements_gpu_fp16, d124251_i124221_an3) { - const auto& engine = get_test_engine(); + auto& engine = get_test_engine(); auto axis = cldnn::gather_elements::gather_elements_axis::along_z; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 1, 2, 4, 2, 5, 1 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 1, 2, 4, 2, 2, 1 } }); // indices + auto input0 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 1, 2, 4, 2, 5, 1 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 1, 2, 4, 2, 2, 1 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), FLOAT16(5), @@ -1055,11 +1045,11 @@ TEST(gather_elements_gpu_fp16, d124251_i124221_an3) { } TEST(gather_elements_gpu_fp16, d233113_i233115_a2) { - const auto& engine = get_test_engine(); + auto& engine = get_test_engine(); auto axis = cldnn::gather_elements::gather_elements_axis::along_w; - auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 3, 1, 1, 3 } }); // data - auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 3, 1, 1, 5 } }); // indices + auto input0 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 2, 3, 3, 1, 1, 3 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 2, 3, 3, 1, 1, 5 } }); // indices set_values(input0, { FLOAT16(0), FLOAT16(1), FLOAT16(8), From a57c57042c9d4cb836da679e7a9a27f66e4bf998 Mon Sep 17 00:00:00 2001 From: yunji Date: Tue, 20 Jul 2021 20:54:13 +0900 Subject: [PATCH 08/11] Fix reviewed codes. --- .../cldnn_engine/cldnn_primitives_list.hpp | 1 - .../src/cldnn_engine/ops/gather_elements.cpp | 14 +++++++------- .../api/cldnn/primitives/gather_elements.hpp | 17 ++--------------- .../gather/gather_elements_kernel_ref.cpp | 2 ++ .../core/cl_kernels/gather_elements_ref.cl | 13 +++++-------- .../thirdparty/clDNN/src/gather_elements.cpp | 19 ++----------------- .../clDNN/src/impls/ocl/gather_elements.cpp | 6 ++++++ .../test_cases/gather_elements_gpu_test.cpp | 1 - 8 files changed, 24 insertions(+), 49 deletions(-) diff --git a/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp b/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp index 4081c5be17029e..0c0ddf7e637050 100644 --- a/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp +++ b/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp @@ -208,7 +208,6 @@ REGISTER_FACTORY(v6, GatherElements); // ------------------------------ Supported v7 ops ------------------------------ // REGISTER_FACTORY(v7, Gather); -// REGISTER_FACTORY(v7, GatherElements); // ------------------------------ Supported v8 ops ------------------------------ // REGISTER_FACTORY(v8, Gather); diff --git a/inference-engine/src/cldnn_engine/ops/gather_elements.cpp b/inference-engine/src/cldnn_engine/ops/gather_elements.cpp index 52ddfe13479aa7..07f20356c76ff0 100644 --- a/inference-engine/src/cldnn_engine/ops/gather_elements.cpp +++ b/inference-engine/src/cldnn_engine/ops/gather_elements.cpp @@ -35,7 +35,7 @@ static cldnn::gather_elements::gather_elements_axis GetGatherElementsAxis(int ax case 3: return cldnn::gather_elements::gather_elements_axis::along_y; case 4: return cldnn::gather_elements::gather_elements_axis::along_z; case 5: return cldnn::gather_elements::gather_elements_axis::along_w; - default: IE_THROW() << "Unsupported ScatterElementsUpdate axis: " << axis; + default: IE_THROW() << "Unsupported GatherElements axis: " << axis; } return cldnn::gather_elements::gather_elements_axis::along_f; // shouldn't get here } @@ -51,11 +51,11 @@ void CreateGatherElementsOp(Program& p, const std::shared_ptrget_output_shape(0).size()); auto primitive = cldnn::gather_elements(layerName, - inputPrimitives[0], - inputPrimitives[1], - outLayout, - CldnnTensorFromIEDims(op->get_output_shape(0)), - GetGatherElementsAxis(axis, rank)); + inputPrimitives[0], + inputPrimitives[1], + outLayout, + CldnnTensorFromIEDims(op->get_output_shape(0)), + GetGatherElementsAxis(axis, rank)); p.AddPrimitive(primitive); p.AddPrimitiveToProfiler(op); @@ -63,4 +63,4 @@ void CreateGatherElementsOp(Program& p, const std::shared_ptr { const primitive_id& indices, const format& output_format, const tensor& output_shape, - // const uint8_t axis = 0, const gather_elements_axis axis, const padding& output_padding = padding()) : primitive_base(id, {data, indices}, output_padding), output_format(output_format), output_shape(output_shape), axis(axis) {} diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp index 092555014ef8f1..9959eecfa3df07 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp @@ -50,6 +50,8 @@ ParamsKey GatherElementsKernelRef::GetSupportedKey() const { k.EnableInputDataType(Datatype::F16); k.EnableInputDataType(Datatype::F32); k.EnableInputDataType(Datatype::INT32); + k.EnableInputDataType(Datatype::INT8); + k.EnableInputDataType(Datatype::UINT8); k.EnableOutputDataType(Datatype::F16); k.EnableOutputDataType(Datatype::F32); k.EnableOutputDataType(Datatype::INT32); diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl index 5eedb0e1dcaba6..04a12b6c914c6b 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl @@ -5,9 +5,9 @@ #include "include/data_types.cl" #include "include/fetch_data.cl" -#define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order) +#define GET_OUTPUT_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order) -KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, +KERNEL(gather_elements_ref)(const __global INPUT0_TYPE* data, const __global INPUT1_TYPE* indices, __global OUTPUT_TYPE* output #if HAS_FUSED_OPS_DECLS @@ -39,7 +39,7 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, const uint f = dim2 % OUTPUT_FEATURE_NUM; const uint b = dim2 / OUTPUT_FEATURE_NUM; - const int out_idx = GET_UPDATES_INDEX(INPUT1, ORDER); + const int out_idx = GET_OUTPUT_INDEX(INPUT1, ORDER); #if INPUT1_DIMS == 4 size_t data_shape[4] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X}; @@ -70,10 +70,7 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, size_t outer_sum = (out_idx / outer_sum_inc_indices) * outer_sum_inc_data; size_t inner_sum = out_idx % max_inner_sum; - if (indices[out_idx] < 0 || indices[out_idx] >= data_shape[AXIS]) { - printf("indices values of GatherElement exceed data size.\n"); - return; - } + uint idx = outer_sum + max_inner_sum * indices[out_idx] + inner_sum; INPUT0_TYPE val = data[idx]; @@ -85,4 +82,4 @@ KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data, #endif } -#undef GET_UPDATES_INDEX +#undef GET_OUTPUT_INDEX diff --git a/inference-engine/thirdparty/clDNN/src/gather_elements.cpp b/inference-engine/thirdparty/clDNN/src/gather_elements.cpp index ddaa73c04a091b..7a3a920aa6277e 100644 --- a/inference-engine/thirdparty/clDNN/src/gather_elements.cpp +++ b/inference-engine/thirdparty/clDNN/src/gather_elements.cpp @@ -1,18 +1,6 @@ -/* -// Copyright (c) 2021 Intel Corporation +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -*/ #include "gather_elements_inst.h" @@ -40,13 +28,10 @@ layout gather_elements_inst::calc_output_layout(gather_elements_node const& node input_layout_origin.data_type = node.get_fused_output_layout().data_type; } - // const size_t input_dims = input_layout.size(); auto output_type = indices_layout_origin.data_type; auto output_format = op->output_format; auto output_shape = op->output_shape; - // const auto axis = op->axis; - // calculate initial output shape return layout(output_type, output_format, output_shape); } diff --git a/inference-engine/thirdparty/clDNN/src/impls/ocl/gather_elements.cpp b/inference-engine/thirdparty/clDNN/src/impls/ocl/gather_elements.cpp index beb4fbd1068dcd..474bcf42ad5212 100644 --- a/inference-engine/thirdparty/clDNN/src/impls/ocl/gather_elements.cpp +++ b/inference-engine/thirdparty/clDNN/src/impls/ocl/gather_elements.cpp @@ -72,12 +72,18 @@ attach_gather_elements_impl::attach_gather_elements_impl() { std::make_tuple(data_types::f32, format::bfyx), std::make_tuple(data_types::f16, format::bfyx), std::make_tuple(data_types::i32, format::bfyx), + std::make_tuple(data_types::i8, format::bfyx), + std::make_tuple(data_types::u8, format::bfyx), std::make_tuple(data_types::f32, format::bfzyx), std::make_tuple(data_types::f16, format::bfzyx), std::make_tuple(data_types::i32, format::bfzyx), + std::make_tuple(data_types::i8, format::bfzyx), + std::make_tuple(data_types::u8, format::bfzyx), std::make_tuple(data_types::f32, format::bfwzyx), std::make_tuple(data_types::f16, format::bfwzyx), std::make_tuple(data_types::i32, format::bfwzyx), + std::make_tuple(data_types::i8, format::bfwzyx), + std::make_tuple(data_types::u8, format::bfwzyx), }); } diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp index 62f84fb7cb5600..d0ddfc4ff3a1c0 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp @@ -35,7 +35,6 @@ inline void DoTest(engine& engine, network.set_input_data("InputIndices", input1); auto outputs = network.execute(); auto output = outputs.at("gather_elements").get_memory(); - // auto output_ptr = output.pointer(); cldnn::mem_lock output_ptr(output, get_test_stream()); for (size_t i = 0; i < expected_results.size(); ++i) { From d36aa7eef8099b9e2201c7eb4bb9ad1e8a969da0 Mon Sep 17 00:00:00 2001 From: yunji Date: Tue, 20 Jul 2021 21:18:32 +0900 Subject: [PATCH 09/11] Reduce gpu functional test cases and refactor --- .../single_layer_tests/gather_elements.cpp | 353 +++++++----------- .../gather/gather_elements_kernel_ref.cpp | 16 +- .../gather/gather_elements_kernel_ref.h | 20 +- 3 files changed, 135 insertions(+), 254 deletions(-) diff --git a/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp index be951a0bb5840b..452586fac12346 100644 --- a/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp +++ b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp @@ -25,296 +25,203 @@ const std::vector idxPrecisions = { }; INSTANTIATE_TEST_CASE_P(smoke_set1, GatherElementsLayerTest, - ::testing::Combine( - ::testing::Values(std::vector({2, 2})), // Data shape - ::testing::Values(std::vector({2, 2})), // Indices shape - ::testing::ValuesIn(std::vector({-1, 0, 1})), // Axis - ::testing::ValuesIn(inputPrecisions), - ::testing::ValuesIn(idxPrecisions), - ::testing::Values(CommonTestUtils::DEVICE_GPU)), - GatherElementsLayerTest::getTestCaseName); + ::testing::Combine( + ::testing::Values(std::vector({2, 2})), + ::testing::Values(std::vector({2, 2})), + ::testing::ValuesIn(std::vector({-1, 0, 1})), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_set2, GatherElementsLayerTest, - ::testing::Combine( - ::testing::Values(std::vector({2, 2, 1})), // Data shape - ::testing::Values(std::vector({4, 2, 1})), // Indices shape - ::testing::ValuesIn(std::vector({0, -3})), // Axis - ::testing::ValuesIn(inputPrecisions), - ::testing::ValuesIn(idxPrecisions), - ::testing::Values(CommonTestUtils::DEVICE_GPU)), - GatherElementsLayerTest::getTestCaseName); + ::testing::Combine( + ::testing::Values(std::vector({2, 2, 1})), + ::testing::Values(std::vector({4, 2, 1})), + ::testing::ValuesIn(std::vector({0, -3})), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_set3, GatherElementsLayerTest, - ::testing::Combine( - ::testing::Values(std::vector({2, 2, 3, 5})), // Data shape - ::testing::Values(std::vector({2, 2, 3, 7})), // Indices shape - ::testing::Values(3, -1), // Axis - ::testing::ValuesIn(inputPrecisions), - ::testing::ValuesIn(idxPrecisions), - ::testing::Values(CommonTestUtils::DEVICE_GPU)), - GatherElementsLayerTest::getTestCaseName); + ::testing::Combine( + ::testing::Values(std::vector({2, 2, 3, 5})), + ::testing::Values(std::vector({2, 2, 3, 7})), + ::testing::Values(3, -1), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_set4, GatherElementsLayerTest, - ::testing::Combine( - ::testing::Values(std::vector({3, 2, 3, 8})), // Data shape - ::testing::Values(std::vector({2, 2, 3, 8})), // Indices shape - ::testing::Values(0, -4), // Axis - ::testing::ValuesIn(inputPrecisions), - ::testing::ValuesIn(idxPrecisions), - ::testing::Values(CommonTestUtils::DEVICE_GPU)), - GatherElementsLayerTest::getTestCaseName); + ::testing::Combine( + ::testing::Values(std::vector({3, 2, 3, 8})), + ::testing::Values(std::vector({2, 2, 3, 8})), + ::testing::Values(0, -4), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_set5, GatherElementsLayerTest, - ::testing::Combine( - ::testing::Values(std::vector({3, 2, 3, 4, 8})), // Data shape - ::testing::Values(std::vector({3, 2, 3, 5, 8})), // Indices shape - ::testing::Values(3, -2), // Axis - ::testing::ValuesIn(inputPrecisions), - ::testing::ValuesIn(idxPrecisions), - ::testing::Values(CommonTestUtils::DEVICE_GPU)), - GatherElementsLayerTest::getTestCaseName); - -const std::vector> ShapesRank4Axis0 = { - std::vector{1, 7, 8, 4}, - std::vector{2, 7, 8, 4}, - std::vector{7, 7, 8, 4}, - std::vector{9, 7, 8, 4}, -}; -const std::vector> ShapesRank4Axis1 = { - std::vector{6, 1, 8, 4}, - std::vector{6, 5, 8, 4}, - std::vector{6, 8, 8, 4}, - std::vector{6, 9, 8, 4}, -}; -const std::vector> ShapesRank4Axis2 = { - std::vector{6, 7, 2, 4}, - std::vector{6, 7, 4, 4}, - std::vector{6, 7, 5, 4}, - std::vector{6, 7, 7, 4}, -}; -const std::vector> ShapesRank4Axis3 = { - std::vector{6, 5, 8, 1}, - std::vector{6, 5, 8, 4}, - std::vector{6, 5, 8, 7}, - std::vector{6, 5, 8, 9}, -}; + ::testing::Combine( + ::testing::Values(std::vector({3, 2, 3, 4, 8})), + ::testing::Values(std::vector({3, 2, 3, 5, 8})), + ::testing::Values(3, -2), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis0, GatherElementsLayerTest, ::testing::Combine( - ::testing::ValuesIn(ShapesRank4Axis0), // Data shapes - ::testing::ValuesIn(ShapesRank4Axis0), // Indices shpae - ::testing::ValuesIn(std::vector({ 0 })), - ::testing::ValuesIn(inputPrecisions), // Data precision - ::testing::ValuesIn(idxPrecisions), // Indices precision - ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + ::testing::Values(std::vector{7, 7, 8, 4}), + ::testing::Values(std::vector{2, 7, 8, 4}), + ::testing::Values(0), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis1, GatherElementsLayerTest, ::testing::Combine( - ::testing::ValuesIn(ShapesRank4Axis1), // Data shapes - ::testing::ValuesIn(ShapesRank4Axis1), // Indices shpae - ::testing::ValuesIn(std::vector({ 1, -3 })), - ::testing::ValuesIn(inputPrecisions), // Data precision - ::testing::ValuesIn(idxPrecisions), // Indices precision - ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + ::testing::Values(std::vector{6, 1, 8, 4}), + ::testing::Values(std::vector{6, 8, 8, 4}), + ::testing::Values(1, -3), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis2, GatherElementsLayerTest, ::testing::Combine( - ::testing::ValuesIn(ShapesRank4Axis2), // Data shapes - ::testing::ValuesIn(ShapesRank4Axis2), // Indices shpae - ::testing::ValuesIn(std::vector({ 2, -2 })), - ::testing::ValuesIn(inputPrecisions), // Data precision - ::testing::ValuesIn(idxPrecisions), // Indices precision - ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + ::testing::Values(std::vector{6, 7, 4, 4}), + ::testing::Values(std::vector{6, 7, 2, 4}), + ::testing::Values(2, -2), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis3, GatherElementsLayerTest, ::testing::Combine( - ::testing::ValuesIn(ShapesRank4Axis3), // Data shapes - ::testing::ValuesIn(ShapesRank4Axis3), // Indices shpae - ::testing::ValuesIn(std::vector({ 3, -1 })), - ::testing::ValuesIn(inputPrecisions), // Data precision - ::testing::ValuesIn(idxPrecisions), // Indices precision - ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + ::testing::Values(std::vector{6, 5, 8, 7}), + ::testing::Values(std::vector{6, 5, 8, 7}), + ::testing::Values(3, -1), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); -const std::vector> ShapesRank5Axis0 = { - std::vector{2, 3, 9, 4, 9}, - std::vector{1, 3, 9, 4, 9}, - std::vector{5, 3, 9, 4, 9}, - std::vector{7, 3, 9, 4, 9}, -}; -const std::vector> ShapesRank5Axis1 = { - std::vector{2, 1, 5, 4, 7}, - std::vector{2, 3, 5, 4, 7}, - std::vector{2, 8, 5, 4, 7}, - std::vector{2, 9, 5, 4, 7}, -}; -const std::vector> ShapesRank5Axis2 = { - std::vector{1, 2, 2, 8, 9}, - std::vector{1, 2, 3, 8, 9}, - std::vector{1, 2, 6, 8, 9}, - std::vector{1, 2, 7, 8, 9}, -}; -const std::vector> ShapesRank5Axis3 = { - std::vector{2, 2, 4, 3, 7}, - std::vector{2, 2, 4, 4, 7}, - std::vector{2, 2, 4, 7, 7}, - std::vector{2, 2, 4, 9, 7}, -}; -const std::vector> ShapesRank5Axis4 = { - std::vector{1, 3, 9, 3, 1}, - std::vector{1, 3, 9, 3, 2}, - std::vector{1, 3, 9, 3, 5}, - std::vector{1, 3, 9, 3, 9}, -}; - INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis0, GatherElementsLayerTest, ::testing::Combine( - ::testing::ValuesIn(ShapesRank5Axis0), // Data shapes - ::testing::ValuesIn(ShapesRank5Axis0), // Indices shpae - ::testing::ValuesIn(std::vector({ 0 })), - ::testing::ValuesIn(inputPrecisions), // Data precision - ::testing::ValuesIn(idxPrecisions), // Indices precision - ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + ::testing::Values(std::vector{2, 3, 9, 4, 9}), + ::testing::Values(std::vector{1, 3, 9, 4, 9}), + ::testing::Values(0), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis1, GatherElementsLayerTest, ::testing::Combine( - ::testing::ValuesIn(ShapesRank5Axis1), // Data shapes - ::testing::ValuesIn(ShapesRank5Axis1), // Indices shpae - ::testing::ValuesIn(std::vector({ 1, -4 })), - ::testing::ValuesIn(inputPrecisions), // Data precision - ::testing::ValuesIn(idxPrecisions), // Indices precision - ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + ::testing::Values(std::vector{2, 3, 5, 4, 7}), + ::testing::Values(std::vector{2, 9, 5, 4, 7}), + ::testing::Values(1, -4), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis2, GatherElementsLayerTest, ::testing::Combine( - ::testing::ValuesIn(ShapesRank5Axis2), // Data shapes - ::testing::ValuesIn(ShapesRank5Axis2), // Indices shpae - ::testing::ValuesIn(std::vector({ 2, -3 })), - ::testing::ValuesIn(inputPrecisions), // Data precision - ::testing::ValuesIn(idxPrecisions), // Indices precision - ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + ::testing::Values(std::vector{1, 2, 6, 8, 9}), + ::testing::Values(std::vector{1, 2, 6, 8, 9}), + ::testing::Values(2, -3), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis3, GatherElementsLayerTest, ::testing::Combine( - ::testing::ValuesIn(ShapesRank5Axis3), // Data shapes - ::testing::ValuesIn(ShapesRank5Axis3), // Indices shpae - ::testing::ValuesIn(std::vector({ 3, -2 })), - ::testing::ValuesIn(inputPrecisions), // Data precision - ::testing::ValuesIn(idxPrecisions), // Indices precision - ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + ::testing::Values(std::vector{2, 2, 4, 7, 7}), + ::testing::Values(std::vector{2, 2, 4, 3, 7}), + ::testing::Values(3, -2), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis4, GatherElementsLayerTest, ::testing::Combine( - ::testing::ValuesIn(ShapesRank5Axis4), // Data shapes - ::testing::ValuesIn(ShapesRank5Axis4), // Indices shpae - ::testing::ValuesIn(std::vector({ 4, -1 })), - ::testing::ValuesIn(inputPrecisions), // Data precision - ::testing::ValuesIn(idxPrecisions), // Indices precision - ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + ::testing::Values(std::vector{1, 3, 9, 3, 2}), + ::testing::Values(std::vector{1, 3, 9, 3, 9}), + ::testing::Values(4, -1), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); -const std::vector> ShapesRank6Axis0 = { - std::vector{1, 3, 2, 4, 4, 3}, - std::vector{3, 3, 2, 4, 4, 3}, - std::vector{6, 3, 2, 4, 4, 3}, - std::vector{7, 3, 2, 4, 4, 3}, -}; -const std::vector> ShapesRank6Axis1 = { - std::vector{1, 2, 2, 3, 5, 9}, - std::vector{1, 5, 2, 3, 5, 9}, - std::vector{1, 6, 2, 3, 5, 9}, - std::vector{1, 9, 2, 3, 5, 9}, -}; -const std::vector> ShapesRank6Axis2 = { - std::vector{2, 3, 2, 7, 2, 1}, - std::vector{2, 3, 5, 7, 2, 1}, - std::vector{2, 3, 8, 7, 2, 1}, - std::vector{2, 3, 9, 7, 2, 1}, -}; -const std::vector> ShapesRank6Axis3 = { - std::vector{1, 3, 4, 2, 1, 3}, - std::vector{1, 3, 4, 4, 1, 3}, - std::vector{1, 3, 4, 5, 1, 3}, - std::vector{1, 3, 4, 8, 1, 3}, -}; -const std::vector> ShapesRank6Axis4 = { - std::vector{1, 3, 2, 4, 1, 3}, - std::vector{1, 3, 2, 4, 4, 3}, - std::vector{1, 3, 2, 4, 6, 3}, - std::vector{1, 3, 2, 4, 7, 3}, -}; -const std::vector> ShapesRank6Axis5 = { - std::vector{2, 1, 7, 8, 1, 2}, - std::vector{2, 1, 7, 8, 1, 3}, - std::vector{2, 1, 7, 8, 1, 4}, - std::vector{2, 1, 7, 8, 1, 6}, -}; - INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis0, GatherElementsLayerTest, ::testing::Combine( - ::testing::ValuesIn(ShapesRank6Axis0), // Data shapes - ::testing::ValuesIn(ShapesRank6Axis0), // Indices shpae - ::testing::ValuesIn(std::vector({ 0 })), - ::testing::ValuesIn(inputPrecisions), // Data precision - ::testing::ValuesIn(idxPrecisions), // Indices precision - ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + ::testing::Values(std::vector{3, 3, 2, 4, 4, 3}), + ::testing::Values(std::vector{7, 3, 2, 4, 4, 3}), + ::testing::Values(0), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis1, GatherElementsLayerTest, ::testing::Combine( - ::testing::ValuesIn(ShapesRank6Axis1), // Data shapes - ::testing::ValuesIn(ShapesRank6Axis1), // Indices shpae - ::testing::ValuesIn(std::vector({ 1, -5 })), - ::testing::ValuesIn(inputPrecisions), // Data precision - ::testing::ValuesIn(idxPrecisions), // Indices precision - ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + ::testing::Values(std::vector{1, 6, 2, 3, 5, 9}), + ::testing::Values(std::vector{1, 6, 2, 3, 5, 9}), + ::testing::Values(1, -5), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis2, GatherElementsLayerTest, ::testing::Combine( - ::testing::ValuesIn(ShapesRank6Axis2), // Data shapes - ::testing::ValuesIn(ShapesRank6Axis2), // Indices shpae - ::testing::ValuesIn(std::vector({ 2, -4 })), - ::testing::ValuesIn(inputPrecisions), // Data precision - ::testing::ValuesIn(idxPrecisions), // Indices precision - ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + ::testing::Values(std::vector{2, 3, 9, 7, 2, 1}), + ::testing::Values(std::vector{2, 3, 5, 7, 2, 1}), + ::testing::Values(2, -4), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis3, GatherElementsLayerTest, ::testing::Combine( - ::testing::ValuesIn(ShapesRank6Axis3), // Data shapes - ::testing::ValuesIn(ShapesRank6Axis3), // Indices shpae - ::testing::ValuesIn(std::vector({ 3, -3 })), - ::testing::ValuesIn(inputPrecisions), // Data precision - ::testing::ValuesIn(idxPrecisions), // Indices precision - ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + ::testing::Values(std::vector{1, 3, 4, 5, 1, 3}), + ::testing::Values(std::vector{1, 3, 4, 4, 1, 3}), + ::testing::Values(3, -3), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis4, GatherElementsLayerTest, ::testing::Combine( - ::testing::ValuesIn(ShapesRank6Axis4), // Data shapes - ::testing::ValuesIn(ShapesRank6Axis4), // Indices shpae - ::testing::ValuesIn(std::vector({ 4, -2 })), - ::testing::ValuesIn(inputPrecisions), // Data precision - ::testing::ValuesIn(idxPrecisions), // Indices precision - ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + ::testing::Values(std::vector{1, 3, 2, 4, 3, 3}), + ::testing::Values(std::vector{1, 3, 2, 4, 6, 3}), + ::testing::Values(4, -2), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis5, GatherElementsLayerTest, ::testing::Combine( - ::testing::ValuesIn(ShapesRank6Axis5), // Data shapes - ::testing::ValuesIn(ShapesRank6Axis5), // Indices shpae - ::testing::ValuesIn(std::vector({ 5, -1 })), - ::testing::ValuesIn(inputPrecisions), // Data precision - ::testing::ValuesIn(idxPrecisions), // Indices precision - ::testing::Values(CommonTestUtils::DEVICE_GPU)), // Device name + ::testing::Values(std::vector{2, 1, 7, 8, 1, 6}), + ::testing::Values(std::vector{2, 1, 7, 8, 1, 5}), + ::testing::Values(5, -1), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), GatherElementsLayerTest::getTestCaseName); } // namespace diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp index 9959eecfa3df07..2d126530ecadf9 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp @@ -1,18 +1,6 @@ -/* -// Copyright (c) 2021 Intel Corporation +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -*/ #include "gather_elements_kernel_ref.h" #include "kernel_selector_utils.h" diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h index 5826671790389e..3216875f59a3cd 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h @@ -1,18 +1,6 @@ -/* -// Copyright (c) 2021 Intel Corporation +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 // -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -*/ #pragma once @@ -26,8 +14,6 @@ struct gather_elements_params : public base_params { gather_elements_params() : base_params(KernelType::GATHER_ELEMENTS), axis(GatherElementsAxis::BATCH) {} GatherElementsAxis axis; - - virtual ParamsKey GetParamsKey() const { return base_params::GetParamsKey(); } }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -43,7 +29,7 @@ class GatherElementsKernelRef : public KernelBaseOpenCL { virtual ~GatherElementsKernelRef() {} virtual JitConstants GetJitConstants(const gather_elements_params& params) const; virtual CommonDispatchData SetDefault(const gather_elements_params& params, const optional_params&) const; - KernelsData GetKernelsData(const Params& params, const optional_params& options) const; + KernelsData GetKernelsData(const Params& params, const optional_params& options) const override; ParamsKey GetSupportedKey() const override; std::vector GetSupportedFusedOps() const override { return { FusedOpType::QUANTIZE, From 306c093d33087d2ac0f218c31cbe6866b37ef088 Mon Sep 17 00:00:00 2001 From: yunji Date: Mon, 26 Jul 2021 20:36:59 +0900 Subject: [PATCH 10/11] Change the location of functional test definition. --- .../single_layer_tests/gather_elements.cpp | 3 ++- .../single_layer_tests/gather_elements.hpp | 4 ++++ .../src/single_layer/gather_elements.cpp | 3 --- .../api/cldnn/primitives/gather_elements.hpp | 12 ++++++------ .../test_cases/gather_elements_gpu_test.cpp | 18 +++++++++--------- 5 files changed, 21 insertions(+), 19 deletions(-) diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp index 30b84f007f011b..63a6fd88c3e241 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp @@ -4,7 +4,8 @@ #include -#include "shared_test_classes/single_layer/gather_elements.hpp" +#include "single_layer_tests/gather_elements.hpp" +#include "common_test_utils/test_constants.hpp" using namespace LayerTestsDefinitions; diff --git a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp index 9c6329c76b3e81..eea88d4abf3183 100644 --- a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp +++ b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/gather_elements.hpp @@ -8,4 +8,8 @@ namespace LayerTestsDefinitions { +TEST_P(GatherElementsLayerTest, CompareWithRefs) { + Run(); +} + } // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/shared_test_classes/src/single_layer/gather_elements.cpp b/inference-engine/tests/functional/shared_test_classes/src/single_layer/gather_elements.cpp index af5832302aa0be..d559e04a53d2c1 100644 --- a/inference-engine/tests/functional/shared_test_classes/src/single_layer/gather_elements.cpp +++ b/inference-engine/tests/functional/shared_test_classes/src/single_layer/gather_elements.cpp @@ -48,7 +48,4 @@ void GatherElementsLayerTest::SetUp() { function = std::make_shared(results, params, "gatherEl"); } -TEST_P(GatherElementsLayerTest, CompareWithRefs) { - Run(); -} } // namespace LayerTestsDefinitions diff --git a/inference-engine/thirdparty/clDNN/api/cldnn/primitives/gather_elements.hpp b/inference-engine/thirdparty/clDNN/api/cldnn/primitives/gather_elements.hpp index ea058a930d181c..fdcbc1df3babd9 100644 --- a/inference-engine/thirdparty/clDNN/api/cldnn/primitives/gather_elements.hpp +++ b/inference-engine/thirdparty/clDNN/api/cldnn/primitives/gather_elements.hpp @@ -36,12 +36,12 @@ struct gather_elements : public primitive_base { /// @param output_shape Output shape. /// @param axis Gathering axis. gather_elements(const primitive_id& id, - const primitive_id& data, - const primitive_id& indices, - const format& output_format, - const tensor& output_shape, - const gather_elements_axis axis, - const padding& output_padding = padding()) + const primitive_id& data, + const primitive_id& indices, + const format& output_format, + const tensor& output_shape, + const gather_elements_axis axis, + const padding& output_padding = padding()) : primitive_base(id, {data, indices}, output_padding), output_format(output_format), output_shape(output_shape), axis(axis) {} /// @brief Gather Elements output format diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp index d0ddfc4ff3a1c0..034f9f6699ada5 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp @@ -100,7 +100,7 @@ TEST(gather_elements_gpu_fp16, d3283_i2283_a0) { FLOAT16(2), FLOAT16(10), FLOAT16(7), FLOAT16(3), FLOAT16(3), FLOAT16(10), FLOAT16(6), FLOAT16(1), }; - DoTest(engine,input0, input1, expected_results, tensor(2, 2, 8, 3), axis); + DoTest(engine, input0, input1, expected_results, tensor(2, 2, 8, 3), axis); } TEST(gather_elements_gpu_fp16, d2235_i2235_a3) { @@ -178,7 +178,7 @@ TEST(gather_elements_gpu_fp16, d2235_i2235_a3) { FLOAT16(9), FLOAT16(9), FLOAT16(0), }; - DoTest(engine,input0, input1, expected_results, tensor(2, 2, 3, 5), axis); + DoTest(engine, input0, input1, expected_results, tensor(2, 2, 3, 5), axis); } TEST(gather_elements_gpu_fp16, d1329_i1359_an1) { @@ -277,7 +277,7 @@ TEST(gather_elements_gpu_fp16, d1329_i1359_an1) { FLOAT16(3), FLOAT16(3), FLOAT16(2), FLOAT16(3), FLOAT16(3), }; - DoTest(engine,input0, input1, expected_results, tensor(1, 3, 5, 9), axis); + DoTest(engine, input0, input1, expected_results, tensor(1, 3, 5, 9), axis); } TEST(gather_elements_gpu_fp16, d12853_i12923_a3) { @@ -350,7 +350,7 @@ TEST(gather_elements_gpu_fp16, d12853_i12923_a3) { FLOAT16(1), FLOAT16(7), FLOAT16(10), FLOAT16(0), FLOAT16(9), FLOAT16(4), FLOAT16(5), FLOAT16(5), }; - DoTest(engine,input0, input1, expected_results, tensor(1, 2, 8, 2, 3), axis); + DoTest(engine, input0, input1, expected_results, tensor(1, 2, 8, 2, 3), axis); } TEST(gather_elements_gpu_fp16, d25441_i22441_an4) { @@ -442,7 +442,7 @@ TEST(gather_elements_gpu_fp16, d25441_i22441_an4) { FLOAT16(6), FLOAT16(5), FLOAT16(10), FLOAT16(8), }; - DoTest(engine,input0, input1, expected_results, tensor(2, 2, 4, 4, 1), axis); + DoTest(engine, input0, input1, expected_results, tensor(2, 2, 4, 4, 1), axis); } TEST(gather_elements_gpu_fp16, d32843_i12843_a0) { @@ -582,7 +582,7 @@ TEST(gather_elements_gpu_fp16, d32843_i12843_a0) { FLOAT16(7), FLOAT16(4), FLOAT16(6), FLOAT16(8), FLOAT16(2), FLOAT16(7), FLOAT16(3), FLOAT16(5), }; - DoTest(engine,input0, input1, expected_results, tensor(1, 2, 8, 4, 3), axis); + DoTest(engine, input0, input1, expected_results, tensor(1, 2, 8, 4, 3), axis); } TEST(gather_elements_gpu_fp16, d223442_i226442_a5) { @@ -985,7 +985,7 @@ TEST(gather_elements_gpu_fp16, d223442_i226442_a5) { FLOAT16(3), FLOAT16(3), FLOAT16(7), FLOAT16(8), FLOAT16(3), FLOAT16(8), }; - DoTest(engine,input0, input1, expected_results, tensor(2, 2, 6, 4, 4, 2), axis); + DoTest(engine, input0, input1, expected_results, tensor(2, 2, 6, 4, 4, 2), axis); } TEST(gather_elements_gpu_fp16, d124251_i124221_an3) { @@ -1040,7 +1040,7 @@ TEST(gather_elements_gpu_fp16, d124251_i124221_an3) { FLOAT16(2), FLOAT16(0), FLOAT16(5), FLOAT16(8), }; - DoTest(engine,input0, input1, expected_results, tensor(1, 2, 4, 2, 2, 1), axis); + DoTest(engine, input0, input1, expected_results, tensor(1, 2, 4, 2, 2, 1), axis); } TEST(gather_elements_gpu_fp16, d233113_i233115_a2) { @@ -1137,5 +1137,5 @@ TEST(gather_elements_gpu_fp16, d233113_i233115_a2) { FLOAT16(5), FLOAT16(6), FLOAT16(3), }; - DoTest(engine,input0, input1, expected_results, tensor(2, 3, 3, 1, 1, 5), axis); + DoTest(engine, input0, input1, expected_results, tensor(2, 3, 3, 1, 1, 5), axis); } From faa7c8281de5a725f2ac6ddf62e4ba0d295a5a27 Mon Sep 17 00:00:00 2001 From: yunji Date: Tue, 27 Jul 2021 02:14:23 +0900 Subject: [PATCH 11/11] Fix reviewed code and remove i8 and u8 data type. --- .../src/cldnn_engine/ops/gather_elements.cpp | 4 ++-- .../single_layer_tests/gather_elements.cpp | 4 +--- .../single_layer_tests/gather_elements.cpp | 2 +- .../api/cldnn/primitives/gather_elements.hpp | 2 +- .../kernel_selector/common/common_types.h | 12 ------------ .../gather/gather_elements_kernel_ref.cpp | 19 +++++++++---------- .../gather/gather_elements_kernel_ref.h | 5 +++-- .../gather_elements_kernel_selector.cpp | 2 +- .../gather/gather_elements_kernel_selector.h | 2 +- .../core/cl_kernels/gather_elements_ref.cl | 1 + .../core/kernel_selector_common.cpp | 14 ++------------ .../core/kernel_selector_common.h | 1 - .../clDNN/src/impls/ocl/gather_elements.cpp | 8 +------- .../src/include/kernel_selector_helper.h | 2 +- 14 files changed, 24 insertions(+), 54 deletions(-) diff --git a/inference-engine/src/cldnn_engine/ops/gather_elements.cpp b/inference-engine/src/cldnn_engine/ops/gather_elements.cpp index 07f20356c76ff0..d61382807506c1 100644 --- a/inference-engine/src/cldnn_engine/ops/gather_elements.cpp +++ b/inference-engine/src/cldnn_engine/ops/gather_elements.cpp @@ -12,7 +12,7 @@ namespace CLDNNPlugin { -static cldnn::gather_elements::gather_elements_axis GetGatherElementsAxis(int axis, unsigned rank) { +static cldnn::gather_elements::gather_elements_axis GetGatherAxis(int axis, unsigned rank) { if (axis < 0) axis += rank; if (axis < 0 || axis >= rank) @@ -55,7 +55,7 @@ void CreateGatherElementsOp(Program& p, const std::shared_ptrget_output_shape(0)), - GetGatherElementsAxis(axis, rank)); + GetGatherAxis(axis, rank)); p.AddPrimitive(primitive); p.AddPrimitiveToProfiler(op); diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp index 63a6fd88c3e241..0220364af315f8 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_elements.cpp @@ -17,8 +17,6 @@ const std::vector dPrecisions = { InferenceEngine::Precision::I32, InferenceEngine::Precision::I64, InferenceEngine::Precision::I16, - InferenceEngine::Precision::U8, - InferenceEngine::Precision::I8 }; const std::vector iPrecisions = { InferenceEngine::Precision::I32, @@ -74,4 +72,4 @@ INSTANTIATE_TEST_SUITE_P(smoke_set5, GatherElementsLayerTest, ::testing::ValuesIn(iPrecisions), ::testing::Values(CommonTestUtils::DEVICE_CPU)), GatherElementsLayerTest::getTestCaseName); -} // namespace \ No newline at end of file +} // namespace diff --git a/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp index 452586fac12346..cbc4e9fed4fc5f 100644 --- a/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp +++ b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/gather_elements.cpp @@ -1,6 +1,6 @@ // Copyright (C) 2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 -// // +// #include #include diff --git a/inference-engine/thirdparty/clDNN/api/cldnn/primitives/gather_elements.hpp b/inference-engine/thirdparty/clDNN/api/cldnn/primitives/gather_elements.hpp index fdcbc1df3babd9..d6d0ca9fdb24f9 100644 --- a/inference-engine/thirdparty/clDNN/api/cldnn/primitives/gather_elements.hpp +++ b/inference-engine/thirdparty/clDNN/api/cldnn/primitives/gather_elements.hpp @@ -55,4 +55,4 @@ struct gather_elements : public primitive_base { /// @} /// @} /// @} -} // namespace cldnn \ No newline at end of file +} // namespace cldnn diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h b/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h index fbb108a8124baa..dbe6bd7004c672 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h @@ -514,18 +514,6 @@ enum class GatherAxis { BATCH, }; -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// GatherElementsAxis -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -enum class GatherElementsAxis { - X, - Y, - Z, - W, - FEATURE, - BATCH, -}; - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // ScatterUpdateAxis //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp index 2d126530ecadf9..eb01e12a12f0ee 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.cpp @@ -14,17 +14,17 @@ static size_t GetGatherElementsChannelIndex(const gather_elements_params& params size_t inputSize = params.inputs[0].GetDims().size(); switch (params.axis) { - case GatherElementsAxis::X: + case GatherAxis::X: return inputSize - 1; - case GatherElementsAxis::Y: + case GatherAxis::Y: return inputSize - 2; - case GatherElementsAxis::Z: + case GatherAxis::Z: return inputSize - 3; - case GatherElementsAxis::W: + case GatherAxis::W: return 2; - case GatherElementsAxis::FEATURE: + case GatherAxis::FEATURE: return 1; - case GatherElementsAxis::BATCH: + case GatherAxis::BATCH: return 0; default: break; @@ -38,13 +38,9 @@ ParamsKey GatherElementsKernelRef::GetSupportedKey() const { k.EnableInputDataType(Datatype::F16); k.EnableInputDataType(Datatype::F32); k.EnableInputDataType(Datatype::INT32); - k.EnableInputDataType(Datatype::INT8); - k.EnableInputDataType(Datatype::UINT8); k.EnableOutputDataType(Datatype::F16); k.EnableOutputDataType(Datatype::F32); k.EnableOutputDataType(Datatype::INT32); - k.EnableOutputDataType(Datatype::INT8); - k.EnableOutputDataType(Datatype::UINT8); k.EnableInputLayout(DataLayout::bfyx); k.EnableOutputLayout(DataLayout::bfyx); k.EnableInputLayout(DataLayout::bfzyx); @@ -152,4 +148,7 @@ KernelsData GatherElementsKernelRef::GetKernelsData(const Params& params, const return { kd }; } +KernelsPriority GatherElementsKernelRef::GetKernelsPriority(const Params& /*params*/, const optional_params& /*options*/) const { + return DONT_USE_IF_HAVE_SOMETHING_ELSE; +} } // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h index 3216875f59a3cd..8eec4ae96326fa 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_ref.h @@ -11,9 +11,9 @@ namespace kernel_selector { // gather_elements_params //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// struct gather_elements_params : public base_params { - gather_elements_params() : base_params(KernelType::GATHER_ELEMENTS), axis(GatherElementsAxis::BATCH) {} + gather_elements_params() : base_params(KernelType::GATHER_ELEMENTS), axis(GatherAxis::BATCH) {} - GatherElementsAxis axis; + GatherAxis axis; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -30,6 +30,7 @@ class GatherElementsKernelRef : public KernelBaseOpenCL { virtual JitConstants GetJitConstants(const gather_elements_params& params) const; virtual CommonDispatchData SetDefault(const gather_elements_params& params, const optional_params&) const; KernelsData GetKernelsData(const Params& params, const optional_params& options) const override; + KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override; ParamsKey GetSupportedKey() const override; std::vector GetSupportedFusedOps() const override { return { FusedOpType::QUANTIZE, diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.cpp index 361e89e6ad5c2b..3a451cf574add9 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/gather/gather_elements_kernel_selector.cpp @@ -24,4 +24,4 @@ gather_elements_kernel_selector::gather_elements_kernel_selector() { Attach