Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] benchmark: Add option to set max concurrency #9390

Merged
merged 5 commits into from
Oct 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ async def benchmark(
selected_percentile_metrics: List[str],
selected_percentiles: List[str],
ignore_eos: bool,
max_concurrency: Optional[int],
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
Expand Down Expand Up @@ -445,9 +446,25 @@ async def benchmark(
print("Profiler started")

print(f"Traffic request rate: {request_rate}")
print(f"Maximum request concurrency: {max_concurrency}")

pbar = None if disable_tqdm else tqdm(total=len(input_requests))

# This can be used once the minimum Python version is 3.10 or higher,
# and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext())
semaphore = (asyncio.Semaphore(max_concurrency)
if max_concurrency else None)

async def limited_request_func(request_func_input, pbar):
if semaphore is None:
return await request_func(request_func_input=request_func_input,
pbar=pbar)
async with semaphore:
return await request_func(request_func_input=request_func_input,
pbar=pbar)

benchmark_start_time = time.perf_counter()
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate):
Expand All @@ -463,8 +480,8 @@ async def benchmark(
ignore_eos=ignore_eos)
tasks.append(
asyncio.create_task(
request_func(request_func_input=request_func_input,
pbar=pbar)))
limited_request_func(request_func_input=request_func_input,
pbar=pbar)))
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)

if profile:
Expand Down Expand Up @@ -680,6 +697,7 @@ def main(args: argparse.Namespace):
float(p) for p in args.metric_percentiles.split(",")
],
ignore_eos=args.ignore_eos,
max_concurrency=args.max_concurrency,
))

# Save config and results to json
Expand Down Expand Up @@ -709,13 +727,16 @@ def main(args: argparse.Namespace):
# Traffic
result_json["request_rate"] = (
args.request_rate if args.request_rate < float("inf") else "inf")
result_json["max_concurrency"] = args.max_concurrency

# Merge with benchmark result
result_json = {**result_json, **benchmark_result}

# Save to file
base_model_id = model_id.split("/")[-1]
file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None else "")
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
Expand Down Expand Up @@ -766,6 +787,19 @@ def main(args: argparse.Namespace):
default=None,
help="Path to the sharegpt/sonnet dataset. "
"Or the huggingface dataset ID if using HF dataset.")
parser.add_argument(
"--max-concurrency",
type=int,
default=None,
help="Maximum number of concurrent requests. This can be used "
"to help simulate an environment where a higher level component "
"is enforcing a maximum number of concurrent requests. While the "
"--request-rate argument controls the rate at which requests are "
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.")

parser.add_argument(
"--model",
type=str,
Expand Down