Skip to content

Commit

Permalink
[Streamer] Handle stop strings in case of sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Dec 16, 2024
1 parent 9e9b409 commit 0e91fae
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 45 deletions.
13 changes: 4 additions & 9 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
[](const std::shared_ptr<StreamerBase>& streamer) {
return streamer;
},
[this](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
[this, &sampling_params](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return sampling_params.size() == 1 ? std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer, sampling_params.begin()->stop_strings) : std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
}
}, streamer);

Expand Down Expand Up @@ -275,8 +275,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
m_requests.clear();
};

bool continue_generation = true;
while (has_non_finished_requests() && continue_generation) {
while (has_non_finished_requests()) {
try {
step();
} catch (...) {
Expand All @@ -297,11 +296,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
streamer_ptr->end();
}

if (!continue_generation) {
drop_requests();
} else {
OPENVINO_ASSERT(m_requests.empty(), "Internal error: current request is supposed to be dropped within step() function as completed");
}
OPENVINO_ASSERT(m_requests.empty(), "Internal error: current request is supposed to be dropped within step() function as completed");

for (size_t generation_idx = 0; generation_idx < generations.size(); ++generation_idx) {
const auto& generation = generations[generation_idx];
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
} else if (auto streamer_obj = std::get_if<std::shared_ptr<StreamerBase>>(&streamer)) {
streamer_ptr = *streamer_obj;
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback, generation_config->stop_strings);
}

auto batch_size = input_ids.get_shape().at(0);
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ EncodedResults StaticLLMPipeline::generate(
} else if (auto streamer_obj = std::get_if<std::shared_ptr<StreamerBase>>(&streamer)) {
streamer_ptr = *streamer_obj;
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback, generation_config->stop_strings);
}

if (!config.is_greedy_decoding()) {
Expand Down
5 changes: 1 addition & 4 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,6 @@ std::vector<int64_t> Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen
if (!sampling_params.stop_strings.empty()) {
int num_matched_last_tokens = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), sampling_params.stop_strings);
if (num_matched_last_tokens) {
if (!sampling_params.include_stop_str_in_output)
running_sequence->remove_last_tokens(num_matched_last_tokens);
running_sequence->set_status(SequenceStatus::FINISHED);
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
dropped_seq_ids.push_back(running_sequence->get_id());
Expand Down Expand Up @@ -886,8 +884,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
// Notify handle after sampling is done.
// For non-streaming this is effective only when the generation is finished.
OPENVINO_ASSERT(num_tokens_to_process >= max_removed_tokens_per_request);
size_t num_output_token_to_push = num_tokens_to_process - max_removed_tokens_per_request + 1;
sequence_group->notify_handle(num_output_token_to_push);
sequence_group->notify_handle();
} else {
// we are in prompt processing phase when prompt is split into chunks and processed step by step
}
Expand Down
22 changes: 16 additions & 6 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ class SequenceGroup {
// flag to enable/disable token generation, e.g. in speculative decoding scenario
bool m_is_gen_paused = false;

size_t m_num_streamed_tokens = 0;


SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching)
: m_request_id(request_id),
Expand Down Expand Up @@ -612,7 +614,7 @@ class SequenceGroup {
m_generation_stream->push(std::move(outputs));
}

