Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Groupwise prefill optimization #12291

Merged
merged 16 commits into from
Oct 30, 2024
47 changes: 34 additions & 13 deletions python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,19 +753,40 @@ def run_prefill(

weights = []

for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))

for l in attn_layer.o_proj_dq_list:
weights.append((l.weight, l.scale))
for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
Expand Down
156 changes: 35 additions & 121 deletions python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,60 +165,21 @@ def attention(self,
)
else:
hidden_states = self.unsqueeze(hidden_states, axis=0)
if mode == "prefill":
query_states_to_concat = []
key_states_to_concat = []
value_states_to_concat = []
for i in range(self.n_splits_linear):
sub_hidden_states = self.slice(hidden_states,
begin=[0, 0, i * groupsize],
end=[1, seq_len, (i + 1) * groupsize])
query_states_to_concat.append(
self.linear(
sub_hidden_states,
num_heads * head_dim,
groupsize,
bias=False,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0)
)
)
key_states_to_concat.append(
self.linear(
sub_hidden_states,
num_key_value_heads * head_dim,
groupsize,
bias=False,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0)
)
)
value_states_to_concat.append(
self.linear(
sub_hidden_states,
num_key_value_heads * head_dim,
groupsize,
bias=False,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0)
)
)
query_states = sum(query_states_to_concat)
key_states = sum(key_states_to_concat)
value_states = sum(value_states_to_concat)
else:
query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))

if q_bias is not None:
query_states = query_states + q_bias
Expand Down Expand Up @@ -296,23 +257,10 @@ def attention(self,
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
)
else:
if mode == "prefill":
attn_output_to_concat = []
for i in range(self.n_splits_linear):
sub_attn_output = self.slice(attn_output,
begin=[0, 0, i * groupsize],
end=[1, seq_len, (i + 1) * groupsize])
attn_output_to_concat.append(
self.linear(
sub_attn_output, hidden_size, groupsize, bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
attn_output = sum(attn_output_to_concat)
else:
attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))

return attn_output, new_key_states, new_value_states

Expand Down Expand Up @@ -488,37 +436,14 @@ def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
else:
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
if mode == "prefill":
gate_up_groupsize = self.hidden_size // self.n_splits_linear
mm1_to_concat = []
mm2_to_concat = []
for i in range(self.n_splits_linear):
sub_hidden_states = self.slice(hidden_states,
begin=[0, 0, i * gate_up_groupsize],
end=[1, seq_len, (i + 1) * gate_up_groupsize])
mm1_to_concat.append(
self.linear(
sub_hidden_states, self.intermediate_size, gate_up_groupsize,
bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
mm2_to_concat.append(
self.linear(
sub_hidden_states, self.intermediate_size, gate_up_groupsize,
bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
mm1 = sum(mm1_to_concat)
mm2 = sum(mm2_to_concat)
else:
mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]

if self.n_splits_down_proj == 1:
Expand All @@ -527,23 +452,10 @@ def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
)
else:
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
if mode == "prefill":
down_groupsize = self.intermediate_size // self.n_splits_down_proj
hidden_states_to_concat = []
for i in range(self.n_splits_down_proj):
sub_mm1 = self.slice(mm1, begin=[0, 0, i * down_groupsize],
end=[1, seq_len, (i + 1) * down_groupsize])
hidden_states_to_concat.append(
self.linear(
sub_mm1, self.hidden_size, down_groupsize, bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
hidden_states = sum(hidden_states_to_concat)
else:
hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
self.n_splits_down_proj, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
self.n_splits_down_proj, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
return hidden_states

def layer_norm(self, hidden_states, layernorm_weight):
Expand Down Expand Up @@ -660,9 +572,11 @@ def dq_split_linear(self,
n_splits: int,
act_dtype: npt.DTypeLike = np.float16,
wt_dtype: npt.DTypeLike = np.float16,
scale_factor: bool = False):
scale_factor: bool = False,
is_prefill: bool = False):
op = super().dq_split_linear(input_node, n_splits, output_channels, input_channels,
False, act_dtype, wt_dtype, scale_factor)
False, act_dtype, wt_dtype, scale_factor,
is_prefill=is_prefill)
self.linear_ops.append(op)
return op

Expand Down
46 changes: 33 additions & 13 deletions python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,20 +827,40 @@ def run_prefill(
mlp_layer = curr_layer.mlp

weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))

for l in attn_layer.o_proj_dq_list:
weights.append((l.weight, l.scale))
for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
Expand Down
Loading