Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 committed Jul 12, 2024
1 parent cb10902 commit def438f
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,6 @@ def qwen2_model_forward(

hidden_states = inputs_embeds

# ipex-llm changes
curr_device = decoder_layer.input_layernorm.weight.device
if attention_mask is not None:
attention_mask = attention_mask.to(curr_device)
if position_ids is not None:
position_ids = position_ids.to(curr_device)
# ipex-llm changes end

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
Expand All @@ -212,6 +204,13 @@ def qwen2_model_forward(
use_cache,
)
else:
# ipex-llm changes
curr_device = decoder_layer.input_layernorm.weight.device
if attention_mask is not None:
attention_mask = attention_mask.to(curr_device)
if position_ids is not None:
position_ids = position_ids.to(curr_device)
# ipex-llm changes end
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
Expand Down

0 comments on commit def438f

Please sign in to comment.