Skip to content

Commit

Permalink
Move beam search in case of chat scenario to sampler.cpp (#1215)
Browse files Browse the repository at this point in the history
Task [CVS-156578](https://jira.devtools.intel.com/browse/CVS-156578)

- add missed token, if prev generation was finished because max length
was reached
  • Loading branch information
sbalandi authored Dec 20, 2024
1 parent 74cdfc9 commit 05d01ac
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 576 deletions.
455 changes: 0 additions & 455 deletions src/cpp/src/group_beam_searcher.cpp

This file was deleted.

134 changes: 76 additions & 58 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,23 @@
namespace ov {
namespace genai {

std::pair<EncodedResults, int32_t> beam_search(
ov::InferRequest& lm,
ov::Tensor prompts,
ov::Tensor attention_mask,
GenerationConfig config,
std::optional<ov::Tensor> position_ids,
std::optional<int32_t> selected_beam_idx
);

class StatefulLLMPipeline final : public LLMPipelineImplBase {
public:
ov::InferRequest m_model_runner;
bool is_chat_conversation = false;
bool m_trust_encoded_history = true;
std::optional<int32_t> m_selected_beam = std::nullopt;
ChatHistory m_history;
std::string m_templated_chat_history = {};
std::vector<int64_t> m_tokenized_chat_history;
ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
size_t m_to_remove_from_hist = 0;
size_t m_kv_cache_seq_length_axis = 2;
Sampler m_sampler;
// Tail of previous output in chat mode is missing in KV cache, let's keep it
std::optional<int64_t> m_last_disappeared_token = std::nullopt;
// If sequence contains some symbols, which could be ambiguously encoded by tokenizer, we need to trim kv cache
// If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history
// so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history
ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0};

StatefulLLMPipeline(
const ov::InferRequest& request,
Expand Down Expand Up @@ -154,35 +149,44 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
// some symbols combinations can be encoded by the tokenizer in different ways
// if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history
// so let's check it out, find the trusted part and use it in on the next step
size_t last_same_hist_token = 0;
size_t trusted_history_length = 0;
if (!m_tokenized_chat_history.empty()) {
std::set<int64_t> stop_tokens = config.stop_token_ids;
last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens);
m_trust_encoded_history = last_same_hist_token == SIZE_MAX;
trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens);
m_trust_encoded_history = trusted_history_length == SIZE_MAX;
}

if (m_tokenized_chat_history.empty()) {
encoded_input = new_chat_tokens;
} else if (last_same_hist_token != SIZE_MAX) {
m_to_remove_from_hist = m_tokenized_chat_history.size() - last_same_hist_token;
} else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_kv_cache_need_to_update()) {
// does_kv_cache_need_to_update will be true here if beam search is activated
// in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly
// if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager
if (m_kv_history_manager.does_kv_cache_need_to_update()) {
trusted_history_length = m_kv_history_manager.trusted_history_length;
} else {
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_chat_history.size() - trusted_history_length;
// if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it
m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0;
}

ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.input_ids.get_element_type(),
{1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token},
new_chat_tokens.input_ids.data<int64_t>() + last_same_hist_token);
{1, new_chat_tokens.input_ids.get_shape().at(1) - trusted_history_length},
new_chat_tokens.input_ids.data<int64_t>() + trusted_history_length);

ov::Tensor new_attention_mask(ov::element::i64, new_tensor.get_shape());
std::fill_n(new_attention_mask.data<int64_t>(), new_tensor.get_shape()[1], 1);

encoded_input.input_ids = ov::Tensor(new_chat_tokens.input_ids.get_element_type(),
{1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token});
{1, new_chat_tokens.input_ids.get_shape().at(1) - trusted_history_length});
new_tensor.copy_to(encoded_input.input_ids);
encoded_input.attention_mask = new_attention_mask;

m_selected_beam = std::nullopt;
m_last_disappeared_token = std::nullopt;
} else {
encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens);
}
m_templated_chat_history = new_templated_chat_history;

m_tokenized_chat_history.clear();
m_tokenized_chat_history.reserve(new_chat_tokens.input_ids.get_size());
std::copy_n(new_chat_tokens.input_ids.data<int64_t>(), new_chat_tokens.input_ids.get_size(),
Expand Down Expand Up @@ -264,6 +268,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
std::copy(input_ids.data<int64_t>(), input_ids.data<int64_t>() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history));

// Tail of previous output in chat mode is missing in KV cache.
if (m_last_disappeared_token.has_value()) {
attention_mask = ov::genai::utils::push_front_inputs(attention_mask, 1);
input_ids = ov::genai::utils::push_front_inputs(input_ids, *m_last_disappeared_token);
}

GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config;

