-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
18 changed files
with
1,121 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
inference-engine/src/cldnn_engine/ops/scatter_elements_update.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
// Copyright (C) 2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "cldnn_program.h" | ||
#include "cldnn_common_utils.h" | ||
|
||
#include "ngraph/op/scatter_elements_update.hpp" | ||
#include "ngraph/op/constant.hpp" | ||
|
||
#include "api/scatter_elements_update.hpp" | ||
|
||
namespace CLDNNPlugin { | ||
|
||
static inline cldnn::scatter_elements_update::scatter_elements_update_axis GetScatterElementsUpdateAxis(int axis, unsigned rank) { | ||
if (axis < 0) | ||
axis += rank; | ||
if (axis < 0 || axis >= rank) | ||
THROW_IE_EXCEPTION << "ScatterElementsUpdate axis is not correspond to number of dimensions"; | ||
|
||
// Difference in dimension ordering between IE and clDNN, | ||
// reverse spatial dimensions after batch and feature. | ||
unsigned cldnn_axis = axis; | ||
if (axis >= 2) { | ||
auto spatial_axis = axis - 2; | ||
// Default and minimum number of dimensions is 4 | ||
auto spatial_size = std::max(rank, 4u) - 2; | ||
cldnn_axis = spatial_size - spatial_axis - 1 + 2; | ||
} | ||
|
||
switch (cldnn_axis) { | ||
case 0: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_b; | ||
case 1: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_f; | ||
case 2: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_x; | ||
case 3: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_y; | ||
case 4: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_z; | ||
case 5: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_w; | ||
default: THROW_IE_EXCEPTION << "Unsupported ScatterElementsUpdate axis: " << axis; | ||
} | ||
|
||
return cldnn::scatter_elements_update::scatter_elements_update_axis::along_f; // shouldn't get here | ||
} | ||
|
||
void CreateScatterElementsUpdateOp(Program& p, const std::shared_ptr<ngraph::op::v3::ScatterElementsUpdate>& op) { | ||
p.ValidateInputs(op, {4}); | ||
auto inputPrimitives = p.GetInputPrimitiveIDs(op); | ||
std::string layerName = layer_type_name_ID(op); | ||
|
||
size_t rank = op->get_input_shape(0).size(); | ||
auto axes_constant = std::dynamic_pointer_cast<ngraph::op::Constant>(op->get_input_node_shared_ptr(3)); | ||
if (!axes_constant) { | ||
THROW_IE_EXCEPTION << "Unsupported parameter nodes type in " << op->get_friendly_name() << " (" << op->get_type_name() << ")"; | ||
} | ||
int32_t axis = axes_constant->cast_vector<int32_t>()[0]; | ||
|
||
auto primitive = cldnn::scatter_elements_update(layerName, | ||
inputPrimitives[0], | ||
inputPrimitives[1], | ||
inputPrimitives[2], | ||
GetScatterElementsUpdateAxis(axis, rank)); | ||
|
||
p.AddPrimitive(primitive); | ||
p.AddPrimitiveToProfiler(op); | ||
} | ||
|
||
REGISTER_FACTORY_IMPL(v3, ScatterElementsUpdate); | ||
|
||
} // namespace CLDNNPlugin |
48 changes: 48 additions & 0 deletions
48
...nctional/plugin/gpu/shared_tests_instances/single_layer_tests/scatter_elements_update.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
// Copyright (C) 2020 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <vector> | ||
#include <ngraph/opsets/opset3.hpp> | ||
|
||
#include "single_layer_tests/scatter_elements_update.hpp" | ||
#include "common_test_utils/test_constants.hpp" | ||
|
||
using namespace LayerTestsDefinitions; | ||
using namespace ngraph::opset3; | ||
|
||
namespace { | ||
// map<inputShape, map<indicesShape, axis>> | ||
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>> axesShapeInShape { | ||
{{10, 12, 15}, {{{1, 2, 4}, {0, 1, 2}}, {{2, 2, 2}, {-1, -2, -3}}}}, | ||
{{15, 9, 8, 12}, {{{1, 2, 2, 2}, {0, 1, 2, 3}}, {{1, 2, 1, 4}, {-1, -2, -3, -4}}}}, | ||
{{9, 9, 8, 8, 11, 10}, {{{1, 2, 1, 2, 1, 2}, {5, -3}}}}, | ||
}; | ||
|
||
// index value should not be random data | ||
const std::vector<std::vector<size_t>> idxValue = { | ||
{1, 0, 4, 6, 2, 3, 7, 5} | ||
}; | ||
|
||
const std::vector<InferenceEngine::Precision> inputPrecisions = { | ||
InferenceEngine::Precision::FP32, | ||
InferenceEngine::Precision::FP16, | ||
InferenceEngine::Precision::I32, | ||
}; | ||
|
||
const std::vector<InferenceEngine::Precision> idxPrecisions = { | ||
InferenceEngine::Precision::I32, | ||
InferenceEngine::Precision::I64, | ||
}; | ||
|
||
const auto ScatterEltUpdateCases = ::testing::Combine( | ||
::testing::ValuesIn(ScatterElementsUpdateLayerTest::combineShapes(axesShapeInShape)), | ||
::testing::ValuesIn(idxValue), | ||
::testing::ValuesIn(inputPrecisions), | ||
::testing::ValuesIn(idxPrecisions), | ||
::testing::Values(CommonTestUtils::DEVICE_GPU) | ||
); | ||
|
||
INSTANTIATE_TEST_CASE_P(smoke_ScatterEltsUpdate, ScatterElementsUpdateLayerTest, | ||
ScatterEltUpdateCases, ScatterElementsUpdateLayerTest::getTestCaseName); | ||
} // namespace |
63 changes: 63 additions & 0 deletions
63
inference-engine/thirdparty/clDNN/api/scatter_elements_update.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* | ||
// Copyright (c) 2020 Intel Corporation | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
*/ | ||
|
||
/////////////////////////////////////////////////////////////////////////////////////////////////// | ||
#pragma once | ||
#include "primitive.hpp" | ||
|
||
namespace cldnn { | ||
/// @addtogroup cpp_api C++ API | ||
/// @{ | ||
/// @addtogroup cpp_topology Network Topology | ||
/// @{ | ||
/// @addtogroup cpp_primitives Primitives | ||
/// @{ | ||
|
||
/// @brief | ||
/// @details | ||
struct scatter_elements_update : public primitive_base<scatter_elements_update> { | ||
CLDNN_DECLARE_PRIMITIVE(scatter_elements_update) | ||
|
||
enum scatter_elements_update_axis { | ||
along_b, | ||
along_f, | ||
along_x, | ||
along_y, | ||
along_z, | ||
along_w | ||
}; | ||
|
||
/// @brief Constructs scatter_elements_update primitive. | ||
/// @param id This primitive id. | ||
/// @param dict Input data primitive id. | ||
/// @param idx Input indexes primitive id. | ||
/// @param idupd Input updates primitive id. | ||
/// @param axis Gathering axis. | ||
scatter_elements_update(const primitive_id& id, | ||
const primitive_id& data, | ||
const primitive_id& idx, | ||
const primitive_id& idupd, | ||
const scatter_elements_update_axis axis, | ||
const padding& output_padding = padding()) | ||
: primitive_base(id, {data, idx, idupd}, output_padding), axis(axis) {} | ||
|
||
/// @brief ScatterElementsUpdate axis | ||
scatter_elements_update_axis axis; | ||
}; | ||
/// @} | ||
/// @} | ||
/// @} | ||
} // namespace cldnn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
174 changes: 174 additions & 0 deletions
174
...kernel_selector/core/actual_kernels/scatter_update/scatter_elements_update_kernel_ref.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
/* | ||
// Copyright (c) 2021 Intel Corporation | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
*/ | ||
|
||
#include "scatter_elements_update_kernel_ref.h" | ||
#include "kernel_selector_utils.h" | ||
#include <string> | ||
#include <vector> | ||
|
||
namespace kernel_selector { | ||
static size_t GetScatterElementsUpdateChannelIndex(const scatter_elements_update_params& params) { | ||
Tensor::DataChannelName name = Tensor::DataChannelName::X; | ||
|
||
const size_t input_size = params.inputs[0].GetDims().size(); | ||
switch (params.axis) { | ||
case ScatterUpdateAxis::X: | ||
return input_size - 1; | ||
case ScatterUpdateAxis::Y: | ||
return input_size - 2; | ||
case ScatterUpdateAxis::Z: | ||
return input_size - 3; | ||
case ScatterUpdateAxis::W: | ||
return 2; | ||
case ScatterUpdateAxis::FEATURE: | ||
return 1; | ||
case ScatterUpdateAxis::BATCH: | ||
return 0; | ||
default: | ||
break; | ||
} | ||
|
||
return DataTensor::Channelndex(params.output.GetLayout(), name); | ||
} | ||
|
||
ParamsKey ScatterElementsUpdateKernelRef::GetSupportedKey() const { | ||
ParamsKey k; | ||
k.EnableInputDataType(Datatype::F16); | ||
k.EnableInputDataType(Datatype::F32); | ||
k.EnableInputDataType(Datatype::INT32); | ||
k.EnableOutputDataType(Datatype::F16); | ||
k.EnableOutputDataType(Datatype::F32); | ||
k.EnableOutputDataType(Datatype::INT32); | ||
k.EnableOutputDataType(Datatype::INT8); | ||
k.EnableOutputDataType(Datatype::UINT8); | ||
k.EnableInputLayout(DataLayout::bfyx); | ||
k.EnableOutputLayout(DataLayout::bfyx); | ||
k.EnableInputLayout(DataLayout::bfzyx); | ||
k.EnableOutputLayout(DataLayout::bfzyx); | ||
k.EnableInputLayout(DataLayout::bfwzyx); | ||
k.EnableOutputLayout(DataLayout::bfwzyx); | ||
k.EnableTensorOffset(); | ||
k.EnableTensorPitches(); | ||
k.EnableBatching(); | ||
k.EnableDifferentTypes(); | ||
return k; | ||
} | ||
|
||
static inline std::string GetOrderString(std::vector<std::string>& order) { | ||
std::string order_str = order[0]; | ||
for (size_t i = 1; i < order.size(); i++) | ||
order_str += ", " + order[i]; | ||
|
||
return order_str; | ||
} | ||
|
||
static inline std::vector<std::string> GetDefaultOrder(size_t size) { | ||
std::vector<std::string> default_order; | ||
if (size <= 4) { | ||
default_order = {"b", "f", "y", "x"}; | ||
} else if (size == 5) { | ||
default_order = {"b", "f", "z", "y", "x"}; | ||
} else if (size == 6) { | ||
default_order = {"b", "f", "w", "z", "y", "x"}; | ||
} | ||
|
||
return default_order; | ||
} | ||
|
||
CommonDispatchData ScatterElementsUpdateKernelRef::SetDefault(const scatter_elements_update_params& params, const optional_params&, bool is_second) const { | ||
CommonDispatchData dispatchData; | ||
const auto& output = params.output; | ||
const auto& indices = params.inputs[1]; | ||
|
||
const auto& scope = is_second ? indices : output; | ||
|
||
switch (params.inputs[0].GetLayout()) { | ||
case DataLayout::bfyx: | ||
dispatchData.gws = {scope.X().v, scope.Y().v, scope.Feature().v * scope.Batch().v}; | ||
break; | ||
|
||
case DataLayout::bfzyx: | ||
dispatchData.gws = {scope.X().v * scope.Y().v, scope.Z().v, scope.Feature().v * scope.Batch().v}; | ||
break; | ||
|
||
case DataLayout::bfwzyx: | ||
dispatchData.gws = {scope.X().v * scope.Y().v, scope.Z().v * scope.W().v, scope.Feature().v * scope.Batch().v}; | ||
break; | ||
default: | ||
throw std::invalid_argument("Unsupported data layout for scatter elements update primitive"); | ||
break; | ||
} | ||
|
||
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo); | ||
|
||
return dispatchData; | ||
} | ||
|
||
JitConstants ScatterElementsUpdateKernelRef::GetJitConstants(const scatter_elements_update_params& params) const { | ||
JitConstants jit = MakeBaseParamsJitConstants(params); | ||
|
||
jit.AddConstant(MakeJitConstant("AXIS_VALUE", GetScatterElementsUpdateChannelIndex(params))); | ||
|
||
if (!params.fused_ops.empty()) { | ||
FusedOpsConfiguration conf1 = { "_FIRST_KERNEL", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() }; | ||
FusedOpsConfiguration conf2 = { "_SECOND_KERNEL", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() }; | ||
jit.Merge(MakeFusedOpsJitConstants(params, {conf1, conf2})); | ||
} | ||
|
||
return jit; | ||
} | ||
|
||
bool ScatterElementsUpdateKernelRef::Validate(const Params& p, const optional_params& o) const { | ||
if (p.GetType() != KernelType:: SCATTER_ELEMENTS_UPDATE || o.GetType() != KernelType::SCATTER_ELEMENTS_UPDATE) { | ||
return false; | ||
} | ||
|
||
const scatter_elements_update_params& params = static_cast<const scatter_elements_update_params&>(p); | ||
|
||
for (auto& fused_op : params.fused_ops) { | ||
if (!IsFusedPrimitiveSupported(fused_op)) | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
KernelsData ScatterElementsUpdateKernelRef::GetKernelsData(const Params& params, const optional_params& options) const { | ||
if (!Validate(params, options)) { | ||
return {}; | ||
} | ||
|
||
KernelData kd = KernelData::Default<scatter_elements_update_params>(params, 2); | ||
scatter_elements_update_params& newParams = *static_cast<scatter_elements_update_params*>(kd.params.get()); | ||
auto cldnn_jit = GetJitConstants(newParams); | ||
|
||
for (int i = 0; i < 2; i++) { | ||
auto dispatchData = SetDefault(newParams, options, (i == 1)); | ||
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options); | ||
|
||
if (i == 1){ | ||
cldnn_jit.AddConstant(MakeJitConstant("IS_SECOND_ITER", "true")); | ||
} | ||
std::string jit = CreateJit(kernelName, cldnn_jit, entry_point); | ||
|
||
clKernelData& kernel = kd.kernels[i]; | ||
|
||
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 3, GetFusedPrimitiveInputsCount(params)); | ||
} | ||
|
||
return {kd}; | ||
} | ||
} // namespace kernel_selector |
Oops, something went wrong.