Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New convert support for C++ NPU #12430

Merged
merged 6 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
transpose_value_cache=not args.disable_transpose_value_cache,
mixed_precision=True,
trust_remote_code=True,
compile_full_model=True,
convert_model=True,
save_directory=save_dir)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
Expand Down
8 changes: 4 additions & 4 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def from_pretrained(cls, *args, **kwargs):
mixed_precision = kwargs.pop('mixed_precision', False)
quantization_group_size = kwargs.pop("quantization_group_size", 0)
mock_device = kwargs.pop('device', None) # For mock on CPU
compile_full_model = kwargs.pop('compile_full_model', False)
convert_model = kwargs.pop('convert_model', False)
save_directory = kwargs.pop('save_directory', None)

invalidInputError(
Expand Down Expand Up @@ -202,7 +202,7 @@ def from_pretrained(cls, *args, **kwargs):
"inter_pp": inter_pp,
"intra_pp": intra_pp,
"transpose_value_cache": transpose_value_cache,
"compile_full_model": compile_full_model,
"convert_model": convert_model,
"save_directory": save_directory,
}
model = cls.optimize_npu_model(*args, **optimize_kwargs)
Expand Down Expand Up @@ -241,7 +241,7 @@ def optimize_npu_model(cls, *args, **kwargs):
inter_pp = kwargs.pop("inter_pp", None)
intra_pp = kwargs.pop("intra_pp", None)
transpose_value_cache = kwargs.pop("transpose_value_cache", True)
compile_full_model = kwargs.pop('compile_full_model', False)
convert_model = kwargs.pop('convert_model', False)
save_directory = kwargs.pop('save_directory', None)

if hasattr(model, "llm"):
Expand Down Expand Up @@ -280,7 +280,7 @@ def optimize_npu_model(cls, *args, **kwargs):
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size,
compile_full_model=compile_full_model,
convert_model=convert_model,
save_directory=save_directory)
model.save_low_bit = types.MethodType(save_low_bit, model)
return model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def convert_llm(model: torch.nn.Module,
max_prompt_len: int,
transpose_value_cache: bool,
group_size: int,
compile_full_model: bool=False,
convert_model: bool=False,
save_directory: str=None):
# whether to set layernorm weight as const
layernorm_const = os.environ.get("IPEX_LLM_LAYERNORM_CONST", "1") == "1"
Expand All @@ -203,6 +203,16 @@ def convert_llm(model: torch.nn.Module,
else:
n_splits_linear = model.config.hidden_size // group_size
n_splits_down_proj = model.config.intermediate_size // group_size
if convert_model:
convert_llm_for_deploy(model,
kv_len,
max_prompt_len,
transpose_value_cache,
n_splits_linear,
n_splits_down_proj,
group_size,
save_directory)
return 0
if model.config.model_type == "llama":
with tempfile.TemporaryDirectory() as temp_dir:
weight_dir = os.path.join(temp_dir, "model_weights")
Expand Down Expand Up @@ -340,7 +350,7 @@ def convert_llm(model: torch.nn.Module,
from .qwen import convert_qwen_layer, convert_lm_head_and_embedding
first_blob_path, last_blob_path = convert_lm_head_and_embedding(model, n_splits_linear,
temp_dir, weight_dir,
compile_full_model)
convert_model)

param_list = []
for layer_idx in range(0, layer_num):
Expand All @@ -350,11 +360,6 @@ def convert_llm(model: torch.nn.Module,
with Pool() as pool:
result = pool.starmap(convert_qwen_layer, param_list)

if compile_full_model:
convert_qwen_layer(model, 0, n_splits_linear, n_splits_down_proj,
temp_dir, weight_dir, transpose_value_cache, max_prompt_len,
group_size, layernorm_const, "prefill")

# Prefill Runner
from ipex_llm.transformers.npu_models.convert_mp import convert_qwen
convert_qwen(model,
Expand Down Expand Up @@ -403,3 +408,48 @@ def convert_llm(model: torch.nn.Module,
import types
model.generate = types.MethodType(generate, model)
return model


def convert_llm_for_deploy(model: torch.nn.Module,
kv_len: int,
max_prompt_len: int,
transpose_value_cache: bool,
n_splits_linear: int,
n_splits_down_proj: int,
group_size: int,
save_directory: str=None):
os.mkdir(save_directory)
weight_dir = os.path.join(save_directory, "model_weights")
os.mkdir(weight_dir)

if model.config.model_type == "qwen2":
layernorm_const = True
if model.config.hidden_size == 1536:
# Qwen2-1.5B-Instruct
fused_layers = 1
else:
fused_layers = 2
update_dict = {"kv_len": kv_len,
"num_head": model.model.layers[0].self_attn.num_heads,
"head_dim": model.model.layers[0].self_attn.head_dim,
"transpose_value_cache": transpose_value_cache,
"max_prompt_len": max_prompt_len,
"layernorm_const": layernorm_const,
"group_size": group_size,
"fused_layers": fused_layers}
model.config.update(update_dict)
model.config.save_pretrained(save_directory)

from .qwen import convert_qwen_layer, convert_fused_qwen_layer
from .qwen import convert_lm_head_and_embedding
# save fused_layers blobs of fused decoder layers
convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
save_directory, weight_dir, transpose_value_cache, kv_len,
group_size, layernorm_const, "decode")
# save blob of single prefill layer
convert_qwen_layer(model, 0, n_splits_linear, n_splits_down_proj,
save_directory, weight_dir, transpose_value_cache, max_prompt_len,
group_size, layernorm_const, "prefill")
# save blob of lmhead and bin of embedding
convert_lm_head_and_embedding(model, n_splits_linear,
save_directory, weight_dir, True)
125 changes: 118 additions & 7 deletions python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
compile_full_model=False):
convert_model=False):
num_heads = model.model.layers[0].self_attn.num_heads
head_dim = model.model.layers[0].self_attn.head_dim
rms_norm_eps = model.config.rms_norm_eps
Expand Down Expand Up @@ -60,7 +60,7 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
)

