From 4b1407fa4edfd1ee685bb001df18e5361e768221 Mon Sep 17 00:00:00 2001 From: Irina Efode Date: Wed, 24 Jul 2024 19:26:52 +0400 Subject: [PATCH 1/4] [ CB ] Return only N results from read_all --- src/cpp/src/generation_handle.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/cpp/src/generation_handle.cpp b/src/cpp/src/generation_handle.cpp index a0187025ec..d3971dbbaf 100644 --- a/src/cpp/src/generation_handle.cpp +++ b/src/cpp/src/generation_handle.cpp @@ -46,8 +46,11 @@ std::vector GenerationHandleImpl::read_all() { add_partial_result(partial_results, iteration_results); } - for (auto& partial_result: partial_results) { + for (auto& partial_result : partial_results) { results.push_back(partial_result.second); + if (results.size() == m_sampling_params.num_return_sequences) { + break; + } } return results; } From ea1a12e1c22962aeb7630ee18e1108298266488d Mon Sep 17 00:00:00 2001 From: Irina Efode Date: Tue, 5 Nov 2024 16:26:27 +0400 Subject: [PATCH 2/4] comments --- src/cpp/src/generation_handle.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/cpp/src/generation_handle.cpp b/src/cpp/src/generation_handle.cpp index 8c61647592..cce5ba5dfd 100644 --- a/src/cpp/src/generation_handle.cpp +++ b/src/cpp/src/generation_handle.cpp @@ -69,9 +69,8 @@ std::vector GenerationHandleImpl::read_all() { for (auto& partial_result : partial_results) { results.push_back(partial_result.second); - if (results.size() == m_sampling_params.num_return_sequences) { - break; - } } + std::sort(results.begin(), results.end(), [](const GenerationOutput& lhs, const GenerationOutput& rhs) { return std::fabs(lhs.score) > std::fabs(rhs.score); }); + results.resize(m_sampling_params.num_return_sequences); return results; } From ad0d01e23f6e860d02044dba82ad0a778a79d5a2 Mon Sep 17 00:00:00 2001 From: Irina Efode Date: Tue, 5 Nov 2024 16:35:20 +0400 Subject: [PATCH 3/4] remove extra limitation" --- src/cpp/src/continuous_batching_impl.cpp | 8 +------- src/cpp/src/generation_handle.cpp | 2 +- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index c56d02afef..52ff15430a 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -318,13 +318,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector generation_outputs = generation->read_all(); - std::sort(generation_outputs.begin(), generation_outputs.end(), [=] (GenerationOutput& r1, GenerationOutput& r2) { - return r1.score > r2.score; - }); - - auto num_outputs = std::min(sampling_params[generation_idx].num_return_sequences, generation_outputs.size()); - for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) { - const auto& generation_output = generation_outputs[generation_output_idx]; + for (const auto& generation_output : generation_outputs) { result.m_generation_ids.push_back(std::move(generation_output.generated_ids)); result.m_scores.push_back(generation_output.score); } diff --git a/src/cpp/src/generation_handle.cpp b/src/cpp/src/generation_handle.cpp index cce5ba5dfd..6563ce8f5a 100644 --- a/src/cpp/src/generation_handle.cpp +++ b/src/cpp/src/generation_handle.cpp @@ -71,6 +71,6 @@ std::vector GenerationHandleImpl::read_all() { results.push_back(partial_result.second); } std::sort(results.begin(), results.end(), [](const GenerationOutput& lhs, const GenerationOutput& rhs) { return std::fabs(lhs.score) > std::fabs(rhs.score); }); - results.resize(m_sampling_params.num_return_sequences); + results.resize(std::min(m_sampling_params.num_return_sequences, results.size())); return results; } From 4ac62c74618086da9fd51bf799c3bcf632d0b5a3 Mon Sep 17 00:00:00 2001 From: Irina Efode Date: Thu, 7 Nov 2024 22:14:11 +0400 Subject: [PATCH 4/4] Update generation_handle.cpp --- src/cpp/src/generation_handle.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/src/generation_handle.cpp b/src/cpp/src/generation_handle.cpp index 6563ce8f5a..a1dd467523 100644 --- a/src/cpp/src/generation_handle.cpp +++ b/src/cpp/src/generation_handle.cpp @@ -70,7 +70,7 @@ std::vector GenerationHandleImpl::read_all() { for (auto& partial_result : partial_results) { results.push_back(partial_result.second); } - std::sort(results.begin(), results.end(), [](const GenerationOutput& lhs, const GenerationOutput& rhs) { return std::fabs(lhs.score) > std::fabs(rhs.score); }); + std::sort(results.begin(), results.end(), [](const GenerationOutput& lhs, const GenerationOutput& rhs) { return lhs.score > rhs.score; }); results.resize(std::min(m_sampling_params.num_return_sequences, results.size())); return results; }