Skip to content

Commit

Permalink
optimize qwen2 gpu memory usage again
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 committed Jun 26, 2024
1 parent a45ceac commit 9dc8119
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,21 @@ def merge_qkv(module: torch.nn.Module):

del module.q_proj, module.k_proj, module.v_proj

# Qwen2 uses pre-computed rope table to accelerate rope,
# original `cos_cached` and `sin_cached` are added by `register_buffer`,
# so they will move to xpu during `model.to('xpu')`.
# But gpu fuse kernel doesn't need this rope table, only cpu needs them,
# so delete them then add them with `=`, so that they will be pinned on CPU,
# this can save about 0.5GB gpu memory usage when running Qwen2
if hasattr(module.rotary_emb, "cos_cached"):
cos_cached = module.rotary_emb.cos_cached
del module.rotary_emb.cos_cached
module.rotary_emb.cos_cached = cos_cached
if hasattr(module.rotary_emb, "sin_cached"):
sin_cached = module.rotary_emb.sin_cached
del module.rotary_emb.sin_cached
module.rotary_emb.sin_cached = sin_cached


def padding_mlp(module: torch.nn.Module):
# for qwen 1.5 14B
Expand Down Expand Up @@ -422,6 +437,7 @@ def qwen2_attention_forward(
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = cos.to(device), sin.to(device)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)

Expand Down

0 comments on commit 9dc8119

Please sign in to comment.