Skip to content

Commit

Permalink
introduce finish reason
Browse files Browse the repository at this point in the history
  • Loading branch information
mzegla committed Aug 2, 2024
1 parent 47fbb5e commit 50a33a0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
7 changes: 7 additions & 0 deletions src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ struct EncodedGenerationResult {
GenerationStatus m_status = GenerationStatus::RUNNING;
};

enum class GenerationFinishReason {
NONE = 0, // Default value, when generation is not yet finished
STOP = 1, // Generation finished naturally, by reaching end of sequence token
LENGTH = 2 // Generation finished by reaching max_new_tokens limit
};

struct GenerationResult {
// request ID - obsolete when handle API is approved as handle will connect results with prompts.
uint64_t m_request_id;
Expand All @@ -49,6 +55,7 @@ struct GenerationResult {
struct GenerationOutput {
std::vector<int64_t> generated_token_ids;
float score;
GenerationFinishReason finish_reason;
};

using GenerationOutputs = std::unordered_map<uint64_t, GenerationOutput>;
Expand Down
19 changes: 17 additions & 2 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Sequence {
uint64_t m_grouped_id;
uint64_t m_id = _get_next_global_sequence_id();
SequenceStatus m_status = SequenceStatus::RUNNING;
GenerationFinishReason m_finish_reason = GenerationFinishReason::NONE;
float m_cumulative_log_prob = 0.0f;

public:
Expand Down Expand Up @@ -91,6 +92,14 @@ class Sequence {
m_status = status;
}

GenerationFinishReason get_finish_reason() const {
return m_finish_reason;
}

void set_finish_reason(GenerationFinishReason finish_reason) {
m_finish_reason = finish_reason;
}

// appends new tokens to a generated part
void append_token(int64_t token_id, float log_prob) {
m_cumulative_log_prob += log_prob;
Expand Down Expand Up @@ -205,6 +214,12 @@ class SequenceGroup {
running_sequence->get_generated_ids().back() == m_sampling_params.eos_token_id && !m_sampling_params.ignore_eos) {
// stop sequence by max_new_tokens or EOS token
running_sequence->set_status(SequenceStatus::FINISHED);

if (running_sequence->get_generated_ids().back() == m_sampling_params.eos_token_id && !m_sampling_params.ignore_eos)
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
else if (m_sampling_params.max_new_tokens == generated_len)
running_sequence->set_finish_reason(GenerationFinishReason::LENGTH);

dropped_seq_ids.push_back(running_sequence->get_id());
}
}
Expand Down Expand Up @@ -451,15 +466,15 @@ class SequenceGroup {
for (auto& sequence: m_sequences) {
GenerationOutput output;
output.generated_token_ids = sequence->get_generated_ids();
output.score = sequence->get_beam_search_score(m_sampling_params);
output.score = m_sampling_params.is_beam_search() ? sequence->get_beam_search_score(m_sampling_params) : sequence->get_cumulative_log_probs();
output.finish_reason = sequence->get_finish_reason();
outputs.emplace(sequence->get_grouped_id(), output);
}
m_generation_stream->push(outputs);
}

void push_partial_outputs() {
GenerationOutputs outputs;
// TODO: support streamimg for n seqs
for (auto& sequence : m_sequences) {
// todo: check seq.is_finished() to generate without several </s>
// or is it ok to use padding?
Expand Down

0 comments on commit 50a33a0

Please sign in to comment.