From fcb474820d03e080c846487effbec6f9bd7c42c0 Mon Sep 17 00:00:00 2001 From: Zijie Li Date: Tue, 17 Dec 2024 01:01:17 -0500 Subject: [PATCH] [NPU] support asym_int4 for llama (#12556) * add llama-imatrix * fix bugs in llama.py * style fix --- .../transformers/npu_models/llama_mp.py | 41 +++++-- .../transformers/npu_pipeline_model/llama.py | 115 ++++++++++++++---- 2 files changed, 124 insertions(+), 32 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index 69f618c888a..67de043337d 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -72,6 +72,7 @@ def __init__( group_size: int = 0, cos_len: int = 1, keep_position_ids=True, + asym: bool = False, ): super().__init__(max_seq_len=max_seq_len, transpose_value=transpose_value, @@ -80,7 +81,8 @@ def __init__( device=device, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size) + group_size=group_size, + asym=asym) self.max_seq_len = max_seq_len self.intermediate_size = intermediate_size self.dtype = dtype @@ -278,7 +280,8 @@ def __init__( do_print: bool = False, n_splits_linear: int = 1, n_splits_down_proj: int = 1, - group_size: int = 0 + group_size: int = 0, + asym: bool = False, ): super().__init__() @@ -286,8 +289,10 @@ def __init__( op_parameters = [] for w in parameters: - if isinstance(w, tuple): # from QuantizedLinear + if isinstance(w, tuple) and not asym: # from QuantizedLinear op_parameters.append((w[0].numpy(), w[1].numpy())) + elif isinstance(w, tuple) and asym: # from QuantizedLinear + op_parameters.append((w[0].numpy(), w[1].numpy(), w[2].numpy())) elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight op_parameters.append(w.numpy()) elif isinstance(w, np.ndarray): # scale @@ -341,7 +346,8 @@ def __init__( dtype=np_dtype, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + asym=asym, ) self.backend_decoders.append(decoder) @@ -427,6 +433,7 @@ def __init__( n_splits_down_proj: int = 1, group_size: int = 0, cos_len: int = 1, + asym: bool = False, ): super().__init__() self.op_parameters = parameters @@ -460,6 +467,7 @@ def __init__( n_splits_down_proj=n_splits_down_proj, group_size=group_size, cos_len=cos_len, + asym=asym, ) self.layer_norm_0 = layer_norm_0 self.layer_norm_1 = layer_norm_1 @@ -555,6 +563,7 @@ def run_decode( layer_indexs = range(layer_start, layer_end) n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) + asym = getattr(model.config, "asym", False) for layer_idx in layer_indexs: curr_layer = model.model.layers[layer_idx] attn_layer = curr_layer.self_attn @@ -567,10 +576,17 @@ def run_decode( mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] + zeros = [] for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + if l.zero is not None: + zeros.append(l.zero) + if len(zeros): + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0), + torch.stack(zeros, axis=0))) + else: + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) @@ -603,7 +619,8 @@ def run_decode( do_print=False, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + asym=asym, ) dist.barrier() @@ -814,6 +831,7 @@ def run_prefill( layer_indexs = range(layer_start, layer_end) n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) + asym = getattr(model.config, "asym", False) for layer_idx in layer_indexs: curr_layer = model.model.layers[layer_idx] attn_layer = curr_layer.self_attn @@ -827,10 +845,18 @@ def run_prefill( mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] + zeros = [] for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + if l.zero is not None: + zeros.append(l.zero) + if len(zeros): + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0), + torch.stack(zeros, axis=0))) + else: + weights.append((torch.stack(l_weights, axis=0), + torch.stack(scales, axis=0))) if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) @@ -859,6 +885,7 @@ def run_prefill( n_splits_down_proj=n_splits_down_proj, group_size=group_size, cos_len=cos_len, + asym=asym, ) layer_weights.extend(weights) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py index 435ba4ff5f1..03700b053cc 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py @@ -130,17 +130,31 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, vocab_size = model.config.vocab_size model_norm = model.model.norm lm_head = model.lm_head + asym = getattr(model.config, "asym", False) if n_splits_linear == 1: - weights = [(lm_head.weight, lm_head.scale)] + asym = lm_head.qtype == "asym_int4_rtn" + if asym: + weights = [(lm_head.weight, lm_head.scale, lm_head.zero)] + else: + weights = [(lm_head.weight, lm_head.scale)] else: lm_heads = lm_head.lm_heads + asym = lm_heads[0].qtype == "asym_int4_rtn" lm_head_weights = [] scales = [] - for i in range(n_splits_linear): - lm_head_weights.append(lm_heads[i].weight) - scales.append(lm_heads[i].scale) - weights = [(torch.stack(lm_head_weights, axis=0), - torch.stack(scales, axis=0))] + zeros = [] + for l in lm_heads: + lm_head_weights.append(l.weight) + scales.append(l.scale) + if l.zero is not None: + zeros.append(l.zero) + if len(zeros): + weights = [(torch.stack(lm_head_weights, axis=0), + torch.stack(scales, axis=0), + torch.stack(zeros, axis=0))] + else: + weights = [(torch.stack(lm_head_weights, axis=0), + torch.stack(scales, axis=0))] if isinstance(weights[0], tuple): np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 else: # FP16 Linear @@ -156,16 +170,23 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, dtype=np_dtype, model_norm_weight=model_norm.weight.to(torch.float16), vocab_size=vocab_size, - n_splits=n_splits_linear + n_splits=n_splits_linear, + asym=asym ) last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir, True, False) # save weights bins files if n_splits_linear == 1: - weight_numpy = [ - lm_head.weight.data.numpy(), lm_head.scale.data.numpy(), - ] + if not asym: + weight_numpy = [ + lm_head.weight.data.numpy(), lm_head.scale.data.numpy(), + ] + else: + weight_numpy = [ + lm_head.weight.data.numpy(), lm_head.scale.data.numpy(), + lm_head.zero.data.numpy() + ] else: weight_numpy = [v.numpy() for v in weights[0]] @@ -234,6 +255,7 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, head_dim = model.model.layers[0].self_attn.head_dim intermediate_size = model.config.intermediate_size rms_norm_eps = model.config.rms_norm_eps + asym = getattr(model.config, "asym", False) from ipex_llm.transformers.npu_models.llama_mp import LowBitLlamaMultiDecoderlayer curr_layer = model.model.layers[layer_idx] @@ -247,10 +269,17 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] + zeros = [] for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + if l.zero is not None: + zeros.append(l.zero) + if len(zeros): + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0), + torch.stack(zeros, axis=0))) + else: + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): # llama-2-7B & llama-3-8B @@ -299,7 +328,8 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, n_splits_down_proj=n_splits_down_proj, group_size=group_size, cos_len=input_len, - keep_position_ids=keep_position_ids + keep_position_ids=keep_position_ids, + asym=asym ) rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, @@ -329,11 +359,24 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, layer_norm_0.data.numpy().tofile(input_lm_bin_file) layer_norm_1.data.numpy().tofile(post_lm_bin_file) st_idx = 8 - for idx, (weight, scale) in enumerate(weights): - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") - weight.numpy().tofile(bin_file) - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") - scale.numpy().tofile(bin_file) + if not asym: + for idx, (weight, scale) in enumerate(weights): + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") + scale.numpy().tofile(bin_file) + else: + for idx, (weight, scale, zero) in enumerate(weights): + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*3}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*3+1}.bin") + scale.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*3+2}.bin") + zero.numpy().tofile(bin_file) + del single_decoder @@ -347,6 +390,7 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow rms_norm_eps = model.config.rms_norm_eps layer_num = len(model.model.layers) fused_layer_num = layer_num // fused_layers + asym = getattr(model.config, "asym", False) from ipex_llm.transformers.npu_models.llama_mp import LowBitLlamaMultiDecoderlayer for i in range(fused_layers): @@ -370,10 +414,17 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] + zeros = [] for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + if l.zero is not None: + zeros.append(l.zero) + if len(zeros): + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0), + torch.stack(zeros, axis=0))) + else: + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): # llama-2-7B & llama-3-8B @@ -397,12 +448,25 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow layer_norm_1.data.numpy().tofile(post_lm_bin_file) st_idx = 5 # 6, 7 are past k/v - for idx, (weight, scale) in enumerate(weights): - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") - weight.numpy().tofile(bin_file) - bin_file = os.path.join(weight_dir, - f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") - scale.numpy().tofile(bin_file) + if not asym: + for idx, (weight, scale) in enumerate(weights): + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*2}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") + scale.numpy().tofile(bin_file) + else: + for idx, (weight, scale, zero) in enumerate(weights): + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*3}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*3+1}.bin") + scale.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*3+2}.bin") + zero.numpy().tofile(bin_file) if isinstance(weights[0], tuple): np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 @@ -426,7 +490,8 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow dtype=np_dtype, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + asym=asym ) update_names_of_IR_and_export_blob(fused_decoder, f"decoder_layer_{i}",