Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix caching improvements #758

Merged
100 changes: 44 additions & 56 deletions src/cpp/src/block_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class KVCacheBlock {
int m_ref_count;
int m_index;
size_t m_hash;
size_t m_num_hashed_tokens;
std::chrono::time_point<std::chrono::system_clock> m_timestamp;
public:
using Ptr = std::shared_ptr<KVCacheBlock>;
Expand Down Expand Up @@ -55,13 +54,8 @@ class KVCacheBlock {
return m_hash;
}

size_t get_num_hashed_tokens() const {
return m_num_hashed_tokens;
}

void set_hash(size_t hash, size_t num_hashed_tokens) {
void set_hash(size_t hash) {
m_hash = hash;
m_num_hashed_tokens = num_hashed_tokens;
}

void set_timestamp(const std::chrono::time_point<std::chrono::system_clock>& timestamp) {
Expand All @@ -75,46 +69,42 @@ class KVCacheBlock {


class Evictor {
std::map<size_t, KVCacheBlock::Ptr> blocks;
public:
void add(size_t hash, KVCacheBlock::Ptr block) {
blocks[hash] = block;
}

static bool block_is_less(const std::pair<size_t, KVCacheBlock::Ptr>& lhs, const std::pair<size_t, KVCacheBlock::Ptr>& rhs) {
return lhs.second->get_timestamp() < rhs.second->get_timestamp();
std::map<size_t, KVCacheBlock::Ptr> m_blocks;
public:
void add(KVCacheBlock::Ptr block) {
m_blocks[block->get_hash()] = block;
}

KVCacheBlock::Ptr get_block(size_t hash) {
if (blocks.find(hash)== blocks.end())
auto it = m_blocks.find(hash);
if (it == m_blocks.end())
{
return nullptr;
}
KVCacheBlock::Ptr block = blocks[hash];
KVCacheBlock::Ptr block = it->second;
block->set_timestamp(std::chrono::system_clock::now());
block->increment();
blocks.erase(hash);
m_blocks.erase(it);
return block;
}

KVCacheBlock::Ptr get_lru_block() {
if (!blocks.size()) {
if (!m_blocks.size()) {
return nullptr;
}
auto hash_block = std::min_element(std::begin(blocks), std::end(blocks), block_is_less);
auto hash_block = std::min_element(std::begin(m_blocks), std::end(m_blocks), [](const auto& lhs, const auto& rhs) -> bool { return lhs.second->get_timestamp() < rhs.second->get_timestamp(); });
auto block = hash_block->second;
block->set_timestamp(std::chrono::system_clock::now());
block->increment();
blocks.erase(hash_block->first);
m_blocks.erase(hash_block->first);
return block;
}

size_t num_blocks() const {
return blocks.size();
return m_blocks.size();
}
};


class BlockAllocator {
std::list<KVCacheBlock::Ptr> m_free_blocks;
ov::genai::Evictor m_evictor;
Expand Down Expand Up @@ -146,7 +136,7 @@ class BlockAllocator {
if (block->is_free()) {
if (m_enable_prefix_caching)
{
m_evictor.add(block->get_hash(), block);
m_evictor.add(block);
}
else {
m_free_blocks.push_back(block);
Expand All @@ -163,29 +153,28 @@ class BlockAllocator {
return allocated_block;
}

KVCacheBlock::Ptr allocate_block(size_t hash, size_t num_hashed_tokens, std::map<uint64_t, KVCacheBlock::Ptr>& cached_blocks) {
KVCacheBlock::Ptr allocate_block(size_t hash, std::map<uint64_t, KVCacheBlock::Ptr>& cached_blocks) {
OPENVINO_ASSERT(m_enable_prefix_caching);
OPENVINO_ASSERT(can_allocate_blocks(1));
auto block = m_evictor.get_block(hash);
auto it = cached_blocks.find(hash);
if (block != nullptr) {
// use cached block from evictor
cached_blocks[hash] = block;
return block;
}
// TODO: Currently we cache all allocated blocks which might be redundant for beam search,
// where blocks of non-used candidates are not needed in cache.
// This part can be improved if we cache only blocks for prompt.
if (cached_blocks.find(hash) != cached_blocks.end()) {
if (it != cached_blocks.end()) {
// use cashed block from cached_blocks
block = cached_blocks[hash];
cached_blocks[hash]->increment();
return block;
it->second->increment();
return it->second;
}
if (m_free_blocks.size() > 0) {
// allocate new empty block
KVCacheBlock::Ptr allocated_block = m_free_blocks.front();
allocated_block->increment();
allocated_block->set_hash(hash, num_hashed_tokens);
allocated_block->set_hash(hash);
cached_blocks[hash] = allocated_block;

m_free_blocks.pop_front();
Expand All @@ -197,7 +186,7 @@ class BlockAllocator {
cached_blocks.erase(block->get_hash());

// update block with new hash
block->set_hash(hash, num_hashed_tokens);
block->set_hash(hash);
cached_blocks[hash] = block;
return block;
}
Expand All @@ -209,15 +198,14 @@ class BlockAllocator {
auto block = m_evictor.get_block(hash);
if (block != nullptr) {
// use cashed block from evictor
cached_blocks[hash] = block;
return block;
}
if (cached_blocks.find(hash) != cached_blocks.end()) {
auto it = cached_blocks.find(hash);
if (it != cached_blocks.end()) {
// use cashed block from cached_blocks
// TODO: add tokens validation in case of hash collision
block = cached_blocks[hash];
cached_blocks[hash]->increment();
return block;
it->second->increment();
return it->second;
}
return nullptr;
}
Expand Down Expand Up @@ -299,11 +287,10 @@ class BlockManager {
return m_allocator.can_allocate_blocks(num_blocks);
}

void allocate(ov::genai::Sequence::CPtr sequence, size_t num_blocks, const ov::genai::TokenIds& prompt_ids = {}) {
void allocate(ov::genai::Sequence::Ptr sequence, size_t num_blocks, const ov::genai::TokenIds& prompt_ids = {}) {
OPENVINO_ASSERT(num_blocks > 0 && can_allocate_blocks(num_blocks));
if (m_enable_prefix_caching) {
OPENVINO_ASSERT(prompt_ids.size() > 0, "prompt_ids should be set for hash calculation.");
}
OPENVINO_ASSERT(!m_enable_prefix_caching || prompt_ids.size() > 0, "prompt_ids should be set for hash calculation.");

auto sequence_id = sequence->get_id();
auto block_table = m_block_table[sequence_id];
auto content_length = sequence->get_generated_len() + prompt_ids.size();
Expand All @@ -317,8 +304,8 @@ class BlockManager {
if (num_hashed_tokens > content_length) {
num_hashed_tokens = content_length;
}
auto hash = sequence->get_hash(num_hashed_tokens, prompt_ids);
block = m_allocator.allocate_block(hash, num_hashed_tokens, cached_blocks);
auto hash = sequence->get_hash(num_hashed_tokens);
block = m_allocator.allocate_block(hash, cached_blocks);
}
else {
block = m_allocator.allocate_block();
Expand Down Expand Up @@ -433,14 +420,14 @@ class BlockManager {
return blocks_count;
}

std::map<size_t, std::list<size_t>> append_slots(SequenceGroup::CPtr seq_group) {
std::map<size_t, std::list<size_t>> append_slots(SequenceGroup::Ptr seq_group) {

size_t num_logical_blocks = seq_group->get_num_logical_blocks();
std::vector<Sequence::CPtr> running_sequences = seq_group->get_running_sequences();
std::vector<Sequence::Ptr> running_sequences = seq_group->get_running_sequences();

std::map<size_t, std::list<size_t>> copy_blocks_map;
for (size_t i = 0; i < running_sequences.size(); ++i) {
Sequence::CPtr sequence = running_sequences[i];
Sequence::Ptr sequence = running_sequences[i];
auto seq_id = sequence->get_id();
auto& block_table = m_block_table[seq_id];
size_t num_physical_blocks = block_table.size();
Expand All @@ -455,8 +442,8 @@ class BlockManager {
// we need to fork current block, because reference counter is more than 1
KVCacheBlock::Ptr new_block = nullptr;
if (m_enable_prefix_caching) {
auto hash = sequence->get_hash(seq_group->get_context_len(), seq_group->get_prompt_ids());
new_block = m_allocator.allocate_block(hash, seq_group->get_context_len(), cached_blocks);
auto hash = sequence->get_hash();
new_block = m_allocator.allocate_block(hash, cached_blocks);
cached_blocks[hash] = new_block;
}
else {
Expand All @@ -472,8 +459,8 @@ class BlockManager {
if (m_enable_prefix_caching) {
// update hash of block
auto prev_hash = last_block->get_hash();
auto hash = sequence->get_hash(seq_group->get_context_len(), seq_group->get_prompt_ids());
last_block->set_hash(hash, seq_group->get_context_len());
auto hash = sequence->get_hash();
last_block->set_hash(hash);
cached_blocks.erase(prev_hash);
cached_blocks[hash] = last_block;
}
Expand All @@ -486,7 +473,7 @@ class BlockManager {
}


void _restore_cached_blocks(SequenceGroup::Ptr group, size_t block_size) {
void restore_cached_blocks(SequenceGroup::Ptr group, size_t block_size) {
auto prompt_ids = group->get_prompt_ids();
auto sequences = group->get_not_finished_sequences();
OPENVINO_ASSERT(sequences.size() == 1);
Expand All @@ -502,8 +489,8 @@ class BlockManager {
content_len = prompt_ids.size();
}
// restore fully filled blocks
auto hash = sequence->get_hash(content_len, prompt_ids);
auto block = m_allocator.get_cached_block(hash, cached_blocks);
auto full_block_hash = sequence->get_hash(content_len);
auto block = m_allocator.get_cached_block(full_block_hash, cached_blocks);
if (block != nullptr) {
block->set_timestamp(std::chrono::system_clock::now());
m_block_table[seq_id].push_back(block);
Expand All @@ -515,19 +502,20 @@ class BlockManager {
if (prev_iteration_content_len + i > prompt_ids.size()) {
break;
}
auto hash = sequence->get_hash(prev_iteration_content_len + i, prompt_ids);
auto hash = sequence->get_hash(prev_iteration_content_len + i);
auto block = m_allocator.get_cached_block(hash, cached_blocks);
if (block != nullptr) {
block->set_timestamp(std::chrono::system_clock::now());
m_block_table[seq_id].push_back(block);
group->update_processed_tokens_num(prev_iteration_content_len + i);

size_t new_tokens_count_in_block = std::min(content_len, prev_iteration_content_len + block_size);
if (new_tokens_count_in_block > prev_iteration_content_len + i) {
cached_blocks.erase(hash);
auto new_hash = sequence->get_hash(new_tokens_count_in_block, prompt_ids);
auto new_hash = sequence->get_hash(new_tokens_count_in_block);
block->set_hash(new_hash);
cached_blocks[new_hash] = block;
}
m_block_table[seq_id].push_back(block);

break;
}
Expand Down
9 changes: 8 additions & 1 deletion src/cpp/src/continuous_batching_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,14 @@ class ContinuousBatchingPipeline::Impl {
sampling_params.set_eos_token_id(m_tokenizer.get_eos_token_id());
sampling_params.validate();
SequenceGroup::Ptr sequence_group = std::make_shared<SequenceGroup>(request_id, input_ids,
sampling_params, m_scheduler->get_config().block_size);
sampling_params,
m_scheduler->get_config().block_size,
m_scheduler->get_config().enable_prefix_caching);
sequence_group->set_sequence_group_ptr(sequence_group);
if (m_scheduler->get_config().enable_prefix_caching) {
m_scheduler->restore_cached_blocks(sequence_group);
}

{
std::lock_guard<std::mutex> lock{m_awaiting_requests_mutex};
m_awaiting_requests.push_back(sequence_group);
Expand Down
21 changes: 5 additions & 16 deletions src/cpp/src/scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ class Scheduler {
Output schedule(std::vector<SequenceGroup::Ptr>& sequence_groups) {
Output scheduler_output;

if (m_config.enable_prefix_caching)
_restore_cached_blocks(sequence_groups);

if (m_config.dynamic_split_fuse) {
// deepspeed-mii case
// generation phase is always scheduled first
Expand Down Expand Up @@ -81,6 +78,10 @@ class Scheduler {
m_block_manager.fork_sequence(parent_id, child_id);
}

void restore_cached_blocks(const SequenceGroup::Ptr& sequence_group) {
m_block_manager.restore_cached_blocks(sequence_group, m_config.block_size);
}

const SchedulerConfig& get_config() const {
return m_config;
}
Expand Down Expand Up @@ -152,15 +153,6 @@ class Scheduler {
return std::numeric_limits<size_t>::max();
}

void _restore_cached_blocks(const std::vector<SequenceGroup::Ptr>& sequence_groups) {
for (size_t sequence_group_id = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id];
if (sequence_group->can_generate_tokens() || sequence_group->num_running_seqs() != 1)
continue;
m_block_manager._restore_cached_blocks(sequence_group, m_config.block_size);
}
}

void _apply_preemption(size_t sequence_group_id, const std::vector<SequenceGroup::Ptr>& sequence_groups) {
SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id];

Expand Down Expand Up @@ -353,10 +345,7 @@ class Scheduler {
sequence_group->schedule_tokens(sequence_len);

// allocate KV blocks
if (sequence_group->get_num_processed_tokens() == 0)
m_block_manager.allocate(sequence, num_required_blocks, sequence_group->get_prompt_ids());
else
m_block_manager.append_slots(sequence_group);
m_block_manager.append_slots(sequence_group);

// add information to scheduler_output
{
Expand Down
60 changes: 60 additions & 0 deletions src/cpp/src/sequence_group.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <string_view>
#include "sequence_group.hpp"

namespace ov {
namespace genai {
size_t Sequence::_make_hash(size_t content_length) {
auto sequence_group = get_sequence_group_ptr();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to add an assert that content_length corresponds to last uncomputed block? E.g. if we have 1 block with has, but content_length is 3x of block_size.

E.g. block_start_idx / block_size == m_prefix_hashes.size()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a case when block_start_idx / block_size < m_prefix_hashes.size().
When we restore blocks of prompt first we check hash of full block and it is saved in m_prefix_hashes, then if we couldn't find hash of full block in cashed_blocks we check hashes of partially completed content of this block. So content_length in this case is less than m_prefix_hashes.size() * block_size.

So I added assert block_start_idx / block_size <= m_prefix_hashes.size().

auto block_size = sequence_group->get_block_size();
size_t block_start_idx = content_length - (content_length % block_size);
if (block_start_idx == content_length) {
block_start_idx -= block_size;
}

// hash of current block depends on prefix hashes
std::vector<int64_t> content;
size_t prefix_hashes_needed_count = block_start_idx / block_size;
OPENVINO_ASSERT(prefix_hashes_needed_count <= m_prefix_hashes.size());
content.insert(content.end(), m_prefix_hashes.begin(), m_prefix_hashes.begin() + prefix_hashes_needed_count);

// get tokens corresponding to current block
const auto prompt_ids = sequence_group->get_prompt_ids();
OPENVINO_ASSERT(content_length <= prompt_ids.size() + m_generated_ids.size());
if (block_start_idx < prompt_ids.size()) {
content.insert(content.end(), prompt_ids.begin() + block_start_idx, prompt_ids.begin() + std::min(prompt_ids.size(), content_length));
}
if (content_length > prompt_ids.size()) {
size_t start = block_start_idx < prompt_ids.size() ? 0 : block_start_idx - prompt_ids.size();
content.insert(content.end(), m_generated_ids.begin() + start, m_generated_ids.begin() + content_length - prompt_ids.size());
}
const char* data = reinterpret_cast<const char*>(content.data());
std::size_t size = content.size() * sizeof(content[0]);
return std::hash<std::string_view>{}(std::string_view(data, size));
}

// Each KV block can be uniquely identified by
// the tokens within the block and the tokens in the prefix before the block.
// hash(prefix tokens + block tokens) <--> KV Block
size_t Sequence::get_hash(size_t content_length) {

auto sequence_group = get_sequence_group_ptr();
OPENVINO_ASSERT(sequence_group, "Hash computation requires setting of sequence_group ptr.");
auto content_len = content_length == 0 ? sequence_group->get_context_len() : content_length;
auto block_size = sequence_group->get_block_size();
size_t cur_content = block_size * (m_prefix_hashes.size() + 1);
while (cur_content <= content_len)
{
m_prefix_hashes.push_back(_make_hash(cur_content));
cur_content += block_size;
}
if (content_len % block_size == 0) {
return m_prefix_hashes[content_len / block_size - 1];
}

return _make_hash(content_len);
}
} // namespace genai
} // namespace ov
Loading
Loading