diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 4baacc3609522..e3848c3c27adc 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -458,6 +458,7 @@ async def add_request_async( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: ... @@ -471,6 +472,7 @@ async def add_request_async( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: ... @@ -487,6 +489,7 @@ async def add_request_async( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, *, inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: @@ -498,6 +501,9 @@ async def add_request_async( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + if priority != 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") if arrival_time is None: arrival_time = time.time() @@ -521,6 +527,7 @@ async def add_request_async( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, + priority=priority, ) async def check_health_async(self) -> None: @@ -871,6 +878,7 @@ def add_request( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> Coroutine[None, None, AsyncGenerator[Union[ RequestOutput, EmbeddingRequestOutput], None]]: ... @@ -885,6 +893,7 @@ def add_request( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> Coroutine[None, None, AsyncGenerator[Union[ RequestOutput, EmbeddingRequestOutput], None]]: ... @@ -902,6 +911,7 @@ async def add_request( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, *, inputs: Optional[PromptType] = None, # DEPRECATED ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: @@ -919,6 +929,11 @@ async def add_request( "error that caused the background loop to stop " "(AsyncEngineDeadError).") + if (priority != 0 + and not self.engine.scheduler_config.policy == "priority"): + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + stream = self._request_tracker.add_request( request_id, verbose=self.log_requests, @@ -927,7 +942,9 @@ async def add_request( arrival_time=arrival_time or time.time(), lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) return stream.generator() @@ -938,7 +955,8 @@ async def generate( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. @@ -955,6 +973,8 @@ async def generate( trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request to use for generation, if any. + priority: The priority of the request. + Only applicable with priority scheduling. Yields: The output `RequestOutput` objects from the LLMEngine @@ -1010,6 +1030,7 @@ async def generate( lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, + priority=priority, ): yield LLMEngine.validate_output(output, RequestOutput) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c5b46ef26cda3..3f6c6e8ae3570 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -839,7 +839,7 @@ def add_request( raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") - if priority > 0 and not self.scheduler_config.policy == "priority": + if priority != 0 and not self.scheduler_config.policy == "priority": raise ValueError(f"Got priority {priority} but " "Priority scheduling is not enabled.")