Skip to content

Commit

Permalink
[Continuous batching] Replace standard max_element call with custom l…
Browse files Browse the repository at this point in the history
…oop for greedy sampling (#607)

Searching for max element in a custom loop gives better performance than
using std::max_element
  • Loading branch information
mzegla authored Jul 11, 2024
1 parent 048d439 commit 740c914
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,13 @@ class Sampler {
}

Token _greedy_sample(const std::vector<Token>& logit_vector) const {
auto out_token = std::max_element(logit_vector.begin(), logit_vector.end(), [](const Token& lhs, const Token& rhs) { return lhs.m_log_prob < rhs.m_log_prob; });
return *out_token;
Token max_token{-std::numeric_limits<float>::infinity() , 0};
for (const auto& logit : logit_vector) {
if (logit.m_log_prob > max_token.m_log_prob) {
max_token = logit;
}
}
return max_token;
}

std::vector<Token> _multinomial_sample(const std::vector<Token>& logit_vector, size_t num_tokens_per_sequence) {
Expand Down

0 comments on commit 740c914

Please sign in to comment.