Skip to content

Commit

Permalink
[IE VPU] support GatherElements* (Gather + GatherElements)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrejsokolov committed Feb 11, 2021
1 parent 3eb0b5c commit 90afef4
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 34 deletions.
6 changes: 3 additions & 3 deletions inference-engine/cmake/vpu_dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ include_guard(GLOBAL)

set(VPU_SUPPORTED_FIRMWARES usb-ma2x8x pcie-ma2x8x)
set(VPU_SUPPORTED_FIRMWARES_HASH
"b9e4c2cff51d17f0751219586906be3611c593aca01b43907518df1d762672ea"
"6d89b52d723c1ba2c361575a53eda951392b6ce818733c30b78f16c75caa7892")
"70117ed4385573a01f4ffc01299078e4cd88501035fe49ff9db6b401ffa85962"
"7e0f86089fb704da543e0aa2c37ffa14e0d42b01e2e820ca1944fd4e5552ff77")

#
# Default packages
#

set(FIRMWARE_PACKAGE_VERSION 1609)
set(FIRMWARE_PACKAGE_VERSION 1618)
set(VPU_CLC_MA2X8X_VERSION "movi-cltools-20.09.2")

#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,11 +343,12 @@ class StageBuilder final {
const Data& input,
const Data& output);

Stage addGatherElementsStage(const Model &model,
const std::string &name,
const ie::CNNLayerPtr &layer,
const Data &input, const Data &indices,
const Data &output, int32_t axis);
Stage addGatherElementsStage(const Model &model,
const std::string &name,
const ie::CNNLayerPtr &layer,
const DataVector &inputs,
const Data &output, int32_t axis,
bool rowIndicesMode);
};

} // namespace vpu
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ FrontEnd::FrontEnd(StageBuilder::Ptr stageBuilder, const ie::ICore* core)
{"HSwish", LAYER_PARSER(parseHSwish)},
{"Ceiling", LAYER_PARSER(parseCeiling)},
{"GatherElements", LAYER_PARSER(parseGatherElements)},
{"ExpGatherElements", LAYER_PARSER(parseGatherElements)},
{"Round", LAYER_PARSER(parseRound)},
}} {
VPU_THROW_UNLESS(_core != nullptr, "Argument core is null");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -25,6 +25,13 @@ class GatherElementsStage final : public StageNode {
const auto input2 = inputEdge(1)->input();
const auto output = outputEdge(0)->output();

const auto rowIndicesMode = attrs().get<int32_t>("rowIndicesMode");
if (rowIndicesMode) {
const auto input3 = inputEdge(2)->input();
orderInfo.setInput(inputEdge(2),
DimsOrder::fromNumDims(input3->desc().numDims()));
}

orderInfo.setInput(inputEdge(0),
DimsOrder::fromNumDims(input1->desc().numDims()));
orderInfo.setInput(inputEdge(1),
Expand All @@ -47,12 +54,16 @@ class GatherElementsStage final : public StageNode {
getBatchSupportInfoImpl(StageDataInfo<BatchSupport> &batchInfo) override {}

StageSHAVEsRequirements getSHAVEsRequirementsImpl() const override {
return StageSHAVEsRequirements::NotNeeded;
const auto axis = attrs().get<int32_t>("axis");
const auto rank = inputEdge(0)->input()->desc().numDims();
const auto rowIndicesMode = attrs().get<int32_t>("rowIndicesMode");

return (rowIndicesMode || (axis == rank - 1)) ? StageSHAVEsRequirements::NeedMax : StageSHAVEsRequirements::NotNeeded;
}

void initialCheckImpl() const override {
VPU_THROW_UNLESS(numInputs() == 2,
"{} stage with name {} must have only 1 output, actually "
VPU_THROW_UNLESS(numInputs() == 2 || numInputs() == 3,
"{} stage with name {} must have 2 or 3 inputs only, actually "
"provided {} inputs",
type(), name(), numInputs());
VPU_THROW_UNLESS(numOutputs() == 1,
Expand All @@ -63,14 +74,19 @@ class GatherElementsStage final : public StageNode {
"First input and output must have the same DataType, "
"actual input type is {} and output type is {}",
inputs()[0]->desc().type(), outputs()[0]->desc().type());
assertInputsOutputsTypes(
this, {{DataType::U8, DataType::FP16, DataType::S32}, {DataType::S32}},
{{DataType::U8, DataType::FP16, DataType::S32}});

DataTypesRequirement inputDataTypes = {{DataType::U8, DataType::FP16, DataType::S32}, {DataType::S32}};
if (numInputs() == 3)
inputDataTypes.push_back({DataType::S32});

assertInputsOutputsTypes(this, inputDataTypes, {{DataType::U8, DataType::FP16, DataType::S32}});
}

void serializeParamsImpl(BlobSerializer &serializer) const override {
const auto axis = attrs().get<int32_t>("axis");
const auto rowIndicesMode = attrs().get<int32_t>("rowIndicesMode");
serializer.append(axis);
serializer.append(rowIndicesMode);
}

void serializeDataImpl(BlobSerializer &serializer) const override {
Expand All @@ -81,6 +97,13 @@ class GatherElementsStage final : public StageNode {
input0->serializeBuffer(serializer);
output->serializeBuffer(serializer);
input1->serializeBuffer(serializer);

const auto rowIndicesMode = attrs().get<int32_t>("rowIndicesMode");

if (rowIndicesMode) {
auto rowIndices = inputEdge(2)->input();
rowIndices->serializeBuffer(serializer);
}
}
};

Expand All @@ -89,21 +112,23 @@ class GatherElementsStage final : public StageNode {
Stage StageBuilder::addGatherElementsStage(const Model &model,
const std::string &name,
const ie::CNNLayerPtr &layer,
const Data &input, const Data &indices,
const Data &output, int32_t axis) {
const DataVector &inputs,
const Data &output, int32_t axis,
bool rowIndicesMode) {
auto stage = model->addNewStage<GatherElementsStage>(
layer->name, StageType::GatherElements, layer, {input, indices}, {output});
layer->name, StageType::GatherElements, layer, inputs, {output});

stage->attrs().set<int32_t>("axis", axis);
stage->attrs().set<int32_t>("rowIndicesMode", rowIndicesMode);

return stage;
}

void FrontEnd::parseGatherElements(const Model &model, const ie::CNNLayerPtr &layer,
const DataVector &inputs,
const DataVector &outputs) const {
VPU_THROW_UNLESS(layer, "CNNLayer pointer is null.");
VPU_THROW_UNLESS(inputs.size() == 2,
VPU_THROW_UNLESS(layer != nullptr, "CNNLayer pointer is null.");
VPU_THROW_UNLESS(inputs.size() == 2 || inputs.size() == 3,
"{} layer with name {} must have 2 inputs, actually "
"provided {} inputs",
layer->type, layer->name, inputs.size());
Expand All @@ -112,19 +137,31 @@ void FrontEnd::parseGatherElements(const Model &model, const ie::CNNLayerPtr &la
"provided {} outputs",
layer->type, layer->name, outputs.size());

bool rowIndicesMode = (inputs.size() == 3);

const auto axis = layer->GetParamAsInt("axis");
const auto rank = inputs[0]->desc().numDims();

VPU_THROW_UNLESS(rank >= 1, "rank has to be more than or equal to 1, actually {}", rank);
VPU_THROW_UNLESS(inputs[1]->desc().numDims() == rank, "rank of the second input must be equal to {}, actually {}",
rank, inputs[1]->desc().numDims());
VPU_THROW_UNLESS(outputs[0]->desc().numDims() == rank, "rank of output must be equal to {}, actually {}",
rank, outputs[0]->desc().numDims());
VPU_THROW_UNLESS(axis >= 0 && axis < rank, "axis must be in the range of [0, {}) , actually {}",
rank, axis);

_stageBuilder->addGatherElementsStage(model, layer->name, layer, inputs[0],
inputs[1], outputs[0], axis);

if (rowIndicesMode) {
VPU_THROW_UNLESS(inputs[1]->desc().numDims() == rank + 1, "rank of the second input must be equal to {}, actually {}",
rank + 1, inputs[1]->desc().numDims());
VPU_THROW_UNLESS(inputs[2]->desc().numDims() == 2, "rank of the third input must be equal to 2, actually {}",
2, inputs[2]->desc().numDims());
VPU_THROW_UNLESS(outputs[0]->desc().numDims() == rank + 1, "rank of output must be equal to {}, actually {}",
rank + 1, outputs[0]->desc().numDims());
VPU_THROW_UNLESS(axis == rank - 1, "axis must be equal to {}, actually {}", rank - 1, axis);
} else {
VPU_THROW_UNLESS(inputs[1]->desc().numDims() == rank, "rank of the second input must be equal to {}, actually {}",
rank, inputs[1]->desc().numDims());
VPU_THROW_UNLESS(outputs[0]->desc().numDims() == rank, "rank of output must be equal to {}, actually {}",
rank, outputs[0]->desc().numDims());
VPU_THROW_UNLESS(axis >= 0 && axis < rank, "axis must be in the range of [0, {}) , actually {}",
rank, axis);
}

_stageBuilder->addGatherElementsStage(model, layer->name, layer, inputs, outputs[0], axis, rowIndicesMode);
}

}// namespace vpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//


#include <vector>

#include "shared_test_classes/single_layer/gather_elements.hpp"

#include <vpu/private_plugin_config.hpp>

using namespace LayerTestsDefinitions;

namespace {

class GatherElementsLayerTestVPU : public GatherElementsLayerTest {
protected:
void SetUp() override {
configuration[InferenceEngine::MYRIAD_DETECT_NETWORK_BATCH] = CONFIG_VALUE(NO);
GatherElementsLayerTest::SetUp();
}
};

TEST_P(GatherElementsLayerTestVPU, GatherElementsTests) {
Run();
}

const std::vector<InferenceEngine::Precision> dPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
};

const std::vector<InferenceEngine::Precision> iPrecisions = {
InferenceEngine::Precision::I32
};

INSTANTIATE_TEST_CASE_P(smoke_GatherElements1, GatherElementsLayerTestVPU,
::testing::Combine(
::testing::Values(std::vector<size_t>({2, 2})), // Data shape
::testing::Values(std::vector<size_t>({2, 2})), // Indices shape
::testing::Values(0, 1), // Axis
::testing::ValuesIn(dPrecisions),
::testing::ValuesIn(iPrecisions),
::testing::Values(CommonTestUtils::DEVICE_MYRIAD)),
GatherElementsLayerTest::getTestCaseName);

INSTANTIATE_TEST_CASE_P(smoke_GatherElements2, GatherElementsLayerTestVPU,
::testing::Combine(
::testing::Values(std::vector<size_t>({2, 65, 300})), // Data shape
::testing::Values(std::vector<size_t>({2, 65, 64})), // Indices shape
::testing::Values(2), // Axis
::testing::ValuesIn(dPrecisions),
::testing::ValuesIn(iPrecisions),
::testing::Values(CommonTestUtils::DEVICE_MYRIAD)),
GatherElementsLayerTest::getTestCaseName);


} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,5 @@ std::vector<std::string> disabledTestPatterns() {
".*DSR_GatherStaticDataDynamicIdx.*f32.*1.3.200.304.*",
// TODO: Issue 47315
".*ProposalLayerTest.*",
// TODO: Issue 46755
".*DSR_GatherElements.*",
// TODO: Issue 46756
".*smoke_Gather_GatherElements.*"
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ INSTANTIATE_TEST_CASE_P(smoke_DynamicGatherElements, DSR_GatherElements,
GatherTestCase{DataShapeWithUpperBound{{1000}, {}}, DataShapeWithUpperBound{{800}, {1000}}, 0},
GatherTestCase{DataShapeWithUpperBound{{1000, 4}, {}}, DataShapeWithUpperBound{{100, 4}, {800, 4}}, 0},
GatherTestCase{DataShapeWithUpperBound{{4, 1000}, {}}, DataShapeWithUpperBound{{4, 100}, {4, 800}}, 1},
GatherTestCase{DataShapeWithUpperBound{{300, 3, 64, 608}, {}}, DataShapeWithUpperBound{{300, 3, 64, 60}, {300, 3, 64, 64}}, 3},
GatherTestCase{DataShapeWithUpperBound{{30, 3, 64, 608}, {}}, DataShapeWithUpperBound{{30, 3, 64, 60}, {30, 3, 64, 64}}, 3},
GatherTestCase{DataShapeWithUpperBound{{800}, {1000}}, DataShapeWithUpperBound{{200}, {800}}, 0},
GatherTestCase{DataShapeWithUpperBound{{800, 4}, {1000, 4}}, DataShapeWithUpperBound{{300, 4}, {800, 4}}, 0},
GatherTestCase{DataShapeWithUpperBound{{4, 800}, {4, 1000}}, DataShapeWithUpperBound{{4, 700}, {4, 750}}, 1}),
Expand Down

0 comments on commit 90afef4

Please sign in to comment.