Skip to content

Commit

Permalink
Support chat conversation for StaticLLMPipeline (#580)
Browse files Browse the repository at this point in the history
# Overview

Adding chat mode support for `StaticLLMPipeline`. 

The current implementation is naive - aggregates the entire chat
conversation and pass as new prompt on every new `generate` call.

---------

Co-authored-by: Pavel Esir <[email protected]>
  • Loading branch information
TolyaTalamanov and pavel-esir authored Jul 17, 2024
1 parent 3cbd691 commit 7f5e8d2
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 25 deletions.
2 changes: 1 addition & 1 deletion samples/cpp/chat_sample/chat_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ int main(int argc, char* argv[]) try {
std::string prompt;
std::string model_path = argv[1];

std::string device = "CPU"; // GPU can be used as well
std::string device = "CPU"; // GPU, NPU can be used as well
ov::genai::LLMPipeline pipe(model_path, "CPU");

ov::genai::GenerationConfig config;
Expand Down
59 changes: 42 additions & 17 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,15 @@ void reshape_to_static(std::shared_ptr<ov::Model> model,
model->reshape(new_shapes);
}

void fill_tensor(ov::Tensor tensor, int64_t fill_val) {
void fill_tensor(ov::Tensor tensor, int64_t fill_val, size_t offset = 0u) {
int64_t* tensor_data = tensor.data<int64_t>();
std::fill(tensor_data, tensor_data + tensor.get_size(), fill_val);
std::fill(tensor_data + offset, tensor_data + tensor.get_size(), fill_val);
}

void copy_with_left_offset(const ov::Tensor& orig, ov::Tensor& padded) {
const auto orig_size = orig.get_size();
const auto padded_size = padded.get_size();
const auto kLeftOffset = padded_size - orig_size;
void copy_with_offset(const ov::Tensor& orig, const int32_t offset, ov::Tensor& padded) {
int64_t* orig_data = orig.data<int64_t>();
int64_t* padded_data = padded.data<int64_t>();
std::copy(orig_data, orig_data + orig_size, padded_data + kLeftOffset);
std::copy(orig_data, orig_data + orig.get_size(), padded_data + offset);
}

ov::AnyMap extract_config_or_default(const ov::AnyMap& config, const std::string& config_name) {
Expand All @@ -111,7 +108,7 @@ ov::AnyMap extract_config_or_default(const ov::AnyMap& config, const std::string
{ "NPUW_FOLD", "YES" },
{ "NPUW_DCOFF_TYPE", "f16" },
{ "NPUW_DCOFF_SCALE", "YES" },
{ "NPU_COMPILATION_MODE_PARAMS", "compute-layers-with-higher-precision=Sqrt,Power,ReduceMean,Add_RMSNorm" },
{ "NPU_COMPILATION_MODE_PARAMS", "compute-layers-with-higher-precision=Sqrt,Power,ReduceMean,Add" },
{ "NPUW_PARALLEL_COMPILE", "YES" },
{ "NPUW_FUNCALL_ASYNC", "YES" }
};
Expand Down Expand Up @@ -179,6 +176,18 @@ StaticLLMPipeline::StaticLLMPipeline(
) : StaticLLMPipeline(path, path.string(), device, config) {
}

void StaticLLMPipeline::start_chat(const std::string& system_message) {
if (!system_message.empty()) {
m_history.push_back({{"role", "system"}, {"content", system_message}});
}
m_is_chat_conversation = true;
};

void StaticLLMPipeline::finish_chat() {
m_is_chat_conversation = false;
m_history.clear();
};

void StaticLLMPipeline::prepare_for_new_conversation() {
fill_tensor(m_prefill_request.get_tensor("input_ids"), m_tokenizer.get_pad_token_id());
fill_tensor(m_prefill_request.get_tensor("position_ids"), 0u);
Expand All @@ -198,9 +207,23 @@ DecodedResults StaticLLMPipeline::generate(
}

OPENVINO_ASSERT(std::holds_alternative<std::string>(inputs));
auto tokenized_input = m_tokenizer.encode(std::get<std::string>(inputs));
auto& prompt = std::get<std::string>(inputs);

if (m_is_chat_conversation) {
m_history.push_back({{"role", "user"}, {"content", prompt}});
constexpr bool add_generation_prompt = true;
prompt = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
}

auto tokenized_input = m_tokenizer.encode(prompt);
auto encoded_results = generate(tokenized_input, config, streamer);
return {m_tokenizer.decode(encoded_results.tokens), encoded_results.scores};
DecodedResults decoded_results = {m_tokenizer.decode(encoded_results.tokens), encoded_results.scores};

if (m_is_chat_conversation) {
auto answer = decoded_results.texts[0];
m_history.push_back({{"role", "assistant"}, {"content", answer}});
}
return decoded_results;
}

EncodedResults StaticLLMPipeline::generate(
Expand Down Expand Up @@ -245,22 +268,25 @@ EncodedResults StaticLLMPipeline::generate(
ov::genai::EncodedResults results;
// NB: Only batch=1 is supported now
results.scores.resize(1u);
results.scores[0] = 0u;
results.tokens.resize(1u);

// NB: Check if input prompt less than maximum size
// NB: Check if there is enough space in KV-cache to process input prompt
auto prompt_len = input_ids.get_size();
if (prompt_len > m_kvcache_desc.total_size) {
OPENVINO_THROW("Currently static pipeline only process up to " + std::to_string(m_kvcache_desc.total_size) + " tokens");
}

// NB: Reset tensors on every generate call - chat conversation isn't supported yet!
// NB: From the "generate" perspective, every call is treated as start of new conversation,
// but if continuation is needed, prompt contains information about the entire conversation.
prepare_for_new_conversation();

auto padded_input_ids = m_prefill_request.get_tensor("input_ids");
copy_with_left_offset(input_ids, padded_input_ids);
const size_t offset = padded_input_ids.get_size() - input_ids.get_size();
copy_with_offset(input_ids, offset, padded_input_ids);

auto padded_attention_mask = m_prefill_request.get_tensor("attention_mask");
copy_with_left_offset(attention_mask, padded_attention_mask);
fill_tensor(padded_attention_mask, 1u, offset);

auto padded_position_ids = m_prefill_request.get_tensor("position_ids");
auto* padded_pos_data = padded_position_ids.data<int64_t>();
Expand All @@ -271,13 +297,13 @@ EncodedResults StaticLLMPipeline::generate(
// NB: Now there are prompt_len tokens in KV-cache
m_kvcache_desc.num_stored_tokens += prompt_len;
int64_t last_token = utils::argmax(m_prefill_request.get_tensor("logits"), 0);
results.tokens[0].push_back(last_token);
if (streamer_ptr && streamer_ptr->put(last_token)) {
return results;
}

padded_attention_mask.copy_to(m_kvcache_request.get_tensor("attention_mask"));


// Inputs: input_ids, attention_mask, position_ids, ...
// Outputs: logits, ...
const auto kStartInputKVCacheLayers = 3u;
Expand Down Expand Up @@ -309,13 +335,12 @@ EncodedResults StaticLLMPipeline::generate(

last_token = utils::argmax(m_kvcache_request.get_tensor("logits"), 0);
results.tokens[0].push_back(last_token);
results.scores[0] = 0u;

if (streamer_ptr && streamer_ptr->put(last_token)) {
break;
}

if (last_token == m_generation_config.eos_token_id) {
if (last_token == config.eos_token_id && !config.ignore_eos) {
break;
}

Expand Down
12 changes: 5 additions & 7 deletions src/cpp/src/llm_pipeline_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,8 @@ class StaticLLMPipeline final : public LLMPipelineImplBase {
StreamerVariant streamer
) override;

void start_chat(const std::string& system_message) override {
OPENVINO_THROW("Currently chat conversation mode isn't supported");
};
void finish_chat() override {
OPENVINO_THROW("Currently chat conversation mode isn't supported");
};

void start_chat(const std::string& system_message) override;
void finish_chat() override;
private:
void prepare_for_new_conversation();

Expand All @@ -54,6 +49,9 @@ class StaticLLMPipeline final : public LLMPipelineImplBase {
KVCacheDesc m_kvcache_desc;
ov::InferRequest m_kvcache_request;
ov::InferRequest m_prefill_request;

bool m_is_chat_conversation = false;
ChatHistory m_history;
};

} // namespace genai
Expand Down

0 comments on commit 7f5e8d2

Please sign in to comment.