From 35932aab12c5a681b586819350a1712481c7272a Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 9 Aug 2024 16:30:58 +0800 Subject: [PATCH] fix minicpm V 2.6 repeat output --- .../ipex_llm/transformers/models/minicpmv.py | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index 340285ed193..bdf9aa3a535 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -15,6 +15,30 @@ # +import torch +from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor + + +# todo +def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + score = torch.gather(scores, 1, input_ids) + + # if score < 0 then repetition penalty has to be + # multiplied to reduce the token probabilities + score = torch.where(score < 0, score * self.penalty, score / self.penalty) + + # ipex llm changes start: call scatter on CPU + device = scores.device + scores = scores.to('cpu') + input_ids = input_ids.to('cpu') + score = score.to('cpu') + scores.scatter_(1, input_ids, score) + scores = scores.to(device) + # ipex llm changes end + + return scores + + def minicpmv_generate_wrapper(origin_generate): def generate( self, @@ -30,8 +54,7 @@ def generate( decode_text=False, **kwargs ): - if kwargs.get("repetition_penalty", None) is not None: - kwargs["repetition_penalty"] = 1 + RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call return origin_generate( self=self, input_ids=input_ids,