Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic KV cache allocation #1364

Merged
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
a8c290a
Dynamic KV-cache allocation.
popovaan Dec 11, 2024
8b68108
Minor corrections.
popovaan Dec 11, 2024
5c384a0
Fixed cpp tests, added tests of dynamic allocation.
popovaan Dec 12, 2024
4e97cf9
Merge remote-tracking branch 'upstream/master' into dynamic_kv_cache_…
popovaan Dec 12, 2024
0f32f1f
Fixed typo.
popovaan Dec 12, 2024
d3f15fa
Test corrected.
popovaan Dec 12, 2024
ec7ca26
Minor corrections.
popovaan Dec 12, 2024
543cbd6
Merge remote-tracking branch 'upstream/master' into dynamic_kv_cache_…
popovaan Dec 13, 2024
175f241
Minor corrections.
popovaan Dec 13, 2024
a6facc5
Minor correction.
popovaan Dec 13, 2024
34f6d27
Removed not used code.
popovaan Dec 13, 2024
0f50cb7
Merge branch 'master' into dynamic_kv_cache_allocation
popovaan Dec 16, 2024
0c3bb28
Code optimizations.
popovaan Dec 17, 2024
a105a9f
Code corrections.
popovaan Dec 17, 2024
7537997
Merge upsteam/master.
popovaan Dec 18, 2024
a8531a5
Added available memory check for GPU.
popovaan Dec 19, 2024
9043ba3
Minor correction.
popovaan Dec 19, 2024
9256f15
Code corrections.
popovaan Dec 19, 2024
d926303
Minor correction.
popovaan Dec 19, 2024
eb4d110
Used correct core instance.
popovaan Dec 19, 2024
f94929c
Moved increasing of cache logic to scheduler.
popovaan Dec 20, 2024
a1e4973
Merge upstream/master.
popovaan Dec 20, 2024
38a42d6
Made sheduler config not needed for prompt lookup.
popovaan Dec 20, 2024
c7d54dd
Minor correction.
popovaan Dec 20, 2024
51cb0a8
Fixed error.
popovaan Dec 20, 2024
c4c8c25
Removed wrong changes.
popovaan Dec 20, 2024
bfcf9ff
Fixed error.
popovaan Dec 20, 2024
11b5e33
Minor correction.
popovaan Dec 20, 2024
64dab76
Removed wrong changes.
popovaan Dec 20, 2024
0d71053
Merge branch 'master' into dynamic_kv_cache_allocation
ilya-lavrenov Dec 20, 2024
bb24a36
Fix.
popovaan Dec 23, 2024
13f9f08
Fix of cache increasing for gpu.
popovaan Dec 23, 2024
f04c06d
Merge branch 'master' into dynamic_kv_cache_allocation
popovaan Dec 24, 2024
eebac1f
Applied comments.
popovaan Dec 24, 2024
2715110
Update src/cpp/src/scheduler.hpp
popovaan Dec 24, 2024
1d3f85b
Update src/cpp/src/scheduler.hpp
popovaan Dec 24, 2024
3fa02d0
Update src/cpp/src/scheduler.hpp
popovaan Dec 24, 2024
a0456d8
Update scheduler.hpp
ilya-lavrenov Dec 24, 2024
e393d3d
Merge branch 'master' into dynamic_kv_cache_allocation
ilya-lavrenov Dec 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions src/cpp/src/block_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class BlockAllocator {
size_t m_num_layers;
bool m_enable_prefix_caching;
ov::genai::OverwritableBlocksHashStore m_overwriteable_blocks;
bool m_initialized = false;
public:
/**
* Constructs the BlockAllocator.
Expand All @@ -205,13 +206,17 @@ class BlockAllocator {
* Blocks returned will be vectors with this size, each vector entry to be associated with a separate layer's KV cache.
*/
BlockAllocator(size_t num_blocks, bool enable_prefix_caching, size_t num_layers = 1) :
m_free_blocks_num(num_layers, num_blocks), m_total_num_blocks(num_blocks), m_num_layers(num_layers), m_enable_prefix_caching(enable_prefix_caching), m_overwriteable_blocks(num_layers) {
m_total_num_blocks(num_blocks), m_num_layers(num_layers), m_enable_prefix_caching(enable_prefix_caching), m_overwriteable_blocks(num_layers) {
OPENVINO_ASSERT(num_layers != 0, "num_layers must be non-zero");
m_free_blocks.resize(m_num_layers);
for (auto& per_layer_block_list : m_free_blocks) {
for (int block_id = 0; block_id < m_total_num_blocks; ++block_id) {
per_layer_block_list.push_back(std::make_shared<KVCacheBlock>(block_id));
if (num_blocks > 0) {
m_free_blocks_num = std::vector<size_t>(num_layers, num_blocks);
for (auto& per_layer_block_list : m_free_blocks) {
for (int block_id = 0; block_id < m_total_num_blocks; ++block_id) {
per_layer_block_list.push_back(std::make_shared<KVCacheBlock>(block_id));
}
}
m_initialized = true;
}
}

Expand All @@ -220,6 +225,28 @@ class BlockAllocator {
// OPENVINO_ASSERT(m_total_num_blocks == m_free_blocks.size());
}

void increase_kv_blocks_number(size_t new_kv_blocks_count) {
OPENVINO_ASSERT(new_kv_blocks_count > m_total_num_blocks, "New blocks number should be more than previous blocks number.");
if (!m_initialized) {
m_free_blocks_num = std::vector<size_t>(m_num_layers, 0);
m_initialized = true;
}
size_t added_blocks = new_kv_blocks_count - m_total_num_blocks;
for (auto idx = 0; idx < m_free_blocks_num.size(); idx++) {
m_free_blocks_num[idx] += added_blocks;
}
for (auto& per_layer_block_list : m_free_blocks) {
for (int block_id = m_total_num_blocks; block_id < new_kv_blocks_count; ++block_id) {
per_layer_block_list.push_back(std::make_shared<KVCacheBlock>(block_id));
}
}
m_total_num_blocks = new_kv_blocks_count;
}

bool is_inilialized() const {
return m_initialized;
popovaan marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Returns the number of free blocks for a given layer.
* @param layer_idx Index of the layer.
Expand Down Expand Up @@ -459,6 +486,13 @@ class BlockAllocator {
for (size_t layer_idx = 0; layer_idx < m_num_layers; layer_idx++) sum += num_free_blocks(layer_idx);
return static_cast<float>(m_num_layers * m_total_num_blocks - sum) / (m_num_layers * m_total_num_blocks) * 100;
}

/**
* @return The total number of KV blocks .
*/
size_t get_total_number_of_kv_blocks() const {
return m_total_num_blocks;
}
};

/**
Expand Down Expand Up @@ -631,6 +665,10 @@ class BlockManager {
return m_allocator.num_free_blocks(0); // relying on the invariant that all layers have identical number of blocks
}

bool block_allocator_initialized() const {
return m_allocator.is_inilialized();
}

/**
* @param num_blocks A number of KV cache blocks
* @return Whether this number of KV cache blocks may be assigned to new sequences.
Expand Down Expand Up @@ -713,6 +751,21 @@ class BlockManager {
return m_allocator.get_used_percentage();
}

/**
* Increases the number of KV blocks.
* @param num_blocks The new number of KV-blocks.
*/
void increase_kv_blocks_number(size_t num_blocks) {
m_allocator.increase_kv_blocks_number(num_blocks);
}

/**
* @return The total number of KV blocks .
*/
size_t get_total_number_of_kv_blocks() const {
return m_allocator.get_total_number_of_kv_blocks();
}

/**
* @brief Forks a sequence, establishing a new sequence from an existing one, reusing
* currently allocated blocks of the existing sequence.
Expand Down
108 changes: 98 additions & 10 deletions src/cpp/src/cache_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,44 @@ class CacheManager {
DeviceConfig m_device_config;
std::vector<ov::Tensor> m_key_cache;
std::vector<ov::Tensor> m_value_cache;
size_t m_num_allocated_kv_blocks = 0;
ov::Core m_core;
ov::InferRequest m_request;

public:
explicit CacheManager(const DeviceConfig &device_config, ov::Core core) :
explicit CacheManager(const DeviceConfig &device_config, ov::InferRequest request, ov::Core core) :
m_device_config(device_config),
m_request(request),
m_core(core) {
m_key_cache.reserve(m_device_config.get_num_layers());
m_value_cache.reserve(m_device_config.get_num_layers());
}

ov::Shape set_first_dim_and_make_static(const ov::PartialShape& shape, size_t dim) {
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
ov::PartialShape res_shape = shape;
res_shape[0] = dim;
OPENVINO_ASSERT(res_shape.is_static());
return res_shape.to_shape();
}

void allocate_cache_if_needed(size_t num_kv_blocks) {
if (m_num_allocated_kv_blocks >= num_kv_blocks) {
return;
}
if (m_num_allocated_kv_blocks > 0) {
increase_cache(num_kv_blocks);
return;
}
m_num_allocated_kv_blocks = num_kv_blocks;
ov::Shape value_cache_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(), num_kv_blocks);
ov::Shape key_cache_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(), num_kv_blocks);

const std::string device_name = m_device_config.get_device();

const std::string device_name = device_config.get_device();
if (device_name.find("GPU") == std::string::npos) {// Allocate KV caches
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
ov::Tensor key_cache(device_config.get_cache_precision(), device_config.get_key_cache_shape());
ov::Tensor value_cache(device_config.get_cache_precision(), device_config.get_value_cache_shape());
ov::Tensor key_cache(m_device_config.get_cache_precision(), key_cache_shape);
ov::Tensor value_cache(m_device_config.get_cache_precision(), value_cache_shape);

// force allocation
std::memset(key_cache.data(), 0, key_cache.get_byte_size());
Expand All @@ -40,15 +64,79 @@ class CacheManager {
} else {
auto remote_context = m_core.get_default_context(device_name);
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
ov::Tensor key_cache = remote_context.create_tensor(device_config.get_cache_precision(),
device_config.get_key_cache_shape());
ov::Tensor value_cache = remote_context.create_tensor(device_config.get_cache_precision(),
device_config.get_value_cache_shape());
ov::Tensor key_cache = remote_context.create_tensor(m_device_config.get_cache_precision(),
key_cache_shape);
ov::Tensor value_cache = remote_context.create_tensor(m_device_config.get_cache_precision(),
value_cache_shape);

m_key_cache.emplace_back(key_cache);
m_value_cache.emplace_back(value_cache);
}
}
update_request_tensor();
}

void update_request_tensor() {
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
m_request.set_tensor(std::string("key_cache.") + std::to_string(decoder_layer_id), m_key_cache[decoder_layer_id]);
m_request.set_tensor(std::string("value_cache.") + std::to_string(decoder_layer_id), m_value_cache[decoder_layer_id]);
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
}
}

void increase_cache(size_t num_kv_blocks) {
OPENVINO_ASSERT(num_kv_blocks > m_num_allocated_kv_blocks);
ov::Shape new_value_cache_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(), num_kv_blocks);
ov::Shape new_key_cache_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(), num_kv_blocks);

const std::string device_name = m_device_config.get_device();
ov::Coordinate start_key{0,0,0,0};
ov::Coordinate start_value{0,0,0,0};

if (device_name.find("GPU") == std::string::npos) {
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
ov::Coordinate end_key(m_key_cache[decoder_layer_id].get_shape());
ov::Coordinate end_value(m_value_cache[decoder_layer_id].get_shape());

ov::Tensor key_cache(m_device_config.get_cache_precision(), new_value_cache_shape);
ov::Tensor value_cache(m_device_config.get_cache_precision(), new_key_cache_shape);

// force allocation
std::memset(key_cache.data(), 0, key_cache.get_byte_size());
popovaan marked this conversation as resolved.
Show resolved Hide resolved
std::memset(value_cache.data(), 0, value_cache.get_byte_size());
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved

// copy current cache data
ov::Tensor dst_key_roi(key_cache, start_key, end_key);
ov::Tensor dst_value_roi(value_cache, start_value, end_value);
m_key_cache[decoder_layer_id].copy_to(dst_key_roi);
m_value_cache[decoder_layer_id].copy_to(dst_value_roi);

ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
// set new cache tensors
m_key_cache[decoder_layer_id] = key_cache;
m_value_cache[decoder_layer_id] = value_cache;
}
} else {
auto remote_context = m_core.get_default_context(device_name);
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
ov::Coordinate end_key(m_key_cache[decoder_layer_id].get_shape());
ov::Coordinate end_value(m_value_cache[decoder_layer_id].get_shape());

ov::Tensor key_cache = remote_context.create_tensor(m_device_config.get_cache_precision(), new_value_cache_shape);
ov::Tensor value_cache = remote_context.create_tensor(m_device_config.get_cache_precision(), new_key_cache_shape);

// copy current cache data
ov::Tensor dst_key_roi(key_cache, start_key, end_key);
ov::Tensor dst_value_roi(value_cache, start_value, end_value);
m_key_cache[decoder_layer_id].copy_to(dst_key_roi);
m_value_cache[decoder_layer_id].copy_to(dst_value_roi);

// set new cache tensors
m_key_cache[decoder_layer_id] = key_cache;
m_value_cache[decoder_layer_id] = value_cache;
}
}
update_request_tensor();

m_num_allocated_kv_blocks = num_kv_blocks;
}

ov::Tensor get_key_cache(size_t decoder_layer_id) const {
Expand All @@ -62,8 +150,8 @@ class CacheManager {
}

void copy_blocks(const std::map<size_t, std::list<size_t>>& block_copy_map) {
ov::Shape key_shape = m_device_config.get_key_cache_shape();
ov::Shape value_shape = m_device_config.get_value_cache_shape();
ov::Shape key_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(), m_num_allocated_kv_blocks);
ov::Shape value_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(), m_num_allocated_kv_blocks);

ov::Coordinate key_src_start_roi(key_shape.size(), 0);
ov::Coordinate key_src_end_roi = key_shape;
Expand Down
8 changes: 2 additions & 6 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::init(
ov::InferRequest infer_request = core.compile_model(model, device_config.get_device(), properties).create_infer_request();

// setup KV caches
m_cache_manager = std::make_shared<CacheManager>(device_config, core);
for (size_t decoder_layer_id = 0; decoder_layer_id < device_config.get_num_layers(); ++decoder_layer_id) {
infer_request.set_tensor(std::string("key_cache.") + std::to_string(decoder_layer_id), m_cache_manager->get_key_cache(decoder_layer_id));
infer_request.set_tensor(std::string("value_cache.") + std::to_string(decoder_layer_id), m_cache_manager->get_value_cache(decoder_layer_id));
}
m_cache_manager = std::make_shared<CacheManager>(device_config, infer_request, core);

SchedulerConfig updated_config = scheduler_config;
// update KV blocks number in scheduler config
Expand All @@ -68,7 +64,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::init(
can_use_partial_preemption = false;
}

m_scheduler = std::make_shared<Scheduler>(device_config.get_block_size(), updated_config, device_config.get_num_layers(), can_use_partial_preemption);
m_scheduler = std::make_shared<Scheduler>(device_config.get_block_size(), m_cache_manager, updated_config, device_config.get_num_layers(), can_use_partial_preemption);
// and finally create model runner
bool is_use_cache_eviction = m_scheduler->get_config().use_cache_eviction;
m_model_runner = std::make_shared<ModelRunner>(infer_request, m_scheduler->get_block_size(), device_config.get_num_layers(), is_use_cache_eviction);
Expand Down
32 changes: 15 additions & 17 deletions src/cpp/src/device_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace ov::genai {
class DeviceConfig {
ov::element::Type m_kv_cache_type;
ov::Shape m_key_cache_shape, m_value_cache_shape;
ov::PartialShape m_key_cache_shape, m_value_cache_shape;
ov::Shape::value_type m_num_kv_heads, m_head_size, m_num_decoder_layers;
size_t m_num_kv_blocks = 0;
size_t m_block_size = 0;
Expand Down Expand Up @@ -80,11 +80,10 @@ class DeviceConfig {
OPENVINO_THROW(m_device, " is not supported by OpenVINO Continuous Batching");
}

OPENVINO_ASSERT(scheduling_config.num_kv_blocks > 0 || scheduling_config.cache_size > 0, "num_kv_blocks or cache_size should be more than zero.");
if (scheduling_config.num_kv_blocks > 0) {
m_num_kv_blocks = scheduling_config.num_kv_blocks;
}
else {
else if (scheduling_config.cache_size > 0) {
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
m_cache_size = scheduling_config.cache_size;
}
}
Expand All @@ -104,23 +103,22 @@ class DeviceConfig {
m_head_size += 8;
}

if (m_num_kv_blocks == 0) {
OPENVINO_ASSERT(m_cache_size > 0, "num_kv_blocks or cache_size should be more than zero.");
if (m_num_kv_blocks == 0 && m_cache_size > 0) {
size_t size_in_bytes = m_cache_size * 1024 * 1024 * 1024;
m_num_kv_blocks = size_in_bytes / (m_num_decoder_layers * 2 * m_num_kv_heads * m_block_size * m_head_size * m_kv_cache_type.size());
}

m_key_cache_shape = m_value_cache_shape = ov::Shape{m_num_kv_blocks,
m_num_kv_heads,
m_block_size,
m_head_size};
m_key_cache_shape = m_value_cache_shape = ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(m_num_kv_heads),
ov::Dimension(m_block_size),
ov::Dimension(m_head_size)};

if (m_device.find("GPU") != std::string::npos) {
// Update key shape, as the key's shape is different from the value's shape
m_key_cache_shape = ov::Shape{m_num_kv_blocks,
m_num_kv_heads,
m_head_size,
m_block_size};
m_key_cache_shape = ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(m_num_kv_heads),
ov::Dimension(m_head_size),
ov::Dimension(m_block_size)};
}
}

Expand All @@ -136,13 +134,13 @@ class DeviceConfig {
return m_num_decoder_layers;
}

ov::Shape get_key_cache_shape() const {
OPENVINO_ASSERT(!m_key_cache_shape.empty());
ov::PartialShape get_key_cache_shape() const {
OPENVINO_ASSERT(m_key_cache_shape.size());
return m_key_cache_shape;
}

ov::Shape get_value_cache_shape() const {
OPENVINO_ASSERT(!m_value_cache_shape.empty());
ov::PartialShape get_value_cache_shape() const {
OPENVINO_ASSERT(m_value_cache_shape.size());
return m_value_cache_shape;
}

Expand Down
26 changes: 25 additions & 1 deletion src/cpp/src/scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "device_config.hpp"
#include "block_manager.hpp"
#include "sequence_group.hpp"
#include "cache_manager.hpp"

namespace ov::genai {
class Scheduler {
Expand All @@ -19,6 +20,11 @@ class Scheduler {
SchedulerConfig m_config;
BlockManager m_block_manager;
friend class CacheStateDumper;
std::shared_ptr<CacheManager> m_cache_manager;
const size_t m_kv_blocks_initial_multiplier = 2;
const float m_cache_growth_factor = 2; // commmon values 1.5 or 2
const float m_precentage_threshold_for_cache_increase = 100;
bool m_dynamic_memory_allocation = false;

public:
struct Output {
Expand All @@ -36,7 +42,8 @@ class Scheduler {
float m_cache_usage = 0.0;
};

explicit Scheduler(size_t block_size, const SchedulerConfig & config = {}, size_t num_layers = 1, bool can_use_partial_preemption = true) :
explicit Scheduler(size_t block_size, std::shared_ptr<CacheManager> cache_manager, const SchedulerConfig & config = {}, size_t num_layers = 1, bool can_use_partial_preemption = true) :
m_cache_manager(cache_manager),
m_can_use_partial_preemption(can_use_partial_preemption),
m_config(config),
m_block_manager(m_config.num_kv_blocks, m_config.enable_prefix_caching, block_size, num_layers) {
Expand All @@ -45,6 +52,23 @@ class Scheduler {

Output schedule(std::vector<SequenceGroup::Ptr>& sequence_groups) {
Output scheduler_output;
float eps = 1e-5;

if (!m_block_manager.block_allocator_initialized()) {
size_t prompt_sum_size = 0;
for (auto idx = 0; idx < sequence_groups.size(); idx++) {
prompt_sum_size += sequence_groups[idx]->get_prompt_len();
}
size_t initial_kv_cache_size = prompt_sum_size * m_kv_blocks_initial_multiplier;
m_block_manager.increase_kv_blocks_number(initial_kv_cache_size);
m_dynamic_memory_allocation = true;
}
else if (m_dynamic_memory_allocation && (m_block_manager.get_used_percentage() + eps) > m_precentage_threshold_for_cache_increase) {
size_t new_cache_size = (size_t)(m_block_manager.get_total_number_of_kv_blocks() * m_cache_growth_factor);
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
m_block_manager.increase_kv_blocks_number(new_cache_size);
}
OPENVINO_ASSERT(m_cache_manager != nullptr, "Cache manager needs to be set in the Scheduler constructor.");
m_cache_manager->allocate_cache_if_needed(m_block_manager.get_total_number_of_kv_blocks());
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved

if (m_config.dynamic_split_fuse) {
// deepspeed-mii case
Expand Down
Loading
Loading