Skip to content

Commit

Permalink
optimize minicpm v 2.5 (#11793)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Aug 14, 2024
1 parent 356281c commit 3d6cfa2
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
10 changes: 7 additions & 3 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def _optimize_pre(model, qtype=None):
model.apply(merge_qkv)
if model.config.model_type == "minicpmv":
from ipex_llm.transformers.models.minicpmv import merge_qkv
model.apply(merge_qkv)
model.vpm.apply(merge_qkv)
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
model.llm.config.model_type = "qwen2"
_optimize_pre(model.llm, qtype=qtype)
Expand Down Expand Up @@ -1742,9 +1742,13 @@ def safe_bmm_fwd(*args, **kwargs):
if model.vpm.config.model_type == "siglip":
# MiniCPM-V 2.6
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
convert_forward(model, vpm_module.SiglipAttention, siglip_attention_forward)
convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
elif model.vpm.config.model_type == "idefics2":
# MiniCPM-V 2.5
pass
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
from ipex_llm.transformers.models.minicpmv import minicpmv_chat_wrapper
convert_forward(model.vpm, vpm_module.Idefics2VisionAttention, siglip_attention_forward)
minicpmv_chat = minicpmv_chat_wrapper(module.MiniCPMV.chat)
model.chat = MethodType(minicpmv_chat, model)

return model
41 changes: 40 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
import torch
from typing import Optional
from ipex_llm.transformers.models.common import merge_qkv_base
from transformers import AutoProcessor
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor


def merge_qkv(module: torch.nn.Module):
return merge_qkv_base(module, "SiglipAttention")
merge_qkv_base(module, "SiglipAttention")
merge_qkv_base(module, "Idefics2VisionAttention")


def siglip_attention_forward(
Expand Down Expand Up @@ -67,6 +69,43 @@ def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: t
return scores


def minicpmv_chat_wrapper(origin_chat):
def minicpmv_chat(
self,
image,
msgs,
tokenizer,
processor=None,
vision_hidden_states=None,
max_new_tokens=1024,
sampling=True,
max_inp_length=2048,
system_prompt='',
stream=False,
**kwargs
):
if processor is None:
if getattr(self, "processor", None) is None:
self.processor = AutoProcessor.from_pretrained(self.config._name_or_path,
trust_remote_code=True)
processor = self.processor
return origin_chat(
self=self,
image=image,
msgs=msgs,
tokenizer=tokenizer,
processor=processor,
vision_hidden_states=vision_hidden_states,
max_new_tokens=max_new_tokens,
sampling=sampling,
max_inp_length=max_inp_length,
system_prompt=system_prompt,
stream=stream,
**kwargs
)
return minicpmv_chat


def minicpmv_generate_wrapper(origin_generate):
def generate(
*inputs,
Expand Down

0 comments on commit 3d6cfa2

Please sign in to comment.