Skip to content

Commit

Permalink
Handle 3rd optional input in CTCGreedyDecoderSeqLen evaluate (#5669)
Browse files Browse the repository at this point in the history
* Handle 3rd optional input in CTCGreedyDecoderSeqLen evaluate

* Add test for OP without blank_index input

* Fix default blank_index calculation
  • Loading branch information
Mikhail Treskin authored May 24, 2021
1 parent 97f0204 commit 66d9853
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 12 deletions.
43 changes: 33 additions & 10 deletions ngraph/test/backend/ctc_greedy_decoder_seq_len.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len)

auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, false);
auto blank_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);

Expand All @@ -66,8 +66,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_merge)

auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, true);
auto blank_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index, true);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);

Expand All @@ -89,8 +89,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_f16)

auto data = make_shared<op::Parameter>(element::f16, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, true);
auto blank_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index, true);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);

Expand All @@ -112,8 +112,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batche

auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, false);
auto blank_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);

Expand Down Expand Up @@ -154,8 +154,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batche

auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, false);
auto blank_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);

Expand All @@ -170,3 +170,26 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batche

test_case.run();
}


NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_no_optional_input)
{
const int N = 1;
const int T = 3;
const int C = 3;
const auto data_shape = Shape{N, T, C};
const auto seq_len_shape = Shape{N};

auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);

test_case.add_input<float>({0.1f, 0.2f, 0.f, 0.4f, 0.3f, 0.f, 0.5f, 0.6f, 0.f});
test_case.add_input<int32_t>({2});
test_case.add_expected_output(Shape{N, T}, vector<int32_t>{1, 0, -1});
test_case.add_expected_output(Shape{N}, vector<int32_t>{2});

test_case.run();
}
9 changes: 7 additions & 2 deletions ngraph/test/runtime/interpreter/evaluates_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2067,12 +2067,17 @@ namespace
using TF = typename element_type_traits<T1>::value_type;
using TI = typename element_type_traits<T2>::value_type;
using TIND1 = typename element_type_traits<TOUT>::value_type;
TI blank_index_val = inputs[0]->get_shape().back() - 1;
const TI *blank_index = &blank_index_val;
if (inputs.size() == 3) {
blank_index = inputs[2]->get_data_ptr<const TI>();
}
if (op->get_sequence_length_type() == element::i32)
{
runtime::reference::ctc_greedy_decoder_seq_len<TF>(
inputs[0]->get_data_ptr<const TF>(),
inputs[1]->get_data_ptr<const TI>(),
inputs[2]->get_data_ptr<const TI>(),
blank_index,
outputs[0]->get_data_ptr<TIND1>(),
outputs[1]->get_data_ptr<int32_t>(),
inputs[0]->get_shape(),
Expand All @@ -2084,7 +2089,7 @@ namespace
runtime::reference::ctc_greedy_decoder_seq_len<TF>(
inputs[0]->get_data_ptr<const TF>(),
inputs[1]->get_data_ptr<const TI>(),
inputs[2]->get_data_ptr<const TI>(),
blank_index,
outputs[0]->get_data_ptr<TIND1>(),
outputs[1]->get_data_ptr<int64_t>(),
inputs[0]->get_shape(),
Expand Down

0 comments on commit 66d9853

Please sign in to comment.