Skip to content

Commit

Permalink
[IE CLDNN] Added CTCGreedyDecoderSeqLen operation (#4119)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman authored Feb 14, 2021
1 parent 3f5ff2c commit d406a5a
Show file tree
Hide file tree
Showing 13 changed files with 266 additions and 70 deletions.
3 changes: 3 additions & 0 deletions inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,5 +201,8 @@ REGISTER_FACTORY(v5, Round);
// REGISTER_FACTORY(v5, Loop);
// REGISTER_FACTORY(v5, RNNSequence);

// ------------------------------ Supported v6 ops ------------------------------ //
REGISTER_FACTORY(v6, CTCGreedyDecoderSeqLen);

// --------------------------- Supported internal ops --------------------------- //
REGISTER_FACTORY(internal, NonMaxSuppressionIEInternal);
112 changes: 101 additions & 11 deletions inference-engine/src/cldnn_engine/ops/ctc_greedy_decoder.cpp
Original file line number Diff line number Diff line change
@@ -1,32 +1,122 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "cldnn_program.h"
#include "cldnn_common_utils.h"

#include "ngraph/op/ctc_greedy_decoder.hpp"
#include "ngraph/op/ctc_greedy_decoder_seq_len.hpp"

#include "api/ctc_greedy_decoder.hpp"
#include "api/reorder.hpp"
#include "api/mutable_data.hpp"

#include "transformations/utils/utils.hpp"

namespace CLDNNPlugin {

void CreateCTCGreedyDecoderOp(Program& p, const std::shared_ptr<ngraph::op::v0::CTCGreedyDecoder>& op) {
p.ValidateInputs(op, {2});
void CreateCommonCTCGreedyDecoderOp(Program& p, const std::shared_ptr<ngraph::Node>& op, bool ctc_merge_repeated) {
p.ValidateInputs(op, {2, 3});
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
std::string layerName = layer_type_name_ID(op);

auto primitive = cldnn::ctc_greedy_decoder(layerName,
inputPrimitives[0],
inputPrimitives[1],
op->get_ctc_merge_repeated(),
DataTypeFromPrecision(op->get_output_element_type(0)),
CldnnTensorFromIEDims(op->get_output_shape(0)));
std::vector<cldnn::primitive_id> reorderedInputs;
reorderedInputs.resize(inputPrimitives.size());

for (size_t portIndex = 0; portIndex < inputPrimitives.size(); portIndex++) {
auto inputDataType = DataTypeFromPrecision(op->get_input_element_type(portIndex));
if (inputDataType == cldnn::data_types::i64) {
// clDNN primitive supports only i32 data type for 'sequence_length' and 'blank_index' inputs
// so we need additional reorder if it's provided as i64
auto reorderPrimName = inputPrimitives[portIndex] + "_" + op->get_friendly_name() + Program::m_preProcessTag;
auto targetFormat = DefaultFormatForDims(op->get_input_shape(portIndex).size());
auto preprocessPrim = cldnn::reorder(reorderPrimName,
inputPrimitives[portIndex],
targetFormat,
cldnn::data_types::i32);
p.AddPrimitive(preprocessPrim);
p.AddInnerPrimitiveToProfiler(reorderPrimName, layer_type_name_ID(op), op);
reorderedInputs[portIndex] = (reorderPrimName);
} else {
reorderedInputs[portIndex] = inputPrimitives[portIndex];
}
}

uint32_t blank_index = op->get_input_shape(0).back() - 1;
if (reorderedInputs.size() == 3) {
auto blank_index_node = std::dynamic_pointer_cast<ngraph::op::v0::Constant>(op->get_input_node_shared_ptr(2));
if (!blank_index_node) {
THROW_IE_EXCEPTION << "Unsupported blank_index node type in " << op->get_friendly_name() << " (" << op->get_type_name() << ")";
}
float val;
if (ngraph::shape_size(blank_index_node->get_output_shape(0)) != 1 || !ngraph::op::util::get_single_value(blank_index_node, val)) {
THROW_IE_EXCEPTION << "Unsupported parameter size in " << op->get_friendly_name() << " (" << op->get_type_name() << ")";
}
blank_index = static_cast<uint32_t>(val);
reorderedInputs.pop_back();
}

std::size_t num_output = op->get_output_size();

std::vector<cldnn::memory> shared_memory;
if (num_output == 2) {
auto mutable_precision = op->get_output_element_type(1);
if (mutable_precision == ngraph::element::i64) {
mutable_precision = ngraph::element::i32;
}

cldnn::layout mutableLayout = cldnn::layout(
DataTypeFromPrecision(mutable_precision),
DefaultFormatForDims(op->get_output_shape(1).size()),
CldnnTensorFromIEDims(op->get_output_shape(1)));

shared_memory.emplace_back(cldnn::memory::allocate(p.GetEngine(), mutableLayout));

cldnn::primitive_id ctc_gd_mutable_id_w = layer_type_name_ID(op) + "_md_write";
auto ctc_gd_mutable_prim = cldnn::mutable_data(ctc_gd_mutable_id_w, shared_memory[0]);
p.primitivesToIRLayersMap[ctc_gd_mutable_id_w] = { op->get_friendly_name() };
p.primitiveIDs[ctc_gd_mutable_id_w] = ctc_gd_mutable_id_w;
p.AddPrimitive(ctc_gd_mutable_prim);
reorderedInputs.push_back(ctc_gd_mutable_id_w);
}

auto CTCGreedyDecoderLayerName = num_output == 2 ? layer_type_name_ID(op) + ".0" : layer_type_name_ID(op);
auto primitive = cldnn::ctc_greedy_decoder(
CTCGreedyDecoderLayerName,
reorderedInputs,
blank_index,
ctc_merge_repeated,
CldnnTensorFromIEDims(op->get_output_shape(0)));

// clDNN primitive supports only i32 as output data type
primitive.output_data_type = DataTypeFromPrecision(ngraph::element::i32);

if (num_output == 2) {
primitive.second_output = reorderedInputs.back();
}

p.AddPrimitive(primitive);
p.AddPrimitiveToProfiler(op);

if (num_output == 2) {
cldnn::primitive_id ctc_gd_mutable_id_r = layer_type_name_ID(op) + ".1";
auto ctc_gd_mutable_prim_r = cldnn::mutable_data(ctc_gd_mutable_id_r, { CTCGreedyDecoderLayerName }, shared_memory[0]);
p.primitivesToIRLayersMap[ctc_gd_mutable_id_r] = { op->get_friendly_name() };
p.primitiveIDs[ctc_gd_mutable_id_r] = ctc_gd_mutable_id_r;
p.AddPrimitive(ctc_gd_mutable_prim_r);
}

p.AddPrimitiveToProfiler(CTCGreedyDecoderLayerName, op);
}

void CreateCTCGreedyDecoderOp(Program& p, const std::shared_ptr<ngraph::op::v0::CTCGreedyDecoder>& op) {
CreateCommonCTCGreedyDecoderOp(p, op, op->get_ctc_merge_repeated());
}

void CreateCTCGreedyDecoderSeqLenOp(Program& p, const std::shared_ptr<ngraph::op::v6::CTCGreedyDecoderSeqLen>& op) {
CreateCommonCTCGreedyDecoderOp(p, op, op->get_merge_repeated());
}

REGISTER_FACTORY_IMPL(v0, CTCGreedyDecoder);
REGISTER_FACTORY_IMPL(v6, CTCGreedyDecoderSeqLen);

} // namespace CLDNNPlugin
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -10,24 +10,31 @@ using namespace LayerTestsDefinitions;
using namespace ngraph::helpers;

namespace {
// Common params
const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16
};
// Common params
const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16
};
std::vector<bool> mergeRepeated{true, false};

const auto basicCases = ::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(std::vector<size_t>({ 10, 1, 16 }),
std::vector<size_t>({ 20, 2, 8 })),
::testing::Values(true, false),
::testing::Values(CommonTestUtils::DEVICE_GPU));
const auto basicCases = ::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(std::vector<size_t>({ 50, 3, 3 }),
std::vector<size_t>({ 50, 3, 7 }),
std::vector<size_t>({ 50, 3, 8 }),
std::vector<size_t>({ 50, 3, 16 }),
std::vector<size_t>({ 50, 3, 128 }),
std::vector<size_t>({ 50, 3, 49 }),
std::vector<size_t>({ 50, 3, 55 }),
std::vector<size_t>({ 1, 1, 16 })),
::testing::ValuesIn(mergeRepeated),
::testing::Values(CommonTestUtils::DEVICE_GPU));

