Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Jul 12, 2024
1 parent 740c914 commit 278b1b6
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 5 deletions.
1 change: 1 addition & 0 deletions samples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_subdirectory(cpp/greedy_causal_lm)
add_subdirectory(cpp/multinomial_causal_lm)
add_subdirectory(cpp/prompt_lookup_decoding_lm)
add_subdirectory(cpp/speculative_decoding_lm)
add_subdirectory(cpp/benchmark_vanilla_genai)

install(FILES requirements.txt DESTINATION samples
COMPONENT cpp_samples_genai)
Expand Down
25 changes: 25 additions & 0 deletions samples/cpp/benchmark_vanilla_genai/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


find_package(OpenVINOGenAI REQUIRED PATHS
"${CMAKE_BINARY_DIR}" # Reuse the package from the build.
${OpenVINO_DIR} # GenAI may be installed alogside OpenVINO.
)

FetchContent_Declare(cxxopts
URL https://github.com/jarro2783/cxxopts/archive/refs/tags/v3.1.1.tar.gz
URL_HASH SHA256=523175f792eb0ff04f9e653c90746c12655f10cb70f1d5e6d6d9491420298a08)
FetchContent_MakeAvailable(cxxopts)

add_executable(benchmark_vanilla_genai benchmark_vanilla_genai.cpp)
target_link_libraries(benchmark_vanilla_genai PRIVATE openvino::genai cxxopts::cxxopts)
set_target_properties(benchmark_vanilla_genai PROPERTIES
COMPILE_PDB_NAME benchmark_vanilla_genai
# Ensure out of box LC_RPATH on macOS with SIP
INSTALL_RPATH_USE_LINK_PATH ON)
# target_compile_features(benchmark_vanilla_genai PRIVATE cxx_std_11)
install(TARGETS benchmark_vanilla_genai
RUNTIME DESTINATION samples_bin/
COMPONENT samples_bin
EXCLUDE_FROM_ALL)
2 changes: 2 additions & 0 deletions samples/cpp/benchmark_vanilla_genai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# benchmark OpenVINO GenAI sample

65 changes: 65 additions & 0 deletions samples/cpp/benchmark_vanilla_genai/benchmark_vanilla_genai.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "openvino/genai/llm_pipeline.hpp"
#include <cxxopts.hpp>

int main(int argc, char* argv[]) try {
cxxopts::Options options("benchmark_vanilla_genai", "Help command");

options.add_options()
("p,prompt", "Prompt", cxxopts::value<std::string>()->default_value("The Sky is blue because"))
("m,model", "Path to model and tokenizers base directory", cxxopts::value<std::string>()->default_value("."))
("nw,num_warmup", "Number of warmup iterations", cxxopts::value<size_t>()->default_value(std::to_string(1)))
("n,num_iter", "Number of iterations", cxxopts::value<size_t>()->default_value(std::to_string(1)))
("d,device", "device", cxxopts::value<std::string>()->default_value("CPU"))
("h,help", "Print usage");

cxxopts::ParseResult result;
try {
result = options.parse(argc, argv);
} catch (const cxxopts::exceptions::exception& e) {
std::cout << e.what() << "\n\n";
std::cout << options.help() << std::endl;
return EXIT_FAILURE;
}

if (result.count("help")) {
std::cout << options.help() << std::endl;
return EXIT_SUCCESS;
}

std::string prompt = result["prompt"].as<std::string>();
const std::string model_path = result["model"].as<std::string>();
std::string device = result["device"].as<std::string>();
size_t num_warmup = result["num_warmup"].as<size_t>();
size_t num_iter = result["num_iter"].as<size_t>();

ov::genai::GenerationConfig config;
config.max_new_tokens = 100;

ov::genai::LLMPipeline pipe(model_path, device);

for (size_t i = 0; i < num_warmup; i++)
pipe.generate(prompt, config);

ov::genai::GenerationMetrics metrics;
for (size_t i = 0; i < num_iter; i++) {
ov::genai::DecodedResults res = pipe.generate(prompt, config);
metrics += res.metrics;
metrics.load_time = res.metrics.load_time;
}

std::cout << "Load time: " << metrics.load_time << " ms" << std::endl;
std::cout << "ttft: " << metrics.mean_ttft << " ± " << metrics.std_ttft << " ms" << std::endl;
std::cout << "tpot: " << metrics.mean_tpot << " ± " << metrics.std_tpot << " ms" << std::endl;
std::cout << "Tokens/s: " << metrics.get_tokens_per_sec().first << std::endl;

return 0;
} catch (const std::exception& error) {
std::cerr << error.what() << '\n';
return EXIT_FAILURE;
} catch (...) {
std::cerr << "Non-exception object thrown\n";
return EXIT_FAILURE;
}
40 changes: 40 additions & 0 deletions src/cpp/include/openvino/genai/generation_metrics.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <chrono>
#include <numeric>
#include <vector>
#include <cmath>

