Skip to content

Commit

Permalink
Enable Qwen padding mlp to 256 to support batch_forward (#12030)
Browse files Browse the repository at this point in the history
* Enable padding mlp

* padding to 256

* update style
  • Loading branch information
hzjane authored and gc-fu committed Sep 10, 2024
1 parent f5c55cd commit 3dc33de
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions python/llm/src/ipex_llm/vllm/xpu/model_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def _ipex_llm_load_model(self) -> None:
scheduler_config=self.scheduler_config,
cache_config=self.cache_config,
)
if "qwen" in self.model_config.model.lower() and \
self.model.model.layers[0].mlp.down_proj.input_size_per_partition % 256 != 0:
self.model.apply(padding_mlp)
from ipex_llm import optimize_model
import os
not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None)
Expand All @@ -250,3 +253,30 @@ def _ipex_llm_load_model(self) -> None:
self.model_memory_usage / float(2**30))

return _ipex_llm_load_model


def padding_mlp(module: torch.nn.Module):
if isinstance(module, Qwen2MLP):
hidden_size = module.down_proj.output_size
# devide by rank
intermediate_size = module.down_proj.input_size_per_partition
padding_size = 256
padding_intermediate_size = \
(intermediate_size + padding_size - 1) // padding_size * padding_size
if intermediate_size % padding_size == 0:
return
gate_up_weight = module.gate_up_proj.weight.data
new_gate_up_weight = torch.zeros([padding_intermediate_size * 2, hidden_size],
dtype=gate_up_weight.dtype, device=gate_up_weight.device)
# merge_gate_up_weight
new_gate_up_weight[:intermediate_size, :] = gate_up_weight[:intermediate_size, :]
new_gate_up_weight[padding_intermediate_size:padding_intermediate_size+intermediate_size, :] = gate_up_weight[intermediate_size:, :] # noqa
module.gate_up_proj.output_size_per_partition = padding_intermediate_size * 2
module.gate_up_proj.weight = torch.nn.Parameter(new_gate_up_weight, requires_grad=False)

down_weight = module.down_proj.weight.data
new_down_weight = torch.zeros([hidden_size, padding_intermediate_size],
dtype=down_weight.dtype, device=down_weight.device)
new_down_weight[:, :intermediate_size] = down_weight
module.down_proj.input_size_per_partition = padding_intermediate_size
module.down_proj.weight = torch.nn.Parameter(new_down_weight, requires_grad=False)

0 comments on commit 3dc33de

Please sign in to comment.