From 8d277f2cbbc782f2a5541f8297705954c44a8563 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 22 Oct 2024 17:45:35 -0700 Subject: [PATCH] [Bugfix] Generate exactly input_len tokens in benchmark_throughput (#9592) Signed-off-by: Alvant --- benchmarks/benchmark_throughput.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 5cca92edb251b..24eb54e7b73bc 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -233,7 +233,16 @@ def main(args: argparse.Namespace): args.tokenizer, trust_remote_code=args.trust_remote_code) if args.dataset is None: # Synthesize a prompt with the given input length. - prompt = "hi" * (args.input_len - 1) + # As tokenizer may add additional tokens like BOS, we need to try + # different lengths to get the desired input length. + for i in range(-10, 10): + prompt = "hi " * (args.input_len + i) + tokenized_prompt = tokenizer(prompt).input_ids + if len(tokenized_prompt) == args.input_len: + break + else: + raise ValueError( + f"Failed to synthesize a prompt with {args.input_len} tokens.") requests = [(prompt, args.input_len, args.output_len) for _ in range(args.num_prompts)] else: