Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang committed Aug 28, 2024
1 parent f57b0d5 commit ffdc410
Showing 1 changed file with 149 additions and 43 deletions.
192 changes: 149 additions & 43 deletions python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,25 @@ def __init__(

self.compile()

def mlp(self, hidden_states):
mm1 = self.linear(
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
)
mm2 = self.linear(
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
) # type: ignore[attr-defined]
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
if self.intermediate_size == 18944:
# for qwen2-7b
hidden_states = self.linear(
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=np.int8
)
else:
hidden_states = self.linear(
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype
)
return hidden_states

def build_decoder(
self,
hidden_states,
Expand Down Expand Up @@ -734,54 +753,65 @@ def run_prefill(
input_layer_norm_weights = []
post_attn_layernorm_weights = []
layer_indexs = range(layer_start, layer_end)
for layer_idx in layer_indexs:
curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn
mlp_layer = curr_layer.mlp

weights = [
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
]

cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)

layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)

new_decoderlayer = FusedQwenLowBitDecoderlayer(
weights,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
cached_cos=cached_cos,
cached_sin=cached_sin,
layer_norm_0=layer_norm_0,
layer_norm_1=layer_norm_1,
q_bias=attn_layer.q_proj.bias.to(torch.float16),
k_bias=attn_layer.k_proj.bias.to(torch.float16),
v_bias=attn_layer.v_proj.bias.to(torch.float16),
layer_idx=layer_idx,
rms_norm_eps=rms_norm_eps,
intermediate_size=intermediate_size,
max_seq_len=max_output_len,
transpose_value=transpose_value_cache,
)
if model.config.intermediate_size == 8960:
# for qwen2-1.5b
for layer_idx in layer_indexs:
curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn
mlp_layer = curr_layer.mlp

weights = [
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
]

cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)

layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)

new_decoderlayer = FusedQwenLowBitDecoderlayer(
weights,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
cached_cos=cached_cos,
cached_sin=cached_sin,
layer_norm_0=layer_norm_0,
layer_norm_1=layer_norm_1,
q_bias=attn_layer.q_proj.bias.to(torch.float16),
k_bias=attn_layer.k_proj.bias.to(torch.float16),
v_bias=attn_layer.v_proj.bias.to(torch.float16),
layer_idx=layer_idx,
rms_norm_eps=rms_norm_eps,
intermediate_size=intermediate_size,
max_seq_len=max_output_len,
transpose_value=transpose_value_cache,
)

layer_weights.extend(weights)
input_layer_norm_weights.append(layer_norm_0)
post_attn_layernorm_weights.append(layer_norm_1)
model.model.layers[layer_idx] = new_decoderlayer
deocderlayers.append(new_decoderlayer)
layer_weights.extend(weights)
input_layer_norm_weights.append(layer_norm_0)
post_attn_layernorm_weights.append(layer_norm_1)
model.model.layers[layer_idx] = new_decoderlayer
deocderlayers.append(new_decoderlayer)

print("finish creating all decode layers in prefill")
result_queue.put("loading finish")

if model.config.intermediate_size == 18944:
# for qwen2-7b
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
from ipex_llm.transformers.npu_models.convert_mp import convert_forward
qwen2_attention_forward = generate_qwen2_attention_forward(max_seq_len=max_output_len,
transpose_value=transpose_value_cache)
convert_forward(model, Qwen2Attention, qwen2_attention_forward)
deocderlayers = model.model.layers

while True:

result = input_queue.get()
Expand Down Expand Up @@ -1053,3 +1083,79 @@ def qwen2_casullm_forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv
import math


def generate_qwen2_attention_forward(max_seq_len, transpose_value):
def qwen2_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)

cache_kwargs = {"max_seq_len": max_seq_len, "transpose": transpose_value,}

if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)


key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = None
if query_states.size(2) == key_states.size(2):
# first token
from intel_npu_acceleration_library.functional import scaled_dot_product_attention
attn_output = scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=q_len > 1 and bsz == 1,
)
else:
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
return qwen2_attention_forward

0 comments on commit ffdc410

Please sign in to comment.