Skip to content

Commit

Permalink
Fixed caching of partially filled blocks.
Browse files Browse the repository at this point in the history
  • Loading branch information
popovaan committed Jul 23, 2024
1 parent cb5c784 commit 1df5806
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
38 changes: 30 additions & 8 deletions src/cpp/src/block_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class BlockAllocator {
}

KVCacheBlock::Ptr allocate_block() {
OPENVINO_ASSERT(!m_enable_prefix_caching);
OPENVINO_ASSERT(can_allocate_blocks(1));
KVCacheBlock::Ptr allocated_block = m_free_blocks.front();
allocated_block->increment();
Expand All @@ -171,6 +172,9 @@ class BlockAllocator {
cached_blocks[hash] = block;
return block;
}
// TODO: Currently we cache all allocated blocks which might be redundant for beam search,
// where blocks of non-used candidates are not needed in cache.
// This part can be probably improved if we cache only blocks for resulting finished sequences.
if (cached_blocks.find(hash) != cached_blocks.end()) {
// use cashed block from cached_blocks
block = cached_blocks[hash];
Expand Down Expand Up @@ -334,18 +338,16 @@ class BlockManager {
auto sequence_id = sequence->get_id();
auto block_table = m_block_table[sequence_id];
auto content_length = sequence->get_generated_len() + prompt_ids.size();
size_t num_hashed_tokens_in_last_block = content_length % m_block_size;
size_t allocated_content = block_table.size() * m_block_size;
size_t num_hashed_tokens = block_table.size() * m_block_size;

for (size_t i = 0; i < num_blocks; ++i) {

ov::genai::KVCacheBlock::Ptr block = nullptr;
if (m_enable_prefix_caching) {
// allocated_content += m_block_size;
// if (allocated_content > content_length) {
// allocated_content = content_length;
// }
size_t num_hashed_tokens = (i + 1) * m_block_size + allocated_content <= content_length ? (i + 1) * m_block_size + allocated_content: num_hashed_tokens_in_last_block + allocated_content;
num_hashed_tokens += m_block_size;
if (num_hashed_tokens > content_length) {
num_hashed_tokens = content_length;
}
auto hash = sequence->get_hash(num_hashed_tokens, prompt_ids);
block = m_allocator.allocate_block(hash, num_hashed_tokens, cached_blocks);
}
Expand Down Expand Up @@ -483,7 +485,6 @@ class BlockManager {
OPENVINO_ASSERT(num_logical_blocks == num_physical_blocks, "A number of physical and logic blocks must be the same in this code path");
KVCacheBlock::Ptr last_block = block_table.back();
if (last_block->copy_on_write()) {
// TODO: Update hash of block
// we need to fork current block, because reference counter is more than 1
KVCacheBlock::Ptr new_block = nullptr;
if (m_enable_prefix_caching) {
Expand All @@ -503,6 +504,8 @@ class BlockManager {
// we are the only users of this block
if (m_enable_prefix_caching) {
// update hash of block
// TODO: Caching time can probably be improved here if we store
// cache of tokens instead of cache of block.
auto prev_hash = last_block->get_hash();
auto hash = sequence->get_hash(seq_group->get_context_len(), seq_group->get_prompt_ids());
last_block->set_hash(hash, seq_group->get_context_len());
Expand All @@ -528,10 +531,12 @@ class BlockManager {

size_t content_len = 0;
while (content_len < prompt_ids.size()) {
size_t prev_iteration_content_len = content_len;
content_len += block_size;
if (content_len > prompt_ids.size()) {
content_len = prompt_ids.size();
}
// resore fully filled blocks
auto hash = sequence->get_hash(content_len, prompt_ids);
auto block = m_allocator.get_cashed_block(hash, cached_blocks);
if (block != nullptr) {
Expand All @@ -540,6 +545,23 @@ class BlockManager {
group->update_processed_tokens_num(content_len);
}
else {
size_t tokens_len_in_last_block = content_len % block_size;
if (tokens_len_in_last_block != 0) {
// resore partially filled block
for (size_t i = 1; i < block_size; i++) {
if (prev_iteration_content_len + i > prompt_ids.size()) {
break;
}
auto hash = sequence->get_hash(prev_iteration_content_len + i, prompt_ids);
auto block = m_allocator.get_cashed_block(hash, cached_blocks);
if (block != nullptr) {
block->set_timestamp(time(NULL));
m_block_table[seq_id].push_back(block);
group->update_processed_tokens_num(prev_iteration_content_len + i);
break;
}
}
}
break;
}
}
Expand Down
8 changes: 6 additions & 2 deletions src/cpp/src/scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,15 @@ class Scheduler {
Sequence::Ptr sequence = (*sequence_group)[0];
uint64_t seq_id = sequence->get_id();

// allocate KV blocks
m_block_manager.allocate(sequence, sequence_group->get_prompt_ids(), num_required_blocks);
// and schedule tokens
sequence_group->schedule_tokens(sequence_len);

// allocate KV blocks
if (sequence_group->get_num_processed_tokens() == 0)
m_block_manager.allocate(sequence, sequence_group->get_prompt_ids(), num_required_blocks);
else
m_block_manager.append_slots(sequence_group);

// add information to scheduler_output
{
scheduler_output.m_scheduled_sequence_groups_ids.push_back(sequence_group_id);
Expand Down

0 comments on commit 1df5806

Please sign in to comment.