From e111fd8e2cd3fe3b06e0fc868c491f9e35adeb4f Mon Sep 17 00:00:00 2001 From: co63oc Date: Sun, 13 Oct 2024 08:10:06 +0800 Subject: [PATCH] Fix --- paddle/fluid/framework/CMakeLists.txt | 5 - paddle/fluid/framework/feed_fetch_type.h | 20 +- paddle/fluid/framework/operator.cc | 3 + paddle/fluid/framework/string_array.h | 117 +--- paddle/fluid/framework/tensor_ref_array.h | 12 +- paddle/fluid/framework/type_info.cc | 2 - paddle/fluid/imperative/prepared_operator.h | 17 +- paddle/fluid/operators/CMakeLists.txt | 4 +- .../ops_signature/faster_tokenizer_sig.cc | 33 + paddle/fluid/operators/string/CMakeLists.txt | 2 +- .../operators/string/faster_tokenizer_op.cc | 423 +----------- .../operators/string/faster_tokenizer_op.h | 210 ------ .../pir/dialect/op_generator/ops_api_gen.py | 1 + paddle/phi/core/CMakeLists.txt | 1 + paddle/phi/core/kernel_registry.cc | 14 + paddle/phi/core/kernel_utils.h | 5 + paddle/phi/core/utils/type_info.cc | 3 + paddle/phi/core/vocab/CMakeLists.txt | 1 + .../core/vocab}/phi_tensor_base_vector.h | 14 +- .../core/vocab}/string_array.cc | 37 +- paddle/phi/core/vocab/string_array.h | 142 ++++ paddle/phi/infermeta/ternary.cc | 16 + paddle/phi/infermeta/ternary.h | 11 + .../kernels/cpu/faster_tokenizer_kernel.cc | 617 ++++++++++++++++++ .../phi/ops/yaml/inconsistent/static_ops.yaml | 11 + paddle/phi/ops/yaml/op_compat.yaml | 6 + 26 files changed, 932 insertions(+), 795 deletions(-) create mode 100644 paddle/fluid/operators/ops_signature/faster_tokenizer_sig.cc delete mode 100644 paddle/fluid/operators/string/faster_tokenizer_op.h create mode 100644 paddle/phi/core/vocab/CMakeLists.txt rename paddle/{fluid/framework => phi/core/vocab}/phi_tensor_base_vector.h (92%) rename paddle/{fluid/framework => phi/core/vocab}/string_array.cc (78%) create mode 100644 paddle/phi/core/vocab/string_array.h create mode 100644 paddle/phi/kernels/cpu/faster_tokenizer_kernel.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 516b80506b40eb..b2e5e539c9bd83 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -101,11 +101,6 @@ foreach(OP_DEF_FILE ${OP_DEF_FILES}) endforeach() file(APPEND ${CMAKE_CURRENT_BINARY_DIR}/op_def.pbtxt "{\"\",\"\"}};\n}") -cc_library( - string_array - SRCS string_array.cc - DEPS utf8proc phi common) - cc_library( data_type SRCS data_type.cc diff --git a/paddle/fluid/framework/feed_fetch_type.h b/paddle/fluid/framework/feed_fetch_type.h index 6be31a062ef07b..5f7c81c90f18f4 100644 --- a/paddle/fluid/framework/feed_fetch_type.h +++ b/paddle/fluid/framework/feed_fetch_type.h @@ -21,13 +21,12 @@ limitations under the License. */ #include "paddle/fluid/framework/string_array.h" #include "paddle/phi/core/extended_tensor.h" -namespace paddle { -namespace framework { -using FeedType = - paddle::variant; +namespace phi { +using FeedType = paddle:: + variant; using FetchType = paddle::variant; template <> @@ -40,9 +39,16 @@ struct PhiVectorType { const char *type_name = "PhiVectorFetchType"; }; -using FeedList = paddle::framework::PhiVector; -using FetchList = paddle::framework::PhiVector; +using FeedList = PhiVector; +using FetchList = PhiVector; +} // namespace phi +namespace paddle { +namespace framework { +using FeedType = phi::FeedType; +using FetchType = phi::FetchType; +using FeedList = phi::FeedList; +using FetchList = phi::FetchList; using FetchUnmergedList = std::vector>; inline bool data_is_lod_tensor(const FetchType &data) { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index adc6dfcf20afc7..0e04b3cd48eab5 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -3253,6 +3253,9 @@ void OperatorWithKernel::BuildPhiKernelContext( } else if (var->IsType()) { tensor_in = &(var->Get()); phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in); + } else if (var->IsType()) { + tensor_in = &(var->Get()); + phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in); } else if (var->IsType()) { tensor_in = &(var->Get()); phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in); diff --git a/paddle/fluid/framework/string_array.h b/paddle/fluid/framework/string_array.h index ddcc15e3dca591..fc3f3d8146a981 100644 --- a/paddle/fluid/framework/string_array.h +++ b/paddle/fluid/framework/string_array.h @@ -14,119 +14,4 @@ limitations under the License. */ #pragma once -#include -#include -#include -#include -#include -#include -#include "paddle/fluid/framework/phi_tensor_base_vector.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/extended_tensor.h" - -namespace paddle { -namespace framework { - -// Note(YuanRisheng): Vocab is mainly used for faster_tokenizer_op and we don't -// recommend widely use it. Because faster_tokenizer_op may be deleted in the -// future and this class will be deleted. - -class Vocab : public phi::ExtendedTensor, - public phi::TypeInfoTraits { - public: - Vocab() = default; - - Vocab(Vocab&& other) = default; - - Vocab(const Vocab& other) = default; - - Vocab& operator=(const Vocab& other) = default; - - Vocab& operator=(Vocab&& other) = default; - - Vocab& operator=( - const std::unordered_map& other) { - this->data_ = other; - return *this; - } - - /// \brief Destroy the Vocab and release exclusive resources. - virtual ~Vocab() = default; - - public: - /// \brief Returns the name of the class for type traits. - /// \return The name of the class. - static const char* name() { return "Vocab"; } - - size_t size() const { return data_.size(); } - - void clear() { data_.clear(); } - - void emplace(const std::wstring& key, std::int32_t value) { - data_.emplace(key, value); - } - - std::int32_t at(const std::wstring& key) { return data_.at(key); } - - std::int32_t at(const std::wstring& key) const { return data_.at(key); } - - std::unordered_map::iterator find( - const std::wstring& key) { - return data_.find(key); - } - - std::unordered_map::const_iterator find( - const std::wstring& key) const { - return data_.find(key); - } - - std::unordered_map::iterator begin() { - return data_.begin(); - } - - std::unordered_map::const_iterator begin() const { - return data_.begin(); - } - - std::unordered_map::iterator end() { - return data_.end(); - } - - std::unordered_map::const_iterator end() const { - return data_.end(); - } - - private: - std::unordered_map data_; -}; - -// Note(YuanRisheng): PhiVector is essentially a vector that only used for PHI -// Kernel. It can be used when you define a non-tensor type that needs to be -// stored in a vector as PHI kernel argument. - -template <> -struct PhiVectorType { - const char* type_name = "PhiVectorString"; -}; - -using String = std::string; -using Strings = PhiVector; - -// Convert the std::string type to the std::string type. -bool ConvertStrToWstr(const std::string& src, std::wstring* res); -// Convert the std::wstring type to the std::string type. -void ConvertWstrToStr(const std::wstring& src, std::string* res); -// Normalization Form Canonical Decomposition. -void NFD(const std::string& s, std::string* ret); - -// Write the data which is type of -// std::unordered_map to ostream. -void StringMapToStream(std::ostream& os, - const std::unordered_map& data); - -// Read the data which is type of -// std::unordered_map from istream. -void StringMapFromStream(std::istream& is, - std::unordered_map* data); -} // namespace framework -} // namespace paddle +#include "paddle/phi/core/vocab/string_array.h" diff --git a/paddle/fluid/framework/tensor_ref_array.h b/paddle/fluid/framework/tensor_ref_array.h index d5f5e0b61f2f9d..80211301b6976a 100644 --- a/paddle/fluid/framework/tensor_ref_array.h +++ b/paddle/fluid/framework/tensor_ref_array.h @@ -14,15 +14,17 @@ #pragma once -#include "paddle/fluid/framework/phi_tensor_base_vector.h" - -namespace paddle { -namespace framework { +#include "paddle/phi/core/vocab/phi_tensor_base_vector.h" +namespace phi { template <> -struct PhiVectorType { +struct PhiVectorType { const char* type_name = "VariableRefArray"; }; +} // namespace phi + +namespace paddle { +namespace framework { using VariableRefArray = PhiVector; diff --git a/paddle/fluid/framework/type_info.cc b/paddle/fluid/framework/type_info.cc index daa91dde9d6dbe..f52db412ac01d8 100644 --- a/paddle/fluid/framework/type_info.cc +++ b/paddle/fluid/framework/type_info.cc @@ -39,8 +39,6 @@ bool TypeInfoTraits::classof(const BaseT* obj) { } template class TypeInfoTraits; -template class TypeInfoTraits; -template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index af224dfc5d282d..ed077dfb82a1aa 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -34,6 +34,7 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/selected_rows.h" +#include "paddle/phi/core/vocab/string_array.h" COMMON_DECLARE_bool(use_mkldnn); @@ -307,8 +308,14 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); continue; } else if (input_defs[i].type_index == - std::type_index(typeid( - paddle::optional>))) { + std::type_index( + typeid(paddle::optional)) || + input_defs[i].type_index == + std::type_index(typeid(paddle::optional)) || + input_defs[i].type_index == + std::type_index( + typeid(paddle::optional< + std::vector>))) { kernel_ctx->EmplaceBackInputWithoutSetRange(nullptr); auto end_idx = start_idx + 1; kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); @@ -338,6 +345,12 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, } else if (var.template IsType()) { tensor_in = &(var.template Get()); kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in); + } else if (var.template IsType()) { + tensor_in = &(var.template Get()); + kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in); + } else if (var.template IsType()) { + tensor_in = &(var.template Get()); + kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in); } else { PADDLE_THROW(common::errors::Unimplemented( "Unsupported input `%s` type when call pt kernel.", diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 6d0bca80b96a52..0944bd7c5773f5 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -81,8 +81,8 @@ op_library(generated_op UNITY SRCS generated_op1.cc generated_op2.cc generated_o op_library(run_program_op DEPS executor_cache ${OP_HEADER_DEPS}) target_link_libraries(run_program_op phi common) op_library(quantize_linear_op DEPS phi common) -op_library(save_combine_op DEPS string_array phi common) -op_library(load_combine_op DEPS string_array) +op_library(save_combine_op DEPS phi) +op_library(load_combine_op DEPS phi) op_library(activation_op SRCS activation_op.cc DEPS ${OP_HEADER_DEPS}) diff --git a/paddle/fluid/operators/ops_signature/faster_tokenizer_sig.cc b/paddle/fluid/operators/ops_signature/faster_tokenizer_sig.cc new file mode 100644 index 00000000000000..0ec3e8ebca8066 --- /dev/null +++ b/paddle/fluid/operators/ops_signature/faster_tokenizer_sig.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature FasterTokenizerOpArgumentMapping( + const ArgumentMappingContext& ctx UNUSED) { + return KernelSignature("faster_tokenizer", + {"Vocab", "Text", "TextPair"}, + {"do_lower_case", + "is_split_into_words", + "max_seq_len", + "pad_to_max_seq_len"}, + {"InputIds", "SegmentIds"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(faster_tokenizer, + phi::FasterTokenizerOpArgumentMapping); diff --git a/paddle/fluid/operators/string/CMakeLists.txt b/paddle/fluid/operators/string/CMakeLists.txt index 1da2e8e455da0c..2065455e61f422 100644 --- a/paddle/fluid/operators/string/CMakeLists.txt +++ b/paddle/fluid/operators/string/CMakeLists.txt @@ -3,4 +3,4 @@ if(WITH_UNITY_BUILD) # Load Unity Build rules for operators in paddle/fluid/operators/sequence_ops. include(unity_build_rule.cmake) endif() -register_operators(DEPS op_version_registry utf8proc string_array) +register_operators(DEPS op_version_registry phi) diff --git a/paddle/fluid/operators/string/faster_tokenizer_op.cc b/paddle/fluid/operators/string/faster_tokenizer_op.cc index 10e08e86dc6855..c02c3d07524470 100644 --- a/paddle/fluid/operators/string/faster_tokenizer_op.cc +++ b/paddle/fluid/operators/string/faster_tokenizer_op.cc @@ -9,10 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/string/faster_tokenizer_op.h" - #include - #include #include #include @@ -24,427 +21,12 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/string_array.h" namespace paddle { namespace operators { -using std::ifstream; -using std::int64_t; -using std::size_t; -using std::string; -using std::unordered_map; -using std::unordered_set; -using std::vector; -using std::wstring; - -const wstring kStripChars = L" \t\n\r\v\f"; - -inline bool IsControl(const wchar_t& ch) { - if (ch == L'\t' || ch == L'\n' || ch == L'\r') return false; - auto cat = utf8proc_category(ch); - if (cat == UTF8PROC_CATEGORY_CC || cat == UTF8PROC_CATEGORY_CF) return true; - return false; -} - -inline bool IsChineseChar(const wchar_t& ch) { - if ((ch >= 0x4E00 && ch <= 0x9FFF) || (ch >= 0x3400 && ch <= 0x4DBF) || - (ch >= 0x20000 && ch <= 0x2A6DF) || (ch >= 0x2A700 && ch <= 0x2B73F) || - (ch >= 0x2B740 && ch <= 0x2B81F) || (ch >= 0x2B820 && ch <= 0x2CEAF) || - (ch >= 0xF900 && ch <= 0xFAFF) || (ch >= 0x2F800 && ch <= 0x2FA1F)) - return true; - return false; -} - -inline bool IsWhiteSpace(const wchar_t& ch) { - if (ch == L' ' || ch == L'\t' || ch == L'\n' || ch == L'\r') return true; - auto cat = utf8proc_category(ch); - if (cat == UTF8PROC_CATEGORY_ZS) return true; - return false; -} - -inline bool IsPunctuation(const wchar_t& ch) { - if ((ch >= 33 && ch <= 47) || (ch >= 58 && ch <= 64) || - (ch >= 91 && ch <= 96) || (ch >= 123 && ch <= 126)) - return true; - auto cat = utf8proc_category(ch); - if (cat == UTF8PROC_CATEGORY_PD || cat == UTF8PROC_CATEGORY_PS || - cat == UTF8PROC_CATEGORY_PE || cat == UTF8PROC_CATEGORY_PC || - cat == UTF8PROC_CATEGORY_PO // sometimes ¶ belong SO - || cat == UTF8PROC_CATEGORY_PI || cat == UTF8PROC_CATEGORY_PF) - return true; - return false; -} - -BasicTokenizer::BasicTokenizer(bool do_lower_case /* = true */) - : do_lower_case_(do_lower_case) {} - -wchar_t BasicTokenizer::do_lower_case(wchar_t ch) const { - wchar_t new_ch = utf8proc_tolower(ch); - return new_ch; -} - -void BasicTokenizer::Tokenize(const string& text, vector* res) const { - std::wstring unicode_text; - bool status = framework::ConvertStrToWstr(text, &unicode_text); - if (!status) { - // String is converted into wstring failedly. - return; - } - std::wstring cache_text = L""; - auto PushCacheText = [&]() { - if (!cache_text.empty()) { - res->emplace_back(cache_text); - cache_text = L""; - } - }; - for (auto& ch : unicode_text) { - if (ch == 0 || ch == 0xfffd || IsControl(ch)) { - continue; - } - if (do_lower_case_) { - ch = do_lower_case(ch); - } - if (IsChineseChar(ch) || IsPunctuation(ch)) { - PushCacheText(); - res->emplace_back(std::wstring{ch}); - } else if (IsWhiteSpace(ch)) { - PushCacheText(); - } else { - cache_text += ch; - } - } - PushCacheText(); -} - -WordPieceTokenizer::WordPieceTokenizer( - const framework::Vocab* vocab, - const wstring& unk_token /* = L"[UNK]"*/, - const size_t max_input_chars_per_word /* = 100 */) - : vocab_(vocab), - unk_token_(unk_token), - max_input_chars_per_word_(max_input_chars_per_word) { - unk_token_id_ = vocab_->at(unk_token_); -} - -void WordPieceTokenizer::Tokenize(const wstring& text, - vector* token_ids) const { - size_t len = text.size(); - if (len > max_input_chars_per_word_) { - token_ids->emplace_back(unk_token_id_); - return; - } - - auto it = vocab_->find(text); - if (it != vocab_->end()) { - token_ids->emplace_back(it->second); - return; - } - - size_t start = 0; - vector wordpiece_ids; - while (start < len) { - size_t end = len; - std::wstring cur_substr; - int64_t cur_substr_id = 0; - while (start < end) { - std::wstring sub = text.substr(start, end - start); - if (start > 0) { - sub.insert(0, L"##"); - } - auto it = vocab_->find(sub); - if (it != vocab_->end()) { - cur_substr = sub; - cur_substr_id = it->second; - break; - } - end -= 1; - } - - if (cur_substr.empty()) { - token_ids->emplace_back(unk_token_id_); - return; - } else { - start = end; - wordpiece_ids.emplace_back(cur_substr_id); - } - } - for (auto& token_id : wordpiece_ids) { - token_ids->emplace_back(token_id); - } -} - -BertTokenizer::BertTokenizer(const framework::Vocab* vocab, - bool do_lower_case /* = false */, - const wstring& unk_token /* = L"[UNK]" */, - const wstring& pad_token /* = L"[PAD]" */, - const wstring& cls_token /* = L"[CLS]" */, - const wstring& mask_token /* = L"[MASK]" */, - const wstring& sep_token /* = L"[SEP]" */, - const string& padding_site /* = "right" */) - : do_lower_case_(do_lower_case), - unk_token_(unk_token), - pad_token_(pad_token), - cls_token_(cls_token), - mask_token_(mask_token), - sep_token_(sep_token), - padding_site_(padding_site), - vocab_(vocab), - basic_tokenizer_(do_lower_case_), - word_piece_tokenizer_(vocab_, unk_token) { - unk_token_id_ = vocab_->at(unk_token_); - pad_token_id_ = vocab_->at(pad_token_); - cls_token_id_ = vocab_->at(cls_token_); - mask_token_id_ = vocab_->at(mask_token_); - sep_token_id_ = vocab_->at(sep_token_); - - all_special_tokens_ = vector( - {unk_token_, pad_token_, cls_token_, mask_token_, sep_token_}); - all_special_token_ids_ = unordered_set({unk_token_id_, - pad_token_id_, - cls_token_id_, - mask_token_id_, - sep_token_id_}); -} - -void BertTokenizer::Tokenize(const string& text, - vector* split_token_ids) const { - std::vector tmp_tokens; - basic_tokenizer_.Tokenize(text, &tmp_tokens); - if (tmp_tokens.empty()) return; - split_token_ids->reserve(tmp_tokens.size()); - for (auto& w_token : tmp_tokens) { - const auto& vec_size = w_token.size(); - if (vec_size == 1) { - if (IsChineseChar(w_token[0])) { - auto vocab_it = vocab_->find(w_token); - if (vocab_it != vocab_->end()) { - split_token_ids->emplace_back(vocab_it->second); - } else { - split_token_ids->emplace_back(unk_token_id_); - } - } else { - word_piece_tokenizer_.Tokenize(w_token, split_token_ids); - } - } else if (vec_size > 1) { - word_piece_tokenizer_.Tokenize(w_token, split_token_ids); - } else { - continue; - } - } -} - -void BertTokenizer::BuildInputsWithSpecialTokens( - vector* inputs, - const vector& token_ids_0, - const vector& token_ids_1 /* = vector() */) const { - if (token_ids_1.empty()) { - inputs->clear(); - inputs->resize(token_ids_0.size() + 2); - inputs->at(0) = cls_token_id_; - size_t i = 1; - for (auto& token_id : token_ids_0) { - inputs->at(i) = token_id; - ++i; - } - inputs->at(i) = sep_token_id_; - } else { - inputs->clear(); - inputs->resize(token_ids_0.size() + token_ids_1.size() + 3); - inputs->at(0) = cls_token_id_; - size_t i = 1; - for (auto& token_id : token_ids_0) { - inputs->at(i) = token_id; - ++i; - } - inputs->at(i) = sep_token_id_; - ++i; - for (auto& token_id : token_ids_1) { - inputs->at(i) = token_id; - ++i; - } - inputs->at(i) = sep_token_id_; - } -} - -int64_t BertTokenizer::GetNumSpecialTokensToAdd(const bool pair) const { - if (pair) { - return 3; - } else { - return 2; - } -} - -void BertTokenizer::CreateTokenTypeIdsFromSequences( - vector* token_type_ids, - const vector& token_ids_0, - const vector& token_ids_1 /* = vector() */) const { - if (token_ids_1.empty()) { - vector tmp(token_ids_0.size() + 2, 0); - token_type_ids->swap(tmp); - } else { - vector tmp(token_ids_0.size() + token_ids_1.size() + 3, 0); - for (size_t i = token_ids_0.size() + 2; i < tmp.size(); i++) { - tmp[i] = 1; - } - token_type_ids->swap(tmp); - } -} - -void BertTokenizer::TruncateSequence( - vector* ids, - vector* pair_ids, - const size_t num_tokens_to_remove /* = 0 */, - const size_t stride /* = 0 */) const { - for (size_t i = 0; i < num_tokens_to_remove; i++) { - if ((pair_ids->empty()) || (ids->size() > pair_ids->size())) { - ids->pop_back(); - } else { - pair_ids->pop_back(); - } - } -} - -int64_t BertTokenizer::GetPadTokenID() const { return pad_token_id_; } - -int BertTokenizer::Encode( - unordered_map>* encoded_inputs, - const string& text, - const string& text_pair /* = "" */, - bool is_split_into_words /* = false */, - const size_t max_seq_len /* = 0 */, - bool pad_to_max_seq_len /* = false */) const { - vector ids; - vector pair_ids; - if (!is_split_into_words) { - Tokenize(text, &ids); - if (ids.empty()) return 0; - if (!text_pair.empty()) { - Tokenize(text_pair, &pair_ids); - if (pair_ids.empty()) return 0; - } - } else { - std::wstring unicode_text; - bool status_a = framework::ConvertStrToWstr(text, &unicode_text); - if (!status_a) { - return 0; - } - for (size_t i = 0; i < unicode_text.size(); i++) { - wstring token = unicode_text.substr(i, 1); - auto it = vocab_->find(token); - if (it != vocab_->end()) { - ids.emplace_back(it->second); - } else { - ids.emplace_back(unk_token_id_); - } - } - } - - bool pair = false; - if (!pair_ids.empty()) { - pair = true; - } - - size_t len_ids = ids.size(); - size_t len_pair_ids = pair_ids.size(); - - // Truncation: Handle max sequence length - // If max_seq_len == 0, then do nothing and keep the real length. - // If max_seq_len > 0 and - // all the input sequence len is over the max_seq_len, - // then we truncate it. - size_t total_len = len_ids + len_pair_ids + GetNumSpecialTokensToAdd(pair); - if (max_seq_len > 0 && total_len > max_seq_len) { - TruncateSequence(&ids, &pair_ids, total_len - max_seq_len); - } - - // Add special tokens - vector sequence; - BuildInputsWithSpecialTokens(&sequence, ids, pair_ids); - size_t seq_len = sequence.size(); - vector token_type_ids; - CreateTokenTypeIdsFromSequences(&token_type_ids, ids, pair_ids); - - // Build output dictionary - encoded_inputs->emplace("input_ids", sequence); - encoded_inputs->emplace("token_type_ids", token_type_ids); - // Check lengths - if (max_seq_len > 0 && seq_len > max_seq_len) { - VLOG(3) << "There is something wrong with the input sequence length." - " Please check it."; - // Failed. - return 0; - } - - // Padding - bool needs_to_be_padded = false; - if (pad_to_max_seq_len && max_seq_len > 0 && (seq_len < max_seq_len)) { - needs_to_be_padded = true; - } - - if (needs_to_be_padded) { - int64_t difference = static_cast(max_seq_len - seq_len); - size_t pad_start = max_seq_len - 1 - difference; - encoded_inputs->at("token_type_ids").resize(max_seq_len); - for (size_t i = max_seq_len - 1; i > pad_start; i--) { - encoded_inputs->at("token_type_ids")[i] = pad_token_id_; - } - - encoded_inputs->at("input_ids").resize(max_seq_len); - for (size_t i = max_seq_len - 1; i > pad_start; i--) { - encoded_inputs->at("input_ids")[i] = pad_token_id_; - } - } - return 1; -} - -void BertTokenizer::BatchEncode( - vector>>* batch_encode_inputs, - const framework::Strings& batch_text, - const framework::Strings& batch_text_pair /* = vector() */, - bool is_split_into_words /* = false */, - const size_t max_seq_len /* = 0 */, - bool pad_to_max_seq_len /* = false */) const { - bool has_text_pair = false; - if (batch_text_pair.size() != 0) { - has_text_pair = true; - } - - size_t batch_size = batch_text.size(); -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (size_t i = 0; i < batch_size; i++) { - unordered_map> res; - if (has_text_pair) { - auto status = Encode(&res, - batch_text[i], - batch_text_pair[i], - is_split_into_words, - max_seq_len, - pad_to_max_seq_len); - if (!status) { - res["input_ids"] = - std::vector{cls_token_id_, sep_token_id_, cls_token_id_}; - res["token_type_ids"] = std::vector{0, 0, 1}; - } - } else { - auto status = Encode(&res, - batch_text[i], - {}, - is_split_into_words, - max_seq_len, - pad_to_max_seq_len); - - if (!status) { - res["input_ids"] = std::vector{cls_token_id_, sep_token_id_}; - res["token_type_ids"] = std::vector{0, 0}; - } - } - batch_encode_inputs->at(i) = std::move(res); - } -} - class FasterTokenizerOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -532,6 +114,3 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(faster_tokenizer, ops::FasterTokenizerOp, ops::FasterTokenizerOpMaker); - -PD_REGISTER_STRUCT_KERNEL( - faster_tokenizer, CPU, ALL_LAYOUT, ops::FasterTokenizerKernel, int64_t) {} diff --git a/paddle/fluid/operators/string/faster_tokenizer_op.h b/paddle/fluid/operators/string/faster_tokenizer_op.h deleted file mode 100644 index 1f848cb393fae2..00000000000000 --- a/paddle/fluid/operators/string/faster_tokenizer_op.h +++ /dev/null @@ -1,210 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -#include -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/string_array.h" - -namespace paddle { -namespace operators { - -using std::endl; -using std::int64_t; -using std::shared_ptr; -using std::size_t; -using std::string; -using std::unordered_map; -using std::unordered_set; -using std::vector; -using std::wcout; -using std::wstring; - -inline bool IsControl(const wchar_t& ch); -inline bool IsChineseChar(const wchar_t& ch); -inline bool IsWhiteSpace(const wchar_t& ch); - -using Vocab = unordered_map; -using InvVocab = unordered_map; - -class BasicTokenizer { - public: - explicit BasicTokenizer(bool do_lower_case = true); - void Tokenize(const string& text, vector* res) const; - - private: - wchar_t do_lower_case(wchar_t ch) const; - - bool do_lower_case_; -}; - -class WordPieceTokenizer { - public: - explicit WordPieceTokenizer(const framework::Vocab* vocab, - const wstring& unk_token = L"[UNK]", - const size_t max_input_chars_per_word = 100); - void Tokenize(const wstring& text, vector* output) const; - - private: - const framework::Vocab* vocab_; - wstring unk_token_{L"[UNK]"}; - int64_t unk_token_id_; - size_t max_input_chars_per_word_; -}; - -class BertTokenizer { - public: - explicit BertTokenizer(const framework::Vocab* vocab, - bool do_lower_case = false, - const wstring& unk_token = L"[UNK]", - const wstring& pad_token = L"[PAD]", - const wstring& cls_token = L"[CLS]", - const wstring& mask_token = L"[MASK]", - const wstring& sep_token = L"[SEP]", - const string& padding_site = "right"); - - void Tokenize(const string& text, vector* split_tokens) const; - void BuildInputsWithSpecialTokens( - vector* res, - const vector& token_ids_0, - const vector& token_ids_1 = vector()) const; - void CreateTokenTypeIdsFromSequences( - vector* token_type_ids, - const vector& token_ids_0, - const vector& token_ids_1 = vector()) const; - void TruncateSequence(vector* ids, - vector* pair_ids, - const size_t num_tokens_to_remove = 0, - const size_t stride = 0) const; - int64_t GetNumSpecialTokensToAdd(const bool pair = false) const; - int Encode(unordered_map>* encoded_inputs, - const string& text, - const string& text_pair = "", - bool is_split_into_words = false, - const size_t max_seq_len = 0, - bool pad_to_max_seq_len = false) const; - void BatchEncode( - vector>>* batch_encode_inputs, - const framework::Strings& batch_text, - const framework::Strings& batch_text_pair = framework::Strings(), - bool is_split_into_words = false, - const size_t max_seq_len = 0, - bool pad_to_max_seq_len = false) const; - - int64_t GetPadTokenID() const; - - private: - bool do_lower_case_; - wstring unk_token_, pad_token_, cls_token_, mask_token_, sep_token_; - string padding_site_; - const framework::Vocab* vocab_; - BasicTokenizer basic_tokenizer_; - WordPieceTokenizer word_piece_tokenizer_; - int64_t unk_token_id_, cls_token_id_, mask_token_id_, pad_token_id_, - sep_token_id_; - vector all_special_tokens_; - unordered_set all_special_token_ids_; - InvVocab inv_vocab_; -}; - -template -class FasterTokenizerKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* text = ctx.Input("Text"); - auto* vocab = ctx.Input("Vocab"); - - auto* input_ids = ctx.Output("InputIds"); - auto* seg_ids = ctx.Output("SegmentIds"); - - auto do_lower_case = static_cast(ctx.Attr("do_lower_case")); - auto is_split_into_words = - static_cast(ctx.Attr("is_split_into_words")); - auto max_seq_len = static_cast(ctx.Attr("max_seq_len")); - auto pad_to_max_seq_len = - static_cast(ctx.Attr("pad_to_max_seq_len")); - - auto* text_pair = ctx.Input("TextPair"); - if (text_pair && text->size() != text_pair->size()) { - VLOG(3) << "The input text(list[str]) and text pair (list[str]) must" - << "be the same number of text sequence. Please check the input!"; - return; - } - - BertTokenizer tokenizer(vocab, do_lower_case); - size_t batch_max_seq_len = 0; - size_t batch_size = text->size(); - - vector>> batch_encode_inputs( - batch_size); - if (text_pair) { - tokenizer.BatchEncode(&batch_encode_inputs, - *text, - *text_pair, - is_split_into_words, - max_seq_len, - pad_to_max_seq_len); - } else { - tokenizer.BatchEncode(&batch_encode_inputs, - *text, - framework::Strings(), - is_split_into_words, - max_seq_len, - pad_to_max_seq_len); - } - - for (size_t i = 0; i < batch_size; ++i) { - size_t seq_len = batch_encode_inputs[i]["input_ids"].size(); - if (seq_len > batch_max_seq_len) { - batch_max_seq_len = seq_len; - } - } - - input_ids->Resize( - common::make_ddim({static_cast(batch_size), - static_cast(batch_max_seq_len)})); - auto* input_ids_data = input_ids->mutable_data(ctx.GetPlace()); - seg_ids->Resize( - common::make_ddim({static_cast(batch_size), - static_cast(batch_max_seq_len)})); - auto* seg_ids_data = seg_ids->mutable_data(ctx.GetPlace()); - - auto pad_token_id = tokenizer.GetPadTokenID(); - for (size_t i = 0; i < batch_size; i++) { - auto& encoder_input_ids = batch_encode_inputs[i]["input_ids"]; - auto& encoder_seg_ids = batch_encode_inputs[i]["token_type_ids"]; - const size_t& seq_len = encoder_input_ids.size(); - // Copy the memory - std::memcpy(input_ids_data + i * batch_max_seq_len, - encoder_input_ids.data(), - seq_len * sizeof(T)); - std::memcpy(seg_ids_data + i * batch_max_seq_len, - encoder_seg_ids.data(), - seq_len * sizeof(T)); - std::memset(input_ids_data + i * batch_max_seq_len + seq_len, - pad_token_id, - (batch_max_seq_len - seq_len) * sizeof(T)); - std::memset(seg_ids_data + i * batch_max_seq_len + seq_len, - pad_token_id, - (batch_max_seq_len - seq_len) * sizeof(T)); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index b5443b66351b24..2aed1c476e2c76 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -177,6 +177,7 @@ 'dgc', 'dpsgd', 'embedding_grad_sparse', + 'faster_tokenizer', 'ftrl', 'fused_adam_', 'fused_batch_norm_act_', diff --git a/paddle/phi/core/CMakeLists.txt b/paddle/phi/core/CMakeLists.txt index 5e8ff5d5fc2ef7..39045a10ff3c11 100644 --- a/paddle/phi/core/CMakeLists.txt +++ b/paddle/phi/core/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(distributed) add_subdirectory(memory) add_subdirectory(platform) add_subdirectory(framework) +add_subdirectory(vocab) if(WITH_GPU) proto_library(external_error_proto SRCS external_error.proto) diff --git a/paddle/phi/core/kernel_registry.cc b/paddle/phi/core/kernel_registry.cc index 172ad23e9302f0..0cc66aafd7be95 100644 --- a/paddle/phi/core/kernel_registry.cc +++ b/paddle/phi/core/kernel_registry.cc @@ -19,6 +19,7 @@ #include "paddle/phi/core/custom_kernel.h" #include "paddle/phi/core/kernel_utils.h" +#include "paddle/phi/core/vocab/string_array.h" namespace phi { @@ -88,6 +89,13 @@ void SetKernelArgsDef(const std::vector& args_type, default_tensor_layout, default_key.dtype(), arg_type); + } else if (arg_type == + std::type_index(typeid( + const paddle::optional&))) { // NOLINT + args_def->AppendInput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else if (arg_type == std::type_index(typeid( const std::vector&))) { // NOLINT @@ -95,6 +103,12 @@ void SetKernelArgsDef(const std::vector& args_type, default_tensor_layout, default_key.dtype(), arg_type); + } else if (arg_type == std::type_index(typeid( + const paddle::optional&))) { // NOLINT + args_def->AppendInput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else if (arg_type == std::type_index(typeid( const std::vector&))) { // NOLINT diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index 801a69498b4c92..d6fdc7cb80a4a6 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -27,6 +27,7 @@ #include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/string_tensor.h" #include "paddle/phi/core/tensor_array.h" +#include "paddle/phi/core/vocab/string_array.h" namespace phi { @@ -319,6 +320,7 @@ struct KernelImpl { PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(ExtendedTensor); + PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(ExtendedTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(ExtendedTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(TensorBase); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(SelectedRows); @@ -340,6 +342,9 @@ struct KernelImpl { PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(TensorArray); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(TensorArray); + PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(phi::Strings); + PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(phi::Strings); + /* Attribute Helpers */ PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(bool); diff --git a/paddle/phi/core/utils/type_info.cc b/paddle/phi/core/utils/type_info.cc index b419338401eeac..fe9878d685412a 100644 --- a/paddle/phi/core/utils/type_info.cc +++ b/paddle/phi/core/utils/type_info.cc @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/phi/core/string_tensor.h" #include "paddle/phi/core/tensor_array.h" #include "paddle/phi/core/utils/type_info.h" +#include "paddle/phi/core/vocab/string_array.h" namespace phi { @@ -50,6 +51,8 @@ template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; +template class TypeInfoTraits; +template class TypeInfoTraits; template class TypeInfoTraits; template class TypeInfoTraits; diff --git a/paddle/phi/core/vocab/CMakeLists.txt b/paddle/phi/core/vocab/CMakeLists.txt new file mode 100644 index 00000000000000..d0b065227d154d --- /dev/null +++ b/paddle/phi/core/vocab/CMakeLists.txt @@ -0,0 +1 @@ +collect_srcs(core_srcs SRCS string_array.cc) diff --git a/paddle/fluid/framework/phi_tensor_base_vector.h b/paddle/phi/core/vocab/phi_tensor_base_vector.h similarity index 92% rename from paddle/fluid/framework/phi_tensor_base_vector.h rename to paddle/phi/core/vocab/phi_tensor_base_vector.h index 1d775383de8090..f2389ba4826824 100644 --- a/paddle/fluid/framework/phi_tensor_base_vector.h +++ b/paddle/phi/core/vocab/phi_tensor_base_vector.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,8 +20,7 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/extended_tensor.h" -namespace paddle { -namespace framework { +namespace phi { template struct PhiVectorType; @@ -97,5 +96,14 @@ class PhiVector : public phi::ExtendedTensor, std::vector data_; }; +} // namespace phi + +namespace paddle { +namespace framework { +template +using PhiVector = phi::PhiVector; + +template +using PhiVectorType = phi::PhiVectorType; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/string_array.cc b/paddle/phi/core/vocab/string_array.cc similarity index 78% rename from paddle/fluid/framework/string_array.cc rename to paddle/phi/core/vocab/string_array.cc index 96aa8d04988aa4..4a9b8df9439fcf 100644 --- a/paddle/fluid/framework/string_array.cc +++ b/paddle/phi/core/vocab/string_array.cc @@ -1,26 +1,23 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/string_array.h" - +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/vocab/string_array.h" #include - #include - #include "glog/logging.h" -namespace paddle::framework { +namespace phi { std::wstring_convert> kConverter; @@ -100,4 +97,4 @@ void StringMapFromStream(std::istream& is, } } -} // namespace paddle::framework +} // namespace phi diff --git a/paddle/phi/core/vocab/string_array.h b/paddle/phi/core/vocab/string_array.h new file mode 100644 index 00000000000000..73cdcfd793470f --- /dev/null +++ b/paddle/phi/core/vocab/string_array.h @@ -0,0 +1,142 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/extended_tensor.h" +#include "paddle/phi/core/vocab/phi_tensor_base_vector.h" + +namespace phi { +template <> +struct PhiVectorType { + const char* type_name = "PhiVectorString"; +}; + +// Note(YuanRisheng): Vocab is mainly used for faster_tokenizer_op and we don't +// recommend widely use it. Because faster_tokenizer_op may be deleted in the +// future and this class will be deleted. + +class Vocab : public phi::ExtendedTensor, + public phi::TypeInfoTraits { + public: + Vocab() = default; + + Vocab(Vocab&& other) = default; + + Vocab(const Vocab& other) = default; + + Vocab& operator=(const Vocab& other) = default; + + Vocab& operator=(Vocab&& other) = default; + + Vocab& operator=( + const std::unordered_map& other) { + this->data_ = other; + return *this; + } + + /// \brief Destroy the Vocab and release exclusive resources. + virtual ~Vocab() = default; + + public: + /// \brief Returns the name of the class for type traits. + /// \return The name of the class. + static const char* name() { return "Vocab"; } + + size_t size() const { return data_.size(); } + + void clear() { data_.clear(); } + + void emplace(const std::wstring& key, std::int32_t value) { + data_.emplace(key, value); + } + + std::int32_t at(const std::wstring& key) { return data_.at(key); } + + std::int32_t at(const std::wstring& key) const { return data_.at(key); } + + std::unordered_map::iterator find( + const std::wstring& key) { + return data_.find(key); + } + + std::unordered_map::const_iterator find( + const std::wstring& key) const { + return data_.find(key); + } + + std::unordered_map::iterator begin() { + return data_.begin(); + } + + std::unordered_map::const_iterator begin() const { + return data_.begin(); + } + + std::unordered_map::iterator end() { + return data_.end(); + } + + std::unordered_map::const_iterator end() const { + return data_.end(); + } + + private: + std::unordered_map data_; +}; + +// Note(YuanRisheng): PhiVector is essentially a vector that only used for PHI +// Kernel. It can be used when you define a non-tensor type that needs to be +// stored in a vector as PHI kernel argument. + +using String = std::string; +using Strings = PhiVector; + +// Convert the std::string type to the std::string type. +bool ConvertStrToWstr(const std::string& src, std::wstring* res); +// Convert the std::wstring type to the std::string type. +void ConvertWstrToStr(const std::wstring& src, std::string* res); +// Normalization Form Canonical Decomposition. +void NFD(const std::string& s, std::string* ret); + +// Write the data which is type of +// std::unordered_map to ostream. +void StringMapToStream(std::ostream& os, + const std::unordered_map& data); + +// Read the data which is type of +// std::unordered_map from istream. +void StringMapFromStream(std::istream& is, + std::unordered_map* data); +} // namespace phi + +namespace paddle { +namespace framework { +using Vocab = phi::Vocab; +using Strings = phi::Strings; +using String = phi::String; +using phi::ConvertStrToWstr; +using phi::ConvertWstrToStr; +using phi::NFD; +using phi::StringMapFromStream; +using phi::StringMapToStream; +} // namespace framework +} // namespace paddle diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 5925634f8d87cf..c9282b43d4e5e3 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -742,6 +742,22 @@ void InstanceNormInferMeta(const MetaTensor& x, } } +void FasterTokenizerInferMeta(const MetaTensor& vocab, + const MetaTensor& text, + const MetaTensor& text_pair, + bool do_lower_case, + bool is_split_into_words, + int max_seq_len, + bool pad_to_max_seq_len, + MetaTensor* input_ids, + MetaTensor* segment_ids, + MetaConfig config) { + input_ids->set_dims({-1, -1}); + segment_ids->set_dims({-1, -1}); + input_ids->set_dtype(phi::DataType::INT64); + segment_ids->set_dtype(phi::DataType::INT64); +} + void GlobalGatherInferMeta(const MetaTensor& x, const MetaTensor& local_count, const MetaTensor& global_count, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index f5f6307a8fa0d0..a86e06239d518b 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -152,6 +152,17 @@ void InstanceNormInferMeta(const MetaTensor& x, MetaTensor* saved_variance, MetaConfig config = MetaConfig()); +void FasterTokenizerInferMeta(const MetaTensor& vocab, + const MetaTensor& text, + const MetaTensor& text_pair, + bool do_lower_case, + bool is_split_into_words, + int max_seq_len, + bool pad_to_max_seq_len, + MetaTensor* input_ids, + MetaTensor* segment_ids, + MetaConfig config = MetaConfig()); + void GlobalGatherInferMeta(const MetaTensor& x, const MetaTensor& local_count, const MetaTensor& global_count, diff --git a/paddle/phi/kernels/cpu/faster_tokenizer_kernel.cc b/paddle/phi/kernels/cpu/faster_tokenizer_kernel.cc new file mode 100644 index 00000000000000..e27db0c181ac68 --- /dev/null +++ b/paddle/phi/kernels/cpu/faster_tokenizer_kernel.cc @@ -0,0 +1,617 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include "glog/logging.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/core/vocab/string_array.h" + +namespace phi { + +using std::endl; +using std::ifstream; +using std::int64_t; +using std::shared_ptr; +using std::size_t; +using std::string; +using std::unordered_map; +using std::unordered_set; +using std::vector; +using std::wcout; +using std::wstring; +using Strings = paddle::framework::Strings; + +inline bool IsControl(const wchar_t& ch); +inline bool IsChineseChar(const wchar_t& ch); +inline bool IsWhiteSpace(const wchar_t& ch); + +using InvVocab = unordered_map; + +class BasicTokenizer { + public: + explicit BasicTokenizer(bool do_lower_case = true); + void Tokenize(const string& text, vector* res) const; + + private: + wchar_t do_lower_case(wchar_t ch) const; + + bool do_lower_case_; +}; + +class WordPieceTokenizer { + public: + explicit WordPieceTokenizer(const paddle::framework::Vocab* vocab, + const wstring& unk_token = L"[UNK]", + const size_t max_input_chars_per_word = 100); + void Tokenize(const wstring& text, vector* output) const; + + private: + const paddle::framework::Vocab* vocab_; + wstring unk_token_{L"[UNK]"}; + int64_t unk_token_id_; + size_t max_input_chars_per_word_; +}; + +class BertTokenizer { + public: + explicit BertTokenizer(const paddle::framework::Vocab* vocab, + bool do_lower_case = false, + const wstring& unk_token = L"[UNK]", + const wstring& pad_token = L"[PAD]", + const wstring& cls_token = L"[CLS]", + const wstring& mask_token = L"[MASK]", + const wstring& sep_token = L"[SEP]", + const string& padding_site = "right"); + + void Tokenize(const string& text, vector* split_tokens) const; + void BuildInputsWithSpecialTokens( + vector* res, + const vector& token_ids_0, + const vector& token_ids_1 = vector()) const; + void CreateTokenTypeIdsFromSequences( + vector* token_type_ids, + const vector& token_ids_0, + const vector& token_ids_1 = vector()) const; + void TruncateSequence(vector* ids, + vector* pair_ids, + const size_t num_tokens_to_remove = 0, + const size_t stride = 0) const; + int64_t GetNumSpecialTokensToAdd(const bool pair = false) const; + int Encode(unordered_map>* encoded_inputs, + const string& text, + const string& text_pair = "", + bool is_split_into_words = false, + const size_t max_seq_len = 0, + bool pad_to_max_seq_len = false) const; + void BatchEncode( + vector>>* batch_encode_inputs, + const Strings& batch_text, + const Strings& batch_text_pair = Strings(), + bool is_split_into_words = false, + const size_t max_seq_len = 0, + bool pad_to_max_seq_len = false) const; + + int64_t GetPadTokenID() const; + + private: + bool do_lower_case_; + wstring unk_token_, pad_token_, cls_token_, mask_token_, sep_token_; + string padding_site_; + const paddle::framework::Vocab* vocab_; + BasicTokenizer basic_tokenizer_; + WordPieceTokenizer word_piece_tokenizer_; + int64_t unk_token_id_, cls_token_id_, mask_token_id_, pad_token_id_, + sep_token_id_; + vector all_special_tokens_; + unordered_set all_special_token_ids_; + InvVocab inv_vocab_; +}; + +const wstring kStripChars = L" \t\n\r\v\f"; + +inline bool IsControl(const wchar_t& ch) { + if (ch == L'\t' || ch == L'\n' || ch == L'\r') return false; + auto cat = utf8proc_category(ch); + if (cat == UTF8PROC_CATEGORY_CC || cat == UTF8PROC_CATEGORY_CF) return true; + return false; +} + +inline bool IsChineseChar(const wchar_t& ch) { + if ((ch >= 0x4E00 && ch <= 0x9FFF) || (ch >= 0x3400 && ch <= 0x4DBF) || + (ch >= 0x20000 && ch <= 0x2A6DF) || (ch >= 0x2A700 && ch <= 0x2B73F) || + (ch >= 0x2B740 && ch <= 0x2B81F) || (ch >= 0x2B820 && ch <= 0x2CEAF) || + (ch >= 0xF900 && ch <= 0xFAFF) || (ch >= 0x2F800 && ch <= 0x2FA1F)) + return true; + return false; +} + +inline bool IsWhiteSpace(const wchar_t& ch) { + if (ch == L' ' || ch == L'\t' || ch == L'\n' || ch == L'\r') return true; + auto cat = utf8proc_category(ch); + if (cat == UTF8PROC_CATEGORY_ZS) return true; + return false; +} + +inline bool IsPunctuation(const wchar_t& ch) { + if ((ch >= 33 && ch <= 47) || (ch >= 58 && ch <= 64) || + (ch >= 91 && ch <= 96) || (ch >= 123 && ch <= 126)) + return true; + auto cat = utf8proc_category(ch); + if (cat == UTF8PROC_CATEGORY_PD || cat == UTF8PROC_CATEGORY_PS || + cat == UTF8PROC_CATEGORY_PE || cat == UTF8PROC_CATEGORY_PC || + cat == UTF8PROC_CATEGORY_PO // sometimes ¶ belong SO + || cat == UTF8PROC_CATEGORY_PI || cat == UTF8PROC_CATEGORY_PF) + return true; + return false; +} + +BasicTokenizer::BasicTokenizer(bool do_lower_case /* = true */) + : do_lower_case_(do_lower_case) {} + +wchar_t BasicTokenizer::do_lower_case(wchar_t ch) const { + wchar_t new_ch = utf8proc_tolower(ch); + return new_ch; +} + +void BasicTokenizer::Tokenize(const string& text, vector* res) const { + std::wstring unicode_text; + bool status = phi::ConvertStrToWstr(text, &unicode_text); + if (!status) { + // String is converted into wstring failedly. + return; + } + std::wstring cache_text = L""; + auto PushCacheText = [&]() { + if (!cache_text.empty()) { + res->emplace_back(cache_text); + cache_text = L""; + } + }; + for (auto& ch : unicode_text) { + if (ch == 0 || ch == 0xfffd || IsControl(ch)) { + continue; + } + if (do_lower_case_) { + ch = do_lower_case(ch); + } + if (IsChineseChar(ch) || IsPunctuation(ch)) { + PushCacheText(); + res->emplace_back(std::wstring{ch}); + } else if (IsWhiteSpace(ch)) { + PushCacheText(); + } else { + cache_text += ch; + } + } + PushCacheText(); +} + +WordPieceTokenizer::WordPieceTokenizer( + const paddle::framework::Vocab* vocab, + const wstring& unk_token /* = L"[UNK]"*/, + const size_t max_input_chars_per_word /* = 100 */) + : vocab_(vocab), + unk_token_(unk_token), + max_input_chars_per_word_(max_input_chars_per_word) { + unk_token_id_ = vocab_->at(unk_token_); +} + +void WordPieceTokenizer::Tokenize(const wstring& text, + vector* token_ids) const { + size_t len = text.size(); + if (len > max_input_chars_per_word_) { + token_ids->emplace_back(unk_token_id_); + return; + } + + auto it = vocab_->find(text); + if (it != vocab_->end()) { + token_ids->emplace_back(it->second); + return; + } + + size_t start = 0; + vector wordpiece_ids; + while (start < len) { + size_t end = len; + std::wstring cur_substr; + int64_t cur_substr_id = 0; + while (start < end) { + std::wstring sub = text.substr(start, end - start); + if (start > 0) { + sub.insert(0, L"##"); + } + auto it = vocab_->find(sub); + if (it != vocab_->end()) { + cur_substr = sub; + cur_substr_id = it->second; + break; + } + end -= 1; + } + + if (cur_substr.empty()) { + token_ids->emplace_back(unk_token_id_); + return; + } else { + start = end; + wordpiece_ids.emplace_back(cur_substr_id); + } + } + for (auto& token_id : wordpiece_ids) { + token_ids->emplace_back(token_id); + } +} + +BertTokenizer::BertTokenizer(const paddle::framework::Vocab* vocab, + bool do_lower_case /* = false */, + const wstring& unk_token /* = L"[UNK]" */, + const wstring& pad_token /* = L"[PAD]" */, + const wstring& cls_token /* = L"[CLS]" */, + const wstring& mask_token /* = L"[MASK]" */, + const wstring& sep_token /* = L"[SEP]" */, + const string& padding_site /* = "right" */) + : do_lower_case_(do_lower_case), + unk_token_(unk_token), + pad_token_(pad_token), + cls_token_(cls_token), + mask_token_(mask_token), + sep_token_(sep_token), + padding_site_(padding_site), + vocab_(vocab), + basic_tokenizer_(do_lower_case_), + word_piece_tokenizer_(vocab_, unk_token) { + unk_token_id_ = vocab_->at(unk_token_); + pad_token_id_ = vocab_->at(pad_token_); + cls_token_id_ = vocab_->at(cls_token_); + mask_token_id_ = vocab_->at(mask_token_); + sep_token_id_ = vocab_->at(sep_token_); + + all_special_tokens_ = vector( + {unk_token_, pad_token_, cls_token_, mask_token_, sep_token_}); + all_special_token_ids_ = unordered_set({unk_token_id_, + pad_token_id_, + cls_token_id_, + mask_token_id_, + sep_token_id_}); +} + +void BertTokenizer::Tokenize(const string& text, + vector* split_token_ids) const { + std::vector tmp_tokens; + basic_tokenizer_.Tokenize(text, &tmp_tokens); + if (tmp_tokens.empty()) return; + split_token_ids->reserve(tmp_tokens.size()); + for (auto& w_token : tmp_tokens) { + const auto& vec_size = w_token.size(); + if (vec_size == 1) { + if (IsChineseChar(w_token[0])) { + auto vocab_it = vocab_->find(w_token); + if (vocab_it != vocab_->end()) { + split_token_ids->emplace_back(vocab_it->second); + } else { + split_token_ids->emplace_back(unk_token_id_); + } + } else { + word_piece_tokenizer_.Tokenize(w_token, split_token_ids); + } + } else if (vec_size > 1) { + word_piece_tokenizer_.Tokenize(w_token, split_token_ids); + } else { + continue; + } + } +} + +void BertTokenizer::BuildInputsWithSpecialTokens( + vector* inputs, + const vector& token_ids_0, + const vector& token_ids_1 /* = vector() */) const { + if (token_ids_1.empty()) { + inputs->clear(); + inputs->resize(token_ids_0.size() + 2); + inputs->at(0) = cls_token_id_; + size_t i = 1; + for (auto& token_id : token_ids_0) { + inputs->at(i) = token_id; + ++i; + } + inputs->at(i) = sep_token_id_; + } else { + inputs->clear(); + inputs->resize(token_ids_0.size() + token_ids_1.size() + 3); + inputs->at(0) = cls_token_id_; + size_t i = 1; + for (auto& token_id : token_ids_0) { + inputs->at(i) = token_id; + ++i; + } + inputs->at(i) = sep_token_id_; + ++i; + for (auto& token_id : token_ids_1) { + inputs->at(i) = token_id; + ++i; + } + inputs->at(i) = sep_token_id_; + } +} + +int64_t BertTokenizer::GetNumSpecialTokensToAdd(const bool pair) const { + if (pair) { + return 3; + } else { + return 2; + } +} + +void BertTokenizer::CreateTokenTypeIdsFromSequences( + vector* token_type_ids, + const vector& token_ids_0, + const vector& token_ids_1 /* = vector() */) const { + if (token_ids_1.empty()) { + vector tmp(token_ids_0.size() + 2, 0); + token_type_ids->swap(tmp); + } else { + vector tmp(token_ids_0.size() + token_ids_1.size() + 3, 0); + for (size_t i = token_ids_0.size() + 2; i < tmp.size(); i++) { + tmp[i] = 1; + } + token_type_ids->swap(tmp); + } +} + +void BertTokenizer::TruncateSequence( + vector* ids, + vector* pair_ids, + const size_t num_tokens_to_remove /* = 0 */, + const size_t stride /* = 0 */) const { + for (size_t i = 0; i < num_tokens_to_remove; i++) { + if ((pair_ids->empty()) || (ids->size() > pair_ids->size())) { + ids->pop_back(); + } else { + pair_ids->pop_back(); + } + } +} + +int64_t BertTokenizer::GetPadTokenID() const { return pad_token_id_; } + +int BertTokenizer::Encode( + unordered_map>* encoded_inputs, + const string& text, + const string& text_pair /* = "" */, + bool is_split_into_words /* = false */, + const size_t max_seq_len /* = 0 */, + bool pad_to_max_seq_len /* = false */) const { + vector ids; + vector pair_ids; + if (!is_split_into_words) { + Tokenize(text, &ids); + if (ids.empty()) return 0; + if (!text_pair.empty()) { + Tokenize(text_pair, &pair_ids); + if (pair_ids.empty()) return 0; + } + } else { + std::wstring unicode_text; + bool status_a = phi::ConvertStrToWstr(text, &unicode_text); + if (!status_a) { + return 0; + } + for (size_t i = 0; i < unicode_text.size(); i++) { + wstring token = unicode_text.substr(i, 1); + auto it = vocab_->find(token); + if (it != vocab_->end()) { + ids.emplace_back(it->second); + } else { + ids.emplace_back(unk_token_id_); + } + } + } + + bool pair = false; + if (!pair_ids.empty()) { + pair = true; + } + + size_t len_ids = ids.size(); + size_t len_pair_ids = pair_ids.size(); + + // Truncation: Handle max sequence length + // If max_seq_len == 0, then do nothing and keep the real length. + // If max_seq_len > 0 and + // all the input sequence len is over the max_seq_len, + // then we truncate it. + size_t total_len = len_ids + len_pair_ids + GetNumSpecialTokensToAdd(pair); + if (max_seq_len > 0 && total_len > max_seq_len) { + TruncateSequence(&ids, &pair_ids, total_len - max_seq_len); + } + + // Add special tokens + vector sequence; + BuildInputsWithSpecialTokens(&sequence, ids, pair_ids); + size_t seq_len = sequence.size(); + vector token_type_ids; + CreateTokenTypeIdsFromSequences(&token_type_ids, ids, pair_ids); + + // Build output dictionary + encoded_inputs->emplace("input_ids", sequence); + encoded_inputs->emplace("token_type_ids", token_type_ids); + // Check lengths + if (max_seq_len > 0 && seq_len > max_seq_len) { + VLOG(3) << "There is something wrong with the input sequence length." + " Please check it."; + // Failed. + return 0; + } + + // Padding + bool needs_to_be_padded = false; + if (pad_to_max_seq_len && max_seq_len > 0 && (seq_len < max_seq_len)) { + needs_to_be_padded = true; + } + + if (needs_to_be_padded) { + int64_t difference = static_cast(max_seq_len - seq_len); + size_t pad_start = max_seq_len - 1 - difference; + encoded_inputs->at("token_type_ids").resize(max_seq_len); + for (size_t i = max_seq_len - 1; i > pad_start; i--) { + encoded_inputs->at("token_type_ids")[i] = pad_token_id_; + } + + encoded_inputs->at("input_ids").resize(max_seq_len); + for (size_t i = max_seq_len - 1; i > pad_start; i--) { + encoded_inputs->at("input_ids")[i] = pad_token_id_; + } + } + return 1; +} + +void BertTokenizer::BatchEncode( + vector>>* batch_encode_inputs, + const Strings& batch_text, + const Strings& batch_text_pair /* = vector() */, + bool is_split_into_words /* = false */, + const size_t max_seq_len /* = 0 */, + bool pad_to_max_seq_len /* = false */) const { + bool has_text_pair = false; + if (batch_text_pair.size() != 0) { + has_text_pair = true; + } + + size_t batch_size = batch_text.size(); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (size_t i = 0; i < batch_size; i++) { + unordered_map> res; + if (has_text_pair) { + auto status = Encode(&res, + batch_text[i], + batch_text_pair[i], + is_split_into_words, + max_seq_len, + pad_to_max_seq_len); + if (!status) { + res["input_ids"] = + std::vector{cls_token_id_, sep_token_id_, cls_token_id_}; + res["token_type_ids"] = std::vector{0, 0, 1}; + } + } else { + auto status = Encode(&res, + batch_text[i], + {}, + is_split_into_words, + max_seq_len, + pad_to_max_seq_len); + + if (!status) { + res["input_ids"] = std::vector{cls_token_id_, sep_token_id_}; + res["token_type_ids"] = std::vector{0, 0}; + } + } + batch_encode_inputs->at(i) = std::move(res); + } +} + +template +void FasterTokenizerKernel(const Context& dev_ctx, + const phi::ExtendedTensor& vocab_in, + const phi::ExtendedTensor& text_in, + const paddle::optional& text_pair_in, + bool do_lower_case, + bool is_split_into_words, + int max_seq_len, + bool pad_to_max_seq_len, + DenseTensor* input_ids, + DenseTensor* segment_ids) { + const auto* vocab = + reinterpret_cast(&vocab_in); + const auto* text = reinterpret_cast(&text_in); + const auto* text_pair = + reinterpret_cast(text_pair_in.get_ptr()); + auto* seg_ids = segment_ids; + if (text_pair && text->size() != text_pair->size()) { + VLOG(3) << "The input text(list[str]) and text pair (list[str]) must" + << "be the same number of text sequence. Please check the input!"; + return; + } + + BertTokenizer tokenizer(vocab, do_lower_case); + size_t batch_max_seq_len = 0; + size_t batch_size = text->size(); + + vector>> batch_encode_inputs( + batch_size); + if (text_pair) { + tokenizer.BatchEncode(&batch_encode_inputs, + *text, + *text_pair, + is_split_into_words, + max_seq_len, + pad_to_max_seq_len); + } else { + tokenizer.BatchEncode(&batch_encode_inputs, + *text, + Strings(), + is_split_into_words, + max_seq_len, + pad_to_max_seq_len); + } + + for (size_t i = 0; i < batch_size; ++i) { + size_t seq_len = batch_encode_inputs[i]["input_ids"].size(); + if (seq_len > batch_max_seq_len) { + batch_max_seq_len = seq_len; + } + } + + input_ids->Resize( + common::make_ddim({static_cast(batch_size), + static_cast(batch_max_seq_len)})); + auto* input_ids_data = dev_ctx.template Alloc(input_ids); + seg_ids->Resize(common::make_ddim({static_cast(batch_size), + static_cast(batch_max_seq_len)})); + auto* seg_ids_data = dev_ctx.template Alloc(seg_ids); + + auto pad_token_id = tokenizer.GetPadTokenID(); + for (size_t i = 0; i < batch_size; i++) { + auto& encoder_input_ids = batch_encode_inputs[i]["input_ids"]; + auto& encoder_seg_ids = batch_encode_inputs[i]["token_type_ids"]; + const size_t& seq_len = encoder_input_ids.size(); + // Copy the memory + std::memcpy(input_ids_data + i * batch_max_seq_len, + encoder_input_ids.data(), + seq_len * sizeof(T)); + std::memcpy(seg_ids_data + i * batch_max_seq_len, + encoder_seg_ids.data(), + seq_len * sizeof(T)); + std::memset(input_ids_data + i * batch_max_seq_len + seq_len, + pad_token_id, + (batch_max_seq_len - seq_len) * sizeof(T)); + std::memset(seg_ids_data + i * batch_max_seq_len + seq_len, + pad_token_id, + (batch_max_seq_len - seq_len) * sizeof(T)); + } +} +} // namespace phi + +PD_REGISTER_KERNEL( + faster_tokenizer, CPU, ALL_LAYOUT, phi::FasterTokenizerKernel, int64_t) {} diff --git a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml index b427ae205d970f..a4d43c5d36dd57 100644 --- a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml @@ -1026,6 +1026,17 @@ data_type : logits backward: c_softmax_with_cross_entropy_grad +- op: faster_tokenizer + args: (Tensor vocab, Tensor text, Tensor text_pair, bool do_lower_case = false, + bool is_split_into_words = false, int max_seq_len = 0, bool pad_to_max_seq_len + = false) + output: Tensor (input_ids), Tensor (segment_ids) + infer_meta: + func: FasterTokenizerInferMeta + kernel: + func: faster_tokenizer + optional: text_pair + - op: fused_attention args: (Tensor x, Tensor ln_scale, Tensor ln_bias, Tensor qkv_weight, Tensor qkv_bias, Tensor cache_kv, Tensor src_mask, Tensor out_linear_weight, Tensor out_linear_bias, Tensor ln_scale_2, Tensor ln_bias_2, int num_heads, bool transpose_qkv_wb, bool pre_layer_norm, float epsilon, float attn_dropout_rate, bool is_test, bool attn_dropout_fix_seed, int attn_dropout_seed, str attn_dropout_implementation, float dropout_rate, bool dropout_fix_seed, int dropout_seed, str dropout_implementation, float ln_epsilon, bool add_residual, int ring_id) output: Tensor(ln_mean), Tensor(ln_var), Tensor(ln_out), Tensor(qkv_out), Tensor(qkv_bias_out), Tensor(transpose_out_2), Tensor(qk_out), Tensor(qktv_out), Tensor(softmax_out), Tensor(attn_dropout_mask_out), Tensor(attn_dropout_out), Tensor(src_mask_out), Tensor(fmha_out), Tensor(out_linear_out), Tensor(dropout_mask_out), Tensor(ln_mean_2), Tensor(ln_var_2), Tensor(bias_dropout_residual_out), Tensor(cache_kv_out), Tensor(out) diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index fbf5d8f15f2a9f..899a43d6e8287f 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -1249,6 +1249,12 @@ out_scale : OutScale out_scales : OutScales +- op : faster_tokenizer + inputs: + {vocab : Vocab, text : Text, text_pair : TextPair} + outputs: + {input_ids : InputIds, segment_ids : SegmentIds} + - op : fc inputs : input : Input