From 4d4942ee61af1bd15eea8f5b2bb69d4bb803e08b Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Mon, 22 Jul 2024 13:10:03 +0200 Subject: [PATCH] refactor structure, add python sample --- .../benchmark_vanilla_genai.cpp | 19 ++-- .../python/benchmark_vanilla_genai/README.md | 51 +++++++++ .../benchmark_vanilla_genai.py | 48 +++++++++ .../include/openvino/genai/llm_pipeline.hpp | 2 + .../include/openvino/genai/perf_metrics.hpp | 37 +++++-- src/cpp/src/greedy_decoding.cpp | 10 +- src/cpp/src/group_beam_searcher.cpp | 20 ++-- src/cpp/src/llm_pipeline.cpp | 31 +++--- src/cpp/src/perf_counters.cpp | 21 ---- src/cpp/src/perf_counters.hpp | 44 -------- src/cpp/src/perf_metrics.cpp | 100 +++++++++++------- src/cpp/src/tokenizer.cpp | 2 - src/python/py_generate_pipeline.cpp | 18 +++- 13 files changed, 254 insertions(+), 149 deletions(-) create mode 100644 samples/python/benchmark_vanilla_genai/README.md create mode 100755 samples/python/benchmark_vanilla_genai/benchmark_vanilla_genai.py delete mode 100644 src/cpp/src/perf_counters.cpp delete mode 100644 src/cpp/src/perf_counters.hpp diff --git a/samples/cpp/benchmark_vanilla_genai/benchmark_vanilla_genai.cpp b/samples/cpp/benchmark_vanilla_genai/benchmark_vanilla_genai.cpp index 6489282b0b..56aaca8cfd 100644 --- a/samples/cpp/benchmark_vanilla_genai/benchmark_vanilla_genai.cpp +++ b/samples/cpp/benchmark_vanilla_genai/benchmark_vanilla_genai.cpp @@ -11,7 +11,7 @@ int main(int argc, char* argv[]) try { ("p,prompt", "Prompt", cxxopts::value()->default_value("The Sky is blue because")) ("m,model", "Path to model and tokenizers base directory", cxxopts::value()->default_value(".")) ("nw,num_warmup", "Number of warmup iterations", cxxopts::value()->default_value(std::to_string(1))) - ("n,num_iter", "Number of iterations", cxxopts::value()->default_value(std::to_string(1))) + ("n,num_iter", "Number of iterations", cxxopts::value()->default_value(std::to_string(5))) ("d,device", "device", cxxopts::value()->default_value("CPU")) ("h,help", "Print usage"); @@ -36,7 +36,7 @@ int main(int argc, char* argv[]) try { size_t num_iter = result["num_iter"].as(); ov::genai::GenerationConfig config; - config.max_new_tokens = 100; + config.max_new_tokens = 5; config.num_beam_groups = 3; config.num_beams = 15; @@ -45,17 +45,20 @@ int main(int argc, char* argv[]) try { for (size_t i = 0; i < num_warmup; i++) pipe.generate(prompt, config); - ov::genai::PerfMetrics metrics; - for (size_t i = 0; i < num_iter; i++) { - ov::genai::DecodedResults res = pipe.generate(prompt, config); + ov::genai::DecodedResults res = pipe.generate(prompt, config); + ov::genai::PerfMetrics metrics = res.metrics; + for (size_t i = 0; i < num_iter - 1; i++) { + res = pipe.generate(prompt, config); metrics = metrics + res.metrics; - metrics.load_time = res.metrics.load_time; } std::cout << "Load time: " << metrics.load_time << " ms" << std::endl; + std::cout << "Generate time: " << metrics.mean_generate_duration << " ± " << metrics.std_generate_duration << " ms" << std::endl; + std::cout << "Tokenization time: " << metrics.mean_tokenization_duration << " ± " << metrics.std_tokenization_duration << " ms" << std::endl; + std::cout << "Detokenization time: " << metrics.mean_detokenization_duration << " ± " << metrics.std_detokenization_duration << " ms" << std::endl; std::cout << "ttft: " << metrics.mean_ttft << " ± " << metrics.std_ttft << " ms" << std::endl; - std::cout << "tpot: " << metrics.mean_tpot << " ± " << metrics.std_tpot << " ms" << std::endl; - std::cout << "Tokens/s: " << metrics.mean_throughput << std::endl; + std::cout << "tpot: " << metrics.mean_tpot << " ± " << metrics.std_tpot << " ms " << std::endl; + std::cout << "Tokens/s: " << metrics.mean_throughput << " ± " << metrics.std_throughput << std::endl; return 0; } catch (const std::exception& error) { diff --git a/samples/python/benchmark_vanilla_genai/README.md b/samples/python/benchmark_vanilla_genai/README.md new file mode 100644 index 0000000000..0353f3f2c6 --- /dev/null +++ b/samples/python/benchmark_vanilla_genai/README.md @@ -0,0 +1,51 @@ +# Benchmark Vanilla GenAI + +This sample script demonstrates how to benchmark an LLMModel in OpenVINO GenAI. The script includes functionality for warm-up iterations, generating text, and calculating various performance metrics. + +# ov.genai.PerfMetrics structure +ov.genai.PerfMetrics is a structure which holds performance metric for each generate call. Each generate call calcualtes the following metrics: +- mean_ttft + - std_ttft + - mean_tpot + - std_tpot + - load_time + - mean_generate_duration + - std_generate_duration + - mean_tokenization_duration + - std_tokenization_duration + - mean_detokenization_duration + - std_detokenization_duration + - mean_throughput + - std_throughput + - num_generated_tokens + - num_input_tokens + +Performance metrics can be added to one another and accumulated using the += operator or the + operator. In that case the mean values accumulated by several generate calls will be calculated. + + +## Download and convert the model and tokenizers + +The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version. + +It's not required to install [../../requirements.txt](../../requirements.txt) for deployment if the model has already been exported. + +```sh +pip install --upgrade-strategy eager -r ../../requirements.txt +optimum-cli export openvino --trust-remote-code --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 +``` + +## Usage + +```sh +python benchmark_vanilla_genai.py [OPTIONS] +``` + +### Options + +- `-p, --prompt` (default: `"The Sky is blue because"`): The prompt to generate text. +- `-m, --model` (default: `""`): Path to the model and tokenizers base directory. +- `-nw, --num_warmup` (default: `1`): Number of warmup iterations. +- `-n, --num_iter` (default: `3`): Number of iterations. +- `-d, --device` (default: `"CPU"`): Device to run the model on. + + diff --git a/samples/python/benchmark_vanilla_genai/benchmark_vanilla_genai.py b/samples/python/benchmark_vanilla_genai/benchmark_vanilla_genai.py new file mode 100755 index 0000000000..d2abbe4de4 --- /dev/null +++ b/samples/python/benchmark_vanilla_genai/benchmark_vanilla_genai.py @@ -0,0 +1,48 @@ +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import openvino_genai as ov_genai + +def main(): + parser = argparse.ArgumentParser(description="Help command") + parser.add_argument("-p", "--prompt", type=str, default="The Sky is blue because", help="Prompt") + parser.add_argument("-m", "--model", type=str, default=".", help="Path to model and tokenizers base directory") + parser.add_argument("-nw", "--num_warmup", type=int, default=1, help="Number of warmup iterations") + parser.add_argument("-n", "--num_iter", type=int, default=3, help="Number of iterations") + parser.add_argument("-n", "--num_new_tokens", type=int, default=3, help="Maximal number of new tokens") + parser.add_argument("-d", "--device", type=str, default="CPU", help="Device") + + args = parser.parse_args() + + prompt = args.prompt + model_path = args.model + device = args.device + num_warmup = args.num_warmup + num_iter = args.num_iter + + + config = ov_genai.GenerationConfig() + config.max_new_tokens = args.num_new_tokens + + pipe = ov_genai.LLMPipeline(model_path, device) + + for _ in range(num_warmup): + pipe.generate(prompt, config) + + res = pipe.generate(prompt, config) + metrics = res.metrics + for _ in range(num_iter - 1): + res = pipe.generate(prompt, config) + metrics += res.metrics + + print(f"Load time: {metrics.load_time} ms") + print(f"Generate time: {metrics.mean_generate_duration} ± {metrics.std_generate_duration} ms") + print(f"Tokenization time: {metrics.mean_tokenization_duration} ± {metrics.std_tokenization_duration} ms") + print(f"Detokenization time: {metrics.mean_detokenization_duration} ± {metrics.std_detokenization_duration} ms") + print(f"ttft: {metrics.mean_ttft} ± {metrics.std_ttft} ms") + print(f"tpot: {metrics.mean_tpot} ± {metrics.std_tpot} ms") + print(f"Tokens/s: {metrics.mean_throughput} ± {metrics.std_throughput}") + +if __name__ == "__main__": + main() diff --git a/src/cpp/include/openvino/genai/llm_pipeline.hpp b/src/cpp/include/openvino/genai/llm_pipeline.hpp index 4db3c613e7..14100d4f16 100644 --- a/src/cpp/include/openvino/genai/llm_pipeline.hpp +++ b/src/cpp/include/openvino/genai/llm_pipeline.hpp @@ -31,6 +31,7 @@ using StringInputs = std::variant>; * * @param tokens sequence of resulting tokens * @param scores sum of logarithmic probabilities of all tokens in the sequence +* @param metrics performance metrics with tpot, ttft, etc. of type ov::genai::PerfMetrics */ class EncodedResults { public: @@ -45,6 +46,7 @@ class EncodedResults { * * @param texts vector of resulting sequences * @param scores scores for each sequence +* @param metrics performance metrics with tpot, ttft, etc. of type ov::genai::PerfMetrics */ class DecodedResults { public: diff --git a/src/cpp/include/openvino/genai/perf_metrics.hpp b/src/cpp/include/openvino/genai/perf_metrics.hpp index a11c4e0374..e66c917e81 100644 --- a/src/cpp/include/openvino/genai/perf_metrics.hpp +++ b/src/cpp/include/openvino/genai/perf_metrics.hpp @@ -7,14 +7,34 @@ #include "openvino/genai/visibility.hpp" #include #include +#include namespace ov { namespace genai { using TimePoint = std::chrono::steady_clock::time_point; -struct PerfCounters; +/** +* @brief Structure with raw performance metrics for each generation before any statistics calculated. +*/ +struct OPENVINO_GENAI_EXPORTS RawPerfMetrics { + std::vector generate_durations; + std::vector tokenization_durations; + std::vector detokenization_durations; + + std::vector m_times_to_first_token; + std::vector m_new_token_times; + std::vector m_batch_sizes; + std::vector m_durations; + size_t num_generated_tokens; + size_t num_input_tokens; +}; + +/** +* @brief Structure to store performance metric for each generation +* +*/ struct OPENVINO_GENAI_EXPORTS PerfMetrics { // First token time. float mean_ttft; @@ -25,11 +45,13 @@ struct OPENVINO_GENAI_EXPORTS PerfMetrics { float std_tpot; float load_time; - float start_time; float mean_generate_duration; - float mean_decoding_duration; - float mean_encoding_duration; + float std_generate_duration; + float mean_tokenization_duration; + float std_tokenization_duration; + float mean_detokenization_duration; + float std_detokenization_duration; float mean_throughput; float std_throughput; @@ -37,13 +59,12 @@ struct OPENVINO_GENAI_EXPORTS PerfMetrics { size_t num_generated_tokens; size_t num_input_tokens; - std::shared_ptr m_counters; - void evaluate(TimePoint start_time); - + void evaluate_statistics(std::optional start_time = std::nullopt); + static float get_duration_ms(std::chrono::steady_clock::duration duration); PerfMetrics operator+(const PerfMetrics& metrics) const; PerfMetrics& operator+=(const PerfMetrics& right); - + RawPerfMetrics raw_counters; }; } // namespace genai diff --git a/src/cpp/src/greedy_decoding.cpp b/src/cpp/src/greedy_decoding.cpp index 0802b87e66..c5bf10a2d1 100644 --- a/src/cpp/src/greedy_decoding.cpp +++ b/src/cpp/src/greedy_decoding.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "openvino/genai/perf_metrics.hpp" -#include "perf_counters.hpp" +// #include "perf_counters.hpp" #include "utils.hpp" namespace ov { @@ -23,7 +23,7 @@ EncodedResults greedy_decoding( size_t max_new_tokens = generation_config.get_max_new_tokens(prompt_len); EncodedResults results; - auto& perf_counters = results.metrics.m_counters; + auto& raw_perf_counters = results.metrics.raw_counters; results.scores.resize(running_batch_size); results.tokens.resize(running_batch_size); @@ -54,7 +54,8 @@ EncodedResults greedy_decoding( eos_met[batch] = (out_token == generation_config.eos_token_id); m_model_runner.get_tensor("input_ids").data()[batch] = out_token; } - perf_counters->add_timestamp(running_batch_size); + raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now()); + raw_perf_counters.m_batch_sizes.emplace_back(batch_size); if (streamer && streamer->put(token_iter_results[0])) { return results; @@ -86,7 +87,8 @@ EncodedResults greedy_decoding( m_model_runner.get_tensor("input_ids").data()[batch] = out_token; } - perf_counters->add_timestamp(running_batch_size); + raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now()); + raw_perf_counters.m_batch_sizes.emplace_back(batch_size); if (streamer && streamer->put(token_iter_results[0])) return results; diff --git a/src/cpp/src/group_beam_searcher.cpp b/src/cpp/src/group_beam_searcher.cpp index 4f5cb79f2a..784ff1a915 100644 --- a/src/cpp/src/group_beam_searcher.cpp +++ b/src/cpp/src/group_beam_searcher.cpp @@ -366,11 +366,6 @@ std::pair beam_search(ov::InferRequest& lm, auto batch_size = input_ids.get_shape().at(0); auto sequence_length = input_ids.get_shape().at(1); - // Initialize time metric counters. - // ov::genai::TimePoints tok_times; - // tok_times.reserve(config.get_max_new_tokens(sequence_length)); - // tok_times.emplace_back(std::chrono::steady_clock::now()); - // Initialize beam search. const int64_t* prompt_data = input_ids.data(); std::vector> prompts; @@ -407,12 +402,19 @@ std::pair beam_search(ov::InferRequest& lm, std::vector next_tokens; std::vector next_beams; - + + // Reserve for performance counters. + std::vector new_token_times; + std::vector batch_sizes; + new_token_times.reserve(parameters.max_new_tokens); + batch_sizes.reserve(parameters.max_new_tokens); + for (size_t length_count = 0; ; ++length_count) { lm.infer(); std::tie(next_tokens, next_beams) = group_beam_searcher.select_next_tokens(lm.get_tensor("logits")); - // tok_times.emplace_back(std::chrono::steady_clock::now()); + new_token_times.emplace_back(std::chrono::steady_clock::now()); + batch_sizes.emplace_back(batch_size); if (next_tokens.empty() || length_count == parameters.max_new_tokens - 1) { // Break the cycle before masks are extended in update_attention_mask_with_beams. @@ -442,6 +444,9 @@ std::pair beam_search(ov::InferRequest& lm, int32_t res_selected_beam_idx = 0; results.scores.reserve(config.num_return_sequences * result.size()); results.tokens.reserve(config.num_return_sequences * result.size()); + auto& raw_perf_counters = results.metrics.raw_counters; + raw_perf_counters.m_new_token_times = new_token_times; + raw_perf_counters.m_batch_sizes = batch_sizes; // align output with HF for (size_t prompt_id = 0; prompt_id < result.size(); prompt_id++) { @@ -471,7 +476,6 @@ std::pair beam_search(ov::InferRequest& lm, } } - // results.metrics = PerfCounters(tok_times); return {results, res_selected_beam_idx}; } diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 81f807c149..5241142afe 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -1,7 +1,6 @@ // Copyright (C) 2023-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 -#include "perf_counters.hpp" #include #include #include @@ -160,14 +159,18 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { m_templated_chat_history.append(answer); m_history.push_back({{"role", "assistant"}, {"content", answer}}); } - - auto& metrics = encoded_results.metrics; - // metrics.tokenization_duration = std::chrono::duration_cast(encode_stop_time - start_time).count(); - // metrics.detokenization_duration = std::chrono::duration_cast(decode_stop_time - decode_start_time).count(); - // auto stop_time = std::chrono::steady_clock::now(); - // metrics.generate_durations.emplace_back(std::chrono::duration_cast(stop_time - start_time).count()); - decoded_results.metrics = std::move(metrics); + // generate_durations + decoded_results.metrics = encoded_results.metrics; + + auto& raw_counters = decoded_results.metrics.raw_counters; + auto stop_time = std::chrono::steady_clock::now(); + + raw_counters.generate_durations.emplace_back(PerfMetrics::get_duration_ms(stop_time - start_time)); + raw_counters.tokenization_durations.emplace_back(PerfMetrics::get_duration_ms(encode_stop_time - start_time)); + raw_counters.detokenization_durations.emplace_back(PerfMetrics::get_duration_ms(decode_stop_time - decode_start_time)); + + decoded_results.metrics.evaluate_statistics(start_time); return decoded_results; } @@ -267,13 +270,11 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { m_is_cache_empty = false; } - - + // If is called without tokenization then that stat will not be reported. auto& metrics = result.metrics; - // metrics.batch_size = batch_size; - // metrics.num_generated_tokens = (metrics.m_durations.size() + 1) * batch_size; - metrics.num_input_tokens = batch_size * input_ids.get_shape().at(0); - result.metrics = std::move(metrics); + metrics.num_input_tokens = batch_size * input_ids.get_shape().at(1); + metrics.load_time = this->m_load_time_ms; + metrics.evaluate_statistics(start_time); return result; } @@ -390,7 +391,7 @@ ov::genai::LLMPipeline::LLMPipeline( m_pimpl = make_unique(std::filesystem::path(path), device, config); } auto stop_time = std::chrono::steady_clock::now(); - m_pimpl->m_load_time_ms = std::chrono::duration_cast(stop_time - start_time).count(); + m_pimpl->m_load_time_ms = PerfMetrics::get_duration_ms(stop_time - start_time); } ov::genai::GenerationConfig ov::genai::LLMPipeline::get_generation_config() const { diff --git a/src/cpp/src/perf_counters.cpp b/src/cpp/src/perf_counters.cpp deleted file mode 100644 index c9dac6eca0..0000000000 --- a/src/cpp/src/perf_counters.cpp +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (C) 2023-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#include "perf_counters.hpp" -#include "openvino/genai/perf_metrics.hpp" -#include "openvino/openvino.hpp" -#include -#include -#include - -namespace ov { -namespace genai { - -void PerfCounters::add_timestamp(size_t batch_size) { - m_new_token_times.emplace_back(std::chrono::steady_clock::now()); - m_batch_sizes.emplace_back(batch_size); -} - - -} // namespace genai -} // namespace ov diff --git a/src/cpp/src/perf_counters.hpp b/src/cpp/src/perf_counters.hpp deleted file mode 100644 index 7d33490205..0000000000 --- a/src/cpp/src/perf_counters.hpp +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (C) 2023-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include -#include -#include - -namespace ov { -namespace genai { - -struct PerfCounters { - std::vector generate_durations; - std::vector tokenization_duration; - std::vector detokenization_duration; - size_t num_generated_tokens; - size_t num_input_tokens; - - std::vector m_batch_sizes; - std::vector m_durations; - std::vector m_times_to_first_token; - std::vector m_new_token_times; - void add_timestamp(size_t batch_size); - // void add_gen_finish_timestamp(size_t batch_size); - -}; - -// class StopWatch { -// TimePoint m_start; -// public: -// StopWatch& start() { -// m_start = std::chrono::steady_clock::now(); -// return *this; -// } - -// float split() { -// std::chrono::steady_clock::time_point curr_time = std::chrono::steady_clock::now(); -// return std::chrono::duration_cast(curr_time - m_start).count(); -// } -// }; - -} // namespace genai -} // namespace ov diff --git a/src/cpp/src/perf_metrics.cpp b/src/cpp/src/perf_metrics.cpp index 4a8b1d76c6..3947793802 100644 --- a/src/cpp/src/perf_metrics.cpp +++ b/src/cpp/src/perf_metrics.cpp @@ -2,7 +2,6 @@ // SPDX-License-Identifier: Apache-2.0 #include "openvino/genai/perf_metrics.hpp" -#include "perf_counters.hpp" #include "openvino/openvino.hpp" #include #include @@ -17,7 +16,7 @@ std::pair calc_mean_and_std(const std::vector& durations) { [](const float& acc, const float& duration) -> float { return acc + duration * duration; }); - float std = std::sqrt(sum_square_durations / durations.size() - mean * mean); + float std = std::sqrt(sum_square_durations / durations.size() - mean * mean); return {mean, std}; } @@ -26,48 +25,77 @@ std::pair calc_mean_and_std(const std::vector& durations) { namespace ov { namespace genai { - -void PerfMetrics::evaluate(TimePoint start_time) { - - auto& tok_times = m_counters->m_new_token_times; - auto& batch_sizes = m_counters->m_batch_sizes; - m_counters->m_durations = std::vector(tok_times.size()); - auto ttft = std::chrono::duration_cast(tok_times[0] - start_time).count(); - m_counters->m_times_to_first_token.emplace_back(ttft); +float PerfMetrics::get_duration_ms(std::chrono::steady_clock::duration duration) { + return std::chrono::duration_cast(duration).count(); +} - for (size_t i = 0; i < tok_times.size(); ++i) { - m_counters->m_durations[i] = std::chrono::duration_cast(tok_times[i] - start_time).count(); - // If in 10 ms a batch of 5 new tokens is generated then TTOT is 10 ms / 5. - // todo: float check that it's valid for batch > 1. - m_counters->m_durations[i] /= batch_sizes[i]; - start_time = tok_times[i]; - } +void PerfMetrics::evaluate_statistics(std::optional start_time) { + // If start_tiem is specified then recalcualte durations according to start times and calculate statistics only after that. + if (start_time.has_value()) { + auto start_time_val = *start_time; + auto& tok_times = raw_counters.m_new_token_times; + auto& batch_sizes = raw_counters.m_batch_sizes; + raw_counters.m_durations = std::vector(tok_times.size()); - std::tie(mean_tpot, std_tpot) = calc_mean_and_std(m_counters->m_durations); - std::tie(mean_ttft, std_ttft) = calc_mean_and_std(m_counters->m_times_to_first_token); -} + auto ttft = std::chrono::duration_cast(tok_times[0] - start_time_val).count(); + raw_counters.m_times_to_first_token = std::vector(); + raw_counters.m_times_to_first_token.emplace_back(ttft); + num_generated_tokens = 0; + for (size_t i = 0; i < tok_times.size(); ++i) { + raw_counters.m_durations[i] = std::chrono::duration_cast(tok_times[i] - start_time_val).count(); + + // If in 10 ms a batch of 5 new tokens is generated then TTOT is 10 ms / 5. + // todo: float check that it's valid for batch > 1. + raw_counters.m_durations[i] /= batch_sizes[i]; + num_generated_tokens += batch_sizes[i]; + start_time_val = tok_times[i]; + } + } -PerfMetrics PerfMetrics::operator+(const PerfMetrics& metrics) const { - PerfMetrics nm; // new metrics - nm.m_counters = m_counters; - auto& new_counters = nm.m_counters; + std::tie(mean_tpot, std_tpot) = calc_mean_and_std(raw_counters.m_durations); + std::tie(mean_ttft, std_ttft) = calc_mean_and_std(raw_counters.m_times_to_first_token); - auto& new_durations = new_counters->m_durations; - auto& new_times_to_first_token = new_counters->m_times_to_first_token; - - auto& counters_to_appnd = metrics.m_counters; - new_durations.insert(new_durations.end(), counters_to_appnd->m_durations.begin(), counters_to_appnd->m_durations.end()); - new_times_to_first_token.insert(new_times_to_first_token.end(), counters_to_appnd->m_times_to_first_token.begin(), counters_to_appnd->m_times_to_first_token.end()); + std::tie(mean_generate_duration, std_generate_duration) = calc_mean_and_std(raw_counters.generate_durations); + std::tie(mean_tokenization_duration, std_tokenization_duration) = calc_mean_and_std(raw_counters.tokenization_durations); + std::tie(mean_detokenization_duration, std_detokenization_duration) = calc_mean_and_std(raw_counters.detokenization_durations); - OPENVINO_ASSERT(metrics.load_time == load_time, "generation metrics can be accumulated only for the same pipeline"); + mean_throughput = 1000.0f / mean_tpot; + std_throughput = (std_tpot * 1000.0f) / (mean_tpot * mean_tpot); +} + +PerfMetrics PerfMetrics::operator+(const PerfMetrics& right) const { + OPENVINO_ASSERT(right.load_time == load_time, "generation metrics can be accumulated only for the same pipeline"); - std::tie(nm.mean_tpot, nm.std_tpot) = calc_mean_and_std(new_counters->m_durations); - std::tie(nm.mean_ttft, nm.std_ttft) = calc_mean_and_std(new_counters->m_times_to_first_token); + // Copy left value to res. + PerfMetrics res = *this; + + // Concatenate duration and first token times. + auto& new_durations = res.raw_counters.m_durations; + auto& new_times_to_first_token = res.raw_counters.m_times_to_first_token; + auto& right_durations = right.raw_counters.m_durations; + auto& right_times_to_first_token = right.raw_counters.m_times_to_first_token; - // todo: add tokenization statistics concatenation. + new_durations.insert(new_durations.end(), right_durations.begin(), right_durations.end()); + new_times_to_first_token.insert(new_times_to_first_token.end(), right_times_to_first_token.begin(), right_times_to_first_token.end()); + + // Concatenate tokenization/detokenization and total generation times. + auto& new_tok_durations = res.raw_counters.tokenization_durations; + auto& new_detok_durations = res.raw_counters.detokenization_durations; + auto& new_gen_durations = res.raw_counters.generate_durations; + auto& right_tok_durations = right.raw_counters.tokenization_durations; + auto& right_detok_durations = right.raw_counters.detokenization_durations; + auto& right_gen_durations = right.raw_counters.generate_durations; - return nm; + new_tok_durations.insert(new_tok_durations.end(), right_tok_durations.begin(), right_tok_durations.end()); + new_detok_durations.insert(new_detok_durations.end(), right_detok_durations.begin(), right_detok_durations.end()); + new_gen_durations.insert(new_gen_durations.end(), right_gen_durations.begin(), right_gen_durations.end()); + + res.num_generated_tokens = num_generated_tokens + right.num_generated_tokens; + res.num_input_tokens = num_generated_tokens + right.num_input_tokens; + res.load_time = load_time; + res.evaluate_statistics(); + return res; } PerfMetrics& PerfMetrics::operator+=(const PerfMetrics& right) { @@ -75,7 +103,5 @@ PerfMetrics& PerfMetrics::operator+=(const PerfMetrics& right) { return *this; } - - } // namespace genai } // namespace ov diff --git a/src/cpp/src/tokenizer.cpp b/src/cpp/src/tokenizer.cpp index 501d0e86cf..ac6b925dcb 100644 --- a/src/cpp/src/tokenizer.cpp +++ b/src/cpp/src/tokenizer.cpp @@ -323,8 +323,6 @@ class Tokenizer::TokenizerImpl { // Replace what jinja2cpp doesn't support std::pair replace_str_map[] = { - {"{-", "{"}, - {"{%-", "{%"}, {"'}", "' }"}, {"{'", "{ '"}, {".strip()", ""} diff --git a/src/python/py_generate_pipeline.cpp b/src/python/py_generate_pipeline.cpp index c78c760b6c..d0dc980e4b 100644 --- a/src/python/py_generate_pipeline.cpp +++ b/src/python/py_generate_pipeline.cpp @@ -22,6 +22,7 @@ using ov::genai::GenerationResult; using ov::genai::LLMPipeline; using ov::genai::OptionalGenerationConfig; using ov::genai::PerfMetrics; +using ov::genai::RawPerfMetrics; using ov::genai::SchedulerConfig; using ov::genai::StopCriteria; using ov::genai::StreamerBase; @@ -537,11 +538,24 @@ PYBIND11_MODULE(py_generate_pipeline, m) { .def_readonly("scores", &DecodedResults::scores) .def("__str__", &DecodedResults::operator std::string);; + py::class_(m, "RawPerfMetrics") + .def(py::init<>()) + .def_readonly("generate_durations", &RawPerfMetrics::generate_durations) + .def_readonly("tokenization_durations", &RawPerfMetrics::tokenization_durations) + .def_readonly("detokenization_durations", &RawPerfMetrics::detokenization_durations) + .def_readonly("m_times_to_first_token", &RawPerfMetrics::m_times_to_first_token) + .def_readonly("m_batch_sizes", &RawPerfMetrics::m_batch_sizes) + .def_readonly("m_durations", &RawPerfMetrics::m_durations) + .def_readonly("num_generated_tokens", &RawPerfMetrics::num_generated_tokens) + .def_readonly("num_input_tokens", &RawPerfMetrics::num_input_tokens); + py::class_(m, "PerfMetrics") .def(py::init<>()) .def_readonly("mean_generate_duration", &PerfMetrics::mean_generate_duration) - .def_readonly("mean_decoding_duration", &PerfMetrics::mean_decoding_duration) - .def_readonly("mean_encoding_duration", &PerfMetrics::mean_encoding_duration) + .def_readonly("mean_tokenization_duration", &PerfMetrics::mean_tokenization_duration) + .def_readonly("mean_detokenization_duration", &PerfMetrics::mean_detokenization_duration) + .def_readonly("mean_throughput", &PerfMetrics::mean_throughput) + .def_readonly("std_throughput", &PerfMetrics::std_throughput) .def_readonly("mean_tpot", &PerfMetrics::mean_tpot) .def_readonly("mean_ttft", &PerfMetrics::mean_ttft) .def_readonly("std_tpot", &PerfMetrics::std_tpot)