Skip to content

Commit

Permalink
Support streaming for lookup generation (#11922)
Browse files Browse the repository at this point in the history
* Support streaming for lookup generation

* Small update

* Style fixes

* Add origin generate full back for batch inference and beam search; support input length threshold judgement for directly input with input_ids

* Fix lookup stream generate with eos token

* Small fixes

* Small fix

* index fix

* Small fix
  • Loading branch information
Oscilloscope98 authored Aug 26, 2024
1 parent a0bbd8e commit c1d07bc
Showing 1 changed file with 39 additions and 7 deletions.
46 changes: 39 additions & 7 deletions python/llm/src/ipex_llm/transformers/lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,40 @@ def generate(
):
lookahead = kwargs.pop("lookahead", None)
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
if perf_mode == "1" and lookahead is None:
if inputs is not None:
if inputs.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
lookahead = 2 # default to 2 now

input_ids_shape = None
if inputs is not None:
input_ids_shape = inputs.shape
else:
input_ids = kwargs.get("input_ids", None)
if input_ids is not None:
input_ids_shape = input_ids.shape
else:
inputs_embeds = kwargs.get("inputs_embeds", None)
if inputs_embeds is not None:
if inputs_embeds.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
lookahead = 2 # default to 2 now
input_ids_shape = inputs_embeds.shape

if perf_mode == "1" and lookahead is None:
if input_ids_shape is not None and \
input_ids_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD:
lookahead = 2 # default to 2 now

if lookahead:
from ipex_llm.transformers.convert import get_enable_ipex
_enable_ipex = get_enable_ipex()

if self.device.type == "cpu" and _enable_ipex:
logger.warning("Prompt lookup is currently not supported on CPU with IPEX, "
"fallback to original generate.")
kwargs.pop("max_matching_ngram_size")
kwargs.pop("max_matching_ngram_size", None)
elif input_ids_shape is not None and input_ids_shape[0] > 1:
logger.warning("Prompt lookup is currently not supported with batch inference, "
"fallback to original generate.")
kwargs.pop("max_matching_ngram_size", None)
elif kwargs.get("num_beams", None) not in [None, 1]:
logger.warning("Prompt lookup is currently not supported with num_beams != 1, "
"fallback to original generate.")
kwargs.pop("max_matching_ngram_size", None)
else:
# Do prompt lookup generation
# If lookahead is provided, we will use lookup_generate instead of
Expand All @@ -94,6 +111,7 @@ def generate(
return self.lookup_generate(inputs=inputs,
num_output_tokens=lookahead,
generation_config=generation_config,
streamer=streamer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
Expand Down Expand Up @@ -254,12 +272,19 @@ def lookup_generate(self,
num_output_tokens: int = 10,
max_matching_ngram_size: int = None,
generation_config: Optional[GenerationConfig] = None,
streamer: Optional["BaseStreamer"] = None,
attention_mask=None,
**sampling_kwargs):
input_ids, generation_config, logits_processor, stopping_criteria, \
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
**sampling_kwargs)

invalidInputError(input_ids.shape[0] == 1,
"Prompt lookup is currently not supported with batch inference.")

if streamer is not None:
streamer.put(input_ids.cpu())

device_name = get_xpu_device_type(input_ids)

candidates_generator = PromptLookupCandidateGenerator(
Expand Down Expand Up @@ -406,12 +431,19 @@ def lookup_generate(self,
first_eos_idx = out_idx
break
if first_eos_idx > -1:
if streamer is not None:
streamer.put(output_ids[:(first_eos_idx + 1)].cpu())
step -= (len(output_ids_list) - first_eos_idx - 1)
break
if streamer is not None:
streamer.put(output_ids.cpu())

step = min(step, max_new_tokens)
e2e_toc = time.time()
self.n_token_generated = step
self.e2e_time_without_first = e2e_toc - e2e_tic

if streamer is not None:
streamer.end()

return input_ids[:, : input_len + step]

0 comments on commit c1d07bc

Please sign in to comment.