Skip to content

Commit

Permalink
Scatter Elements Update for cldnn
Browse files Browse the repository at this point in the history
  • Loading branch information
isanghao committed Feb 3, 2021
1 parent 359c2ca commit 654a857
Show file tree
Hide file tree
Showing 18 changed files with 1,121 additions and 4 deletions.
4 changes: 2 additions & 2 deletions inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand Down Expand Up @@ -157,6 +157,7 @@ REGISTER_FACTORY(v3, EmbeddingBagPackedSum);
REGISTER_FACTORY(v3, EmbeddingSegmentsSum);
REGISTER_FACTORY(v3, ExtractImagePatches);
REGISTER_FACTORY(v3, ScatterUpdate);
REGISTER_FACTORY(v3, ScatterElementsUpdate);
// REGISTER_FACTORY(v3, NonMaxSuppression); Supported via v3 -> v5 internal conversion

// ----------------------------- Unsupported v3 ops ----------------------------- //
Expand All @@ -166,7 +167,6 @@ REGISTER_FACTORY(v3, ScatterUpdate);
// REGISTER_FACTORY(v3, NonZero);
// REGISTER_FACTORY(v3, ROIAlign);
// REGISTER_FACTORY(v3, ReadValue);
// REGISTER_FACTORY(v3, ScatterElementsUpdate);
// REGISTER_FACTORY(v3, ScatterNDUpdate);
// REGISTER_FACTORY(v3, ShapeOf);
// REGISTER_FACTORY(v3, TopK);
Expand Down
68 changes: 68 additions & 0 deletions inference-engine/src/cldnn_engine/ops/scatter_elements_update.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "cldnn_program.h"
#include "cldnn_common_utils.h"

#include "ngraph/op/scatter_elements_update.hpp"
#include "ngraph/op/constant.hpp"

#include "api/scatter_elements_update.hpp"

namespace CLDNNPlugin {

static inline cldnn::scatter_elements_update::scatter_elements_update_axis GetScatterElementsUpdateAxis(int axis, unsigned rank) {
if (axis < 0)
axis += rank;
if (axis < 0 || axis >= rank)
THROW_IE_EXCEPTION << "ScatterElementsUpdate axis is not correspond to number of dimensions";

// Difference in dimension ordering between IE and clDNN,
// reverse spatial dimensions after batch and feature.
unsigned cldnn_axis = axis;
if (axis >= 2) {
auto spatial_axis = axis - 2;
// Default and minimum number of dimensions is 4
auto spatial_size = std::max(rank, 4u) - 2;
cldnn_axis = spatial_size - spatial_axis - 1 + 2;
}

switch (cldnn_axis) {
case 0: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_b;
case 1: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_f;
case 2: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_x;
case 3: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_y;
case 4: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_z;
case 5: return cldnn::scatter_elements_update::scatter_elements_update_axis::along_w;
default: THROW_IE_EXCEPTION << "Unsupported ScatterElementsUpdate axis: " << axis;
}

return cldnn::scatter_elements_update::scatter_elements_update_axis::along_f; // shouldn't get here
}

void CreateScatterElementsUpdateOp(Program& p, const std::shared_ptr<ngraph::op::v3::ScatterElementsUpdate>& op) {
p.ValidateInputs(op, {4});
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
std::string layerName = layer_type_name_ID(op);

size_t rank = op->get_input_shape(0).size();
auto axes_constant = std::dynamic_pointer_cast<ngraph::op::Constant>(op->get_input_node_shared_ptr(3));
if (!axes_constant) {
THROW_IE_EXCEPTION << "Unsupported parameter nodes type in " << op->get_friendly_name() << " (" << op->get_type_name() << ")";
}
int32_t axis = axes_constant->cast_vector<int32_t>()[0];

auto primitive = cldnn::scatter_elements_update(layerName,
inputPrimitives[0],
inputPrimitives[1],
inputPrimitives[2],
GetScatterElementsUpdateAxis(axis, rank));

p.AddPrimitive(primitive);
p.AddPrimitiveToProfiler(op);
}

REGISTER_FACTORY_IMPL(v3, ScatterElementsUpdate);

} // namespace CLDNNPlugin
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>
#include <ngraph/opsets/opset3.hpp>

