From 930a3d395d59bf2ec78113a64d8174c232290e65 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Tue, 27 Feb 2024 17:48:31 +0400 Subject: [PATCH] [TF FE][Extensions] Support StaticRegexReplace operation (#41) * [TF FE][Extensions] Support StaticRegexReplace operation Signed-off-by: Kazantsev, Roman * Support global_rewrite case Signed-off-by: Kazantsev, Roman * Update src/utils.cpp * Update src/utils.cpp --------- Signed-off-by: Kazantsev, Roman --- src/regex_normalization.cpp | 18 ++++++++++++++---- src/regex_normalization.hpp | 5 ++++- src/tensorflow_translators.cpp | 14 +++++++++----- src/utils.cpp | 28 ++++++++++++---------------- src/utils.hpp | 4 +++- 5 files changed, 42 insertions(+), 27 deletions(-) diff --git a/src/regex_normalization.cpp b/src/regex_normalization.cpp index dd95e85d4..31a9563f0 100644 --- a/src/regex_normalization.cpp +++ b/src/regex_normalization.cpp @@ -10,10 +10,20 @@ using namespace ov; -RegexNormalization::RegexNormalization(const ov::OutputVector& arguments) : - ov::op::Op(arguments) { - constructor_validate_and_infer_types(); - } +RegexNormalization::RegexNormalization( + const ov::OutputVector& arguments, + bool global_replace +) : ov::op::Op(arguments), +m_global_replace(global_replace) { + auto search_pattern_const = as_type_ptr(arguments[3].get_node_shared_ptr()); + auto replace_pattern_const = as_type_ptr(arguments[4].get_node_shared_ptr()); + auto search_pattern_buf = static_cast(search_pattern_const->get_data_ptr()); + auto replace_pattern_buf = static_cast(replace_pattern_const->get_data_ptr()); + auto search_pattern = absl::string_view((const char*)search_pattern_buf, search_pattern_const->get_byte_size()); + m_replace_pattern = absl::string_view((const char*)replace_pattern_buf, replace_pattern_const->get_byte_size()); + m_search_pattern_re = std::make_shared(search_pattern); + constructor_validate_and_infer_types(); +} RegexNormalization::RegexNormalization( diff --git a/src/regex_normalization.hpp b/src/regex_normalization.hpp index cef8827f7..2307e6a15 100644 --- a/src/regex_normalization.hpp +++ b/src/regex_normalization.hpp @@ -18,7 +18,10 @@ class RegexNormalization : public ov::op::Op { OPENVINO_OP("RegexNormalization"); RegexNormalization () = default; - RegexNormalization(const ov::OutputVector& arguments); // not used + RegexNormalization( + const ov::OutputVector& arguments, + bool global_replace = true + ); RegexNormalization( const ov::OutputVector& arguments, const std::shared_ptr& search_pattern_re, diff --git a/src/tensorflow_translators.cpp b/src/tensorflow_translators.cpp index 86a6bdd71..56e72b644 100644 --- a/src/tensorflow_translators.cpp +++ b/src/tensorflow_translators.cpp @@ -105,11 +105,15 @@ ov::OutputVector translate_normalize_utf8(const ov::frontend::NodeContext& node) } ov::OutputVector translate_static_regex_replace(const ov::frontend::NodeContext& node) { + auto node_name = node.get_name(); FRONT_END_GENERAL_CHECK(node.get_input_size() == 1, "StaticRegexReplace expects only 1 input"); + auto replace_global = node.get_attribute("replace_global", true); ov::OutputVector inputs = pre_translate_string_tensor_input(node.get_input(0)); inputs.push_back(string_attribute_to_constant(node, "pattern")); inputs.push_back(string_attribute_to_constant(node, "rewrite")); - return { post_translate_string_tensor_output(std::make_shared(inputs)->outputs()) }; + auto string_pack_result = post_translate_string_tensor_output(std::make_shared(inputs, replace_global)->outputs()); + set_node_name(node_name, string_pack_result.get_node_shared_ptr()); + return { string_pack_result }; } ov::OutputVector translate_regex_split_with_offsets(const ov::frontend::NodeContext& node) { @@ -119,7 +123,7 @@ ov::OutputVector translate_regex_split_with_offsets(const ov::frontend::NodeCont inputs.push_back(delim_regex_pattern); // TODO: Use node.get_input(2) with keep_delim_regex_pattern, most likely it should be handled in another RegexSplit with `isolate` behaviour auto outputs = std::make_shared(inputs)->outputs(); - auto flatten_string_tensor = post_translate_string_tensor_output({outputs[2], outputs[3], outputs[4]}); + auto flatten_string_tensor = post_translate_string_tensor_output({ outputs[2], outputs[3], outputs[4] }); return { post_translate_ragged_tensor_output({outputs[0], outputs[1], flatten_string_tensor}) }; } @@ -127,14 +131,14 @@ ov::OutputVector translate_wordpiece_tokenize_with_offsets(const ov::frontend::N FRONT_END_GENERAL_CHECK(node.get_input_size() == 2, "WordpieceTokenizeWithOffsets expects 2 inputs"); ov::OutputVector inputs = pre_translate_ragged_string_tensor_input(node.get_input(0)); - #if USE_STRING_TENSORS +#if USE_STRING_TENSORS // It may seem enough to call pre_translate_string_tensor_input that will override Parameter element // type in case if string tensors are not used. // But a Parameter is still required to be overridden even if string tensors are used because in TF model // it is represented not as a string tensor, but as a resource with hash table for lookup that we cannot interpret // and have to replace by 1D string tensor. - override_parameter(node.get_input(1).get_node_shared_ptr(), element::string, PartialShape{Dimension()}); - #endif + override_parameter(node.get_input(1).get_node_shared_ptr(), element::string, PartialShape{ Dimension() }); +#endif auto vocab = pre_translate_string_tensor_input(node.get_input(1)); inputs.insert(inputs.end(), vocab.begin(), vocab.end()); diff --git a/src/utils.cpp b/src/utils.cpp index 334689aeb..11f19082c 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -125,27 +125,15 @@ void override_parameter (std::shared_ptr node, element::Type type, con } } -// TODO: replace NodeContext and input_index by a single input -OutputVector pre_translate_string_tensor_input(ov::Output input) { +OutputVector pre_translate_string_tensor_input(const ov::Output& input) { auto input_node = input.get_node_shared_ptr(); -#if !USE_STRING_TENSORS - override_parameter(input_node, element::u8, PartialShape{Dimension()}); -#endif - if (auto struct_pack = std::dynamic_pointer_cast(input_node)) { FRONT_END_GENERAL_CHECK(struct_pack->get_input_size() == 3, "Expected 3 inputs to StringTensorPack which represents a string tensor"); return struct_pack->input_values(); - } else { - #if USE_STRING_TENSORS || true // always - return std::make_shared(OutputVector{input}, "begins_ends")->outputs(); - #else - // Suppose this is u8 packed string tensor with a single batch dimension - // Unpack this tensor using standard operations - - // Cannot do that because there is not ReinterprectCast operation in OV - // TODO: Find a way to make it without reinterpretation operation or introduce it as an extension (easy) - #endif + } + else { + return std::make_shared(OutputVector{ input }, "begins_ends")->outputs(); } } @@ -221,3 +209,11 @@ std::shared_ptr string_attribute_to_constant (const ov::frontend::NodeCont return std::make_shared(element::u8, Shape{value.length()}, (const void*)value.data()); #endif } + +void set_node_name(const std::string& node_name, const std::shared_ptr& node) { + const auto& outputs = node->outputs(); + node->set_friendly_name(node_name); + for (size_t idx = 0; idx < outputs.size(); ++idx) { + outputs[idx].get_tensor().add_names({ node_name + ":" + std::to_string(idx) }); + } +} diff --git a/src/utils.hpp b/src/utils.hpp index da0634687..7fafb011b 100644 --- a/src/utils.hpp +++ b/src/utils.hpp @@ -52,7 +52,7 @@ void unpack_strings_to_tensors(const std::string* strings, const ov::Shape shape void override_parameter (std::shared_ptr node, ov::element::Type type, const ov::PartialShape& shape); -ov::OutputVector pre_translate_string_tensor_input(ov::Output input); +ov::OutputVector pre_translate_string_tensor_input(const ov::Output& input); ov::OutputVector pre_translate_ragged_tensor_input(ov::Output input); @@ -68,3 +68,5 @@ bool evaluate_normalization_helper ( std::function normalizer); std::shared_ptr string_attribute_to_constant (const ov::frontend::NodeContext& node, const std::string& name); + +void set_node_name(const std::string& node_name, const std::shared_ptr& node);