Skip to content

Commit

Permalink
[TF FE] Fix CTCLoss translator (openvinotoolkit#20775)
Browse files Browse the repository at this point in the history
* Fix CTCLoss translator

Signed-off-by: Kazantsev, Roman <[email protected]>

* Expend layer tests for CTCLoss

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Oct 31, 2023
1 parent 0076f7f commit fc4fe07
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/frontends/tensorflow_common/src/op/ctc_loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ OutputVector translate_ctc_loss_op(const NodeContext& node) {

// retrieve all attributes for CTCLoss
auto preprocess_collapse_repeated = node.get_attribute<bool>("preprocess_collapse_repeated", false);
auto ctc_merge_repeated = node.get_attribute<bool>("preprocess_collapse_repeated", true);
auto ctc_merge_repeated = node.get_attribute<bool>("ctc_merge_repeated", true);
auto time_major = node.get_attribute<bool>("time_major", true);

if (time_major) {
Expand Down
17 changes: 12 additions & 5 deletions tests/layer_tests/tensorflow_tests/test_tf_CTCLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _prepare_input(self, inputs_dict):
inputs_dict[input] = np.random.randint(0, 5, inputs_dict[input]).astype(np.float32)
return inputs_dict

def create_ctcloss_placeholder_const_net(self, inputs, targets):
def create_ctcloss_placeholder_const_net(self, inputs, targets, preprocess_collapse_repeated, ctc_merge_repeated):
seq_lens = np.array([inputs[2]], dtype=np.int32)
x = [targets]

Expand All @@ -36,7 +36,9 @@ def create_ctcloss_placeholder_const_net(self, inputs, targets):
tf_inputs = tf.compat.v1.placeholder(tf.float32, inputs, "inputs")

ctc_loss = tf.raw_ops.CTCLoss(inputs=tf_inputs, labels_indices=indices, labels_values=vals,
sequence_length=seq_lens)
sequence_length=seq_lens,
preprocess_collapse_repeated=preprocess_collapse_repeated,
ctc_merge_repeated=ctc_merge_repeated)
# compute exponent since CTCLoss value is -ln(prob)
tf.math.exp(-ctc_loss[0])

Expand All @@ -54,11 +56,16 @@ def create_ctcloss_placeholder_const_net(self, inputs, targets):
]

@pytest.mark.parametrize("params", test_data)
@pytest.mark.parametrize("preprocess_collapse_repeated", [True, False, None])
@pytest.mark.parametrize("ctc_merge_repeated", [True, False, None])
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
@pytest.mark.skipif(platform == 'darwin', reason="Ticket - 122182")
def test_ctcloss_placeholder_const(self, params, ie_device, precision, ir_version, temp_dir,
def test_ctcloss_placeholder_const(self, params, preprocess_collapse_repeated, ctc_merge_repeated,
ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_ctcloss_placeholder_const_net(**params),
self._test(*self.create_ctcloss_placeholder_const_net(**params,
preprocess_collapse_repeated=preprocess_collapse_repeated,
ctc_merge_repeated=ctc_merge_repeated),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api, custom_eps=1e-2)
use_new_frontend=use_new_frontend, use_old_api=use_old_api)

0 comments on commit fc4fe07

Please sign in to comment.