#include "single_layer_tests/scatter_elements_update.hpp"
#include "common_test_utils/test_constants.hpp"

using namespace LayerTestsDefinitions;
using namespace ngraph::opset3;

namespace {
// map<inputShape, map<indicesShape, axis>>
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>> axesShapeInShape {
{{10, 12, 15}, {{{1, 2, 4}, {0, 1, 2}}, {{2, 2, 2}, {-1, -2, -3}}}},
{{15, 9, 8, 12}, {{{1, 2, 2, 2}, {0, 1, 2, 3}}, {{1, 2, 1, 4}, {-1, -2, -3, -4}}}},
{{9, 9, 8, 8, 11, 10}, {{{1, 2, 1, 2, 1, 2}, {5, -3}}}},
};

// index value should not be random data
const std::vector<std::vector<size_t>> idxValue = {
{1, 0, 4, 6, 2, 3, 7, 5}
};

const std::vector<InferenceEngine::Precision> inputPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::I32,
};

const std::vector<InferenceEngine::Precision> idxPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64,
};

const auto ScatterEltUpdateCases = ::testing::Combine(
::testing::ValuesIn(ScatterElementsUpdateLayerTest::combineShapes(axesShapeInShape)),
::testing::ValuesIn(idxValue),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)
);

INSTANTIATE_TEST_CASE_P(smoke_ScatterEltsUpdate, ScatterElementsUpdateLayerTest,
ScatterEltUpdateCases, ScatterElementsUpdateLayerTest::getTestCaseName);
} // namespace
63 changes: 63 additions & 0 deletions inference-engine/thirdparty/clDNN/api/scatter_elements_update.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
// Copyright (c) 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/

///////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once
#include "primitive.hpp"

