From 12a86c9f0c65323678939211c491538906fb57c8 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Thu, 4 Apr 2024 00:41:05 -0700 Subject: [PATCH] [Bugfix] Fix args in benchmark_serving (#3836) Co-authored-by: Roger Wang --- benchmarks/benchmark_serving.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index bc7812ed4119e..6054df439fa57 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -110,7 +110,9 @@ def sample_sonnet_requests( prefix_len: int, tokenizer: PreTrainedTokenizerBase, ) -> List[Tuple[str, str, int, int]]: - assert input_len > prefix_len, "input_len must be greater than prefix_len." + assert ( + input_len > prefix_len + ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." # Load the dataset. with open(dataset_path) as f: @@ -131,8 +133,9 @@ def sample_sonnet_requests( base_message, add_generation_prompt=True, tokenize=False) base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids) - assert (input_len > base_prompt_offset - ), f"Please set 'args.input-len' higher than {base_prompt_offset}." + assert ( + input_len > base_prompt_offset + ), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}." num_input_lines = round( (input_len - base_prompt_offset) / average_poem_len) @@ -140,7 +143,7 @@ def sample_sonnet_requests( # prompt are fixed poem lines. assert ( prefix_len > base_prompt_offset - ), f"Please set 'args.prefix-len' higher than {base_prompt_offset}." + ), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}." num_prefix_lines = round( (prefix_len - base_prompt_offset) / average_poem_len) @@ -373,9 +376,9 @@ def main(args: argparse.Namespace): input_requests = sample_sonnet_requests( dataset_path=args.dataset_path, num_requests=args.num_prompts, - input_len=args.input_len, - output_len=args.output_len, - prefix_len=args.prefix_len, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, ) input_requests = [(prompt, prompt_len, output_len) @@ -388,9 +391,9 @@ def main(args: argparse.Namespace): input_requests = sample_sonnet_requests( dataset_path=args.dataset_path, num_requests=args.num_prompts, - input_len=args.input_len, - output_len=args.output_len, - prefix_len=args.prefix_len, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, ) input_requests = [(prompt_formatted, prompt_len, output_len)