diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 66941442c8c9c..672382717d119 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -5,21 +5,23 @@ We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead. """ - +import asyncio import json import ssl -from typing import AsyncGenerator +from argparse import Namespace +from typing import Any, AsyncGenerator, Optional -import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.launcher import serve_http from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid +from vllm.version import __version__ as VLLM_VERSION logger = init_logger("vllm.entrypoints.api_server") @@ -81,6 +83,53 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return JSONResponse(ret) +def build_app(args: Namespace) -> FastAPI: + global app + + app.root_path = args.root_path + return app + + +async def init_app( + args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, +) -> FastAPI: + app = build_app(args) + + global engine + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = (llm_engine + if llm_engine is not None else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER)) + + return app + + +async def run_server(args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, + **uvicorn_kwargs: Any) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + app = await init_app(args, llm_engine) + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + await shutdown_task + + if __name__ == "__main__": parser = FlexibleArgumentParser() parser.add_argument("--host", type=str, default=None) @@ -105,25 +154,5 @@ async def stream_results() -> AsyncGenerator[bytes, None]: parser.add_argument("--log-level", type=str, default="debug") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.API_SERVER) - - app.root_path = args.root_path - logger.info("Available routes are:") - for route in app.routes: - if not hasattr(route, 'methods'): - continue - methods = ', '.join(route.methods) - logger.info("Route: %s, Methods: %s", route.path, methods) - - uvicorn.run(app, - host=args.host, - port=args.port, - log_level=args.log_level, - timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs) + asyncio.run(run_server(args)) diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py new file mode 100644 index 0000000000000..00826762f76a1 --- /dev/null +++ b/vllm/entrypoints/launcher.py @@ -0,0 +1,46 @@ +import asyncio +import signal +from typing import Any + +import uvicorn +from fastapi import FastAPI + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): + logger.info("Available routes are:") + for route in app.routes: + methods = getattr(route, "methods", None) + path = getattr(route, "path", None) + + if methods is None or path is None: + continue + + logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) + + config = uvicorn.Config(app, **uvicorn_kwargs) + server = uvicorn.Server(config) + + loop = asyncio.get_running_loop() + + server_task = loop.create_task(server.serve()) + + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() + + async def dummy_shutdown() -> None: + pass + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + return dummy_shutdown() + except asyncio.CancelledError: + logger.info("Gracefully stopping http server") + return server.shutdown() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e330ee81f7e44..a0190f3d66b10 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -2,15 +2,13 @@ import importlib import inspect import re -import signal +from argparse import Namespace from contextlib import asynccontextmanager from http import HTTPStatus from multiprocessing import Process from typing import AsyncIterator, Set -import fastapi -import uvicorn -from fastapi import APIRouter, Request +from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -22,6 +20,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.protocol import AsyncEngineClient +from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser # yapf conflicts with isort for this block @@ -71,7 +70,7 @@ def model_is_embedding(model_name: str) -> bool: @asynccontextmanager -async def lifespan(app: fastapi.FastAPI): +async def lifespan(app: FastAPI): async def _force_log(): while True: @@ -135,7 +134,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: router = APIRouter() -def mount_metrics(app: fastapi.FastAPI): +def mount_metrics(app: FastAPI): # Add prometheus asgi middleware to route /metrics requests metrics_route = Mount("/metrics", make_asgi_app()) # Workaround for 307 Redirect for /metrics @@ -225,8 +224,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) -def build_app(args): - app = fastapi.FastAPI(lifespan=lifespan) +def build_app(args: Namespace) -> FastAPI: + app = FastAPI(lifespan=lifespan) app.include_router(router) app.root_path = args.root_path @@ -274,11 +273,10 @@ async def authentication(request: Request, call_next): return app -async def build_server( +async def init_app( async_engine_client: AsyncEngineClient, - args, - **uvicorn_kwargs, -) -> uvicorn.Server: + args: Namespace, +) -> FastAPI: app = build_app(args) if args.served_model_name is not None: @@ -334,62 +332,31 @@ async def build_server( ) app.root_path = args.root_path - logger.info("Available routes are:") - for route in app.routes: - if not hasattr(route, 'methods'): - continue - methods = ', '.join(route.methods) - logger.info("Route: %s, Methods: %s", route.path, methods) - - config = uvicorn.Config( - app, - host=args.host, - port=args.port, - log_level=args.uvicorn_log_level, - timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs, - **uvicorn_kwargs, - ) - - return uvicorn.Server(config) + return app async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - shutdown_task = None async with build_async_engine_client(args) as async_engine_client: - - server = await build_server( - async_engine_client, - args, + app = await init_app(async_engine_client, args) + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, **uvicorn_kwargs, ) - loop = asyncio.get_running_loop() - - server_task = loop.create_task(server.serve()) - - def signal_handler() -> None: - # prevents the uvicorn signal handler to exit early - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - try: - await server_task - except asyncio.CancelledError: - logger.info("Gracefully stopping http server") - shutdown_task = server.shutdown() - - if shutdown_task: - # NB: Await server shutdown only after the backend context is exited - await shutdown_task + # NB: Await server shutdown only after the backend context is exited + await shutdown_task if __name__ == "__main__": @@ -399,4 +366,5 @@ def signal_handler() -> None: description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) args = parser.parse_args() + asyncio.run(run_server(args))