diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index c6fe4847073..c70558c7fd2 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -59,15 +59,24 @@ def generate( ): lookahead = kwargs.pop("lookahead", None) perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None) - if perf_mode == "1" and lookahead is None: - if inputs is not None: - if inputs.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: - lookahead = 2 # default to 2 now + + input_ids_shape = None + if inputs is not None: + input_ids_shape = inputs.shape + else: + input_ids = kwargs.get("input_ids", None) + if input_ids is not None: + input_ids_shape = input_ids.shape else: inputs_embeds = kwargs.get("inputs_embeds", None) if inputs_embeds is not None: - if inputs_embeds.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: - lookahead = 2 # default to 2 now + input_ids_shape = inputs_embeds.shape + + if perf_mode == "1" and lookahead is None: + if input_ids_shape is not None and \ + input_ids_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: + lookahead = 2 # default to 2 now + if lookahead: from ipex_llm.transformers.convert import get_enable_ipex _enable_ipex = get_enable_ipex() @@ -75,7 +84,15 @@ def generate( if self.device.type == "cpu" and _enable_ipex: logger.warning("Prompt lookup is currently not supported on CPU with IPEX, " "fallback to original generate.") - kwargs.pop("max_matching_ngram_size") + kwargs.pop("max_matching_ngram_size", None) + elif input_ids_shape is not None and input_ids_shape[0] > 1: + logger.warning("Prompt lookup is currently not supported with batch inference, " + "fallback to original generate.") + kwargs.pop("max_matching_ngram_size", None) + elif kwargs.get("num_beams", None) not in [None, 1]: + logger.warning("Prompt lookup is currently not supported with num_beams != 1, " + "fallback to original generate.") + kwargs.pop("max_matching_ngram_size", None) else: # Do prompt lookup generation # If lookahead is provided, we will use lookup_generate instead of @@ -94,6 +111,7 @@ def generate( return self.lookup_generate(inputs=inputs, num_output_tokens=lookahead, generation_config=generation_config, + streamer=streamer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, @@ -254,12 +272,19 @@ def lookup_generate(self, num_output_tokens: int = 10, max_matching_ngram_size: int = None, generation_config: Optional[GenerationConfig] = None, + streamer: Optional["BaseStreamer"] = None, attention_mask=None, **sampling_kwargs): input_ids, generation_config, logits_processor, stopping_criteria, \ model_kwargs = _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs) + invalidInputError(input_ids.shape[0] == 1, + "Prompt lookup is currently not supported with batch inference.") + + if streamer is not None: + streamer.put(input_ids.cpu()) + device_name = get_xpu_device_type(input_ids) candidates_generator = PromptLookupCandidateGenerator( @@ -406,12 +431,19 @@ def lookup_generate(self, first_eos_idx = out_idx break if first_eos_idx > -1: + if streamer is not None: + streamer.put(output_ids[:(first_eos_idx + 1)].cpu()) step -= (len(output_ids_list) - first_eos_idx - 1) break + if streamer is not None: + streamer.put(output_ids.cpu()) step = min(step, max_new_tokens) e2e_toc = time.time() self.n_token_generated = step self.e2e_time_without_first = e2e_toc - e2e_tic + if streamer is not None: + streamer.end() + return input_ids[:, : input_len + step]