Skip to content

Commit

Permalink
[IE CLDNN] Gather8 (openvinotoolkit#6430)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-cv authored Jul 6, 2021
1 parent 773307a commit ff2d661
Show file tree
Hide file tree
Showing 13 changed files with 266 additions and 44 deletions.
3 changes: 3 additions & 0 deletions inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,5 +208,8 @@ REGISTER_FACTORY(v6, MVN);
// ------------------------------ Supported v7 ops ------------------------------ //
REGISTER_FACTORY(v7, Gather);

// ------------------------------ Supported v8 ops ------------------------------ //
REGISTER_FACTORY(v8, Gather);

// --------------------------- Supported internal ops --------------------------- //
REGISTER_FACTORY(internal, NonMaxSuppressionIEInternal);
58 changes: 18 additions & 40 deletions inference-engine/src/cldnn_engine/ops/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ static cldnn::gather::gather_axis GetGatherAxis(int32_t axis, cldnn::format inpu
}
}

void CreateGatherOp(Program& p, const std::shared_ptr<ngraph::op::v1::Gather>& op) {
p.ValidateInputs(op, {2, 3});
template <typename T>
void CreateGatherOpBase(Program& p, const std::shared_ptr<T>& op, const int64_t batch_dim = 0, bool support_neg_ind = false) {
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
std::string layerName = layer_type_name_ID(op);

Expand Down Expand Up @@ -92,55 +92,33 @@ void CreateGatherOp(Program& p, const std::shared_ptr<ngraph::op::v1::Gather>& o
reorderedInputs[1],
GetGatherAxis(axis, DefaultFormatForDims(op->get_input_shape(0).size())),
outLayout,
CldnnTensorFromIEDims(op->get_output_shape(0)));
CldnnTensorFromIEDims(op->get_output_shape(0)),
batch_dim,
support_neg_ind);

p.AddPrimitive(gatherPrim);
p.AddPrimitiveToProfiler(op);
}

void CreateGatherOp(Program& p, const std::shared_ptr<ngraph::op::v1::Gather>& op) {
p.ValidateInputs(op, {2, 3});
CreateGatherOpBase<ngraph::op::v1::Gather>(p, op);
}

REGISTER_FACTORY_IMPL(v1, Gather);

void CreateGatherOp(Program& p, const std::shared_ptr<ngraph::op::v7::Gather>& op) {
p.ValidateInputs(op, {2, 3, 4});
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
std::string layerName = layer_type_name_ID(op);

int32_t axis = static_cast<int32_t>(op->get_axis());

std::vector<cldnn::primitive_id> reorderedInputs;
reorderedInputs.resize(inputPrimitives.size());

for (size_t portIndex = 0; portIndex < inputPrimitives.size(); portIndex++) {
auto inputDataType = DataTypeFromPrecision(op->get_input_element_type(portIndex));
if (inputDataType == cldnn::data_types::i64) {
// clDNN primitive does not support i64 inputs,
// so we need additional reorders to convert them to i32
auto reorderPrimName = inputPrimitives[portIndex] + "_" + op->get_friendly_name() + Program::m_preProcessTag;
auto targetFormat = DefaultFormatForDims(op->get_input_shape(portIndex).size());
auto preprocessPrim = cldnn::reorder(reorderPrimName,
inputPrimitives[portIndex],
targetFormat,
cldnn::data_types::i32);
p.AddPrimitive(preprocessPrim);
p.AddInnerPrimitiveToProfiler(reorderPrimName, layerName, op);
reorderedInputs[portIndex] = reorderPrimName;
} else {
reorderedInputs[portIndex] = inputPrimitives[portIndex];
}
}
CreateGatherOpBase<ngraph::op::v7::Gather>(p, op, op->get_batch_dims());
}

auto outLayout = DefaultFormatForDims(op->get_output_shape(0).size());
auto gatherPrim = cldnn::gather(layerName,
reorderedInputs[0],
reorderedInputs[1],
GetGatherAxis(axis, DefaultFormatForDims(op->get_input_shape(0).size())),
outLayout,
CldnnTensorFromIEDims(op->get_output_shape(0)),
op->get_batch_dims());
REGISTER_FACTORY_IMPL(v7, Gather);

p.AddPrimitive(gatherPrim);
p.AddPrimitiveToProfiler(op);
void CreateGatherOp(Program& p, const std::shared_ptr<ngraph::op::v8::Gather>& op) {
p.ValidateInputs(op, {2, 3, 4});
CreateGatherOpBase<ngraph::op::v8::Gather>(p, op, op->get_batch_dims(), true);
}

