Skip to content

Commit

Permalink
Merge pull request #7 from cgbur/utf8-support
Browse files Browse the repository at this point in the history
Support utf8 prompts
  • Loading branch information
cgbur authored Aug 16, 2023
2 parents a973fd1 + fddde56 commit fb99f3f
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions src/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ const Tokenizer = struct {
/// responsible for freeing the returned list.
fn encode(self: *const Tokenizer, input: []const u8, allocator: Allocator) ![]u32 {
var token_buf: []u32 = try allocator.alloc(u32, input.len); // worst case is every byte is a token
var token_buf_len: usize = token_buf.len;

const max_allowed_token_len = 128;
if (self.max_token_len > max_allowed_token_len) {
Expand All @@ -230,9 +229,18 @@ const Tokenizer = struct {
var fba = std.heap.FixedBufferAllocator.init(&buffer);
const fixed_allocator = fba.allocator();

// first encode every byte as a token
for (input, token_buf) |byte, *token| {
token.* = self.lookup(&[1]u8{byte}) orelse return error.TokenNotFound;
var utf_encoded_buffer: [4]u8 = undefined;
var idx: usize = 0;
var token_end_idx: usize = 0;
while (idx < input.len) {
const utf_len = try std.unicode.utf8ByteSequenceLength(input[idx]);
var codepoint: u21 = try std.unicode.utf8Decode(input[idx..][0..utf_len]);
const encoded_len = try std.unicode.utf8Encode(codepoint, &utf_encoded_buffer);
token_buf[token_end_idx] = self.lookup(utf_encoded_buffer[0..encoded_len]) orelse {
return error.TokenNotFound;
};
token_end_idx += 1; // we have one more token now
idx += utf_len; // skip over the utf8 sequence
}

while (true) {
Expand All @@ -241,7 +249,7 @@ const Tokenizer = struct {
var best_idx: ?usize = null;

// find the best token to merge
for (0..token_buf_len - 1) |i| {
for (0..token_end_idx - 1) |i| {
// check if we are able to merge the token at i with the next token
const catted = try std.mem.concat(fixed_allocator, u8, &[_][]u8{
self.tokens[token_buf[i]],
Expand All @@ -260,18 +268,18 @@ const Tokenizer = struct {
if (best_idx) |best| {
// merge the best token and shift the rest of the tokens down
token_buf[best] = best_id;
std.mem.copyForwards(u32, token_buf[best + 1 ..], token_buf[best + 2 .. token_buf_len]);
token_buf_len -= 1;
std.mem.copyForwards(u32, token_buf[best + 1 ..], token_buf[best + 2 .. token_end_idx]);
token_end_idx -= 1;
} else {
// if we didn't find any tokens to merge, we are done
break;
}
}

if (!allocator.resize(token_buf, token_buf_len)) {
if (!allocator.resize(token_buf, token_end_idx)) {
return error.OutOfMemory;
}
return token_buf[0..token_buf_len];
return token_buf[0..token_end_idx];
}
};

Expand Down Expand Up @@ -987,6 +995,7 @@ test "bpe" {
const tokenizer = try Tokenizer.fromFile("tokenizer.bin", 32000, allocator);
defer tokenizer.deinit(allocator);

try std.testing.expect(tokenizer.lookup("æ") == 233);
try std.testing.expect(std.mem.eql(u8, tokenizer.tokens[100], "a"));
try std.testing.expect(tokenizer.max_token_len == 27);
try std.testing.expect(tokenizer.tokens.len == tokenizer.scores.len);
Expand All @@ -1001,4 +1010,12 @@ test "bpe" {
for (tokenization, 0..) |token, i| {
try std.testing.expect(token == expected_tokenization[i]);
}
const utf_input: []const u8 = "中";
const utf_expected_tokens: []const u32 = &[_]u32{30275};
const utf_tokenization = try tokenizer.encode(utf_input, allocator);
defer allocator.free(utf_tokenization);
try std.testing.expect(utf_tokenization.len == utf_expected_tokens.len);
for (utf_tokenization, 0..) |token, i| {
try std.testing.expect(token == utf_expected_tokens[i]);
}
}

0 comments on commit fb99f3f

Please sign in to comment.