last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, f"lm_head",
temp_dir, True, True)
temp_dir, True, False)

# save weights bins files
if not isinstance(lm_head, SlicedLMHead):
Expand All @@ -83,11 +83,13 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
dtype=np.float16,
input_length=1,
)
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, f"embedding",
temp_dir, True, keep_ir=True)
if compile_full_model:
if convert_model:
bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
first_blob_path = True
else:
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, f"embedding",
temp_dir, True, keep_ir=True)
return first_blob_path, last_blob_path


Expand Down Expand Up @@ -138,8 +140,8 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
else:
input_len = kv_len
decoder_name = "decoder_layer_prefill"
compile = False
keep_ir = True
compile = True
keep_ir = False
single_decoder = LowBitQwenMultiDecoderlayer(
[1, input_len, num_heads * head_dim],
input_layernorm_weights=None,
Expand Down Expand Up @@ -190,3 +192,112 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
scale.numpy().tofile(bin_file)

del single_decoder


def convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
save_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const, mode="decode"):
num_heads = model.model.layers[0].self_attn.num_heads
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
head_dim = model.model.layers[0].self_attn.head_dim
intermediate_size = model.config.intermediate_size
rms_norm_eps = model.config.rms_norm_eps
layer_num = len(model.model.layers)
fused_layer_num = layer_num // fused_layers

from ipex_llm.transformers.npu_models.qwen2_mp import LowBitQwenMultiDecoderlayer
for i in range(fused_layers):
layer_start = i * fused_layer_num
layer_end = min((i + 1) * fused_layer_num, layer_num)
layer_weights = []
input_layer_norm_weights = []
post_attn_layernorm_weights = []
q_biases = []
k_biases = []
v_biases = []
layer_indexs = range(layer_start, layer_end)
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
for layer_idx in layer_indexs:
curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn
mlp_layer = curr_layer.mlp

weights = []
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)

layer_weights.extend(weights)
input_layer_norm_weights.append(layer_norm_0)
post_attn_layernorm_weights.append(layer_norm_1)
q_biases.append(attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16))
k_biases.append(attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16))
v_biases.append(attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16))

# save weight
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
st_idx = 5
# 5 / 6 / 7 are bias
q_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx}.bin")
k_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+1}.bin")
v_bias_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+2}.bin")
q_biases[-1].data.numpy().tofile(q_bias_bin_file)
k_biases[-1].data.numpy().tofile(k_bias_bin_file)
v_biases[-1].data.numpy().tofile(v_bias_bin_file)
# 6, 7 are past k/v
for idx, (weight, scale) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+3+idx*2}.bin")
weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir,
f"model_{layer_idx}_input_{st_idx+3+idx*2+1}.bin")
scale.numpy().tofile(bin_file)

if isinstance(weights[0], tuple):
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
else: # FP16 Linear
np_dtype = np.float16

fused_decoder = LowBitQwenMultiDecoderlayer(
[1, 1, num_heads * head_dim],
input_layernorm_weights=input_layer_norm_weights,
post_attn_layernorm_weights=post_attn_layernorm_weights,
q_biases=q_biases,
k_biases=k_biases,
v_biases=v_biases,
cached_cos=cached_cos,
cached_sin=cached_sin,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
num_layers=fused_layer_num,
max_seq_len=kv_len,
rms_norm_eps=rms_norm_eps,
intermediate_size=intermediate_size,
mode=mode,
transpose_value=transpose_value_cache,
dtype=np_dtype,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
)
update_names_of_IR_and_export_blob(fused_decoder,
f"decoder_layer_{i}",
save_dir,
compile_blob=True,
keep_ir=False)
return 0
Loading