// If eos_token_id was not provided, take value from default m_generation_config
Expand Down Expand Up @@ -294,7 +304,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
"(input_ids, attention_mask, position_ids, beam_idx) "
"but you have '" + std::to_string(num_inputs) + "' inputs");

ov::genai::utils::trim_kv_cache(m_model_runner, m_to_remove_from_hist, m_kv_cache_seq_length_axis, m_adapter_controller);
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_manager.num_tokens_to_remove_from_kv_cache, m_kv_cache_seq_length_axis, m_adapter_controller);

size_t kv_cache_len = 0;
ov::Tensor concatenated_attention_mask;
Expand All @@ -304,10 +314,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
// Between subsequent runs attention_mask should not be modified.
auto atten_mask_history = m_model_runner.get_tensor("attention_mask");
auto prompt_len = attention_mask.get_shape()[1];
kv_cache_len = atten_mask_history.get_shape()[1] - m_to_remove_from_hist;

kv_cache_len = atten_mask_history.get_shape()[1] - m_kv_history_manager.num_tokens_to_remove_from_kv_cache;

ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}};
auto start_atten_hst = atten_mask_history.data<int64_t>() + kv_cache_len * (*m_selected_beam);
auto start_atten_hst = atten_mask_history.data<int64_t>();

std::copy(start_atten_hst, start_atten_hst + kv_cache_len,
new_atten_mask.data<int64_t>());
std::copy(attention_mask.data<int64_t>(), attention_mask.data<int64_t>() + prompt_len,
Expand All @@ -317,6 +329,8 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
concatenated_attention_mask = attention_mask;
}

size_t prev_attn_mask_size = concatenated_attention_mask.get_shape()[1];

bool position_ids_available = (num_inputs == 4);
std::optional<ov::Tensor> position_ids = std::nullopt;
if (position_ids_available) {
Expand All @@ -330,51 +344,55 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

if (is_chat_conversation && !m_trust_encoded_history) {
m_trust_encoded_history = true;
m_to_remove_from_hist = 0;
m_kv_history_manager.reset();
}

ov::genai::EncodedResults result;
if (config.is_beam_search() && is_chat_conversation) {
std::tie(result, m_selected_beam) = beam_search(m_model_runner, input_ids, concatenated_attention_mask,
config, position_ids, m_selected_beam);
} else {
std::vector<SequenceGroup::Ptr> requests;
size_t block_size = 1;
bool enable_prefix_caching = false;

for (size_t request_id = 0; request_id < batch_size; request_id++) {
SequenceGroup::Ptr sequence_group;
if (is_chat_conversation) {
ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data());
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching);
} else {
size_t seq_len = input_ids.get_shape().at(1);
size_t batch_offset = request_id * seq_len;
const int64_t* prompt_start = input_ids.data<const int64_t>() + batch_offset;
std::vector<int64_t> tokenized_prompt(prompt_start, prompt_start + seq_len);
std::vector<SequenceGroup::Ptr> requests;
size_t block_size = 1;
bool enable_prefix_caching = false;

sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_prompt, config, block_size, enable_prefix_caching);
}
for (size_t request_id = 0; request_id < batch_size; request_id++) {
SequenceGroup::Ptr sequence_group;
if (is_chat_conversation) {
ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data());
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching);
} else {
size_t seq_len = input_ids.get_shape().at(1);
size_t batch_offset = request_id * seq_len;
const int64_t* prompt_start = input_ids.data<const int64_t>() + batch_offset;
std::vector<int64_t> tokenized_prompt(prompt_start, prompt_start + seq_len);

sequence_group->set_sequence_group_ptr(sequence_group);
requests.push_back(sequence_group);
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_prompt, config, block_size, enable_prefix_caching);
}

if (m_sampler.get_seed() != config.rng_seed) {
m_sampler.set_seed(config.rng_seed);
}
sequence_group->set_sequence_group_ptr(sequence_group);
requests.push_back(sequence_group);
}

std::tie(result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, streamer_ptr,
m_sampler, requests, position_ids, std::nullopt, m_selected_beam);
if (m_sampler.get_seed() != config.rng_seed) {
m_sampler.set_seed(config.rng_seed);
}

ov::genai::EncodedResults result;
std::tie(result, m_last_disappeared_token) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask,
streamer_ptr, m_sampler, requests, position_ids, std::nullopt);

if (is_chat_conversation) {
// force remove from kv_cache last answer
if (config.is_beam_search() && m_chat_input_type != ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
m_kv_history_manager.trusted_history_length = m_tokenized_chat_history.size();
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
}

std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
} else {
reset_kv_state();
m_selected_beam = std::nullopt;
m_last_disappeared_token = std::nullopt;
}

if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));

auto stop_time = std::chrono::steady_clock::now();

// If is called without tokenization then that stat will not be reported.
Expand All @@ -388,10 +406,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

