diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 027c2217380..9800cfea70d 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -737,11 +737,12 @@ def _prepare_generate_args_4_45(self, inputs, generation_config, streamer=None, ) # 9. prepare logits processors and stopping criteria + prefix_allowed_tokens_fn = kwargs.pop("prefix_allowed_tokens_fn", None) prepared_logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, - prefix_allowed_tokens_fn=None, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, device=inputs_tensor.device, model_kwargs=model_kwargs,