Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor to simplify following upgrade #12680

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,6 @@ def _optimize_post(model):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.chatglm2 import chatglm2_attention_forward
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward
from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward
from ipex_llm.transformers.models.chatglm2 import mlp_forward
Expand All @@ -1338,9 +1337,7 @@ def _optimize_post(model):
convert_forward(model,
module.ChatGLMModel,
chatglm2_model_forward)
convert_forward(model,
module.RMSNorm,
chatglm_rms_norm_forward)
convert_forward(model, module.RMSNorm, rms_norm_forward)
convert_forward(model, module.MLP, mlp_forward)
# for codegeex-nano
if hasattr(model.config, "rope_ratio"):
Expand All @@ -1358,8 +1355,7 @@ def _optimize_post(model):
# glm4 family
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward)
convert_forward(model, module.RMSNorm, rms_norm_forward)

if hasattr(model.transformer, "vision"):
# glm4 vision family
Expand Down Expand Up @@ -1448,8 +1444,8 @@ def _optimize_post(model):
elif model.config.model_type == "baichuan":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.baichuan import baichuan_mlp_forward
convert_forward(model, module.MLP, baichuan_mlp_forward)
convert_forward(model, module.RMSNorm, rms_norm_forward)
convert_forward(model, module.MLP, mlp_silu_forward)

if model.config.hidden_size in [4096, 2048]:
# baichuan-7B and baichuan2-7B
Expand All @@ -1458,7 +1454,6 @@ def _optimize_post(model):
for i in range(len(model.model.layers)):
setattr(model.model.layers[i].self_attn, "layer_idx", i)
convert_forward(model, module.Attention, baichuan_attention_forward_7b)
convert_forward(model, module.RMSNorm, rms_norm_forward)
if model.config.vocab_size == 125696:
# baichuan2-7B
convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward)
Expand All @@ -1468,9 +1463,7 @@ def _optimize_post(model):
elif model.config.hidden_size == 5120:
# baichuan-13B and baichuan2-13B
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b
from ipex_llm.transformers.models.baichuan import baichuan_13b_rms_norm_forward
convert_forward(model, module.BaichuanAttention, baichuan_attention_forward_13b)
convert_forward(model, module.RMSNorm, baichuan_13b_rms_norm_forward)

if model.config.vocab_size == 125696:
# baichaun2-13B
Expand Down Expand Up @@ -1565,7 +1558,6 @@ def _optimize_post(model):
from ipex_llm.transformers.models.qwen import qwen_attention_forward
from ipex_llm.transformers.models.qwen import qwen_attention_forward_registered
from ipex_llm.transformers.models.qwen import qwen_mlp_forward
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
from ipex_llm.transformers.models.qwen import qwen_model_forward
if model.config.max_position_embeddings == 8192 \
and model.config.hidden_size == 4096:
Expand All @@ -1580,7 +1572,7 @@ def _optimize_post(model):
)
convert_forward(model,
module.RMSNorm,
chatglm_rms_norm_forward)
rms_norm_forward)
convert_forward(model,
module.QWenMLP,
qwen_mlp_forward)
Expand Down
32 changes: 0 additions & 32 deletions python/llm/src/ipex_llm/transformers/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,38 +47,6 @@ def pre_compute_inv_freq(module: torch.nn.Module):
module.register_buffer("inv_freq", inv_freq, persistent=False)


def baichuan_13b_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training or hidden_states.requires_grad):
import xe_addons
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = xe_addons.rms_norm(self.weight, x_2d, self.epsilon)
return output.reshape(hidden_states.shape)

input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
return self.weight * hidden_states.to(input_dtype)


def baichuan_mlp_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training):
import xe_linear
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
return self.down_proj(xe_linear.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
SILU, qtype
))
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


def baichuan_model_7b_forward(
self,
input_ids: torch.LongTensor = None,
Expand Down
15 changes: 2 additions & 13 deletions python/llm/src/ipex_llm/transformers/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,13 @@
import torch
from typing import Optional, Tuple
from transformers.models.bert.modeling_bert import BertSelfAttention, BertEncoder
from ipex_llm.transformers.models.common import merge_linear
from ipex_llm.utils.common import invalidInputError


def merge_qkv(module: torch.nn.Module):
if isinstance(module, BertSelfAttention):
q_w = module.query.weight.data
k_w = module.key.weight.data
v_w = module.value.weight.data
q_b = module.query.bias.data
k_b = module.key.bias.data
v_b = module.value.bias.data
new_w = torch.cat([q_w, k_w, v_w], dim=0)
new_b = torch.cat([q_b, k_b, v_b], dim=-1)
qkv = torch.nn.Linear(0, 0, bias=True)
qkv.weight = torch.nn.Parameter(new_w, requires_grad=False)
qkv.bias = torch.nn.Parameter(new_b, requires_grad=False)
qkv.in_features = module.query.in_features
qkv.out_features = module.query.out_features * 3
qkv = merge_linear([module.query, module.key, module.value])
module.qkv = qkv
del module.query
del module.key
Expand Down
28 changes: 0 additions & 28 deletions python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,34 +33,6 @@
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states
go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def chatglm_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import xe_addons
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = xe_addons.rms_norm(self.weight, x_2d, self.eps)
return output.reshape(hidden_states.shape)

input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return self.weight * hidden_states.to(input_dtype)


def chatglm2_model_forward(
self,
input_ids,
Expand Down
4 changes: 3 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,10 @@ def rms_norm_forward(self, hidden_states: torch.Tensor):
weight = self.weight
if hasattr(self, "variance_epsilon"):
eps = self.variance_epsilon
else:
elif hasattr(self, "epsilon"):
eps = self.epsilon
else:
eps = self.eps

if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
import xe_addons
Expand Down
Loading