Skip to content

Commit

Permalink
Fix shape error when run qwen1.5-14b using deepspeed autotp (#11420)
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang authored Jun 25, 2024
1 parent 3b23de6 commit aacc1fd
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ def merge_qkv(module: torch.nn.Module):
def padding_mlp(module: torch.nn.Module):
# for qwen 1.5 14B
if isinstance(module, Qwen2MLP):
hidden_size = module.hidden_size
intermediate_size = module.intermediate_size
hidden_size = module.gate_proj.weight.shape[1]
intermediate_size = module.gate_proj.weight.shape[0]
padding_intermediate_size = (intermediate_size + 256 - 1) // 256 * 256
if intermediate_size % 256 == 0:
return
Expand All @@ -371,21 +371,24 @@ def padding_mlp(module: torch.nn.Module):
new_gate_weight = torch.zeros([padding_intermediate_size, hidden_size],
dtype=gate_weight.dtype, device=gate_weight.device)
new_gate_weight[:intermediate_size, :] = gate_weight
module.gate_proj.out_features = padding_intermediate_size
if hasattr(module.gate_proj, 'out_features'):
module.gate_proj.out_features = padding_intermediate_size
module.gate_proj.weight = torch.nn.Parameter(new_gate_weight, requires_grad=False)

up_weight = module.up_proj.weight.data
new_up_weight = torch.zeros([padding_intermediate_size, hidden_size],
dtype=up_weight.dtype, device=up_weight.device)
new_up_weight[:intermediate_size, :] = up_weight
module.up_proj.out_features = padding_intermediate_size
if hasattr(module.gate_proj, 'out_features'):
module.up_proj.out_features = padding_intermediate_size
module.up_proj.weight = torch.nn.Parameter(new_up_weight, requires_grad=False)

down_weight = module.down_proj.weight.data
new_down_weight = torch.zeros([hidden_size, padding_intermediate_size],
dtype=down_weight.dtype, device=down_weight.device)
new_down_weight[:, :intermediate_size] = down_weight
module.down_proj.in_features = padding_intermediate_size
if hasattr(module.gate_proj, 'out_features'):
module.down_proj.in_features = padding_intermediate_size
module.down_proj.weight = torch.nn.Parameter(new_down_weight, requires_grad=False)


Expand Down

0 comments on commit aacc1fd

Please sign in to comment.