From 50a33a09ffc65b237815766eb0a194108f4b8e02 Mon Sep 17 00:00:00 2001 From: mzegla Date: Thu, 1 Aug 2024 11:23:55 +0200 Subject: [PATCH] introduce finish reason --- .../openvino/genai/generation_handle.hpp | 7 +++++++ src/cpp/src/sequence_group.hpp | 19 +++++++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/cpp/include/openvino/genai/generation_handle.hpp b/src/cpp/include/openvino/genai/generation_handle.hpp index 8d00ae0e9b..c18f11f3e8 100644 --- a/src/cpp/include/openvino/genai/generation_handle.hpp +++ b/src/cpp/include/openvino/genai/generation_handle.hpp @@ -32,6 +32,12 @@ struct EncodedGenerationResult { GenerationStatus m_status = GenerationStatus::RUNNING; }; +enum class GenerationFinishReason { + NONE = 0, // Default value, when generation is not yet finished + STOP = 1, // Generation finished naturally, by reaching end of sequence token + LENGTH = 2 // Generation finished by reaching max_new_tokens limit +}; + struct GenerationResult { // request ID - obsolete when handle API is approved as handle will connect results with prompts. uint64_t m_request_id; @@ -49,6 +55,7 @@ struct GenerationResult { struct GenerationOutput { std::vector generated_token_ids; float score; + GenerationFinishReason finish_reason; }; using GenerationOutputs = std::unordered_map; diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index d5b9506b2c..db227a3436 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -33,6 +33,7 @@ class Sequence { uint64_t m_grouped_id; uint64_t m_id = _get_next_global_sequence_id(); SequenceStatus m_status = SequenceStatus::RUNNING; + GenerationFinishReason m_finish_reason = GenerationFinishReason::NONE; float m_cumulative_log_prob = 0.0f; public: @@ -91,6 +92,14 @@ class Sequence { m_status = status; } + GenerationFinishReason get_finish_reason() const { + return m_finish_reason; + } + + void set_finish_reason(GenerationFinishReason finish_reason) { + m_finish_reason = finish_reason; + } + // appends new tokens to a generated part void append_token(int64_t token_id, float log_prob) { m_cumulative_log_prob += log_prob; @@ -205,6 +214,12 @@ class SequenceGroup { running_sequence->get_generated_ids().back() == m_sampling_params.eos_token_id && !m_sampling_params.ignore_eos) { // stop sequence by max_new_tokens or EOS token running_sequence->set_status(SequenceStatus::FINISHED); + + if (running_sequence->get_generated_ids().back() == m_sampling_params.eos_token_id && !m_sampling_params.ignore_eos) + running_sequence->set_finish_reason(GenerationFinishReason::STOP); + else if (m_sampling_params.max_new_tokens == generated_len) + running_sequence->set_finish_reason(GenerationFinishReason::LENGTH); + dropped_seq_ids.push_back(running_sequence->get_id()); } } @@ -451,7 +466,8 @@ class SequenceGroup { for (auto& sequence: m_sequences) { GenerationOutput output; output.generated_token_ids = sequence->get_generated_ids(); - output.score = sequence->get_beam_search_score(m_sampling_params); + output.score = m_sampling_params.is_beam_search() ? sequence->get_beam_search_score(m_sampling_params) : sequence->get_cumulative_log_probs(); + output.finish_reason = sequence->get_finish_reason(); outputs.emplace(sequence->get_grouped_id(), output); } m_generation_stream->push(outputs); @@ -459,7 +475,6 @@ class SequenceGroup { void push_partial_outputs() { GenerationOutputs outputs; - // TODO: support streamimg for n seqs for (auto& sequence : m_sequences) { // todo: check seq.is_finished() to generate without several // or is it ok to use padding?