namespace cldnn {
/// @addtogroup cpp_api C++ API
/// @{
/// @addtogroup cpp_topology Network Topology
/// @{
/// @addtogroup cpp_primitives Primitives
/// @{

/// @brief
/// @details
struct scatter_elements_update : public primitive_base<scatter_elements_update> {
CLDNN_DECLARE_PRIMITIVE(scatter_elements_update)

enum scatter_elements_update_axis {
along_b,
along_f,
along_x,
along_y,
along_z,
along_w
};

/// @brief Constructs scatter_elements_update primitive.
/// @param id This primitive id.
/// @param dict Input data primitive id.
/// @param idx Input indexes primitive id.
/// @param idupd Input updates primitive id.
/// @param axis Gathering axis.
scatter_elements_update(const primitive_id& id,
const primitive_id& data,
const primitive_id& idx,
const primitive_id& idupd,
const scatter_elements_update_axis axis,
const padding& output_padding = padding())
: primitive_base(id, {data, idx, idupd}, output_padding), axis(axis) {}

/// @brief ScatterElementsUpdate axis
scatter_elements_update_axis axis;
};
/// @}
/// @}
/// @}
} // namespace cldnn
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ enum class KernelType {
ONE_HOT,
GATHER,
SCATTER_UPDATE,
SCATTER_ELEMENTS_UPDATE,
DEPTH_TO_SPACE,
BATCH_TO_SPACE,
SHUFFLE_CHANNELS,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
// Copyright (c) 2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/

#include "scatter_elements_update_kernel_ref.h"
#include "kernel_selector_utils.h"
#include <string>
#include <vector>

namespace kernel_selector {
static size_t GetScatterElementsUpdateChannelIndex(const scatter_elements_update_params& params) {
Tensor::DataChannelName name = Tensor::DataChannelName::X;

const size_t input_size = params.inputs[0].GetDims().size();
switch (params.axis) {
case ScatterUpdateAxis::X:
return input_size - 1;
case ScatterUpdateAxis::Y:
return input_size - 2;
case ScatterUpdateAxis::Z:
return input_size - 3;
case ScatterUpdateAxis::W:
return 2;
case ScatterUpdateAxis::FEATURE:
return 1;
case ScatterUpdateAxis::BATCH:
return 0;
default:
break;
}

return DataTensor::Channelndex(params.output.GetLayout(), name);
}

ParamsKey ScatterElementsUpdateKernelRef::GetSupportedKey() const {
ParamsKey k;
k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableInputLayout(DataLayout::bfzyx);
k.EnableOutputLayout(DataLayout::bfzyx);
k.EnableInputLayout(DataLayout::bfwzyx);
k.EnableOutputLayout(DataLayout::bfwzyx);
k.EnableTensorOffset();
k.EnableTensorPitches();
k.EnableBatching();
k.EnableDifferentTypes();
return k;
}

static inline std::string GetOrderString(std::vector<std::string>& order) {
std::string order_str = order[0];
for (size_t i = 1; i < order.size(); i++)
order_str += ", " + order[i];

return order_str;
}

static inline std::vector<std::string> GetDefaultOrder(size_t size) {
std::vector<std::string> default_order;
if (size <= 4) {
default_order = {"b", "f", "y", "x"};
} else if (size == 5) {
default_order = {"b", "f", "z", "y", "x"};
} else if (size == 6) {
default_order = {"b", "f", "w", "z", "y", "x"};
}

return default_order;
}

CommonDispatchData ScatterElementsUpdateKernelRef::SetDefault(const scatter_elements_update_params& params, const optional_params&, bool is_second) const {
CommonDispatchData dispatchData;
const auto& output = params.output;
const auto& indices = params.inputs[1];

const auto& scope = is_second ? indices : output;

switch (params.inputs[0].GetLayout()) {
case DataLayout::bfyx:
dispatchData.gws = {scope.X().v, scope.Y().v, scope.Feature().v * scope.Batch().v};
break;

case DataLayout::bfzyx:
dispatchData.gws = {scope.X().v * scope.Y().v, scope.Z().v, scope.Feature().v * scope.Batch().v};
break;

case DataLayout::bfwzyx:
dispatchData.gws = {scope.X().v * scope.Y().v, scope.Z().v * scope.W().v, scope.Feature().v * scope.Batch().v};
break;
default:
throw std::invalid_argument("Unsupported data layout for scatter elements update primitive");
break;
}

dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);

return dispatchData;
}

JitConstants ScatterElementsUpdateKernelRef::GetJitConstants(const scatter_elements_update_params& params) const {
JitConstants jit = MakeBaseParamsJitConstants(params);

jit.AddConstant(MakeJitConstant("AXIS_VALUE", GetScatterElementsUpdateChannelIndex(params)));

if (!params.fused_ops.empty()) {
FusedOpsConfiguration conf1 = { "_FIRST_KERNEL", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() };
FusedOpsConfiguration conf2 = { "_SECOND_KERNEL", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() };
jit.Merge(MakeFusedOpsJitConstants(params, {conf1, conf2}));
}

return jit;
}

bool ScatterElementsUpdateKernelRef::Validate(const Params& p, const optional_params& o) const {
if (p.GetType() != KernelType:: SCATTER_ELEMENTS_UPDATE || o.GetType() != KernelType::SCATTER_ELEMENTS_UPDATE) {
return false;
}

const scatter_elements_update_params& params = static_cast<const scatter_elements_update_params&>(p);

for (auto& fused_op : params.fused_ops) {
if (!IsFusedPrimitiveSupported(fused_op))
return false;
}

return true;
}

KernelsData ScatterElementsUpdateKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
if (!Validate(params, options)) {
return {};
}

KernelData kd = KernelData::Default<scatter_elements_update_params>(params, 2);
scatter_elements_update_params& newParams = *static_cast<scatter_elements_update_params*>(kd.params.get());
auto cldnn_jit = GetJitConstants(newParams);

for (int i = 0; i < 2; i++) {
auto dispatchData = SetDefault(newParams, options, (i == 1));
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);

if (i == 1){
cldnn_jit.AddConstant(MakeJitConstant("IS_SECOND_ITER", "true"));
}
std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);

clKernelData& kernel = kd.kernels[i];

FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 3, GetFusedPrimitiveInputsCount(params));
}

return {kd};
}
} // namespace kernel_selector
Loading

0 comments on commit 654a857

Please sign in to comment.