From b104dc2951dbe2670c594c2fb778edd69068848e Mon Sep 17 00:00:00 2001 From: Jeremy Arnold Date: Sat, 19 Oct 2024 19:04:16 +0000 Subject: [PATCH] Make benchmarks use EngineArgs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update the benchmark scripts to directly use the CLI arguments provided by EngineArgs instead of duplicating a subset of these arguments in each benchmark script. Currently the CLI arguments are duplicated, forcing changes to be made in multiple locations and resulting in some useful vLLM options not being exposed in the scripts.  For example, the --num-scheduler-steps option is currently available in benchmark_throughput.py but not benchmark_latency.py, making it difficult to understand the latency impacts of this option.  As another example, the benchmark_prioritization.py script appears to be broken currently because it was not updated to expose the --scheduling-policy option which is required for enabling priority. These maintenance challenges are eliminated by using EngineArgs.add_cli_args to add support for all engine arguments directly, and then passing these options to the engine initialization. One minor change in behavior is that when benchmark_throughput.py runs in async mode it no longer includes hard-coded settings for worker_use_ray=False (which is deprecated anyway) and disable_log_requests=True (but the user now has the option to pass --disable-log-requests on the command-line). Similarly, benchmark_prefix_caching no longer has hard-coded values for trust_remote_code=True and enforce_eager=True, but these may now be passed on the command-line. --- benchmarks/benchmark_latency.py | 155 +--------------- benchmarks/benchmark_prefix_caching.py | 24 +-- benchmarks/benchmark_prioritization.py | 134 +------------- benchmarks/benchmark_throughput.py | 237 ++----------------------- 4 files changed, 38 insertions(+), 512 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index ea1a7788f621d..0a14aedd5feba 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -1,5 +1,6 @@ """Benchmark the latency of processing a single batch of requests.""" import argparse +import dataclasses import json import time from pathlib import Path @@ -10,43 +11,19 @@ from tqdm import tqdm from vllm import LLM, SamplingParams -from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs +from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser def main(args: argparse.Namespace): print(args) + engine_args = EngineArgs.from_cli_args(args) + # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. - llm = LLM( - model=args.model, - speculative_model=args.speculative_model, - num_speculative_tokens=args.num_speculative_tokens, - speculative_draft_tensor_parallel_size=\ - args.speculative_draft_tensor_parallel_size, - tokenizer=args.tokenizer, - quantization=args.quantization, - tensor_parallel_size=args.tensor_parallel_size, - trust_remote_code=args.trust_remote_code, - dtype=args.dtype, - max_model_len=args.max_model_len, - enforce_eager=args.enforce_eager, - kv_cache_dtype=args.kv_cache_dtype, - quantization_param_path=args.quantization_param_path, - device=args.device, - ray_workers_use_nsight=args.ray_workers_use_nsight, - enable_chunked_prefill=args.enable_chunked_prefill, - download_dir=args.download_dir, - block_size=args.block_size, - gpu_memory_utilization=args.gpu_memory_utilization, - load_format=args.load_format, - distributed_executor_backend=args.distributed_executor_backend, - otlp_traces_endpoint=args.otlp_traces_endpoint, - enable_prefix_caching=args.enable_prefix_caching, - ) + llm = LLM(**dataclasses.asdict(engine_args)) sampling_params = SamplingParams( n=args.n, @@ -125,19 +102,6 @@ def run_to_completion(profile_dir: Optional[str] = None): parser = FlexibleArgumentParser( description='Benchmark the latency of processing a single batch of ' 'requests till completion.') - parser.add_argument('--model', type=str, default='facebook/opt-125m') - parser.add_argument('--speculative-model', type=str, default=None) - parser.add_argument('--num-speculative-tokens', type=int, default=None) - parser.add_argument('--speculative-draft-tensor-parallel-size', - '-spec-draft-tp', - type=int, - default=None) - parser.add_argument('--tokenizer', type=str, default=None) - parser.add_argument('--quantization', - '-q', - choices=[*QUANTIZATION_METHODS, None], - default=None) - parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--output-len', type=int, default=128) parser.add_argument('--batch-size', type=int, default=8) @@ -154,45 +118,6 @@ def run_to_completion(profile_dir: Optional[str] = None): type=int, default=30, help='Number of iterations to run.') - parser.add_argument('--trust-remote-code', - action='store_true', - help='trust remote code from huggingface') - parser.add_argument( - '--max-model-len', - type=int, - default=None, - help='Maximum length of a sequence (including prompt and output). ' - 'If None, will be derived from the model.') - parser.add_argument( - '--dtype', - type=str, - default='auto', - choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - parser.add_argument('--enforce-eager', - action='store_true', - help='enforce eager mode and disable CUDA graph') - parser.add_argument( - '--kv-cache-dtype', - type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], - default="auto", - help='Data type for kv cache storage. If "auto", will use model ' - 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') - parser.add_argument( - '--quantization-param-path', - type=str, - default=None, - help='Path to the JSON file containing the KV cache scaling factors. ' - 'This should generally be supplied, when KV cache dtype is FP8. ' - 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' - 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' - 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' - 'instead supported for common inference criteria.') parser.add_argument( '--profile', action='store_true', @@ -203,78 +128,12 @@ def run_to_completion(profile_dir: Optional[str] = None): default=None, help=('path to save the pytorch profiler output. Can be visualized ' 'with ui.perfetto.dev or Tensorboard.')) - parser.add_argument("--device", - type=str, - default="auto", - choices=DEVICE_OPTIONS, - help='device type for vLLM execution') - parser.add_argument('--block-size', - type=int, - default=16, - help='block size of key/value cache') - parser.add_argument( - '--enable-chunked-prefill', - action='store_true', - help='If True, the prefill requests can be chunked based on the ' - 'max_num_batched_tokens') - parser.add_argument("--enable-prefix-caching", - action='store_true', - help="Enable automatic prefix caching") - parser.add_argument( - "--ray-workers-use-nsight", - action='store_true', - help="If specified, use nsight to profile ray workers", - ) - parser.add_argument('--download-dir', - type=str, - default=None, - help='directory to download and load the weights, ' - 'default to the default cache dir of huggingface') parser.add_argument( '--output-json', type=str, default=None, help='Path to save the latency results in JSON format.') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=0.9, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') - parser.add_argument( - '--load-format', - type=str, - default=EngineArgs.load_format, - choices=[ - 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', - 'bitsandbytes' - ], - help='The format of the model weights to load.\n\n' - '* "auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available.\n' - '* "pt" will load the weights in the pytorch bin format.\n' - '* "safetensors" will load the weights in the safetensors format.\n' - '* "npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading.\n' - '* "dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.\n' - '* "tensorizer" will load the weights using tensorizer from ' - 'CoreWeave. See the Tensorize vLLM Model script in the Examples' - 'section for more information.\n' - '* "bitsandbytes" will load the weights using bitsandbytes ' - 'quantization.\n') - parser.add_argument( - '--distributed-executor-backend', - choices=['ray', 'mp'], - default=None, - help='Backend to use for distributed serving. When more than 1 GPU ' - 'is used, will be automatically set to "ray" if installed ' - 'or "mp" (multiprocessing) otherwise.') - parser.add_argument( - '--otlp-traces-endpoint', - type=str, - default=None, - help='Target URL to which OpenTelemetry traces will be sent.') + + parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index a354358e43aa3..1aac029992dbf 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -25,6 +25,7 @@ --input-length-range 128:256 """ +import dataclasses import json import random import time @@ -33,6 +34,7 @@ from transformers import PreTrainedTokenizerBase from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs from vllm.utils import FlexibleArgumentParser try: @@ -129,12 +131,9 @@ def main(args): filtered_datasets = [(PROMPT, prompt_len, args.output_len) ] * args.num_prompts - llm = LLM(model=args.model, - tokenizer_mode='auto', - trust_remote_code=True, - enforce_eager=True, - tensor_parallel_size=args.tensor_parallel_size, - enable_prefix_caching=args.enable_prefix_caching) + engine_args = EngineArgs.from_cli_args(args) + + llm = LLM(**dataclasses.asdict(engine_args)) sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) @@ -162,18 +161,11 @@ def main(args): parser = FlexibleArgumentParser( description= 'Benchmark the performance with or without automatic prefix caching.') - parser.add_argument('--model', - type=str, - default='baichuan-inc/Baichuan2-13B-Chat') parser.add_argument("--dataset-path", type=str, default=None, help="Path to the dataset.") - parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--output-len', type=int, default=10) - parser.add_argument('--enable-prefix-caching', - action='store_true', - help='enable prefix caching') parser.add_argument('--num-prompts', type=int, default=1, @@ -190,9 +182,7 @@ def main(args): default='128:256', help='Range of input lengths for sampling prompts,' 'specified as "min:max" (e.g., "128:256").') - parser.add_argument("--seed", - type=int, - default=0, - help='Random seed for reproducibility') + + parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py index 8843e3a927a01..e0c9e6a6db502 100644 --- a/benchmarks/benchmark_prioritization.py +++ b/benchmarks/benchmark_prioritization.py @@ -1,5 +1,6 @@ """Benchmark offline prioritization.""" import argparse +import dataclasses import json import random import time @@ -7,7 +8,8 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import FlexibleArgumentParser def sample_requests( @@ -62,46 +64,11 @@ def sample_requests( def run_vllm( requests: List[Tuple[str, int, int]], - model: str, - tokenizer: str, - quantization: Optional[str], - tensor_parallel_size: int, - seed: int, n: int, - trust_remote_code: bool, - dtype: str, - max_model_len: Optional[int], - enforce_eager: bool, - kv_cache_dtype: str, - quantization_param_path: Optional[str], - device: str, - enable_prefix_caching: bool, - enable_chunked_prefill: bool, - max_num_batched_tokens: int, - gpu_memory_utilization: float = 0.9, - download_dir: Optional[str] = None, + engine_args: EngineArgs, ) -> float: from vllm import LLM, SamplingParams - llm = LLM( - model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, - device=device, - enable_prefix_caching=enable_prefix_caching, - download_dir=download_dir, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - disable_log_stats=False, - ) + llm = LLM(**dataclasses.asdict(engine_args)) # Add the requests to the engine. prompts = [] @@ -142,16 +109,8 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - elapsed_time = run_vllm(requests, args.model, args.tokenizer, - args.quantization, args.tensor_parallel_size, - args.seed, args.n, args.trust_remote_code, - args.dtype, args.max_model_len, - args.enforce_eager, args.kv_cache_dtype, - args.quantization_param_path, args.device, - args.enable_prefix_caching, - args.enable_chunked_prefill, - args.max_num_batched_tokens, - args.gpu_memory_utilization, args.download_dir) + elapsed_time = run_vllm(requests, args.n, + EngineArgs.from_cli_args(args)) else: raise ValueError(f"Unknown backend: {args.backend}") total_num_tokens = sum(prompt_len + output_len @@ -173,7 +132,7 @@ def main(args: argparse.Namespace): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Benchmark the throughput.") + parser = FlexibleArgumentParser(description="Benchmark the throughput.") parser.add_argument("--backend", type=str, choices=["vllm", "hf", "mii"], @@ -191,13 +150,6 @@ def main(args: argparse.Namespace): default=None, help="Output length for each request. Overrides the " "output length from the dataset.") - parser.add_argument("--model", type=str, default="facebook/opt-125m") - parser.add_argument("--tokenizer", type=str, default=None) - parser.add_argument('--quantization', - '-q', - choices=[*QUANTIZATION_METHODS, None], - default=None) - parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", type=int, default=1, @@ -206,81 +158,13 @@ def main(args: argparse.Namespace): type=int, default=200, help="Number of prompts to process.") - parser.add_argument("--seed", type=int, default=0) - parser.add_argument('--trust-remote-code', - action='store_true', - help='trust remote code from huggingface') - parser.add_argument( - '--max-model-len', - type=int, - default=None, - help='Maximum length of a sequence (including prompt and output). ' - 'If None, will be derived from the model.') - parser.add_argument( - '--dtype', - type=str, - default='auto', - choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=0.9, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') - parser.add_argument("--enforce-eager", - action="store_true", - help="enforce eager execution") - parser.add_argument( - '--kv-cache-dtype', - type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], - default="auto", - help='Data type for kv cache storage. If "auto", will use model ' - 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') - parser.add_argument( - '--quantization-param-path', - type=str, - default=None, - help='Path to the JSON file containing the KV cache scaling factors. ' - 'This should generally be supplied, when KV cache dtype is FP8. ' - 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' - 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' - 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' - 'instead supported for common inference criteria.') - parser.add_argument( - "--device", - type=str, - default="cuda", - choices=["cuda", "cpu"], - help='device type for vLLM execution, supporting CUDA and CPU.') - parser.add_argument( - "--enable-prefix-caching", - action='store_true', - help="enable automatic prefix caching for vLLM backend.") - parser.add_argument("--enable-chunked-prefill", - action='store_true', - help="enable chunked prefill for vLLM backend.") - parser.add_argument('--max-num-batched-tokens', - type=int, - default=None, - help='maximum number of batched tokens per ' - 'iteration') - parser.add_argument('--download-dir', - type=str, - default=None, - help='directory to download and load the weights, ' - 'default to the default cache dir of huggingface') parser.add_argument( '--output-json', type=str, default=None, help='Path to save the throughput results in JSON format.') + parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index e26706af606b0..5cca92edb251b 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -1,5 +1,6 @@ """Benchmark offline inference throughput.""" import argparse +import dataclasses import json import random import time @@ -11,10 +12,9 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) -from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.sampling_params import BeamSearchParams from vllm.utils import FlexibleArgumentParser, merge_async_iterators @@ -67,53 +67,11 @@ def sample_requests( def run_vllm( requests: List[Tuple[str, int, int]], - model: str, - tokenizer: str, - quantization: Optional[str], - tensor_parallel_size: int, - seed: int, n: int, - trust_remote_code: bool, - dtype: str, - max_model_len: Optional[int], - enforce_eager: bool, - kv_cache_dtype: str, - quantization_param_path: Optional[str], - device: str, - enable_prefix_caching: bool, - enable_chunked_prefill: bool, - max_num_batched_tokens: int, - distributed_executor_backend: Optional[str], - gpu_memory_utilization: float = 0.9, - num_scheduler_steps: int = 1, - download_dir: Optional[str] = None, - load_format: str = EngineArgs.load_format, - disable_async_output_proc: bool = False, + engine_args: EngineArgs, ) -> float: from vllm import LLM, SamplingParams - llm = LLM( - model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, - device=device, - enable_prefix_caching=enable_prefix_caching, - download_dir=download_dir, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - distributed_executor_backend=distributed_executor_backend, - load_format=load_format, - num_scheduler_steps=num_scheduler_steps, - disable_async_output_proc=disable_async_output_proc, - ) + llm = LLM(**dataclasses.asdict(engine_args)) # Add the requests to the engine. prompts: List[str] = [] @@ -155,56 +113,11 @@ def run_vllm( async def run_vllm_async( requests: List[Tuple[str, int, int]], - model: str, - tokenizer: str, - quantization: Optional[str], - tensor_parallel_size: int, - seed: int, n: int, - trust_remote_code: bool, - dtype: str, - max_model_len: Optional[int], - enforce_eager: bool, - kv_cache_dtype: str, - quantization_param_path: Optional[str], - device: str, - enable_prefix_caching: bool, - enable_chunked_prefill: bool, - max_num_batched_tokens: int, - distributed_executor_backend: Optional[str], - gpu_memory_utilization: float = 0.9, - num_scheduler_steps: int = 1, - download_dir: Optional[str] = None, - load_format: str = EngineArgs.load_format, - disable_async_output_proc: bool = False, + engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, ) -> float: from vllm import SamplingParams - engine_args = AsyncEngineArgs( - model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, - device=device, - enable_prefix_caching=enable_prefix_caching, - download_dir=download_dir, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - distributed_executor_backend=distributed_executor_backend, - load_format=load_format, - num_scheduler_steps=num_scheduler_steps, - disable_async_output_proc=disable_async_output_proc, - worker_use_ray=False, - disable_log_requests=True, - ) async with build_async_engine_client_from_engine_args( engine_args, disable_frontend_multiprocessing) as llm: @@ -328,23 +241,17 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - run_args = [ - requests, args.model, args.tokenizer, args.quantization, - args.tensor_parallel_size, args.seed, args.n, - args.trust_remote_code, args.dtype, args.max_model_len, - args.enforce_eager, args.kv_cache_dtype, - args.quantization_param_path, args.device, - args.enable_prefix_caching, args.enable_chunked_prefill, - args.max_num_batched_tokens, args.distributed_executor_backend, - args.gpu_memory_utilization, args.num_scheduler_steps, - args.download_dir, args.load_format, args.disable_async_output_proc - ] - if args.async_engine: - run_args.append(args.disable_frontend_multiprocessing) - elapsed_time = uvloop.run(run_vllm_async(*run_args)) + elapsed_time = uvloop.run( + run_vllm_async( + requests, + args.n, + AsyncEngineArgs.from_cli_args(args), + args.disable_frontend_multiprocessing, + )) else: - elapsed_time = run_vllm(*run_args) + elapsed_time = run_vllm(requests, args.n, + EngineArgs.from_cli_args(args)) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -391,13 +298,6 @@ def main(args: argparse.Namespace): default=None, help="Output length for each request. Overrides the " "output length from the dataset.") - parser.add_argument("--model", type=str, default="facebook/opt-125m") - parser.add_argument("--tokenizer", type=str, default=None) - parser.add_argument('--quantization', - '-q', - choices=[*QUANTIZATION_METHODS, None], - default=None) - parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", type=int, default=1, @@ -406,123 +306,15 @@ def main(args: argparse.Namespace): type=int, default=1000, help="Number of prompts to process.") - parser.add_argument("--seed", type=int, default=0) parser.add_argument("--hf-max-batch-size", type=int, default=None, help="Maximum batch size for HF backend.") - parser.add_argument('--trust-remote-code', - action='store_true', - help='trust remote code from huggingface') - parser.add_argument( - '--max-model-len', - type=int, - default=None, - help='Maximum length of a sequence (including prompt and output). ' - 'If None, will be derived from the model.') - parser.add_argument( - '--dtype', - type=str, - default='auto', - choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=0.9, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') - parser.add_argument("--enforce-eager", - action="store_true", - help="enforce eager execution") - parser.add_argument( - '--kv-cache-dtype', - type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], - default="auto", - help='Data type for kv cache storage. If "auto", will use model ' - 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') - parser.add_argument( - '--quantization-param-path', - type=str, - default=None, - help='Path to the JSON file containing the KV cache scaling factors. ' - 'This should generally be supplied, when KV cache dtype is FP8. ' - 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' - 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' - 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' - 'instead supported for common inference criteria.') - parser.add_argument("--device", - type=str, - default="auto", - choices=DEVICE_OPTIONS, - help='device type for vLLM execution') - parser.add_argument( - "--num-scheduler-steps", - type=int, - default=1, - help="Maximum number of forward steps per scheduler call.") - parser.add_argument( - "--enable-prefix-caching", - action='store_true', - help="Enable automatic prefix caching for vLLM backend.") - parser.add_argument("--enable-chunked-prefill", - action='store_true', - help="enable chunked prefill for vLLM backend.") - parser.add_argument('--max-num-batched-tokens', - type=int, - default=None, - help='maximum number of batched tokens per ' - 'iteration') - parser.add_argument('--download-dir', - type=str, - default=None, - help='directory to download and load the weights, ' - 'default to the default cache dir of huggingface') parser.add_argument( '--output-json', type=str, default=None, help='Path to save the throughput results in JSON format.') - parser.add_argument( - '--distributed-executor-backend', - choices=['ray', 'mp'], - default=None, - help='Backend to use for distributed serving. When more than 1 GPU ' - 'is used, will be automatically set to "ray" if installed ' - 'or "mp" (multiprocessing) otherwise.') - parser.add_argument( - '--load-format', - type=str, - default=EngineArgs.load_format, - choices=[ - 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', - 'bitsandbytes' - ], - help='The format of the model weights to load.\n\n' - '* "auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available.\n' - '* "pt" will load the weights in the pytorch bin format.\n' - '* "safetensors" will load the weights in the safetensors format.\n' - '* "npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading.\n' - '* "dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.\n' - '* "tensorizer" will load the weights using tensorizer from ' - 'CoreWeave. See the Tensorize vLLM Model script in the Examples' - 'section for more information.\n' - '* "bitsandbytes" will load the weights using bitsandbytes ' - 'quantization.\n') - parser.add_argument( - "--disable-async-output-proc", - action='store_true', - default=False, - help="Disable async output processor for vLLM backend.") parser.add_argument("--async-engine", action='store_true', default=False, @@ -531,6 +323,7 @@ def main(args: argparse.Namespace): action='store_true', default=False, help="Disable decoupled async engine frontend.") + parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model