Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPU] Support l0 Llama groupwise #12276

Merged
merged 11 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--quantization_group_size", type=int, default=0)
parser.add_argument("--max-prompt-len", type=int, default=960)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)

Expand All @@ -63,6 +64,7 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],
pipeline=True,
max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len,
quantization_group_size=args.quantization_group_size,
torch_dtype=torch.float16,
attn_implementation="eager",
transpose_value_cache=not args.disable_transpose_value_cache)
Expand Down
5 changes: 3 additions & 2 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def from_pretrained(cls, *args, **kwargs):
"max_prompt_len": max_prompt_len,
"inter_pp": inter_pp,
"intra_pp": intra_pp,
"transpose_value_cache": transpose_value_cache,
"transpose_value_cache": transpose_value_cache
}
model = cls.optimize_npu_model(*args, **optimize_kwargs)
else:
Expand Down Expand Up @@ -260,7 +260,8 @@ def optimize_npu_model(cls, *args, **kwargs):
convert_llm(llm,
kv_len=max_context_len,
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache)
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size)

return model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ipex_llm.utils.common import invalidInputError
import tempfile
import numpy as np
from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead


def generate(
Expand Down Expand Up @@ -225,7 +226,14 @@ def update_names_of_IR_and_export_blob(model, model_name, dir):
def convert_llm(model: torch.nn.Module,
kv_len: int,
max_prompt_len: int,
transpose_value_cache: bool):
transpose_value_cache: bool,
group_size: int):
if group_size == 0:
n_splits_linear = 1
n_splits_down_proj = 1
else:
n_splits_linear = model.config.hidden_size // group_size
n_splits_down_proj = model.config.intermediate_size // group_size
if model.config.model_type == "llama":
from ipex_llm.transformers.npu_models.convert_mp import convert_llama
convert_llama(model,
Expand All @@ -247,7 +255,17 @@ def convert_llm(model: torch.nn.Module,
vocab_size = model.config.vocab_size
model_norm = model.model.norm
lm_head = model.lm_head
weights = [(lm_head.weight, lm_head.scale)]
if n_splits_linear == 1:
weights = [(lm_head.weight, lm_head.scale)]
else:
lm_heads = lm_head.lm_heads
lm_head_weights = []
scales = []
for i in range(n_splits_linear):
lm_head_weights.append(lm_heads[i].weight)
scales.append(lm_heads[i].scale)
weights = [(torch.stack(lm_head_weights, axis=0),
torch.stack(scales, axis=0))]
if isinstance(weights[0], tuple):
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
else: # FP16 Linear
Expand All @@ -264,13 +282,17 @@ def convert_llm(model: torch.nn.Module,
dtype=np_dtype,
model_norm_weight=model_norm.weight.to(torch.float16),
vocab_size=vocab_size,
n_splits=n_splits_linear
)
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)

# save weights bins files
weight_numpy = [
lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
]
if n_splits_linear == 1:
weight_numpy = [
lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
]
else:
weight_numpy = [v.numpy() for v in weights[0]]

for idx, weight in enumerate(weight_numpy):
bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
Expand All @@ -295,20 +317,41 @@ def convert_llm(model: torch.nn.Module,
mlp_layer = curr_layer.mlp

weights = []
for q, k, v, o, g, u, d in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
weights.append((d.weight, d.scale))
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0),
torch.stack(scales, axis=0)))

if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
Expand Down Expand Up @@ -336,6 +379,9 @@ def convert_llm(model: torch.nn.Module,
mode="decode",
transpose_value=transpose_value_cache,
dtype=np_dtype,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
)
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
"decoder_layer",
Expand Down Expand Up @@ -370,6 +416,9 @@ def convert_llm(model: torch.nn.Module,
invalidInputError(False,
"Now we only support Llama2 for pipeline running.")

if isinstance(model.lm_head, SlicedLMHead):
model.lm_head.get_fused_lm_head()

# patch generate function
import types
model.generate = types.MethodType(generate, model)
Expand Down
13 changes: 10 additions & 3 deletions python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
transpose_value: bool = False,
profile: bool = False,
device: str = "NPU",
n_splits: int = 1,
):
super().__init__(max_seq_len=max_seq_len,
transpose_value=transpose_value,
Expand Down Expand Up @@ -64,9 +65,15 @@ def __init__(
# model norm and lm head
model_norm_weight = self.constant(model_norm_weight)
hidden_states = self.layer_norm(hidden_states, model_norm_weight)
hidden_states = self.linear(
hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype
)
if n_splits == 1:
hidden_states = self.linear(
hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype
)
else:
hidden_states = self.dq_split_linear(
hidden_states, self.vocab_size, self.hidden_size, n_splits,
wt_dtype=dtype, scale_factor=False
)

# define outputs
hidden_states = self.convert_to_fp32(hidden_states)
Expand Down
Loading