From 5fd6412347b1c1049ead1aa46297291a944d2aec Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Fri, 23 Aug 2024 16:18:55 +0800 Subject: [PATCH 1/5] Update IPEX_LLM_PERFORMANCE_MODE with input length threshold --- python/llm/src/ipex_llm/transformers/lookup.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index e86c05b1642..6dd3987cb0c 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -40,6 +40,9 @@ original_generate = GenerationMixin.generate query_group_size = 16 +# may tune it with more tested data +PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD = 100 + @torch.no_grad() def generate( @@ -54,10 +57,12 @@ def generate( streamer: Optional["BaseStreamer"] = None, **kwargs, ): + device_name = get_xpu_device_type(inputs) lookahead = kwargs.pop("lookahead", None) perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None) if perf_mode == "1" and lookahead is None: - lookahead = 2 # default to 2 now + if device_name != 'mtl' or inputs.shape[-1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: + lookahead = 2 # default to 2 now if lookahead: from ipex_llm.transformers.convert import get_enable_ipex _enable_ipex = get_enable_ipex() From e631ecd0ca6b91f300615360fdd8510265ef7153 Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Fri, 23 Aug 2024 17:44:32 +0800 Subject: [PATCH 2/5] Update based on comments. And and judgement for inputs_embeds --- python/llm/src/ipex_llm/transformers/lookup.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index 6dd3987cb0c..c6fe4847073 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -57,12 +57,17 @@ def generate( streamer: Optional["BaseStreamer"] = None, **kwargs, ): - device_name = get_xpu_device_type(inputs) lookahead = kwargs.pop("lookahead", None) perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None) if perf_mode == "1" and lookahead is None: - if device_name != 'mtl' or inputs.shape[-1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: - lookahead = 2 # default to 2 now + if inputs is not None: + if inputs.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: + lookahead = 2 # default to 2 now + else: + inputs_embeds = kwargs.get("inputs_embeds", None) + if inputs_embeds is not None: + if inputs_embeds.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: + lookahead = 2 # default to 2 now if lookahead: from ipex_llm.transformers.convert import get_enable_ipex _enable_ipex = get_enable_ipex() From 0f34656c96fe3ac0a5c3bffbe036036b654b75e5 Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Fri, 23 Aug 2024 18:09:51 +0800 Subject: [PATCH 3/5] Fix for benchmarking purposes --- python/llm/src/ipex_llm/transformers/lookup.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index c6fe4847073..0ec784da301 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -59,16 +59,23 @@ def generate( ): lookahead = kwargs.pop("lookahead", None) perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None) + use_update_candidate_strategy = True if perf_mode == "1" and lookahead is None: if inputs is not None: if inputs.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: lookahead = 2 # default to 2 now + else: + lookahead = 0 + use_update_candidate_strategy = False else: inputs_embeds = kwargs.get("inputs_embeds", None) if inputs_embeds is not None: if inputs_embeds.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: lookahead = 2 # default to 2 now - if lookahead: + else: + lookahead = 0 + use_update_candidate_strategy = False + if lookahead is not None: from ipex_llm.transformers.convert import get_enable_ipex _enable_ipex = get_enable_ipex() @@ -97,6 +104,7 @@ def generate( logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + use_update_candidate_strategy=use_update_candidate_strategy, **kwargs) return original_generate(self, @@ -152,7 +160,7 @@ def __init__( self.min_candidates = 0 self.lookup_table = {} - invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0, + invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens >= 0, "Invalid max_matching_ngram_size or num_output_tokens") def init_look_up_table(self, @@ -255,6 +263,7 @@ def lookup_generate(self, max_matching_ngram_size: int = None, generation_config: Optional[GenerationConfig] = None, attention_mask=None, + use_update_candidate_strategy=True, **sampling_kwargs): input_ids, generation_config, logits_processor, stopping_criteria, \ model_kwargs = _prepare_generate_args(self, inputs, generation_config, @@ -385,7 +394,7 @@ def lookup_generate(self, accept_rate = self.n_matched/self.n_drafted if self.n_drafted > 0 else 1 self.accept_rate.append(accept_rate) # Update the candidate generation strategy if needed - if device_name != 'mtl': + if device_name != 'mtl' and use_update_candidate_strategy: candidates_generator.update_candidate_strategy(candidate_length, n_matches, accept_rate) From 0d0914486f29dd3efb18a46f9c6fbd09cb922c2f Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Fri, 23 Aug 2024 19:34:57 +0800 Subject: [PATCH 4/5] Update based on comments --- python/llm/src/ipex_llm/transformers/lookup.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index 0ec784da301..6b20930484f 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -59,23 +59,16 @@ def generate( ): lookahead = kwargs.pop("lookahead", None) perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None) - use_update_candidate_strategy = True if perf_mode == "1" and lookahead is None: if inputs is not None: if inputs.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: lookahead = 2 # default to 2 now - else: - lookahead = 0 - use_update_candidate_strategy = False else: inputs_embeds = kwargs.get("inputs_embeds", None) if inputs_embeds is not None: if inputs_embeds.shape[1] >= PERFORMANCE_MODE_LOOKUP_INPUT_THRESHOLD: lookahead = 2 # default to 2 now - else: - lookahead = 0 - use_update_candidate_strategy = False - if lookahead is not None: + if lookahead: from ipex_llm.transformers.convert import get_enable_ipex _enable_ipex = get_enable_ipex() @@ -104,7 +97,6 @@ def generate( logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - use_update_candidate_strategy=use_update_candidate_strategy, **kwargs) return original_generate(self, @@ -160,7 +152,7 @@ def __init__( self.min_candidates = 0 self.lookup_table = {} - invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens >= 0, + invalidInputError(self.max_matching_ngram_size > 0 and self.num_output_tokens > 0, "Invalid max_matching_ngram_size or num_output_tokens") def init_look_up_table(self, @@ -394,7 +386,7 @@ def lookup_generate(self, accept_rate = self.n_matched/self.n_drafted if self.n_drafted > 0 else 1 self.accept_rate.append(accept_rate) # Update the candidate generation strategy if needed - if device_name != 'mtl' and use_update_candidate_strategy: + if device_name != 'mtl': candidates_generator.update_candidate_strategy(candidate_length, n_matches, accept_rate) From 1379be2174731ac669b6fd0219de4945a26d8fd8 Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Fri, 23 Aug 2024 19:47:08 +0800 Subject: [PATCH 5/5] Small fix --- python/llm/src/ipex_llm/transformers/lookup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index 6b20930484f..c6fe4847073 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -255,7 +255,6 @@ def lookup_generate(self, max_matching_ngram_size: int = None, generation_config: Optional[GenerationConfig] = None, attention_mask=None, - use_update_candidate_strategy=True, **sampling_kwargs): input_ids, generation_config, logits_processor, stopping_criteria, \ model_kwargs = _prepare_generate_args(self, inputs, generation_config,