diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 9981aea3c8b..33c6b83da69 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -740,13 +740,18 @@ def _optimize_pre(model, qtype=None): from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp model.apply(pre_process_attn_and_mlp) if model.config.model_type == "internvl_chat": - _optimize_pre(model.language_model) + _optimize_pre(model.language_model, qtype=qtype) if model.config.model_type == "gemma2": from ipex_llm.transformers.models.gemma2 import merge_qkv model.apply(merge_qkv) if model.config.model_type == "llama": from ipex_llm.transformers.models.llama import merge_qkv model.apply(merge_qkv) + if model.config.model_type == "minicpmv": + if model.config.hidden_size == 3584 and model.config.vocab_size == 151666: + model.llm.config.model_type = "qwen2" + _optimize_pre(model.llm, qtype=qtype) + model.llm.config.model_type = "minicpmv" return model @@ -1747,5 +1752,15 @@ def safe_bmm_fwd(*args, **kwargs): convert_forward(model, module.MiniCPMModel, minicpm_model_forward) + elif model.config.model_type == "minicpmv": + if model.config.hidden_size == 3584 and model.config.vocab_size == 151666: + model.llm.config.model_type = "qwen2" + _optimize_post(model.llm, lightweight_bmm=lightweight_bmm) + model.llm.config.model_type = "minicpmv" + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from ipex_llm.transformers.models.minicpmv import minicpmv_generate_wrapper + minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate) + model.generate = MethodType(minicpmv_generate, model) return model diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py new file mode 100644 index 00000000000..340285ed193 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -0,0 +1,49 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +def minicpmv_generate_wrapper(origin_generate): + def generate( + self, + input_ids=None, + pixel_values=None, + tgt_sizes=None, + image_bound=None, + attention_mask=None, + tokenizer=None, + vision_hidden_states=None, + return_vision_hidden_states=False, + stream=False, + decode_text=False, + **kwargs + ): + if kwargs.get("repetition_penalty", None) is not None: + kwargs["repetition_penalty"] = 1 + return origin_generate( + self=self, + input_ids=input_ids, + pixel_values=pixel_values, + tgt_sizes=tgt_sizes, + image_bound=image_bound, + attention_mask=attention_mask, + tokenizer=tokenizer, + vision_hidden_states=vision_hidden_states, + return_vision_hidden_states=return_vision_hidden_states, + stream=stream, + decode_text=decode_text, + **kwargs + ) + return generate