Skip to content

Commit

Permalink
LLM: unify memory optimization env variables. (#11549)
Browse files Browse the repository at this point in the history
* LLM: unify memory optimization env variables.

* fix comments.
  • Loading branch information
lalalapotter authored Jul 11, 2024
1 parent 51f2eff commit 70ab1a6
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
8 changes: 5 additions & 3 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
optimize_lm_head = False
if is_lm_head(name, model_config, out_features):
model_type = getattr(model_config, "model_type", None)
if model_type in ["gptj", "llama", "qwen2"] and \
os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "1":
optimize_lm_head = True
if model_type in ["gptj", "llama", "qwen2"]:
if os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) is not None:
optimize_lm_head = os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "1"
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
optimize_lm_head = os.environ.get("IPEX_LLM_LOW_MEM", None) == "1"
with init_empty_weights():
new_linear = None
is_gptq = is_gptq_linear(module)
Expand Down
2 changes: 2 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, out
if not output_attentions:
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
return os.environ.get("IPEX_LLM_LOW_MEM", None) == "1"
elif query_states.dtype == torch.float16 and \
query_states.shape[2] >= 6800:
# split tensor for memory block limitation
Expand Down
2 changes: 2 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, out
if not output_attentions:
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
return os.environ.get("IPEX_LLM_LOW_MEM", None) == "1"
elif query_states.dtype == torch.float16 and \
query_states.shape[2] >= 6300:
# split tensor for memory block limitation
Expand Down

0 comments on commit 70ab1a6

Please sign in to comment.