Skip to content

Commit

Permalink
[TF FE][Extensions] Support StaticRegexReplace operation (#41)
Browse files Browse the repository at this point in the history
* [TF FE][Extensions] Support StaticRegexReplace operation

Signed-off-by: Kazantsev, Roman <[email protected]>

* Support global_rewrite case

Signed-off-by: Kazantsev, Roman <[email protected]>

* Update src/utils.cpp

* Update src/utils.cpp

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Feb 27, 2024
1 parent 3147ab9 commit 930a3d3
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 27 deletions.
18 changes: 14 additions & 4 deletions src/regex_normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,20 @@
using namespace ov;


RegexNormalization::RegexNormalization(const ov::OutputVector& arguments) :
ov::op::Op(arguments) {
constructor_validate_and_infer_types();
}
RegexNormalization::RegexNormalization(
const ov::OutputVector& arguments,
bool global_replace
) : ov::op::Op(arguments),
m_global_replace(global_replace) {
auto search_pattern_const = as_type_ptr<Constant>(arguments[3].get_node_shared_ptr());
auto replace_pattern_const = as_type_ptr<Constant>(arguments[4].get_node_shared_ptr());
auto search_pattern_buf = static_cast<const char*>(search_pattern_const->get_data_ptr());
auto replace_pattern_buf = static_cast<const char*>(replace_pattern_const->get_data_ptr());
auto search_pattern = absl::string_view((const char*)search_pattern_buf, search_pattern_const->get_byte_size());
m_replace_pattern = absl::string_view((const char*)replace_pattern_buf, replace_pattern_const->get_byte_size());
m_search_pattern_re = std::make_shared<re2::RE2>(search_pattern);
constructor_validate_and_infer_types();
}


RegexNormalization::RegexNormalization(
Expand Down
5 changes: 4 additions & 1 deletion src/regex_normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ class RegexNormalization : public ov::op::Op {
OPENVINO_OP("RegexNormalization");

RegexNormalization () = default;
RegexNormalization(const ov::OutputVector& arguments); // not used
RegexNormalization(
const ov::OutputVector& arguments,
bool global_replace = true
);
RegexNormalization(
const ov::OutputVector& arguments,
const std::shared_ptr<re2::RE2>& search_pattern_re,
Expand Down
14 changes: 9 additions & 5 deletions src/tensorflow_translators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,15 @@ ov::OutputVector translate_normalize_utf8(const ov::frontend::NodeContext& node)
}

ov::OutputVector translate_static_regex_replace(const ov::frontend::NodeContext& node) {
auto node_name = node.get_name();
FRONT_END_GENERAL_CHECK(node.get_input_size() == 1, "StaticRegexReplace expects only 1 input");
auto replace_global = node.get_attribute<bool>("replace_global", true);
ov::OutputVector inputs = pre_translate_string_tensor_input(node.get_input(0));
inputs.push_back(string_attribute_to_constant(node, "pattern"));
inputs.push_back(string_attribute_to_constant(node, "rewrite"));
return { post_translate_string_tensor_output(std::make_shared<RegexNormalization>(inputs)->outputs()) };
auto string_pack_result = post_translate_string_tensor_output(std::make_shared<RegexNormalization>(inputs, replace_global)->outputs());
set_node_name(node_name, string_pack_result.get_node_shared_ptr());
return { string_pack_result };
}

ov::OutputVector translate_regex_split_with_offsets(const ov::frontend::NodeContext& node) {
Expand All @@ -119,22 +123,22 @@ ov::OutputVector translate_regex_split_with_offsets(const ov::frontend::NodeCont
inputs.push_back(delim_regex_pattern);
// TODO: Use node.get_input(2) with keep_delim_regex_pattern, most likely it should be handled in another RegexSplit with `isolate` behaviour
auto outputs = std::make_shared<RegexSplit>(inputs)->outputs();
auto flatten_string_tensor = post_translate_string_tensor_output({outputs[2], outputs[3], outputs[4]});
auto flatten_string_tensor = post_translate_string_tensor_output({ outputs[2], outputs[3], outputs[4] });
return { post_translate_ragged_tensor_output({outputs[0], outputs[1], flatten_string_tensor}) };
}

ov::OutputVector translate_wordpiece_tokenize_with_offsets(const ov::frontend::NodeContext& node) {
FRONT_END_GENERAL_CHECK(node.get_input_size() == 2, "WordpieceTokenizeWithOffsets expects 2 inputs");
ov::OutputVector inputs = pre_translate_ragged_string_tensor_input(node.get_input(0));

#if USE_STRING_TENSORS
#if USE_STRING_TENSORS
// It may seem enough to call pre_translate_string_tensor_input that will override Parameter element
// type in case if string tensors are not used.
// But a Parameter is still required to be overridden even if string tensors are used because in TF model
// it is represented not as a string tensor, but as a resource with hash table for lookup that we cannot interpret
// and have to replace by 1D string tensor.
override_parameter(node.get_input(1).get_node_shared_ptr(), element::string, PartialShape{Dimension()});
#endif
override_parameter(node.get_input(1).get_node_shared_ptr(), element::string, PartialShape{ Dimension() });
#endif

auto vocab = pre_translate_string_tensor_input(node.get_input(1));
inputs.insert(inputs.end(), vocab.begin(), vocab.end());
Expand Down
28 changes: 12 additions & 16 deletions src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,27 +125,15 @@ void override_parameter (std::shared_ptr<ov::Node> node, element::Type type, con
}
}

// TODO: replace NodeContext and input_index by a single input
OutputVector pre_translate_string_tensor_input(ov::Output<ov::Node> input) {
OutputVector pre_translate_string_tensor_input(const ov::Output<ov::Node>& input) {
auto input_node = input.get_node_shared_ptr();

#if !USE_STRING_TENSORS
override_parameter(input_node, element::u8, PartialShape{Dimension()});
#endif

if (auto struct_pack = std::dynamic_pointer_cast<StringTensorPack>(input_node)) {
FRONT_END_GENERAL_CHECK(struct_pack->get_input_size() == 3, "Expected 3 inputs to StringTensorPack which represents a string tensor");
return struct_pack->input_values();
} else {
#if USE_STRING_TENSORS || true // always
return std::make_shared<StringTensorUnpack>(OutputVector{input}, "begins_ends")->outputs();
#else
// Suppose this is u8 packed string tensor with a single batch dimension
// Unpack this tensor using standard operations

// Cannot do that because there is not ReinterprectCast operation in OV
// TODO: Find a way to make it without reinterpretation operation or introduce it as an extension (easy)
#endif
}
else {
return std::make_shared<StringTensorUnpack>(OutputVector{ input }, "begins_ends")->outputs();
}
}

Expand Down Expand Up @@ -221,3 +209,11 @@ std::shared_ptr<Node> string_attribute_to_constant (const ov::frontend::NodeCont
return std::make_shared<Constant>(element::u8, Shape{value.length()}, (const void*)value.data());
#endif
}

void set_node_name(const std::string& node_name, const std::shared_ptr<Node>& node) {
const auto& outputs = node->outputs();
node->set_friendly_name(node_name);
for (size_t idx = 0; idx < outputs.size(); ++idx) {
outputs[idx].get_tensor().add_names({ node_name + ":" + std::to_string(idx) });
}
}
4 changes: 3 additions & 1 deletion src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void unpack_strings_to_tensors(const std::string* strings, const ov::Shape shape

void override_parameter (std::shared_ptr<ov::Node> node, ov::element::Type type, const ov::PartialShape& shape);

ov::OutputVector pre_translate_string_tensor_input(ov::Output<ov::Node> input);
ov::OutputVector pre_translate_string_tensor_input(const ov::Output<ov::Node>& input);

ov::OutputVector pre_translate_ragged_tensor_input(ov::Output<ov::Node> input);

Expand All @@ -68,3 +68,5 @@ bool evaluate_normalization_helper (
std::function<std::string(const std::string&)> normalizer);

std::shared_ptr<ov::Node> string_attribute_to_constant (const ov::frontend::NodeContext& node, const std::string& name);

void set_node_name(const std::string& node_name, const std::shared_ptr<ov::Node>& node);

0 comments on commit 930a3d3

Please sign in to comment.