Skip to content

Commit

Permalink
Use queue for finished requests (vllm-project#957)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored Sep 6, 2023
1 parent df905c7 commit 5b8c160
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ def generate_greedy(
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params)
return [(output_ids[0], output_str[0]) for output_ids, output_str in
outputs]
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]

def generate_beam_search(
self,
Expand Down
18 changes: 11 additions & 7 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import time
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Type, Union

from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -152,7 +152,7 @@ def __init__(self,

# Request id -> stream.
self.request_streams: Dict[str, AsyncStream] = {}
self.finished_requests: Set[str] = set()
self.finished_requests: asyncio.Queue[str] = asyncio.Queue()
self.background_loop = None
if start_engine_loop:
self.start_background_loop()
Expand Down Expand Up @@ -194,12 +194,14 @@ async def engine_step(self):
if self.log_requests:
logger.info(f"Finished request {request_id}.")
self.request_streams[request_id].finish()
self.finished_requests.add(request_id)
self.finished_requests.put_nowait(request_id)

await self._engine_abort(self.finished_requests)
for request_id in self.finished_requests:
finished_request = set()
while not self.finished_requests.empty():
finished_request.add(self.finished_requests.get_nowait())
await self._engine_abort(finished_request)
for request_id in finished_request:
del self.request_streams[request_id]
self.finished_requests.clear()

async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
Expand All @@ -226,6 +228,8 @@ async def add_request(
f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.")

if request_id in self.request_streams:
raise KeyError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id)
self.request_streams[request_id] = stream

Expand Down Expand Up @@ -316,7 +320,7 @@ def _abort(self, request_id: str) -> None:
logger.info(f"Aborted request {request_id}.")

self.request_streams[request_id].finish()
self.finished_requests.add(request_id)
self.finished_requests.put_nowait(request_id)

async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
Expand Down

0 comments on commit 5b8c160

Please sign in to comment.