INSTANTIATE_TEST_CASE_P(smoke_CTC_Greedy_decoder_Basic, CTCGreedyDecoderLayerTest,
basicCases,
CTCGreedyDecoderLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_CtcGreedyDecoderBasic, CTCGreedyDecoderLayerTest,
basicCases,
CTCGreedyDecoderLayerTest::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>
#include "single_layer_tests/ctc_greedy_decoder_seq_len.hpp"
#include "common_test_utils/test_constants.hpp"

using namespace LayerTestsDefinitions;
using namespace ngraph::helpers;

namespace {

std::vector<std::vector<size_t>> inputShape{{1, 1, 1}, {1, 6, 10}, {3, 3, 16}, {5, 3, 55}};

const std::vector<InferenceEngine::Precision> probPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16
};
const std::vector<InferenceEngine::Precision> idxPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64
};

std::vector<bool> mergeRepeated{true, false};

const auto basicCases = ::testing::Combine(
::testing::ValuesIn(inputShape),
::testing::ValuesIn(probPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(0),
::testing::ValuesIn(mergeRepeated),
::testing::Values(CommonTestUtils::DEVICE_GPU));

INSTANTIATE_TEST_CASE_P(smoke_set1, CTCGreedyDecoderSeqLenLayerTest,
basicCases,
CTCGreedyDecoderSeqLenLayerTest::getTestCaseName);

INSTANTIATE_TEST_CASE_P(smoke_set2, CTCGreedyDecoderSeqLenLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{2, 8, 11}),
::testing::ValuesIn(probPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::ValuesIn(std::vector<int>{0, 5, 10}),
::testing::ValuesIn(mergeRepeated),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
CTCGreedyDecoderSeqLenLayerTest::getTestCaseName);
} // namespace
30 changes: 15 additions & 15 deletions inference-engine/thirdparty/clDNN/api/ctc_greedy_decoder.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
// Copyright (c) 2020 Intel Corporation
// Copyright (c) 2020-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,24 +32,24 @@ struct ctc_greedy_decoder : public primitive_base<ctc_greedy_decoder> {

