From f2010de9fbcf69ff44b465535c3ff9efeb749f7e Mon Sep 17 00:00:00 2001 From: Sylwia Kuros Date: Fri, 26 Jul 2024 08:47:09 +0200 Subject: [PATCH 1/5] Update requirements.txt --- llm_bench/python/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/llm_bench/python/requirements.txt b/llm_bench/python/requirements.txt index d83cd5a376..ed80a66deb 100644 --- a/llm_bench/python/requirements.txt +++ b/llm_bench/python/requirements.txt @@ -7,7 +7,6 @@ openvino_genai auto-gptq>=0.5.1 # for gptq pillow torch -torchvision<0.19.0 transformers>=4.40.0 diffusers>=0.22.0 #optimum is in dependency list of optimum-intel From 4bd1a26a08cca1895475add911bc53d8eff34a6c Mon Sep 17 00:00:00 2001 From: Anastasiia Pnevskaia Date: Fri, 26 Jul 2024 08:51:58 +0200 Subject: [PATCH 2/5] Prefix caching. (#639) Implementation of prefix caching. Ticket: CVS-138669 --- .../openvino/genai/scheduler_config.hpp | 8 + src/cpp/src/block_manager.hpp | 258 +++++++++++++++++- src/cpp/src/scheduler.hpp | 28 +- src/cpp/src/sequence_group.hpp | 21 ++ src/python/py_generate_pipeline.cpp | 5 +- tests/cpp/CMakeLists.txt | 5 +- tests/cpp/block_manager.cpp | 31 ++- tests/cpp/evictor.cpp | 54 ++++ tests/cpp/scheduler.cpp | 68 +++++ 9 files changed, 443 insertions(+), 35 deletions(-) create mode 100644 tests/cpp/evictor.cpp diff --git a/src/cpp/include/openvino/genai/scheduler_config.hpp b/src/cpp/include/openvino/genai/scheduler_config.hpp index 787060d07e..d9bf7a7b41 100644 --- a/src/cpp/include/openvino/genai/scheduler_config.hpp +++ b/src/cpp/include/openvino/genai/scheduler_config.hpp @@ -30,5 +30,13 @@ struct SchedulerConfig { // max number of scheduled sequences (you can think of it as "max batch size") std::size_t max_num_seqs = 256; + + // Enable caching of KV-blocks. + // When turned on all previously calculated KV-caches are kept in memory for future usages. + // KV-caches can be rewritten if KV-cache limit is reached, but blocks are not released. + // This results in more RAM usage, maximum RAM usage is determined by cache_size or num_kv_blocks parameters. + // When turend off only KV-cache required for batch calculation is kept in memory and + // when a sequence has finished genegartion its cache is released. + bool enable_prefix_caching = false; }; } diff --git a/src/cpp/src/block_manager.hpp b/src/cpp/src/block_manager.hpp index ab60b7f5ff..3b1a663235 100644 --- a/src/cpp/src/block_manager.hpp +++ b/src/cpp/src/block_manager.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include "sequence_group.hpp" @@ -13,13 +14,17 @@ namespace ov::genai { class KVCacheBlock { int m_ref_count; int m_index; + size_t m_hash; + size_t m_num_hashed_tokens; + std::chrono::time_point m_timestamp; public: using Ptr = std::shared_ptr; using CPtr = std::shared_ptr; explicit KVCacheBlock(int index) : m_ref_count(0), - m_index(index) { } + m_index(index), + m_timestamp(std::chrono::system_clock::now()) { } int get_index() const { return m_index; @@ -34,6 +39,7 @@ class KVCacheBlock { } void release() { + OPENVINO_ASSERT(m_ref_count > 0); --m_ref_count; } @@ -44,15 +50,79 @@ class KVCacheBlock { int get_references_count() const { return m_ref_count; } + + size_t get_hash() const { + return m_hash; + } + + size_t get_num_hashed_tokens() const { + return m_num_hashed_tokens; + } + + void set_hash(size_t hash, size_t num_hashed_tokens) { + m_hash = hash; + m_num_hashed_tokens = num_hashed_tokens; + } + + void set_timestamp(const std::chrono::time_point& timestamp) { + m_timestamp = timestamp; + } + + std::chrono::time_point get_timestamp() { + return m_timestamp; + } +}; + + +class Evictor { + std::map blocks; +public: + void add(size_t hash, KVCacheBlock::Ptr block) { + blocks[hash] = block; + } + + static bool block_is_less(const std::pair& lhs, const std::pair& rhs) { + return lhs.second->get_timestamp() < rhs.second->get_timestamp(); + } + + KVCacheBlock::Ptr get_block(size_t hash) { + if (blocks.find(hash)== blocks.end()) + { + return nullptr; + } + KVCacheBlock::Ptr block = blocks[hash]; + block->set_timestamp(std::chrono::system_clock::now()); + block->increment(); + blocks.erase(hash); + return block; + } + + KVCacheBlock::Ptr get_lru_block() { + if (!blocks.size()) { + return nullptr; + } + auto hash_block = std::min_element(std::begin(blocks), std::end(blocks), block_is_less); + auto block = hash_block->second; + block->set_timestamp(std::chrono::system_clock::now()); + block->increment(); + blocks.erase(hash_block->first); + return block; + } + + size_t num_blocks() const { + return blocks.size(); + } }; class BlockAllocator { std::list m_free_blocks; + ov::genai::Evictor m_evictor; int m_total_num_blocks; + bool m_enable_prefix_caching; public: - BlockAllocator(int num_blocks) : - m_total_num_blocks(num_blocks) { + BlockAllocator(int num_blocks, bool enable_prefix_caching) : + m_total_num_blocks(num_blocks), m_enable_prefix_caching(enable_prefix_caching) { for (int block_id = 0; block_id < m_total_num_blocks; ++block_id) { m_free_blocks.push_back(std::make_shared(block_id)); } @@ -64,21 +134,28 @@ class BlockAllocator { } size_t num_free_blocks() const { - return m_free_blocks.size(); + return m_free_blocks.size() + m_evictor.num_blocks(); } bool can_allocate_blocks(size_t num_blocks) const { - return num_blocks <= m_free_blocks.size(); + return num_blocks <= num_free_blocks(); } void free(KVCacheBlock::Ptr block) { block->release(); if (block->is_free()) { - m_free_blocks.push_back(block); + if (m_enable_prefix_caching) + { + m_evictor.add(block->get_hash(), block); + } + else { + m_free_blocks.push_back(block); + } } } KVCacheBlock::Ptr allocate_block() { + OPENVINO_ASSERT(!m_enable_prefix_caching); OPENVINO_ASSERT(can_allocate_blocks(1)); KVCacheBlock::Ptr allocated_block = m_free_blocks.front(); allocated_block->increment(); @@ -86,20 +163,83 @@ class BlockAllocator { return allocated_block; } + KVCacheBlock::Ptr allocate_block(size_t hash, size_t num_hashed_tokens, std::map& cached_blocks) { + OPENVINO_ASSERT(m_enable_prefix_caching); + OPENVINO_ASSERT(can_allocate_blocks(1)); + auto block = m_evictor.get_block(hash); + if (block != nullptr) { + // use cached block from evictor + cached_blocks[hash] = block; + return block; + } + // TODO: Currently we cache all allocated blocks which might be redundant for beam search, + // where blocks of non-used candidates are not needed in cache. + // This part can be improved if we cache only blocks for prompt. + if (cached_blocks.find(hash) != cached_blocks.end()) { + // use cashed block from cached_blocks + block = cached_blocks[hash]; + cached_blocks[hash]->increment(); + return block; + } + if (m_free_blocks.size() > 0) { + // allocate new empty block + KVCacheBlock::Ptr allocated_block = m_free_blocks.front(); + allocated_block->increment(); + allocated_block->set_hash(hash, num_hashed_tokens); + cached_blocks[hash] = allocated_block; + + m_free_blocks.pop_front(); + return allocated_block; + } + if (m_evictor.num_blocks() > 0) { + // get least resently used block from evictor and reuse it + KVCacheBlock::Ptr block = m_evictor.get_lru_block(); + cached_blocks.erase(block->get_hash()); + + // update block with new hash + block->set_hash(hash, num_hashed_tokens); + cached_blocks[hash] = block; + return block; + } + // out of memory + return nullptr; + } + + KVCacheBlock::Ptr get_cached_block(size_t hash, std::map& cached_blocks) { + auto block = m_evictor.get_block(hash); + if (block != nullptr) { + // use cashed block from evictor + cached_blocks[hash] = block; + return block; + } + if (cached_blocks.find(hash) != cached_blocks.end()) { + // use cashed block from cached_blocks + // TODO: add tokens validation in case of hash collision + block = cached_blocks[hash]; + cached_blocks[hash]->increment(); + return block; + } + return nullptr; + } + float get_used_percentage() const { - return static_cast(m_total_num_blocks - m_free_blocks.size()) / m_total_num_blocks; + return static_cast(m_total_num_blocks - num_free_blocks()) / m_total_num_blocks; } }; class BlockManager { BlockAllocator m_allocator; + bool m_enable_prefix_caching; + size_t m_block_size; + // TODO: caching time can probably be improved if we use the prefix tree + std::map cached_blocks; // stores blocks for each sequence (not sequence group) // the same block can be seen in multiple block_tables for different sequences std::map> m_block_table; public: - BlockManager(int num_blocks) - : m_allocator(num_blocks) { } + BlockManager(int num_blocks, bool enable_prefix_caching, size_t block_size) + : m_allocator(num_blocks, enable_prefix_caching), m_enable_prefix_caching(enable_prefix_caching), m_block_size(block_size) { } ~BlockManager() { // sanity check that all sequences are freed @@ -195,11 +335,32 @@ class BlockManager { return m_allocator.can_allocate_blocks(num_blocks); } - void allocate(uint64_t sequence_id, size_t num_blocks) { + void allocate(ov::genai::Sequence::CPtr sequence, size_t num_blocks, const ov::genai::TokenIds& prompt_ids = {}) { OPENVINO_ASSERT(num_blocks > 0 && can_allocate_blocks(num_blocks)); + if (m_enable_prefix_caching) { + OPENVINO_ASSERT(prompt_ids.size() > 0, "prompt_ids should be set for hash calculation."); + } + auto sequence_id = sequence->get_id(); + auto block_table = m_block_table[sequence_id]; + auto content_length = sequence->get_generated_len() + prompt_ids.size(); + size_t num_hashed_tokens = block_table.size() * m_block_size; for (size_t i = 0; i < num_blocks; ++i) { - m_block_table[sequence_id].push_back(m_allocator.allocate_block()); + + ov::genai::KVCacheBlock::Ptr block = nullptr; + if (m_enable_prefix_caching) { + num_hashed_tokens += m_block_size; + if (num_hashed_tokens > content_length) { + num_hashed_tokens = content_length; + } + auto hash = sequence->get_hash(num_hashed_tokens, prompt_ids); + block = m_allocator.allocate_block(hash, num_hashed_tokens, cached_blocks); + } + else { + block = m_allocator.allocate_block(); + } + OPENVINO_ASSERT(block != nullptr); + m_block_table[sequence_id].push_back(block); } } @@ -324,21 +485,36 @@ class BlockManager { if (num_logical_blocks > num_physical_blocks) { OPENVINO_ASSERT(can_allocate_blocks(num_logical_blocks - num_physical_blocks)); - allocate(seq_id, num_logical_blocks - num_physical_blocks); + allocate(sequence, num_logical_blocks - num_physical_blocks, seq_group->get_prompt_ids()); } else { OPENVINO_ASSERT(num_logical_blocks == num_physical_blocks, "A number of physical and logic blocks must be the same in this code path"); KVCacheBlock::Ptr last_block = block_table.back(); - if (last_block->copy_on_write()) { // we need to fork current block, because reference counter is more than 1 - KVCacheBlock::Ptr new_block = m_allocator.allocate_block(); + KVCacheBlock::Ptr new_block = nullptr; + if (m_enable_prefix_caching) { + auto hash = sequence->get_hash(seq_group->get_context_len(), seq_group->get_prompt_ids()); + new_block = m_allocator.allocate_block(hash, seq_group->get_context_len(), cached_blocks); + cached_blocks[hash] = new_block; + } + else { + new_block = m_allocator.allocate_block(); + } block_table[num_physical_blocks - 1] = new_block; // write information about block forking for later usage in CacheManager copy_blocks_map[last_block->get_index()].push_back(new_block->get_index()); // release `last_block` usage m_allocator.free(last_block); } else { - // nothing to do, because we are the only users of this block + // we are the only users of this block + if (m_enable_prefix_caching) { + // update hash of block + auto prev_hash = last_block->get_hash(); + auto hash = sequence->get_hash(seq_group->get_context_len(), seq_group->get_prompt_ids()); + last_block->set_hash(hash, seq_group->get_context_len()); + cached_blocks.erase(prev_hash); + cached_blocks[hash] = last_block; + } } } } @@ -346,5 +522,57 @@ class BlockManager { // it returns information which blocks should be forked by CacheManager return copy_blocks_map; } + + + void _restore_cached_blocks(SequenceGroup::Ptr group, size_t block_size) { + auto prompt_ids = group->get_prompt_ids(); + auto sequences = group->get_not_finished_sequences(); + OPENVINO_ASSERT(sequences.size() == 1); + auto sequence = sequences[0]; + auto seq_id = sequence->get_id(); + auto& block_table = m_block_table[seq_id]; + + size_t content_len = 0; + while (content_len < prompt_ids.size()) { + size_t prev_iteration_content_len = content_len; + content_len += block_size; + if (content_len > prompt_ids.size()) { + content_len = prompt_ids.size(); + } + // restore fully filled blocks + auto hash = sequence->get_hash(content_len, prompt_ids); + auto block = m_allocator.get_cached_block(hash, cached_blocks); + if (block != nullptr) { + block->set_timestamp(std::chrono::system_clock::now()); + m_block_table[seq_id].push_back(block); + group->update_processed_tokens_num(content_len); + } + else { + // restore partially filled block + for (size_t i = 1; i < block_size; i++) { + if (prev_iteration_content_len + i > prompt_ids.size()) { + break; + } + auto hash = sequence->get_hash(prev_iteration_content_len + i, prompt_ids); + auto block = m_allocator.get_cached_block(hash, cached_blocks); + if (block != nullptr) { + block->set_timestamp(std::chrono::system_clock::now()); + m_block_table[seq_id].push_back(block); + group->update_processed_tokens_num(prev_iteration_content_len + i); + + size_t new_tokens_count_in_block = std::min(content_len, prev_iteration_content_len + block_size); + if (new_tokens_count_in_block > prev_iteration_content_len + i) { + cached_blocks.erase(hash); + auto new_hash = sequence->get_hash(new_tokens_count_in_block, prompt_ids); + cached_blocks[new_hash] = block; + } + + break; + } + } + break; + } + } + } }; } diff --git a/src/cpp/src/scheduler.hpp b/src/cpp/src/scheduler.hpp index ca749137db..c52ed8d7a6 100644 --- a/src/cpp/src/scheduler.hpp +++ b/src/cpp/src/scheduler.hpp @@ -10,7 +10,6 @@ #include "openvino/genai/scheduler_config.hpp" #include "block_manager.hpp" #include "sequence_group.hpp" -#include "block_manager.hpp" namespace ov::genai { class Scheduler { @@ -34,11 +33,14 @@ class Scheduler { }; explicit Scheduler(const SchedulerConfig & config = {}) : - m_config(config), m_block_manager(m_config.num_kv_blocks) { } + m_config(config), m_block_manager(m_config.num_kv_blocks, m_config.enable_prefix_caching, m_config.block_size) { } Output schedule(std::vector& sequence_groups) { Output scheduler_output; + if (m_config.enable_prefix_caching) + _restore_cached_blocks(sequence_groups); + if (m_config.dynamic_split_fuse) { // deepspeed-mii case // generation phase is always scheduled first @@ -167,6 +169,15 @@ class Scheduler { return std::numeric_limits::max(); } + void _restore_cached_blocks(const std::vector& sequence_groups) { + for (size_t sequence_group_id = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) { + SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id]; + if (sequence_group->can_generate_tokens() || sequence_group->num_running_seqs() != 1) + continue; + m_block_manager._restore_cached_blocks(sequence_group, m_config.block_size); + } + } + void _apply_preemption(size_t sequence_group_id, const std::vector& sequence_groups) { SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id]; @@ -222,7 +233,7 @@ class Scheduler { if (num_scheduled_tokens > 0) { // allocate KV blocks if required if (num_scheduled_blocks > 0) - m_block_manager.allocate(seq_id, num_scheduled_blocks); + m_block_manager.allocate(sequence, num_scheduled_blocks, sequence_group->get_prompt_ids()); // and schedule tokens sequence_group->schedule_tokens(num_scheduled_tokens); @@ -326,7 +337,8 @@ class Scheduler { // prompt phases can have a single running sequence OPENVINO_ASSERT(num_running_seqs == 1); // here we also assume that sequence must be scheduler in a single shot and has no already generated context - OPENVINO_ASSERT(sequence_group->get_context_len() == 0); + if (!m_config.enable_prefix_caching) + OPENVINO_ASSERT(sequence_group->get_context_len() == 0); size_t num_available_tokens_in_megabatch = m_config.max_num_batched_tokens - scheduler_output.m_total_num_scheduled_tokens; size_t sequence_len = sequence_group->get_num_available_tokens_for_batching(); @@ -354,11 +366,15 @@ class Scheduler { Sequence::Ptr sequence = (*sequence_group)[0]; uint64_t seq_id = sequence->get_id(); - // allocate KV blocks - m_block_manager.allocate(seq_id, num_required_blocks); // and schedule tokens sequence_group->schedule_tokens(sequence_len); + // allocate KV blocks + if (sequence_group->get_num_processed_tokens() == 0) + m_block_manager.allocate(sequence, num_required_blocks, sequence_group->get_prompt_ids()); + else + m_block_manager.append_slots(sequence_group); + // add information to scheduler_output { scheduler_output.m_scheduled_sequence_groups_ids.push_back(sequence_group_id); diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index 88b86b4484..d5b9506b2c 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include "openvino/genai/generation_handle.hpp" #include "openvino/genai/generation_config.hpp" @@ -121,6 +122,21 @@ class Sequence { float score = cumulative_log_prob / std::pow(current_length, sampling_params.length_penalty); return score; } + + // Each KV block can be uniquely identified by + // the tokens within the block and the tokens in the prefix before the block. + // hash(prefix tokens + block tokens) <--> KV Block + size_t get_hash(size_t content_length, const ov::genai::TokenIds& prompt_ids) const { + std::vector content; + OPENVINO_ASSERT(content_length <= prompt_ids.size() + m_generated_ids.size()); + content.insert( content.end(), prompt_ids.begin(), prompt_ids.begin() + std::min(prompt_ids.size(), content_length)); + if (content_length > prompt_ids.size()) { + content.insert(content.end(), m_generated_ids.begin(), m_generated_ids.begin() + content_length - prompt_ids.size()); + } + const char* data = reinterpret_cast(content.data()); + std::size_t size = content.size() * sizeof(content[0]); + return std::hash{}(std::string_view(data, size)); + } }; // contains a list of Sequences in generic case (beam search or parallel sampling) @@ -345,6 +361,11 @@ class SequenceGroup { clear_scheduled_tokens(); } + void update_processed_tokens_num(size_t processed_tokens) { + m_num_processed_tokens = processed_tokens; + m_max_content_len = processed_tokens; + } + void clear_waiting_sequences() { for (size_t seq_id = 0; seq_id < m_sequences.size(); ++seq_id) { if (m_sequences[seq_id]->is_waiting()) { diff --git a/src/python/py_generate_pipeline.cpp b/src/python/py_generate_pipeline.cpp index 8a1a226bc1..f2dea4b830 100644 --- a/src/python/py_generate_pipeline.cpp +++ b/src/python/py_generate_pipeline.cpp @@ -591,9 +591,10 @@ PYBIND11_MODULE(py_generate_pipeline, m) { .def_readwrite("num_kv_blocks", &SchedulerConfig::num_kv_blocks) .def_readwrite("cache_size", &SchedulerConfig::cache_size) .def_readwrite("block_size", &SchedulerConfig::block_size) - .def_readwrite("cache_size", &SchedulerConfig::cache_size) .def_readwrite("dynamic_split_fuse", &SchedulerConfig::dynamic_split_fuse) - .def_readwrite("max_num_seqs", &SchedulerConfig::max_num_seqs); + .def_readwrite("max_num_seqs", &SchedulerConfig::max_num_seqs) + .def_readwrite("enable_prefix_caching", &SchedulerConfig::enable_prefix_caching); + py::class_(m, "ContinuousBatchingPipeline") .def(py::init([](const std::string& model_path, const SchedulerConfig& scheduler_config, const std::string& device, const std::map& llm_plugin_config, const std::map& tokenizer_plugin_config) { diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 025a58a507..083b911416 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -4,6 +4,9 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(googletest) set(TEST_TARGET_NAME "tests_continuous_batching") -add_executable(${TEST_TARGET_NAME} scheduler.cpp block_manager.cpp logit_filtering.cpp cache_manager.cpp generate_config.cpp) +file(GLOB tests_src + "*.cpp" +) +add_executable(${TEST_TARGET_NAME} ${tests_src}) target_link_libraries(${TEST_TARGET_NAME} PUBLIC openvino::genai gtest_main) target_include_directories(${TEST_TARGET_NAME} PRIVATE "${PROJECT_SOURCE_DIR}/src/cpp/src") diff --git a/tests/cpp/block_manager.cpp b/tests/cpp/block_manager.cpp index b3c89535a6..4621c184f5 100644 --- a/tests/cpp/block_manager.cpp +++ b/tests/cpp/block_manager.cpp @@ -10,30 +10,39 @@ #include "scheduler.hpp" TEST(TestBlockManager, general_test) { - ov::genai::BlockManager bm = ov::genai::BlockManager(6); + ov::genai::BlockManager bm = ov::genai::BlockManager(6, false, 4); + ov::genai::TokenIds prompt_ids; - bm.allocate(0, 6); - EXPECT_TRUE(bm.has_block_table(0)); - EXPECT_EQ(bm.get_block_table(0).size(), 6); + ov::genai::SequenceGroup::Ptr sequence_group = std::make_shared( + 0, + ov::Tensor(ov::element::i64, { + prompt_ids.size()}, prompt_ids.data()), + ov::genai::beam_search(), + 4); + auto sequence = sequence_group->get_not_finished_sequences()[0]; + bm.allocate(sequence, 6); + auto seq_id = sequence->get_id(); + EXPECT_TRUE(bm.has_block_table(seq_id)); + EXPECT_EQ(bm.get_block_table(seq_id).size(), 6); EXPECT_EQ(bm.num_free_blocks(), 0); - bm.free_sequence_partially_single_runnning_sequence(0, 4); - EXPECT_EQ(bm.get_block_table(0).size(), 2); + bm.free_sequence_partially_single_runnning_sequence(seq_id, 4); + EXPECT_EQ(bm.get_block_table(seq_id).size(), 2); EXPECT_EQ(bm.num_free_blocks(), 4); - bm.free_sequence(0); - EXPECT_FALSE(bm.has_block_table(0)); + bm.free_sequence(seq_id); + EXPECT_FALSE(bm.has_block_table(seq_id)); EXPECT_EQ(bm.num_free_blocks(), 6); - bm.allocate(0, 2); - bm.fork_sequence(0, 1); + bm.allocate(sequence, 2); + bm.fork_sequence(seq_id, 1); EXPECT_TRUE(bm.has_block_table(1)); EXPECT_EQ(bm.get_block_table(1).back()->get_references_count(), 2); } TEST(TestBlockManager, required_blocks_count) { - ov::genai::BlockManager bm = ov::genai::BlockManager(8); + ov::genai::BlockManager bm = ov::genai::BlockManager(8, false, 4); std::vector tokens = {0,1,2,3,4}; ov::genai::SequenceGroup::Ptr sequence_group = std::make_shared( diff --git a/tests/cpp/evictor.cpp b/tests/cpp/evictor.cpp new file mode 100644 index 0000000000..9867dfa2b5 --- /dev/null +++ b/tests/cpp/evictor.cpp @@ -0,0 +1,54 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include "openvino/runtime/core.hpp" +#include "scheduler.hpp" +#include +#include + +TEST(TestEvictor, general_test) { + ov::genai::Evictor evictor; + auto block0 = std::make_shared(0); + block0->set_hash(77, 1); + std::this_thread::sleep_until(std::chrono::system_clock::now() + std::chrono::seconds(1)); + auto block1 = std::make_shared(1); + block1->set_hash(56, 2); + std::this_thread::sleep_until(std::chrono::system_clock::now() + std::chrono::seconds(1)); + auto block2 = std::make_shared(2); + block2->set_hash(23, 3); + std::this_thread::sleep_until(std::chrono::system_clock::now() + std::chrono::seconds(1)); + evictor.add(block0->get_hash(), block0); + evictor.add(block1->get_hash(), block1); + evictor.add(block2->get_hash(), block2); + EXPECT_EQ(evictor.num_blocks(), 3); + + auto block = evictor.get_block(56); + EXPECT_EQ(block->get_index(), 1); + EXPECT_EQ(block->get_hash(), 56); + EXPECT_EQ(block->get_references_count(), 1); + EXPECT_EQ(evictor.num_blocks(), 2); + + EXPECT_EQ(evictor.get_block(44), nullptr); + EXPECT_EQ(evictor.num_blocks(), 2); + + EXPECT_EQ(evictor.get_lru_block()->get_index(), 0); + EXPECT_EQ(evictor.num_blocks(), 1); + + auto block3 = std::make_shared(7); + block3->set_hash(12, 4); + std::this_thread::sleep_until(std::chrono::system_clock::now() + std::chrono::seconds(1)); + auto block4 = std::make_shared(10); + block4->set_hash(99, 5); + std::this_thread::sleep_until(std::chrono::system_clock::now() + std::chrono::seconds(1)); + evictor.add(block3->get_hash(), block3); + evictor.add(block4->get_hash(), block4); + block2->set_timestamp(std::chrono::system_clock::now()); + + EXPECT_EQ(evictor.get_lru_block()->get_index(), 7); + EXPECT_EQ(evictor.get_lru_block()->get_index(), 10); + EXPECT_EQ(evictor.get_lru_block()->get_index(), 2); + EXPECT_EQ(evictor.get_lru_block(), nullptr); + EXPECT_EQ(evictor.num_blocks(), 0); +} diff --git a/tests/cpp/scheduler.cpp b/tests/cpp/scheduler.cpp index b4114dd1b2..5468fd014b 100644 --- a/tests/cpp/scheduler.cpp +++ b/tests/cpp/scheduler.cpp @@ -366,3 +366,71 @@ TEST(TestScheduler, test_partially_preempted_prompt) { EXPECT_FALSE(scheduler.has_block_table(idx0)); } } + + + +TEST(TestScheduler, prefix_caching_test) { + std::array configs = {SchedulerConfig(), SchedulerConfig()}; + configs.at(0).max_num_batched_tokens = 32; + configs.at(0).num_kv_blocks = 100; + configs.at(0).block_size = 4; + configs.at(0).dynamic_split_fuse = false; + configs.at(0).max_num_seqs = 5; + configs.at(0).enable_prefix_caching = true; + configs.at(1).max_num_batched_tokens = 32; + configs.at(1).num_kv_blocks = 100; + configs.at(1).block_size = 4; + configs.at(1).dynamic_split_fuse = true; + configs.at(1).max_num_seqs = 5; + configs.at(1).enable_prefix_caching = true; + for (auto scheduler_config: configs) { + std::vector prompt_tokens = {0,1,2,3,4,5,6,7}; + std::vector histrory_tokens = {}; + // schedule prompt + Scheduler scheduler = Scheduler(scheduler_config); + + size_t chat_iterations = 10; + + for (size_t chat_iteration = 0; chat_iteration < chat_iterations; chat_iteration++) { + std::vector tokens = histrory_tokens; + tokens.insert(tokens.end(), prompt_tokens.begin(), prompt_tokens.end()); + SequenceGroup::Ptr sequence_group = std::make_shared(0, ov::Tensor(ov::element::i64, {tokens.size()}, tokens.data()), + ov::genai::greedy(), scheduler_config.block_size); + std::vector requests = {sequence_group}; + + auto out1 = scheduler.schedule(requests); + if (chat_iteration == 0) + EXPECT_EQ(out1.m_total_num_scheduled_tokens, prompt_tokens.size()); + else + EXPECT_EQ(out1.m_total_num_scheduled_tokens, prompt_tokens.size() + 1); + for (auto seq: requests) { + std::vector running_sequences = seq->get_running_sequences(); + running_sequences[0]->append_token(23, 0.7); + seq->finish_iteration(); + } + + // schedule generate + size_t num_generate_tokens = 10; + for (size_t i = 0; i < num_generate_tokens; i++) { + auto out2 = scheduler.schedule(requests); + EXPECT_EQ(out2.m_total_num_scheduled_tokens, 1); + for (auto seq: requests) { + std::vector running_sequences = seq->get_running_sequences(); + running_sequences[0]->append_token(16, 0.9); + seq->finish_iteration(); + } + } + + // finish sequence + auto sequence = requests[0]->get_running_sequences()[0]; + sequence->set_status(SequenceStatus::FINISHED); + auto idx0 = sequence->get_id(); + scheduler.free_sequence(idx0); + auto generated_ids = sequence->get_generated_ids(); + + histrory_tokens.insert(histrory_tokens.end(), prompt_tokens.begin(), prompt_tokens.end()); + histrory_tokens.insert(histrory_tokens.end(), generated_ids.begin(), generated_ids.end()); + } + } + +} From 12d933fdf6c32d46a72363152cd849feb5452a71 Mon Sep 17 00:00:00 2001 From: Damian Kalinowski Date: Tue, 30 Jul 2024 15:15:50 +0200 Subject: [PATCH 3/5] Coverity fixes related to OVMS (#706) --- src/cpp/src/block_manager.hpp | 1 - src/cpp/src/tokenizers_path.hpp | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/cpp/src/block_manager.hpp b/src/cpp/src/block_manager.hpp index 3b1a663235..3e80217f14 100644 --- a/src/cpp/src/block_manager.hpp +++ b/src/cpp/src/block_manager.hpp @@ -277,7 +277,6 @@ class BlockManager { } phisical_blocks_released += released_count; } - phisical_blocks_released = phisical_blocks_released; return num_required_blocks <= phisical_blocks_released; } diff --git a/src/cpp/src/tokenizers_path.hpp b/src/cpp/src/tokenizers_path.hpp index d2c3ef3b5e..4899daccc4 100644 --- a/src/cpp/src/tokenizers_path.hpp +++ b/src/cpp/src/tokenizers_path.hpp @@ -86,7 +86,7 @@ std::filesystem::path tokenizers_relative_to_genai() { // was already defined. class ScopedVar { public: - bool was_already_set; + bool was_already_set{false}; static constexpr char ENVIRONMENT_VARIABLE_NAME[] = "OPENVINO_TOKENIZERS_PATH_GENAI"; explicit ScopedVar(const std::string& environment_variable_value) { #ifdef _WIN32 From 42281319e90523d646c47692a43bdcc6b78ecb49 Mon Sep 17 00:00:00 2001 From: Oleg Pipikin Date: Tue, 30 Jul 2024 20:20:31 +0200 Subject: [PATCH 4/5] Fix to throw exception in case of empty chat template in chat scenario (#697) --- samples/cpp/chat_sample/README.md | 8 ++++++++ samples/python/chat_sample/README.md | 10 ++++++++++ src/cpp/src/llm_pipeline.cpp | 1 + src/cpp/src/tokenizer.cpp | 7 +++++-- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/samples/cpp/chat_sample/README.md b/samples/cpp/chat_sample/README.md index a2eccb4d3d..3f736985c2 100644 --- a/samples/cpp/chat_sample/README.md +++ b/samples/cpp/chat_sample/README.md @@ -34,3 +34,11 @@ UnicodeEncodeError: 'charmap' codec can't encode character '\u25aa' in position If you encounter the error described in the example when sample is printing output to the Windows console, it is likely due to the default Windows encoding not supporting certain Unicode characters. To resolve this: 1. Enable Unicode characters for Windows cmd - open `Region` settings from `Control panel`. `Administrative`->`Change system locale`->`Beta: Use Unicode UTF-8 for worldwide language support`->`OK`. Reboot. 2. Enable UTF-8 mode by setting environment variable `PYTHONIOENCODING="utf8"`. + +#### Missing chat template + +If you encounter an exception indicating a missing "chat template" when launching the `ov::genai::LLMPipeline` in chat mode, it likely means the model was not tuned for chat functionality. To work this around, manually add the chat template to tokenizer_config.json of your model. +The following template can be used as a default, but it may not work properly with every model: +``` +"chat_template": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n<|im_start|>assistant\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>\n'}}{% endif %}{% endfor %}", +``` diff --git a/samples/python/chat_sample/README.md b/samples/python/chat_sample/README.md index 983789d0eb..c07023391f 100644 --- a/samples/python/chat_sample/README.md +++ b/samples/python/chat_sample/README.md @@ -22,3 +22,13 @@ To enable Unicode characters for Windows cmd open `Region` settings from `Contro Discrete GPUs (dGPUs) usually provide better performance compared to CPUs. It is recommended to run larger models on a dGPU with 32GB+ RAM. For example, the model meta-llama/Llama-2-13b-chat-hf can benefit from being run on a dGPU. Modify the source code to change the device for inference to the GPU. See https://github.com/openvinotoolkit/openvino.genai/blob/master/src/README.md#supported-models for the list of supported models. + + +## Troubleshooting +### Missing chat template + +If you encounter an exception indicating a missing "chat template" when launching the `ov::genai::LLMPipeline` in chat mode, it likely means the model was not tuned for chat functionality. To work this around, manually add the chat template to tokenizer_config.json of your model. +The following template can be used as a default, but it may not work properly with every model: +``` +"chat_template": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n<|im_start|>assistant\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>\n'}}{% endif %}{% endfor %}", +``` \ No newline at end of file diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 507d988a6a..1594dbd583 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -271,6 +271,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { m_history.push_back({{"role", "system"}, {"content", system_message}}); constexpr bool add_generation_prompt = false; + m_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); } diff --git a/src/cpp/src/tokenizer.cpp b/src/cpp/src/tokenizer.cpp index b1e36033ee..748daa5875 100644 --- a/src/cpp/src/tokenizer.cpp +++ b/src/cpp/src/tokenizer.cpp @@ -368,6 +368,11 @@ class Tokenizer::TokenizerImpl { bool add_generation_prompt, const std::string& chat_template) const { auto chat_tpl = chat_template.empty() ? m_chat_template : chat_template; + OPENVINO_ASSERT(!chat_tpl.empty(), + "Chat template wasn't found. This may indicate that the model wasn't trained for chat scenario." + " Please add 'chat_template' to tokenizer_config.json to use the model in chat scenario." + " For more information see the section Troubleshooting in README.md"); + // Jinja2Cpp does not support slicing, e.g. [1:]. // In templates slicing is used typically in the header to find system prompt. // If header containts that typical expression we update template and @@ -433,8 +438,6 @@ class Tokenizer::TokenizerImpl { "For exmaple: user{user_prompt}model"); } } - - }; Tokenizer::Tokenizer(const std::string& tokenizer_path, const ov::AnyMap& plugin_config) { From 3f55103816cc9417857e9d2ef98fe3404e76a10a Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Wed, 31 Jul 2024 12:14:27 +0400 Subject: [PATCH 5/5] update optimum commit for master (#710) --- llm_bench/python/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llm_bench/python/requirements.txt b/llm_bench/python/requirements.txt index ed80a66deb..e7f7dfcd10 100644 --- a/llm_bench/python/requirements.txt +++ b/llm_bench/python/requirements.txt @@ -10,7 +10,7 @@ torch transformers>=4.40.0 diffusers>=0.22.0 #optimum is in dependency list of optimum-intel -git+https://github.com/huggingface/optimum-intel.git@439d61f79cf55d5d0b28334f577b6ac3c5ced28f#egg=optimum-intel +git+https://github.com/eaidova/optimum-intel.git@ea/remove_bf16_rotary_emb_patching#egg=optimum-intel git+https://github.com/openvinotoolkit/nncf.git@develop#egg=nncf packaging psutil