Skip to content

Commit

Permalink
Fix ft substr bug (#3279)
Browse files Browse the repository at this point in the history
* optimize cmakelist

* Add substr pos check
  • Loading branch information
joey12300 authored Sep 16, 2022
1 parent 0f464e8 commit 46f395a
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 7 deletions.
2 changes: 1 addition & 1 deletion faster_tokenizer/faster_tokenizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ add_subdirectory(postprocessors)
add_subdirectory(core)
add_subdirectory(utils)
# set the relative path of shared library
if (NOT APPLE)
if (UNIX)
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-rpath='$ORIGIN'")
endif()

Expand Down
2 changes: 1 addition & 1 deletion faster_tokenizer/faster_tokenizer/core/added_vocabulary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ bool AddedVocabulary::FindMatch(const std::string& sequence,
if (added_tokens.GetIsSingleWord()) {
bool start_space =
(curr_start == 0) || !EndWithWord(sequence.substr(0, curr_start));
bool stop_space = (curr_end == sequence.length()) ||
bool stop_space = (curr_end >= sequence.length()) ||
!StartWithWord(sequence.substr(curr_end));
if (!start_space || !stop_space) {
// Discard not single word
Expand Down
6 changes: 4 additions & 2 deletions faster_tokenizer/faster_tokenizer/models/faster_wordpiece.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ void FasterWordPiece::AppendTokensToOutput(
if (id == unk_token_id_) {
value = unk_token_;
} else {
auto c_offset = *curr_offset_in_sequence;
c_offset = (std::min)(c_offset, static_cast<int>(sequence.length() - 1));
value = sequence.substr(*curr_offset_in_sequence, token_substr_length);
}

Expand Down Expand Up @@ -286,7 +288,7 @@ std::vector<core::Token> FasterWordPiece::TokenizeWithoutPreTokenize(
&all_tokens);
}
if (all_tokens.size() == 0) {
ResetOutputAppendUNK(0, sequence.size(), &original_num_tokens, &all_tokens);
ResetOutputAppendUNK(0, sequence.size(), &original_num_tokens, &all_tokens);
}
VLOG(6) << "All tokens num from TokenizeWithoutPreTokenize: "
<< all_tokens.size();
Expand Down Expand Up @@ -374,7 +376,7 @@ std::vector<core::Token> FasterWordPiece::TokenizeWithPreTokenize(
&all_tokens);
}
if (all_tokens.size() == 0) {
ResetOutputAppendUNK(0, sequence.size(), &original_num_tokens, &all_tokens);
ResetOutputAppendUNK(0, sequence.size(), &original_num_tokens, &all_tokens);
}
VLOG(6) << "All tokens num from TokenizeWithPreTokenize: "
<< all_tokens.size();
Expand Down
2 changes: 2 additions & 0 deletions faster_tokenizer/faster_tokenizer/models/wordpiece.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ core::Vocab WordPiece::GetVocabFromFile(const std::string& file) {
std::string word_str = word;
auto leading_spaces = word_str.find_first_not_of(WHITESPACE);
if (leading_spaces != std::string::npos) {
leading_spaces = (std::min)(leading_spaces, word_str.length() - 1);
word_str = word_str.substr(leading_spaces);
}
auto trailing_spaces = word_str.find_last_not_of(WHITESPACE);
Expand Down Expand Up @@ -275,6 +276,7 @@ void WordPieceFactory::GetVocabFromFiles(const std::string& files) {
std::string word_str = word;
auto leading_spaces = word_str.find_first_not_of(WHITESPACE);
if (leading_spaces != std::string::npos) {
leading_spaces = (std::min)(leading_spaces, word_str.length() - 1);
word_str = word_str.substr(leading_spaces);
}
auto trailing_spaces = word_str.find_last_not_of(WHITESPACE);
Expand Down
18 changes: 15 additions & 3 deletions faster_tokenizer/faster_tokenizer/normalizers/normalizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ limitations under the License. */
#include "faster_tokenizer/normalizers/normalizer.h"
#include "faster_tokenizer/utils/utf8.h"

#include "glog/logging.h"
#include "faster_tokenizer/normalizers/unicode.h"
#include "glog/logging.h"
#include "re2/re2.h"
#include "unicode/edits.h"
#include "unicode/errorcode.h"
Expand Down Expand Up @@ -100,6 +100,8 @@ void NormalizedString::UpdateNormalizedRange(
// Retrieve the original characters that are being replaced. This let us
// compute the change in byte sizes along the way.
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
n_range.first = (std::min)(n_range.first,
static_cast<uint32_t>(normalized_.length() - 1));
std::u32string u32replaced_normalized = conv.from_bytes(
normalized_.substr(n_range.first, n_range.second - n_range.first));
uint32_t initial_removed = 0;
Expand Down Expand Up @@ -332,12 +334,14 @@ NormalizedString& NormalizedString::RStrip() { return LRStrip(false, true); }
const std::string WHITESPACE = " \n\r\t\f\v";

NormalizedString& NormalizedString::LRStrip(bool left, bool right) {
int leading_spaces = 0;
int trailing_spaces = 0;
uint32_t leading_spaces = 0;
uint32_t trailing_spaces = 0;
std::string new_normalized = normalized_;
if (left) {
leading_spaces = new_normalized.find_first_not_of(WHITESPACE);
if (leading_spaces != std::string::npos) {
leading_spaces = (std::min)(
leading_spaces, static_cast<uint32_t>(new_normalized.length() - 1));
new_normalized = new_normalized.substr(leading_spaces);
}
}
Expand Down Expand Up @@ -534,8 +538,16 @@ bool NormalizedString::Slice(core::Range range,
ConvertOffsets(&original_range, false);
}
uint32_t n_shift = original_range.first;

original_range.first =
(std::min)(original_range.first,
static_cast<uint32_t>(this->original_.length() - 1));
normalized->original_ = this->original_.substr(
original_range.first, original_range.second - original_range.first);

normalized_range.first =
(std::min)(normalized_range.first,
static_cast<uint32_t>(this->normalized_.length() - 1));
normalized->normalized_ = this->normalized_.substr(
normalized_range.first,
normalized_range.second - normalized_range.first);
Expand Down
1 change: 1 addition & 0 deletions faster_tokenizer/faster_tokenizer/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void GetVocabFromFiles(const std::string& files,
std::string word_str = word;
auto leading_spaces = word_str.find_first_not_of(WHITESPACE);
if (leading_spaces != std::string::npos) {
leading_spaces = (std::min)(leading_spaces, word_str.length() - 1);
word_str = word_str.substr(leading_spaces);
}
auto trailing_spaces = word_str.find_last_not_of(WHITESPACE);
Expand Down

0 comments on commit 46f395a

Please sign in to comment.