void notify_handle(size_t num_output_token_to_push = 0) {
void notify_handle() {
if (out_of_memory()) {
set_generation_status(GenerationStatus::IGNORED);
} else if (has_finished()) {
Expand All @@ -626,10 +628,18 @@ class SequenceGroup {
} else if (m_sampling_params.is_greedy_decoding() || m_sampling_params.is_multinomial()) {
// We can stream only when one sequence is returned and we don't use stop strings that would be excluded from the output
// (after stop string is detected its tokens are already sent)
if (num_total_seqs() == 1 &&
(m_sampling_params.stop_strings.empty() || m_sampling_params.include_stop_str_in_output)) {
if (num_output_token_to_push)
push_partial_outputs(num_output_token_to_push);
if (num_total_seqs() == 1) {
const auto generated_len = m_sequences.front()->get_generated_len();
// speculative decoding draft handling
if (generated_len < m_num_streamed_tokens) {
m_num_streamed_tokens = generated_len;
}
OPENVINO_ASSERT(generated_len >= m_num_streamed_tokens);
auto delta = generated_len - m_num_streamed_tokens;

size_t num_output_token_to_push = generated_len - m_num_streamed_tokens;
push_partial_outputs(num_output_token_to_push);
m_num_streamed_tokens += (num_output_token_to_push);
} else if (has_finished() || out_of_memory()) {
push_outputs();
}
Expand Down Expand Up @@ -661,4 +671,4 @@ class SequenceGroup {
m_generation_stream->push(std::move(outputs));
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector<
[](const std::shared_ptr<StreamerBase>& streamer) {
return streamer;
},
[this](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
[this, &sampling_params](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return sampling_params.size() == 1 ? std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer, sampling_params.begin()->stop_strings) : std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
}
}, streamer);

Expand Down
93 changes: 73 additions & 20 deletions src/cpp/src/text_callback_streamer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,93 @@
namespace ov {
namespace genai {

TextCallbackStreamer::TextCallbackStreamer(const Tokenizer& tokenizer, std::function<bool(std::string)> callback) {
std::vector<int64_t> encode_and_process_stop_string(const std::string& stop_string, ov::genai::Tokenizer& tokenizer) {
// encode stop_string
ov::Tensor ov_encoded_stop_string = tokenizer.encode(stop_string).input_ids;
size_t tensor_size = ov_encoded_stop_string.get_size();
std::vector<int64_t> source_encoded_stop_string(tensor_size), encoded_stop_string;
std::copy_n(ov_encoded_stop_string.data<int64_t>(), tensor_size, source_encoded_stop_string.begin());
// remove special symbols
for (const auto& token_id : source_encoded_stop_string) {
if (token_id != tokenizer.get_bos_token_id() &&
token_id != tokenizer.get_eos_token_id() &&
token_id != tokenizer.get_pad_token_id()) {
encoded_stop_string.push_back(token_id);
}
}
return encoded_stop_string;
}

TextCallbackStreamer::TextCallbackStreamer(const Tokenizer& tokenizer, std::function<bool(std::string)> callback, const std::set<std::string>& stop_strings) {
m_tokenizer = tokenizer;
on_finalized_subword_callback = callback;
for (const auto& stop_string : stop_strings) {
auto encoded_stop_string = encode_and_process_stop_string(stop_string, m_tokenizer);
m_max_stop_string_len = std::max(encoded_stop_string.size(), m_max_stop_string_len);
m_stop_strings.insert(stop_string);
}
}

bool TextCallbackStreamer::put(int64_t token) {
std::stringstream res;
m_tokens_cache.push_back(token);
std::string text = m_tokenizer.decode(m_tokens_cache);
if (!text.empty() && '\n' == text.back() && text.size() > print_len) {
// Flush the cache after the new line symbol
res << std::string_view{text.data() + print_len, text.size() - print_len};
m_tokens_cache.clear();
print_len = 0;
return on_finalized_subword_callback(res.str());
}
m_tokens_cache_stop_string.push_back(token);
if (m_tokens_cache_stop_string.size() > m_max_stop_string_len || token == m_tokenizer.get_eos_token_id()) {
std::vector<int64_t> buffer(m_tokens_cache_stop_string.begin(), m_tokens_cache_stop_string.end());
std::string text = m_tokenizer.decode(buffer);
std::string activated_stop_string = "";
for (const auto& stop_string : m_stop_strings) {
if (text.find(stop_string) != std::string::npos) {
activated_stop_string = stop_string;
break;
}
}


if (activated_stop_string.empty() && token != m_tokenizer.get_eos_token_id()) {
m_tokens_cache.push_back(m_tokens_cache_stop_string.front());
m_tokens_cache_stop_string.pop_front();
} else {
m_tokens_cache.insert(m_tokens_cache.end(), m_tokens_cache_stop_string.begin(), m_tokens_cache_stop_string.end());
m_tokens_cache_stop_string.clear();
}

text = m_tokenizer.decode(m_tokens_cache);
if (!activated_stop_string.empty()) {
auto pos = text.find(activated_stop_string);
if (pos != std::string::npos) {
text.replace(pos, activated_stop_string.length(), "");
}
m_tokens_cache.clear();
}

if (!text.empty() && '\n' == text.back() && text.size() > print_len) {
// Flush the cache after the new line symbol
res << std::string_view{text.data() + print_len, text.size() - print_len};
m_tokens_cache.clear();
print_len = 0;
return on_finalized_subword_callback(res.str());
}


constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error.
if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) {
// Don't print incomplete text
return on_finalized_subword_callback(res.str());
} else if (text.size() > print_len) {
// It is possible to have a shorter text after adding new token.
// Print to output only if text length is increaesed.
res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush;
print_len = text.size();
constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error.
if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) {
// Don't print incomplete text
return on_finalized_subword_callback(res.str());
} else {
// It is possible to have a shorter text after adding new token.
// Print to output only if text length is increaesed.
res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush;
print_len = text.size();
}
}

return on_finalized_subword_callback(res.str());
}

void TextCallbackStreamer::end() {
std::stringstream res;
std::string text = m_tokenizer.decode(m_tokens_cache);
std::vector<int64_t> buffer(m_tokens_cache.begin(), m_tokens_cache.end());
std::string text = m_tokenizer.decode(buffer);
if (text.size() <= print_len)
return ;
res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush;
Expand Down
8 changes: 6 additions & 2 deletions src/cpp/src/text_callback_streamer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#pragma once

#include <list>

#include "openvino/genai/streamer_base.hpp"
#include "openvino/genai/tokenizer.hpp"

Expand All @@ -14,14 +16,16 @@ class TextCallbackStreamer: public StreamerBase {
bool put(int64_t token) override;
void end() override;

TextCallbackStreamer(const Tokenizer& tokenizer, std::function<bool(std::string)> callback);
TextCallbackStreamer(const Tokenizer& tokenizer, std::function<bool(std::string)> callback, const std::set<std::string>& stop_strings = {});

std::function<bool(std::string)> on_finalized_subword_callback = [](std::string words)->bool { return false; };

protected:
Tokenizer m_tokenizer;
std::vector<int64_t> m_tokens_cache;
size_t print_len = 0;
std::list<int64_t> m_tokens_cache_stop_string;
size_t print_len = 0, m_max_stop_string_len = 0;
std::set<std::string> m_stop_strings;
};

} // namespace genai
Expand Down

0 comments on commit 0e91fae

Please sign in to comment.