From 6b048b79c1bbf9ba32d2f9fc69a5fcd48e49996c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 22 Oct 2024 10:52:47 -0700 Subject: [PATCH] generate exactly input_len tokens in benchmark_throughput --- benchmarks/benchmark_throughput.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index e26706af606b0..f161011ad4d5e 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -320,7 +320,16 @@ def main(args: argparse.Namespace): args.tokenizer, trust_remote_code=args.trust_remote_code) if args.dataset is None: # Synthesize a prompt with the given input length. - prompt = "hi" * (args.input_len - 1) + # As tokenizer may add additional tokens like BOS, we need to try + # different lengths to get the desired input length. + for i in range(-10, 10): + prompt = "hi " * (args.input_len + i) + tokenized_prompt = tokenizer(prompt).input_ids + if len(tokenized_prompt) == args.input_len: + break + else: + raise ValueError( + f"Failed to synthesize a prompt with {args.input_len} tokens.") requests = [(prompt, args.input_len, args.output_len) for _ in range(args.num_prompts)] else: