Skip to content

Commit

Permalink
update minicpm.py (#11517)
Browse files Browse the repository at this point in the history
* update minicpm

* meet code review
  • Loading branch information
qiuxin2012 authored Jul 5, 2024
1 parent 24de13f commit a31f2cb
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions python/llm/src/ipex_llm/transformers/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,9 @@ def minicpm_attention_forward_original(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama",
rope_theta=rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
if cache_position is not None:
# for transformers 4.38.0
Expand Down Expand Up @@ -313,7 +310,6 @@ def minicpm_attention_forward_original(
is_causal=True)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
self.layer_idx > 0 and \
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import xe_addons
attn_output = xe_addons.sdp(query_states, key_states, value_states,
Expand Down Expand Up @@ -450,12 +446,9 @@ def minicpm_attention_forward_quantized(
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama",
rope_theta=rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
if cache_position is not None:
# for transformers 4.38.0
Expand Down

0 comments on commit a31f2cb

Please sign in to comment.