From 1b7f7b36d5f5d0b3b3aa7ccc767c6adb000fdad1 Mon Sep 17 00:00:00 2001 From: Mikhail Treskin Date: Tue, 18 May 2021 13:08:17 +0300 Subject: [PATCH 1/3] Handle 3rd optional input in CTCGreedyDecoderSeqLen evaluate --- ngraph/test/runtime/interpreter/evaluates_map.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ngraph/test/runtime/interpreter/evaluates_map.cpp b/ngraph/test/runtime/interpreter/evaluates_map.cpp index a3b51e57d5aef4..4eeb54dfd99fc7 100644 --- a/ngraph/test/runtime/interpreter/evaluates_map.cpp +++ b/ngraph/test/runtime/interpreter/evaluates_map.cpp @@ -2065,12 +2065,17 @@ namespace using TF = typename element_type_traits::value_type; using TI = typename element_type_traits::value_type; using TIND1 = typename element_type_traits::value_type; + TI blank_index_val = *inputs[0]->get_shape().end() - 1; + const TI *blank_index = &blank_index_val; + if (inputs.size() == 3) { + blank_index = inputs[2]->get_data_ptr(); + } if (op->get_sequence_length_type() == element::i32) { runtime::reference::ctc_greedy_decoder_seq_len( inputs[0]->get_data_ptr(), inputs[1]->get_data_ptr(), - inputs[2]->get_data_ptr(), + blank_index, outputs[0]->get_data_ptr(), outputs[1]->get_data_ptr(), inputs[0]->get_shape(), @@ -2082,7 +2087,7 @@ namespace runtime::reference::ctc_greedy_decoder_seq_len( inputs[0]->get_data_ptr(), inputs[1]->get_data_ptr(), - inputs[2]->get_data_ptr(), + blank_index, outputs[0]->get_data_ptr(), outputs[1]->get_data_ptr(), inputs[0]->get_shape(), From 4bdb945d8e755fa0bc57f4e4f476b8f3e7c16687 Mon Sep 17 00:00:00 2001 From: Mikhail Treskin Date: Thu, 20 May 2021 15:38:19 +0300 Subject: [PATCH 2/3] Add test for OP without blank_index input --- .../backend/ctc_greedy_decoder_seq_len.in.cpp | 43 ++++++++++++++----- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/ngraph/test/backend/ctc_greedy_decoder_seq_len.in.cpp b/ngraph/test/backend/ctc_greedy_decoder_seq_len.in.cpp index 2017aa45a8b19a..18b96ae75fde54 100644 --- a/ngraph/test/backend/ctc_greedy_decoder_seq_len.in.cpp +++ b/ngraph/test/backend/ctc_greedy_decoder_seq_len.in.cpp @@ -43,8 +43,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len) auto data = make_shared(element::f32, data_shape); auto seq_len = make_shared(element::i32, seq_len_shape); - auto blanck_index = op::Constant::create(element::i32, Shape{}, {2}); - auto decoder = make_shared(data, seq_len, blanck_index, false); + auto blank_index = op::Constant::create(element::i32, Shape{}, {2}); + auto decoder = make_shared(data, seq_len, blank_index, false); auto function = make_shared(decoder, ParameterVector{data, seq_len}); auto test_case = test::TestCase(function); @@ -66,8 +66,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_merge) auto data = make_shared(element::f32, data_shape); auto seq_len = make_shared(element::i32, seq_len_shape); - auto blanck_index = op::Constant::create(element::i32, Shape{}, {2}); - auto decoder = make_shared(data, seq_len, blanck_index, true); + auto blank_index = op::Constant::create(element::i32, Shape{}, {2}); + auto decoder = make_shared(data, seq_len, blank_index, true); auto function = make_shared(decoder, ParameterVector{data, seq_len}); auto test_case = test::TestCase(function); @@ -89,8 +89,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_f16) auto data = make_shared(element::f16, data_shape); auto seq_len = make_shared(element::i32, seq_len_shape); - auto blanck_index = op::Constant::create(element::i32, Shape{}, {2}); - auto decoder = make_shared(data, seq_len, blanck_index, true); + auto blank_index = op::Constant::create(element::i32, Shape{}, {2}); + auto decoder = make_shared(data, seq_len, blank_index, true); auto function = make_shared(decoder, ParameterVector{data, seq_len}); auto test_case = test::TestCase(function); @@ -112,8 +112,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batche auto data = make_shared(element::f32, data_shape); auto seq_len = make_shared(element::i32, seq_len_shape); - auto blanck_index = op::Constant::create(element::i32, Shape{}, {2}); - auto decoder = make_shared(data, seq_len, blanck_index, false); + auto blank_index = op::Constant::create(element::i32, Shape{}, {2}); + auto decoder = make_shared(data, seq_len, blank_index, false); auto function = make_shared(decoder, ParameterVector{data, seq_len}); auto test_case = test::TestCase(function); @@ -154,8 +154,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batche auto data = make_shared(element::f32, data_shape); auto seq_len = make_shared(element::i32, seq_len_shape); - auto blanck_index = op::Constant::create(element::i32, Shape{}, {2}); - auto decoder = make_shared(data, seq_len, blanck_index, false); + auto blank_index = op::Constant::create(element::i32, Shape{}, {2}); + auto decoder = make_shared(data, seq_len, blank_index, false); auto function = make_shared(decoder, ParameterVector{data, seq_len}); auto test_case = test::TestCase(function); @@ -170,3 +170,26 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batche test_case.run(); } + + +NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_no_optional_input) +{ + const int N = 1; + const int T = 3; + const int C = 3; + const auto data_shape = Shape{N, T, C}; + const auto seq_len_shape = Shape{N}; + + auto data = make_shared(element::f32, data_shape); + auto seq_len = make_shared(element::i32, seq_len_shape); + auto decoder = make_shared(data, seq_len, false); + auto function = make_shared(decoder, ParameterVector{data, seq_len}); + auto test_case = test::TestCase(function); + + test_case.add_input({0.1f, 0.2f, 0.f, 0.4f, 0.3f, 0.f, 0.5f, 0.6f, 0.f}); + test_case.add_input({2}); + test_case.add_expected_output(Shape{N, T}, vector{1, 0, -1}); + test_case.add_expected_output(Shape{N}, vector{2}); + + test_case.run(); +} \ No newline at end of file From 541ec6bcd32e28d42beee8081109a453be794a29 Mon Sep 17 00:00:00 2001 From: Mikhail Treskin Date: Fri, 21 May 2021 11:22:16 +0300 Subject: [PATCH 3/3] Fix default blank_index calculation --- ngraph/test/runtime/interpreter/evaluates_map.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ngraph/test/runtime/interpreter/evaluates_map.cpp b/ngraph/test/runtime/interpreter/evaluates_map.cpp index 4eeb54dfd99fc7..2c537d44d04c02 100644 --- a/ngraph/test/runtime/interpreter/evaluates_map.cpp +++ b/ngraph/test/runtime/interpreter/evaluates_map.cpp @@ -2065,7 +2065,7 @@ namespace using TF = typename element_type_traits::value_type; using TI = typename element_type_traits::value_type; using TIND1 = typename element_type_traits::value_type; - TI blank_index_val = *inputs[0]->get_shape().end() - 1; + TI blank_index_val = inputs[0]->get_shape().back() - 1; const TI *blank_index = &blank_index_val; if (inputs.size() == 3) { blank_index = inputs[2]->get_data_ptr();