diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index c56d02afef..52ff15430a 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -318,13 +318,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector generation_outputs = generation->read_all(); - std::sort(generation_outputs.begin(), generation_outputs.end(), [=] (GenerationOutput& r1, GenerationOutput& r2) { - return r1.score > r2.score; - }); - - auto num_outputs = std::min(sampling_params[generation_idx].num_return_sequences, generation_outputs.size()); - for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) { - const auto& generation_output = generation_outputs[generation_output_idx]; + for (const auto& generation_output : generation_outputs) { result.m_generation_ids.push_back(std::move(generation_output.generated_ids)); result.m_scores.push_back(generation_output.score); } diff --git a/src/cpp/src/generation_handle.cpp b/src/cpp/src/generation_handle.cpp index 31c110d961..a1dd467523 100644 --- a/src/cpp/src/generation_handle.cpp +++ b/src/cpp/src/generation_handle.cpp @@ -67,8 +67,10 @@ std::vector GenerationHandleImpl::read_all() { add_partial_result(partial_results, iteration_results); } - for (auto& partial_result: partial_results) { + for (auto& partial_result : partial_results) { results.push_back(partial_result.second); } + std::sort(results.begin(), results.end(), [](const GenerationOutput& lhs, const GenerationOutput& rhs) { return lhs.score > rhs.score; }); + results.resize(std::min(m_sampling_params.num_return_sequences, results.size())); return results; }