REGISTER_FACTORY_IMPL(v7, Gather);
REGISTER_FACTORY_IMPL(v8, Gather);

} // namespace CLDNNPlugin
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,34 @@ INSTANTIATE_TEST_SUITE_P(
Gather7LayerTest::getTestCaseName
);

INSTANTIATE_TEST_SUITE_P(
smoke_Gather7Axes4i4b1,
Gather8LayerTest,
GatherAxes4i4b1,
Gather8LayerTest::getTestCaseName
);

INSTANTIATE_TEST_SUITE_P(
smoke_Gather7Axes4i4b2,
Gather8LayerTest,
GatherAxes4i4b1,
Gather8LayerTest::getTestCaseName
);

INSTANTIATE_TEST_SUITE_P(
smoke_Gather7Axes4i8b1,
Gather8LayerTest,
GatherAxes4i8b1,
Gather8LayerTest::getTestCaseName
);

INSTANTIATE_TEST_SUITE_P(
smoke_Gather7Axes4i8b2,
Gather8LayerTest,
GatherAxes4i8b2,
Gather8LayerTest::getTestCaseName
);

const std::vector<std::vector<int>> indices = {
std::vector<int>{0, 3, 2, 1},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,8 @@ TEST_P(Gather7LayerTest, CompareWithRefs) {
Run();
};

TEST_P(Gather8LayerTest, CompareWithRefs) {
Run();
};

} // namespace LayerTestsDefinitions
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,13 @@ class Gather7LayerTest : public testing::WithParamInterface<gather7ParamsTuple>,
void SetUp() override;
};

class Gather8LayerTest : public testing::WithParamInterface<gather7ParamsTuple>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<gather7ParamsTuple>& obj);

protected:
void SetUp() override;
};

} // namespace LayerTestsDefinitions
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,47 @@ void Gather7LayerTest::SetUp() {
function = std::make_shared<ngraph::Function>(results, functionParams, "gather");
}

std::string Gather8LayerTest::getTestCaseName(const testing::TestParamInfo<gather7ParamsTuple>& obj) {
std::tuple<int, int> axis_batchIdx;
std::vector<int> indices;
std::vector<size_t> indicesShape, inputShape;
InferenceEngine::Precision netPrecision;
InferenceEngine::Precision inPrc, outPrc;
InferenceEngine::Layout inLayout, outLayout;
std::string targetName;
std::tie(inputShape, indicesShape, axis_batchIdx, netPrecision, inPrc, outPrc, inLayout, outLayout, targetName) = obj.param;
std::ostringstream result;
result << "IS=" << CommonTestUtils::vec2str(inputShape) << "_";
result << "axis=" << std::get<0>(axis_batchIdx) << "_";
result << "batchIdx=" << std::get<1>(axis_batchIdx) << "_";
result << "indicesShape=" << CommonTestUtils::vec2str(indicesShape) << "_";
result << "netPRC=" << netPrecision.name() << "_";
result << "inPRC=" << inPrc.name() << "_";
result << "outPRC=" << outPrc.name() << "_";
result << "inL=" << inLayout << "_";
result << "outL=" << outLayout << "_";
result << "trgDev=" << targetName << "_";
return result.str();
}

void Gather8LayerTest::SetUp() {
std::tuple<int, int> axis_batchIdx;
std::vector<size_t> indicesShape;
std::vector<size_t> inputShape;
InferenceEngine::Precision netPrecision;
std::tie(inputShape, indicesShape, axis_batchIdx, netPrecision, inPrc, outPrc, inLayout, outLayout, targetDevice) = GetParam();
int axis = std::get<0>(axis_batchIdx);
int batchIdx = std::get<1>(axis_batchIdx);
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto functionParams = ngraph::builder::makeParams(ngPrc, { inputShape });
auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(functionParams));
auto indicesNode = ngraph::builder::makeConstant<int>(ngraph::element::i64, indicesShape, {}, true,
inputShape[axis < 0 ? axis + inputShape.size() : axis] - 1,
1 - static_cast<int>(inputShape[axis < 0 ? axis + inputShape.size() : axis]));
auto axisNode = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape({}), { axis });
auto gather = std::make_shared<ngraph::opset8::Gather>(paramOuts[0], indicesNode, axisNode, batchIdx);
ngraph::ResultVector results{ std::make_shared<ngraph::opset8::Result>(gather) };
function = std::make_shared<ngraph::Function>(results, functionParams, "gather");
}

} // namespace LayerTestsDefinitions
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/opsets/opset8.hpp>

