Skip to content

Commit

Permalink
[Feature] vLLM CLI (vllm-project#5090)
Browse files Browse the repository at this point in the history
Co-authored-by: simon-mo <[email protected]>
  • Loading branch information
2 people authored and dtrifiro committed Jul 17, 2024
1 parent 86b85b0 commit 1cbca2b
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 36 deletions.
4 changes: 2 additions & 2 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
On the server side, run one of the following commands:
vLLM OpenAI API server
python -m vllm.entrypoints.openai.api_server \
--model <your_model> --swap-space 16 \
vllm serve <your_model> \
--swap-space 16 \
--disable-log-requests
(TGI backend)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)

```{argparse}
:module: vllm.entrypoints.openai.cli_args
:func: make_arg_parser
:func: create_parser_for_docs
:prog: -m vllm.entrypoints.openai.api_server
```

Expand Down
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,4 +488,9 @@ def _read_requirements(filename: str) -> List[str]:
},
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
package_data=package_data,
entry_points={
"console_scripts": [
"vllm=vllm.scripts:main",
],
},
)
6 changes: 4 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import get_open_port, is_hip
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip

if is_hip():
from amdsmi import (amdsmi_get_gpu_vram_usage,
Expand Down Expand Up @@ -57,7 +57,9 @@ def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None:

cli_args = cli_args + ["--port", str(get_open_port())]

parser = make_arg_parser()
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args(cli_args)
self.host = str(args.host or 'localhost')
self.port = int(args.port)
Expand Down
78 changes: 50 additions & 28 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import fastapi
import uvicorn
from fastapi import Request
from fastapi import APIRouter, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
Expand All @@ -35,10 +35,14 @@
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION

TIMEOUT_KEEP_ALIVE = 5 # seconds

logger = init_logger(__name__)
engine: AsyncLLMEngine
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
Expand All @@ -64,35 +68,23 @@ async def _force_log():
yield


app = fastapi.FastAPI(lifespan=lifespan)


def parse_args():
parser = make_arg_parser()
return parser.parse_args()

router = APIRouter()

# Add prometheus asgi middleware to route /metrics requests
route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(route)


@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
router.routes.append(route)


@app.get("/health")
@router.get("/health")
async def health() -> Response:
"""Health check."""
await openai_serving_chat.engine.check_health()
return Response(status_code=200)


@app.post("/tokenize")
@router.post("/tokenize")
async def tokenize(request: TokenizeRequest):
generator = await openai_serving_completion.create_tokenize(request)
if isinstance(generator, ErrorResponse):
Expand All @@ -103,7 +95,7 @@ async def tokenize(request: TokenizeRequest):
return JSONResponse(content=generator.model_dump())


@app.post("/detokenize")
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
generator = await openai_serving_completion.create_detokenize(request)
if isinstance(generator, ErrorResponse):
Expand All @@ -114,19 +106,19 @@ async def detokenize(request: DetokenizeRequest):
return JSONResponse(content=generator.model_dump())


@app.get("/v1/models")
@router.get("/v1/models")
async def show_available_models():
models = await openai_serving_completion.show_available_models()
return JSONResponse(content=models.model_dump())


@app.get("/version")
@router.get("/version")
async def show_version():
ver = {"version": VLLM_VERSION}
return JSONResponse(content=ver)


@app.post("/v1/chat/completions")
@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
Expand All @@ -142,7 +134,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return JSONResponse(content=generator.model_dump())


@app.post("/v1/completions")
@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion(
request, raw_request)
Expand All @@ -156,7 +148,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())


@app.post("/v1/embeddings")
@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await openai_serving_embedding.create_embedding(
request, raw_request)
Expand All @@ -167,8 +159,10 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())


if __name__ == "__main__":
args = parse_args()
def build_app(args):
app = fastapi.FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path

app.add_middleware(
CORSMiddleware,
Expand All @@ -178,6 +172,12 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
allow_headers=args.allowed_headers,
)

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)

if token := envs.VLLM_API_KEY or args.api_key:

@app.middleware("http")
Expand All @@ -203,6 +203,12 @@ async def authentication(request: Request, call_next):
raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.")

return app


def run_server(args, llm_engine=None):
app = build_app(args)

logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)

Expand All @@ -211,10 +217,12 @@ async def authentication(request: Request, call_next):
else:
served_model_names = [args.model]

engine_args = AsyncEngineArgs.from_cli_args(args)
global engine, engine_args

engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))

event_loop: Optional[asyncio.AbstractEventLoop]
try:
Expand All @@ -230,6 +238,10 @@ async def authentication(request: Request, call_next):
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())

global openai_serving_chat
global openai_serving_completion
global openai_serving_embedding

openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names,
args.response_role,
Expand Down Expand Up @@ -258,3 +270,13 @@ async def authentication(request: Request, call_next):
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)


if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
args = parser.parse_args()
run_server(args)
10 changes: 7 additions & 3 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, adapter_list)


def make_arg_parser():
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--host",
type=nullable_str,
default=None,
Expand Down Expand Up @@ -133,3 +131,9 @@ def make_arg_parser():

parser = AsyncEngineArgs.add_cli_args(parser)
return parser


def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server")
return make_arg_parser(parser_for_docs)
Loading

0 comments on commit 1cbca2b

Please sign in to comment.