diff --git a/ngraph/test/backend/ctc_greedy_decoder_seq_len.in.cpp b/ngraph/test/backend/ctc_greedy_decoder_seq_len.in.cpp index 2017aa45a8b19a..18b96ae75fde54 100644 --- a/ngraph/test/backend/ctc_greedy_decoder_seq_len.in.cpp +++ b/ngraph/test/backend/ctc_greedy_decoder_seq_len.in.cpp @@ -43,8 +43,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len) auto data = make_shared(element::f32, data_shape); auto seq_len = make_shared(element::i32, seq_len_shape); - auto blanck_index = op::Constant::create(element::i32, Shape{}, {2}); - auto decoder = make_shared(data, seq_len, blanck_index, false); + auto blank_index = op::Constant::create(element::i32, Shape{}, {2}); + auto decoder = make_shared(data, seq_len, blank_index, false); auto function = make_shared(decoder, ParameterVector{data, seq_len}); auto test_case = test::TestCase(function); @@ -66,8 +66,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_merge) auto data = make_shared(element::f32, data_shape); auto seq_len = make_shared(element::i32, seq_len_shape); - auto blanck_index = op::Constant::create(element::i32, Shape{}, {2}); - auto decoder = make_shared(data, seq_len, blanck_index, true); + auto blank_index = op::Constant::create(element::i32, Shape{}, {2}); + auto decoder = make_shared(data, seq_len, blank_index, true); auto function = make_shared(decoder, ParameterVector{data, seq_len}); auto test_case = test::TestCase(function); @@ -89,8 +89,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_f16) auto data = make_shared(element::f16, data_shape); auto seq_len = make_shared(element::i32, seq_len_shape); - auto blanck_index = op::Constant::create(element::i32, Shape{}, {2}); - auto decoder = make_shared(data, seq_len, blanck_index, true); + auto blank_index = op::Constant::create(element::i32, Shape{}, {2}); + auto decoder = make_shared(data, seq_len, blank_index, true); auto function = make_shared(decoder, ParameterVector{data, seq_len}); auto test_case = test::TestCase(function); @@ -112,8 +112,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batche auto data = make_shared(element::f32, data_shape); auto seq_len = make_shared(element::i32, seq_len_shape); - auto blanck_index = op::Constant::create(element::i32, Shape{}, {2}); - auto decoder = make_shared(data, seq_len, blanck_index, false); + auto blank_index = op::Constant::create(element::i32, Shape{}, {2}); + auto decoder = make_shared(data, seq_len, blank_index, false); auto function = make_shared(decoder, ParameterVector{data, seq_len}); auto test_case = test::TestCase(function); @@ -154,8 +154,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batche auto data = make_shared(element::f32, data_shape); auto seq_len = make_shared(element::i32, seq_len_shape); - auto blanck_index = op::Constant::create(element::i32, Shape{}, {2}); - auto decoder = make_shared(data, seq_len, blanck_index, false); + auto blank_index = op::Constant::create(element::i32, Shape{}, {2}); + auto decoder = make_shared(data, seq_len, blank_index, false); auto function = make_shared(decoder, ParameterVector{data, seq_len}); auto test_case = test::TestCase(function); @@ -170,3 +170,26 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batche test_case.run(); } + + +NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_no_optional_input) +{ + const int N = 1; + const int T = 3; + const int C = 3; + const auto data_shape = Shape{N, T, C}; + const auto seq_len_shape = Shape{N}; + + auto data = make_shared(element::f32, data_shape); + auto seq_len = make_shared(element::i32, seq_len_shape); + auto decoder = make_shared(data, seq_len, false); + auto function = make_shared(decoder, ParameterVector{data, seq_len}); + auto test_case = test::TestCase(function); + + test_case.add_input({0.1f, 0.2f, 0.f, 0.4f, 0.3f, 0.f, 0.5f, 0.6f, 0.f}); + test_case.add_input({2}); + test_case.add_expected_output(Shape{N, T}, vector{1, 0, -1}); + test_case.add_expected_output(Shape{N}, vector{2}); + + test_case.run(); +} \ No newline at end of file diff --git a/ngraph/test/runtime/interpreter/evaluates_map.cpp b/ngraph/test/runtime/interpreter/evaluates_map.cpp index a3b51e57d5aef4..2c537d44d04c02 100644 --- a/ngraph/test/runtime/interpreter/evaluates_map.cpp +++ b/ngraph/test/runtime/interpreter/evaluates_map.cpp @@ -2065,12 +2065,17 @@ namespace using TF = typename element_type_traits::value_type; using TI = typename element_type_traits::value_type; using TIND1 = typename element_type_traits::value_type; + TI blank_index_val = inputs[0]->get_shape().back() - 1; + const TI *blank_index = &blank_index_val; + if (inputs.size() == 3) { + blank_index = inputs[2]->get_data_ptr(); + } if (op->get_sequence_length_type() == element::i32) { runtime::reference::ctc_greedy_decoder_seq_len( inputs[0]->get_data_ptr(), inputs[1]->get_data_ptr(), - inputs[2]->get_data_ptr(), + blank_index, outputs[0]->get_data_ptr(), outputs[1]->get_data_ptr(), inputs[0]->get_shape(), @@ -2082,7 +2087,7 @@ namespace runtime::reference::ctc_greedy_decoder_seq_len( inputs[0]->get_data_ptr(), inputs[1]->get_data_ptr(), - inputs[2]->get_data_ptr(), + blank_index, outputs[0]->get_data_ptr(), outputs[1]->get_data_ptr(), inputs[0]->get_shape(),