Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Use a context manager for the server
Browse files Browse the repository at this point in the history
  • Loading branch information
dbarbuzzi committed May 29, 2024
1 parent f687019 commit 32ab964
Showing 1 changed file with 56 additions and 18 deletions.
74 changes: 56 additions & 18 deletions neuralmagic/benchmarks/run_benchmark_serving.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import argparse
import itertools
import json
import shlex
import subprocess
import sys
import time
from pathlib import Path
from typing import List, NamedTuple, Optional
from tempfile import TemporaryFile
from typing import Dict, List, NamedTuple, Optional

import requests

from ...tests.utils.logging import log_banner, make_logger
from ..tools.call_cmd import call_cmd
from .common import (benchmark_configs, download_model,
max_model_length_from_model_id, script_args_to_cla)
Expand All @@ -17,6 +21,55 @@
BENCH_SERVER_PORT = 9000


class Server:

def __init__(self, args: Dict, max_ready_wait: int = 600):
self.logger = make_logger("nm-vllm-server")
self.cmd = [sys.executable, "-m", "vllm.entrypoints.api_server"]
for k, v in args.items():
self.cmd.extend([f"--{k}", str(v)])
self.max_ready_wait = max_ready_wait
self.proc = None
self.output_file = TemporaryFile()

def __enter__(self):
log_banner(self.logger, "server startup command", shlex.join(self.cmd))
self.proc = subprocess.Popen(self.cmd,
stderr=subprocess.STDOUT,
stdout=self.output_file.fileno())
self._wait_for_server_ready()

def __exit__(self, exc_type, exc_value, exc_traceback):
if self.proc and self.proc.poll() is None:
self.logger.info("killing server")
self.proc.kill()

if exc_type is None:
return # only log if an exception occurred

self.output_file.seek(0)
self.output = self.output_file.read()
self.output_file.close()

log_banner(self.logger, "server output", self.output)

def _wait_for_server_ready(self):
self.logger.info("waiting for server to become ready")
start = time.time()
while time.time() - start < self.max_ready_wait:
try:
if requests.get(
f"http://{BENCH_SERVER_HOST}:{BENCH_SERVER_PORT}/health",
timeout=10).status_code == 200:
break
except Exception as e:
if self.proc and self.proc.poll() is not None:
raise RuntimeError("server exited unexpectedly") from e
time.sleep(0.5)
else:
raise RuntimeError("server failed to start in time")


def get_tensor_parallel_size(config: NamedTuple) -> int:

num_tp_directives = [
Expand Down Expand Up @@ -63,14 +116,7 @@ def run_benchmark_serving_script(config: NamedTuple,
assert config.script_name == 'benchmark_serving'

def run_bench(server_cmd: str, bench_cmd: List[str], model: str) -> None:
try:
# start server
server_process = subprocess.Popen("exec " + server_cmd, shell=True)
if not is_server_running(BENCH_SERVER_HOST, BENCH_SERVER_PORT):
raise ValueError(
f"Aborting bench run with : server-cmd {server_cmd} , "
f"bench-cmd {bench_cmd}. Reason: Cannot start Server")

with Server(server_cmd):
# server warmup
warmup_server(server_host=BENCH_SERVER_HOST,
server_port=BENCH_SERVER_PORT,
Expand All @@ -79,10 +125,6 @@ def run_bench(server_cmd: str, bench_cmd: List[str], model: str) -> None:

# run bench
call_cmd(bench_cmd, stdout=None, stderr=None)
finally:
# kill the server
assert server_process is not None
server_process.kill()

tensor_parallel_size = get_tensor_parallel_size(config)

Expand Down Expand Up @@ -120,10 +162,6 @@ def run_bench(server_cmd: str, bench_cmd: List[str], model: str) -> None:
if sparsity:
server_args["sparsity"] = sparsity

server_cmd = "python3 -m vllm.entrypoints.api_server " + \
" ".join([f"--{k} {v}"
for k, v in server_args.items()])

for script_args in script_args_to_cla(config):

description = (f"{config.description}\n" +
Expand Down Expand Up @@ -153,7 +191,7 @@ def run_bench(server_cmd: str, bench_cmd: List[str], model: str) -> None:
f"{tensor_parallel_size}"
])

run_bench(server_cmd, bench_cmd, model)
run_bench(server_args, bench_cmd, model)


if __name__ == '__main__':
Expand Down

0 comments on commit 32ab964

Please sign in to comment.