diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index d3e06cca618df..d6bf18c82e465 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -29,22 +29,23 @@ def sample_requests( dataset = [(data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset] - # Tokenize the prompts and completions. - prompts = [prompt for prompt, _ in dataset] - prompt_token_ids = tokenizer(prompts).input_ids - completions = [completion for _, completion in dataset] - completion_token_ids = tokenizer(completions).input_ids - tokenized_dataset = [] - for i in range(len(dataset)): - output_len = len(completion_token_ids[i]) - if fixed_output_len is not None: - output_len = fixed_output_len - tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + # Shuffle the dataset. + random.shuffle(dataset) - # Filter out too long sequences. + # Filter out sequences that are too long or too short filtered_dataset: List[Tuple[str, int, int]] = [] - for prompt, prompt_token_ids, output_len in tokenized_dataset: + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len if prompt_len < 4 or output_len < 4: # Prune too short sequences. continue @@ -53,9 +54,7 @@ def sample_requests( continue filtered_dataset.append((prompt, prompt_len, output_len)) - # Sample the requests. - sampled_requests = random.sample(filtered_dataset, num_requests) - return sampled_requests + return filtered_dataset def run_vllm(