Skip to content

Commit

Permalink
Performance mode strategy update for input_embeds input (intel-analyt…
Browse files Browse the repository at this point in the history
  • Loading branch information
Oscilloscope98 authored and cranechu0131 committed Sep 9, 2024
1 parent 6a8a563 commit 6e37dea
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions python/llm/src/ipex_llm/transformers/lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,24 @@ def generate(
lookahead = kwargs.pop("lookahead", None)
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)

input_ids_shape = None
input_tensor_shape = None
is_inputs_embeds = False
if inputs is not None:
input_ids_shape = inputs.shape
input_tensor_shape = inputs.shape
else:
input_ids = kwargs.get("input_ids", None)
if input_ids is not None:
input_ids_shape = input_ids.shape
input_tensor_shape = input_ids.shape
else:
inputs_embeds = kwargs.get("inputs_embeds", None)
if inputs_embeds is not None:
input_ids_shape = inputs_embeds.shape
is_inputs_embeds = True
input_tensor_shape = inputs_embeds.shape

if perf_mode == "1" and lookahead is None:
if input_ids_shape is not None and \
input_ids_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
if input_tensor_shape is not None and \
input_tensor_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD \
and not is_inputs_embeds:
lookahead = 2 # default to 2 now

if lookahead:
Expand All @@ -85,7 +88,7 @@ def generate(
logger.warning("Prompt lookup is currently not supported on CPU with IPEX, "
"fallback to original generate.")
kwargs.pop("max_matching_ngram_size", None)
elif input_ids_shape is not None and input_ids_shape[0] > 1:
elif input_tensor_shape is not None and input_tensor_shape[0] > 1:
logger.warning("Prompt lookup is currently not supported with batch inference, "
"fallback to original generate.")
kwargs.pop("max_matching_ngram_size", None)
Expand Down

0 comments on commit 6e37dea

Please sign in to comment.