From 0e91fae93975d01d690c1265ee9bdad83cf3ee2c Mon Sep 17 00:00:00 2001 From: Irina Efode Date: Mon, 16 Dec 2024 22:24:20 +0400 Subject: [PATCH] [Streamer] Handle stop strings in case of sampler --- src/cpp/src/continuous_batching_impl.cpp | 13 +-- src/cpp/src/llm_pipeline.cpp | 2 +- src/cpp/src/llm_pipeline_static.cpp | 2 +- src/cpp/src/sampler.cpp | 5 +- src/cpp/src/sequence_group.hpp | 22 +++-- .../speculative_decoding_impl.cpp | 4 +- src/cpp/src/text_callback_streamer.cpp | 93 +++++++++++++++---- src/cpp/src/text_callback_streamer.hpp | 8 +- 8 files changed, 104 insertions(+), 45 deletions(-) diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 1e42f5b2d9..442fd6f7c5 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -246,8 +246,8 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector& streamer) { return streamer; }, - [this](const std::function& streamer) -> std::shared_ptr { - return std::make_unique(m_tokenizer, streamer); + [this, &sampling_params](const std::function& streamer) -> std::shared_ptr { + return sampling_params.size() == 1 ? std::make_unique(m_tokenizer, streamer, sampling_params.begin()->stop_strings) : std::make_unique(m_tokenizer, streamer); } }, streamer); @@ -275,8 +275,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vectorend(); } - if (!continue_generation) { - drop_requests(); - } else { - OPENVINO_ASSERT(m_requests.empty(), "Internal error: current request is supposed to be dropped within step() function as completed"); - } + OPENVINO_ASSERT(m_requests.empty(), "Internal error: current request is supposed to be dropped within step() function as completed"); for (size_t generation_idx = 0; generation_idx < generations.size(); ++generation_idx) { const auto& generation = generations[generation_idx]; diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index f663b27dd9..89e71f21f8 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -273,7 +273,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { } else if (auto streamer_obj = std::get_if>(&streamer)) { streamer_ptr = *streamer_obj; } else if (auto callback = std::get_if>(&streamer)) { - streamer_ptr = std::make_shared(m_tokenizer, *callback); + streamer_ptr = std::make_shared(m_tokenizer, *callback, generation_config->stop_strings); } auto batch_size = input_ids.get_shape().at(0); diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index cb83209b4b..01a06230d0 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -967,7 +967,7 @@ EncodedResults StaticLLMPipeline::generate( } else if (auto streamer_obj = std::get_if>(&streamer)) { streamer_ptr = *streamer_obj; } else if (auto callback = std::get_if>(&streamer)) { - streamer_ptr = std::make_shared(m_tokenizer, *callback); + streamer_ptr = std::make_shared(m_tokenizer, *callback, generation_config->stop_strings); } if (!config.is_greedy_decoding()) { diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index f77463d767..f1abc862e2 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -578,8 +578,6 @@ std::vector Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen if (!sampling_params.stop_strings.empty()) { int num_matched_last_tokens = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), sampling_params.stop_strings); if (num_matched_last_tokens) { - if (!sampling_params.include_stop_str_in_output) - running_sequence->remove_last_tokens(num_matched_last_tokens); running_sequence->set_status(SequenceStatus::FINISHED); running_sequence->set_finish_reason(GenerationFinishReason::STOP); dropped_seq_ids.push_back(running_sequence->get_id()); @@ -886,8 +884,7 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, // Notify handle after sampling is done. // For non-streaming this is effective only when the generation is finished. OPENVINO_ASSERT(num_tokens_to_process >= max_removed_tokens_per_request); - size_t num_output_token_to_push = num_tokens_to_process - max_removed_tokens_per_request + 1; - sequence_group->notify_handle(num_output_token_to_push); + sequence_group->notify_handle(); } else { // we are in prompt processing phase when prompt is split into chunks and processed step by step } diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index 6755255fe8..c8b4c59486 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -221,6 +221,8 @@ class SequenceGroup { // flag to enable/disable token generation, e.g. in speculative decoding scenario bool m_is_gen_paused = false; + size_t m_num_streamed_tokens = 0; + SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching) : m_request_id(request_id), @@ -612,7 +614,7 @@ class SequenceGroup { m_generation_stream->push(std::move(outputs)); } - void notify_handle(size_t num_output_token_to_push = 0) { + void notify_handle() { if (out_of_memory()) { set_generation_status(GenerationStatus::IGNORED); } else if (has_finished()) { @@ -626,10 +628,18 @@ class SequenceGroup { } else if (m_sampling_params.is_greedy_decoding() || m_sampling_params.is_multinomial()) { // We can stream only when one sequence is returned and we don't use stop strings that would be excluded from the output // (after stop string is detected its tokens are already sent) - if (num_total_seqs() == 1 && - (m_sampling_params.stop_strings.empty() || m_sampling_params.include_stop_str_in_output)) { - if (num_output_token_to_push) - push_partial_outputs(num_output_token_to_push); + if (num_total_seqs() == 1) { + const auto generated_len = m_sequences.front()->get_generated_len(); + // speculative decoding draft handling + if (generated_len < m_num_streamed_tokens) { + m_num_streamed_tokens = generated_len; + } + OPENVINO_ASSERT(generated_len >= m_num_streamed_tokens); + auto delta = generated_len - m_num_streamed_tokens; + + size_t num_output_token_to_push = generated_len - m_num_streamed_tokens; + push_partial_outputs(num_output_token_to_push); + m_num_streamed_tokens += (num_output_token_to_push); } else if (has_finished() || out_of_memory()) { push_outputs(); } @@ -661,4 +671,4 @@ class SequenceGroup { m_generation_stream->push(std::move(outputs)); } }; -} +} \ No newline at end of file diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index e4f3b1ad1f..fd9bf00785 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -199,8 +199,8 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< [](const std::shared_ptr& streamer) { return streamer; }, - [this](const std::function& streamer) -> std::shared_ptr { - return std::make_unique(m_tokenizer, streamer); + [this, &sampling_params](const std::function& streamer) -> std::shared_ptr { + return sampling_params.size() == 1 ? std::make_unique(m_tokenizer, streamer, sampling_params.begin()->stop_strings) : std::make_unique(m_tokenizer, streamer); } }, streamer); diff --git a/src/cpp/src/text_callback_streamer.cpp b/src/cpp/src/text_callback_streamer.cpp index 314a7ffa4d..46b4c666b9 100644 --- a/src/cpp/src/text_callback_streamer.cpp +++ b/src/cpp/src/text_callback_streamer.cpp @@ -6,32 +6,84 @@ namespace ov { namespace genai { -TextCallbackStreamer::TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback) { +std::vector encode_and_process_stop_string(const std::string& stop_string, ov::genai::Tokenizer& tokenizer) { + // encode stop_string + ov::Tensor ov_encoded_stop_string = tokenizer.encode(stop_string).input_ids; + size_t tensor_size = ov_encoded_stop_string.get_size(); + std::vector source_encoded_stop_string(tensor_size), encoded_stop_string; + std::copy_n(ov_encoded_stop_string.data(), tensor_size, source_encoded_stop_string.begin()); + // remove special symbols + for (const auto& token_id : source_encoded_stop_string) { + if (token_id != tokenizer.get_bos_token_id() && + token_id != tokenizer.get_eos_token_id() && + token_id != tokenizer.get_pad_token_id()) { + encoded_stop_string.push_back(token_id); + } + } + return encoded_stop_string; +} + +TextCallbackStreamer::TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback, const std::set& stop_strings) { m_tokenizer = tokenizer; on_finalized_subword_callback = callback; + for (const auto& stop_string : stop_strings) { + auto encoded_stop_string = encode_and_process_stop_string(stop_string, m_tokenizer); + m_max_stop_string_len = std::max(encoded_stop_string.size(), m_max_stop_string_len); + m_stop_strings.insert(stop_string); + } } bool TextCallbackStreamer::put(int64_t token) { std::stringstream res; - m_tokens_cache.push_back(token); - std::string text = m_tokenizer.decode(m_tokens_cache); - if (!text.empty() && '\n' == text.back() && text.size() > print_len) { - // Flush the cache after the new line symbol - res << std::string_view{text.data() + print_len, text.size() - print_len}; - m_tokens_cache.clear(); - print_len = 0; - return on_finalized_subword_callback(res.str()); - } + m_tokens_cache_stop_string.push_back(token); + if (m_tokens_cache_stop_string.size() > m_max_stop_string_len || token == m_tokenizer.get_eos_token_id()) { + std::vector buffer(m_tokens_cache_stop_string.begin(), m_tokens_cache_stop_string.end()); + std::string text = m_tokenizer.decode(buffer); + std::string activated_stop_string = ""; + for (const auto& stop_string : m_stop_strings) { + if (text.find(stop_string) != std::string::npos) { + activated_stop_string = stop_string; + break; + } + } + + + if (activated_stop_string.empty() && token != m_tokenizer.get_eos_token_id()) { + m_tokens_cache.push_back(m_tokens_cache_stop_string.front()); + m_tokens_cache_stop_string.pop_front(); + } else { + m_tokens_cache.insert(m_tokens_cache.end(), m_tokens_cache_stop_string.begin(), m_tokens_cache_stop_string.end()); + m_tokens_cache_stop_string.clear(); + } + + text = m_tokenizer.decode(m_tokens_cache); + if (!activated_stop_string.empty()) { + auto pos = text.find(activated_stop_string); + if (pos != std::string::npos) { + text.replace(pos, activated_stop_string.length(), ""); + } + m_tokens_cache.clear(); + } + + if (!text.empty() && '\n' == text.back() && text.size() > print_len) { + // Flush the cache after the new line symbol + res << std::string_view{text.data() + print_len, text.size() - print_len}; + m_tokens_cache.clear(); + print_len = 0; + return on_finalized_subword_callback(res.str()); + } + - constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error. - if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) { - // Don't print incomplete text - return on_finalized_subword_callback(res.str()); - } else if (text.size() > print_len) { - // It is possible to have a shorter text after adding new token. - // Print to output only if text length is increaesed. - res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush; - print_len = text.size(); + constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error. + if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) { + // Don't print incomplete text + return on_finalized_subword_callback(res.str()); + } else { + // It is possible to have a shorter text after adding new token. + // Print to output only if text length is increaesed. + res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush; + print_len = text.size(); + } } return on_finalized_subword_callback(res.str()); @@ -39,7 +91,8 @@ bool TextCallbackStreamer::put(int64_t token) { void TextCallbackStreamer::end() { std::stringstream res; - std::string text = m_tokenizer.decode(m_tokens_cache); + std::vector buffer(m_tokens_cache.begin(), m_tokens_cache.end()); + std::string text = m_tokenizer.decode(buffer); if (text.size() <= print_len) return ; res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush; diff --git a/src/cpp/src/text_callback_streamer.hpp b/src/cpp/src/text_callback_streamer.hpp index a03b0deccb..ae353d27d5 100644 --- a/src/cpp/src/text_callback_streamer.hpp +++ b/src/cpp/src/text_callback_streamer.hpp @@ -3,6 +3,8 @@ #pragma once +#include + #include "openvino/genai/streamer_base.hpp" #include "openvino/genai/tokenizer.hpp" @@ -14,14 +16,16 @@ class TextCallbackStreamer: public StreamerBase { bool put(int64_t token) override; void end() override; - TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback); + TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback, const std::set& stop_strings = {}); std::function on_finalized_subword_callback = [](std::string words)->bool { return false; }; protected: Tokenizer m_tokenizer; std::vector m_tokens_cache; - size_t print_len = 0; + std::list m_tokens_cache_stop_string; + size_t print_len = 0, m_max_stop_string_len = 0; + std::set m_stop_strings; }; } // namespace genai