diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index 44f8d580b2..5339817c1f 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -10,7 +10,7 @@ add_subdirectory(cpp/greedy_causal_lm) add_subdirectory(cpp/multinomial_causal_lm) add_subdirectory(cpp/prompt_lookup_decoding_lm) add_subdirectory(cpp/speculative_decoding_lm) -add_subdirectory(cpp/benchmark_vanilla_genai) +add_subdirectory(cpp/benchmark_genai) install(FILES requirements.txt DESTINATION samples COMPONENT cpp_samples_genai) diff --git a/samples/cpp/benchmark_vanilla_genai/CMakeLists.txt b/samples/cpp/benchmark_genai/CMakeLists.txt similarity index 64% rename from samples/cpp/benchmark_vanilla_genai/CMakeLists.txt rename to samples/cpp/benchmark_genai/CMakeLists.txt index e871f5a33a..bfa1592f61 100644 --- a/samples/cpp/benchmark_vanilla_genai/CMakeLists.txt +++ b/samples/cpp/benchmark_genai/CMakeLists.txt @@ -12,14 +12,14 @@ FetchContent_Declare(cxxopts URL_HASH SHA256=523175f792eb0ff04f9e653c90746c12655f10cb70f1d5e6d6d9491420298a08) FetchContent_MakeAvailable(cxxopts) -add_executable(benchmark_vanilla_genai benchmark_vanilla_genai.cpp) -target_link_libraries(benchmark_vanilla_genai PRIVATE openvino::genai cxxopts::cxxopts) -set_target_properties(benchmark_vanilla_genai PROPERTIES - COMPILE_PDB_NAME benchmark_vanilla_genai +add_executable(benchmark_genai benchmark_genai.cpp) +target_link_libraries(benchmark_genai PRIVATE openvino::genai cxxopts::cxxopts) +set_target_properties(benchmark_genai PROPERTIES + COMPILE_PDB_NAME benchmark_genai # Ensure out of box LC_RPATH on macOS with SIP INSTALL_RPATH_USE_LINK_PATH ON) -# target_compile_features(benchmark_vanilla_genai PRIVATE cxx_std_11) -install(TARGETS benchmark_vanilla_genai +# target_compile_features(benchmark_genai PRIVATE cxx_std_11) +install(TARGETS benchmark_genai RUNTIME DESTINATION samples_bin/ COMPONENT samples_bin EXCLUDE_FROM_ALL) diff --git a/samples/cpp/benchmark_genai/README.md b/samples/cpp/benchmark_genai/README.md new file mode 100644 index 0000000000..bac16c2f7d --- /dev/null +++ b/samples/cpp/benchmark_genai/README.md @@ -0,0 +1,47 @@ +# Benchmarking Vanilla GenAI + +This sample script demonstrates how to benchmark an LLMModel in OpenVINO GenAI. The script includes functionality for warm-up iterations, generating text, and calculating various performance metrics. + +## Download and convert the model and tokenizers + +The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version. + +It's not required to install [../../requirements.txt](../../requirements.txt) for deployment if the model has already been exported. + +```sh +pip install --upgrade-strategy eager -r ../../requirements.txt +optimum-cli export openvino --trust-remote-code --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 +``` + +## Usage + +```sh +benchmark_vanilla_genai [OPTIONS] +``` + +### Options + +- `-m, --model`: Path to the model and tokenizers base directory. +- `-p, --prompt` (default: `"The Sky is blue because"`): The prompt to generate text. +- `-nw, --num_warmup` (default: `1`): Number of warmup iterations. +- `-mt, --max_new_tokens` (default: `20`): Number of warmup iterations. +- `-n, --num_iter` (default: `3`): Number of iterations. +- `-d, --device` (default: `"CPU"`): Device to run the model on. + +### Output: + +``` +benchmark_vanilla_genai -m TinyLlama-1.1B-Chat-v1.0 -n 10 +``` + +``` +Load time: 3405.69 ms +Generate time: 1430.77 ± 3.04 ms +Tokenization time: 0.51 ± 0.02 ms +Detokenization time: 0.37 ± 0.01 ms +TTFT: 81.60 ± 0.54 ms +TPOT: 71.52 ± 2.72 ms +Throughput tokens/s: 13.98 ± 0.53 +``` + +For more information how performance metrics are calculated please follow [performance-metrics tutorial](../../../src/README.md#performance-metrics). diff --git a/samples/cpp/benchmark_vanilla_genai/benchmark_vanilla_genai.cpp b/samples/cpp/benchmark_genai/benchmark_genai.cpp similarity index 90% rename from samples/cpp/benchmark_vanilla_genai/benchmark_vanilla_genai.cpp rename to samples/cpp/benchmark_genai/benchmark_genai.cpp index a9bc07f641..9610aabe54 100644 --- a/samples/cpp/benchmark_vanilla_genai/benchmark_vanilla_genai.cpp +++ b/samples/cpp/benchmark_genai/benchmark_genai.cpp @@ -8,11 +8,11 @@ int main(int argc, char* argv[]) try { cxxopts::Options options("benchmark_vanilla_genai", "Help command"); options.add_options() - ("p,prompt", "Prompt", cxxopts::value()->default_value("The Sky is blue because")) ("m,model", "Path to model and tokenizers base directory", cxxopts::value()->default_value(".")) + ("p,prompt", "Prompt", cxxopts::value()->default_value("The Sky is blue because")) ("nw,num_warmup", "Number of warmup iterations", cxxopts::value()->default_value(std::to_string(1))) - ("n,num_iter", "Number of iterations", cxxopts::value()->default_value(std::to_string(20))) - ("mt,max_new_tokens", "Number of iterations", cxxopts::value()->default_value(std::to_string(20))) + ("n,num_iter", "Number of iterations", cxxopts::value()->default_value(std::to_string(3))) + ("mt,max_new_tokens", "Maximal number of new tokens", cxxopts::value()->default_value(std::to_string(20))) ("d,device", "device", cxxopts::value()->default_value("CPU")) ("h,help", "Print usage"); @@ -38,6 +38,8 @@ int main(int argc, char* argv[]) try { ov::genai::GenerationConfig config; config.max_new_tokens = result["max_new_tokens"].as(); + config.num_beam_groups = 3; + config.num_beams = 15; ov::genai::LLMPipeline pipe(model_path, device); @@ -45,10 +47,10 @@ int main(int argc, char* argv[]) try { pipe.generate(prompt, config); ov::genai::DecodedResults res = pipe.generate(prompt, config); - ov::genai::PerfMetrics metrics = res.metrics; + ov::genai::PerfMetrics metrics = res.perf_metrics; for (size_t i = 0; i < num_iter - 1; i++) { res = pipe.generate(prompt, config); - metrics = metrics + res.metrics; + metrics = metrics + res.perf_metrics; } std::cout << "Load time: " << metrics.load_time << " ms" << std::endl; diff --git a/samples/cpp/benchmark_vanilla_genai/README.md b/samples/cpp/benchmark_vanilla_genai/README.md deleted file mode 100644 index 50197dad1d..0000000000 --- a/samples/cpp/benchmark_vanilla_genai/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# benchmark OpenVINO GenAI sample - -TODO: adapt from python sample to c++ \ No newline at end of file diff --git a/samples/python/benchmark_vanilla_genai/README.md b/samples/python/benchmark_genai/README.md similarity index 64% rename from samples/python/benchmark_vanilla_genai/README.md rename to samples/python/benchmark_genai/README.md index 13666a7de9..fa4fa85576 100644 --- a/samples/python/benchmark_vanilla_genai/README.md +++ b/samples/python/benchmark_genai/README.md @@ -1,28 +1,7 @@ -# Benchmark Vanilla GenAI +# Benchmarking Vanilla GenAI This sample script demonstrates how to benchmark an LLMModel in OpenVINO GenAI. The script includes functionality for warm-up iterations, generating text, and calculating various performance metrics. -# ov.genai.PerfMetrics structure -ov.genai.PerfMetrics is a structure which holds performance metric for each generate call. Each generate call calcualtes the following metrics: -- mean_ttft - - std_ttft - - mean_tpot - - std_tpot - - load_time - - mean_generate_duration - - std_generate_duration - - mean_tokenization_duration - - std_tokenization_duration - - mean_detokenization_duration - - std_detokenization_duration - - mean_throughput - - std_throughput - - num_generated_tokens - - num_input_tokens - -Performance metrics can be added to one another and accumulated using the += operator or the + operator. In that case the mean values accumulated by several generate calls will be calculated. - - ## Download and convert the model and tokenizers The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version. @@ -45,14 +24,14 @@ python benchmark_vanilla_genai.py [OPTIONS] - `-m, --model`: Path to the model and tokenizers base directory. - `-p, --prompt` (default: `"The Sky is blue because"`): The prompt to generate text. - `-nw, --num_warmup` (default: `1`): Number of warmup iterations. -- `-mt, --max_new_tokens` (default: `20`): Number of warmup iterations. - `-n, --num_iter` (default: `3`): Number of iterations. +- `-mt, --max_new_tokens` (default: `20`): Number of warmup iterations. - `-d, --device` (default: `"CPU"`): Device to run the model on. ### Output: ``` -python benchmark_vanilla_genai.py -m TinyLlama-1.1B-Chat-v1.0/ +python benchmark_vanilla_genai.py -m TinyLlama-1.1B-Chat-v1.0 -n 10 ``` ``` @@ -64,4 +43,5 @@ TTFT: 81.60 ± 0.54 ms TPOT: 71.52 ± 2.72 ms Throughput tokens/s: 13.98 ± 0.53 ``` -s \ No newline at end of file + +For more information on how performance metrics are calculated, see [performance metrics readme](../../../src/README.md#performance-metrics). diff --git a/samples/python/benchmark_vanilla_genai/benchmark_vanilla_genai.py b/samples/python/benchmark_genai/benchmark_genai.py similarity index 58% rename from samples/python/benchmark_vanilla_genai/benchmark_vanilla_genai.py rename to samples/python/benchmark_genai/benchmark_genai.py index 9e4debe847..06bd8b0f48 100755 --- a/samples/python/benchmark_vanilla_genai/benchmark_vanilla_genai.py +++ b/samples/python/benchmark_genai/benchmark_genai.py @@ -3,7 +3,6 @@ import argparse import openvino_genai as ov_genai -import pdb def main(): parser = argparse.ArgumentParser(description="Help command") @@ -16,6 +15,8 @@ def main(): args = parser.parse_args() + # Perf metrics is stored in DecodedResults. + # In order to get DecodedResults instead of a string input should be a list. prompt = [args.prompt] model_path = args.model device = args.device @@ -24,6 +25,8 @@ def main(): config = ov_genai.GenerationConfig() config.max_new_tokens = args.max_new_tokens + config.num_beam_groups = 3 + config.num_beams = 15 pipe = ov_genai.LLMPipeline(model_path, device) @@ -31,19 +34,18 @@ def main(): pipe.generate(prompt, config) res = pipe.generate(prompt, config) - metrics = res.metrics + perf_metrics = res.perf_metrics for _ in range(num_iter - 1): - # pdb.set_trace() res = pipe.generate(prompt, config) - metrics += res.metrics + perf_metrics += res.perf_metrics - print(f"Load time: {metrics.load_time:.2f} ms") - print(f"Generate time: {metrics.mean_generate_duration:.2f} ± {metrics.std_generate_duration:.2f} ms") - print(f"Tokenization time: {metrics.mean_tokenization_duration:.2f} ± {metrics.std_tokenization_duration:.2f} ms") - print(f"Detokenization time: {metrics.mean_detokenization_duration:.2f} ± {metrics.std_detokenization_duration:.2f} ms") - print(f"TTFT: {metrics.mean_ttft:.2f} ± {metrics.std_ttft:.2f} ms") - print(f"TPOT: {metrics.mean_tpot:.2f} ± {metrics.std_tpot:.2f} ms") - print(f"Throughput tokens/s: {metrics.mean_throughput:.2f} ± {metrics.std_throughput:.2f}") + print(f"Load time: {perf_metrics.load_time:.2f} ms") + print(f"Generate time: {perf_metrics.mean_generate_duration:.2f} ± {perf_metrics.std_generate_duration:.2f} ms") + print(f"Tokenization time: {perf_metrics.mean_tokenization_duration:.2f} ± {perf_metrics.std_tokenization_duration:.2f} ms") + print(f"Detokenization time: {perf_metrics.mean_detokenization_duration:.2f} ± {perf_metrics.std_detokenization_duration:.2f} ms") + print(f"TTFT: {perf_metrics.mean_ttft:.2f} ± {perf_metrics.std_ttft:.2f} ms") + print(f"TPOT: {perf_metrics.mean_tpot:.2f} ± {perf_metrics.std_tpot:.2f} ms") + print(f"Throughput tokens/s: {perf_metrics.mean_throughput:.2f} ± {perf_metrics.std_throughput:.2f}") if __name__ == "__main__": main() diff --git a/samples/python/benchmark_genai/benchmark_genai_automatic.py b/samples/python/benchmark_genai/benchmark_genai_automatic.py new file mode 100755 index 0000000000..98a00a8c99 --- /dev/null +++ b/samples/python/benchmark_genai/benchmark_genai_automatic.py @@ -0,0 +1,62 @@ +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import openvino_genai as ov_genai +import pdb + +def main(): + parser = argparse.ArgumentParser(description="Help command") + parser.add_argument("-m", "--model", type=str, help="Path to model and tokenizers base directory") + parser.add_argument("-p", "--prompt", type=str, default="The Sky is blue because", help="Prompt") + parser.add_argument("-nw", "--num_warmup", type=int, default=1, help="Number of warmup iterations") + parser.add_argument("-n", "--num_iter", type=int, default=5, help="Number of iterations") + parser.add_argument("-mt", "--max_new_tokens", type=int, default=20, help="Maximal number of new tokens") + parser.add_argument("-d", "--device", type=str, default="CPU", help="Device") + + args = parser.parse_args() + + # Perf metrics is stored in DecodedResults. + # In order to get DecodedResults instead of a string input should be a list. + + model_path = args.model + device = args.device + num_warmup = args.num_warmup + num_iter = args.num_iter + + config = ov_genai.GenerationConfig() + config.max_new_tokens = 20 + # config.num_beam_groups = 3 + # config.num_beams = 15 + + pipe = ov_genai.LLMPipeline(model_path, device) + + import pandas as pd + metrics_df = pd.DataFrame(columns=['batch_size', 'throughput', 'ttft', 'tpot', 'std_throughput', 'std_ttft', 'std_tpot']) + + batch_sizes = [1, 2, 4, 16, 32, 64, 256] + for batch_size in batch_sizes: + prompt = [args.prompt] * batch_size + for _ in range(num_warmup): + pipe.generate(prompt, config) + + res = pipe.generate(prompt, config) + metrics = res.metrics + for _ in range(num_iter - 1): + res = pipe.generate(prompt, config) + metrics += res.metrics + # pdb.set_trace() + metrics_df = metrics_df._append({ + 'batch_size': batch_size, + 'throughput': metrics.mean_throughput, + 'ttft': metrics.mean_ttft, + 'tpot': metrics.mean_tpot, + 'std_throughput': metrics.std_throughput, + 'std_ttft': metrics.std_ttft, + 'std_tpot': metrics.std_tpot, + }, ignore_index=True) + + metrics_df.to_csv('metrics.csv', index=False) + +if __name__ == "__main__": + main() diff --git a/src/README.md b/src/README.md index 445b88aa58..a5530ea578 100644 --- a/src/README.md +++ b/src/README.md @@ -196,6 +196,55 @@ int main(int argc, char* argv[]) { } ``` +### Performance Metrics + +`ov.genai.PerfMetrics` (referred to as `PerfMetrics` for simplicity) is a structure that holds performance metrics for each generate call. `PerfMetrics` hold fields with mean and standard deviations for the following metrics: +- `ttft` +- `tpot` +- `load_time` +- `generate_duration` +- `tokenization_duration` +- `detokenization_duration` +- `throughput` + +and: +- `num_generated_tokens` +- `num_input_tokens` + +Performance metrics are stored either in the `DecodedResults` or `EncodedResults` `perf_metric` field. Additionally to the fields mentioned above, `PerfMetrics` has a member `raw_metrics` of type `ov.genai.RawPerfMetrics` (referred to as `RawPerfMetrics` for simplicity) that contains raw values for the durations of each batch of new token generation, tokenization durations, detokenization durations, and more. These raw metrics are accessible if you wish to calculate your own statistical values such as median or percentiles. However, since mean and standard deviation values are usually sufficient, we will focus on `PerfMetrics`. + +```python +import openvino_genai as ov_genai +pipe = ov_genai.LLMPipeline(model_path, "CPU") +res = pipe.generate(["The Sun is yellow because"], max_new_tokens=20) +perf_metrics = res.perf_metrics +print(f'generate_duration: {perf_metrics.mean_generate_duration:.2f}') +print(f'ttft: {perf_metrics.mean_ttft:.2f}') +print(f'tpot: {perf_metrics.mean_tpot:.2f}') +``` +output: +```sh +mean_generate_duration: 76.28 +mean_ttft: 42.58 +mean_tpot 3.80 +``` + +>**Note**: If the input prompt is just a string, the generate function will return only a string without perf_metrics. To obtain perf_metrics, provide the prompt as a list with at least one element or call generate with encoded inputs. + +Several `perf_metrics` can be added with each other. In that case `raw_metrics` will be concatenated and mean/std values will be recalculated. This enhances benchmarking and accumulating statistics from several calls. + +```python +import openvino_genai as ov_genai +pipe = ov_genai.LLMPipeline(model_path, "CPU") +res_1 = pipe.generate(["The Sun is yellow because"], max_new_tokens=20) +res_2 = pipe.generate(["Why Sky is blue because"], max_new_tokens=20) +perf_metrics = res_1.perf_metrics + res_2.perf_metrics + +print(f'generate_duration: {perf_metrics.mean_generate_duration:.2f}') +print(f'ttft: {perf_metrics.mean_ttft:.2f}') +print(f'tpot: {perf_metrics.mean_tpot:.2f}') +``` + ## How It Works For information on how OpenVINO™ GenAI works, refer to the [How It Works Section](https://github.com/openvinotoolkit/openvino.genai/tree/releases/2024/2/src/docs/HOW_IT_WORKS.md). diff --git a/src/cpp/include/openvino/genai/llm_pipeline.hpp b/src/cpp/include/openvino/genai/llm_pipeline.hpp index 14100d4f16..4be298128e 100644 --- a/src/cpp/include/openvino/genai/llm_pipeline.hpp +++ b/src/cpp/include/openvino/genai/llm_pipeline.hpp @@ -37,7 +37,7 @@ class EncodedResults { public: std::vector> tokens; std::vector scores; - PerfMetrics metrics; + PerfMetrics perf_metrics; }; /** @@ -52,7 +52,7 @@ class DecodedResults { public: std::vector texts; std::vector scores; - PerfMetrics metrics; + PerfMetrics perf_metrics; // @brief Convert DecodedResults to a string. operator std::string() const { diff --git a/src/cpp/include/openvino/genai/perf_metrics.hpp b/src/cpp/include/openvino/genai/perf_metrics.hpp index 5779b9b080..44535cf3a2 100644 --- a/src/cpp/include/openvino/genai/perf_metrics.hpp +++ b/src/cpp/include/openvino/genai/perf_metrics.hpp @@ -37,23 +37,25 @@ struct OPENVINO_GENAI_EXPORTS RawPerfMetrics { * */ struct OPENVINO_GENAI_EXPORTS PerfMetrics { - // First token time. + // Load time in ms. + float load_time; + + // First token time (in ms). float mean_ttft; float std_ttft; - // Time per output token. + // Time (in ms) per output token. float mean_tpot; float std_tpot; - float load_time; - float mean_generate_duration; float std_generate_duration; - float mean_tokenization_duration; - float std_tokenization_duration; - float mean_detokenization_duration; - float std_detokenization_duration; - + float mean_tokenization_duration = -1; + float std_tokenization_duration = -1; + float mean_detokenization_duration = -1; + float std_detokenization_duration = -1; + + // Tokens per second. float mean_throughput; float std_throughput; @@ -61,11 +63,11 @@ struct OPENVINO_GENAI_EXPORTS PerfMetrics { size_t num_input_tokens; void evaluate_statistics(std::optional start_time = std::nullopt); - static float get_duration_ms(std::chrono::steady_clock::duration duration); + static float get_microsec(std::chrono::steady_clock::duration duration); PerfMetrics operator+(const PerfMetrics& metrics) const; PerfMetrics& operator+=(const PerfMetrics& right); - RawPerfMetrics raw_counters; + RawPerfMetrics raw_metrics; }; } // namespace genai diff --git a/src/cpp/src/greedy_decoding.cpp b/src/cpp/src/greedy_decoding.cpp index c8fd36cbdd..8b0cf19c1f 100644 --- a/src/cpp/src/greedy_decoding.cpp +++ b/src/cpp/src/greedy_decoding.cpp @@ -24,7 +24,7 @@ EncodedResults greedy_decoding( // Initialize results and performance metrics. EncodedResults results; - auto& raw_perf_counters = results.metrics.raw_counters; + auto& raw_perf_counters = results.perf_metrics.raw_metrics; results.scores.resize(running_batch_size); results.tokens.resize(running_batch_size); diff --git a/src/cpp/src/group_beam_searcher.cpp b/src/cpp/src/group_beam_searcher.cpp index 784ff1a915..1b9729b2f6 100644 --- a/src/cpp/src/group_beam_searcher.cpp +++ b/src/cpp/src/group_beam_searcher.cpp @@ -444,7 +444,7 @@ std::pair beam_search(ov::InferRequest& lm, int32_t res_selected_beam_idx = 0; results.scores.reserve(config.num_return_sequences * result.size()); results.tokens.reserve(config.num_return_sequences * result.size()); - auto& raw_perf_counters = results.metrics.raw_counters; + auto& raw_perf_counters = results.perf_metrics.raw_metrics; raw_perf_counters.m_new_token_times = new_token_times; raw_perf_counters.m_batch_sizes = batch_sizes; diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index adac9110e1..1c1bd5ccd8 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -161,16 +161,16 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { } // generate_durations - decoded_results.metrics = encoded_results.metrics; + decoded_results.perf_metrics = encoded_results.perf_metrics; - auto& raw_counters = decoded_results.metrics.raw_counters; + auto& raw_counters = decoded_results.perf_metrics.raw_metrics; auto stop_time = std::chrono::steady_clock::now(); raw_counters.generate_durations = std::vector(); - raw_counters.generate_durations.emplace_back(PerfMetrics::get_duration_ms(stop_time - start_time)); - raw_counters.tokenization_durations.emplace_back(PerfMetrics::get_duration_ms(encode_stop_time - start_time)); - raw_counters.detokenization_durations.emplace_back(PerfMetrics::get_duration_ms(decode_stop_time - decode_start_time)); + raw_counters.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); + raw_counters.tokenization_durations.emplace_back(PerfMetrics::get_microsec(encode_stop_time - start_time)); + raw_counters.detokenization_durations.emplace_back(PerfMetrics::get_microsec(decode_stop_time - decode_start_time)); - decoded_results.metrics.evaluate_statistics(start_time); + decoded_results.perf_metrics.evaluate_statistics(start_time); return decoded_results; } @@ -272,10 +272,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { auto stop_time = std::chrono::steady_clock::now(); // If is called without tokenization then that stat will not be reported. - auto& metrics = result.metrics; + auto& metrics = result.perf_metrics; metrics.num_input_tokens = batch_size * input_ids.get_shape().at(1); metrics.load_time = this->m_load_time_ms; - metrics.raw_counters.generate_durations.emplace_back(PerfMetrics::get_duration_ms(stop_time - start_time)); + metrics.raw_metrics.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); metrics.evaluate_statistics(start_time); return result; } @@ -393,7 +393,7 @@ ov::genai::LLMPipeline::LLMPipeline( m_pimpl = make_unique(std::filesystem::path(path), device, config); } auto stop_time = std::chrono::steady_clock::now(); - m_pimpl->m_load_time_ms = PerfMetrics::get_duration_ms(stop_time - start_time) / 1000.0f; + m_pimpl->m_load_time_ms = PerfMetrics::get_microsec(stop_time - start_time) / 1000.0f; } ov::genai::GenerationConfig ov::genai::LLMPipeline::get_generation_config() const { diff --git a/src/cpp/src/multinomial_decoding.cpp b/src/cpp/src/multinomial_decoding.cpp index fc59f00e12..b00c62aed7 100644 --- a/src/cpp/src/multinomial_decoding.cpp +++ b/src/cpp/src/multinomial_decoding.cpp @@ -164,7 +164,7 @@ ov::genai::EncodedResults multinominal_decoding(ov::InferRequest& m_model_runner // Initialize results and performance metrics. EncodedResults results; - auto& raw_perf_counters = results.metrics.raw_counters; + auto& raw_perf_counters = results.perf_metrics.raw_metrics; results.scores.resize(batch_size, 0); results.tokens.resize(batch_size); diff --git a/src/cpp/src/perf_metrics.cpp b/src/cpp/src/perf_metrics.cpp index d4dc6c8de6..c319032449 100644 --- a/src/cpp/src/perf_metrics.cpp +++ b/src/cpp/src/perf_metrics.cpp @@ -9,18 +9,18 @@ namespace { -// std::pair calc_mean_and_std(const std::vector& durations) { std::pair calc_mean_and_std(const std::vector& durations) { + // Accepts time durations in microseconds and returns standard deviation and mean in milliseconds. float mean = std::accumulate(durations.begin(), durations.end(), 0.0f, [](const float& acc, const ov::genai::MicroSeconds& duration) -> float { - return acc + duration.count(); + return acc + duration.count() / 1000.0f; }); mean /= durations.size(); - mean /= 1000.f; float sum_square_durations = std::accumulate(durations.begin(), durations.end(), 0.0f, [](const float& acc, const ov::genai::MicroSeconds& duration) -> float { - return acc + duration.count() * duration.count() / 1000000.0f; + auto d = duration.count() / 1000.0f; + return acc + d * d; }); float std = std::sqrt(sum_square_durations / durations.size() - mean * mean); return {mean, std}; @@ -32,7 +32,7 @@ std::pair calc_mean_and_std(const std::vector(duration).count(); } @@ -40,33 +40,33 @@ void PerfMetrics::evaluate_statistics(std::optional start_time) { // If start_tiem is specified then recalcualte durations according to start times and calculate statistics only after that. if (start_time.has_value()) { auto start_time_val = *start_time; - auto& tok_times = raw_counters.m_new_token_times; - auto& batch_sizes = raw_counters.m_batch_sizes; - raw_counters.m_durations = std::vector(tok_times.size()); + auto& tok_times = raw_metrics.m_new_token_times; + auto& batch_sizes = raw_metrics.m_batch_sizes; + raw_metrics.m_durations = std::vector(tok_times.size()); auto ttft = tok_times[0] - start_time_val; - raw_counters.m_times_to_first_token = std::vector(); - raw_counters.m_times_to_first_token.emplace_back(ttft); + raw_metrics.m_times_to_first_token = std::vector(); + raw_metrics.m_times_to_first_token.emplace_back(ttft); num_generated_tokens = 0; for (size_t i = 0; i < tok_times.size(); ++i) { - raw_counters.m_durations[i] = tok_times[i] - start_time_val; + raw_metrics.m_durations[i] = tok_times[i] - start_time_val; - // If in 10 ms a batch of 5 new tokens is generated then TTOT is 10 ms / 5. - // todo: float check that it's valid for batch > 1. - raw_counters.m_durations[i] /= batch_sizes[i]; + // If in 10 ms a batch of 5 new tokens is generated then TPOT is 10 / 5 = 2 tok/ms. + raw_metrics.m_durations[i] /= batch_sizes[i]; num_generated_tokens += batch_sizes[i]; start_time_val = tok_times[i]; } } + + // calc_mean_and_std will convert microsecond to milliseconds. + std::tie(mean_tpot, std_tpot) = calc_mean_and_std(raw_metrics.m_durations); + std::tie(mean_ttft, std_ttft) = calc_mean_and_std(raw_metrics.m_times_to_first_token); - std::tie(mean_tpot, std_tpot) = calc_mean_and_std(raw_counters.m_durations); - std::tie(mean_ttft, std_ttft) = calc_mean_and_std(raw_counters.m_times_to_first_token); - - std::tie(mean_generate_duration, std_generate_duration) = calc_mean_and_std(raw_counters.generate_durations); - std::tie(mean_tokenization_duration, std_tokenization_duration) = calc_mean_and_std(raw_counters.tokenization_durations); - std::tie(mean_detokenization_duration, std_detokenization_duration) = calc_mean_and_std(raw_counters.detokenization_durations); + std::tie(mean_generate_duration, std_generate_duration) = calc_mean_and_std(raw_metrics.generate_durations); + std::tie(mean_tokenization_duration, std_tokenization_duration) = calc_mean_and_std(raw_metrics.tokenization_durations); + std::tie(mean_detokenization_duration, std_detokenization_duration) = calc_mean_and_std(raw_metrics.detokenization_durations); - mean_throughput = 1000.0f / mean_tpot; + mean_throughput = 1000.0f / mean_tpot; // tokens per second std_throughput = (std_tpot * 1000.0f) / (mean_tpot * mean_tpot); } @@ -76,22 +76,25 @@ PerfMetrics PerfMetrics::operator+(const PerfMetrics& right) const { // Copy left value to res. PerfMetrics res = *this; - // Concatenate duration and first token times. - auto& new_durations = res.raw_counters.m_durations; - auto& new_times_to_first_token = res.raw_counters.m_times_to_first_token; - auto& right_durations = right.raw_counters.m_durations; - auto& right_times_to_first_token = right.raw_counters.m_times_to_first_token; + // Concatenate durations, batch_sizes first token times. + auto& new_durations = res.raw_metrics.m_durations; + auto& new_batch_sizes = res.raw_metrics.m_batch_sizes; + auto& new_times_to_first_token = res.raw_metrics.m_times_to_first_token; + auto& right_durations = right.raw_metrics.m_durations; + auto& right_batch_sizes = right.raw_metrics.m_batch_sizes; + auto& right_times_to_first_token = right.raw_metrics.m_times_to_first_token; new_durations.insert(new_durations.end(), right_durations.begin(), right_durations.end()); new_times_to_first_token.insert(new_times_to_first_token.end(), right_times_to_first_token.begin(), right_times_to_first_token.end()); + new_batch_sizes.insert(new_batch_sizes.end(), right_batch_sizes.begin(), right_batch_sizes.end()); // Concatenate tokenization/detokenization and total generation times. - auto& new_tok_durations = res.raw_counters.tokenization_durations; - auto& new_detok_durations = res.raw_counters.detokenization_durations; - auto& new_gen_durations = res.raw_counters.generate_durations; - auto& right_tok_durations = right.raw_counters.tokenization_durations; - auto& right_detok_durations = right.raw_counters.detokenization_durations; - auto& right_gen_durations = right.raw_counters.generate_durations; + auto& new_tok_durations = res.raw_metrics.tokenization_durations; + auto& new_detok_durations = res.raw_metrics.detokenization_durations; + auto& new_gen_durations = res.raw_metrics.generate_durations; + auto& right_tok_durations = right.raw_metrics.tokenization_durations; + auto& right_detok_durations = right.raw_metrics.detokenization_durations; + auto& right_gen_durations = right.raw_metrics.generate_durations; new_tok_durations.insert(new_tok_durations.end(), right_tok_durations.begin(), right_tok_durations.end()); new_detok_durations.insert(new_detok_durations.end(), right_detok_durations.begin(), right_detok_durations.end()); diff --git a/src/python/py_generate_pipeline.cpp b/src/python/py_generate_pipeline.cpp index e2f89cd962..6c88b3ffcc 100644 --- a/src/python/py_generate_pipeline.cpp +++ b/src/python/py_generate_pipeline.cpp @@ -38,6 +38,17 @@ using PyBindStreamerVariant = std::variant, std::sh template struct overloaded : Ts... { using Ts::operator()...; }; template overloaded(Ts...) -> overloaded; +template +std::vector get_ms(const T& instance, U T::*member) { + // Converts c++ duration to float so that it can be used in Python. + std::vector res; + const auto& durations = instance.*member; + res.reserve(durations.size()); + std::transform(durations.begin(), durations.end(), std::back_inserter(res), + [](const auto& duration) { return duration.count(); }); + return res; +} + namespace { auto generate_docstring = R"( @@ -536,17 +547,25 @@ PYBIND11_MODULE(py_generate_pipeline, m) { .def(py::init<>()) .def_property_readonly("texts", [](const DecodedResults &dr) { return handle_utf8_results(dr); }) .def_readonly("scores", &DecodedResults::scores) - .def_readonly("metrics", &DecodedResults::metrics) + .def_readonly("perf_metrics", &DecodedResults::perf_metrics) .def("__str__", &DecodedResults::operator std::string); py::class_(m, "RawPerfMetrics") .def(py::init<>()) .def_readonly("generate_durations", &RawPerfMetrics::generate_durations) - .def_readonly("tokenization_durations", &RawPerfMetrics::tokenization_durations) - .def_readonly("detokenization_durations", &RawPerfMetrics::detokenization_durations) - .def_readonly("m_times_to_first_token", &RawPerfMetrics::m_times_to_first_token) + .def_property_readonly("tokenization_durations", [](const RawPerfMetrics &rw) { + return get_ms(rw, &RawPerfMetrics::tokenization_durations); + }) + .def_property_readonly("detokenization_durations", [](const RawPerfMetrics &rw) { + return get_ms(rw, &RawPerfMetrics::detokenization_durations); + }) + .def_property_readonly("m_times_to_first_token", [](const RawPerfMetrics &rw) { + return get_ms(rw, &RawPerfMetrics::m_times_to_first_token); + }) + .def_property_readonly("m_durations", [](const RawPerfMetrics &rw) { + return get_ms(rw, &RawPerfMetrics::m_durations); + }) .def_readonly("m_batch_sizes", &RawPerfMetrics::m_batch_sizes) - .def_readonly("m_durations", &RawPerfMetrics::m_durations) .def_readonly("num_generated_tokens", &RawPerfMetrics::num_generated_tokens) .def_readonly("num_input_tokens", &RawPerfMetrics::num_input_tokens); @@ -567,7 +586,7 @@ PYBIND11_MODULE(py_generate_pipeline, m) { .def_readonly("load_time", &PerfMetrics::load_time) .def("__add__", &PerfMetrics::operator+) .def("__iadd__", &PerfMetrics::operator+=) - .def_readonly("raw_counters", &PerfMetrics::raw_counters) + .def_readonly("raw_metrics", &PerfMetrics::raw_metrics) ; py::class_(m, "TokenizedInputs") @@ -578,7 +597,7 @@ PYBIND11_MODULE(py_generate_pipeline, m) { py::class_(m, "EncodedResults") .def_readonly("tokens", &EncodedResults::tokens) .def_readonly("scores", &EncodedResults::scores) - .def_readonly("metrics", &EncodedResults::metrics); + .def_readonly("perf_metrics", &EncodedResults::perf_metrics); py::class_>(m, "StreamerBase") // Change the holder form unique_ptr to shared_ptr .def(py::init<>())