diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index d37d462352a..93f1ff36448 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -188,7 +188,10 @@ def __init__( new_value_states = self.convert_to_fp16(curr_key_values[i][1]) print("start compiling") - self.compile() + if mode == "prefill": + self.compile(npu_dpu_groups=6) + else: + self.compile() def build_decoder( self, @@ -753,19 +756,40 @@ def run_prefill( weights = [] - for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list): - weights.append((q.weight, q.scale)) - weights.append((k.weight, k.scale)) - weights.append((v.weight, v.scale)) - - for l in attn_layer.o_proj_dq_list: - weights.append((l.weight, l.scale)) - for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list): - weights.append((g.weight, g.scale)) - weights.append((u.weight, u.scale)) - for l in mlp_layer.down_proj_dq_list: - weights.append((l.weight, l.scale)) + if n_splits_linear == 1: + for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, + attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, + attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, + mlp_layer.up_proj_dq_list): + weights.append((q.weight, q.scale)) + weights.append((k.weight, k.scale)) + weights.append((v.weight, v.scale)) + weights.append((o.weight, o.scale)) + weights.append((g.weight, g.scale)) + weights.append((u.weight, u.scale)) + else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + + if n_splits_down_proj == 1: + for l in mlp_layer.down_proj_dq_list: + weights.append((l.weight, l.scale)) + else: + l_weights = [] + scales = [] + for l in mlp_layer.down_proj_dq_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 48967bf968d..1550d6837f6 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -165,60 +165,21 @@ def attention(self, ) else: hidden_states = self.unsqueeze(hidden_states, axis=0) - if mode == "prefill": - query_states_to_concat = [] - key_states_to_concat = [] - value_states_to_concat = [] - for i in range(self.n_splits_linear): - sub_hidden_states = self.slice(hidden_states, - begin=[0, 0, i * groupsize], - end=[1, seq_len, (i + 1) * groupsize]) - query_states_to_concat.append( - self.linear( - sub_hidden_states, - num_heads * head_dim, - groupsize, - bias=False, - wt_dtype=self.dtype, - scale_factor=(self.group_size == 0) - ) - ) - key_states_to_concat.append( - self.linear( - sub_hidden_states, - num_key_value_heads * head_dim, - groupsize, - bias=False, - wt_dtype=self.dtype, - scale_factor=(self.group_size == 0) - ) - ) - value_states_to_concat.append( - self.linear( - sub_hidden_states, - num_key_value_heads * head_dim, - groupsize, - bias=False, - wt_dtype=self.dtype, - scale_factor=(self.group_size == 0) - ) - ) - query_states = sum(query_states_to_concat) - key_states = sum(key_states_to_concat) - value_states = sum(value_states_to_concat) - else: - query_states = self.dq_split_linear(hidden_states, num_heads * head_dim, - hidden_size, self.n_splits_linear, - wt_dtype=self.dtype, - scale_factor=(self.group_size == 0)) - key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim, - hidden_size, self.n_splits_linear, - wt_dtype=self.dtype, - scale_factor=(self.group_size == 0)) - value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim, - hidden_size, self.n_splits_linear, - wt_dtype=self.dtype, - scale_factor=(self.group_size == 0)) + query_states = self.dq_split_linear(hidden_states, num_heads * head_dim, + hidden_size, self.n_splits_linear, + wt_dtype=self.dtype, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill")) + key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim, + hidden_size, self.n_splits_linear, + wt_dtype=self.dtype, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill")) + value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim, + hidden_size, self.n_splits_linear, + wt_dtype=self.dtype, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill")) if q_bias is not None: query_states = query_states + q_bias @@ -296,23 +257,10 @@ def attention(self, attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype ) else: - if mode == "prefill": - attn_output_to_concat = [] - for i in range(self.n_splits_linear): - sub_attn_output = self.slice(attn_output, - begin=[0, 0, i * groupsize], - end=[1, seq_len, (i + 1) * groupsize]) - attn_output_to_concat.append( - self.linear( - sub_attn_output, hidden_size, groupsize, bias=False, - wt_dtype=self.dtype, scale_factor=(self.group_size == 0) - ) - ) - attn_output = sum(attn_output_to_concat) - else: - attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size, - self.n_splits_linear, wt_dtype=self.dtype, - scale_factor=(self.group_size == 0)) + attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size, + self.n_splits_linear, wt_dtype=self.dtype, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill")) return attn_output, new_key_states, new_value_states @@ -488,37 +436,14 @@ def mlp(self, hidden_states, seq_len=-1, mode="prefill"): mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] else: invalidInputError(seq_len > 0, "seq_len should be provided if use split linear") - if mode == "prefill": - gate_up_groupsize = self.hidden_size // self.n_splits_linear - mm1_to_concat = [] - mm2_to_concat = [] - for i in range(self.n_splits_linear): - sub_hidden_states = self.slice(hidden_states, - begin=[0, 0, i * gate_up_groupsize], - end=[1, seq_len, (i + 1) * gate_up_groupsize]) - mm1_to_concat.append( - self.linear( - sub_hidden_states, self.intermediate_size, gate_up_groupsize, - bias=False, - wt_dtype=self.dtype, scale_factor=(self.group_size == 0) - ) - ) - mm2_to_concat.append( - self.linear( - sub_hidden_states, self.intermediate_size, gate_up_groupsize, - bias=False, - wt_dtype=self.dtype, scale_factor=(self.group_size == 0) - ) - ) - mm1 = sum(mm1_to_concat) - mm2 = sum(mm2_to_concat) - else: - mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size, - self.n_splits_linear, wt_dtype=self.dtype, - scale_factor=(self.group_size == 0)) - mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size, - self.n_splits_linear, wt_dtype=self.dtype, - scale_factor=(self.group_size == 0)) + mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size, + self.n_splits_linear, wt_dtype=self.dtype, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill")) + mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size, + self.n_splits_linear, wt_dtype=self.dtype, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill")) mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] if self.n_splits_down_proj == 1: @@ -527,23 +452,10 @@ def mlp(self, hidden_states, seq_len=-1, mode="prefill"): ) else: invalidInputError(seq_len > 0, "seq_len should be provided if use split linear") - if mode == "prefill": - down_groupsize = self.intermediate_size // self.n_splits_down_proj - hidden_states_to_concat = [] - for i in range(self.n_splits_down_proj): - sub_mm1 = self.slice(mm1, begin=[0, 0, i * down_groupsize], - end=[1, seq_len, (i + 1) * down_groupsize]) - hidden_states_to_concat.append( - self.linear( - sub_mm1, self.hidden_size, down_groupsize, bias=False, - wt_dtype=self.dtype, scale_factor=(self.group_size == 0) - ) - ) - hidden_states = sum(hidden_states_to_concat) - else: - hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size, - self.n_splits_down_proj, wt_dtype=self.dtype, - scale_factor=(self.group_size == 0)) + hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size, + self.n_splits_down_proj, wt_dtype=self.dtype, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill")) return hidden_states def layer_norm(self, hidden_states, layernorm_weight): @@ -660,9 +572,11 @@ def dq_split_linear(self, n_splits: int, act_dtype: npt.DTypeLike = np.float16, wt_dtype: npt.DTypeLike = np.float16, - scale_factor: bool = False): + scale_factor: bool = False, + is_prefill: bool = False): op = super().dq_split_linear(input_node, n_splits, output_channels, input_channels, - False, act_dtype, wt_dtype, scale_factor) + False, act_dtype, wt_dtype, scale_factor, + is_prefill=is_prefill) self.linear_ops.append(op) return op diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index f6952af2f7f..8459ddf5efe 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -827,20 +827,40 @@ def run_prefill( mlp_layer = curr_layer.mlp weights = [] + if n_splits_linear == 1: + for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, + attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, + attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, + mlp_layer.up_proj_dq_list): + weights.append((q.weight, q.scale)) + weights.append((k.weight, k.scale)) + weights.append((v.weight, v.scale)) + weights.append((o.weight, o.scale)) + weights.append((g.weight, g.scale)) + weights.append((u.weight, u.scale)) + else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) - for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list): - weights.append((q.weight, q.scale)) - weights.append((k.weight, k.scale)) - weights.append((v.weight, v.scale)) - - for l in attn_layer.o_proj_dq_list: - weights.append((l.weight, l.scale)) - for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list): - weights.append((g.weight, g.scale)) - weights.append((u.weight, u.scale)) - for l in mlp_layer.down_proj_dq_list: - weights.append((l.weight, l.scale)) + if n_splits_down_proj == 1: + for l in mlp_layer.down_proj_dq_list: + weights.append((l.weight, l.scale)) + else: + l_weights = [] + scales = [] + for l in mlp_layer.down_proj_dq_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)