Skip to content

Commit

Permalink
fix chatglm3 npu output (#11590)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Jul 16, 2024
1 parent 06930ab commit 5837bc0
Showing 1 changed file with 24 additions and 27 deletions.
51 changes: 24 additions & 27 deletions python/llm/src/ipex_llm/transformers/npu_models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,16 @@ def chatglm2_model_forward(
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
# ipex-llm change start: change rope cache shape
# rotary_pos_emb: [bsz, seq_len, rot_dim//2, 2]
cos, sin = rotary_pos_emb.permute(3, 0, 1, 2).chunk(2, dim=0)
cos = cos.squeeze(0).unsqueeze(1)
sin = sin.squeeze(0).unsqueeze(1)
cos = cos.repeat_interleave(2, dim=-1)
sin = sin.repeat_interleave(2, dim=-1)
# cos, sin: [bsz, 1, seq_len, rot_dim]
rotary_pos_emb = (cos, sin)
# ipex-llm change end

# ipex-llm changes begin:
# generate `causal_mask` and replace `full_attention_mask` with it
Expand All @@ -76,14 +85,6 @@ def chatglm2_model_forward(
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
mask_value = torch.finfo(inputs_embeds.dtype).min
causal_mask.masked_fill_(full_attention_mask, mask_value)
elif self.training or (inputs_embeds.device.type != "xpu" and past_key_values is None):
full_attention_mask = self.get_masks(input_ids,
past_key_values,
padding_mask=attention_mask)
causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)],
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
mask_value = torch.finfo(inputs_embeds.dtype).min
causal_mask.masked_fill_(full_attention_mask, mask_value)
else:
causal_mask = None

Expand Down Expand Up @@ -174,24 +175,20 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:


@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
# x: [sq, b, np, hn]
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
rot_dim = rope_cache.shape[-2] * 2
def rotate_every_two(x: torch.Tensor):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)


def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: Tuple[torch.Tensor]) -> torch.Tensor:
# x: [bsz, n_head, seq_len, head_dim]
cos, sin = rope_cache
rot_dim = cos.size(-1)
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
# truncate to support variable sizes
rope_cache = rope_cache[:sq]
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
],
-1,
)
x_out2 = x_out2.flatten(3)
return torch.cat((x_out2, x_pass), dim=-1)
x_out = x * cos + rotate_every_two(x) * sin
return torch.cat([x_out, x_pass], dim=-1)


def chatglm2_attention_forward(
Expand Down Expand Up @@ -246,7 +243,7 @@ def chatglm2_attention_forward(
key_states,
value_states,
attn_mask=attention_mask,
is_causal=q_len > 1 and bsz == 1,
is_causal=attention_mask is None and q_len > 1 and bsz == 1,
)
attn_weights = None
else:
Expand Down

0 comments on commit 5837bc0

Please sign in to comment.