diff --git a/.github/workflows/causal_lm_cpp.yml b/.github/workflows/causal_lm_cpp.yml index 107777bf74..2e9d72e263 100644 --- a/.github/workflows/causal_lm_cpp.yml +++ b/.github/workflows/causal_lm_cpp.yml @@ -491,7 +491,6 @@ jobs: python -m pip install ./thirdparty/openvino_tokenizers/[transformers] --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly python -m pip install -r ./samples/requirements.txt optimum-cli export openvino --trust-remote-code --weight-format fp16 --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 - optimum-cli export openvino --trust-remote-code --weight-format fp16 --model Qwen/Qwen-7B-Chat Qwen-7B-Chat --task text-generation-with-past - name: run and compare run: | source ./ov/setupvars.sh @@ -505,36 +504,22 @@ jobs: ./build/samples/cpp/prompt_lookup_decoding_lm/prompt_lookup_decoding_lm ./TinyLlama-1.1B-Chat-v1.0/ "$( predictions_prompt_lookup.txt ./build/samples/cpp/text_generation/greedy_causal_lm ./TinyLlama-1.1B-Chat-v1.0/ "$( predictions_greedy.txt + python ./samples/python/prompt_lookup_decoding_lm/prompt_lookup_decoding_lm.py ./TinyLlama-1.1B-Chat-v1.0/ "$( predictions_py.txt python -c " with open('predictions_greedy.txt', 'r') as f: predicted_greedy = f.readline() with open('predictions_prompt_lookup.txt', 'r') as f: predicted_prompt_lookup = f.readline() + with open('predictions_py.txt', 'r') as f: + predicted_prompt_lookup_py = f.readline() assert predicted_greedy == predicted_prompt_lookup + assert predicted_greedy == predicted_prompt_lookup_py + assert predicted_prompt_lookup == predicted_prompt_lookup_py " echo "Prompt lookup" passed - - name: run and compare (model with seq_length_axis = 1) - run: | - source ./ov/setupvars.sh - - echo 'Code:```python - def add(a, b): - return a + b - ``` - Question: Can you please add 2 and 3 - A:' > ./prompt.txt - - ./build/samples/cpp/prompt_lookup_decoding_lm/prompt_lookup_decoding_lm ./Qwen-7B-Chat/ "$( predictions_prompt_lookup.txt - ./build/samples/cpp/text_generation/greedy_causal_lm ./Qwen-7B-Chat/ "$( predictions_greedy.txt - python -c " - with open('predictions_greedy.txt', 'r') as f: - predicted_greedy = f.readline() - with open('predictions_prompt_lookup.txt', 'r') as f: - predicted_prompt_lookup = f.readline() - assert predicted_greedy == predicted_prompt_lookup - " - echo "Prompt lookup" passed - + env: + PYTHONPATH: "./build/:$PYTHONPATH" + LD_LIBRARY_PATH: "./build/openvino_genai/:$LD_LIBRARY_PATH" cpp-Phi-1_5: runs-on: ubuntu-20.04-16-cores defaults: diff --git a/samples/cpp/prompt_lookup_decoding_lm/CMakeLists.txt b/samples/cpp/prompt_lookup_decoding_lm/CMakeLists.txt index c899c6e47b..b0ce8b1b60 100644 --- a/samples/cpp/prompt_lookup_decoding_lm/CMakeLists.txt +++ b/samples/cpp/prompt_lookup_decoding_lm/CMakeLists.txt @@ -1,8 +1,6 @@ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -find_package(OpenVINO REQUIRED COMPONENTS Runtime Threading) - find_package(OpenVINOGenAI REQUIRED PATHS "${CMAKE_BINARY_DIR}" # Reuse the package from the build. @@ -10,21 +8,16 @@ find_package(OpenVINOGenAI REQUIRED NO_CMAKE_FIND_ROOT_PATH ) -add_executable(prompt_lookup_decoding_lm prompt_lookup_decoding_lm.cpp) -target_link_libraries(prompt_lookup_decoding_lm PRIVATE openvino::runtime openvino::threading) -set_target_properties(prompt_lookup_decoding_lm PROPERTIES - COMPILE_PDB_NAME prompt_lookup_decoding_lm +set(TARGET_NAME prompt_lookup_decoding_lm) +add_executable(${TARGET_NAME} ${TARGET_NAME}.cpp) +target_link_libraries(${TARGET_NAME} PRIVATE openvino::genai) + +set_target_properties(${TARGET_NAME} PROPERTIES + COMPILE_PDB_NAME ${TARGET_NAME} # Ensure out of box LC_RPATH on macOS with SIP INSTALL_RPATH_USE_LINK_PATH ON) -target_compile_features(prompt_lookup_decoding_lm PRIVATE cxx_std_17) - -get_target_property(genai_imported openvino::genai IMPORTED_LOCATION) -set(OPENVINO_TOKENIZERS_PATH $,${genai_imported},$>) -set(OPENVINO_TOKENIZERS_FILENAME "${CMAKE_SHARED_LIBRARY_PREFIX}openvino_tokenizers${CMAKE_SHARED_LIBRARY_SUFFIX}") -target_compile_definitions(prompt_lookup_decoding_lm PRIVATE - OPENVINO_TOKENIZERS_PATH="${OPENVINO_TOKENIZERS_PATH}/${OPENVINO_TOKENIZERS_FILENAME}") -install(TARGETS prompt_lookup_decoding_lm +install(TARGETS ${TARGET_NAME} RUNTIME DESTINATION samples_bin/ COMPONENT samples_bin EXCLUDE_FROM_ALL) diff --git a/samples/cpp/prompt_lookup_decoding_lm/prompt_lookup_decoding_lm.cpp b/samples/cpp/prompt_lookup_decoding_lm/prompt_lookup_decoding_lm.cpp index 282220a4b1..e692110027 100644 --- a/samples/cpp/prompt_lookup_decoding_lm/prompt_lookup_decoding_lm.cpp +++ b/samples/cpp/prompt_lookup_decoding_lm/prompt_lookup_decoding_lm.cpp @@ -1,338 +1,45 @@ // Copyright (C) 2023-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 -#include #include -#include -namespace { - -// only batch_size = 1 currently supported -constexpr size_t BATCH_SIZE = 1; - -size_t get_seq_len_axis(std::shared_ptr model) { - // sequence length axis in key/values tensors, for most cases [BATCH_SIZE, num_kv_heads, seq_len, head_size], - // therefore usually seq_length_axis = 2 - size_t seq_length_axis = 2; - - // "ReadValue" node is KV cache representation in stateful model - std::string kv_node_type_name = std::string(ov::op::v6::ReadValue::get_type_info_static().name); - - for (const auto op : model->get_ops()) { - if (op->get_type_name() != kv_node_type_name) { - continue; - } - - // Shape example: [-1,4,0,64] - auto shape = op->get_input_partial_shape(0); - - for (size_t i = 0; i < shape.rank().get_length(); i++) { - // Find axis = 0. This would be sequence length axis. - if (shape[i] == 0) { - seq_length_axis = i; - } - } - break; - } - - return seq_length_axis; -} - -std::pair tokenize(ov::InferRequest& tokenizer, std::string&& prompt) { - tokenizer.set_input_tensor(ov::Tensor{ov::element::string, {BATCH_SIZE}, &prompt}); - tokenizer.infer(); - return {tokenizer.get_tensor("input_ids"), tokenizer.get_tensor("attention_mask")}; -} - -std::string detokenize(ov::InferRequest& detokenizer, std::vector& tokens) { - detokenizer.set_input_tensor(ov::Tensor{ov::element::i64, {BATCH_SIZE, tokens.size()}, tokens.data()}); - detokenizer.infer(); - return detokenizer.get_output_tensor().data()[0]; -} - -// The following reasons require TextStreamer to keep a cache of previous tokens: -// detokenizer removes starting ' '. For example detokenize(tokenize(" a")) == "a", -// but detokenize(tokenize("prefix a")) == "prefix a" -// 1 printable token may consist of 2 token ids: detokenize(incomplete_token_idx) == "�" -struct TextStreamer { - ov::InferRequest detokenizer; - std::vector token_cache; - size_t print_len = 0; - - void put(int64_t token) { - token_cache.push_back(token); - std::string text = detokenize(detokenizer, token_cache); - if (!text.empty() && '\n' == text.back() && text.size() > print_len) { - // Flush the cache after the new line symbol - std::cout << std::string_view{text.data() + print_len, text.size() - print_len}; - token_cache.clear(); - print_len = 0; - return; - } - constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error. - if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) { - // Don't print incomplete text - return; - } else if (text.size() > print_len) { - // It is possible to have a shorter text after adding new token. - // Print to output only if text length is increaeseds. - std::cout << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush; - print_len = text.size(); - } - } - - void end() { - std::string text = detokenize(detokenizer, token_cache); - if (text.size() <= print_len) - return; - std::cout << std::string_view{text.data() + print_len, text.size() - print_len} << '\n'; - token_cache.clear(); - print_len = 0; - } -}; - -ov::Tensor trimm_tensor(ov::Tensor& tensor, uint64_t seq_len_axis, uint64_t new_seq_len) { - // Copy elements from the old to a new tensor and return it. - // Trim kv tensor on sequence length axis - // key/values tensor shape example: [BATCH_SIZE, num_kv_heads, seq_len, head_size] - // Sequence length axis position may vary from one model to another - - auto shape = tensor.get_shape(); - - OPENVINO_ASSERT(seq_len_axis < shape.size(), - "Sequence length axis: ", - seq_len_axis, - " should be less than shape size: ", - shape.size()); - - size_t old_seq_len = shape[seq_len_axis]; - - OPENVINO_ASSERT(new_seq_len <= old_seq_len); - - // if new_seq_len equal to old one no need to copy tensor, return as is - if (old_seq_len == new_seq_len) - return tensor; - - shape[seq_len_axis] = new_seq_len; - - if (seq_len_axis == 0) { - tensor.set_shape(shape); - return tensor; - } - - ov::Coordinate new_shape_begin{0, 0, 0, 0}; - ov::Coordinate new_shape_end{shape}; - - auto new_tensor = ov::Tensor(tensor, new_shape_begin, new_shape_end); - - return new_tensor; -} - -void update_kv_cache(ov::InferRequest request, uint64_t seq_len_axis, uint64_t new_seq_len) { - // trim kv_cache values up to the new_seq_len - auto states = request.query_state(); - ov::parallel_for(states.size(), [&](size_t i) { - ov::Tensor old_tensor = states.at(i).get_state(); - states.at(i).set_state(trimm_tensor(old_tensor, seq_len_axis, new_seq_len)); - }); -} - -class PromptLookupCandidateGenerator { -private: - const size_t max_ngram_size = 3; - size_t num_pred_tokens = 5; - const size_t max_pred_tokens = 20; - -public: - PromptLookupCandidateGenerator(const size_t max_ngram_size, const size_t num_pred_tokens) - : max_ngram_size{max_ngram_size}, - num_pred_tokens{num_pred_tokens} {}; - - std::vector generate_candidates(const std::vector& input_ids) { - const size_t input_length = input_ids.size(); - - for (int32_t ngram_size = max_ngram_size; ngram_size > 0; ngram_size--) { - // extract last ngram_size tokens as search ngram - std::vector ngram = std::vector{input_ids.cend() - ngram_size, input_ids.cend()}; - - // find ngram match in input_ids - size_t ngram_i = 0; - for (size_t input_i = 0; input_i < input_length - ngram_size; input_i++) { - if (ngram[ngram_i] != input_ids[input_i]) { - ngram_i = 0; - continue; - } - - ngram_i++; - - if (ngram_i < ngram_size) { - continue; - } - - // match found with the end at input_i - size_t avaliable_num_pred = std::min(input_length - (input_i + 1), num_pred_tokens); - - // return candidates with length of avaliable_num_pred - return std::vector{input_ids.cbegin() + input_i + 1, - input_ids.cbegin() + input_i + 1 + avaliable_num_pred}; - } - } - - return std::vector{}; - } - - void update_candidate_strategy(const size_t num_matches) { - // dynamically adjust number of generated candidates based on number of matches - // we want to balance the benefits of getting assistant tokens correct with the - // cost of forecasting incorrect assistant tokens. - if (num_matches == num_pred_tokens) { - num_pred_tokens = std::min(num_pred_tokens + 2, max_pred_tokens); - } else { - num_pred_tokens = std::max(num_pred_tokens - 1, size_t(1)); - } - } -}; - -int64_t get_eos_token(const std::shared_ptr tokenizer) { - auto rt_info = tokenizer->get_rt_info(); // Get the runtime info for the model - - auto it = rt_info.find("eos_token_id"); - if (it == rt_info.end()) { - throw std::runtime_error("EOS token ID not found in model's runtime information."); - } - return it->second.as(); -} - -} // namespace +#include "openvino/genai/llm_pipeline.hpp" int main(int argc, char* argv[]) try { - if (argc != 3) { + if (3 != argc) { throw std::runtime_error(std::string{"Usage: "} + argv[0] + " ''"); } - // tokenizer model - ov::Core core; - core.add_extension(OPENVINO_TOKENIZERS_PATH); // OPENVINO_TOKENIZERS_PATH is defined in CMakeLists.txt - - const std::string model_dir = std::string{argv[1]}; - - auto tokenizer_model = core.read_model(model_dir + "/openvino_tokenizer.xml"); - // tokenizer and detokenizer work on CPU only - ov::InferRequest tokenizer = core.compile_model(tokenizer_model, "CPU").create_infer_request(); - auto [input_ids, attention_mask] = tokenize(tokenizer, argv[2]); - - std::vector full_input_ids{input_ids.data(), input_ids.data() + input_ids.get_size()}; - - ov::InferRequest detokenizer = - core.compile_model(model_dir + "/openvino_detokenizer.xml", "CPU").create_infer_request(); - TextStreamer text_streamer{std::move(detokenizer)}; - - std::shared_ptr ov_model = core.read_model(model_dir + "/openvino_model.xml"); - - size_t seq_len_axis = get_seq_len_axis(ov_model); - - ov::InferRequest model = core.compile_model(ov_model, "CPU").create_infer_request(); - - model.set_tensor("input_ids", input_ids); - model.set_tensor("attention_mask", attention_mask); - - ov::Tensor position_ids = model.get_tensor("position_ids"); - position_ids.set_shape(input_ids.get_shape()); - std::iota(position_ids.data(), position_ids.data() + position_ids.get_size(), 0); - size_t seq_len = input_ids.get_shape()[1]; - - // set beam_idx for stateful model: no beam search is used and BATCH_SIZE = 1 - model.get_tensor("beam_idx").set_shape({BATCH_SIZE}); - model.get_tensor("beam_idx").data()[0] = 0; - - // To collect kv-cache for the and to get the next token run the very first infer request - model.infer(); - - // logits shape is [BATCH_SIZE, seq_len, vocab_size] - auto logits = model.get_tensor("logits"); - size_t vocab_size = logits.get_shape().back(); - auto data_logits = logits.data() + (seq_len - 1) * vocab_size; - int64_t out_token = std::max_element(data_logits, data_logits + vocab_size) - data_logits; - - full_input_ids.push_back(out_token); - - auto first_token = out_token; - text_streamer.put(out_token); - - const int64_t EOS_TOKEN = get_eos_token(tokenizer_model); - - // Prompt lookup decoding is a speculative decoding technique where the draft model replaced - // with string matching in the prompt to generate candidate token sequences. - int max_sequence_length = 100; - PromptLookupCandidateGenerator candidateGenerator{3, 5}; - - while (out_token != EOS_TOKEN && seq_len < max_sequence_length) { - auto candidates = candidateGenerator.generate_candidates(full_input_ids); - - // cut redundant candidates on last iteration - size_t tokens_to_generate = max_sequence_length - seq_len; - candidates.resize(std::min(candidates.size(), tokens_to_generate - 1)); - size_t candidates_size = candidates.size(); - - // candidates_size + 1 tokens will be fed at once in a single infer request. - input_ids.set_shape({BATCH_SIZE, candidates_size + 1}); - input_ids.data()[0] = first_token; - std::copy_n(candidates.begin(), candidates_size, input_ids.data() + 1); - - attention_mask.set_shape({BATCH_SIZE, seq_len + candidates_size + 1}); - std::fill_n(attention_mask.data(), attention_mask.get_size(), 1); - - position_ids.set_shape({BATCH_SIZE, candidates_size + 1}); - std::iota(position_ids.data(), position_ids.data() + position_ids.get_size(), seq_len); - - model.infer(); - - data_logits = logits.data(); // [BATCH_SIZE, 1 + candidates_size, vocab_size] - - // 1. accept current out token (if not eos) - // 2. check if it matches appropriate candidate - // 2.1 if it's match, continue - accept next token - // 2.2 it it's mismatch, stop iteration but still accept current token as it was last token generated by - // model from a valid sequence. - size_t accepted_tokens_number = 0; - for (size_t i = 0; i < candidates_size + 1; i++) { - auto start = data_logits + vocab_size * i; - auto stop = data_logits + vocab_size * (i + 1); - out_token = std::max_element(start, stop) - start; - - if (out_token == EOS_TOKEN) { - break; - } - - text_streamer.put(out_token); - full_input_ids.push_back(out_token); - accepted_tokens_number++; - - if (i == candidates_size || out_token != candidates[i]) { - break; - } - } - - if (accepted_tokens_number > 0) { - candidateGenerator.update_candidate_strategy(accepted_tokens_number - 1); - } - - // After the inference request, key/values have shape [BATCH_SIZE, seq_len + candidates_size, vocab_size]. - // Increment the sequence length by the number of matched tokens, and - // trim the KV cache to match the new sequence length. - seq_len += accepted_tokens_number; - update_kv_cache(model, seq_len_axis, seq_len); - - first_token = out_token; - } - - text_streamer.end(); - // Model is stateful which means that context (kv-cache) which belongs to a particular - // text sequence is accumulated inside the model during the generation loop above. - // This context should be reset before processing the next text sequence. - // While it is not required to reset context in this sample as only one sequence is processed, - // it is called for education purposes: - model.reset_state(); + ov::genai::GenerationConfig config; + config.max_new_tokens = 100; + // Define candidates number for candidate generation + config.num_assistant_tokens = 5; + // Define max_ngram_size + config.max_ngram_size = 3; + + std::string model_path = argv[1]; + std::string prompt = argv[2]; + + std::string device = "CPU"; + + ov::genai::SchedulerConfig scheduler_config; + scheduler_config.cache_size = 5; + + ov::genai::LLMPipeline pipe( + model_path, + device, + ov::genai::prompt_lookup(true), + ov::genai::scheduler_config(scheduler_config)); + + auto streamer = [](std::string subword) { + std::cout << subword << std::flush; + return false; + }; + + // Since the streamer is set, the results will + // be printed each time a new token is generated. + pipe.generate(prompt, config, streamer); + std::cout << std::endl; } catch (const std::exception& error) { try { std::cerr << error.what() << '\n'; diff --git a/samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp b/samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp index dc6761879c..487296566b 100644 --- a/samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp +++ b/samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp @@ -29,7 +29,6 @@ int main(int argc, char* argv[]) try { ov::genai::SchedulerConfig scheduler_config; scheduler_config.cache_size = 5; - // Different devices require different block sizes, so different scheduler configs need to be set. ov::genai::LLMPipeline pipe( main_model_path, main_device, diff --git a/samples/python/prompt_lookup_decoding_lm/README.md b/samples/python/prompt_lookup_decoding_lm/README.md new file mode 100644 index 0000000000..1e5f4003d4 --- /dev/null +++ b/samples/python/prompt_lookup_decoding_lm/README.md @@ -0,0 +1,41 @@ +# prompt_lookup_decoding_lm Python sample that supports most popular models like LLaMA 3 + +[Prompt Lookup decoding](https://github.com/apoorvumang/prompt-lookup-decoding) is [assested-generation](https://huggingface.co/blog/assisted-generation#understanding-text-generation-latency) technique where the draft model is replaced with simple string matching the prompt to generate candidate token sequences. This method highly effective for input grounded generation (summarization, document QA, multi-turn chat, code editing), where there is high n-gram overlap between LLM input (prompt) and LLM output. This could be entity names, phrases, or code chunks that the LLM directly copies from the input while generating the output. Prompt lookup exploits this pattern to speed up autoregressive decoding in LLMs. This results in significant speedups with no effect on output quality. + +This example showcases inference of text-generation Large Language Models (LLMs): `chatglm`, `LLaMA`, `Qwen` and other models with the same signature. The application doesn't have many configuration options to encourage the reader to explore and modify the source code. Loading `openvino_tokenizers` to `ov::Core` enables tokenization. Run `optimum-cli` to generate IRs for the samples. There is also a Jupyter [notebook](https://github.com/openvinotoolkit/openvino_notebooks/tree/latest/notebooks/llm-chatbot) which provides an example of LLM-powered Chatbot in Python. + +## Download and convert the model and tokenizers + +The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version. + +It's not required to install [../../export-requirements.txt](../../export requirements.txt) for deployment if the model has already been exported. + +```sh +source /setupvars.sh +pip install --upgrade-strategy eager -r ../../requirements.txt +optimum-cli export openvino --trust-remote-code --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 +``` + +## Run + +Install [deployment-requirements.txt](../../deployment-requirements.txt) via `pip install -r ../../deployment-requirements.txt` and then, run a sample: + +`python prompt_lookup_decoding_lm.py ./TinyLlama-1.1B-Chat-v1.0/ "return 0;"` + + +Discrete GPUs (dGPUs) usually provide better performance compared to CPUs. It is recommended to run larger models on a dGPU with 32GB+ RAM. For example, the model meta-llama/Llama-2-13b-chat-hf can benefit from being run on a dGPU. Modify the source code to change the device for inference to the GPU. + +See https://github.com/openvinotoolkit/openvino.genai/blob/master/src/README.md#supported-models for the list of supported models. + +### Troubleshooting + +#### Unicode characters encoding error on Windows + +Example error: +``` +UnicodeEncodeError: 'charmap' codec can't encode character '\u25aa' in position 0: character maps to +``` + +If you encounter the error described in the example when sample is printing output to the Windows console, it is likely due to the default Windows encoding not supporting certain Unicode characters. To resolve this: +1. Enable Unicode characters for Windows cmd - open `Region` settings from `Control panel`. `Administrative`->`Change system locale`->`Beta: Use Unicode UTF-8 for worldwide language support`->`OK`. Reboot. +2. Enable UTF-8 mode by setting environment variable `PYTHONIOENCODING="utf8"`. diff --git a/samples/python/prompt_lookup_decoding_lm/prompt_lookup_decoding_lm.py b/samples/python/prompt_lookup_decoding_lm/prompt_lookup_decoding_lm.py new file mode 100755 index 0000000000..557897b6b1 --- /dev/null +++ b/samples/python/prompt_lookup_decoding_lm/prompt_lookup_decoding_lm.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import openvino_genai + +def streamer(subword): + print(subword, end='', flush=True) + # Return flag corresponds whether generation should be stopped. + # False means continue generation. + return False + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('model_dir') + parser.add_argument('prompt') + args = parser.parse_args() + + device = 'CPU' + scheduler_config = openvino_genai.SchedulerConfig() + # cache params + scheduler_config.cache_size = 2 + + pipe = openvino_genai.LLMPipeline(args.model_dir, device, scheduler_config=scheduler_config, prompt_lookup=True) + + config = openvino_genai.GenerationConfig() + config.max_new_tokens = 100 + # add parameter to enable prompt lookup decoding to generate `num_assistant_tokens` candidates per iteration + config.num_assistant_tokens = 5 + # Define max_ngram_size + config.max_ngram_size = 3 + + # Since the streamer is set, the results will be printed + # every time a new token is generated and put into the streamer queue. + pipe.generate(args.prompt, config, streamer) + +if '__main__' == __name__: + main() diff --git a/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp b/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp index 4a0637f2d9..74466ee488 100644 --- a/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp +++ b/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp @@ -55,10 +55,14 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline { class ImplInterface; class ContinuousBatchingImpl; class ContinuousBatchingForSpeculativeDecodingImpl; + class ContinuousBatchingForPromptLookupImpl; class SpeculativeDecodingImpl; + class PromptLookupImpl; friend class ContinuousBatchingForSpeculativeDecodingImpl; + friend class ContinuousBatchingForPromptLookupImpl; friend class SpeculativeDecodingImpl; + friend class PromptLookupImpl; std::shared_ptr m_impl; diff --git a/src/cpp/include/openvino/genai/generation_config.hpp b/src/cpp/include/openvino/genai/generation_config.hpp index 9d79240aa8..b8b222e347 100644 --- a/src/cpp/include/openvino/genai/generation_config.hpp +++ b/src/cpp/include/openvino/genai/generation_config.hpp @@ -71,9 +71,10 @@ enum class StopCriteria { EARLY, HEURISTIC, NEVER }; * @param frequency_penalty reduces absolute log prob as many times as the token was generated. * @param rng_seed initializes random generator. * - * Speculative decoding parameters: - * @param assistant_confidence_threshold the lower token probability of candidate to be validated by main model in case of static strategy candidates number update. - * @param num_assistant_tokens the defined candidates number to be generated by draft model in case of dynamic strategy candidates number update. + * Assisting generation parameters: + * @param assistant_confidence_threshold the lower token probability of candidate to be validated by main model in case of dynamic strategy candidates number update. + * @param num_assistant_tokens the defined candidates number to be generated by draft model/prompt lookup in case of static strategy candidates number update. + * @param max_ngram_size is maximum ngram to use when looking for matches in the prompt. */ class OPENVINO_GENAI_EXPORTS GenerationConfig { @@ -114,9 +115,10 @@ class OPENVINO_GENAI_EXPORTS GenerationConfig { float frequency_penalty = 0.0f; size_t rng_seed = 0; - // Speculative decoding + // Assisting generation parameters float assistant_confidence_threshold = 0.f; size_t num_assistant_tokens = 0; + size_t max_ngram_size = 0; // EOS special token int64_t eos_token_id = -1; @@ -132,7 +134,10 @@ class OPENVINO_GENAI_EXPORTS GenerationConfig { bool is_greedy_decoding() const; bool is_beam_search() const; bool is_multinomial() const; + OPENVINO_DEPRECATED("Please, use `is_assisting_generation()` instead of `is_speculative_decoding()`. This method will be removed in 2025.0.0 release") bool is_speculative_decoding() const; + bool is_assisting_generation() const; + bool is_prompt_lookup() const; void update_generation_config(const ov::AnyMap& config_map); template diff --git a/src/cpp/include/openvino/genai/llm_pipeline.hpp b/src/cpp/include/openvino/genai/llm_pipeline.hpp index 44427d45b1..948baab6f4 100644 --- a/src/cpp/include/openvino/genai/llm_pipeline.hpp +++ b/src/cpp/include/openvino/genai/llm_pipeline.hpp @@ -320,5 +320,12 @@ inline std::pair draft_model( */ static constexpr ov::Property scheduler_config{"scheduler_config"}; +/** +* @brief enable prompt_lookup property serves to activate prompt lookup decoding. +* Set `true` to activate this mode. +* And create LLMPipeline instance with this config. +*/ +static constexpr ov::Property prompt_lookup{"prompt_lookup"}; + } // namespace genai } // namespace ov diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index bf0c979d39..6e7e982a4c 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -16,10 +16,12 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl( const SchedulerConfig& scheduler_config, const std::string& device, const ov::AnyMap& properties, - const ov::genai::GenerationConfig& generation_config + const ov::genai::GenerationConfig& generation_config, + bool is_validation_mode_enabled ) { m_tokenizer = tokenizer; m_generation_config = generation_config; + m_is_validation_mode_enabled = is_validation_mode_enabled; ov::Core core; diff --git a/src/cpp/src/continuous_batching_impl.hpp b/src/cpp/src/continuous_batching_impl.hpp index 780bff6a31..8da05c6dfa 100644 --- a/src/cpp/src/continuous_batching_impl.hpp +++ b/src/cpp/src/continuous_batching_impl.hpp @@ -58,7 +58,8 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc const SchedulerConfig& scheduler_config, const std::string& device, const ov::AnyMap& properties, - const ov::genai::GenerationConfig& generation_config); + const ov::genai::GenerationConfig& generation_config, + bool is_validation_mode_enabled = false); GenerationHandle add_request(uint64_t request_id, const ov::Tensor& input_ids, diff --git a/src/cpp/src/continuous_batching_pipeline.cpp b/src/cpp/src/continuous_batching_pipeline.cpp index 2faad4354e..148eb2fa9f 100644 --- a/src/cpp/src/continuous_batching_pipeline.cpp +++ b/src/cpp/src/continuous_batching_pipeline.cpp @@ -11,6 +11,7 @@ #include "openvino/genai/tokenizer.hpp" #include "continuous_batching_impl.hpp" #include "speculative_decoding/speculative_decoding_impl.hpp" +#include "prompt_lookup/prompt_lookup_impl.hpp" #include "timer.hpp" #include "utils.hpp" #include "debug_utils.hpp" @@ -28,6 +29,15 @@ extract_draft_model_from_config(ov::AnyMap& config) { return draft_model; } +inline bool +extract_prompt_lookup_from_config(ov::AnyMap& config) { + bool res = false; + if (config.find(ov::genai::prompt_lookup.name()) != config.end()) { + res = config.at(ov::genai::prompt_lookup.name()).as(); + config.erase(ov::genai::prompt_lookup.name()); + } + return res; +} ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::path& models_path, const SchedulerConfig& scheduler_config, @@ -36,12 +46,16 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p const ov::AnyMap& tokenizer_properties) { auto properties_without_draft_model = properties; auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); + auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); std::filesystem::path openvino_model_name = "openvino_model.xml"; auto model = utils::singleton_core().read_model((models_path / openvino_model_name).string()); auto tokenizer = ov::genai::Tokenizer(models_path, tokenizer_properties); auto generation_config = utils::from_config_json_if_exists(models_path); - if (draft_model_desr.model == nullptr) { + if (is_prompt_lookup_enabled) { + OPENVINO_ASSERT(draft_model_desr.model == nullptr, "Speculative decoding and prompt lookup decoding are mutually excluded"); + m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model, generation_config); + } else if (draft_model_desr.model == nullptr) { m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties, generation_config); } else { auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config, generation_config); @@ -57,11 +71,15 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const ov::AnyMap& properties) { auto properties_without_draft_model = properties; auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); + auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); std::filesystem::path openvino_model_name = "openvino_model.xml"; auto model = utils::singleton_core().read_model((models_path / openvino_model_name).string()); auto generation_config = utils::from_config_json_if_exists(models_path); - if (draft_model_desr.model == nullptr) { + if (is_prompt_lookup_enabled) { + OPENVINO_ASSERT(draft_model_desr.model == nullptr, "Speculative decoding and prompt lookup decoding are mutually excluded"); + m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model, generation_config); + } else if (draft_model_desr.model == nullptr) { m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties, generation_config); } else { auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config, generation_config); @@ -79,9 +97,13 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const ov::genai::GenerationConfig& generation_config) { auto properties_without_draft_model = properties; auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); + auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); auto model = utils::singleton_core().read_model(model_str, weights_tensor); - if (draft_model_desr.model == nullptr) { + if (is_prompt_lookup_enabled) { + OPENVINO_ASSERT(draft_model_desr.model == nullptr, "Speculative decoding and prompt lookup decoding are mutually excluded"); + m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model, generation_config); + } else if (draft_model_desr.model == nullptr) { m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties, generation_config); } else { auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config, generation_config); diff --git a/src/cpp/src/generation_config.cpp b/src/cpp/src/generation_config.cpp index 189cfeded7..35ae92d605 100644 --- a/src/cpp/src/generation_config.cpp +++ b/src/cpp/src/generation_config.cpp @@ -132,9 +132,17 @@ bool GenerationConfig::is_multinomial() const { } bool GenerationConfig::is_speculative_decoding() const { + return is_assisting_generation(); +} + +bool GenerationConfig::is_assisting_generation() const { return (assistant_confidence_threshold > 0 || num_assistant_tokens > 0); } +bool GenerationConfig::is_prompt_lookup() const { + return (max_ngram_size > 0 && num_assistant_tokens > 0); +} + void GenerationConfig::validate() const { OPENVINO_ASSERT(eos_token_id == -1 || stop_token_ids.find(eos_token_id) != stop_token_ids.end(), "'stop_token_ids' must contain 'eos_token_id'. Please, call 'set_eos_token_id' with 'eos_token_id' value"); @@ -181,9 +189,10 @@ void GenerationConfig::validate() const { OPENVINO_ASSERT(frequency_penalty >= -2.0f && frequency_penalty <= 2.0f, "frequence_penalty penalty must be a [-2; +2]"); OPENVINO_ASSERT(presence_penalty >= -2.0f && presence_penalty <= 2.0f, "presence_penalty penalty must be a [-2; +2]"); } - if (is_speculative_decoding()) { + if (is_assisting_generation()) { if (assistant_confidence_threshold != 0.f) { OPENVINO_ASSERT(num_assistant_tokens == 0, "Parameters `assistant_confidence_threshold` and `num_assistant_tokens` are mutually exclusive in `GenerationConfig`"); + OPENVINO_ASSERT(!is_prompt_lookup(), "Parameters `assistant_confidence_threshold` cannot be used while Prompt Lookup decoding"); } else { OPENVINO_ASSERT(num_assistant_tokens > 0, "Parameters `assistant_confidence_threshold` and `num_assistant_tokens` are mutually exclusive in `GenerationConfig`"); }; diff --git a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp new file mode 100644 index 0000000000..8c9e520728 --- /dev/null +++ b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp @@ -0,0 +1,85 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "continuous_batching_for_prompt_lookup.hpp" + +namespace ov::genai { + +std::map +ContinuousBatchingPipeline::ContinuousBatchingForPromptLookupImpl::get_generated_request_len() { + std::map result; + for (const auto& request : m_requests) { + const auto request_id = request->get_request_id(); + auto validation_len = request->get_num_tokens_to_validate(); + auto generated_len = request->get_num_processed_tokens() - request->get_prompt_len() + 1; + result.insert({ request_id, { generated_len, validation_len } }); + } + return result; +} + +TokenIds ContinuousBatchingPipeline::ContinuousBatchingForPromptLookupImpl::generate_candidates(const TokenIds& input_ids, size_t num_pred_tokens, size_t max_ngram_size) { + if (num_pred_tokens == 0) { + return std::vector{}; + } + + const size_t input_length = input_ids.size(); + + for (int32_t ngram_size = max_ngram_size; ngram_size > 0; ngram_size--) { + // extract last ngram_size tokens as search ngram + std::vector ngram = std::vector{input_ids.cend() - ngram_size, input_ids.cend()}; + + // find ngram match in input_ids + size_t ngram_i = 0; + for (size_t input_i = 0; input_i < input_length - ngram_size; input_i++) { + if (ngram[ngram_i] != input_ids[input_i]) { + ngram_i = 0; + continue; + } + + ngram_i++; + + if (ngram_i < ngram_size) { + continue; + } + + // match found with the end at input_i + size_t avaliable_num_pred = std::min(input_length - (input_i + 1), num_pred_tokens); + + // return candidates with length of avaliable_num_pred + return std::vector{input_ids.cbegin() + input_i + 1, + input_ids.cbegin() + input_i + 1 + avaliable_num_pred}; + } + } + + return std::vector{}; +} + +void ContinuousBatchingPipeline::ContinuousBatchingForPromptLookupImpl::generate_candidates() { + for (auto& request : m_requests) { + const auto prompt = request->get_prompt_ids(); + size_t max_validation_len = 0; + for (auto& running_sequence : request->get_running_sequences()) { + const auto generated_tokens = running_sequence->get_generated_ids(); + TokenIds full_input_ids = prompt; + full_input_ids.insert(full_input_ids.end(), generated_tokens.begin(), generated_tokens.end()); + + size_t min_num_assistant_tokens = 0; + const auto sampling_params = request->get_sampling_parameters(); + { + const auto generated_len = running_sequence->get_generated_len(); + const auto left_generated_len = std::min(sampling_params.max_new_tokens, sampling_params.max_length) - generated_len - 1; + min_num_assistant_tokens = std::min(sampling_params.num_assistant_tokens, left_generated_len); + } + TokenIds candidates = generate_candidates(full_input_ids, min_num_assistant_tokens, sampling_params.max_ngram_size); + + if (!candidates.empty()) { + for (const auto& candidate : candidates) { + running_sequence->append_token(candidate, 0); + } + max_validation_len = std::max(max_validation_len, candidates.size()); + } + } + request->set_num_validated_tokens(max_validation_len); + } +} +} \ No newline at end of file diff --git a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.hpp b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.hpp new file mode 100644 index 0000000000..8962aba0f2 --- /dev/null +++ b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "openvino/genai/continuous_batching_pipeline.hpp" + +#include "continuous_batching_impl.hpp" + +namespace ov::genai { +class ContinuousBatchingPipeline::ContinuousBatchingForPromptLookupImpl : public ContinuousBatchingPipeline::ContinuousBatchingImpl { +public: + ContinuousBatchingForPromptLookupImpl() = default; + + ContinuousBatchingForPromptLookupImpl( + const std::shared_ptr& model, + const Tokenizer& tokenizer, + const SchedulerConfig& scheduler_config, + const std::string& device, + const ov::AnyMap& properties, + const ov::genai::GenerationConfig& generation_config, + bool is_validation_mode_enabled = false) : + ContinuousBatchingImpl{ model, + tokenizer, + scheduler_config, + device, + properties, + generation_config, + true } {}; + + void generate_candidates(); + + // { generated_len, validation_len } + using SequenceLen = std::pair; + std::map get_generated_request_len(); + +protected: + TokenIds generate_candidates(const TokenIds& input_ids, size_t num_pred_tokens, size_t max_ngram_size); +}; +} \ No newline at end of file diff --git a/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp b/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp new file mode 100644 index 0000000000..f934a56939 --- /dev/null +++ b/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp @@ -0,0 +1,159 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "prompt_lookup_impl.hpp" +#include "text_callback_streamer.hpp" + +namespace ov::genai { +template struct overloaded : Ts... {using Ts::operator()...;}; +template overloaded(Ts...) -> overloaded; + +GenerationHandle +ContinuousBatchingPipeline::PromptLookupImpl::add_request(uint64_t request_id, + const ov::Tensor& input_ids, + ov::genai::GenerationConfig sampling_params) { + OPENVINO_ASSERT(sampling_params.is_prompt_lookup(), "`max_ngram_size` && `num_assistant_tokens` should be specified for `prompt lookup decoding`"); + return m_pipeline->add_request(request_id, input_ids, sampling_params); +}; + +GenerationHandle +ContinuousBatchingPipeline::PromptLookupImpl::add_request(uint64_t request_id, + const std::string& prompt, + ov::genai::GenerationConfig sampling_params) { + OPENVINO_ASSERT(sampling_params.is_prompt_lookup(), "`max_ngram_size` && `num_assistant_tokens` should be specified for `prompt lookup decoding`"); + return m_pipeline->add_request(request_id, prompt, sampling_params); +} + +bool ContinuousBatchingPipeline::PromptLookupImpl::has_non_finished_requests() { + return m_pipeline->has_non_finished_requests(); +} + +void ContinuousBatchingPipeline::PromptLookupImpl::step() { + ManualTimer candidates_timer("prompt_lookup_decoding: generate_candidates()"); + candidates_timer.start(); + m_pipeline->generate_candidates(); + candidates_timer.end(); + m_sd_metrics.draft_duration += candidates_timer.get_duration(); + auto generated_len_before = m_pipeline->get_generated_request_len(); + + ManualTimer main_timer("prompt_lookup_decoding: step()"); + main_timer.start(); + m_pipeline->step(); + main_timer.end(); + m_sd_metrics.main_duration += main_timer.get_duration(); + m_pipeline_metrics = m_pipeline->get_metrics(); + auto generated_len_after = m_pipeline->get_generated_request_len(); + + for (const auto request : generated_len_before) { + auto request_id = request.first; + auto prev_validation_len = request.second.second; + if (prev_validation_len == 0) { + continue; + } + size_t num_matches = prev_validation_len; + float acceptance_rate = 1.f; + if (generated_len_after.count(request.first)) { + auto present_req_len = generated_len_after.at(request.first).first; + auto prev_full_req_len = request.second.first; + + num_matches = (present_req_len - prev_full_req_len - 1); + acceptance_rate = static_cast(num_matches) / static_cast(prev_validation_len); + } + m_sd_metrics.update_acceptance_rate(request_id, acceptance_rate * 100); + m_sd_metrics.update_draft_accepted_tokens(request_id, num_matches); + } + + if (generated_len_after.empty() && 0) { + m_sd_metrics.print(true); + m_sd_metrics.clean_up(); + } +} + +std::vector +ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer) { + ManualTimer generate_timer("speculative_decoding: generate()"); + generate_timer.start(); + OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request"); + OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); + const std::shared_ptr& streamer_ptr = std::visit(overloaded{ + [](std::monostate) -> std::shared_ptr { + return nullptr; + }, + [](const std::shared_ptr& streamer) { + return streamer; + }, + [this](const std::function& streamer) -> std::shared_ptr { + return std::make_unique(m_tokenizer, streamer); + } + }, streamer); + + OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()), + "Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); + + std::vector main_generations; + for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { + OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); + OPENVINO_ASSERT(sampling_params[request_id].is_prompt_lookup(), "`max_ngram_size` && `num_assistant_tokens` should be specified for `prompt lookup decoding`"); + main_generations.push_back(m_pipeline->add_request(request_id, input_ids[request_id], sampling_params[request_id])); + } + + std::vector results; + results.reserve(input_ids.size()); + + bool continue_generation = true; + while (has_non_finished_requests() && continue_generation) { + step(); + if (streamer_ptr) { + // not generated tokens like several prompt phase + if (!main_generations.at(0).get()->can_read()) { + continue; + } + std::unordered_map token = main_generations.at(0).get()->back(); + OPENVINO_ASSERT(1 <= token.size()); + OPENVINO_ASSERT(1 <= token.begin()->second.generated_ids.size()); + for (const auto& gen_token : token.begin()->second.generated_ids) { + continue_generation = !streamer_ptr->put(gen_token); + if (!continue_generation) { + break; + } + } + } + } + if (streamer_ptr) { + streamer_ptr->end(); + } + + for (size_t generation_idx = 0; generation_idx < main_generations.size(); ++generation_idx) { + const auto& generation = main_generations[generation_idx]; + EncodedGenerationResult result; + result.m_request_id = 1; + std::vector generation_outputs = generation->read_all(); + std::sort(generation_outputs.begin(), generation_outputs.end(), [=] (GenerationOutput& r1, GenerationOutput& r2) { + return r1.score > r2.score; + }); + + auto num_outputs = std::min(sampling_params[generation_idx].num_return_sequences, generation_outputs.size()); + for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) { + const auto& generation_output = generation_outputs[generation_output_idx]; + m_sd_metrics.set_generated_len(generation_idx, generation_outputs[generation_output_idx].generated_ids.size()); + result.m_generation_ids.push_back(std::move(generation_output.generated_ids)); + result.m_scores.push_back(generation_output.score); + } + result.m_status = generation->get_status(); + results.push_back(std::move(result)); + } + + OPENVINO_ASSERT(results.size() == input_ids.size()); + generate_timer.end(); + m_sd_metrics.total_duration = generate_timer.get_duration(); + + return results; +} + +SpeculativeDecodingMetrics +ContinuousBatchingPipeline::PromptLookupImpl::get_metrics() { + return m_sd_metrics; +}; +} diff --git a/src/cpp/src/prompt_lookup/prompt_lookup_impl.hpp b/src/cpp/src/prompt_lookup/prompt_lookup_impl.hpp new file mode 100644 index 0000000000..dae721741b --- /dev/null +++ b/src/cpp/src/prompt_lookup/prompt_lookup_impl.hpp @@ -0,0 +1,49 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "openvino/genai/continuous_batching_pipeline.hpp" +#include "continuous_batching_impl.hpp" +#include "continuous_batching_for_prompt_lookup.hpp" +#include "speculative_decoding/speculative_decoding_metrics.hpp" +#include "utils.hpp" + +namespace ov::genai { + +class ContinuousBatchingPipeline::PromptLookupImpl : public ContinuousBatchingPipeline::ImplInterface { +protected: + std::shared_ptr m_pipeline; + SpeculativeDecodingMetrics m_sd_metrics; + +public: + PromptLookupImpl(const std::shared_ptr& model, + const Tokenizer& tokenizer, + const SchedulerConfig& scheduler_config, + const std::string& device, + const ov::AnyMap& properties, + const ov::genai::GenerationConfig& generation_config) { + m_tokenizer = tokenizer; + m_pipeline = std::make_shared(model, tokenizer, scheduler_config, device, properties, generation_config); + }; + + GenerationHandle add_request(uint64_t request_id, + const ov::Tensor& input_ids, + ov::genai::GenerationConfig sampling_params) override; + GenerationHandle add_request(uint64_t request_id, + const std::string& prompt, + ov::genai::GenerationConfig sampling_params) override; + + bool has_non_finished_requests() override; + + void step() override; + + std::vector + generate(const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer) override; + + SpeculativeDecodingMetrics get_metrics(); +}; + +} \ No newline at end of file diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index 06a51b38be..36f274f30f 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -141,7 +141,7 @@ init_request( LogitProcessor& logit_processor, bool is_update_logit_processor, bool is_init_all_sequences_in_request = false) { - OPENVINO_ASSERT(request->get_sampling_parameters().is_speculative_decoding(), + OPENVINO_ASSERT(request->get_sampling_parameters().is_assisting_generation(), "Speculative decoding should have initialized options `assistant_confidence_threshold` xor `num_assistant_tokens` in `GenerationConfig`."); if (candidates.begin()->second.token_ids.empty() && !is_init_all_sequences_in_request) { return 0; @@ -303,7 +303,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m to_generate = false; for (auto& request : m_requests) { const auto& sampling_params = request->get_sampling_parameters(); - if (!sampling_params.is_speculative_decoding()) { + if (!sampling_params.is_assisting_generation()) { // generate only one token in case of non speculative decoding request->pause_generation(true); } else if (request->get_num_processed_tokens() >= request->get_prompt_len() && diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index e4f3b1ad1f..4a0748b5c0 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -182,6 +182,11 @@ void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() { m_sd_metrics.update_acceptance_rate(request_id, acceptance_rate * 100); m_sd_metrics.update_draft_accepted_tokens(request_id, (updated_seq_info.inserted_tokens_cnt - updated_seq_info.removed_tokens_cnt)); } + + if (main_generated_requests.empty() && 0) { + m_sd_metrics.print(true); + m_sd_metrics.clean_up(); + } } std::vector @@ -266,24 +271,6 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< OPENVINO_ASSERT(results.size() == input_ids.size()); generate_timer.end(); - m_sd_metrics.total_duration = generate_timer.get_duration(); - - // Print Speculative decoding metrics - if (0) { - std::cout << std::endl; - std::cout << "Total duration, ms: " << m_sd_metrics.total_duration << std::endl; - std::cout << "Draft model duration, ms: " << m_sd_metrics.draft_duration << std::endl; - std::cout << "Main model duration, ms: " << m_sd_metrics.main_duration << std::endl; - std::cout << "Draft model duration, %: " << m_sd_metrics.get_draft_duration_percentage() << std::endl; - std::cout << "Main model duration, %: " << m_sd_metrics.get_main_duration_percentage() << std::endl; - std::cout << "Main model iterations: " << m_sd_metrics.get_iteration_number(0) << std::endl; - std::cout << "Token per sec: " << float(sampling_params[0].max_new_tokens) / m_sd_metrics.total_duration << std::endl; - std::cout << "AVG acceptance rate, %: " << m_sd_metrics.get_avg_acceptance_rate(0) << std::endl; - std::cout << "Accepted tokens by draft model: " << m_sd_metrics.get_draft_accepted_tokens_counter(0) << std::endl; - std::cout << "Generated tokens: " << sampling_params[0].max_new_tokens << std::endl; - std::cout << "Accepted token rate, %: " << m_sd_metrics.get_draft_accepted_tokens_percentage(0) << std::endl; - } - return results; } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_metrics.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_metrics.cpp index 42d3f0b750..4e5602482a 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_metrics.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_metrics.cpp @@ -95,4 +95,63 @@ void SpeculativeDecodingMetrics::set_generated_len(int64_t request_id, size_t ge m_generated_len.insert({ request_id, generated_len }); } +size_t SpeculativeDecodingMetrics::get_generated_len(int64_t request_id) { + return m_generated_len.at(request_id); +} + +std::vector SpeculativeDecodingMetrics::get_requests_id() { + std::vector result; + for (const auto& req : m_generated_len) { + result.push_back(req.first); + } + return result; +} + +void SpeculativeDecodingMetrics::print_acceptance_rates() { + for (const auto& a : m_acceptance_rate) { + std::cout << "Request_id: " << a.first << " ||| "; + for (const auto& b : a.second) { + std::cout << b << " "; + } + std::cout << std::endl; + } +} + +void SpeculativeDecodingMetrics::print(bool is_printing_per_request) { + if (total_duration == 0) { + total_duration = draft_duration + main_duration; + } + std::cout << "\n=============================== " << std::endl; + std::cout << "Total duration, ms: " << total_duration << std::endl; + std::cout << "Draft model duration, ms: " << draft_duration << std::endl; + std::cout << "Main model duration, ms: " << main_duration << std::endl; + std::cout << "Draft model duration, %: " << get_draft_duration_percentage() << std::endl; + std::cout << "Main model duration, %: " << get_main_duration_percentage() << std::endl; + std::cout << "AVG acceptance rate, %: " << get_avg_acceptance_rate(-1) << std::endl; + std::cout << "=============================== " << std::endl; + if (is_printing_per_request) { + for (const auto& i : get_requests_id()) { + std::cout << "REQUEST_ID: " << i << std::endl; + std::cout << "Main model iterations: " << get_iteration_number(i) << std::endl; + std::cout << "Token per sec: " << float(get_generated_len(i)) / total_duration << std::endl; + std::cout << "AVG acceptance rate, %: " << get_avg_acceptance_rate(i) << std::endl; + std::cout << "Accepted tokens by draft model: " << get_draft_accepted_tokens_counter(i) << std::endl; + std::cout << "Generated tokens: " << get_generated_len(i) << std::endl; + std::cout << "Accepted token rate, %: " << get_draft_accepted_tokens_percentage(i) << std::endl; + std::cout << "=============================== " << std::endl; + } + print_acceptance_rates(); + } + +} + +void SpeculativeDecodingMetrics::clean_up() { + m_acceptance_rate.clear(); + m_draft_accepted_tokens.clear(); + m_generated_len.clear(); + draft_duration = 0; + main_duration = 0; + total_duration = 0; +} + } \ No newline at end of file diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_metrics.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_metrics.hpp index 5256128277..0d9173b99f 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_metrics.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_metrics.hpp @@ -28,6 +28,7 @@ class SpeculativeDecodingMetrics { void update_draft_accepted_tokens(int64_t request_id, size_t num_matches); void set_generated_len(int64_t request_id, size_t generated_len); + size_t get_generated_len(int64_t request_id); size_t get_iteration_number(int64_t request_id); @@ -35,5 +36,11 @@ class SpeculativeDecodingMetrics { float get_main_duration_percentage(); float get_inference_duration_percentage(); + std::vector get_requests_id(); + + void print_acceptance_rates(); + void print(bool is_printing_per_request = false); + + void clean_up(); }; } \ No newline at end of file diff --git a/src/python/openvino_genai/__init__.py b/src/python/openvino_genai/__init__.py index 470ddd0cd8..a0b0faf58c 100644 --- a/src/python/openvino_genai/__init__.py +++ b/src/python/openvino_genai/__init__.py @@ -28,7 +28,7 @@ # LLM pipeline from .py_openvino_genai import ( LLMPipeline, - draft_model + draft_model, ) # LoRA diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index 6135a187eb..524ff0f921 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -575,6 +575,7 @@ class GenerationConfig: logprobs: int max_length: int max_new_tokens: int + max_ngram_size: int min_new_tokens: int no_repeat_ngram_size: int num_assistant_tokens: int @@ -598,11 +599,13 @@ class GenerationConfig: @typing.overload def __init__(self, **kwargs) -> None: ... + def is_assisting_generation(self) -> bool: + ... def is_beam_search(self) -> bool: ... def is_greedy_decoding(self) -> bool: ... - def is_speculative_decoding(self) -> bool: + def is_prompt_lookup(self) -> bool: ... def set_eos_token_id(self, tokenizer_eos_token_id: int) -> None: ... @@ -2122,11 +2125,7 @@ class WhisperRawPerfMetrics: @property def features_extraction_durations(self) -> list[float]: ... -class draft_model: +def draft_model(models_path: os.PathLike, device: str = '', **kwargs) -> openvino._pyopenvino.OVAny: """ - This class is used to enable Speculative Decoding + device on which inference will be performed """ - def __init__(self, models_path: os.PathLike, device: str = '', **kwargs) -> None: - """ - device on which inference will be performed - """ diff --git a/src/python/py_generation_config.cpp b/src/python/py_generation_config.cpp index d24a915dd6..b1a5c6cd2e 100644 --- a/src/python/py_generation_config.cpp +++ b/src/python/py_generation_config.cpp @@ -107,12 +107,14 @@ void init_generation_config(py::module_& m) { .def_readwrite("logprobs", &GenerationConfig::logprobs) .def_readwrite("assistant_confidence_threshold", &GenerationConfig::assistant_confidence_threshold) .def_readwrite("num_assistant_tokens", &GenerationConfig::num_assistant_tokens) + .def_readwrite("max_ngram_size", &GenerationConfig::max_ngram_size) .def_readwrite("include_stop_str_in_output", &GenerationConfig::include_stop_str_in_output) .def_readwrite("stop_token_ids", &GenerationConfig::stop_token_ids) .def_readwrite("adapters", &GenerationConfig::adapters) .def("set_eos_token_id", &GenerationConfig::set_eos_token_id, py::arg("tokenizer_eos_token_id")) .def("is_beam_search", &GenerationConfig::is_beam_search) .def("is_greedy_decoding", &GenerationConfig::is_greedy_decoding) - .def("is_speculative_decoding", &GenerationConfig::is_speculative_decoding) + .def("is_assisting_generation", &GenerationConfig::is_assisting_generation) + .def("is_prompt_lookup", &GenerationConfig::is_prompt_lookup) .def("update_generation_config", static_cast(&ov::genai::GenerationConfig::update_generation_config), py::arg("config_map")); } diff --git a/src/python/py_llm_pipeline.cpp b/src/python/py_llm_pipeline.cpp index b53cc56f10..b1d5136253 100644 --- a/src/python/py_llm_pipeline.cpp +++ b/src/python/py_llm_pipeline.cpp @@ -195,15 +195,14 @@ void init_llm_pipeline(py::module_& m) { .def("get_generation_config", &LLMPipeline::get_generation_config, py::return_value_policy::copy) .def("set_generation_config", &LLMPipeline::set_generation_config, py::arg("config")); - py::class_(m, "draft_model", py::module_local(), "This class is used to enable Speculative Decoding") - .def(py::init([]( + m.def("draft_model", []( const std::filesystem::path& models_path, const std::string& device, const py::kwargs& kwargs ) { ScopedVar env_manager(pyutils::ov_tokenizers_module_path()); return draft_model(models_path, device, pyutils::kwargs_to_any_map(kwargs)).second; - }), + }, py::arg("models_path"), "folder with openvino_model.xml and openvino_tokenizer[detokenizer].xml files", py::arg("device") = "", "device on which inference will be performed"); } diff --git a/src/python/py_openvino_genai.cpp b/src/python/py_openvino_genai.cpp index e821c1cfdc..429f48f30d 100644 --- a/src/python/py_openvino_genai.cpp +++ b/src/python/py_openvino_genai.cpp @@ -21,7 +21,6 @@ using ov::genai::DecodedResults; using ov::genai::EncodedResults; using ov::genai::StreamerBase; using ov::genai::StringInputs; -using ov::genai::draft_model; void init_lora_adapter(py::module_& m); void init_perf_metrics(py::module_& m); diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index f404e63cff..093cd993de 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -19,6 +19,7 @@ file(GLOB src_files "${OpenVINOGenAI_SOURCE_DIR}/src/cpp/src/sequence_group.cpp" "${OpenVINOGenAI_SOURCE_DIR}/src/cpp/src/cache_eviction.cpp" "${OpenVINOGenAI_SOURCE_DIR}/src/cpp/src/sampler.cpp" "${OpenVINOGenAI_SOURCE_DIR}/src/cpp/src/speculative_decoding/*.cpp" + "${OpenVINOGenAI_SOURCE_DIR}/src/cpp/src/prompt_lookup/*.cpp" "${OpenVINOGenAI_SOURCE_DIR}/src/cpp/src/utils/*.cpp" "${OpenVINOGenAI_SOURCE_DIR}/src/cpp/src/utils.cpp" "${OpenVINOGenAI_SOURCE_DIR}/src/cpp/src/continuous_batching*.cpp"