/// @brief Constructs ctc_greedy_decoder primitive.
/// @param id This primitive id.
/// @param input Input primitive id.
/// @param input sequence_indicators primitive id.
/// @param ctc_merge_repeated int
/// @param input Input primitive id (input, sequence_indicators, second_output(optional)).
/// @param blank_index Specifies the class index to use for the blank class.
/// @param ctc_merge_repeated Flag for merging repeated labels during the CTC calculation
ctc_greedy_decoder(const primitive_id& id,
const primitive_id& input,
const primitive_id& sequence_indicators,
const bool ctc_merge_repeated,
const data_types data_type,
const tensor output_tensor,
const padding& output_padding = padding())
: primitive_base(id, { input, sequence_indicators },
output_padding, optional_data_type{ data_type }),
ctc_merge_repeated(ctc_merge_repeated),
output_tensor(output_tensor)
{}
const std::vector<primitive_id>& input,
const uint32_t blank_index,
const bool ctc_merge_repeated,
const tensor output_tensor,
const padding& output_padding = padding())
: primitive_base(id, input, output_padding)
, blank_index(blank_index)
, ctc_merge_repeated(ctc_merge_repeated)
, output_tensor(output_tensor) {}

uint32_t blank_index;
bool ctc_merge_repeated;
tensor output_tensor;
primitive_id second_output;
};
/// @}
/// @}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020 Intel Corporation
// Copyright (c) 2020-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -24,11 +24,23 @@ JitConstants CTCGreedyDecoderKernelBase::GetJitConstants(const ctc_greedy_decode

jit.AddConstants({
MakeJitConstant("ctc_merge_repeated_", params.merge_repeated),
MakeJitConstant("T_", inp.Batch().v),
MakeJitConstant("N_", inp.Feature().v),
MakeJitConstant("blank_index_", params.blank_index),
MakeJitConstant("C_", inp.Y().v)
});

if (params.outputs_num == 2) {
jit.AddConstants({
MakeJitConstant("SECOND_OUTPUT_EXIST", 1),
MakeJitConstant("N_", inp.Batch().v),
MakeJitConstant("T_", inp.Feature().v)
});
} else {
jit.AddConstants({
MakeJitConstant("T_", inp.Batch().v),
MakeJitConstant("N_", inp.Feature().v)
});
};

return jit;
}

Expand Down Expand Up @@ -71,6 +83,10 @@ KernelsData CTCGreedyDecoderKernelBase::GetCommonKernelsData(const Params& param
2, // input and sequence indicatiors
GetFusedPrimitiveInputsCount(params));

if (orgParams.outputs_num == 2) {
kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, 2});
}

return {kd};
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020 Intel Corporation
// Copyright (c) 2020-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,8 @@ struct ctc_greedy_decoder_params : public base_params {
ctc_greedy_decoder_params() : base_params(KernelType::CTC_GREEDY_DECODER) {}

bool merge_repeated = true;
uint32_t blank_index;
uint32_t outputs_num = 1;
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020 Intel Corporation
// Copyright (c) 2020-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -21,9 +21,13 @@ ParamsKey CTCGreedyDecoderKernelRef::GetSupportedKey() const {
ParamsKey k;
k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT32);
k.EnableInputDataType(Datatype::INT64);

k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::INT64);

k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
Expand Down
Loading

0 comments on commit d406a5a

Please sign in to comment.