diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 095c795a42..ba4aa4543b 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -219,8 +219,13 @@ class Sampler { } Token _greedy_sample(const std::vector& logit_vector) const { - auto out_token = std::max_element(logit_vector.begin(), logit_vector.end(), [](const Token& lhs, const Token& rhs) { return lhs.m_log_prob < rhs.m_log_prob; }); - return *out_token; + Token max_token{0.0, 0}; + for (const auto& logit : logit_vector) { + if (logit.m_log_prob > max_token.m_log_prob) { + max_token = logit; + } + } + return max_token; } std::vector _multinomial_sample(const std::vector& logit_vector, size_t num_tokens_per_sequence) {