Skip to content

Commit

Permalink
fix minicpm V 2.6 repeat output (#11753)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Aug 9, 2024
1 parent 7e917d6 commit 93455aa
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,30 @@
#


import torch
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor


# todo
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
score = torch.gather(scores, 1, input_ids)

# if score < 0 then repetition penalty has to be
# multiplied to reduce the token probabilities
score = torch.where(score < 0, score * self.penalty, score / self.penalty)

# ipex llm changes start: call scatter on CPU
device = scores.device
scores = scores.to('cpu')
input_ids = input_ids.to('cpu')
score = score.to('cpu')
scores.scatter_(1, input_ids, score)
scores = scores.to(device)
# ipex llm changes end

return scores


def minicpmv_generate_wrapper(origin_generate):
def generate(
self,
Expand All @@ -30,8 +54,7 @@ def generate(
decode_text=False,
**kwargs
):
if kwargs.get("repetition_penalty", None) is not None:
kwargs["repetition_penalty"] = 1
RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call
return origin_generate(
self=self,
input_ids=input_ids,
Expand Down

0 comments on commit 93455aa

Please sign in to comment.