From fc4fe07a0e49b008243cab82ad54912275baef54 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Tue, 31 Oct 2023 08:51:18 +0400 Subject: [PATCH] [TF FE] Fix CTCLoss translator (#20775) * Fix CTCLoss translator Signed-off-by: Kazantsev, Roman * Expend layer tests for CTCLoss --------- Signed-off-by: Kazantsev, Roman --- .../tensorflow_common/src/op/ctc_loss.cpp | 2 +- .../tensorflow_tests/test_tf_CTCLoss.py | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/frontends/tensorflow_common/src/op/ctc_loss.cpp b/src/frontends/tensorflow_common/src/op/ctc_loss.cpp index 1abba8801f2c64..8679379b1c72e3 100644 --- a/src/frontends/tensorflow_common/src/op/ctc_loss.cpp +++ b/src/frontends/tensorflow_common/src/op/ctc_loss.cpp @@ -36,7 +36,7 @@ OutputVector translate_ctc_loss_op(const NodeContext& node) { // retrieve all attributes for CTCLoss auto preprocess_collapse_repeated = node.get_attribute("preprocess_collapse_repeated", false); - auto ctc_merge_repeated = node.get_attribute("preprocess_collapse_repeated", true); + auto ctc_merge_repeated = node.get_attribute("ctc_merge_repeated", true); auto time_major = node.get_attribute("time_major", true); if (time_major) { diff --git a/tests/layer_tests/tensorflow_tests/test_tf_CTCLoss.py b/tests/layer_tests/tensorflow_tests/test_tf_CTCLoss.py index 0a2eae6303386e..805ab3ff52f6fd 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_CTCLoss.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_CTCLoss.py @@ -18,7 +18,7 @@ def _prepare_input(self, inputs_dict): inputs_dict[input] = np.random.randint(0, 5, inputs_dict[input]).astype(np.float32) return inputs_dict - def create_ctcloss_placeholder_const_net(self, inputs, targets): + def create_ctcloss_placeholder_const_net(self, inputs, targets, preprocess_collapse_repeated, ctc_merge_repeated): seq_lens = np.array([inputs[2]], dtype=np.int32) x = [targets] @@ -36,7 +36,9 @@ def create_ctcloss_placeholder_const_net(self, inputs, targets): tf_inputs = tf.compat.v1.placeholder(tf.float32, inputs, "inputs") ctc_loss = tf.raw_ops.CTCLoss(inputs=tf_inputs, labels_indices=indices, labels_values=vals, - sequence_length=seq_lens) + sequence_length=seq_lens, + preprocess_collapse_repeated=preprocess_collapse_repeated, + ctc_merge_repeated=ctc_merge_repeated) # compute exponent since CTCLoss value is -ln(prob) tf.math.exp(-ctc_loss[0]) @@ -54,11 +56,16 @@ def create_ctcloss_placeholder_const_net(self, inputs, targets): ] @pytest.mark.parametrize("params", test_data) + @pytest.mark.parametrize("preprocess_collapse_repeated", [True, False, None]) + @pytest.mark.parametrize("ctc_merge_repeated", [True, False, None]) @pytest.mark.precommit_tf_fe @pytest.mark.nightly @pytest.mark.skipif(platform == 'darwin', reason="Ticket - 122182") - def test_ctcloss_placeholder_const(self, params, ie_device, precision, ir_version, temp_dir, + def test_ctcloss_placeholder_const(self, params, preprocess_collapse_repeated, ctc_merge_repeated, + ie_device, precision, ir_version, temp_dir, use_new_frontend, use_old_api): - self._test(*self.create_ctcloss_placeholder_const_net(**params), + self._test(*self.create_ctcloss_placeholder_const_net(**params, + preprocess_collapse_repeated=preprocess_collapse_repeated, + ctc_merge_repeated=ctc_merge_repeated), ie_device, precision, ir_version, temp_dir=temp_dir, - use_new_frontend=use_new_frontend, use_old_api=use_old_api, custom_eps=1e-2) + use_new_frontend=use_new_frontend, use_old_api=use_old_api)