From 7cab496c63a598dcb96027c9a88d3c96ef1b5b48 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Fri, 19 Jul 2024 13:01:02 +0200 Subject: [PATCH] add detokenization metric; refactor split to perf_conter & perf_metrics --- .../benchmark_vanilla_genai.cpp | 8 +- .../openvino/genai/generation_metrics.hpp | 40 --------- .../include/openvino/genai/llm_pipeline.hpp | 6 +- .../include/openvino/genai/perf_metrics.hpp | 50 ++++++++++++ src/cpp/src/generation_metrics.cpp | 62 -------------- src/cpp/src/greedy_decoding.cpp | 19 ++--- src/cpp/src/group_beam_searcher.cpp | 19 +++-- src/cpp/src/llm_pipeline.cpp | 30 +++++-- src/cpp/src/perf_counters.cpp | 21 +++++ src/cpp/src/perf_counters.hpp | 44 ++++++++++ src/cpp/src/perf_metrics.cpp | 81 +++++++++++++++++++ src/cpp/src/tokenizer.cpp | 2 + src/cpp/src/utils.hpp | 14 ++++ src/python/py_generate_pipeline.cpp | 14 ++++ tests/python_tests/ov_genai_test_utils.py | 2 + 15 files changed, 282 insertions(+), 130 deletions(-) delete mode 100644 src/cpp/include/openvino/genai/generation_metrics.hpp create mode 100644 src/cpp/include/openvino/genai/perf_metrics.hpp delete mode 100644 src/cpp/src/generation_metrics.cpp create mode 100644 src/cpp/src/perf_counters.cpp create mode 100644 src/cpp/src/perf_counters.hpp create mode 100644 src/cpp/src/perf_metrics.cpp diff --git a/samples/cpp/benchmark_vanilla_genai/benchmark_vanilla_genai.cpp b/samples/cpp/benchmark_vanilla_genai/benchmark_vanilla_genai.cpp index ccb7650b84..6489282b0b 100644 --- a/samples/cpp/benchmark_vanilla_genai/benchmark_vanilla_genai.cpp +++ b/samples/cpp/benchmark_vanilla_genai/benchmark_vanilla_genai.cpp @@ -37,23 +37,25 @@ int main(int argc, char* argv[]) try { ov::genai::GenerationConfig config; config.max_new_tokens = 100; + config.num_beam_groups = 3; + config.num_beams = 15; ov::genai::LLMPipeline pipe(model_path, device); for (size_t i = 0; i < num_warmup; i++) pipe.generate(prompt, config); - ov::genai::GenerationMetrics metrics; + ov::genai::PerfMetrics metrics; for (size_t i = 0; i < num_iter; i++) { ov::genai::DecodedResults res = pipe.generate(prompt, config); metrics = metrics + res.metrics; metrics.load_time = res.metrics.load_time; } - + std::cout << "Load time: " << metrics.load_time << " ms" << std::endl; std::cout << "ttft: " << metrics.mean_ttft << " ± " << metrics.std_ttft << " ms" << std::endl; std::cout << "tpot: " << metrics.mean_tpot << " ± " << metrics.std_tpot << " ms" << std::endl; - std::cout << "Tokens/s: " << metrics.get_tokens_per_sec().first << std::endl; + std::cout << "Tokens/s: " << metrics.mean_throughput << std::endl; return 0; } catch (const std::exception& error) { diff --git a/src/cpp/include/openvino/genai/generation_metrics.hpp b/src/cpp/include/openvino/genai/generation_metrics.hpp deleted file mode 100644 index 7129e5c52b..0000000000 --- a/src/cpp/include/openvino/genai/generation_metrics.hpp +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (C) 2023-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include -#include -#include -#include - -namespace ov { -namespace genai { - -using TimePoints = std::vector; - -struct GenerationMetrics { - GenerationMetrics() = default; - - GenerationMetrics(const TimePoints& tok_times, size_t batch_size = 1); - GenerationMetrics(const std::vector& durations, const std::vector& times_to_first_token, size_t batch_size = 1); - - // First token time. - float mean_ttft; - float std_ttft; - std::vector times_to_first_token; - - // Time per output token. - float mean_tpot; - float std_tpot; - std::vector durations; - - std::pair get_tokens_per_sec() const; - size_t batch_size; - float load_time; - - GenerationMetrics operator+(GenerationMetrics const& metrics) const; -}; - -} // namespace genai -} // namespace ov diff --git a/src/cpp/include/openvino/genai/llm_pipeline.hpp b/src/cpp/include/openvino/genai/llm_pipeline.hpp index 9f0c9fba97..4db3c613e7 100644 --- a/src/cpp/include/openvino/genai/llm_pipeline.hpp +++ b/src/cpp/include/openvino/genai/llm_pipeline.hpp @@ -11,7 +11,7 @@ #include "openvino/genai/generation_config.hpp" #include "openvino/genai/tokenizer.hpp" #include "openvino/genai/streamer_base.hpp" -#include "openvino/genai/generation_metrics.hpp" +#include "openvino/genai/perf_metrics.hpp" namespace ov { namespace genai { @@ -36,7 +36,7 @@ class EncodedResults { public: std::vector> tokens; std::vector scores; - GenerationMetrics metrics; + PerfMetrics metrics; }; /** @@ -50,7 +50,7 @@ class DecodedResults { public: std::vector texts; std::vector scores; - GenerationMetrics metrics; + PerfMetrics metrics; // @brief Convert DecodedResults to a string. operator std::string() const { diff --git a/src/cpp/include/openvino/genai/perf_metrics.hpp b/src/cpp/include/openvino/genai/perf_metrics.hpp new file mode 100644 index 0000000000..a11c4e0374 --- /dev/null +++ b/src/cpp/include/openvino/genai/perf_metrics.hpp @@ -0,0 +1,50 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "openvino/genai/visibility.hpp" +#include +#include + +namespace ov { +namespace genai { + +using TimePoint = std::chrono::steady_clock::time_point; + +struct PerfCounters; + +struct OPENVINO_GENAI_EXPORTS PerfMetrics { + // First token time. + float mean_ttft; + float std_ttft; + + // Time per output token. + float mean_tpot; + float std_tpot; + + float load_time; + float start_time; + + float mean_generate_duration; + float mean_decoding_duration; + float mean_encoding_duration; + + float mean_throughput; + float std_throughput; + + size_t num_generated_tokens; + size_t num_input_tokens; + + std::shared_ptr m_counters; + void evaluate(TimePoint start_time); + + PerfMetrics operator+(const PerfMetrics& metrics) const; + PerfMetrics& operator+=(const PerfMetrics& right); + + +}; + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/generation_metrics.cpp b/src/cpp/src/generation_metrics.cpp deleted file mode 100644 index 8ca8e0a07d..0000000000 --- a/src/cpp/src/generation_metrics.cpp +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (C) 2023-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#include "openvino/genai/generation_metrics.hpp" -#include - -namespace { - -std::pair calc_mean_and_std(const std::vector& durations) { - float mean = std::accumulate(durations.begin(), durations.end(), 0.0f) / durations.size(); - - float sum_square_durations = std::accumulate(durations.begin(), durations.end(), 0.0f, - [](const float& acc, const float& duration) -> float { - return acc + duration * duration; - }); - float std = std::sqrt(sum_square_durations / durations.size() - mean * mean); - return {mean, std}; -} - -} // namespace - -namespace ov { -namespace genai { - - -GenerationMetrics::GenerationMetrics(const TimePoints& tok_times, size_t batch_size) { - this->batch_size = batch_size; - durations = std::vector(tok_times.size() - 1); - for (size_t i = 1; i < tok_times.size(); ++i) { - durations[i - 1] = std::chrono::duration_cast(tok_times[i] - tok_times[i - 1]).count(); - } - times_to_first_token.emplace_back(durations[0]); - - std::tie(mean_tpot, std_tpot) = calc_mean_and_std(durations); - std::tie(mean_ttft, std_ttft) = calc_mean_and_std(times_to_first_token); -} - -GenerationMetrics::GenerationMetrics(const std::vector& durations_, const std::vector& times_to_first_token_, size_t batch_size) - : durations(durations_), times_to_first_token(times_to_first_token_) { - this->batch_size = batch_size; - std::tie(mean_tpot, std_tpot) = calc_mean_and_std(durations); - std::tie(mean_ttft, std_ttft) = calc_mean_and_std(times_to_first_token); -} - -GenerationMetrics GenerationMetrics::operator+(GenerationMetrics const& metrics) const { - std::vector new_durations = durations; - std::vector new_times_to_first_token = times_to_first_token; - new_durations.insert(new_durations.end(), metrics.durations.begin(), metrics.durations.end()); - new_times_to_first_token.insert(new_times_to_first_token.end(), metrics.times_to_first_token.begin(), metrics.times_to_first_token.end()); - - return GenerationMetrics(new_durations, new_times_to_first_token); -} - -std::pair GenerationMetrics::get_tokens_per_sec() const { - auto mean_tps = 1000.0f * batch_size / mean_tpot; - auto std_tps = 1000.0f * std_tpot / (mean_tpot * mean_tpot); - return {mean_tps, std_tps}; -} - - -} // namespace genai -} // namespace ov diff --git a/src/cpp/src/greedy_decoding.cpp b/src/cpp/src/greedy_decoding.cpp index dad93a0e6e..0802b87e66 100644 --- a/src/cpp/src/greedy_decoding.cpp +++ b/src/cpp/src/greedy_decoding.cpp @@ -1,7 +1,8 @@ // Copyright (C) 2023-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 -#include "openvino/genai/llm_pipeline.hpp" +#include "openvino/genai/perf_metrics.hpp" +#include "perf_counters.hpp" #include "utils.hpp" namespace ov { @@ -22,11 +23,8 @@ EncodedResults greedy_decoding( size_t max_new_tokens = generation_config.get_max_new_tokens(prompt_len); EncodedResults results; - // Time before the first token generated as a reference point. - ov::genai::TimePoints tok_times; - tok_times.reserve(max_new_tokens); - tok_times.emplace_back(std::chrono::steady_clock::now()); - + auto& perf_counters = results.metrics.m_counters; + results.scores.resize(running_batch_size); results.tokens.resize(running_batch_size); std::fill(results.scores.begin(), results.scores.end(), 0); @@ -56,8 +54,8 @@ EncodedResults greedy_decoding( eos_met[batch] = (out_token == generation_config.eos_token_id); m_model_runner.get_tensor("input_ids").data()[batch] = out_token; } - tok_times.emplace_back(std::chrono::steady_clock::now()); - + perf_counters->add_timestamp(running_batch_size); + if (streamer && streamer->put(token_iter_results[0])) { return results; } @@ -88,7 +86,7 @@ EncodedResults greedy_decoding( m_model_runner.get_tensor("input_ids").data()[batch] = out_token; } - tok_times.emplace_back(std::chrono::steady_clock::now()); + perf_counters->add_timestamp(running_batch_size); if (streamer && streamer->put(token_iter_results[0])) return results; @@ -116,9 +114,8 @@ EncodedResults greedy_decoding( streamer->end(); } - results.metrics = GenerationMetrics(tok_times); return results; } } //namespace genai -} //namespace ov \ No newline at end of file +} //namespace ov diff --git a/src/cpp/src/group_beam_searcher.cpp b/src/cpp/src/group_beam_searcher.cpp index 8695aeac02..4f5cb79f2a 100644 --- a/src/cpp/src/group_beam_searcher.cpp +++ b/src/cpp/src/group_beam_searcher.cpp @@ -362,14 +362,20 @@ std::pair beam_search(ov::InferRequest& lm, std::optional selected_beam_idx) { OPENVINO_ASSERT(config.num_beams % config.num_beam_groups == 0, "number of beams should be divisible by number of groups"); - - // Initialize beam search + auto batch_size = input_ids.get_shape().at(0); + auto sequence_length = input_ids.get_shape().at(1); + + // Initialize time metric counters. + // ov::genai::TimePoints tok_times; + // tok_times.reserve(config.get_max_new_tokens(sequence_length)); + // tok_times.emplace_back(std::chrono::steady_clock::now()); + + // Initialize beam search. const int64_t* prompt_data = input_ids.data(); std::vector> prompts; prompts.reserve(batch_size); for (size_t batch = 0; batch < batch_size; batch++) { - size_t sequence_length = input_ids.get_shape().at(1); size_t batch_offset = batch * sequence_length; const int64_t* prompt_start = prompt_data + batch_offset; prompts.push_back(std::vector{prompt_start, prompt_start + sequence_length}); @@ -389,7 +395,7 @@ std::pair beam_search(ov::InferRequest& lm, lm.set_tensor("beam_idx", beam_idx); Parameters parameters{std::move(prompts)}; - parameters.max_new_tokens = config.max_new_tokens; + parameters.max_new_tokens = config.get_max_new_tokens(sequence_length); parameters.eos_token_id = config.eos_token_id; parameters.n_groups = config.num_beam_groups; parameters.group_size = config.num_beams / config.num_beam_groups; @@ -406,6 +412,8 @@ std::pair beam_search(ov::InferRequest& lm, lm.infer(); std::tie(next_tokens, next_beams) = group_beam_searcher.select_next_tokens(lm.get_tensor("logits")); + // tok_times.emplace_back(std::chrono::steady_clock::now()); + if (next_tokens.empty() || length_count == parameters.max_new_tokens - 1) { // Break the cycle before masks are extended in update_attention_mask_with_beams. // If generation is continued, attention_mask length should be equal to KV cache size. @@ -462,7 +470,8 @@ std::pair beam_search(ov::InferRequest& lm, results.tokens.push_back(std::move(beam->get().tokens)); } } - + + // results.metrics = PerfCounters(tok_times); return {results, res_selected_beam_idx}; } diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 918e744286..81f807c149 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -1,6 +1,7 @@ // Copyright (C) 2023-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 +#include "perf_counters.hpp" #include #include #include @@ -9,7 +10,7 @@ #include #include "openvino/genai/generation_config.hpp" #include "openvino/genai/llm_pipeline.hpp" -#include "openvino/genai/generation_metrics.hpp" +#include "openvino/genai/perf_metrics.hpp" #include "llm_pipeline_base.hpp" #include "llm_pipeline_static.hpp" #include "utils.hpp" @@ -111,8 +112,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { OptionalGenerationConfig generation_config, StreamerVariant streamer ) override { + auto start_time = std::chrono::steady_clock::now(); GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; - EncodedInputs encoded_input; + TokenizedInputs encoded_input; if (auto input_vector = std::get_if>(&inputs)) { encoded_input = m_tokenizer.encode(*input_vector); @@ -144,9 +146,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { encoded_input = m_tokenizer.encode(prompt); } } + auto encode_stop_time = std::chrono::steady_clock::now(); + auto encoded_results = generate(encoded_input, config, streamer); - auto encoded_results = generate(encoded_input, config, streamer); + auto decode_start_time = std::chrono::steady_clock::now(); DecodedResults decoded_results = {m_tokenizer.decode(encoded_results.tokens), encoded_results.scores}; + auto decode_stop_time = std::chrono::steady_clock::now(); if (is_chat_conversation) { // Tail of chat template is missing in KV cache. @@ -155,9 +160,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { m_templated_chat_history.append(answer); m_history.push_back({{"role", "assistant"}, {"content", answer}}); } + + auto& metrics = encoded_results.metrics; + // metrics.tokenization_duration = std::chrono::duration_cast(encode_stop_time - start_time).count(); + // metrics.detokenization_duration = std::chrono::duration_cast(decode_stop_time - decode_start_time).count(); - decoded_results.metrics = std::move(encoded_results.metrics); - decoded_results.metrics.load_time = m_load_time_ms; + // auto stop_time = std::chrono::steady_clock::now(); + // metrics.generate_durations.emplace_back(std::chrono::duration_cast(stop_time - start_time).count()); + decoded_results.metrics = std::move(metrics); return decoded_results; } @@ -166,9 +176,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { OptionalGenerationConfig generation_config, StreamerVariant streamer ) override { + auto start_time = std::chrono::steady_clock::now(); ov::Tensor input_ids; ov::Tensor attention_mask; - if (auto data = std::get_if(&inputs)) { input_ids = *data; attention_mask = ov::genai::utils::init_attention_mask(input_ids); @@ -256,6 +266,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { } else { m_is_cache_empty = false; } + + + + auto& metrics = result.metrics; + // metrics.batch_size = batch_size; + // metrics.num_generated_tokens = (metrics.m_durations.size() + 1) * batch_size; + metrics.num_input_tokens = batch_size * input_ids.get_shape().at(0); + result.metrics = std::move(metrics); return result; } diff --git a/src/cpp/src/perf_counters.cpp b/src/cpp/src/perf_counters.cpp new file mode 100644 index 0000000000..c9dac6eca0 --- /dev/null +++ b/src/cpp/src/perf_counters.cpp @@ -0,0 +1,21 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "perf_counters.hpp" +#include "openvino/genai/perf_metrics.hpp" +#include "openvino/openvino.hpp" +#include +#include +#include + +namespace ov { +namespace genai { + +void PerfCounters::add_timestamp(size_t batch_size) { + m_new_token_times.emplace_back(std::chrono::steady_clock::now()); + m_batch_sizes.emplace_back(batch_size); +} + + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/perf_counters.hpp b/src/cpp/src/perf_counters.hpp new file mode 100644 index 0000000000..7d33490205 --- /dev/null +++ b/src/cpp/src/perf_counters.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +namespace ov { +namespace genai { + +struct PerfCounters { + std::vector generate_durations; + std::vector tokenization_duration; + std::vector detokenization_duration; + size_t num_generated_tokens; + size_t num_input_tokens; + + std::vector m_batch_sizes; + std::vector m_durations; + std::vector m_times_to_first_token; + std::vector m_new_token_times; + void add_timestamp(size_t batch_size); + // void add_gen_finish_timestamp(size_t batch_size); + +}; + +// class StopWatch { +// TimePoint m_start; +// public: +// StopWatch& start() { +// m_start = std::chrono::steady_clock::now(); +// return *this; +// } + +// float split() { +// std::chrono::steady_clock::time_point curr_time = std::chrono::steady_clock::now(); +// return std::chrono::duration_cast(curr_time - m_start).count(); +// } +// }; + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/perf_metrics.cpp b/src/cpp/src/perf_metrics.cpp new file mode 100644 index 0000000000..4a8b1d76c6 --- /dev/null +++ b/src/cpp/src/perf_metrics.cpp @@ -0,0 +1,81 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "openvino/genai/perf_metrics.hpp" +#include "perf_counters.hpp" +#include "openvino/openvino.hpp" +#include +#include +#include + +namespace { + +std::pair calc_mean_and_std(const std::vector& durations) { + float mean = std::accumulate(durations.begin(), durations.end(), 0.0f) / durations.size(); + + float sum_square_durations = std::accumulate(durations.begin(), durations.end(), 0.0f, + [](const float& acc, const float& duration) -> float { + return acc + duration * duration; + }); + float std = std::sqrt(sum_square_durations / durations.size() - mean * mean); + return {mean, std}; +} + + +} // namespace + +namespace ov { +namespace genai { + +void PerfMetrics::evaluate(TimePoint start_time) { + + auto& tok_times = m_counters->m_new_token_times; + auto& batch_sizes = m_counters->m_batch_sizes; + m_counters->m_durations = std::vector(tok_times.size()); + + auto ttft = std::chrono::duration_cast(tok_times[0] - start_time).count(); + m_counters->m_times_to_first_token.emplace_back(ttft); + + for (size_t i = 0; i < tok_times.size(); ++i) { + m_counters->m_durations[i] = std::chrono::duration_cast(tok_times[i] - start_time).count(); + // If in 10 ms a batch of 5 new tokens is generated then TTOT is 10 ms / 5. + // todo: float check that it's valid for batch > 1. + m_counters->m_durations[i] /= batch_sizes[i]; + start_time = tok_times[i]; + } + + std::tie(mean_tpot, std_tpot) = calc_mean_and_std(m_counters->m_durations); + std::tie(mean_ttft, std_ttft) = calc_mean_and_std(m_counters->m_times_to_first_token); +} + +PerfMetrics PerfMetrics::operator+(const PerfMetrics& metrics) const { + PerfMetrics nm; // new metrics + nm.m_counters = m_counters; + auto& new_counters = nm.m_counters; + + auto& new_durations = new_counters->m_durations; + auto& new_times_to_first_token = new_counters->m_times_to_first_token; + + auto& counters_to_appnd = metrics.m_counters; + new_durations.insert(new_durations.end(), counters_to_appnd->m_durations.begin(), counters_to_appnd->m_durations.end()); + new_times_to_first_token.insert(new_times_to_first_token.end(), counters_to_appnd->m_times_to_first_token.begin(), counters_to_appnd->m_times_to_first_token.end()); + + OPENVINO_ASSERT(metrics.load_time == load_time, "generation metrics can be accumulated only for the same pipeline"); + + std::tie(nm.mean_tpot, nm.std_tpot) = calc_mean_and_std(new_counters->m_durations); + std::tie(nm.mean_ttft, nm.std_ttft) = calc_mean_and_std(new_counters->m_times_to_first_token); + + // todo: add tokenization statistics concatenation. + + return nm; +} + +PerfMetrics& PerfMetrics::operator+=(const PerfMetrics& right) { + *this = *this + right; + return *this; +} + + + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/tokenizer.cpp b/src/cpp/src/tokenizer.cpp index ac6b925dcb..501d0e86cf 100644 --- a/src/cpp/src/tokenizer.cpp +++ b/src/cpp/src/tokenizer.cpp @@ -323,6 +323,8 @@ class Tokenizer::TokenizerImpl { // Replace what jinja2cpp doesn't support std::pair replace_str_map[] = { + {"{-", "{"}, + {"{%-", "{%"}, {"'}", "' }"}, {"{'", "{ '"}, {".strip()", ""} diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index 25acc1c87f..446ef8549b 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -12,6 +12,20 @@ namespace ov { namespace genai { namespace utils { +#include +#include +#include + +// Templated function to measure execution time of an object method. +template +std::pair execution_time_wrapper(T& instance, Ret(T::*method)(Args...), Args&&... args) { + auto start = std::chrono::steady_clock::now(); + Ret result = (instance.*method)(std::forward(args)...); + auto end = std::chrono::steady_clock::now(); + auto duration = std::chrono::duration_cast(end - start).count(); + return {result, duration}; +} + Tensor init_attention_mask(const Tensor& position_ids); void print_tensor(const ov::Tensor& tensor); diff --git a/src/python/py_generate_pipeline.cpp b/src/python/py_generate_pipeline.cpp index d7b2aab29c..c78c760b6c 100644 --- a/src/python/py_generate_pipeline.cpp +++ b/src/python/py_generate_pipeline.cpp @@ -21,6 +21,7 @@ using ov::genai::GenerationConfig; using ov::genai::GenerationResult; using ov::genai::LLMPipeline; using ov::genai::OptionalGenerationConfig; +using ov::genai::PerfMetrics; using ov::genai::SchedulerConfig; using ov::genai::StopCriteria; using ov::genai::StreamerBase; @@ -536,6 +537,19 @@ PYBIND11_MODULE(py_generate_pipeline, m) { .def_readonly("scores", &DecodedResults::scores) .def("__str__", &DecodedResults::operator std::string);; + py::class_(m, "PerfMetrics") + .def(py::init<>()) + .def_readonly("mean_generate_duration", &PerfMetrics::mean_generate_duration) + .def_readonly("mean_decoding_duration", &PerfMetrics::mean_decoding_duration) + .def_readonly("mean_encoding_duration", &PerfMetrics::mean_encoding_duration) + .def_readonly("mean_tpot", &PerfMetrics::mean_tpot) + .def_readonly("mean_ttft", &PerfMetrics::mean_ttft) + .def_readonly("std_tpot", &PerfMetrics::std_tpot) + .def_readonly("std_ttft", &PerfMetrics::std_ttft) + .def_readonly("load_time", &PerfMetrics::load_time) + .def("__add__", &PerfMetrics::operator+) + .def("__iadd__", &PerfMetrics::operator+=); + py::class_(m, "TokenizedInputs") .def(py::init()) .def_readwrite("input_ids", &TokenizedInputs::input_ids) diff --git a/tests/python_tests/ov_genai_test_utils.py b/tests/python_tests/ov_genai_test_utils.py index 4ba71a1d48..5d038e65e2 100644 --- a/tests/python_tests/ov_genai_test_utils.py +++ b/tests/python_tests/ov_genai_test_utils.py @@ -81,6 +81,8 @@ def get_chat_templates(): # but skips some models that currently are not processed correctly. skipped_models = { + "berkeley-nest/Starling-LM-7B-alpha", # TODO: Need to enable and unskip, since it's preset in continious batching and has ~30 000 downloads. + # These models fail even on HF so no need to check if applying chat matches. "vibhorag101/llama-2-13b-chat-hf-phr_mental_therapy", "codellama/CodeLlama-34b-Instruct-hf",