Skip to content

Commit

Permalink
Support OpenAI API server in benchmark_serving.py (vllm-project#2172)
Browse files Browse the repository at this point in the history
  • Loading branch information
hmellor authored Jan 19, 2024
1 parent c3149dc commit bacad9c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 32 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,6 @@ _build/
# hip files generated by PyTorch
*.hip
*_hip*

# Benchmark dataset
*.json
80 changes: 48 additions & 32 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import aiohttp
import numpy as np
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer import get_tokenizer

Expand All @@ -40,15 +41,10 @@ def sample_requests(
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [
data for data in dataset
if len(data["conversations"]) >= 2
]
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]

# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
Expand Down Expand Up @@ -98,6 +94,7 @@ async def get_request(

async def send_request(
backend: str,
model: str,
api_url: str,
prompt: str,
prompt_len: int,
Expand All @@ -120,6 +117,8 @@ async def send_request(
"ignore_eos": True,
"stream": False,
}
if model is not None:
pload["model"] = model
elif backend == "tgi":
assert not use_beam_search
params = {
Expand All @@ -137,7 +136,8 @@ async def send_request(
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session:
while True:
async with session.post(api_url, headers=headers, json=pload) as response:
async with session.post(api_url, headers=headers,
json=pload) as response:
chunks = []
async for chunk, _ in response.content.iter_chunks():
chunks.append(chunk)
Expand All @@ -155,6 +155,7 @@ async def send_request(

async def benchmark(
backend: str,
model: str,
api_url: str,
input_requests: List[Tuple[str, int, int]],
best_of: int,
Expand All @@ -164,25 +165,27 @@ async def benchmark(
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request
task = asyncio.create_task(send_request(backend, api_url, prompt,
prompt_len, output_len,
best_of, use_beam_search))
task = asyncio.create_task(
send_request(backend, model, api_url, prompt, prompt_len,
output_len, best_of, use_beam_search))
tasks.append(task)
await asyncio.gather(*tasks)
await tqdm.gather(*tasks)


def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
np.random.seed(args.seed)

api_url = f"http://{args.host}:{args.port}/generate"
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
tokenizer = get_tokenizer(args.tokenizer,
trust_remote_code=args.trust_remote_code)
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)

benchmark_start_time = time.perf_counter()
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
args.use_beam_search, args.request_rate))
asyncio.run(
benchmark(args.backend, args.model, api_url, input_requests,
args.best_of, args.use_beam_search, args.request_rate))
benchmark_end_time = time.perf_counter()
benchmark_time = benchmark_end_time - benchmark_start_time
print(f"Total time: {benchmark_time:.2f} s")
Expand All @@ -196,38 +199,51 @@ def main(args: argparse.Namespace):
for prompt_len, output_len, latency in REQUEST_LATENCY
])
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
avg_per_output_token_latency = np.mean([
latency / output_len
for _, output_len, latency in REQUEST_LATENCY
])
avg_per_output_token_latency = np.mean(
[latency / output_len for _, output_len, latency in REQUEST_LATENCY])
print("Average latency per output token: "
f"{avg_per_output_token_latency:.2f} s")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark the online serving throughput.")
parser.add_argument("--backend", type=str, default="vllm",
parser.add_argument("--backend",
type=str,
default="vllm",
choices=["vllm", "tgi"])
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--dataset", type=str, required=True,
parser.add_argument("--endpoint", type=str, default="/generate")
parser.add_argument("--model", type=str, default=None)
parser.add_argument("--dataset",
type=str,
required=True,
help="Path to the dataset.")
parser.add_argument("--tokenizer", type=str, required=True,
parser.add_argument("--tokenizer",
type=str,
required=True,
help="Name or path of the tokenizer.")
parser.add_argument("--best-of", type=int, default=1,
parser.add_argument("--best-of",
type=int,
default=1,
help="Generates `best_of` sequences per prompt and "
"returns the best one.")
"returns the best one.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000,
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--request-rate", type=float, default=float("inf"),
parser.add_argument("--request-rate",
type=float,
default=float("inf"),
help="Number of requests per second. If this is inf, "
"then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize "
"the request arrival times.")
"then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize "
"the request arrival times.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code', action='store_true',
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args()
main(args)

0 comments on commit bacad9c

Please sign in to comment.