Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add basic glm-edge support #12531

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,10 @@ def _optimize_pre(model, qtype=None):
model.apply(split_mlp)
elif model.config.num_layers in [40, 28]:
model.apply(split_mlp)
elif model.config.model_type == "glm":
from ipex_llm.transformers.models.glm import merge_qkv, split_mlp
model.apply(merge_qkv)
model.apply(split_mlp)

return model

Expand Down Expand Up @@ -1487,7 +1491,19 @@ def _optimize_post(model, lightweight_bmm=False):
# workaround glm4-9b fp16 overflow
from ipex_llm.transformers.models.chatglm4 import chatglm4_block_forward
convert_forward(model, module.GLMBlock, chatglm4_block_forward)

elif model.config.model_type == "glm":
# glm-edge series
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.common import rms_norm_forward
from ipex_llm.transformers.models.common import mlp_silu_forward
from ipex_llm.transformers.models.glm import glm_attention_forward
from ipex_llm.transformers.models.glm import glm_model_forward_wrapper
convert_forward(model, module.GlmRMSNorm, rms_norm_forward)
convert_forward(model, module.GlmMLP, mlp_silu_forward)
convert_forward(model, module.GlmAttention, glm_attention_forward)
glm_model_forward = glm_model_forward_wrapper(module.GlmModel.forward)
convert_forward(model, module.GlmModel, glm_model_forward)
elif "mpt" in model.config.model_type:
if model.config.architectures is not None:
modeling_module_name = model.__class__.__module__
Expand Down
202 changes: 202 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/glm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file is adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm/modeling_glm.py
#
# which is licensed under Apache License 2.0:
#
# Copyright 2024 The GLM & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from typing import Optional, Tuple
from transformers.cache_utils import Cache
from transformers.models.glm.modeling_glm import GlmAttention, GlmMLP
from transformers.models.glm.modeling_glm import repeat_kv, apply_rotary_pos_emb
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache


def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, GlmAttention)


def split_mlp(module: torch.nn.Module):
if isinstance(module, GlmMLP):
gate_weight, up_weight = module.gate_up_proj.weight.data.chunk(2, dim=0)

gate_proj = torch.nn.Linear(0, 0, bias=False)
gate_proj.weight = torch.nn.Parameter(gate_weight, requires_grad=False)
gate_proj.in_features = gate_weight.size(1)
gate_proj.out_features = gate_weight.size(0)

up_proj = torch.nn.Linear(0, 0, bias=False)
up_proj.weight = torch.nn.Parameter(up_weight, requires_grad=False)
up_proj.in_features = up_weight.size(1)
up_proj.out_features = up_weight.size(0)

module.gate_proj = gate_proj
module.up_proj = up_proj

del module.gate_up_proj

# rename activation function
module.act_fn = module.activation_fn


def glm_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]]=None,
**kwargs,
):
bsz, q_len, _ = hidden_states.size()

qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=1)

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

use_quantizekv = isinstance(past_key_value, DynamicFp8Cache)
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)

kv_seq_len = key_states.size(-2)
if attention_mask is not None: # no matter the length, we just slice it
attention_mask = attention_mask[:, :, :, : kv_seq_len]

if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if use_quantizekv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if use_quantizekv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if use_quantizekv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) * self.scaling

if attention_mask is not None:
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value


def glm_model_forward_wrapper(origin_forward):
def glm_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs,
):
# ipex-llm changes start
# IPEX-LLM OPT: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = use_cache or inputs.device.type == 'xpu'
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads //
self.config.num_key_value_heads)

if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
# ipex-llm changes end

return origin_forward(
self=self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**flash_attn_kwargs,
)

return glm_model_forward
Loading