Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Continuous batching] Late token vector initialization in sampling #649

Merged
merged 2 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {

public:
GenerationHandleImpl(std::shared_ptr<GenerationStream> generation_stream, const ov::genai::GenerationConfig& sampling_params) :
m_generation_stream(generation_stream),
m_generation_stream(std::move(generation_stream)),
m_sampling_params(sampling_params) {};

~GenerationHandleImpl();
Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/block_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ class BlockManager {
}

bool can_append_slots(SequenceGroup::CPtr seq_group) {
return required_blocks_count(seq_group) <= m_allocator.num_free_blocks();
return required_blocks_count(std::move(seq_group)) <= m_allocator.num_free_blocks();
}

size_t required_blocks_count(SequenceGroup::CPtr seq_group) {
Expand Down Expand Up @@ -503,7 +503,7 @@ class BlockManager {
// write information about block forking for later usage in CacheManager
copy_blocks_map[last_block->get_index()].push_back(new_block->get_index());
// release `last_block` usage
m_allocator.free(last_block);
m_allocator.free(std::move(last_block));
} else {
// we are the only users of this block
if (m_enable_prefix_caching) {
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/generation_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class GenerationStream {
}

void push(GenerationOutputs outputs) {
m_output_queue.push(outputs);
m_output_queue.push(std::move(outputs));
}

// Retriving vector of pairs <sequence_id, token_id> as we can generate multiple outputs for a single prompt
Expand Down
135 changes: 85 additions & 50 deletions src/cpp/src/logit_processor.hpp
Wovchena marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,38 @@ struct Token {
Token() = default;
};

struct Logits {
float * m_data = nullptr;
size_t m_size;
// Late initialized for top_p or top_k transforms
std::vector<Token> m_vector;

Logits(float* data, size_t size): m_data(data), m_size(size) {}


void initialize_vector() {
OPENVINO_ASSERT(m_vector.size() == 0, "Logits vector already initialized");
m_vector.reserve(m_size);
Wovchena marked this conversation as resolved.
Show resolved Hide resolved
for (size_t i = 0; i < m_size; i++)
m_vector.emplace_back(m_data[i], i);
}

bool is_vector_initialized() const {
return m_vector.size() > 0;
}

void resize(size_t new_size) {
m_size = new_size;
m_vector.resize(new_size);
Wovchena marked this conversation as resolved.
Show resolved Hide resolved
}
};

namespace LogitTransformers {
using TokenIds = std::vector<int64_t>;

class ILogitTransformer {
public:
virtual void apply(std::vector<Token>& logits) = 0;
virtual void apply(Logits& logits) = 0;

virtual bool is_applicable(size_t generated_tokens_cnt = 0) {
return true;
Expand All @@ -32,11 +58,15 @@ class TopPFilter : public ILogitTransformer {
public:
TopPFilter(double top_p) : m_top_p(top_p) {}

void apply(std::vector<Token>& logits) override {
std::sort(logits.begin(), logits.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
void apply(Logits& logits) override {
if (!logits.is_vector_initialized()) {
// Initialize and sort vector
logits.initialize_vector();
std::sort(logits.m_vector.begin(), logits.m_vector.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
}
float probability_sum = 0.0f;
size_t nucleus_size = 0;
for (const auto& probability : logits) {
for (const auto& probability : logits.m_vector) {
probability_sum += probability.m_log_prob;
nucleus_size += 1;
if (probability_sum > m_top_p) break;
Expand All @@ -52,10 +82,18 @@ class TopKFilter : public ILogitTransformer {
public:
TopKFilter(size_t top_k) : m_top_k(top_k) {}

void apply(std::vector<Token>& logits) override {
std::sort(logits.begin(), logits.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
size_t top_k = logits.size() >= m_top_k ? m_top_k : logits.size();
logits.resize(top_k);
// If this transform is used along with top_p, it should be applied after it since top_p sorts entire vector and top_k does it only partially
void apply(Logits& logits) override {

if (m_top_k >= logits.m_size)
mzegla marked this conversation as resolved.
Show resolved Hide resolved
return;

if (!logits.is_vector_initialized()) {
// Initialize and partially sort vector
logits.initialize_vector();
std::partial_sort(logits.m_vector.begin(), logits.m_vector.begin() + m_top_k, logits.m_vector.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; });
}
logits.resize(m_top_k);
}

protected:
Expand All @@ -66,18 +104,23 @@ class TemperatureLogitTransform : public ILogitTransformer {
public:
TemperatureLogitTransform(double temperature) : m_temperature(temperature) {};

void apply(std::vector<Token>& logits) override {
auto max_prob_token = std::max_element(logits.begin(), logits.end(), [](const Token& lhs, const Token& rhs) { return lhs.m_log_prob < rhs.m_log_prob; });
float max_logit = max_prob_token->m_log_prob;

std::for_each(logits.begin(), logits.end(), [max_logit, this](Token& val) {val.m_log_prob = expf((val.m_log_prob - max_logit) / this->m_temperature);});
void apply(Logits& logits) override {
float max_logit = -std::numeric_limits<float>::infinity();
for (size_t i = 0; i < logits.m_size; i++) {
if (logits.m_data[i] > max_logit) {
max_logit = logits.m_data[i];
}
}

float norm_sum = 0.0;
for (const auto& val : logits) {
norm_sum += val.m_log_prob;
for (size_t i = 0; i < logits.m_size; i++) {
logits.m_data[i] = expf((logits.m_data[i] - max_logit) / this->m_temperature);
norm_sum += logits.m_data[i];
}

std::for_each(logits.begin(), logits.end(), [norm_sum](Token& val) {val.m_log_prob /= norm_sum;});
for (size_t i = 0; i < logits.m_size; i++) {
logits.m_data[i] /= norm_sum;
}
}

protected:
Expand Down Expand Up @@ -118,32 +161,28 @@ class RepetitionPenaltyTransform : public IPenaltyTransformer {
m_penalty = repetition_penalty;
};

void apply(std::vector<Token>& logits) override {
size_t vocab_size = logits.size();
void apply(Logits& logits) override {
size_t vocab_size = logits.m_size;
for (const auto& prompt_id : *m_unique_prompt_token_ids) {
OPENVINO_ASSERT((prompt_id >= 0) && (prompt_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(logits[prompt_id].m_index == prompt_id, "input_logits must have original index order");
auto logit_value = logits[prompt_id].m_log_prob;
if (logit_value >= 0) {
logits[prompt_id].m_log_prob /= m_penalty;
if (logits.m_data[prompt_id] >= 0) {
logits.m_data[prompt_id] /= m_penalty;
} else {
logits[prompt_id].m_log_prob *= m_penalty;
logits.m_data[prompt_id] *= m_penalty;
};
}
for (const auto& input_id_pair : *m_unique_generated_token_ids) {
const auto& input_id = input_id_pair.first;
OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(logits[input_id].m_index == input_id, "input_logits must have original index order");
auto logit_value = logits[input_id].m_log_prob;
if (logit_value >= 0) {
logits[input_id].m_log_prob /= m_penalty;
if (logits.m_data[input_id] >= 0) {
logits.m_data[input_id] /= m_penalty;
} else {
logits[input_id].m_log_prob *= m_penalty;
logits.m_data[input_id] *= m_penalty;
};
}
}

void apply(std::vector<Token>& logits, const TokenIds& input_ids) {
void apply(Logits& logits, const TokenIds& input_ids) {
set_unique_prompt_token_ids(nullptr);
extract_generated_tokens(input_ids);
apply(logits);
Expand All @@ -166,10 +205,10 @@ class EOSPenaltyTransform : public ILogitTransformer {
EOSPenaltyTransform(size_t eos_token_id, size_t min_generated_tokens) :
m_eos_token_id(eos_token_id), m_applicable_tensor_len(min_generated_tokens) {}

void apply(std::vector<Token>& logits) override {
// Since EOS penalty is applied early, the token vector is not sorted
void apply(Logits& logits) override {
// Since EOS penalty is applied early, the token vector is not initialized yet
// and we can assume element order match token ids.
logits[m_eos_token_id].m_log_prob = 0.f;
logits.m_data[m_eos_token_id] = 0.f;
}


Expand All @@ -188,22 +227,20 @@ class FrequencyPenaltyTransform : public IPenaltyTransformer {
m_penalty = value;
};

void apply(std::vector<Token>& logits) override {
size_t vocab_size = logits.size();
void apply(Logits& logits) override {
size_t vocab_size = logits.m_size;
for (const auto& input_id_pair : *m_unique_generated_token_ids) {
const auto& input_id = input_id_pair.first;
OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(logits[input_id].m_index == input_id, "input_logits must have original index order");
auto logit_value = logits[input_id].m_log_prob;
if (logit_value >= 0) {
logits[input_id].m_log_prob -= m_penalty * input_id_pair.second;
if (logits.m_data[input_id] >= 0) {
logits.m_data[input_id] -= m_penalty * input_id_pair.second;
} else {
logits[input_id].m_log_prob += m_penalty * input_id_pair.second;
logits.m_data[input_id] += m_penalty * input_id_pair.second;
};
}
}

void apply(std::vector<Token>& logits, const TokenIds& input_ids) {
void apply(Logits& logits, const TokenIds& input_ids) {
extract_generated_tokens(input_ids);
apply(logits);
}
Expand All @@ -215,22 +252,20 @@ class PresencePenaltyTransform : public IPenaltyTransformer {
m_penalty = value;
};

void apply(std::vector<Token>& logits) override {
size_t vocab_size = logits.size();
void apply(Logits& logits) override {
size_t vocab_size = logits.m_size;
for (const auto& input_id_pair : *m_unique_generated_token_ids) {
const auto& input_id = input_id_pair.first;
OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(logits[input_id].m_index == input_id, "input_logits must have original index order");
auto logit_value = logits[input_id].m_log_prob;
if (logit_value >= 0) {
logits[input_id].m_log_prob -= m_penalty;
if (logits.m_data[input_id] >= 0) {
logits.m_data[input_id] -= m_penalty;
} else {
logits[input_id].m_log_prob += m_penalty;
logits.m_data[input_id] += m_penalty;
};
}
}

void apply(std::vector<Token>& logits, const TokenIds& input_ids) {
void apply(Logits& logits, const TokenIds& input_ids) {
extract_generated_tokens(input_ids);
apply(logits);
}
Expand Down Expand Up @@ -286,14 +321,14 @@ class LogitProcessor {
if (sampling_params.top_p != 1.0f) {
m_logit_transformers.emplace_back(new LogitTransformers::TopPFilter(sampling_params.top_p));
}
if (sampling_params.top_k > 0) {
if (sampling_params.top_k > 0 && sampling_params.top_k < std::numeric_limits<size_t>::max()) {
m_logit_transformers.emplace_back(new LogitTransformers::TopKFilter(sampling_params.top_k));
}
}
}
}

void apply(std::vector<Token>& logits) {
void apply(Logits& logits) {
for (const auto& transformer : m_logit_transformers) {
if (transformer->is_applicable(m_generated_tokens)) {
transformer->apply(logits);
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/model_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ModelRunner {
SchedulerConfig m_scheduler_config;
public:
ModelRunner(ov::InferRequest request, const SchedulerConfig& scheduler_config) :
m_request(request),
m_request(std::move(request)),
m_scheduler_config(scheduler_config) { }

ov::InferRequest get_infer_request() const {
Expand Down
46 changes: 27 additions & 19 deletions src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ struct Beam {
float m_score = -std::numeric_limits<float>::infinity();

Beam(Sequence::Ptr sequence)
: m_sequence(sequence) { }
: m_sequence(std::move(sequence)) { }

size_t get_generated_len() const {
return m_sequence->get_generated_len();
Expand Down Expand Up @@ -203,40 +203,49 @@ class GroupBeamSearcher {

class Sampler {

std::vector<Token> _get_logit_vector(ov::Tensor logits, size_t batch_idx = 1) {
Logits _get_logit_vector(ov::Tensor logits, size_t batch_idx = 1) {
ov::Shape logits_shape = logits.get_shape();
size_t batch_size = logits_shape[0], seq_len = logits_shape[1], vocab_size = logits_shape[2];
OPENVINO_ASSERT(batch_idx <= batch_size);
size_t batch_offset = batch_idx * seq_len * vocab_size;
size_t sequence_offset = (seq_len - 1) * vocab_size;
const float* logits_data = logits.data<const float>() + batch_offset + sequence_offset;
float* logits_data = logits.data<float>() + batch_offset + sequence_offset;

std::vector<Token> logit_vector(vocab_size);
for (size_t i = 0; i < logit_vector.size(); i++) {
logit_vector[i] = Token(logits_data[i], i);
}
return logit_vector;
return Logits{logits_data, vocab_size};
}

Token _greedy_sample(const std::vector<Token>& logit_vector) const {
Token max_token{-std::numeric_limits<float>::infinity() , 0};
for (const auto& logit : logit_vector) {
if (logit.m_log_prob > max_token.m_log_prob) {
max_token = logit;
Token _greedy_sample(const Logits& logits) const {
// For greedy sampling we do not expect sorting or shrinking considered tokens
// so we can operate directly on the data buffer
float max_value = -std::numeric_limits<float>::infinity();
size_t max_index = 0;
for (size_t i = 0; i < logits.m_size; ++i) {
if (logits.m_data[i] > max_value) {
max_value = logits.m_data[i];
max_index = i;
}
}
return max_token;
return Token(logits.m_data[max_index], max_index);
}

std::vector<Token> _multinomial_sample(const std::vector<Token>& logit_vector, size_t num_tokens_per_sequence) {
std::vector<float> multinomial_weights(logit_vector.size());
for (size_t i = 0; i < logit_vector.size(); i++) multinomial_weights[i] = logit_vector[i].m_log_prob;
std::vector<Token> _multinomial_sample(const Logits& logits, size_t num_tokens_per_sequence) {
// If top_p or top_k was applied we use sorted vector, if not we go with original buffer.
std::vector<float> multinomial_weights;
multinomial_weights.reserve(logits.m_size);
if (logits.is_vector_initialized())
for (auto& logit: logits.m_vector) multinomial_weights.emplace_back(logit.m_log_prob);
else
multinomial_weights.assign(logits.m_data, logits.m_data + logits.m_size);

auto dist = std::discrete_distribution<size_t>(multinomial_weights.begin(), multinomial_weights.end()); // equivalent to multinomial with number of trials == 1

std::vector<Token> out_tokens;
for (size_t token_idx = 0; token_idx < num_tokens_per_sequence; ++token_idx) {
size_t element_to_pick = dist(rng_engine);
out_tokens.push_back(logit_vector[element_to_pick]);
if (logits.is_vector_initialized())
out_tokens.push_back(logits.m_vector[element_to_pick]);
else
out_tokens.emplace_back(logits.m_data[element_to_pick], element_to_pick);
}
return out_tokens;
}
Expand Down Expand Up @@ -296,7 +305,6 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
for (size_t running_sequence_id = 0; running_sequence_id < num_running_sequences; ++running_sequence_id) {
auto logit_vector = _get_logit_vector(sequence_group_logits, running_sequence_id);
logit_processor.apply(logit_vector);

Token sampled_token_id;
if (sampling_params.is_greedy_decoding()) {
sampled_token_id = _greedy_sample(logit_vector);
Expand Down
Loading
Loading