void start_chat(const std::string& system_message) override {
is_chat_conversation = true;
m_selected_beam = std::nullopt;
m_trust_encoded_history = true;
m_to_remove_from_hist = 0;
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
if (!m_tokenized_chat_history.empty()) {
reset_kv_state();
m_history = {};
Expand All @@ -409,10 +427,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

void finish_chat() override {
is_chat_conversation = false;
m_selected_beam = std::nullopt;
m_trust_encoded_history = true;
m_to_remove_from_hist = 0;
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
if (!m_tokenized_chat_history.empty()) {
reset_kv_state();
m_history.clear();
Expand Down
39 changes: 18 additions & 21 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
#include <regex>
#include <vector>

#include "utils.hpp"
#include "debug_utils.hpp"
#include "lm_encoding.hpp"
#include "openvino/genai/perf_metrics.hpp"

#include "debug_utils.hpp"

#include "utils.hpp"

namespace ov {
namespace genai {
Expand Down Expand Up @@ -51,16 +50,15 @@ void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector<i
}


std::pair<EncodedResults, int32_t> get_lm_encoded_results(
std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
ov::InferRequest& m_llm,
const ov::Tensor& input_ids,
const ov::Tensor& attention_mask,
const std::shared_ptr<StreamerBase>& streamer_ptr,
Sampler& sampler,
std::vector<SequenceGroup::Ptr> sequence_groups,
std::optional<ov::Tensor> position_ids,
std::optional<EmbeddingsModel> m_embedding,
std::optional<int32_t> selected_beam_idx
std::optional<EmbeddingsModel> m_embedding
) {
std::vector<GenerationHandle> generations;
for (SequenceGroup::Ptr sequence_group : sequence_groups) {
Expand Down Expand Up @@ -105,7 +103,7 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
m_llm.set_tensor("position_ids", *position_ids);

ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {batch_size});
std::fill_n(beam_idx.data<int32_t>(), batch_size, selected_beam_idx.has_value() ? *selected_beam_idx : 0);
std::fill_n(beam_idx.data<int32_t>(), batch_size, 0);
m_llm.set_tensor("beam_idx", beam_idx);

// "Prompt" phase
Expand Down Expand Up @@ -171,13 +169,13 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
// apply strides to shift to a next sequence
input_ids_data += num_scheduled_tokens;

// for different sequences iteration of beams started from 0, but we collect it to one input_ids#
// for different sequences iteration of beams started from 0, but we collect it to one input_ids
next_beams.push_back(beam_idxs[sequence->get_id()] + beam_offets.at(sequence_group->get_request_id()));
}
}

for (size_t i = 0; i < sequence_groups.size(); i++) {
beam_offets[sequence_groups.at(i)->get_request_id()] = i == 0 ? 0 : (sequence_groups.at(i - 1)->num_running_seqs() + beam_offets[i - 1]);
for (size_t i = 0; i < active_sequence_groups.size(); i++) {
beam_offets[active_sequence_groups.at(i)->get_request_id()] = i == 0 ? 0 : (active_sequence_groups.at(i - 1)->num_running_seqs() + beam_offets[i - 1]);
}

if (m_embedding.has_value()) {
Expand Down Expand Up @@ -212,15 +210,10 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
streamer_ptr->end();
}

// Collect results

size_t next_selected_beam = 0;
for (size_t i = 0; i < sequence_groups.size(); i++) {
auto request = sequence_groups[i];
std::vector<GenerationOutput> generation_outputs;
auto sampling_params = request->get_sampling_parameters();
const auto& sequences = request->get_finished_sequences();
size_t num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, sequences.size());
for (auto& sequence_group : sequence_groups) {
auto sampling_params = sequence_group->get_sampling_parameters();
const auto& sequences = sequence_group->get_finished_sequences();
size_t num_outputs = std::min(sequence_group->get_sampling_parameters().num_return_sequences, sequences.size());

for (size_t seq_id = 0; seq_id < num_outputs; ++seq_id) {
const auto & sequence = sequences[seq_id];
Expand All @@ -229,13 +222,17 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
results.tokens.push_back(sequence->get_generated_ids());
results.scores.push_back(score);
}
// next_selected_beam = sampler.last_selected_beam(request);
}

for (SequenceGroup::Ptr sequence_group : sequence_groups)
sampler.clear_request_info(sequence_group->get_request_id());

return {results, next_selected_beam};
// it is not saved in KV cache, we need to add it for some cases
std::optional<int64_t> last_token_of_best_sequence = std::nullopt;
if (sequence_groups[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH || sequence_groups[0]->handle_dropped())
last_token_of_best_sequence = results.tokens[0].back();

return {results, last_token_of_best_sequence};
}

} // namespace genai
Expand Down
Loading

0 comments on commit 05d01ac

Please sign in to comment.