Skip to content

Commit

Permalink
set reason for partial push also
Browse files Browse the repository at this point in the history
  • Loading branch information
mzegla committed Aug 2, 2024
1 parent 50a33a0 commit 21e680e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/cpp/src/generation_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ void add_partial_result(std::unordered_map<uint64_t, GenerationOutput>& partial_
} else {
partial_result_iter->second.generated_token_ids.push_back(iteration_result.second.generated_token_ids[0]);
partial_result_iter->second.score = iteration_result.second.score;
partial_result_iter->second.finish_reason = iteration_result.second.finish_reason;
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class Sequence {
OPENVINO_ASSERT(m_generated_ids.size());
output.score = get_cumulative_log_probs();
output.generated_token_ids = std::vector<int64_t> {m_generated_ids.back()};
output.finish_reason = get_finish_reason();
return output;
}

Expand Down Expand Up @@ -215,10 +216,11 @@ class SequenceGroup {
// stop sequence by max_new_tokens or EOS token
running_sequence->set_status(SequenceStatus::FINISHED);

if (running_sequence->get_generated_ids().back() == m_sampling_params.eos_token_id && !m_sampling_params.ignore_eos)
if (running_sequence->get_generated_ids().back() == m_sampling_params.eos_token_id && !m_sampling_params.ignore_eos) {
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
else if (m_sampling_params.max_new_tokens == generated_len)
} else if (m_sampling_params.max_new_tokens == generated_len) {
running_sequence->set_finish_reason(GenerationFinishReason::LENGTH);
}

dropped_seq_ids.push_back(running_sequence->get_id());
}
Expand Down

0 comments on commit 21e680e

Please sign in to comment.