Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IE CLDNN] Added ScatterElementsUpdate op support #4105

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand Down Expand Up @@ -157,6 +157,7 @@ REGISTER_FACTORY(v3, EmbeddingBagPackedSum);
REGISTER_FACTORY(v3, EmbeddingSegmentsSum);
REGISTER_FACTORY(v3, ExtractImagePatches);
REGISTER_FACTORY(v3, ScatterUpdate);
REGISTER_FACTORY(v3, ScatterElementsUpdate);
// REGISTER_FACTORY(v3, NonMaxSuppression); Supported via v3 -> v5 internal conversion

// ----------------------------- Unsupported v3 ops ----------------------------- //
Expand All @@ -166,7 +167,6 @@ REGISTER_FACTORY(v3, ScatterUpdate);
// REGISTER_FACTORY(v3, NonZero);
// REGISTER_FACTORY(v3, ROIAlign);
// REGISTER_FACTORY(v3, ReadValue);
// REGISTER_FACTORY(v3, ScatterElementsUpdate);
// REGISTER_FACTORY(v3, ScatterNDUpdate);
// REGISTER_FACTORY(v3, ShapeOf);
// REGISTER_FACTORY(v3, TopK);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "cldnn_program.h"
#include "cldnn_common_utils.h"

#include "ngraph/op/scatter_elements_update.hpp"
#include "ngraph/op/constant.hpp"

#include "api/scatter_elements_update.hpp"

namespace CLDNNPlugin {

static inline cldnn::scatter_elements_update::scatter_elements_update_axis GetScatterElementsUpdateAxis(int axis, unsigned rank) {
if (axis < 0)
axis += rank;
if (axis < 0 || axis >= rank)
THROW_IE_EXCEPTION << "ScatterElementsUpdate axis is not correspond to number of dimensions";

// Difference in dimension ordering between IE and clDNN,
// reverse spatial dimensions after batch and feature.
unsigned cldnn_axis = axis;
if (axis >= 2) {
auto spatial_axis = axis - 2;
// Default and minimum number of dimensions is 4
auto spatial_size = std::max(rank, 4u) - 2;
cldnn_axis = spatial_size - spatial_axis - 1 + 2;
}

switch (cldnn_axis) {
case 0: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_b;
case 1: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_f;
case 2: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_x;
case 3: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_y;
case 4: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_z;
case 5: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_w;
default: THROW_IE_EXCEPTION << "Unsupported ScatterElementsUpdate axis: " << axis;
}

return cldnn::scatter_elements_update::scatter_elements_update_axis::along_f; // shouldn't get here
}

void CreateScatterElementsUpdateOp(Program& p, const std::shared_ptr<ngraph::op::v3::ScatterElementsUpdate>& op) {
p.ValidateInputs(op, {4});
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
std::string layerName = layer_type_name_ID(op);

size_t rank = op->get_input_shape(0).size();
auto axes_constant = std::dynamic_pointer_cast<ngraph::op::Constant>(op->get_input_node_shared_ptr(3));
if (!axes_constant) {
THROW_IE_EXCEPTION << "Unsupported parameter nodes type in " << op->get_friendly_name() << " (" << op->get_type_name() << ")";
}
int32_t axis = axes_constant->cast_vector<int32_t>()[0];

auto primitive = cldnn::scatter_elements_update(layerName,
inputPrimitives[0],
inputPrimitives[1],
inputPrimitives[2],
GetScatterElementsUpdateAxis(axis, rank));

p.AddPrimitive(primitive);
p.AddPrimitiveToProfiler(op);
}

REGISTER_FACTORY_IMPL(v3, ScatterElementsUpdate);

} // namespace CLDNNPlugin
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>
#include <ngraph/opsets/opset3.hpp>

#include "single_layer_tests/scatter_elements_update.hpp"
#include "common_test_utils/test_constants.hpp"

using namespace LayerTestsDefinitions;
using namespace ngraph::opset3;

