Skip to content

Commit

Permalink
[NPU] Support Baichuan groupwise & gw code refactor (#12337)
Browse files Browse the repository at this point in the history
* support minicpm 1b & qwen 1.5b gw

* support minicpm 1b

* baichuan part

* update

* support minicpm 1b & qwen 1.5b gw

* support minicpm 1b

* baichuan part

* update

* update

* update

* baichuan support

* code refactor

* remove code

* fix style

* address comments

* revert
  • Loading branch information
cyita authored Nov 8, 2024
1 parent 812d5cc commit b2e69a8
Show file tree
Hide file tree
Showing 13 changed files with 367 additions and 434 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--quantization_group_size", type=int, default=0)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)

Expand All @@ -72,6 +73,7 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],
pipeline=True,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
quantization_group_size=args.quantization_group_size,
torch_dtype=torch.float16,
attn_implementation="eager",
transpose_value_cache=not args.disable_transpose_value_cache,
Expand Down
209 changes: 176 additions & 33 deletions python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion python/llm/src/ipex_llm/transformers/npu_models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,11 @@ def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2Attention
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"]
baichuan_attn_module_names = ["W_pack", "o_proj"]
mlp_module_names = ["down_proj", "up_proj", "gate_proj"]
if (
isinstance(module, (Qwen2Attention, LlamaAttention))
or module.__class__.__name__ in ['MiniCPMAttention', 'Attention']
or module.__class__.__name__ in ['MiniCPMAttention']
):
for name in attn_module_names:
setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
Expand All @@ -97,3 +98,10 @@ def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down
n_splits=n_splits_mlp,
load=load))
delattr(module, name)
elif module.__class__.__name__ == 'Attention' and module.config.model_type == 'baichuan':
# baichuan attention
for name in baichuan_attn_module_names:
setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
n_splits=n_splits_hidden_size,
load=load))
delattr(module, name)
6 changes: 4 additions & 2 deletions python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
model.llm.config.model_type = "llama"
model = model.llm

if model.config.model_type in ["qwen2", "llama", "minicpm"]:
if model.config.model_type in ["qwen2", "llama", "minicpm", "baichuan"]:
from ipex_llm.transformers.npu_models.common import split_linears
if quantization_group_size == 0:
n_splits_linear = 1
Expand Down Expand Up @@ -245,6 +245,8 @@ def convert_baichuan(
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
convert_forward(model, module.BaichuanModel, baichuan_model_forward)
from ipex_llm.transformers.npu_models.baichuan_mp import baichuan2_causal_forward
convert_forward(model, module.BaichuanForCausalLM, baichuan2_causal_forward)


def convert_minicpm(
Expand Down Expand Up @@ -392,7 +394,7 @@ def optimize_llm(
if intra_pp is None:
intra_pp = 2
if inter_pp is None:
inter_pp = 2
inter_pp = 2 if group_size == 0 else 4
convert_baichuan(model,
max_output_len=max_context_len,
max_prompt_len=max_prompt_len,
Expand Down
69 changes: 10 additions & 59 deletions python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,37 +560,13 @@ def run_decode(
mlp_layer = curr_layer.mlp

weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
Expand Down Expand Up @@ -844,38 +820,13 @@ def run_prefill(

weights = []

if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0),
torch.stack(scales, axis=0)))

if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
Expand Down
69 changes: 10 additions & 59 deletions python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,37 +540,13 @@ def run_decode(
mlp_layer = curr_layer.mlp

weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
Expand Down Expand Up @@ -783,38 +759,13 @@ def run_prefill(
mlp_layer = curr_layer.mlp

weights = []

if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
Expand Down
Loading

0 comments on commit b2e69a8

Please sign in to comment.