From f1fff59e6918ea1b8fafa06d60e8adbc5bb73ecd Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 14 Aug 2024 16:04:31 +0800 Subject: [PATCH] fix and optimize minicpm v 2 --- .../llm/src/ipex_llm/transformers/convert.py | 11 ++++- .../ipex_llm/transformers/models/minicpmv.py | 46 ++++++++++++++----- 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 34db3f0fbe7..4123035ec87 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1726,6 +1726,11 @@ def safe_bmm_fwd(*args, **kwargs): minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate) model.generate = MethodType(minicpmv_generate, model) + if model.config.hidden_size == 2304 and model.config.vocab_size == 122753: + # MiniCPM-V 2 + model.llm.config.model_type = "minicpm" + _optimize_post(model.llm, lightweight_bmm=lightweight_bmm) + model.llm.config.model_type = "minicpmv" if model.config.hidden_size == 3584 and model.config.vocab_size == 151666: # MiniCPM-V 2.6 model.llm.config.model_type = "qwen2" @@ -1739,7 +1744,11 @@ def safe_bmm_fwd(*args, **kwargs): vpm_modeling_module_name = model.vpm.__class__.__module__ vpm_module = importlib.import_module(vpm_modeling_module_name) - if model.vpm.config.model_type == "siglip": + if not hasattr(model.vpm, "config"): + # MiniCPM-V 2 + from ipex_llm.transformers.models.minicpmv import minicpmv_get_vision_embedding + model.get_vision_embedding = MethodType(minicpmv_get_vision_embedding, model) + elif model.vpm.config.model_type == "siglip": # MiniCPM-V 2.6 from ipex_llm.transformers.models.minicpmv import siglip_attention_forward convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward) diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index f0118b6d701..15bf61d4009 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -15,6 +15,7 @@ # +import math import torch from typing import Optional from ipex_llm.transformers.models.common import merge_qkv_base @@ -22,11 +23,13 @@ from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor +# MiniCPM-V-2_5 and MiniCPM-V-2_6 def merge_qkv(module: torch.nn.Module): merge_qkv_base(module, "SiglipAttention") merge_qkv_base(module, "Idefics2VisionAttention") +# MiniCPM-V-2_5 and MiniCPM-V-2_6 def siglip_attention_forward( self, hidden_states: torch.Tensor, @@ -58,17 +61,7 @@ def siglip_attention_forward( return attn_output, attn_weights -def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): - if scores.device.type == "xpu": - import xe_addons - xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty) - else: - score = torch.gather(scores, 1, input_ids) - score = torch.where(score < 0, score * self.penalty, score / self.penalty) - scores.scatter_(1, input_ids, score) - return scores - - +# MiniCPM-V-2_5 def minicpmv_chat_wrapper(origin_chat): def minicpmv_chat( self, @@ -106,6 +99,37 @@ def minicpmv_chat( return minicpmv_chat +# MiniCPM-V-2 +def minicpmv_get_vision_embedding(self, pixel_values): + res = [] + dtype = self.dtype + + def process_each_pixel(pixel_value, dtype, config, vpm, resampler): + H, W = pixel_value.shape[-2:] + target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size)) + vision_embedding = self.vpm_forward_features(pixel_value.unsqueeze(0).type(dtype)) + + if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0: + vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:] + return resampler(vision_embedding, target_size) + + for pixel_value in pixel_values: + result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler) + res.append(result) + return torch.vstack(res) + + +def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + if scores.device.type == "xpu": + import xe_addons + xe_addons.repetition_penalty_logits_process_inplaced(scores, input_ids, self.penalty) + else: + score = torch.gather(scores, 1, input_ids) + score = torch.where(score < 0, score * self.penalty, score / self.penalty) + scores.scatter_(1, input_ids, score) + return scores + + def minicpmv_generate_wrapper(origin_generate): def generate( *inputs,