#include "ngraph_functions/utils/data_utils.hpp"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,19 @@ struct gather : public primitive_base<gather> {
/// @param axis Gathering axis.
/// @param output_shape Output shape.
/// @param batch_dim Batch_dim
/// @param support_neg_ind Support negative indexes
gather(const primitive_id& id,
const primitive_id& dict,
const primitive_id& idx,
const gather_axis axis,
const format& output_format,
const tensor& output_shape,
const int64_t batch_dim = 0,
const padding& output_padding = padding())
: primitive_base(id, {dict, idx}, output_padding), axis(axis), output_format(output_format), output_shape(output_shape), batch_dim(batch_dim) {}
const bool support_neg_ind = false,
const padding& output_padding = padding()
)
: primitive_base(id, {dict, idx}, output_padding), axis(axis), output_format(output_format),
output_shape(output_shape), batch_dim(batch_dim), support_neg_ind(support_neg_ind) {}

/// @brief Gathering axis
gather_axis axis;
Expand All @@ -53,6 +57,8 @@ struct gather : public primitive_base<gather> {
tensor output_shape;
/// @brief Gathering batch_dim
int64_t batch_dim;
/// @brief Support negative indexes
bool support_neg_ind;
};
/// @}
/// @}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ static int64_t GetGatherBatchDim(const gather_params& params) {
return params.batch_dim;
}

static inline std::string GetGatherMaxIndexDim(const gather_params& params) {
return std::to_string(params.inputs[0].GetDims().at(params.inputs[0].GetDims().size() - GetGatherChannelIndex(params) - 1).v);
}

static inline std::string GetOrderString(std::vector<std::string>& order) {
std::string order_str = order[0];
for (size_t i = 1; i < order.size(); i++)
Expand Down Expand Up @@ -168,6 +172,8 @@ JitConstants GatherKernelRef::GetJitConstants(const gather_params& params) const

jit.AddConstant(MakeJitConstant("DICTIONARY_INDEX_ORDER", GetDictionaryIndexOrder(params, GetGatherChannelIndex(params))));
jit.AddConstant(MakeJitConstant("INDICES_INDEX_ORDER", GetIndecesIdxOrder(params, GetGatherChannelIndex(params), GetGatherBatchDim(params))));
if (params.support_neg_ind)
jit.AddConstant(MakeJitConstant("INDEX_DIM", GetGatherMaxIndexDim(params)));

if (!params.fused_ops.empty()) {
std::vector<std::string> idx_order = GetOrder(params.inputs[0].GetDims().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ namespace kernel_selector {
// gather_params
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct gather_params : public base_params {
gather_params() : base_params(KernelType::GATHER), axis(GatherAxis::BATCH), batch_dim(0) {}
gather_params() : base_params(KernelType::GATHER), axis(GatherAxis::BATCH), batch_dim(0), support_neg_ind(false) {}

GatherAxis axis;
int64_t batch_dim;
bool support_neg_ind;
virtual ParamsKey GetParamsKey() const { return base_params::GetParamsKey(); }
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,19 @@
#include "include/data_types.cl"
#include "include/fetch_data.cl"

#define INPUT_AXIS_INDEX (uint)indices[indices_idx]
#ifdef INDEX_DIM
inline uint FUNC(get_positive_index)(int in)
{
if(in < 0)
return in + INDEX_DIM;
else
return in;
}
#define INPUT_AXIS_INDEX (uint)FUNC_CALL(get_positive_index)(indices[indices_idx])
#else
#define INPUT_AXIS_INDEX (uint)(indices[indices_idx])
#endif

#define GET_DICTIONARY_INDEX(idx_order) INPUT0_GET_INDEX(idx_order)
#define GET_INDICES_INDEX(idx_order) INPUT1_GET_INDEX(idx_order)
#define GET_INDEX(prefix, num, idx_order) CAT(CAT(prefix, num), _GET_INDEX)(idx_order)
Expand Down
1 change: 1 addition & 0 deletions inference-engine/thirdparty/clDNN/src/gpu/gather_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct gather_gpu : typed_primitive_gpu_impl<gather> {

gather_params.axis = convert_axis(arg.get_primitive()->axis);
gather_params.batch_dim = size_t(arg.get_primitive()->batch_dim);
gather_params.support_neg_ind = arg.get_primitive()->support_neg_ind;

gather_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));

Expand Down
Loading

0 comments on commit ff2d661

Please sign in to comment.