From 098b4a94db5fc426c0a0c39d1733a788b75078d0 Mon Sep 17 00:00:00 2001 From: Lei Wen Date: Fri, 3 May 2024 18:52:58 +0800 Subject: [PATCH] fix async mode --- tests/spec_decode/e2e/conftest.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index b09e57d86d5e1..7342c4dda5a9e 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -51,7 +51,7 @@ def __init__( ) -> None: if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True - self.engine_args = AsyncEngineArgs( + engine_args = AsyncEngineArgs( model=model, tokenizer=tokenizer, tokenizer_mode=tokenizer_mode, @@ -72,6 +72,8 @@ def __init__( **kwargs, ) self.request_counter = Counter() + self.llm_engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.LLM_CLASS) def generate( self, @@ -84,9 +86,6 @@ def generate( multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: - llm_engine = AsyncLLMEngine.from_engine_args( - self.engine_args, usage_context=UsageContext.LLM_CLASS) - if prompts is None: raise ValueError("prompts must be provided.") if isinstance(prompts, str): @@ -107,8 +106,8 @@ def generate( async def get_output(prompt, sampling_param) -> str: request_id = random_uuid() - results_generator = llm_engine.generate(prompt, sampling_param, - request_id) + results_generator = self.llm_engine.generate( + prompt, sampling_param, request_id) final_output = None async for request_output in results_generator: final_output = request_output @@ -180,7 +179,8 @@ def get_output_from_llm_generator( tokens = [] token_ids = [] for llm in llm_generator(): - if (llm.llm_engine.speculative_config is not None and + if (not isinstance(llm, AsyncLLM) + and llm.llm_engine.speculative_config is not None and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0): assert ("set_ngram_window_size" in dir( llm.llm_engine.model_executor.driver_worker.proposer_worker))