From 5837bc00148ceeedd788c549ebd1287accbaa6e1 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 16 Jul 2024 18:16:30 +0800 Subject: [PATCH] fix chatglm3 npu output (#11590) --- .../transformers/npu_models/chatglm.py | 51 +++++++++---------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/chatglm.py b/python/llm/src/ipex_llm/transformers/npu_models/chatglm.py index 69579053a6e..d9eba2775c4 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/chatglm.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/chatglm.py @@ -64,7 +64,16 @@ def chatglm2_model_forward( rotary_pos_emb = rotary_pos_emb[position_ids] else: rotary_pos_emb = rotary_pos_emb[None, :seq_length] - rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + # ipex-llm change start: change rope cache shape + # rotary_pos_emb: [bsz, seq_len, rot_dim//2, 2] + cos, sin = rotary_pos_emb.permute(3, 0, 1, 2).chunk(2, dim=0) + cos = cos.squeeze(0).unsqueeze(1) + sin = sin.squeeze(0).unsqueeze(1) + cos = cos.repeat_interleave(2, dim=-1) + sin = sin.repeat_interleave(2, dim=-1) + # cos, sin: [bsz, 1, seq_len, rot_dim] + rotary_pos_emb = (cos, sin) + # ipex-llm change end # ipex-llm changes begin: # generate `causal_mask` and replace `full_attention_mask` with it @@ -76,14 +85,6 @@ def chatglm2_model_forward( dtype=inputs_embeds.dtype, device=inputs_embeds.device) mask_value = torch.finfo(inputs_embeds.dtype).min causal_mask.masked_fill_(full_attention_mask, mask_value) - elif self.training or (inputs_embeds.device.type != "xpu" and past_key_values is None): - full_attention_mask = self.get_masks(input_ids, - past_key_values, - padding_mask=attention_mask) - causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)], - dtype=inputs_embeds.dtype, device=inputs_embeds.device) - mask_value = torch.finfo(inputs_embeds.dtype).min - causal_mask.masked_fill_(full_attention_mask, mask_value) else: causal_mask = None @@ -174,24 +175,20 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @torch.jit.script -def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: - # x: [sq, b, np, hn] - sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) - rot_dim = rope_cache.shape[-2] * 2 +def rotate_every_two(x: torch.Tensor): + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: Tuple[torch.Tensor]) -> torch.Tensor: + # x: [bsz, n_head, seq_len, head_dim] + cos, sin = rope_cache + rot_dim = cos.size(-1) x, x_pass = x[..., :rot_dim], x[..., rot_dim:] - # truncate to support variable sizes - rope_cache = rope_cache[:sq] - xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) - rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], - ], - -1, - ) - x_out2 = x_out2.flatten(3) - return torch.cat((x_out2, x_pass), dim=-1) + x_out = x * cos + rotate_every_two(x) * sin + return torch.cat([x_out, x_pass], dim=-1) def chatglm2_attention_forward( @@ -246,7 +243,7 @@ def chatglm2_attention_forward( key_states, value_states, attn_mask=attention_mask, - is_causal=q_len > 1 and bsz == 1, + is_causal=attention_mask is None and q_len > 1 and bsz == 1, ) attn_weights = None else: