Skip to content

Commit

Permalink
[Frontend] Reapply "Factor out code for running uvicorn" (#7095)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Aug 5, 2024
1 parent 7b86e7c commit cc08fc7
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 82 deletions.
77 changes: 53 additions & 24 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,23 @@
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
"""

import asyncio
import json
import ssl
from typing import AsyncGenerator
from argparse import Namespace
from typing import Any, AsyncGenerator, Optional

import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.launcher import serve_http
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger("vllm.entrypoints.api_server")

Expand Down Expand Up @@ -81,6 +83,53 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
return JSONResponse(ret)


def build_app(args: Namespace) -> FastAPI:
global app

app.root_path = args.root_path
return app


async def init_app(
args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
) -> FastAPI:
app = build_app(args)

global engine

engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER))

return app


async def run_server(args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
**uvicorn_kwargs: Any) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)

app = await init_app(args, llm_engine)

shutdown_task = await serve_http(
app,
host=args.host,
port=args.port,
log_level=args.log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)

await shutdown_task


if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
Expand All @@ -105,25 +154,5 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
parser.add_argument("--log-level", type=str, default="debug")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER)

app.root_path = args.root_path

logger.info("Available routes are:")
for route in app.routes:
if not hasattr(route, 'methods'):
continue
methods = ', '.join(route.methods)
logger.info("Route: %s, Methods: %s", route.path, methods)

uvicorn.run(app,
host=args.host,
port=args.port,
log_level=args.log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)
asyncio.run(run_server(args))
46 changes: 46 additions & 0 deletions vllm/entrypoints/launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import asyncio
import signal
from typing import Any

import uvicorn
from fastapi import FastAPI

from vllm.logger import init_logger

logger = init_logger(__name__)


async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
logger.info("Available routes are:")
for route in app.routes:
methods = getattr(route, "methods", None)
path = getattr(route, "path", None)

if methods is None or path is None:
continue

logger.info("Route: %s, Methods: %s", path, ', '.join(methods))

config = uvicorn.Config(app, **uvicorn_kwargs)
server = uvicorn.Server(config)

loop = asyncio.get_running_loop()

server_task = loop.create_task(server.serve())

def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()

async def dummy_shutdown() -> None:
pass

loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)

try:
await server_task
return dummy_shutdown()
except asyncio.CancelledError:
logger.info("Gracefully stopping http server")
return server.shutdown()
84 changes: 26 additions & 58 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
import importlib
import inspect
import re
import signal
from argparse import Namespace
from contextlib import asynccontextmanager
from http import HTTPStatus
from multiprocessing import Process
from typing import AsyncIterator, Set

import fastapi
import uvicorn
from fastapi import APIRouter, Request
from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
Expand All @@ -22,6 +20,7 @@
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser
# yapf conflicts with isort for this block
Expand Down Expand Up @@ -71,7 +70,7 @@ def model_is_embedding(model_name: str) -> bool:


@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
async def lifespan(app: FastAPI):

async def _force_log():
while True:
Expand Down Expand Up @@ -135,7 +134,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
router = APIRouter()


def mount_metrics(app: fastapi.FastAPI):
def mount_metrics(app: FastAPI):
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
Expand Down Expand Up @@ -225,8 +224,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())


def build_app(args):
app = fastapi.FastAPI(lifespan=lifespan)
def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path

Expand Down Expand Up @@ -274,11 +273,10 @@ async def authentication(request: Request, call_next):
return app


async def build_server(
async def init_app(
async_engine_client: AsyncEngineClient,
args,
**uvicorn_kwargs,
) -> uvicorn.Server:
args: Namespace,
) -> FastAPI:
app = build_app(args)

if args.served_model_name is not None:
Expand Down Expand Up @@ -334,62 +332,31 @@ async def build_server(
)
app.root_path = args.root_path

logger.info("Available routes are:")
for route in app.routes:
if not hasattr(route, 'methods'):
continue
methods = ', '.join(route.methods)
logger.info("Route: %s, Methods: %s", route.path, methods)

config = uvicorn.Config(
app,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)

return uvicorn.Server(config)
return app


async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)

shutdown_task = None
async with build_async_engine_client(args) as async_engine_client:

server = await build_server(
async_engine_client,
args,
app = await init_app(async_engine_client, args)

shutdown_task = await serve_http(
app,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)

loop = asyncio.get_running_loop()

server_task = loop.create_task(server.serve())

def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()

loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)

try:
await server_task
except asyncio.CancelledError:
logger.info("Gracefully stopping http server")
shutdown_task = server.shutdown()

if shutdown_task:
# NB: Await server shutdown only after the backend context is exited
await shutdown_task
# NB: Await server shutdown only after the backend context is exited
await shutdown_task


if __name__ == "__main__":
Expand All @@ -399,4 +366,5 @@ def signal_handler() -> None:
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
args = parser.parse_args()

asyncio.run(run_server(args))

0 comments on commit cc08fc7

Please sign in to comment.