From 3d6cfa291ddf6369a9e4bc1c6301d607e802dcb7 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 14 Aug 2024 16:07:24 +0800 Subject: [PATCH] optimize minicpm v 2.5 (#11793) --- .../llm/src/ipex_llm/transformers/convert.py | 10 +++-- .../ipex_llm/transformers/models/minicpmv.py | 41 ++++++++++++++++++- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 2d36b5c54c1..34db3f0fbe7 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -749,7 +749,7 @@ def _optimize_pre(model, qtype=None): model.apply(merge_qkv) if model.config.model_type == "minicpmv": from ipex_llm.transformers.models.minicpmv import merge_qkv - model.apply(merge_qkv) + model.vpm.apply(merge_qkv) if model.config.hidden_size == 3584 and model.config.vocab_size == 151666: model.llm.config.model_type = "qwen2" _optimize_pre(model.llm, qtype=qtype) @@ -1742,9 +1742,13 @@ def safe_bmm_fwd(*args, **kwargs): if model.vpm.config.model_type == "siglip": # MiniCPM-V 2.6 from ipex_llm.transformers.models.minicpmv import siglip_attention_forward - convert_forward(model, vpm_module.SiglipAttention, siglip_attention_forward) + convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward) elif model.vpm.config.model_type == "idefics2": # MiniCPM-V 2.5 - pass + from ipex_llm.transformers.models.minicpmv import siglip_attention_forward + from ipex_llm.transformers.models.minicpmv import minicpmv_chat_wrapper + convert_forward(model.vpm, vpm_module.Idefics2VisionAttention, siglip_attention_forward) + minicpmv_chat = minicpmv_chat_wrapper(module.MiniCPMV.chat) + model.chat = MethodType(minicpmv_chat, model) return model diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index 03d8d2f4075..f0118b6d701 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -18,11 +18,13 @@ import torch from typing import Optional from ipex_llm.transformers.models.common import merge_qkv_base +from transformers import AutoProcessor from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor def merge_qkv(module: torch.nn.Module): - return merge_qkv_base(module, "SiglipAttention") + merge_qkv_base(module, "SiglipAttention") + merge_qkv_base(module, "Idefics2VisionAttention") def siglip_attention_forward( @@ -67,6 +69,43 @@ def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: t return scores +def minicpmv_chat_wrapper(origin_chat): + def minicpmv_chat( + self, + image, + msgs, + tokenizer, + processor=None, + vision_hidden_states=None, + max_new_tokens=1024, + sampling=True, + max_inp_length=2048, + system_prompt='', + stream=False, + **kwargs + ): + if processor is None: + if getattr(self, "processor", None) is None: + self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, + trust_remote_code=True) + processor = self.processor + return origin_chat( + self=self, + image=image, + msgs=msgs, + tokenizer=tokenizer, + processor=processor, + vision_hidden_states=vision_hidden_states, + max_new_tokens=max_new_tokens, + sampling=sampling, + max_inp_length=max_inp_length, + system_prompt=system_prompt, + stream=stream, + **kwargs + ) + return minicpmv_chat + + def minicpmv_generate_wrapper(origin_generate): def generate( *inputs,