From 0d4d39a151794d40d69732b406528afaa8630658 Mon Sep 17 00:00:00 2001 From: sbalandi Date: Fri, 15 Nov 2024 20:07:09 +0000 Subject: [PATCH] Fix wrong logits processing without applying of slice matmul --- src/cpp/src/lm_encoding.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index 644aa369c6..93d97bcfe1 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -105,6 +105,20 @@ std::pair get_lm_encoded_results( auto logits = m_llm.get_tensor("logits"); + // if slice matmul is not appyed + size_t vocab_size = logits.get_shape().back(); + if (!m_embedding.has_value()) { + ov::Tensor new_logits = ov::Tensor(logits.get_element_type(), {batch_size, 1, vocab_size}); + size_t sequence_offset = (logits.get_shape()[1] - 1) * vocab_size; + + for (size_t batch_idx = 0; batch_idx < batch_size; batch_idx++) { + size_t batch_offset = batch_idx * logits.get_shape().at(1) * vocab_size; + const float* logits_data = logits.data() + batch_offset + sequence_offset; + std::copy(logits_data, logits_data + vocab_size, new_logits.data() + batch_idx * vocab_size); + } + logits = new_logits; + } + int64_t sequence_len = logits.get_shape().at(1); for (auto& sequence_group : sequence_groups) sequence_group->schedule_tokens(sequence_len);