Skip to content

Commit

Permalink
[TF FE][Extensions] Support StringLower operation (#44)
Browse files Browse the repository at this point in the history
Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Feb 28, 2024
1 parent 6b9fa42 commit 649413a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/ov_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ OPENVINO_CREATE_EXTENSIONS(
std::make_shared<ov::OpExtension<TemplateExtension::SentencepieceStreamDetokenizer>>(),
std::make_shared<ov::frontend::ConversionExtension>("SentencepieceOp", translate_sentencepiece_op),
std::make_shared<ov::frontend::ConversionExtension>("RaggedTensorToSparse", translate_sentencepiece_tokenizer),
std::make_shared<ov::frontend::ConversionExtension>("StringLower", translate_string_lower),
}));
//! [ov_extension:entry_point]
// clang-format on
10 changes: 10 additions & 0 deletions src/tensorflow_translators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,13 @@ ov::OutputVector translate_wordpiece_tokenize_with_offsets(const ov::frontend::N
);
return { post_translate_ragged_tensor_output(wp_tokenizer->outputs()) };
}

ov::OutputVector translate_string_lower(const ov::frontend::NodeContext& node) {
auto node_name = node.get_name();
FRONT_END_GENERAL_CHECK(node.get_input_size() == 1, "StringLower expects only 1 input");
auto encoding = node.get_attribute<std::string>("encoding", "");
ov::OutputVector inputs = pre_translate_string_tensor_input(node.get_input(0));
auto string_lower_result = post_translate_string_tensor_output(std::make_shared<CaseFold>(inputs, encoding)->outputs());
set_node_name(node_name, string_lower_result.get_node_shared_ptr());
return { string_lower_result };
}
1 change: 1 addition & 0 deletions src/tensorflow_translators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ ov::OutputVector translate_normalize_utf8(const ov::frontend::NodeContext& node)
ov::OutputVector translate_static_regex_replace(const ov::frontend::NodeContext& node);
ov::OutputVector translate_regex_split_with_offsets(const ov::frontend::NodeContext& node);
ov::OutputVector translate_wordpiece_tokenize_with_offsets(const ov::frontend::NodeContext& node);
ov::OutputVector translate_string_lower(const ov::frontend::NodeContext& node);

0 comments on commit 649413a

Please sign in to comment.