Skip to content

Commit

Permalink
Reuse GenerationConfig (#569)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wovchena authored Jul 5, 2024
1 parent 08154fa commit 6667c3d
Show file tree
Hide file tree
Showing 22 changed files with 209 additions and 323 deletions.
10 changes: 5 additions & 5 deletions samples/cpp/accuracy_sample/accuracy_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ int main(int argc, char* argv[]) try {
"What is OpenVINO?",
};

std::vector<GenerationConfig> sampling_params_examples {
GenerationConfig::beam_search(),
GenerationConfig::greedy(),
GenerationConfig::multinomial(),
std::vector<ov::genai::GenerationConfig> sampling_params_examples {
ov::genai::beam_search(),
ov::genai::greedy(),
ov::genai::multinomial(),
};

std::vector<std::string> prompts(num_prompts);
std::vector<GenerationConfig> sampling_params(num_prompts);
std::vector<ov::genai::GenerationConfig> sampling_params(num_prompts);

for (size_t request_id = 0; request_id < num_prompts; ++request_id) {
prompts[request_id] = prompt_examples[request_id % prompt_examples.size()];
Expand Down
6 changes: 3 additions & 3 deletions samples/cpp/throughput_benchmark/throughput_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class AutoStartTimer {

struct Dataset {
std::vector<std::string> m_prompts;
std::vector<GenerationConfig> m_sampling_params;
std::vector<ov::genai::GenerationConfig> m_sampling_params;
std::vector<size_t> m_input_lens, m_output_lens;

size_t m_total_input_len = 0;
Expand All @@ -50,7 +50,7 @@ struct Dataset {
m_output_lens.reserve(size);
}

void push_data(std::string prompt, GenerationConfig sampling_params) {
void push_data(std::string prompt, ov::genai::GenerationConfig sampling_params) {
m_prompts.push_back(prompt);
m_sampling_params.push_back(sampling_params);
}
Expand Down Expand Up @@ -121,7 +121,7 @@ Dataset filtered_dataset(const std::string& models_path, const std::string& data
if (input_len > max_input_len || (input_len + output_len) > 2048)
continue;

GenerationConfig greedy_search = GenerationConfig::greedy();
ov::genai::GenerationConfig greedy_search = ov::genai::greedy();
greedy_search.max_new_tokens = std::min(max_output_len, output_len);

dataset.push_data(human_question, greedy_search);
Expand Down
1 change: 0 additions & 1 deletion src/cpp/continuous_batching/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ find_file(spda_to_pa_header sdpa_to_paged_attention.hpp
set(TARGET_NAME openvino_continuous_batching)

add_library(${TARGET_NAME} STATIC
src/generation_config.cpp
src/generation_handle.cpp
src/continuous_batching_pipeline.cpp
src/paged_attention_transformations.cpp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#include "scheduler_config.hpp"
#include "openvino/genai/tokenizer.hpp"
#include "generation_config.hpp"
#include "openvino/genai/generation_config.hpp"
#include "generation_handle.hpp"

struct PipelineMetrics {
Expand All @@ -32,16 +32,16 @@ class ContinuousBatchingPipeline {

std::shared_ptr<ov::genai::Tokenizer> get_tokenizer();

GenerationConfig get_config() const;
ov::genai::GenerationConfig get_config() const;

PipelineMetrics get_metrics() const;

GenerationHandle add_request(uint64_t request_id, std::string prompt, GenerationConfig sampling_params);
GenerationHandle add_request(uint64_t request_id, std::string prompt, ov::genai::GenerationConfig sampling_params);

void step();

bool has_non_finished_requests();

// more high level interface, which can process multiple prompts in continuous batching manner
std::vector<GenerationResult> generate(const std::vector<std::string>& prompts, std::vector<GenerationConfig> sampling_params);
std::vector<GenerationResult> generate(const std::vector<std::string>& prompts, std::vector<ov::genai::GenerationConfig> sampling_params);
};
78 changes: 0 additions & 78 deletions src/cpp/continuous_batching/include/generation_config.hpp

This file was deleted.

6 changes: 3 additions & 3 deletions src/cpp/continuous_batching/include/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <memory>
#include <unordered_map>

#include "generation_config.hpp"
#include "openvino/genai/generation_config.hpp"


enum class GenerationStatus {
Expand Down Expand Up @@ -42,10 +42,10 @@ class GenerationStream;

class GenerationHandleImpl {
std::shared_ptr<GenerationStream> m_generation_stream;
GenerationConfig m_sampling_params;
ov::genai::GenerationConfig m_sampling_params;

public:
GenerationHandleImpl(std::shared_ptr<GenerationStream> generation_stream, const GenerationConfig& sampling_params) :
GenerationHandleImpl(std::shared_ptr<GenerationStream> generation_stream, const ov::genai::GenerationConfig& sampling_params) :
m_generation_stream(generation_stream),
m_sampling_params(sampling_params) {};

Expand Down
14 changes: 7 additions & 7 deletions src/cpp/continuous_batching/src/continuous_batching_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ContinuousBatchingPipeline::Impl {

// TODO (mzegla): GenerationConfig is request specific object
// and pipeline only uses default rng_seed.
GenerationConfig m_generation_config;
ov::genai::GenerationConfig m_generation_config;

PipelineMetrics m_pipeline_metrics;

Expand Down Expand Up @@ -103,7 +103,7 @@ class ContinuousBatchingPipeline::Impl {
// read default generation config
}

GenerationConfig get_config() const {
ov::genai::GenerationConfig get_config() const {
return m_generation_config;
}

Expand All @@ -115,7 +115,7 @@ class ContinuousBatchingPipeline::Impl {
return m_tokenizer;
}

GenerationHandle add_request(uint64_t request_id, std::string prompt, GenerationConfig sampling_params) {
GenerationHandle add_request(uint64_t request_id, std::string prompt, ov::genai::GenerationConfig sampling_params) {
sampling_params.set_eos_token_id(m_tokenizer->get_eos_token_id());
sampling_params.validate();

Expand Down Expand Up @@ -233,7 +233,7 @@ class ContinuousBatchingPipeline::Impl {
return !m_awaiting_requests.empty() || !m_requests.empty();
}

std::vector<GenerationResult> generate(const std::vector<std::string> prompts, std::vector<GenerationConfig> sampling_params) {
std::vector<GenerationResult> generate(const std::vector<std::string> prompts, std::vector<ov::genai::GenerationConfig> sampling_params) {
OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request");
OPENVINO_ASSERT(prompts.size() == sampling_params.size());

Expand Down Expand Up @@ -285,15 +285,15 @@ std::shared_ptr<ov::genai::Tokenizer> ContinuousBatchingPipeline::get_tokenizer(
return m_impl->get_tokenizer();
}

GenerationConfig ContinuousBatchingPipeline::get_config() const{
ov::genai::GenerationConfig ContinuousBatchingPipeline::get_config() const{
return m_impl->get_config();
}

PipelineMetrics ContinuousBatchingPipeline::get_metrics() const{
return m_impl->get_metrics();
}

GenerationHandle ContinuousBatchingPipeline::add_request(uint64_t request_id, std::string prompt, GenerationConfig sampling_params) {
GenerationHandle ContinuousBatchingPipeline::add_request(uint64_t request_id, std::string prompt, ov::genai::GenerationConfig sampling_params) {
return m_impl->add_request(request_id, prompt, sampling_params);
}

Expand All @@ -305,6 +305,6 @@ bool ContinuousBatchingPipeline::has_non_finished_requests() {
return m_impl->has_non_finished_requests();
}

std::vector<GenerationResult> ContinuousBatchingPipeline::generate(const std::vector<std::string>& prompts, std::vector<GenerationConfig> sampling_params) {
std::vector<GenerationResult> ContinuousBatchingPipeline::generate(const std::vector<std::string>& prompts, std::vector<ov::genai::GenerationConfig> sampling_params) {
return m_impl->generate(prompts, sampling_params);
}
105 changes: 0 additions & 105 deletions src/cpp/continuous_batching/src/generation_config.cpp

This file was deleted.

10 changes: 5 additions & 5 deletions src/cpp/continuous_batching/src/logit_processor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <algorithm>
#include <cmath>

#include "generation_config.hpp"
#include "openvino/genai/generation_config.hpp"

struct Token {
float m_log_prob = 0.;
Expand Down Expand Up @@ -277,7 +277,7 @@ class LogitProcessor {
size_t m_generated_tokens = 0;

public:
LogitProcessor(const GenerationConfig& sampling_params,
LogitProcessor(const ov::genai::GenerationConfig& sampling_params,
const LogitTransformers::TokenIds& input_ids) {
for (const auto& input_id : input_ids) {
m_unique_prompt_token_ids->insert(input_id);
Expand All @@ -289,7 +289,7 @@ class LogitProcessor {
);
}

if (sampling_params.is_multinomial() || sampling_params.is_greedy_sampling()) {
if (sampling_params.is_multinomial() || sampling_params.is_greedy_decoding()) {
if (sampling_params.repetition_penalty != 1.0f) {
std::shared_ptr<LogitTransformers::RepetitionPenaltyTransform> transformer =
std::shared_ptr<LogitTransformers::RepetitionPenaltyTransform>(new LogitTransformers::RepetitionPenaltyTransform(sampling_params.repetition_penalty));
Expand All @@ -304,9 +304,9 @@ class LogitProcessor {
m_logit_transformers.push_back(transformer);

}
if (sampling_params.frequence_penalty != 0.0f) {
if (sampling_params.frequency_penalty != 0.0f) {
std::shared_ptr<LogitTransformers::FrequencyPenaltyTransform> transformer =
std::shared_ptr<LogitTransformers::FrequencyPenaltyTransform>(new LogitTransformers::FrequencyPenaltyTransform(sampling_params.frequence_penalty));
std::shared_ptr<LogitTransformers::FrequencyPenaltyTransform>(new LogitTransformers::FrequencyPenaltyTransform(sampling_params.frequency_penalty));
transformer->set_unique_generated_token_ids(m_unique_generated_token_ids);
m_logit_transformers.push_back(transformer);
}
Expand Down
Loading

0 comments on commit 6667c3d

Please sign in to comment.