Skip to content

Commit

Permalink
Groupwise prefill optimization (#12291)
Browse files Browse the repository at this point in the history
* except lm_head

* remove

* support gw lm_head

* update

* fix

* remove run.bat

* fix style

* support llama3

* slice -> split

* remove debug

* fix style

* add dpu
  • Loading branch information
cyita authored Oct 30, 2024
1 parent 540eaeb commit 70037ad
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 148 deletions.
52 changes: 38 additions & 14 deletions python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ def __init__(
new_value_states = self.convert_to_fp16(curr_key_values[i][1])

print("start compiling")
self.compile()
if mode == "prefill":
self.compile(npu_dpu_groups=6)
else:
self.compile()

def build_decoder(
self,
Expand Down Expand Up @@ -753,19 +756,40 @@ def run_prefill(

weights = []

for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))

for l in attn_layer.o_proj_dq_list:
weights.append((l.weight, l.scale))
for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
Expand Down
156 changes: 35 additions & 121 deletions python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,60 +165,21 @@ def attention(self,
)
else:
hidden_states = self.unsqueeze(hidden_states, axis=0)
if mode == "prefill":
query_states_to_concat = []
key_states_to_concat = []
value_states_to_concat = []
for i in range(self.n_splits_linear):
sub_hidden_states = self.slice(hidden_states,
begin=[0, 0, i * groupsize],
end=[1, seq_len, (i + 1) * groupsize])
query_states_to_concat.append(
self.linear(
sub_hidden_states,
num_heads * head_dim,
groupsize,
bias=False,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0)
)
)
key_states_to_concat.append(
self.linear(
sub_hidden_states,
num_key_value_heads * head_dim,
groupsize,
bias=False,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0)
)
)
value_states_to_concat.append(
self.linear(
sub_hidden_states,
num_key_value_heads * head_dim,
groupsize,
bias=False,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0)
)
)
query_states = sum(query_states_to_concat)
key_states = sum(key_states_to_concat)
value_states = sum(value_states_to_concat)
else:
query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))

if q_bias is not None:
query_states = query_states + q_bias
Expand Down Expand Up @@ -296,23 +257,10 @@ def attention(self,
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
)
else:
if mode == "prefill":
attn_output_to_concat = []
for i in range(self.n_splits_linear):
sub_attn_output = self.slice(attn_output,
begin=[0, 0, i * groupsize],
end=[1, seq_len, (i + 1) * groupsize])
attn_output_to_concat.append(
self.linear(
sub_attn_output, hidden_size, groupsize, bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
attn_output = sum(attn_output_to_concat)
else:
attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))

return attn_output, new_key_states, new_value_states

Expand Down Expand Up @@ -488,37 +436,14 @@ def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
else:
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
if mode == "prefill":
gate_up_groupsize = self.hidden_size // self.n_splits_linear
mm1_to_concat = []
mm2_to_concat = []
for i in range(self.n_splits_linear):
sub_hidden_states = self.slice(hidden_states,
begin=[0, 0, i * gate_up_groupsize],
end=[1, seq_len, (i + 1) * gate_up_groupsize])
mm1_to_concat.append(
self.linear(
sub_hidden_states, self.intermediate_size, gate_up_groupsize,
bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
mm2_to_concat.append(
self.linear(
sub_hidden_states, self.intermediate_size, gate_up_groupsize,
bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
mm1 = sum(mm1_to_concat)
mm2 = sum(mm2_to_concat)
else:
mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]

if self.n_splits_down_proj == 1:
Expand All @@ -527,23 +452,10 @@ def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
)
else:
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
if mode == "prefill":
down_groupsize = self.intermediate_size // self.n_splits_down_proj
hidden_states_to_concat = []
for i in range(self.n_splits_down_proj):
sub_mm1 = self.slice(mm1, begin=[0, 0, i * down_groupsize],
end=[1, seq_len, (i + 1) * down_groupsize])
hidden_states_to_concat.append(
self.linear(
sub_mm1, self.hidden_size, down_groupsize, bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
hidden_states = sum(hidden_states_to_concat)
else:
hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
self.n_splits_down_proj, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
self.n_splits_down_proj, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
return hidden_states

def layer_norm(self, hidden_states, layernorm_weight):
Expand Down Expand Up @@ -660,9 +572,11 @@ def dq_split_linear(self,
n_splits: int,
act_dtype: npt.DTypeLike = np.float16,
wt_dtype: npt.DTypeLike = np.float16,
scale_factor: bool = False):
scale_factor: bool = False,
is_prefill: bool = False):
op = super().dq_split_linear(input_node, n_splits, output_channels, input_channels,
False, act_dtype, wt_dtype, scale_factor)
False, act_dtype, wt_dtype, scale_factor,
is_prefill=is_prefill)
self.linear_ops.append(op)
return op

Expand Down
46 changes: 33 additions & 13 deletions python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,20 +827,40 @@ def run_prefill(
mlp_layer = curr_layer.mlp

weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))

for l in attn_layer.o_proj_dq_list:
weights.append((l.weight, l.scale))
for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
Expand Down

0 comments on commit 70037ad

Please sign in to comment.