namespace ov {
namespace genai {

using TimePoints = std::vector<std::chrono::steady_clock::time_point>;

struct GenerationMetrics {
GenerationMetrics() = default;

GenerationMetrics(const TimePoints& tok_times, size_t batch_size = 1);
GenerationMetrics(const std::vector<float>& durations, const std::vector<float>& times_to_first_token, size_t batch_size = 1);

// First token time.
float mean_ttft;
float std_ttft;
std::vector<float> times_to_first_token;

// Time per output token.
float mean_tpot;
float std_tpot;
std::vector<float> durations;

std::pair<float, float> get_tokens_per_sec() const;
size_t batch_size;
float load_time;

GenerationMetrics operator+=(GenerationMetrics const& metrics) const;
};

} // namespace genai
} // namespace ov
4 changes: 4 additions & 0 deletions src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

#include <optional>
#include <variant>
#include <chrono>

#include "openvino/core/any.hpp"
#include "openvino/genai/generation_config.hpp"
#include "openvino/genai/tokenizer.hpp"
#include "openvino/genai/streamer_base.hpp"
#include "openvino/genai/generation_metrics.hpp"

namespace ov {
namespace genai {
Expand All @@ -34,6 +36,7 @@ class EncodedResults {
public:
std::vector<std::vector<int64_t>> tokens;
std::vector<float> scores;
GenerationMetrics metrics;
};

/**
Expand All @@ -47,6 +50,7 @@ class DecodedResults {
public:
std::vector<std::string> texts;
std::vector<float> scores;
GenerationMetrics metrics;

// @brief Convert DecodedResults to a string.
operator std::string() const {
Expand Down
62 changes: 62 additions & 0 deletions src/cpp/src/generation_metrics.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "openvino/genai/generation_metrics.hpp"
#include <tuple>

namespace {

std::pair<float, float> calc_mean_and_std(const std::vector<float>& durations) {
float mean = std::accumulate(durations.begin(), durations.end(), 0.0f) / durations.size();

float sum_square_durations = std::accumulate(durations.begin(), durations.end(), 0.0f,
[](const float& acc, const float& duration) -> float {
return acc + duration * duration;
});
float std = std::sqrt(sum_square_durations / durations.size() - mean * mean);
return {mean, std};
}

} // namespace

namespace ov {
namespace genai {


GenerationMetrics::GenerationMetrics(const TimePoints& tok_times, size_t batch_size) {
this->batch_size = batch_size;
durations = std::vector<float>(tok_times.size() - 1);
for (size_t i = 1; i < tok_times.size(); ++i) {
durations[i - 1] = std::chrono::duration_cast<std::chrono::milliseconds>(tok_times[i] - tok_times[i - 1]).count();
}
times_to_first_token.emplace_back(durations[0]);

std::tie(mean_tpot, std_tpot) = calc_mean_and_std(durations);
std::tie(mean_ttft, std_ttft) = calc_mean_and_std(times_to_first_token);
}

GenerationMetrics::GenerationMetrics(const std::vector<float>& durations_, const std::vector<float>& times_to_first_token_, size_t batch_size)
: durations(durations_), times_to_first_token(times_to_first_token_) {
this->batch_size = batch_size;
std::tie(mean_tpot, std_tpot) = calc_mean_and_std(durations);
std::tie(mean_ttft, std_ttft) = calc_mean_and_std(times_to_first_token);
}

GenerationMetrics GenerationMetrics::operator+=(GenerationMetrics const& metrics) const {
std::vector<float> new_durations = durations;
std::vector<float> new_times_to_first_token = times_to_first_token;
new_durations.insert(new_durations.end(), metrics.durations.begin(), metrics.durations.end());
new_times_to_first_token.insert(new_times_to_first_token.end(), metrics.times_to_first_token.begin(), metrics.times_to_first_token.end());

return GenerationMetrics(new_durations, new_times_to_first_token);
}

std::pair<float, float> GenerationMetrics::get_tokens_per_sec() const {
auto mean_tps = 1000.0f * batch_size / mean_tpot;
auto std_tps = 1000.0f * std_tpot / (mean_tpot * mean_tpot);
return {mean_tps, std_tps};
}


} // namespace genai
} // namespace ov
17 changes: 14 additions & 3 deletions src/cpp/src/greedy_decoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@ EncodedResults greedy_decoding(
const size_t batch_size = prompts_shape[0];
size_t running_batch_size = batch_size;
size_t prompt_len = prompts_shape[1];
size_t max_new_tokens = generation_config.get_max_new_tokens(prompt_len);

EncodedResults results;
// Time before the first token generated as a reference point.
ov::genai::TimePoints tok_times;
tok_times.reserve(max_new_tokens);
tok_times.emplace_back(std::chrono::steady_clock::now());

results.scores.resize(running_batch_size);
results.tokens.resize(running_batch_size);
std::fill(results.scores.begin(), results.scores.end(), 0);

m_model_runner.set_tensor("input_ids", input_ids);
m_model_runner.set_tensor("attention_mask", attention_mask);
if (position_ids.has_value())
Expand All @@ -50,6 +56,8 @@ EncodedResults greedy_decoding(
eos_met[batch] = (out_token == generation_config.eos_token_id);
m_model_runner.get_tensor("input_ids").data<int64_t>()[batch] = out_token;
}
tok_times.emplace_back(std::chrono::steady_clock::now());

if (streamer && streamer->put(token_iter_results[0])) {
return results;
}
Expand All @@ -58,8 +66,8 @@ EncodedResults greedy_decoding(
if (!generation_config.ignore_eos && all_are_eos)
return results;

size_t max_tokens = generation_config.get_max_new_tokens(prompt_len);
for (size_t i = 0; i < max_tokens - 1; ++i) {

for (size_t i = 0; i < max_new_tokens - 1; ++i) {
if (position_ids.has_value())
utils::update_position_ids(m_model_runner.get_tensor("position_ids"), m_model_runner.get_tensor("attention_mask"));
m_model_runner.set_tensor("attention_mask", utils::extend_attention(m_model_runner.get_tensor("attention_mask")));
Expand All @@ -80,6 +88,7 @@ EncodedResults greedy_decoding(

m_model_runner.get_tensor("input_ids").data<int64_t>()[batch] = out_token;
}
tok_times.emplace_back(std::chrono::steady_clock::now());

if (streamer && streamer->put(token_iter_results[0]))
return results;
Expand All @@ -106,6 +115,8 @@ EncodedResults greedy_decoding(
if (streamer) {
streamer->end();
}

results.metrics = GenerationMetrics(tok_times);
return results;
}

Expand Down
10 changes: 8 additions & 2 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <openvino/openvino.hpp>
#include "openvino/genai/generation_config.hpp"
#include "openvino/genai/llm_pipeline.hpp"
#include "openvino/genai/generation_metrics.hpp"
#include "llm_pipeline_base.hpp"
#include "llm_pipeline_static.hpp"
#include "utils.hpp"
Expand Down Expand Up @@ -155,6 +156,8 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
m_history.push_back({{"role", "assistant"}, {"content", answer}});
}

decoded_results.metrics = std::move(encoded_results.metrics);
decoded_results.metrics.load_time = m_load_time_ms;
return decoded_results;
}

Expand Down Expand Up @@ -253,7 +256,6 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
} else {
m_is_cache_empty = false;
}

return result;
}

Expand Down Expand Up @@ -350,6 +352,7 @@ ov::genai::LLMPipeline::LLMPipeline(
const std::string& device,
const ov::AnyMap& plugin_config
) {

if (device == "NPU") {
m_pimpl = make_unique<StaticLLMPipeline>(std::filesystem::path(model_path), tokenizer, device, plugin_config);
} else {
Expand All @@ -361,12 +364,15 @@ ov::genai::LLMPipeline::LLMPipeline(
const std::string& path,
const std::string& device,
const ov::AnyMap& config
) {
) {
auto start_time = std::chrono::steady_clock::now();
if (device == "NPU") {
m_pimpl = make_unique<StaticLLMPipeline>(std::filesystem::path(path), device, config);
} else {
m_pimpl = make_unique<StatefulLLMPipeline>(std::filesystem::path(path), device, config);
}
auto stop_time = std::chrono::steady_clock::now();
m_pimpl->m_load_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(stop_time - start_time).count();
}

ov::genai::GenerationConfig ov::genai::LLMPipeline::get_generation_config() const {
Expand Down
2 changes: 2 additions & 0 deletions src/cpp/src/llm_pipeline_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class LLMPipelineImplBase {

Tokenizer m_tokenizer;
GenerationConfig m_generation_config;

float m_load_time_ms = 0;
};

} // namespace genai
Expand Down

0 comments on commit 278b1b6

Please sign in to comment.