Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix qwen's position_ids no enough #10572

Merged
merged 2 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,6 @@ def merge_qk_proj_func(module):
from ipex_llm.transformers.models.bert import merge_qkv
model.apply(merge_qkv)
if model.config.model_type == "qwen":
position_ids = torch.arange(0, model.config.max_position_embeddings)
rope_base = model.config.rotary_emb_base
from accelerate.big_modeling import init_empty_weights

Expand Down Expand Up @@ -625,7 +624,6 @@ def split_qkv_proj_func(module):
module.q_proj = q_proj
module.k_proj = k_proj
module.v_proj = v_proj
module.position_ids = position_ids
module.rope_base = rope_base
del module.c_attn
model.apply(split_qkv_proj_func)
Expand Down
13 changes: 7 additions & 6 deletions python/llm/src/ipex_llm/transformers/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def qwen_attention_forward_original(
device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
position_ids = rotary_pos_emb_list[-1] # the last one is posisiton_ids
rotary_pos_emb_list = rotary_pos_emb_list[:-1]

use_fuse_rope = should_use_fuse_rope(self, hidden_states)
qtype_check = decoding_fast_path_qtype_check(self.q_proj)
Expand All @@ -147,8 +149,6 @@ def qwen_attention_forward_original(
cache_v = cache_v.transpose(1, 2)

kv_seq_len = cache_k.shape[-2]
self.position_ids = self.position_ids.to(device)
position_ids = self.position_ids[kv_seq_len]
base = self.rope_base
if is_enough_kv_cache_room(layer_past, kv_seq_len):
new_cache_k, new_cache_v = extend_kv_cache(bsz,
Expand Down Expand Up @@ -182,7 +182,7 @@ def qwen_attention_forward_original(
# query = self._split_heads(query, self.num_heads, self.head_dim)
# key = self._split_heads(key, self.num_heads, self.head_dim)
# value = self._split_heads(value, self.num_heads, self.head_dim)
if rotary_pos_emb_list is not None:
if len(rotary_pos_emb_list) != 0:
cur_len = query.shape[1]
if len(rotary_pos_emb_list) == 1:
rotary_pos_emb = rotary_pos_emb_list[0]
Expand Down Expand Up @@ -332,6 +332,8 @@ def qwen_attention_forward_quantized(

bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
position_ids = rotary_pos_emb_list[-1] # the last one is posisiton_ids
rotary_pos_emb_list = rotary_pos_emb_list[:-1]

use_fuse_rope = should_use_fuse_rope(self, hidden_states)
# qtype_check = decoding_fast_path_qtype_check(self.q_proj)
Expand All @@ -349,7 +351,6 @@ def qwen_attention_forward_quantized(
device=device
)

position_ids = self.position_ids[self.kv_seq_len].to(device)
base = self.rope_base

args = [hidden_states, self.q_proj.weight.data, self.k_proj.weight.data,
Expand Down Expand Up @@ -599,7 +600,7 @@ def qwen_model_forward(
if self.use_cache_quantization:
past_length = past_key_values[0][0][0].size(2)
else:
past_length = past_key_values[0][0].size(-2)
past_length = past_key_values[0][0].size(1)
if position_ids is None:
position_ids = torch.arange(
past_length,
Expand Down Expand Up @@ -651,7 +652,7 @@ def qwen_model_forward(
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
rotary_pos_emb_list = [
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
]
] + [position_ids]

hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
Expand Down
Loading