diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/baichuan2.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/baichuan2.py index 96c77fb1923..53258002a66 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/baichuan2.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/baichuan2.py @@ -60,6 +60,7 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]], parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") parser.add_argument("--max-context-len", type=int, default=1024) parser.add_argument("--max-prompt-len", type=int, default=512) + parser.add_argument("--quantization_group_size", type=int, default=0) parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--disable-streaming", action="store_true", default=False) @@ -72,6 +73,7 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]], pipeline=True, max_context_len=args.max_context_len, max_prompt_len=args.max_prompt_len, + quantization_group_size=args.quantization_group_size, torch_dtype=torch.float16, attn_implementation="eager", transpose_value_cache=not args.disable_transpose_value_cache, diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index 59e6a2b97c7..c04bad7423e 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -50,6 +50,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from ipex_llm.transformers.npu_models.mp_models_base import run_model from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory +from ipex_llm.transformers.npu_models.common import reshape_lm_head_input +from transformers.modeling_outputs import CausalLMOutputWithPast +from torch.nn import CrossEntropyLoss class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory): @@ -75,12 +78,18 @@ def __init__( device: str = "NPU", rms_norm_eps, intermediate_size, + n_splits_linear: int = 1, + n_splits_down_proj: int = 1, + group_size: int = 0 ): super().__init__(max_seq_len=max_seq_len, transpose_value=transpose_value, dtype=dtype, profile=profile, - device=device) + device=device, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size) self.max_seq_len = max_seq_len self.intermediate_size = intermediate_size self.dtype = dtype @@ -115,8 +124,7 @@ def __init__( attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.int64) else: - attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len), - dtype=np.int64) + attention_mask = None position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) # self.num_key_value_heads = num_key_value_heads @@ -178,6 +186,7 @@ def __init__( post_attention_layernorm_weight=post_attn_layernorm_weights[i], past_key=past_keys[i], past_value=past_values[i], + use_prefill_sdp=True, ) curr_key_values.append((new_key_states, new_value_states)) @@ -189,7 +198,10 @@ def __init__( new_value_states = self.convert_to_fp16(curr_key_values[i][1]) print("start compiling") - self.compile() + if mode == "prefill" and os.environ.get("IPEX_LLM_NPU_DISABLE_COMPILE_OPT", "0") != "1": + self.compile(npu_dpu_groups=6) + else: + self.compile() def attention(self, *, @@ -206,15 +218,23 @@ def attention(self, seq_len, q_bias=None, k_bias=None, - v_bias=None): + v_bias=None, + use_prefill_sdp=False): hidden_size = num_heads * head_dim + if self.n_splits_linear != 1: + hidden_states = self.unsqueeze(hidden_states, axis=0) + proj = self.linear( hidden_states, 3 * hidden_size, hidden_size, bias=False, - wt_dtype=self.dtype + wt_dtype=self.dtype, + n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill") ) + proj = self.reshape(proj, [-1, 3, hidden_size]) # b*s, 3, h proj = self.unsqueeze(proj, [0]) # b, s, 3, h proj = self.transpose(proj, [2, 1, 0, 3]) # 3, s, b, h @@ -224,8 +244,14 @@ def attention(self, key_states = self.reshape(proj[1, ...], [1, self.seq_len, num_heads, head_dim]) key_states = self.transpose(key_states, [0, 2, 1, 3]) value_states = self.reshape(proj[2, ...], [1, self.seq_len, num_heads, head_dim]) + + use_ov_sdp = (mode == "prefill") and use_prefill_sdp if self.transpose_value: - value_states = self.transpose(value_states, [0, 2, 3, 1]) + new_value_states = self.transpose(value_states, [0, 2, 3, 1]) + if use_ov_sdp: + value_states = self.transpose(value_states, [0, 2, 1, 3]) + else: + value_states = new_value_states else: value_states = self.transpose(value_states, [0, 2, 1, 3]) @@ -243,7 +269,6 @@ def attention(self, head_dim=head_dim, ) new_key_states = key_states - new_value_states = value_states if self.mode == "decode": key_states = self.concat(past_key, key_states, axis=-2) @@ -252,20 +277,31 @@ def attention(self, else: value_states = self.concat(past_value, value_states, axis=-2) - attn_weight = self.matmul(query_states, key_states, False, True) / ( - math.sqrt(self.head_dim)) - attention_mask = self.convert_to_fp16(attention_mask) - attn_weight = self.eltwise_add(attn_weight, attention_mask) - attn_weight = self.convert_to_fp32(attn_weight) - attn_weight = self.softmax(attn_weight, -1) - attn_weight = self.convert_to_fp16(attn_weight) - attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) + if use_ov_sdp: + value_states = self.convert_to_fp32(value_states) + key_states = self.convert_to_fp32(key_states) + query_states = self.convert_to_fp32(query_states) + attn_output = self.scaled_dot_product_attention( + query_states, key_states, value_states, None, True) + attn_output = self.convert_to_fp16(attn_output) + else: + attn_weight = self.matmul(query_states, key_states, False, True) / ( + math.sqrt(self.head_dim)) + attention_mask = self.convert_to_fp16(attention_mask) + attn_weight = self.eltwise_add(attn_weight, attention_mask) + attn_weight = self.convert_to_fp32(attn_weight) + attn_weight = self.softmax(attn_weight, -1) + attn_weight = self.convert_to_fp16(attn_weight) + attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) attn_output = self.transpose(attn_output, [0, 2, 1, 3]) attn_output = self.reshape(attn_output, [1, seq_len, hidden_size]) attn_output = self.linear( - attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype + attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype, + n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill") ) return attn_output, new_key_states, new_value_states @@ -278,6 +314,7 @@ def build_decoder( post_attention_layernorm_weight, past_key=None, past_value=None, + use_prefill_sdp=False, ): residual = hidden_states @@ -298,12 +335,13 @@ def build_decoder( num_heads=self.num_heads, head_dim=self.head_dim, seq_len=self.seq_len, + use_prefill_sdp=use_prefill_sdp, ) hidden_states = self.eltwise_add(residual, attn_output) residual = hidden_states hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight) - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states, self.seq_len, self.mode) hidden_states = self.eltwise_add(residual, hidden_states) hidden_states = self.convert_to_fp16(hidden_states) @@ -329,6 +367,9 @@ def __init__( max_seq_len: int = 1024, transpose_value: bool = False, do_print: bool = False, + n_splits_linear: int = 1, + n_splits_down_proj: int = 1, + group_size: int = 0 ): super().__init__() @@ -338,6 +379,10 @@ def __init__( for w in parameters: if isinstance(w, tuple): # from QuantizedLinear op_parameters.append((w[0].numpy(), w[1].numpy())) + elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight + op_parameters.append(w.numpy()) + elif isinstance(w, np.ndarray): # scale + op_parameters.append(w) else: op_parameters.append(w.to(torch.float16).numpy()) self.op_parameters = op_parameters @@ -346,6 +391,10 @@ def __init__( self.transpose_value = transpose_value if isinstance(parameters[0], tuple): np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 + elif parameters[0].dtype == torch.int8: + np_dtype = np.int8 + elif parameters[0].dtype == torch.uint8: + np_dtype = np.uint8 else: # FP16 Linear np_dtype = np.float16 @@ -380,6 +429,9 @@ def __init__( mode="decode", transpose_value=self.transpose_value, dtype=np_dtype, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) self.backend_decoders.append(decoder) @@ -453,6 +505,9 @@ def __init__( intermediate_size, max_seq_len: int = 128, transpose_value: bool = False, + n_splits_linear: int = 1, + n_splits_down_proj: int = 1, + group_size: int = 0 ): super().__init__() self.op_parameters = parameters @@ -481,6 +536,9 @@ def __init__( mode="prefill", transpose_value=self.transpose_value, dtype=np_dtype, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) self.layer_norm_0 = layer_norm_0 self.layer_norm_1 = layer_norm_1 @@ -507,7 +565,6 @@ def forward( backend_cls = self.backend_cls_prefill inputs = (hidden_states.to(torch.float16), - attention_mask.to(torch.int64), position_ids.to(torch.int64)) inputs += (self.layer_norm_0, self.layer_norm_1) hidden_states, past_key, past_value = run_model( @@ -557,22 +614,28 @@ def run_decode( head_dim = model.model.layers[layer_start].self_attn.head_dim rms_norm_eps = model.config.rms_norm_eps intermediate_size = model.config.intermediate_size + group_size = getattr(model.config, "group_size", 0) layer_weights = [] input_layer_norm_weights = [] post_attn_layernorm_weights = [] layer_indexs = range(layer_start, layer_end) + n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) + n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) for layer_idx in layer_indexs: curr_layer = model.model.layers[layer_idx] attn_layer = curr_layer.self_attn mlp_layer = curr_layer.mlp - weights = [ - (attn_layer.W_pack.weight, attn_layer.W_pack.scale), - (attn_layer.o_proj.weight, attn_layer.o_proj.scale), - (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), - (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), - (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), - ] + weights = [] + for layer_list in [attn_layer.W_pack_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -599,6 +662,9 @@ def run_decode( max_seq_len=max_seq_len, transpose_value=transpose_value_cache, do_print=False, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) dist.barrier() @@ -754,23 +820,29 @@ def run_prefill( head_dim = model.model.layers[layer_start].self_attn.head_dim rms_norm_eps = model.config.rms_norm_eps intermediate_size = model.config.intermediate_size + group_size = getattr(model.config, "group_size", 0) deocderlayers = [] layer_weights = [] input_layer_norm_weights = [] post_attn_layernorm_weights = [] layer_indexs = range(layer_start, layer_end) + n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) + n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) for layer_idx in layer_indexs: curr_layer = model.model.layers[layer_idx] attn_layer = curr_layer.self_attn mlp_layer = curr_layer.mlp - weights = [ - (attn_layer.W_pack.weight, attn_layer.W_pack.scale), - (attn_layer.o_proj.weight, attn_layer.o_proj.scale), - (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), - (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), - (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), - ] + weights = [] + for layer_list in [attn_layer.W_pack_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -791,6 +863,9 @@ def run_prefill( intermediate_size=intermediate_size, max_seq_len=max_output_len, transpose_value=transpose_value_cache, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) layer_weights.extend(weights) @@ -1025,3 +1100,71 @@ def baichuan_fused_model_forward( ) return baichuan_fused_model_forward + + +def baichuan2_causal_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + # ipex-llm change start + hidden_states = reshape_lm_head_input(hidden_states) + # ipex-llm change end + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + softmax_normalizer = shift_logits.max(-1).values ** 2 + z_loss = self.config.z_loss_weight * softmax_normalizer.mean() + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + z_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/common.py b/python/llm/src/ipex_llm/transformers/npu_models/common.py index 45c12fc63e2..4bf492cbe0d 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/common.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/common.py @@ -75,10 +75,11 @@ def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2Attention from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"] + baichuan_attn_module_names = ["W_pack", "o_proj"] mlp_module_names = ["down_proj", "up_proj", "gate_proj"] if ( isinstance(module, (Qwen2Attention, LlamaAttention)) - or module.__class__.__name__ in ['MiniCPMAttention', 'Attention'] + or module.__class__.__name__ in ['MiniCPMAttention'] ): for name in attn_module_names: setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name, @@ -97,3 +98,10 @@ def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down n_splits=n_splits_mlp, load=load)) delattr(module, name) + elif module.__class__.__name__ == 'Attention' and module.config.model_type == 'baichuan': + # baichuan attention + for name in baichuan_attn_module_names: + setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name, + n_splits=n_splits_hidden_size, + load=load)) + delattr(module, name) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index f7e00bb804d..aba32ab433a 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -87,7 +87,7 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision, model.llm.config.model_type = "llama" model = model.llm - if model.config.model_type in ["qwen2", "llama", "minicpm"]: + if model.config.model_type in ["qwen2", "llama", "minicpm", "baichuan"]: from ipex_llm.transformers.npu_models.common import split_linears if quantization_group_size == 0: n_splits_linear = 1 @@ -245,6 +245,8 @@ def convert_baichuan( modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) convert_forward(model, module.BaichuanModel, baichuan_model_forward) + from ipex_llm.transformers.npu_models.baichuan_mp import baichuan2_causal_forward + convert_forward(model, module.BaichuanForCausalLM, baichuan2_causal_forward) def convert_minicpm( @@ -392,7 +394,7 @@ def optimize_llm( if intra_pp is None: intra_pp = 2 if inter_pp is None: - inter_pp = 2 + inter_pp = 2 if group_size == 0 else 4 convert_baichuan(model, max_output_len=max_context_len, max_prompt_len=max_prompt_len, diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index cf452199a87..fda6530bd9b 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -560,37 +560,13 @@ def run_decode( mlp_layer = curr_layer.mlp weights = [] - if n_splits_linear == 1: - for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, - attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, - attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, - mlp_layer.up_proj_dq_list): - weights.append((q.weight, q.scale)) - weights.append((k.weight, k.scale)) - weights.append((v.weight, v.scale)) - weights.append((o.weight, o.scale)) - weights.append((g.weight, g.scale)) - weights.append((u.weight, u.scale)) - else: - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: - l_weights = [] - scales = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) - - if n_splits_down_proj == 1: - for l in mlp_layer.down_proj_dq_list: - weights.append((l.weight, l.scale)) - else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] - for l in mlp_layer.down_proj_dq_list: + for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) @@ -844,38 +820,13 @@ def run_prefill( weights = [] - if n_splits_linear == 1: - for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, - attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, - attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, - mlp_layer.up_proj_dq_list): - weights.append((q.weight, q.scale)) - weights.append((k.weight, k.scale)) - weights.append((v.weight, v.scale)) - weights.append((o.weight, o.scale)) - weights.append((g.weight, g.scale)) - weights.append((u.weight, u.scale)) - else: - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: - l_weights = [] - scales = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), - torch.stack(scales, axis=0))) - - if n_splits_down_proj == 1: - for l in mlp_layer.down_proj_dq_list: - weights.append((l.weight, l.scale)) - else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] - for l in mlp_layer.down_proj_dq_list: + for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py index c35f687a36b..e9fbfce1a97 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py @@ -540,37 +540,13 @@ def run_decode( mlp_layer = curr_layer.mlp weights = [] - if n_splits_linear == 1: - for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, - attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, - attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, - mlp_layer.up_proj_dq_list): - weights.append((q.weight, q.scale)) - weights.append((k.weight, k.scale)) - weights.append((v.weight, v.scale)) - weights.append((o.weight, o.scale)) - weights.append((g.weight, g.scale)) - weights.append((u.weight, u.scale)) - else: - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: - l_weights = [] - scales = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) - - if n_splits_down_proj == 1: - for l in mlp_layer.down_proj_dq_list: - weights.append((l.weight, l.scale)) - else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] - for l in mlp_layer.down_proj_dq_list: + for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) @@ -783,38 +759,13 @@ def run_prefill( mlp_layer = curr_layer.mlp weights = [] - - if n_splits_linear == 1: - for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, - attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, - attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, - mlp_layer.up_proj_dq_list): - weights.append((q.weight, q.scale)) - weights.append((k.weight, k.scale)) - weights.append((v.weight, v.scale)) - weights.append((o.weight, o.scale)) - weights.append((g.weight, g.scale)) - weights.append((u.weight, u.scale)) - else: - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: - l_weights = [] - scales = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) - - if n_splits_down_proj == 1: - for l in mlp_layer.down_proj_dq_list: - weights.append((l.weight, l.scale)) - else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] - for l in mlp_layer.down_proj_dq_list: + for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 13997d9f507..73666487333 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -138,47 +138,41 @@ def attention(self, use_prefill_sdp=False): hidden_size = num_heads * head_dim num_key_value_groups = num_heads // num_key_value_heads - if self.n_splits_linear == 1: - query_states = self.linear( - hidden_states, - num_heads * head_dim, - hidden_size, - bias=False, - wt_dtype=self.dtype, - ) + if self.n_splits_linear != 1: + hidden_states = self.unsqueeze(hidden_states, axis=0) - key_states = self.linear( - hidden_states, - num_key_value_heads * head_dim, - hidden_size, - bias=False, - wt_dtype=self.dtype, - ) + query_states = self.linear( + hidden_states, + num_heads * head_dim, + hidden_size, + bias=False, + wt_dtype=self.dtype, + n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill") + ) - value_states = self.linear( - hidden_states, - num_key_value_heads * head_dim, - hidden_size, - bias=False, - wt_dtype=self.dtype, - ) - else: - hidden_states = self.unsqueeze(hidden_states, axis=0) - query_states = self.dq_split_linear(hidden_states, num_heads * head_dim, - hidden_size, self.n_splits_linear, - wt_dtype=self.dtype, - scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill")) - key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim, - hidden_size, self.n_splits_linear, - wt_dtype=self.dtype, - scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill")) - value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim, - hidden_size, self.n_splits_linear, - wt_dtype=self.dtype, - scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill")) + key_states = self.linear( + hidden_states, + num_key_value_heads * head_dim, + hidden_size, + bias=False, + wt_dtype=self.dtype, + n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill") + ) + + value_states = self.linear( + hidden_states, + num_key_value_heads * head_dim, + hidden_size, + bias=False, + wt_dtype=self.dtype, + n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill") + ) if q_bias is not None: query_states = query_states + q_bias @@ -263,15 +257,12 @@ def attention(self, attn_output = self.transpose(attn_output, [0, 2, 1, 3]) attn_output = self.reshape(attn_output, [1, seq_len, hidden_size]) - if self.n_splits_linear == 1: - attn_output = self.linear( - attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype - ) - else: - attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size, - self.n_splits_linear, wt_dtype=self.dtype, - scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill")) + attn_output = self.linear( + attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype, + n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill") + ) return attn_output, new_key_states, new_value_states def paraformer_layer_norm(self, hidden_states, layernorm_weight, layernorm_bias): @@ -434,38 +425,26 @@ def feed_forward_sanm_decoder(self, x, w_1_bias, norm_weights, norm_bias): return w_2 def mlp(self, hidden_states, seq_len=-1, mode="prefill"): - if self.n_splits_linear == 1: - mm1 = self.linear( - hidden_states, self.intermediate_size, self.hidden_size, bias=False, - wt_dtype=self.dtype - ) - mm2 = self.linear( - hidden_states, self.intermediate_size, self.hidden_size, bias=False, - wt_dtype=self.dtype - ) # type: ignore[attr-defined] - mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] - else: - invalidInputError(seq_len > 0, "seq_len should be provided if use split linear") - mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size, - self.n_splits_linear, wt_dtype=self.dtype, - scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill")) - mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size, - self.n_splits_linear, wt_dtype=self.dtype, - scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill")) - mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] - - if self.n_splits_down_proj == 1: - hidden_states = self.linear( - mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype - ) - else: - invalidInputError(seq_len > 0, "seq_len should be provided if use split linear") - hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size, - self.n_splits_down_proj, wt_dtype=self.dtype, - scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill")) + mm1 = self.linear( + hidden_states, self.intermediate_size, self.hidden_size, bias=False, + wt_dtype=self.dtype, n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill") + ) + mm2 = self.linear( + hidden_states, self.intermediate_size, self.hidden_size, bias=False, + wt_dtype=self.dtype, n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill") + ) # type: ignore[attr-defined] + mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] + + hidden_states = self.linear( + mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype, + n_splits=self.n_splits_down_proj, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill") + ) return hidden_states def layer_norm(self, hidden_states, layernorm_weight): @@ -571,8 +550,26 @@ def create_input_op(self, shape, dtype=np.float16): self.input_ops.append(op) return op - def linear(self, *args, **kwargs): - op = super().linear(*args, **kwargs) + def linear(self, + input_node: ctypes._Pointer, + output_channels: int, + input_channels: int, + bias: Optional[bool] = False, + act_dtype: npt.DTypeLike = np.float16, + wt_dtype: npt.DTypeLike = np.float16, + n_splits: int = 1, + scale_factor: bool = True, + is_prefill: bool = False): + if n_splits == 1: + op = super().linear(input_node, output_channels, + input_channels, bias, act_dtype, + wt_dtype, scale_factor=scale_factor) + else: + op = super().dq_split_linear(input_node, n_splits, + output_channels, input_channels, + bias=bias, act_dtype=act_dtype, + wt_dtype=wt_dtype, scale_factor=scale_factor, + is_prefill=is_prefill) self.linear_ops.append(op) return op diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index 501fb4aa87a..ab11c27b665 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -586,37 +586,13 @@ def run_decode( mlp_layer = curr_layer.mlp weights = [] - if n_splits_linear == 1: - for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, - attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, - attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, - mlp_layer.up_proj_dq_list): - weights.append((q.weight, q.scale)) - weights.append((k.weight, k.scale)) - weights.append((v.weight, v.scale)) - weights.append((o.weight, o.scale)) - weights.append((g.weight, g.scale)) - weights.append((u.weight, u.scale)) - else: - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: - l_weights = [] - scales = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) - - if n_splits_down_proj == 1: - for l in mlp_layer.down_proj_dq_list: - weights.append((l.weight, l.scale)) - else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] - for l in mlp_layer.down_proj_dq_list: + for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) @@ -839,37 +815,13 @@ def run_prefill( mlp_layer = curr_layer.mlp weights = [] - if n_splits_linear == 1: - for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, - attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, - attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, - mlp_layer.up_proj_dq_list): - weights.append((q.weight, q.scale)) - weights.append((k.weight, k.scale)) - weights.append((v.weight, v.scale)) - weights.append((o.weight, o.scale)) - weights.append((g.weight, g.scale)) - weights.append((u.weight, u.scale)) - else: - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: - l_weights = [] - scales = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) - - if n_splits_down_proj == 1: - for l in mlp_layer.down_proj_dq_list: - weights.append((l.weight, l.scale)) - else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] - for l in mlp_layer.down_proj_dq_list: + for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/baichuan.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/baichuan.py index 078490925ac..a035b1332dd 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/baichuan.py @@ -28,7 +28,17 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): vocab_size = model.config.vocab_size model_norm = model.model.norm lm_head = model.lm_head - weights = [(lm_head.weight, lm_head.scale)] + if n_splits_linear == 1: + weights = [(lm_head.weight, lm_head.scale)] + else: + lm_heads = lm_head.lm_heads + lm_head_weights = [] + scales = [] + for l in lm_heads: + lm_head_weights.append(l.weight) + scales.append(l.scale) + weights = [(torch.stack(lm_head_weights, axis=0), + torch.stack(scales, axis=0))] if isinstance(weights[0], tuple): np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 else: # FP16 Linear @@ -44,13 +54,17 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): dtype=np_dtype, model_norm_weight=model_norm.weight.to(torch.float16), vocab_size=vocab_size, + n_splits=n_splits_linear ) last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir) # save weights bins files - weight_numpy = [ - lm_head.weight.data.numpy(), lm_head.scale.data.numpy(), - ] + if n_splits_linear == 1: + weight_numpy = [ + lm_head.weight.data.numpy(), lm_head.scale.data.numpy(), + ] + else: + weight_numpy = [v.numpy() for v in weights[0]] for idx, weight in enumerate(weight_numpy): bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin") @@ -83,17 +97,15 @@ def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj mlp_layer = curr_layer.mlp weights = [] - if n_splits_linear == 1: - weights = [ - (attn_layer.W_pack.weight, attn_layer.W_pack.scale), - (attn_layer.o_proj.weight, attn_layer.o_proj.scale), - (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), - (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), - (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), - ] - else: - # TODO - pass + for layer_list in [attn_layer.W_pack_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -119,6 +131,9 @@ def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj mode="decode", transpose_value=transpose_value_cache, dtype=np_dtype, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, f"decoder_layer_{layer_idx}", diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py index 3cccb9fd422..3a9f81e2041 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py @@ -91,22 +91,22 @@ def __init__( self.head_dim = self.hidden_size // self.num_heads # define input, the order self.parameter matters - input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size)) + if n_splits == 1: + input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size)) + else: + input = self.create_input_op((1, self.batch_size, self.hidden_size)) hidden_states = input # model norm and lm head model_norm_weight = self.constant(model_norm_weight) hidden_states = self.layer_norm(hidden_states, model_norm_weight) - if n_splits == 1: - hidden_states = self.linear( - hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype - ) - else: - hidden_states = self.dq_split_linear( - hidden_states, self.vocab_size, self.hidden_size, n_splits, - wt_dtype=dtype, scale_factor=False - ) + + hidden_states = self.linear( + hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype, + n_splits=n_splits, + scale_factor=(n_splits == 1), + ) # define outputs hidden_states = self.convert_to_fp32(hidden_states) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py index 4e3674a791c..2e62418fa6f 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py @@ -174,38 +174,13 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, mlp_layer = curr_layer.mlp weights = [] - if n_splits_linear == 1: - for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, - attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, - attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, - mlp_layer.up_proj_dq_list): - weights.append((q.weight, q.scale)) - weights.append((k.weight, k.scale)) - weights.append((v.weight, v.scale)) - weights.append((o.weight, o.scale)) - weights.append((g.weight, g.scale)) - weights.append((u.weight, u.scale)) - else: - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: - l_weights = [] - scales = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), - torch.stack(scales, axis=0))) - - if n_splits_down_proj == 1: - for l in mlp_layer.down_proj_dq_list: - weights.append((l.weight, l.scale)) - else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] - for l in mlp_layer.down_proj_dq_list: + for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py index d4dbefdb3b0..e5939efcb97 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py @@ -109,37 +109,23 @@ def __init__( hidden_states = self.layer_norm(hidden_states, model_norm_weight) if vocab_size == 122753: # for MiniCPM-2B-sft-bf16 - if n_splits == 1: - hidden_states_1 = self.linear( - hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype - ) - hidden_states_2 = self.linear( - hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype - ) - else: - hidden_states_1 = self.dq_split_linear( - hidden_states, 73440, self.hidden_size, - n_splits=n_splits, wt_dtype=dtype, scale_factor=False - ) - hidden_states_2 = self.dq_split_linear( - hidden_states, 73440, self.hidden_size, - n_splits=n_splits, wt_dtype=dtype, scale_factor=False - ) + hidden_states_1 = self.linear( + hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype, + n_splits=n_splits, scale_factor=(n_splits == 1) + ) + hidden_states_2 = self.linear( + hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype, + n_splits=n_splits, scale_factor=(n_splits == 1) + ) hidden_states_2 = self.slice(hidden_states_2, begin=[0, 0, 0], end=[1, 1, 49313]) hidden_states = self.concat(hidden_states_1, hidden_states_2, axis=2) else: # for MiniCPM-1B-sft-bf16 - if n_splits == 1: - hidden_states = self.linear( - hidden_states, self.vocab_size, self.hidden_size, bias=False, - wt_dtype=self.dtype - ) - else: - hidden_states = self.dq_split_linear( - hidden_states, self.vocab_size, self.hidden_size, - n_splits=n_splits, wt_dtype=dtype, scale_factor=False - ) + hidden_states = self.linear( + hidden_states, self.vocab_size, self.hidden_size, bias=False, + wt_dtype=self.dtype, n_splits=n_splits, scale_factor=(n_splits == 1) + ) # define outputs hidden_states = self.convert_to_fp32(hidden_states) @@ -245,38 +231,13 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, mlp_layer = curr_layer.mlp weights = [] - if n_splits_linear == 1: - for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, - attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, - attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, - mlp_layer.up_proj_dq_list): - weights.append((q.weight, q.scale)) - weights.append((k.weight, k.scale)) - weights.append((v.weight, v.scale)) - weights.append((o.weight, o.scale)) - weights.append((g.weight, g.scale)) - weights.append((u.weight, u.scale)) - else: - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: - l_weights = [] - scales = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), - torch.stack(scales, axis=0))) - - if n_splits_down_proj == 1: - for l in mlp_layer.down_proj_dq_list: - weights.append((l.weight, l.scale)) - else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] - for l in mlp_layer.down_proj_dq_list: + for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py index eb38ad7b107..645ad830e0d 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py @@ -99,37 +99,13 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, mlp_layer = curr_layer.mlp weights = [] - if n_splits_linear == 1: - for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, - attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, - attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, - mlp_layer.up_proj_dq_list): - weights.append((q.weight, q.scale)) - weights.append((k.weight, k.scale)) - weights.append((v.weight, v.scale)) - weights.append((o.weight, o.scale)) - weights.append((g.weight, g.scale)) - weights.append((u.weight, u.scale)) - else: - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: - l_weights = [] - scales = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) - - if n_splits_down_proj == 1: - for l in mlp_layer.down_proj_dq_list: - weights.append((l.weight, l.scale)) - else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] - for l in mlp_layer.down_proj_dq_list: + for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))