From 66813d997aae8d2711b48425acb7ebe5cb19715c Mon Sep 17 00:00:00 2001 From: Archit Patke Date: Tue, 24 Sep 2024 21:50:50 -0500 Subject: [PATCH] [Core] Adding Priority Scheduling (#5958) Signed-off-by: Amit Garg --- benchmarks/benchmark_prioritization.py | 295 +++++++++++++++++++++++++ vllm/config.py | 6 +- vllm/core/scheduler.py | 77 +++++++ vllm/engine/llm_engine.py | 24 +- vllm/entrypoints/llm.py | 12 +- vllm/sequence.py | 4 + 6 files changed, 410 insertions(+), 8 deletions(-) create mode 100644 benchmarks/benchmark_prioritization.py diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py new file mode 100644 index 0000000000000..0ba29fabca59b --- /dev/null +++ b/benchmarks/benchmark_prioritization.py @@ -0,0 +1,295 @@ +"""Benchmark offline prioritization.""" +import argparse +import json +import random +import time +from typing import List, Optional, Tuple + +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int], +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data["conversations"][0]["value"], + data["conversations"][1]["value"]) for data in dataset] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + + #Select a equi-probable random priority + priority = 0 if random.random() < 0.5 else 1 + + filtered_dataset.append((prompt, prompt_len, output_len, priority)) + + return filtered_dataset + + +def run_vllm( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: str, + quantization: Optional[str], + tensor_parallel_size: int, + seed: int, + n: int, + use_beam_search: bool, + trust_remote_code: bool, + dtype: str, + max_model_len: Optional[int], + enforce_eager: bool, + kv_cache_dtype: str, + quantization_param_path: Optional[str], + device: str, + enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, + gpu_memory_utilization: float = 0.9, + download_dir: Optional[str] = None, +) -> float: + from vllm import LLM, SamplingParams + llm = LLM( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + disable_log_stats=False, + ) + + # Add the requests to the engine. + prompts = [] + sampling_params = [] + priority = [] + for prompt, _, output_len, _priority in requests: + prompts.append(prompt) + priority.append(_priority) + sampling_params.append( + SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + )) + + start = time.perf_counter() + llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True) + end = time.perf_counter() + return end - start + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + if args.dataset is None: + # Synthesize a prompt with the given input length. + prompt = "hi" * (args.input_len - 1) + requests = [(prompt, args.input_len, args.output_len) + for _ in range(args.num_prompts)] + else: + requests = sample_requests(args.dataset, args.num_prompts, tokenizer, + args.output_len) + + if args.backend == "vllm": + elapsed_time = run_vllm( + requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.gpu_memory_utilization, + args.download_dir) + else: + raise ValueError(f"Unknown backend: {args.backend}") + total_num_tokens = sum(prompt_len + output_len + for _, prompt_len, output_len, priority in requests) + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark the throughput.") + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii"], + default="vllm") + parser.add_argument("--dataset", + type=str, + default=None, + help="Path to the dataset.") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument('--quantization', + '-q', + choices=[*QUANTIZATION_METHODS, None], + default=None) + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--num-prompts", + type=int, + default=200, + help="Number of prompts to process.") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument( + '--max-model-len', + type=int, + default=None, + help='Maximum length of a sequence (including prompt and output). ' + 'If None, will be derived from the model.') + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') + parser.add_argument("--enforce-eager", + action="store_true", + help="enforce eager execution") + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default="auto", + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache scaling factors. ' + 'This should generally be supplied, when KV cache dtype is FP8. ' + 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' + 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' + 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') + parser.add_argument( + "--device", + type=str, + default="cuda", + choices=["cuda", "cpu"], + help='device type for vLLM execution, supporting CUDA and CPU.') + parser.add_argument( + "--enable-prefix-caching", + action='store_true', + help="enable automatic prefix caching for vLLM backend.") + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') + parser.add_argument('--download-dir', + type=str, + default=None, + help='directory to download and load the weights, ' + 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + if args.dataset is None: + assert args.input_len is not None + assert args.output_len is not None + else: + assert args.input_len is None + + main(args) diff --git a/vllm/config.py b/vllm/config.py index 562564bbfa032..308f29a3dc371 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -961,7 +961,7 @@ class SchedulerConfig: workers instead of an entire data. It should be enabled only when SPMD worker architecture is enabled. I.e., VLLM_USE_RAY_SPMD_WORKER=1 - + policy: The scheduling policy to use. "fcfs" (default) or "priority". """ def __init__(self, @@ -977,7 +977,8 @@ def __init__(self, preemption_mode: Optional[str] = None, num_scheduler_steps: int = 1, multi_step_stream_outputs: bool = False, - send_delta_data: bool = False) -> None: + send_delta_data: bool = False, + policy: str = "fcfs") -> None: if max_num_batched_tokens is None: if enable_chunked_prefill: # It is the values that have the best balance between ITL @@ -1019,6 +1020,7 @@ def __init__(self, self.num_scheduler_steps = num_scheduler_steps self.multi_step_stream_outputs = multi_step_stream_outputs self.send_delta_data = send_delta_data + self.policy = policy self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c3fa95f57b737..b707d87c3af83 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -766,6 +766,79 @@ def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: else: return prompt_limit + def _get_priority(self, + seq_group: SequenceGroup) -> Tuple[Optional[int], float]: + """ Get the priority of the sequence group. + Highest preference to user-defined priority, followed by arrival time. + Args: + seq_group: The sequence group input. + Returns: + The priority of the sequence group. + """ + return seq_group.priority, seq_group.arrival_time + + def _schedule_priority_preemption( + self, + budget: SchedulingBudget, + ) -> int: + """Sorts waiting and running queue. Also, force preempt requests + from the running queue if their priority is lower. + Priority-based preemption is used with the priority policy. + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + Returns: + A count of priority-based preemptions. + """ + + waiting_queue = self.waiting + + running_queue = deque(sorted(self.running, key=self._get_priority)) + + blocks_to_swap_out: List[Tuple[int, int]] = [] + force_preemption_count = 0 + + if waiting_queue: + seq_group = waiting_queue.popleft() + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens = self._get_num_new_tokens(seq_group, + SequenceStatus.WAITING, + False, budget) + + #Only preempt if priority inversion exists + while running_queue and self._get_priority( + running_queue[-1]) > self._get_priority(seq_group): + #Only preempt if waiting sequence cannot be allocated + can_allocate = self.block_manager.can_allocate(seq_group) + if (num_new_tokens and can_allocate == AllocStatus.OK + and budget.can_schedule(num_new_tokens=num_new_tokens, + num_new_seqs=num_new_seqs)): + break + + #Adjust budget to remove the victim sequence group + vseq_group = running_queue.pop() + num_running_tokens = self._get_num_new_tokens( + vseq_group, SequenceStatus.RUNNING, False, budget) + budget.subtract_num_batched_tokens(vseq_group.request_id, + num_running_tokens) + num_running_seqs = vseq_group.get_max_num_running_seqs() + budget.subtract_num_seqs(vseq_group.request_id, + num_running_seqs) + + #Preempt out the victim sequence group + self._preempt(vseq_group, blocks_to_swap_out, + PreemptionMode.RECOMPUTE) + waiting_queue.appendleft(vseq_group) + force_preemption_count += 1 + #Put the sequence back into the waiting queue + waiting_queue.appendleft(seq_group) + + waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) + + self.waiting = waiting_queue + self.running = running_queue + return force_preemption_count + def _schedule_prefills( self, budget: SchedulingBudget, @@ -917,6 +990,10 @@ def _schedule_default(self) -> SchedulerOutputs: curr_loras, enable_chunking=False) + if len(prefills.seq_groups + ) == 0 and self.scheduler_config.policy == "priority": + self._schedule_priority_preemption(budget) + # Don't schedule decodes if prefills are scheduled. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # only contains decode requests, not chunked prefills. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bd7b3250e31af..c341b236003a3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -631,6 +631,7 @@ def _add_processed_request( lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, ) -> None: self._validate_model_inputs(processed_inputs) # Create the sequences. @@ -661,7 +662,8 @@ def _add_processed_request( lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( request_id, @@ -670,7 +672,8 @@ def _add_processed_request( arrival_time=arrival_time, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) else: raise ValueError( "Either SamplingParams or PoolingParams must be provided.") @@ -695,6 +698,7 @@ def add_request( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: """Add a request to the engine's request pool. @@ -713,6 +717,8 @@ def add_request( arrival_time: The arrival time of the request. If None, we use the current monotonic time. trace_headers: OpenTelemetry trace headers. + priority: The priority of the request. + Only applicable with priority scheduling. Details: - Set arrival_time to the current time if it is None. @@ -741,6 +747,11 @@ def add_request( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + + if priority > 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + if arrival_time is None: arrival_time = time.time() @@ -760,6 +771,7 @@ def add_request( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, + priority=priority, ) def _create_sequence_group_with_sampling( @@ -772,6 +784,7 @@ def _create_sequence_group_with_sampling( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, encoder_seq: Optional[Sequence] = None, + priority: int = 0, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -798,7 +811,8 @@ def _create_sequence_group_with_sampling( lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) return seq_group @@ -811,6 +825,7 @@ def _create_sequence_group_with_pooling( lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], encoder_seq: Optional[Sequence] = None, + priority: int = 0, ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler @@ -823,7 +838,8 @@ def _create_sequence_group_with_pooling( lora_request=lora_request, pooling_params=pooling_params, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index cd10eda8c212c..77ae7b088398a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -320,7 +320,8 @@ def generate( lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, - GuidedDecodingRequest]] = None + GuidedDecodingRequest]] = None, + priority: Optional[List[int]] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -339,6 +340,8 @@ def generate( lora_request: LoRA request to use for generation, if any. prompt_adapter_request: Prompt Adapter request to use for generation, if any. + priority: The priority of the requests, if any. + Only applicable when priority scheduling policy is enabled. Returns: A list of ``RequestOutput`` objects containing the @@ -379,7 +382,8 @@ def generate( params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, - guided_options=guided_options_request) + guided_options=guided_options_request, + priority=priority) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) @@ -782,6 +786,7 @@ def _validate_and_add_requests( lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], guided_options: Optional[GuidedDecodingRequest] = None, + priority: Optional[List[int]] = None, ) -> None: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. @@ -811,6 +816,7 @@ def _validate_and_add_requests( lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, prompt_adapter_request=prompt_adapter_request, + priority=priority[i] if priority else 0, ) def _add_request( @@ -819,6 +825,7 @@ def _add_request( params: Union[SamplingParams, PoolingParams], lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request( @@ -827,6 +834,7 @@ def _add_request( params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, + priority=priority, ) def _add_guided_processor( diff --git a/vllm/sequence.py b/vllm/sequence.py index b32e1aebe17be..fda7ef87749a1 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -646,6 +646,7 @@ class SequenceGroup: unless you are working with an encoder/decoder model. trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request. + priority: User-defined priority of the request. """ def __init__( @@ -660,9 +661,11 @@ def __init__( encoder_seq: Optional[Sequence] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: self.request_id = request_id self.seqs = seqs + self.arrival_time = arrival_time self.is_single_seq = len(seqs) == 1 self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -680,6 +683,7 @@ def __init__( self.prompt_adapter_request = prompt_adapter_request self.encoder_seq = encoder_seq self.trace_headers = trace_headers + self.priority = priority self.cached_request_output = None