forked from QwenLM/qwen.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
qwen: extract QwenTokenizer as a seperate lib module
use older stable release version of abseil-cpp issue link: abseil/abseil-cpp#1536
- Loading branch information
sriduth
committed
Jul 30, 2024
1 parent
6c9cccf
commit bf451ab
Showing
5 changed files
with
229 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
#include "qwen_tokenizer.h" | ||
#include "tiktoken.h" | ||
#include "base64.h" | ||
#include "unordered_dense.h" | ||
#include <fcntl.h> | ||
#include <fstream> | ||
#include <numeric> | ||
#include <random> | ||
#include <thread> | ||
#include <sys/stat.h> | ||
#include <memory> | ||
|
||
|
||
#ifdef __has_include | ||
#if __has_include(<unistd.h>) | ||
#include <unistd.h> | ||
#if defined(_POSIX_MAPPED_FILES) | ||
#include <sys/mman.h> | ||
#endif | ||
#if defined(_POSIX_MEMLOCK_RANGE) | ||
#include <sys/resource.h> | ||
#endif | ||
#endif | ||
#endif | ||
|
||
#if defined(_WIN32) | ||
#define WIN32_LEAN_AND_MEAN | ||
#ifndef NOMINMAX | ||
#define NOMINMAX | ||
#endif | ||
#include <io.h> | ||
#include <stdio.h> | ||
#include <windows.h> | ||
#endif | ||
|
||
|
||
namespace qwen { | ||
|
||
static constexpr size_t MB = 1024 * 1024; | ||
|
||
static const std::string PAT_STR = R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?:$|[^\S])|\s+)"; | ||
|
||
class LogMessageFatal { | ||
public: | ||
LogMessageFatal(const char* file, int line) { oss_ << file << ':' << line << ' '; } | ||
[[noreturn]] ~LogMessageFatal() noexcept(false) { throw std::runtime_error(oss_.str()); } | ||
auto stream() -> std::ostringstream& { return oss_; } | ||
|
||
private: | ||
std::ostringstream oss_; | ||
}; | ||
|
||
#define QWEN_THROW ::qwen::LogMessageFatal(__FILE__, __LINE__).stream() | ||
#define QWEN_CHECK(cond) \ | ||
if (!(cond)) \ | ||
QWEN_THROW << "check failed (" #cond ") " | ||
|
||
|
||
|
||
class QwenTokenizer : public IQwenTokenizer { | ||
public: | ||
|
||
QwenTokenizer(const std::string& tiktoken_path, const QwenTokenizerConfig& config); | ||
|
||
auto encode(const std::string& text, int max_length) const->std::vector<int>; | ||
|
||
auto decode(const std::vector<int>& ids) const->std::string; | ||
|
||
auto encode_history(const std::vector<std::string>& history, int max_length) const->std::vector<int>; | ||
|
||
auto build_prompt(const std::vector<std::string>& history) const->std::string; | ||
|
||
auto is_special_id(int id) const -> bool; | ||
|
||
tiktoken::tiktoken tokenizer; | ||
int eos_token_id; | ||
int im_start_id; | ||
int im_end_id; | ||
}; | ||
|
||
std::shared_ptr<IQwenTokenizer> IQwenTokenizer::make(const std::string& tiktoken_path, const QwenTokenizerConfig& config) { | ||
return std::make_shared<QwenTokenizer>(tiktoken_path, config); | ||
} | ||
|
||
static std::pair<std::string, int> _parse(const std::string &line) { | ||
auto pos = line.find(" "); | ||
if (pos == std::string::npos) { | ||
throw std::runtime_error("invalid encoder line: " + line); | ||
} | ||
|
||
auto token = base64::decode({line.data(), pos}); | ||
int rank = 0; | ||
try { | ||
rank = std::stoul(line.substr(pos + 1)); | ||
} catch (const std::exception &) { | ||
throw std::runtime_error("invalid encoder rank: " + line); | ||
} | ||
|
||
return {std::move(token), rank}; | ||
} | ||
|
||
QwenTokenizer::QwenTokenizer(const std::string & tiktoken_path, const QwenTokenizerConfig &config) { | ||
std::ifstream file(tiktoken_path); | ||
if (!file) { | ||
throw std::runtime_error("failed to open encoder file: " + tiktoken_path); | ||
} | ||
|
||
ankerl::unordered_dense::map<std::string, int> encoder; | ||
std::string line; | ||
while (std::getline(file, line)) { | ||
auto [token, rank] = _parse(line); | ||
|
||
if (!encoder.emplace(std::move(token), rank).second) { | ||
throw std::runtime_error("duplicate item: " + line); | ||
} | ||
} | ||
|
||
std::vector<std::string> special_tokens_s{"<|endoftext|>", "<|im_start|>", "<|im_end|>"}; | ||
char buffer[12]; | ||
for (size_t i = 0; i < 205; i++) { | ||
snprintf(buffer, 12, "<|extra_%zu|>", i); | ||
special_tokens_s.push_back(buffer); | ||
} | ||
size_t encoder_size = encoder.size(); | ||
ankerl::unordered_dense::map<std::string, int> special_tokens; | ||
special_tokens.reserve(special_tokens_s.size()); | ||
for (size_t i = 0; i < special_tokens_s.size(); i++) { | ||
special_tokens[special_tokens_s[i]] = encoder_size + i; | ||
} | ||
|
||
tokenizer = tiktoken::tiktoken(std::move(encoder), special_tokens, PAT_STR); | ||
eos_token_id = config.eos_token_id; | ||
im_start_id = config.im_start_id; | ||
im_end_id = config.im_end_id; | ||
} | ||
|
||
auto QwenTokenizer::build_prompt(const std::vector<std::string> &history) const -> std::string { | ||
QWEN_CHECK(history.size() % 2 == 1) << "invalid history size " << history.size(); | ||
|
||
std::ostringstream oss_prompt; | ||
oss_prompt << "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"; | ||
for (size_t i = 0; i < history.size() - 1; i += 2) { | ||
oss_prompt << "\n<|im_start|>user\n" << history[i] << "<|im_end|>\n<|im_start|>" << history[i + 1] << "<|im_end|>"; | ||
} | ||
oss_prompt << "\n<|im_start|>user\n" << history.back() << "<|im_end|>\n<|im_start|>assistant\n"; | ||
|
||
return oss_prompt.str(); | ||
} | ||
|
||
auto QwenTokenizer::encode(const std::string &text, int max_length) const -> std::vector<int> { | ||
auto ids = tokenizer.encode(text); | ||
if ((int)ids.size() > max_length) { | ||
ids.erase(ids.begin(), ids.end() - max_length); | ||
} | ||
return ids; | ||
} | ||
|
||
auto QwenTokenizer::decode(const std::vector<int> &ids) const -> std::string { | ||
std::vector<int> normal_ids(ids); | ||
normal_ids.erase(std::remove_if(normal_ids.begin(), normal_ids.end(), [this](int id) { return is_special_id(id); }), | ||
normal_ids.end()); | ||
auto text = tokenizer.decode(normal_ids); | ||
return text; | ||
} | ||
|
||
auto QwenTokenizer::encode_history( | ||
const std::vector<std::string> &history, int max_length | ||
) const -> std::vector<int> { | ||
std::string prompt = build_prompt(history); | ||
std::vector<int> input_ids = encode(prompt, max_length); | ||
return input_ids; | ||
} | ||
|
||
auto QwenTokenizer::is_special_id(int id) const -> bool { | ||
return id == eos_token_id || id == im_start_id || id == im_end_id; | ||
} | ||
|
||
} // namespace qwen |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#pragma once | ||
|
||
#include <sstream> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <vector> | ||
|
||
namespace qwen { | ||
|
||
struct QwenTokenizerConfig { | ||
// for tokenizer | ||
int eos_token_id; | ||
int pad_token_id; | ||
int im_start_id; | ||
int im_end_id; | ||
}; | ||
|
||
class IQwenTokenizer { | ||
public: | ||
|
||
virtual auto encode(const std::string& text, int max_length) const->std::vector<int> = 0; | ||
|
||
virtual auto decode(const std::vector<int> &ids) const -> std::string = 0; | ||
|
||
virtual auto encode_history(const std::vector<std::string> &history, int max_length) const -> std::vector<int> = 0; | ||
|
||
virtual auto build_prompt(const std::vector<std::string> &history) const -> std::string = 0; | ||
|
||
virtual auto is_special_id(int id) const -> bool = 0; | ||
|
||
static std::shared_ptr<IQwenTokenizer> make(const std::string& tiktoken_path, const QwenTokenizerConfig& config); | ||
}; | ||
|
||
|
||
} // namespace qwen |
Submodule abseil-cpp
updated
446 files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters