diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index bf2293b61d9..afc12f4bfda 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -361,8 +361,8 @@ def merge_qkv(module: torch.nn.Module): def padding_mlp(module: torch.nn.Module): # for qwen 1.5 14B if isinstance(module, Qwen2MLP): - hidden_size = module.hidden_size - intermediate_size = module.intermediate_size + hidden_size = module.gate_proj.weight.shape[1] + intermediate_size = module.gate_proj.weight.shape[0] padding_intermediate_size = (intermediate_size + 256 - 1) // 256 * 256 if intermediate_size % 256 == 0: return @@ -371,21 +371,24 @@ def padding_mlp(module: torch.nn.Module): new_gate_weight = torch.zeros([padding_intermediate_size, hidden_size], dtype=gate_weight.dtype, device=gate_weight.device) new_gate_weight[:intermediate_size, :] = gate_weight - module.gate_proj.out_features = padding_intermediate_size + if hasattr(module.gate_proj, 'out_features'): + module.gate_proj.out_features = padding_intermediate_size module.gate_proj.weight = torch.nn.Parameter(new_gate_weight, requires_grad=False) up_weight = module.up_proj.weight.data new_up_weight = torch.zeros([padding_intermediate_size, hidden_size], dtype=up_weight.dtype, device=up_weight.device) new_up_weight[:intermediate_size, :] = up_weight - module.up_proj.out_features = padding_intermediate_size + if hasattr(module.gate_proj, 'out_features'): + module.up_proj.out_features = padding_intermediate_size module.up_proj.weight = torch.nn.Parameter(new_up_weight, requires_grad=False) down_weight = module.down_proj.weight.data new_down_weight = torch.zeros([hidden_size, padding_intermediate_size], dtype=down_weight.dtype, device=down_weight.device) new_down_weight[:, :intermediate_size] = down_weight - module.down_proj.in_features = padding_intermediate_size + if hasattr(module.gate_proj, 'out_features'): + module.down_proj.in_features = padding_intermediate_size module.down_proj.weight = torch.nn.Parameter(new_down_weight, requires_grad=False)