From d487669495f9b7b860822afc3f6f1524f165f50e Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Mon, 26 Aug 2024 15:33:46 +0800 Subject: [PATCH 1/9] Support streaming for lookup generation --- python/llm/src/ipex_llm/transformers/lookup.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index c6fe4847073..0a09702c3af 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -97,6 +97,7 @@ def generate( logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + streamer=streamer, **kwargs) return original_generate(self, @@ -254,11 +255,15 @@ def lookup_generate(self, num_output_tokens: int = 10, max_matching_ngram_size: int = None, generation_config: Optional[GenerationConfig] = None, + streamer: Optional["BaseStreamer"] = None, attention_mask=None, **sampling_kwargs): input_ids, generation_config, logits_processor, stopping_criteria, \ model_kwargs = _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs) + + if streamer is not None: + streamer.put(input_ids.cpu()) device_name = get_xpu_device_type(input_ids) @@ -390,6 +395,8 @@ def lookup_generate(self, accept_rate) input_ids = torch.cat((input_ids, output_ids), dim=-1) + if streamer is not None: + streamer.put(output_ids.cpu()) candidates_generator.update_look_up_table(input_ids) step += output_ids.size(1) @@ -414,4 +421,7 @@ def lookup_generate(self, self.n_token_generated = step self.e2e_time_without_first = e2e_toc - e2e_tic + if streamer is not None: + streamer.end() + return input_ids[:, : input_len + step] From 66b511cac79b207412f9d7a506edb9a5e26c8ece Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Mon, 26 Aug 2024 15:35:48 +0800 Subject: [PATCH 2/9] Small update --- python/llm/src/ipex_llm/transformers/lookup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index 0a09702c3af..c7640571a3d 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -94,10 +94,10 @@ def generate( return self.lookup_generate(inputs=inputs, num_output_tokens=lookahead, generation_config=generation_config, + streamer=streamer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - streamer=streamer, **kwargs) return original_generate(self, From c8ef546cd84d9d292c2b2108ee737178489ecbda Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Mon, 26 Aug 2024 15:37:08 +0800 Subject: [PATCH 3/9] Style fixes --- python/llm/src/ipex_llm/transformers/lookup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index c7640571a3d..2dfdb0334c9 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -261,7 +261,7 @@ def lookup_generate(self, input_ids, generation_config, logits_processor, stopping_criteria, \ model_kwargs = _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs) - + if streamer is not None: streamer.put(input_ids.cpu()) From bf3146b8db34dc145c52a1b5a13e329e724187ea Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Mon, 26 Aug 2024 17:25:20 +0800 Subject: [PATCH 4/9] Add origin generate full back for batch inference and beam search; support input length threshold judgement for directly input with input_ids --- .../llm/src/ipex_llm/transformers/lookup.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index 2dfdb0334c9..e35d02d1a8e 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -59,15 +59,22 @@ def generate( ): lookahead = kwargs.pop("lookahead", None) perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None) - if perf_mode == "1" and lookahead is None: - if inputs is not None: - if inputs.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: - lookahead = 2 # default to 2 now + + if inputs is not None: + input_ids_shape = inputs.shape + else: + input_ids = kwargs.get("input_ids", None) + if input_ids is not None: + input_ids_shape = input_ids.shape else: inputs_embeds = kwargs.get("inputs_embeds", None) if inputs_embeds is not None: - if inputs_embeds.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: - lookahead = 2 # default to 2 now + input_ids_shape = inputs_embeds.shape + + if perf_mode == "1" and lookahead is None: + if input_ids_shape.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: + lookahead = 2 # default to 2 now + if lookahead: from ipex_llm.transformers.convert import get_enable_ipex _enable_ipex = get_enable_ipex() @@ -76,6 +83,14 @@ def generate( logger.warning("Prompt lookup is currently not supported on CPU with IPEX, " "fallback to original generate.") kwargs.pop("max_matching_ngram_size") + elif input_ids_shape[0] > 1: + logger.warning("Prompt lookup is currently not supported with batch inference, " + "fallback to original generate.") + kwargs.pop("max_matching_ngram_size") + elif kwargs.get("num_beams", None) != 1: + logger.warning("Prompt lookup is currently not supported with num_beams != 1, " + "fallback to original generate.") + kwargs.pop("max_matching_ngram_size") else: # Do prompt lookup generation # If lookahead is provided, we will use lookup_generate instead of @@ -262,6 +277,9 @@ def lookup_generate(self, model_kwargs = _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs) + invalidInputError(input_ids.shape[0] > 1, + "Prompt lookup is currently not supported with batch inference.") + if streamer is not None: streamer.put(input_ids.cpu()) From c8d4d57a3b736acc5c0bd78221fdb17fcf011fb7 Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Mon, 26 Aug 2024 17:32:09 +0800 Subject: [PATCH 5/9] Fix lookup stream generate with eos token --- python/llm/src/ipex_llm/transformers/lookup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index e35d02d1a8e..ee3e87f4c23 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -413,8 +413,6 @@ def lookup_generate(self, accept_rate) input_ids = torch.cat((input_ids, output_ids), dim=-1) - if streamer is not None: - streamer.put(output_ids.cpu()) candidates_generator.update_look_up_table(input_ids) step += output_ids.size(1) @@ -431,8 +429,11 @@ def lookup_generate(self, first_eos_idx = out_idx break if first_eos_idx > -1: + if streamer is not None: + streamer.put(output_ids[:first_eos_idx].cpu()) step -= (len(output_ids_list) - first_eos_idx - 1) break + streamer.put(output_ids.cpu()) step = min(step, max_new_tokens) e2e_toc = time.time() From 53bd743da18d3de6ea24bc79b472d1ded5098d80 Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Mon, 26 Aug 2024 18:26:30 +0800 Subject: [PATCH 6/9] Small fixes --- python/llm/src/ipex_llm/transformers/lookup.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index ee3e87f4c23..cb13add8e76 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -72,7 +72,7 @@ def generate( input_ids_shape = inputs_embeds.shape if perf_mode == "1" and lookahead is None: - if input_ids_shape.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: + if input_ids_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: lookahead = 2 # default to 2 now if lookahead: @@ -87,7 +87,7 @@ def generate( logger.warning("Prompt lookup is currently not supported with batch inference, " "fallback to original generate.") kwargs.pop("max_matching_ngram_size") - elif kwargs.get("num_beams", None) != 1: + elif kwargs.get("num_beams", None) not in [None, 1]: logger.warning("Prompt lookup is currently not supported with num_beams != 1, " "fallback to original generate.") kwargs.pop("max_matching_ngram_size") @@ -277,7 +277,7 @@ def lookup_generate(self, model_kwargs = _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs) - invalidInputError(input_ids.shape[0] > 1, + invalidInputError(input_ids.shape[0] == 1, "Prompt lookup is currently not supported with batch inference.") if streamer is not None: @@ -433,7 +433,8 @@ def lookup_generate(self, streamer.put(output_ids[:first_eos_idx].cpu()) step -= (len(output_ids_list) - first_eos_idx - 1) break - streamer.put(output_ids.cpu()) + if streamer is not None: + streamer.put(output_ids.cpu()) step = min(step, max_new_tokens) e2e_toc = time.time() From 11752a9a6efe2cea41a78464f4075107e7daf3e4 Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Mon, 26 Aug 2024 18:29:07 +0800 Subject: [PATCH 7/9] Small fix --- python/llm/src/ipex_llm/transformers/lookup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index cb13add8e76..6e4f1afe897 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -82,15 +82,15 @@ def generate( if self.device.type == "cpu" and _enable_ipex: logger.warning("Prompt lookup is currently not supported on CPU with IPEX, " "fallback to original generate.") - kwargs.pop("max_matching_ngram_size") + kwargs.pop("max_matching_ngram_size", None) elif input_ids_shape[0] > 1: logger.warning("Prompt lookup is currently not supported with batch inference, " "fallback to original generate.") - kwargs.pop("max_matching_ngram_size") + kwargs.pop("max_matching_ngram_size", None) elif kwargs.get("num_beams", None) not in [None, 1]: logger.warning("Prompt lookup is currently not supported with num_beams != 1, " "fallback to original generate.") - kwargs.pop("max_matching_ngram_size") + kwargs.pop("max_matching_ngram_size", None) else: # Do prompt lookup generation # If lookahead is provided, we will use lookup_generate instead of From e0eb5ff203a4b55dc35d84e8c97545bb7ff90077 Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Mon, 26 Aug 2024 18:31:18 +0800 Subject: [PATCH 8/9] index fix --- python/llm/src/ipex_llm/transformers/lookup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index 6e4f1afe897..570ff18e27a 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -430,7 +430,7 @@ def lookup_generate(self, break if first_eos_idx > -1: if streamer is not None: - streamer.put(output_ids[:first_eos_idx].cpu()) + streamer.put(output_ids[:(first_eos_idx + 1)].cpu()) step -= (len(output_ids_list) - first_eos_idx - 1) break if streamer is not None: From 80f49d7bd30f4b88534951f48e71cf0dcca52dab Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Mon, 26 Aug 2024 19:04:14 +0800 Subject: [PATCH 9/9] Small fix --- python/llm/src/ipex_llm/transformers/lookup.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index 570ff18e27a..c70558c7fd2 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -60,6 +60,7 @@ def generate( lookahead = kwargs.pop("lookahead", None) perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None) + input_ids_shape = None if inputs is not None: input_ids_shape = inputs.shape else: @@ -72,7 +73,8 @@ def generate( input_ids_shape = inputs_embeds.shape if perf_mode == "1" and lookahead is None: - if input_ids_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: + if input_ids_shape is not None and \ + input_ids_shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: lookahead = 2 # default to 2 now if lookahead: @@ -83,7 +85,7 @@ def generate( logger.warning("Prompt lookup is currently not supported on CPU with IPEX, " "fallback to original generate.") kwargs.pop("max_matching_ngram_size", None) - elif input_ids_shape[0] > 1: + elif input_ids_shape is not None and input_ids_shape[0] > 1: logger.warning("Prompt lookup is currently not supported with batch inference, " "fallback to original generate.") kwargs.pop("max_matching_ngram_size", None)