Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix and optimize minicpm-v-2 #11799

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
@@ -1726,6 +1726,11 @@ def safe_bmm_fwd(*args, **kwargs):
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
model.generate = MethodType(minicpmv_generate, model)

if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
# MiniCPM-V 2
model.llm.config.model_type = "minicpm"
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
model.llm.config.model_type = "minicpmv"
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
# MiniCPM-V 2.6
model.llm.config.model_type = "qwen2"
@@ -1739,7 +1744,11 @@ def safe_bmm_fwd(*args, **kwargs):

vpm_modeling_module_name = model.vpm.__class__.__module__
vpm_module = importlib.import_module(vpm_modeling_module_name)
if model.vpm.config.model_type == "siglip":
if not hasattr(model.vpm, "config"):
# MiniCPM-V 2
from ipex_llm.transformers.models.minicpmv import minicpmv_get_vision_embedding
model.get_vision_embedding = MethodType(minicpmv_get_vision_embedding, model)
elif model.vpm.config.model_type == "siglip":
# MiniCPM-V 2.6
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
46 changes: 35 additions & 11 deletions python/llm/src/ipex_llm/transformers/models/minicpmv.py
Original file line number Diff line number Diff line change
@@ -15,18 +15,21 @@
#


import math
import torch
from typing import Optional
from ipex_llm.transformers.models.common import merge_qkv_base
from transformers import AutoProcessor
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor


# MiniCPM-V-2_5 and MiniCPM-V-2_6
def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, "SiglipAttention")
merge_qkv_base(module, "Idefics2VisionAttention")


# MiniCPM-V-2_5 and MiniCPM-V-2_6
def siglip_attention_forward(
self,
hidden_states: torch.Tensor,
@@ -58,17 +61,7 @@ def siglip_attention_forward(
return attn_output, attn_weights


def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
if scores.device.type == "xpu":
import xe_addons
xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
else:
score = torch.gather(scores, 1, input_ids)
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores


# MiniCPM-V-2_5
def minicpmv_chat_wrapper(origin_chat):
def minicpmv_chat(
self,
@@ -106,6 +99,37 @@ def minicpmv_chat(
return minicpmv_chat


# MiniCPM-V-2
def minicpmv_get_vision_embedding(self, pixel_values):
res = []
dtype = self.dtype

def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
H, W = pixel_value.shape[-2:]
target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
vision_embedding = self.vpm_forward_features(pixel_value.unsqueeze(0).type(dtype))

if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
return resampler(vision_embedding, target_size)

for pixel_value in pixel_values:
result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
res.append(result)
return torch.vstack(res)


def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
if scores.device.type == "xpu":
import xe_addons
xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty)
else:
score = torch.gather(scores, 1, input_ids)
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores


def minicpmv_generate_wrapper(origin_generate):
def generate(
*inputs,