Skip to content

Commit

Permalink
fix glm4 error
Browse files Browse the repository at this point in the history
  • Loading branch information
cyita authored and Oscilloscope98 committed Nov 15, 2024
1 parent a2ca7b2 commit 1b36bf4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
11 changes: 8 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/chatglm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,14 @@ def chatglm4_model_forward(
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or\
(past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(inputs_embeds,
past_key_values,
padding_mask=attention_mask)
if self.config.hidden_size == 4096:
full_attention_mask = self.get_masks(input_ids,
past_key_values,
padding_mask=attention_mask)
else:
full_attention_mask = self.get_masks(inputs_embeds,
past_key_values,
padding_mask=attention_mask)

# ipex-llm changes begin
# 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids`
Expand Down
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa
for k, v in past_key_values
]
elif self.config.model_type == "chatglm":
if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'):
if self.config.num_layers in [28, 40] and hasattr(self.config, 'rope_ratio'):
past_key_values = [
(k[:, :, :-(new_cache_size), :],
v[:, :, :-(new_cache_size), :])
Expand Down Expand Up @@ -768,7 +768,7 @@ def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_
forward_args["attention_mask"] = cur_attention_mask

if self.config.model_type == "chatglm":
if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'):
if self.config.num_layers in [28, 40] and hasattr(self.config, 'rope_ratio'):
past_key_value_len = past_key_values[0][0].shape[2]
else:
past_key_value_len = past_key_values[0][0].shape[0]
Expand Down

0 comments on commit 1b36bf4

Please sign in to comment.