Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support streaming for lookup generation #11922

Merged
Merged
46 changes: 39 additions & 7 deletions python/llm/src/ipex_llm/transformers/lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,40 @@ def generate(
):
lookahead = kwargs.pop("lookahead", None)
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
if perf_mode == "1" and lookahead is None:
if inputs is not None:
if inputs.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
lookahead = 2 # default to 2 now

input_ids_shape = None
if inputs is not None:
input_ids_shape = inputs.shape
else:
input_ids = kwargs.get("input_ids", None)
if input_ids is not None:
input_ids_shape = input_ids.shape
else:
inputs_embeds = kwargs.get("inputs_embeds", None)
if inputs_embeds is not None:
if inputs_embeds.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
lookahead = 2 # default to 2 now
input_ids_shape = inputs_embeds.shape

if perf_mode == "1" and lookahead is None:
if input_ids_shape is not None and \
input_ids_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
lookahead = 2 # default to 2 now

if lookahead:
from ipex_llm.transformers.convert import get_enable_ipex
_enable_ipex = get_enable_ipex()

if self.device.type == "cpu" and _enable_ipex:
logger.warning("Prompt lookup is currently not supported on CPU with IPEX, "
"fallback to original generate.")
kwargs.pop("max_matching_ngram_size")
kwargs.pop("max_matching_ngram_size", None)
elif input_ids_shape is not None and input_ids_shape[0] > 1:
logger.warning("Prompt lookup is currently not supported with batch inference, "
"fallback to original generate.")
kwargs.pop("max_matching_ngram_size", None)
elif kwargs.get("num_beams", None) not in [None, 1]:
logger.warning("Prompt lookup is currently not supported with num_beams != 1, "
"fallback to original generate.")
kwargs.pop("max_matching_ngram_size", None)
else:
# Do prompt lookup generation
# If lookahead is provided, we will use lookup_generate instead of
Expand All @@ -94,6 +111,7 @@ def generate(
return self.lookup_generate(inputs=inputs,
num_output_tokens=lookahead,
generation_config=generation_config,
streamer=streamer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
Expand Down Expand Up @@ -254,12 +272,19 @@ def lookup_generate(self,
num_output_tokens: int = 10,
max_matching_ngram_size: int = None,
generation_config: Optional[GenerationConfig] = None,
streamer: Optional["BaseStreamer"] = None,
attention_mask=None,
**sampling_kwargs):
input_ids, generation_config, logits_processor, stopping_criteria, \
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
**sampling_kwargs)

invalidInputError(input_ids.shape[0] == 1,
"Prompt lookup is currently not supported with batch inference.")

if streamer is not None:
streamer.put(input_ids.cpu())

device_name = get_xpu_device_type(input_ids)

candidates_generator = PromptLookupCandidateGenerator(
Expand Down Expand Up @@ -406,12 +431,19 @@ def lookup_generate(self,
first_eos_idx = out_idx
break
if first_eos_idx > -1:
if streamer is not None:
streamer.put(output_ids[:(first_eos_idx + 1)].cpu())
step -= (len(output_ids_list) - first_eos_idx - 1)
break
if streamer is not None:
streamer.put(output_ids.cpu())

step = min(step, max_new_tokens)
e2e_toc = time.time()
self.n_token_generated = step
self.e2e_time_without_first = e2e_toc - e2e_tic

if streamer is not None:
streamer.end()

return input_ids[:, : input_len + step]
Loading