From a326cb42e77a2658877a34cb34a448f057989a6b Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Wed, 10 Apr 2024 08:17:55 +0400 Subject: [PATCH] [TF FE] Support StringHashBucketFast operation (#23946) **Details:** Support StringHashBucketFast operation. Needed for several models from TF Hub. Merge after https://github.com/openvinotoolkit/openvino_tokenizers/pull/115 **Ticket:** --------- Signed-off-by: Kazantsev, Roman --- .../tensorflow/docs/supported_ops.md | 2 +- src/frontends/tensorflow/src/op_table.cpp | 7 ++- .../test_tf_StringToHashBucketFast.py | 60 +++++++++++++++++++ 3 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 tests/layer_tests/tensorflow_tests/test_tf_StringToHashBucketFast.py diff --git a/src/frontends/tensorflow/docs/supported_ops.md b/src/frontends/tensorflow/docs/supported_ops.md index f206839f3aacbb..7206a2af141041 100644 --- a/src/frontends/tensorflow/docs/supported_ops.md +++ b/src/frontends/tensorflow/docs/supported_ops.md @@ -1227,7 +1227,7 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV | StringSplitV2NEW | YES | openvino-tokenizers required | | StringStrip | NO | | | StringToHashBucket | NO | | -| StringToHashBucketFast | NO | | +| StringToHashBucketFast | YES | openvino-tokenizers required | | StringToHashBucketStrong | NO | | | StringToNumber | NO | | | StringUpper | NO | | diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index e2abb859e3bc75..3422148db2d771 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -461,7 +461,12 @@ const std::map get_supported_ops() { }; const std::vector get_supported_ops_via_tokenizers() { - return {"RaggedTensorToSparse", "RaggedTensorToTensor", "StaticRegexReplace", "StringLower", "StringSplitV2"}; + return {"RaggedTensorToSparse", + "RaggedTensorToTensor", + "StaticRegexReplace", + "StringLower", + "StringSplitV2", + "StringToHashBucketFast"}; } } // namespace op } // namespace tensorflow diff --git a/tests/layer_tests/tensorflow_tests/test_tf_StringToHashBucketFast.py b/tests/layer_tests/tensorflow_tests/test_tf_StringToHashBucketFast.py new file mode 100644 index 00000000000000..08812fe7b46228 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_StringToHashBucketFast.py @@ -0,0 +1,60 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import platform + +import numpy as np +import pytest +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest +from common.utils.tf_utils import run_in_jenkins + +rng = np.random.default_rng() + + +class TestStringToHashBucketFast(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'input:0' in inputs_info + input_shape = inputs_info['input:0'] + inputs_data = {} + sample_data = rng.choice(self.strings_dictionary, input_shape) + inputs_data['input:0'] = sample_data + return inputs_data + + def create_string_to_hash_bucket_fast_net(self, input_shape, strings_dictionary, num_buckets): + self.strings_dictionary = strings_dictionary + + tf.compat.v1.reset_default_graph() + with tf.compat.v1.Session() as sess: + input = tf.compat.v1.placeholder(tf.string, input_shape, 'input') + tf.raw_ops.StringToHashBucketFast(input=input, num_buckets=num_buckets) + + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + ref_net = None + + return tf_net, ref_net + + @pytest.mark.parametrize("input_shape", [[], [2], [3, 4], [1, 3, 2]]) + @pytest.mark.parametrize("num_buckets", [1, 4, 7, 11]) + @pytest.mark.parametrize("strings_dictionary", + [['UPPER CASE SENTENCE', 'lower case sentence', ' UppEr LoweR CAse SENtence', ' '], + ['Первое Предложение', 'второе предложение', ' ', ' ТРЕТЬЕ ПРЕДЛОЖЕНИЕ '], + ['第一句話在這裡', '第二句話在這裡', '第三句話在這裡'], + ['', ' ', '12345 ']]) + @pytest.mark.precommit + @pytest.mark.nightly + @pytest.mark.xfail(condition=platform.system() in ('Darwin', 'Linux') and platform.machine() in ['arm', 'armv7l', + 'aarch64', + 'arm64', 'ARM64'], + reason='Ticket - 126314, 132699') + def test_string_to_hash_bucket_fast(self, input_shape, num_buckets, strings_dictionary, ie_device, precision, + ir_version, temp_dir, + use_legacy_frontend): + if ie_device == 'GPU' or run_in_jenkins(): + pytest.skip("operation extension is not supported on GPU") + self._test(*self.create_string_to_hash_bucket_fast_net(input_shape=input_shape, num_buckets=num_buckets, + strings_dictionary=strings_dictionary), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_legacy_frontend=use_legacy_frontend)