From f0afa720bdfd8f9c99e06c8a189945bc852cef51 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Tue, 9 Apr 2024 20:34:41 +0400 Subject: [PATCH 1/3] [TF FE] Support StringHashBucketFast operation Signed-off-by: Kazantsev, Roman --- .../tensorflow/docs/supported_ops.md | 2 +- src/frontends/tensorflow/src/op_table.cpp | 2 +- .../test_tf_StringToHashBucketFast.py | 60 +++++++++++++++++++ 3 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 tests/layer_tests/tensorflow_tests/test_tf_StringToHashBucketFast.py diff --git a/src/frontends/tensorflow/docs/supported_ops.md b/src/frontends/tensorflow/docs/supported_ops.md index f206839f3aacbb..7206a2af141041 100644 --- a/src/frontends/tensorflow/docs/supported_ops.md +++ b/src/frontends/tensorflow/docs/supported_ops.md @@ -1227,7 +1227,7 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV | StringSplitV2NEW | YES | openvino-tokenizers required | | StringStrip | NO | | | StringToHashBucket | NO | | -| StringToHashBucketFast | NO | | +| StringToHashBucketFast | YES | openvino-tokenizers required | | StringToHashBucketStrong | NO | | | StringToNumber | NO | | | StringUpper | NO | | diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index e2abb859e3bc75..a9020c02150ec9 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -461,7 +461,7 @@ const std::map get_supported_ops() { }; const std::vector get_supported_ops_via_tokenizers() { - return {"RaggedTensorToSparse", "RaggedTensorToTensor", "StaticRegexReplace", "StringLower", "StringSplitV2"}; + return {"RaggedTensorToSparse", "RaggedTensorToTensor", "StaticRegexReplace", "StringLower", "StringSplitV2", "StringToHashBucketFast"}; } } // namespace op } // namespace tensorflow diff --git a/tests/layer_tests/tensorflow_tests/test_tf_StringToHashBucketFast.py b/tests/layer_tests/tensorflow_tests/test_tf_StringToHashBucketFast.py new file mode 100644 index 00000000000000..08812fe7b46228 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_StringToHashBucketFast.py @@ -0,0 +1,60 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import platform + +import numpy as np +import pytest +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest +from common.utils.tf_utils import run_in_jenkins + +rng = np.random.default_rng() + + +class TestStringToHashBucketFast(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'input:0' in inputs_info + input_shape = inputs_info['input:0'] + inputs_data = {} + sample_data = rng.choice(self.strings_dictionary, input_shape) + inputs_data['input:0'] = sample_data + return inputs_data + + def create_string_to_hash_bucket_fast_net(self, input_shape, strings_dictionary, num_buckets): + self.strings_dictionary = strings_dictionary + + tf.compat.v1.reset_default_graph() + with tf.compat.v1.Session() as sess: + input = tf.compat.v1.placeholder(tf.string, input_shape, 'input') + tf.raw_ops.StringToHashBucketFast(input=input, num_buckets=num_buckets) + + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + ref_net = None + + return tf_net, ref_net + + @pytest.mark.parametrize("input_shape", [[], [2], [3, 4], [1, 3, 2]]) + @pytest.mark.parametrize("num_buckets", [1, 4, 7, 11]) + @pytest.mark.parametrize("strings_dictionary", + [['UPPER CASE SENTENCE', 'lower case sentence', ' UppEr LoweR CAse SENtence', ' '], + ['Первое Предложение', 'второе предложение', ' ', ' ТРЕТЬЕ ПРЕДЛОЖЕНИЕ '], + ['第一句話在這裡', '第二句話在這裡', '第三句話在這裡'], + ['', ' ', '12345 ']]) + @pytest.mark.precommit + @pytest.mark.nightly + @pytest.mark.xfail(condition=platform.system() in ('Darwin', 'Linux') and platform.machine() in ['arm', 'armv7l', + 'aarch64', + 'arm64', 'ARM64'], + reason='Ticket - 126314, 132699') + def test_string_to_hash_bucket_fast(self, input_shape, num_buckets, strings_dictionary, ie_device, precision, + ir_version, temp_dir, + use_legacy_frontend): + if ie_device == 'GPU' or run_in_jenkins(): + pytest.skip("operation extension is not supported on GPU") + self._test(*self.create_string_to_hash_bucket_fast_net(input_shape=input_shape, num_buckets=num_buckets, + strings_dictionary=strings_dictionary), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_legacy_frontend=use_legacy_frontend) From b134adb4b1d42bfdcd880934039e6d605ab7957b Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Tue, 9 Apr 2024 20:40:07 +0400 Subject: [PATCH 2/3] Update src/frontends/tensorflow/src/op_table.cpp --- src/frontends/tensorflow/src/op_table.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index a9020c02150ec9..e306dc3ee266b0 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -461,7 +461,12 @@ const std::map get_supported_ops() { }; const std::vector get_supported_ops_via_tokenizers() { - return {"RaggedTensorToSparse", "RaggedTensorToTensor", "StaticRegexReplace", "StringLower", "StringSplitV2", "StringToHashBucketFast"}; + return {"RaggedTensorToSparse", + "RaggedTensorToTensor", + "StaticRegexReplace", + "StringLower", + "StringSplitV2", + "StringToHashBucketFast"}; } } // namespace op } // namespace tensorflow From a9b88e25a6c240866e866948ffc0ff4d01de2082 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Tue, 9 Apr 2024 20:40:53 +0400 Subject: [PATCH 3/3] Update src/frontends/tensorflow/src/op_table.cpp --- src/frontends/tensorflow/src/op_table.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index e306dc3ee266b0..3422148db2d771 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -462,11 +462,11 @@ const std::map get_supported_ops() { const std::vector get_supported_ops_via_tokenizers() { return {"RaggedTensorToSparse", - "RaggedTensorToTensor", - "StaticRegexReplace", - "StringLower", - "StringSplitV2", - "StringToHashBucketFast"}; + "RaggedTensorToTensor", + "StaticRegexReplace", + "StringLower", + "StringSplitV2", + "StringToHashBucketFast"}; } } // namespace op } // namespace tensorflow