Skip to content

Commit

Permalink
Revert "hotfix for RE2 segfaults (#284)" (#293)
Browse files Browse the repository at this point in the history
This reverts commit ce9d51c.
  • Loading branch information
rkazants authored Oct 18, 2024
1 parent ce9d51c commit b0b0dc9
Show file tree
Hide file tree
Showing 8 changed files with 4,370 additions and 4,290 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -603,19 +603,19 @@ This report is autogenerated and includes tokenizers and detokenizers tests. The
<tr>
<td >SentencePiece</td>
<td >NousResearch/Llama-2-13b-hf</td>
<td >96.73</td>
<td >94.29</td>
<td >245</td>
</tr>
<tr>
<td >SentencePiece</td>
<td >NousResearch/Llama-2-13b-hf_legacy</td>
<td >95.92</td>
<td >97.55</td>
<td >245</td>
</tr>
<tr>
<td >SentencePiece</td>
<td >NousResearch/Llama-2-13b-hf_sp_backend</td>
<td >95.10</td>
<td >97.55</td>
<td >245</td>
</tr>
<tr>
Expand Down Expand Up @@ -717,7 +717,7 @@ This report is autogenerated and includes tokenizers and detokenizers tests. The
<tr>
<td >SentencePiece</td>
<td >rinna/bilingual-gpt-neox-4b</td>
<td >82.04</td>
<td >80.41</td>
<td >245</td>
</tr>
<tr>
Expand Down
14 changes: 7 additions & 7 deletions python/openvino_tokenizers/tokenizer_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,23 +206,23 @@ def strip_accents_regex(cls) -> "RegexNormalizationStep":

@classmethod
def add_prefix_whitespace_regex(cls) -> "RegexNormalizationStep":
return cls(regex_search_pattern=r"^(\S)", replace_term=r" $1")
return cls(regex_search_pattern=r"^(\S)", replace_term=r" \1")

@classmethod
def add_prefix_whitespace_to_not_whitespace_regex(cls) -> "RegexNormalizationStep":
return cls(regex_search_pattern=r"^([^ ])", replace_term=r" $1")
return cls(regex_search_pattern=r"^([^ ])", replace_term=r" \1")

@classmethod
def replace_spaces_metaspace(cls, replace_term=r"▁") -> "RegexNormalizationStep":
return cls(regex_search_pattern=r" ", replace_term=replace_term)

@classmethod
def prepend_regex(cls, string: str) -> "RegexNormalizationStep":
return cls(regex_search_pattern=r"(^)(.+)", replace_term=rf"{string}$2")
return cls(regex_search_pattern=r"(^)(.+)", replace_term=rf"{string}\2")

@classmethod
def prepend_with_check_regex(cls, string: str, check_string: str) -> "RegexNormalizationStep":
return cls(regex_search_pattern=rf"(^)([^{check_string}])", replace_term=rf"{string}$2")
return cls(regex_search_pattern=rf"(^)([^{check_string}])", replace_term=rf"{string}\2")

@classmethod
def del_control_chars_regex(cls) -> "RegexNormalizationStep":
Expand All @@ -235,7 +235,7 @@ def del_control_chars_regex(cls) -> "RegexNormalizationStep":
def clean_up_tokenization_spaces(cls) -> "RegexNormalizationStep":
return cls(
regex_search_pattern=r" ([\.\?\!\,])| ('[ms])| (') | ('[rv]e)",
replace_term="$1",
replace_term="\1",
)

def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
Expand Down Expand Up @@ -1077,7 +1077,7 @@ class RegexDecodingStep(DecodingStep):
def clean_up_tokenization_spaces(cls) -> "RegexDecodingStep":
return cls(
regex_search_pattern=r" ([\\.\\?\\!,])| ('[ms])| (') | ('[rv]e)| (n't)",
replace_term=r"$1",
replace_term=r"\1",
)

@classmethod
Expand Down Expand Up @@ -1115,7 +1115,7 @@ def strip_forward_space(cls) -> "RegexDecodingStep":
def strip_forward_space_before_not_space(cls) -> "RegexDecodingStep":
return cls(
regex_search_pattern=r"(^ )([^ ])",
replace_term=r"$2",
replace_term=r"\2",
)

@classmethod
Expand Down
54 changes: 47 additions & 7 deletions src/regex_normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,28 @@ m_global_replace(global_replace) {
auto search_pattern = std::string(search_pattern_buf, search_pattern_const->get_byte_size());
m_replace_pattern = std::string(replace_pattern_buf, replace_pattern_const->get_byte_size());

m_search_pattern_pcre2 = std::make_shared<PCRE2Wrapper>(search_pattern);
auto options = re2::RE2::Options();
options.set_log_errors(false);
m_search_pattern_re = std::make_shared<re2::RE2>(search_pattern, options);

if (m_search_pattern_re->NumberOfCapturingGroups() == -1) {
// If RE2 was unable to process pattern.
m_search_pattern_pcre2 = std::make_shared<PCRE2Wrapper>(search_pattern);
m_search_pattern_re = nullptr;
}

constructor_validate_and_infer_types();
}


RegexNormalization::RegexNormalization(
const ov::OutputVector& arguments,
const std::shared_ptr<re2::RE2>& search_pattern_re,
const std::shared_ptr<PCRE2Wrapper>& search_pattern_pcre2,
const std::string replace_pattern,
bool global_replace
) : ov::op::Op(arguments),
m_search_pattern_re(search_pattern_re),
m_search_pattern_pcre2(search_pattern_pcre2),
m_replace_pattern(replace_pattern),
m_global_replace(global_replace) {
Expand All @@ -47,14 +57,25 @@ RegexNormalization::RegexNormalization(
const char* replace_pattern_buf;
std::string search_pattern;

if (m_search_pattern_pcre2 == nullptr) {
if (m_search_pattern_re == nullptr || m_search_pattern_pcre2 == nullptr) {
search_pattern_buf = static_cast<const char*>(search_pattern_const->get_data_ptr());
replace_pattern_buf = static_cast<const char*>(replace_pattern_const->get_data_ptr());
search_pattern = std::string(search_pattern_buf, search_pattern_const->get_byte_size());
m_replace_pattern = std::string(replace_pattern_buf, replace_pattern_const->get_byte_size());
};

auto options = re2::RE2::Options();
options.set_log_errors(false);
if (m_search_pattern_re == nullptr) {
auto options = re2::RE2::Options();
options.set_log_errors(false);
m_search_pattern_re = std::make_shared<re2::RE2>(search_pattern, options);
}

if (m_search_pattern_re->NumberOfCapturingGroups() == -1 && m_search_pattern_pcre2 == nullptr) {
m_search_pattern_pcre2 = std::make_shared<PCRE2Wrapper>(search_pattern);
m_search_pattern_re = nullptr;
}

constructor_validate_and_infer_types();
}

Expand Down Expand Up @@ -83,17 +104,36 @@ bool RegexNormalization::evaluate(ov::TensorVector& outputs, const ov::TensorVec
const bool has_skips = (inputs.size() == 6);
const auto pattern_input = 3 + has_skips;

if (m_search_pattern_pcre2 == nullptr) {
std::string search_pattern = std::string(inputs[pattern_input].data<const char>(), inputs[pattern_input].get_size());
std::string search_pattern;
if (m_search_pattern_re == nullptr || m_search_pattern_pcre2 == nullptr) {
search_pattern = std::string(inputs[pattern_input].data<const char>(), inputs[pattern_input].get_size());
m_replace_pattern = std::string(inputs[pattern_input + 1].data<const char>(), inputs[pattern_input + 1].get_size());

auto options = re2::RE2::Options();
options.set_log_errors(false);
m_search_pattern_re = std::make_shared<re2::RE2>(search_pattern, options);
}

if ((m_search_pattern_re == nullptr) || (m_search_pattern_re->NumberOfCapturingGroups() == -1 && m_search_pattern_pcre2 == nullptr)) {
m_search_pattern_pcre2 = std::make_shared<PCRE2Wrapper>(search_pattern);
m_search_pattern_re = nullptr;
}

return evaluate_normalization_helper(
outputs, inputs,
[this](const std::string& str) -> std::string {
if (m_search_pattern_pcre2) {
return m_search_pattern_pcre2->substitute(str, m_replace_pattern, m_global_replace);
std::string result = str;

// Use RE2 where possible, and fallback to PCRE2 if RE2 was not able to process.
if (m_search_pattern_re) {
if (m_global_replace) {
re2::RE2::GlobalReplace(&result, *m_search_pattern_re, m_replace_pattern);
} else {
re2::RE2::Replace(&result, *m_search_pattern_re, m_replace_pattern);
};
return result;
} else if (m_search_pattern_pcre2) {
return m_search_pattern_pcre2->substitute(result, m_replace_pattern, m_global_replace);
} else {
return str;
}
Expand Down
4 changes: 4 additions & 0 deletions src/regex_normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <openvino/op/op.hpp>
#include "openvino/opsets/opset13.hpp"
#include <re2/re2.h>
#include <pcre2.h>

using namespace ov;
Expand All @@ -25,6 +26,7 @@ class RegexNormalization : public ov::op::Op {
);
RegexNormalization(
const ov::OutputVector& arguments,
const std::shared_ptr<re2::RE2>& search_pattern_re,
const std::shared_ptr<PCRE2Wrapper>& search_pattern_rcre2,
const std::string replace_pattern,
bool global_replace = true
Expand All @@ -35,6 +37,7 @@ class RegexNormalization : public ov::op::Op {
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& inputs) const override {
return std::make_shared<RegexNormalization>(
inputs,
m_search_pattern_re,
m_search_pattern_pcre2,
m_replace_pattern,
m_global_replace
Expand All @@ -52,6 +55,7 @@ class RegexNormalization : public ov::op::Op {
return true;
}
private:
mutable std::shared_ptr<re2::RE2> m_search_pattern_re;
mutable std::shared_ptr<PCRE2Wrapper> m_search_pattern_pcre2;
mutable std::string m_replace_pattern;
bool m_global_replace = true;
Expand Down
51 changes: 42 additions & 9 deletions src/regex_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const std::map<std::string, RegexSplit::SplitMode> split_modes_map = {
void RegexSplit::compile_pattern_if_necessary(std::string split_pattern) const {
m_split_mode = split_modes_map.at(m_behaviour);

if (m_search_pattern_pcre2) {
if (m_search_pattern_re2 || m_search_pattern_pcre2) {
return;
}

Expand All @@ -35,7 +35,18 @@ void RegexSplit::compile_pattern_if_necessary(std::string split_pattern) const {
tmp_stream << "(" << split_pattern << ")+";
split_pattern = tmp_stream.str();
}
m_search_pattern_pcre2 = std::make_shared<PCRE2Wrapper>(split_pattern);

if (m_search_pattern_re2 == nullptr) {
auto options = re2::RE2::Options();
options.set_log_errors(false);
m_search_pattern_re2 = std::make_shared<re2::RE2>(split_pattern, options);
}

if (m_search_pattern_re2->NumberOfCapturingGroups() == -1) {
// If RE2 was unable to process pattern use PCRE2.
m_search_pattern_pcre2 = std::make_shared<PCRE2Wrapper>(split_pattern);
m_search_pattern_re2 = nullptr;
}
}


Expand All @@ -49,12 +60,14 @@ RegexSplit::RegexSplit(const ov::OutputVector& arguments, const std::string& beh

RegexSplit::RegexSplit(
const ov::OutputVector& arguments,
const std::shared_ptr<re2::RE2>& search_pattern_re2,
const std::shared_ptr<PCRE2Wrapper>& search_pattern_pcre2,
const std::string& behaviour,
bool invert,
int max_splits
) :
ov::op::Op(arguments),
m_search_pattern_re2(search_pattern_re2),
m_search_pattern_pcre2(search_pattern_pcre2),
m_behaviour(behaviour),
m_invert(invert),
Expand All @@ -72,13 +85,15 @@ RegexSplit::RegexSplit(

RegexSplit::RegexSplit(
const ov::OutputVector& arguments,
const std::shared_ptr<re2::RE2>& search_pattern_re2,
const std::shared_ptr<PCRE2Wrapper>& search_pattern_pcre2,
const std::shared_ptr<std::set<std::string>>& skip_tokens,
const std::string& behaviour,
bool invert,
int max_splits
) :
ov::op::Op(arguments),
m_search_pattern_re2(search_pattern_re2),
m_search_pattern_pcre2(search_pattern_pcre2),
m_skip_tokens(skip_tokens),
m_behaviour(behaviour),
Expand Down Expand Up @@ -125,14 +140,31 @@ bool RegexSplit::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inp
auto split_pattern = std::string(inputs[5].data<const char>(), inputs[5].get_size());
compile_pattern_if_necessary(split_pattern);

auto get_next_match = [this](const std::string& str, size_t curr_start) -> std::optional<std::pair<size_t, size_t>>{
auto match = this->m_search_pattern_pcre2->match(str, curr_start);
if (match.first != SIZE_MAX && match.first != match.second) {
return match;
} else {
// If RE2 didn't compiled successfully fallback to PCRE2 matcher.
std::function<std::optional<std::pair<size_t, size_t>>(const std::string&, size_t)> get_next_match;
if (m_search_pattern_re2) {
get_next_match = [this](const std::string& str, size_t curr_start) -> std::optional<std::pair<size_t, size_t>>{
re2::StringPiece result;
bool flag = this->m_search_pattern_re2->Match(str, curr_start, str.length(), RE2::UNANCHORED, &result, 1);
if (flag) {
size_t start = result.data() - str.data();
size_t end = start + result.length();
if (start != end) {
return std::pair(start, end);
}
}
return std::nullopt;
}
};
};
} else {
get_next_match = [this](const std::string& str, size_t curr_start) -> std::optional<std::pair<size_t, size_t>>{
auto match = this->m_search_pattern_pcre2->match(str, curr_start);
if (match.first != SIZE_MAX && match.first != match.second) {
return match;
} else {
return std::nullopt;
}
};
}

auto input_size = get_input_size();
if (input_size == 9 && m_skip_tokens == nullptr && inputs[6].get_size() > 0) {
Expand Down Expand Up @@ -206,6 +238,7 @@ bool RegexSplit::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inp
new_ends[ragged_offset++] = ends[ragged_col];
} else {
size_t start = 0;
re2::StringPiece result;
uint32_t num_splits = 0;

size_t last_begin = -1;
Expand Down
6 changes: 5 additions & 1 deletion src/regex_split.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ class RegexSplit : public ov::op::Op {
RegexSplit(const ov::OutputVector& arguments, const std::string& behaviour = "remove", bool invert = false);
RegexSplit(
const ov::OutputVector& arguments,
const std::shared_ptr<re2::RE2>& search_pattern_re2,
const std::shared_ptr<PCRE2Wrapper>& search_pattern_pcre2,
const std::string& behaviour = "remove",
bool invert = false,
int max_splits = -1
);
RegexSplit(
const ov::OutputVector& arguments,
const std::shared_ptr<re2::RE2>& search_pattern_re2,
const std::shared_ptr<PCRE2Wrapper>& search_pattern_pcre2,
const std::shared_ptr<std::set<std::string>>& skip_tokens,
const std::string& behaviour = "remove",
Expand All @@ -36,7 +38,7 @@ class RegexSplit : public ov::op::Op {
void validate_and_infer_types() override;

std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& inputs) const override {
return std::make_shared<RegexSplit>(inputs, m_search_pattern_pcre2,
return std::make_shared<RegexSplit>(inputs, m_search_pattern_re2, m_search_pattern_pcre2,
m_skip_tokens, m_behaviour, m_invert, m_max_splits);
}

Expand All @@ -61,7 +63,9 @@ class RegexSplit : public ov::op::Op {
CONTIGUOUS, // Contiguous is not used during evaluate, replaced with isolated with patched pattern in ctor.
};


private:
mutable std::shared_ptr<re2::RE2> m_search_pattern_re2;
mutable std::shared_ptr<PCRE2Wrapper> m_search_pattern_pcre2;
mutable std::shared_ptr<std::set<std::string>> m_skip_tokens;
mutable std::string m_behaviour = "remove";
Expand Down
7 changes: 3 additions & 4 deletions src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,10 @@ std::string PCRE2Wrapper::substitute(const std::string& orig_str,
return orig_str;
}

// Usually found pattern is replaced by shorter string, but set 4 times more space for safety.
// Also set min length to 16 to avoid too small buffer when single ASCII symbol is replaced with several UTF-8 symbols.
// Usually found pattern is replaced by shorter string, but set 3 times more space for safety.
// Allocate dynamically since lenght depends dynamically on the lenght of input string.
// Allocated memory will be freed at the exit from function.
size_t buffer_length = sizeof(PCRE2_UCHAR) * subject_length * 4 + 16;
size_t buffer_length = sizeof(PCRE2_UCHAR) * subject_length * 4;
PCRE2_UCHAR* buffer = (PCRE2_UCHAR*) std::malloc(buffer_length);
if (buffer == nullptr) {
std::cerr << "Memory allocation failed" << std::endl;
Expand Down Expand Up @@ -321,7 +320,7 @@ std::string PCRE2Wrapper::substitute(const std::string& orig_str,
std::free(buffer);
return orig_str;
}
auto res = std::string(reinterpret_cast<char*>(buffer), buffer_length);
auto res = std::string(reinterpret_cast<char*>(buffer), subject_length);
std::free(buffer);
pcre2_match_data_free(match_data);
return res;
Expand Down
Loading

0 comments on commit b0b0dc9

Please sign in to comment.