Skip to content

Commit

Permalink
optimize minicpm v 2_6 firs token perf (#11770)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Aug 13, 2024
1 parent 841dbcd commit a1eb793
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
7 changes: 7 additions & 0 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,8 @@ def _optimize_pre(model, qtype=None):
from ipex_llm.transformers.models.llama import merge_qkv
model.apply(merge_qkv)
if model.config.model_type == "minicpmv":
from ipex_llm.transformers.models.minicpmv import merge_qkv
model.apply(merge_qkv)
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
model.llm.config.model_type = "qwen2"
_optimize_pre(model.llm, qtype=qtype)
Expand Down Expand Up @@ -1763,4 +1765,9 @@ def safe_bmm_fwd(*args, **kwargs):
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
model.generate = MethodType(minicpmv_generate, model)

modeling_module_name = model.vpm.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
convert_forward(model, module.SiglipAttention, siglip_attention_forward)

return model
5 changes: 4 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:


def merge_qkv_base(module: torch.nn.Module, attention_class):
if isinstance(module, attention_class):
if (
isinstance(attention_class, str) and module.__class__.__name__ == attention_class
or not isinstance(attention_class, str) and isinstance(module, attention_class)
):
qkv_proj = merge_linear([
module.q_proj,
module.k_proj,
Expand Down
36 changes: 36 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,45 @@


import torch
from typing import Optional
from ipex_llm.transformers.models.common import merge_qkv_base
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor


def merge_qkv(module: torch.nn.Module):
return merge_qkv_base(module, "SiglipAttention")


def siglip_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
):
bsz, q_len, _ = hidden_states.size()

qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.chunk(3, dim=1)

attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

return attn_output, attn_weights


def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
if scores.device.type == "xpu":
import xe_addons
Expand Down

0 comments on commit a1eb793

Please sign in to comment.