namespace {
// map<inputShape, map<indicesShape, axis>>
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>> axesShapeInShape {
{{10, 12, 15}, {{{1, 2, 4}, {0, 1, 2}}, {{2, 2, 2}, {-1, -2, -3}}}},
{{15, 9, 8, 12}, {{{1, 2, 2, 2}, {0, 1, 2, 3}}, {{1, 2, 1, 4}, {-1, -2, -3, -4}}}},
{{9, 9, 8, 8, 11, 10}, {{{1, 2, 1, 2, 1, 2}, {5, -3}}}},
};

// index value should not be random data
const std::vector<std::vector<size_t>> idxValue = {
{1, 0, 4, 6, 2, 3, 7, 5}
};

const std::vector<InferenceEngine::Precision> inputPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::I32,
};

const std::vector<InferenceEngine::Precision> idxPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64,
};

const auto ScatterEltUpdateCases = ::testing::Combine(
::testing::ValuesIn(ScatterElementsUpdateLayerTest::combineShapes(axesShapeInShape)),
::testing::ValuesIn(idxValue),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)
);

INSTANTIATE_TEST_CASE_P(smoke_ScatterEltsUpdate, ScatterElementsUpdateLayerTest,
ScatterEltUpdateCases, ScatterElementsUpdateLayerTest::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
// Copyright (c) 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/

///////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once
#include "primitive.hpp"

namespace cldnn {
/// @addtogroup cpp_api C++ API
/// @{
/// @addtogroup cpp_topology Network Topology
/// @{
/// @addtogroup cpp_primitives Primitives
/// @{

/// @brief
/// @details
struct scatter_elements_update : public primitive_base<scatter_elements_update> {
CLDNN_DECLARE_PRIMITIVE(scatter_elements_update)

enum scatter_elements_update_axis {
along_b,
along_f,
along_x,
along_y,
along_z,
along_w
};

/// @brief Constructs scatter_elements_update primitive.
/// @param id This primitive id.
/// @param dict Input data primitive id.
/// @param idx Input indexes primitive id.
/// @param idupd Input updates primitive id.
/// @param axis Gathering axis.
scatter_elements_update(const primitive_id& id,
const primitive_id& data,
const primitive_id& idx,
const primitive_id& idupd,
const scatter_elements_update_axis axis,
const padding& output_padding = padding())
: primitive_base(id, {data, idx, idupd}, output_padding), axis(axis) {}

/// @brief ScatterElementsUpdate axis
scatter_elements_update_axis axis;
};
/// @}
/// @}
/// @}
} // namespace cldnn
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ enum class KernelType {
ONE_HOT,
GATHER,
SCATTER_UPDATE,
SCATTER_ELEMENTS_UPDATE,
DEPTH_TO_SPACE,
BATCH_TO_SPACE,
SHUFFLE_CHANNELS,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
// Copyright (c) 2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/

#include "scatter_elements_update_kernel_ref.h"
#include "kernel_selector_utils.h"
#include <string>
#include <vector>

namespace kernel_selector {
static size_t GetScatterElementsUpdateChannelIndex(const scatter_elements_update_params& params) {
Tensor::DataChannelName name = Tensor::DataChannelName::X;

const size_t input_size = params.inputs[0].GetDims().size();
switch (params.axis) {
case ScatterUpdateAxis::X:
return input_size - 1;
case ScatterUpdateAxis::Y:
return input_size - 2;
case ScatterUpdateAxis::Z:
return input_size - 3;
case ScatterUpdateAxis::W:
return 2;
case ScatterUpdateAxis::FEATURE:
return 1;
case ScatterUpdateAxis::BATCH:
return 0;
default:
break;
}

return DataTensor::Channelndex(params.output.GetLayout(), name);
}

ParamsKey ScatterElementsUpdateKernelRef::GetSupportedKey() const {
ParamsKey k;
k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableInputLayout(DataLayout::bfzyx);
k.EnableOutputLayout(DataLayout::bfzyx);
k.EnableInputLayout(DataLayout::bfwzyx);
k.EnableOutputLayout(DataLayout::bfwzyx);
k.EnableTensorOffset();
k.EnableTensorPitches();
k.EnableBatching();
k.EnableDifferentTypes();
return k;
}

static inline std::string GetOrderString(std::vector<std::string>& order) {
std::string order_str = order[0];
for (size_t i = 1; i < order.size(); i++)
order_str += ", " + order[i];

return order_str;
}

static inline std::vector<std::string> GetDefaultOrder(size_t size) {
std::vector<std::string> default_order;
if (size <= 4) {
default_order = {"b", "f", "y", "x"};
} else if (size == 5) {
default_order = {"b", "f", "z", "y", "x"};
} else if (size == 6) {
default_order = {"b", "f", "w", "z", "y", "x"};
}

return default_order;
}

CommonDispatchData ScatterElementsUpdateKernelRef::SetDefault(const scatter_elements_update_params& params, const optional_params&, bool is_second) const {
CommonDispatchData dispatchData;
const auto& output = params.output;
const auto& indices = params.inputs[1];

const auto& scope = is_second ? indices : output;

switch (params.inputs[0].GetLayout()) {
case DataLayout::bfyx:
dispatchData.gws = {scope.X().v, scope.Y().v, scope.Feature().v * scope.Batch().v};
break;

case DataLayout::bfzyx:
dispatchData.gws = {scope.X().v * scope.Y().v, scope.Z().v, scope.Feature().v * scope.Batch().v};
break;

case DataLayout::bfwzyx:
dispatchData.gws = {scope.X().v * scope.Y().v, scope.Z().v * scope.W().v, scope.Feature().v * scope.Batch().v};
break;
default:
throw std::invalid_argument("Unsupported data layout for scatter elements update primitive");
break;
}

dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);

return dispatchData;
}

JitConstants ScatterElementsUpdateKernelRef::GetJitConstants(const scatter_elements_update_params& params) const {
JitConstants jit = MakeBaseParamsJitConstants(params);

jit.AddConstant(MakeJitConstant("AXIS_VALUE", GetScatterElementsUpdateChannelIndex(params)));

if (!params.fused_ops.empty()) {
FusedOpsConfiguration conf1 = { "_FIRST_KERNEL", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() };
FusedOpsConfiguration conf2 = { "_SECOND_KERNEL", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() };
jit.Merge(MakeFusedOpsJitConstants(params, {conf1, conf2}));
}

return jit;
}

bool ScatterElementsUpdateKernelRef::Validate(const Params& p, const optional_params& o) const {
if (p.GetType() != KernelType:: SCATTER_ELEMENTS_UPDATE || o.GetType() != KernelType::SCATTER_ELEMENTS_UPDATE) {
return false;
}

const scatter_elements_update_params& params = static_cast<const scatter_elements_update_params&>(p);

for (auto& fused_op : params.fused_ops) {
if (!IsFusedPrimitiveSupported(fused_op))
return false;
}

return true;
}

KernelsData ScatterElementsUpdateKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
if (!Validate(params, options)) {
return {};
}

KernelData kd = KernelData::Default<scatter_elements_update_params>(params, 2);
scatter_elements_update_params& newParams = *static_cast<scatter_elements_update_params*>(kd.params.get());
auto cldnn_jit = GetJitConstants(newParams);

for (int i = 0; i < 2; i++) {
auto dispatchData = SetDefault(newParams, options, (i == 1));
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);

if (i == 1){
cldnn_jit.AddConstant(MakeJitConstant("IS_SECOND_ITER", "true"));
}
std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);

clKernelData& kernel = kd.kernels[i];

FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 3, GetFusedPrimitiveInputsCount(params));
}

return {kd};
}
} // namespace kernel_selector
Loading