From cb100cb3bc7459bb489154937b3a076c5bd9f1d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mi=C5=82osz=20=C5=BBeglarski?= Date: Thu, 11 Jul 2024 16:50:27 +0200 Subject: [PATCH] [Continuous batching] Replace standard max_element call with custom loop for greedy sampling (#607) Searching for max element in a custom loop gives better performance than using std::max_element --- src/cpp/src/sampler.hpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index dc631c68ac..6390fc8725 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -219,8 +219,13 @@ class Sampler { } Token _greedy_sample(const std::vector& logit_vector) const { - auto out_token = std::max_element(logit_vector.begin(), logit_vector.end(), [](const Token& lhs, const Token& rhs) { return lhs.m_log_prob < rhs.m_log_prob; }); - return *out_token; + Token max_token{-std::numeric_limits::infinity() , 0}; + for (const auto& logit : logit_vector) { + if (logit.m_log_prob > max_token.m_log_prob) { + max_token = logit; + } + } + return max_token; } std::vector _multinomial_sample(const std::vector& logit_vector, size_t num_tokens_per_sequence) {