From f0c26772d613cc1a31c7c1491484aef41a706996 Mon Sep 17 00:00:00 2001 From: Anastasiia Pnevskaia Date: Mon, 15 Jul 2024 19:20:26 +0200 Subject: [PATCH] Clear beam search info when generate() is finished. (#630) Port of PR: https://github.com/openvinotoolkit/openvino.genai/pull/615 --- src/cpp/src/continuous_batching_pipeline.cpp | 1 + src/cpp/src/sampler.hpp | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/src/cpp/src/continuous_batching_pipeline.cpp b/src/cpp/src/continuous_batching_pipeline.cpp index 27c183ddd8..ddfebc5926 100644 --- a/src/cpp/src/continuous_batching_pipeline.cpp +++ b/src/cpp/src/continuous_batching_pipeline.cpp @@ -61,6 +61,7 @@ class ContinuousBatchingPipeline::Impl { for (const auto& sequence: request->get_sequences()) { m_scheduler->free_sequence(sequence->get_id()); } + m_sampler->clear_beam_search_info(request->get_request_id()); requests_iterator = m_requests.erase(requests_iterator); } else { requests_iterator++; diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 095c795a42..dc631c68ac 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -247,6 +247,8 @@ class Sampler { SamplerOutput sample(std::vector & sequence_groups, ov::Tensor logits); void set_seed(size_t seed) { rng_engine.seed(seed); } + + void clear_beam_search_info(uint64_t request_id); }; SamplerOutput Sampler::sample(std::vector & sequence_groups, ov::Tensor logits) { @@ -578,4 +580,8 @@ void GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, SamplerOutp } } } + +void Sampler::clear_beam_search_info(uint64_t request_id) { + m_beam_search_info.erase(request_id); +} }