Skip to content

Commit

Permalink
[Frontend] Add progress reporting to run_batch.py (vllm-project#8060)
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Lugowski <[email protected]>
  • Loading branch information
alugowski and Adam Lugowski authored Sep 9, 2024
1 parent 08287ef commit 58fcc85
Showing 1 changed file with 48 additions and 6 deletions.
54 changes: 48 additions & 6 deletions vllm/entrypoints/openai/run_batch.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
from io import StringIO
from typing import Awaitable, Callable, List
from typing import Awaitable, Callable, List, Optional

import aiohttp
import torch
from prometheus_client import start_http_server
from tqdm import tqdm

from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine
Expand Down Expand Up @@ -78,6 +80,38 @@ def parse_args():
return parser.parse_args()


# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501


class BatchProgressTracker:

def __init__(self):
self._total = 0
self._pbar: Optional[tqdm] = None

def submitted(self):
self._total += 1

def completed(self):
if self._pbar:
self._pbar.update()

def pbar(self) -> tqdm:
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
self._pbar = tqdm(total=self._total,
unit="req",
desc="Running batch",
mininterval=5,
disable=not enable_tqdm,
bar_format=_BAR_FORMAT)
return self._pbar


async def read_file(path_or_url: str) -> str:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, \
Expand All @@ -102,7 +136,8 @@ async def write_file(path_or_url: str, data: str) -> None:


async def run_request(serving_engine_func: Callable,
request: BatchRequestInput) -> BatchRequestOutput:
request: BatchRequestInput,
tracker: BatchProgressTracker) -> BatchRequestOutput:
response = await serving_engine_func(request.body)

if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
Expand All @@ -125,6 +160,7 @@ async def run_request(serving_engine_func: Callable,
else:
raise ValueError("Request must not be sent in stream mode")

tracker.completed()
return batch_output


Expand Down Expand Up @@ -164,6 +200,9 @@ async def main(args):
request_logger=request_logger,
)

tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)

# Submit all requests in the file to the engine "concurrently".
response_futures: List[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
Expand All @@ -178,16 +217,19 @@ async def main(args):
if request.url == "/v1/chat/completions":
response_futures.append(
run_request(openai_serving_chat.create_chat_completion,
request))
request, tracker))
tracker.submitted()
elif request.url == "/v1/embeddings":
response_futures.append(
run_request(openai_serving_embedding.create_embedding,
request))
run_request(openai_serving_embedding.create_embedding, request,
tracker))
tracker.submitted()
else:
raise ValueError("Only /v1/chat/completions and /v1/embeddings are"
"supported in the batch endpoint.")

responses = await asyncio.gather(*response_futures)
with tracker.pbar():
responses = await asyncio.gather(*response_futures)

output_buffer = StringIO()
for response in responses:
Expand Down

0 comments on commit 58fcc85

Please sign in to comment.