Skip to content

Commit

Permalink
[Misc] Remove deprecated arg for cuda graph capture (vllm-project#9864)
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 authored Oct 31, 2024
1 parent d087bf8 commit 3ea2dc2
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 23 deletions.
7 changes: 0 additions & 7 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ class ModelConfig:
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
If None, the user did not specify, so default to False.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode. Additionally for encoder-decoder models, if the
Expand Down Expand Up @@ -147,7 +144,6 @@ def __init__(
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
Expand Down Expand Up @@ -181,9 +177,6 @@ def __init__(
self.quantization = quantization
self.quantization_param_path = quantization_param_path
self.enforce_eager = enforce_eager
if max_context_len_to_capture is not None:
raise ValueError("`max_context_len_to_capture` is deprecated. "
"Use `max_seq_len_to_capture` instead.")
self.max_seq_len_to_capture = max_seq_len_to_capture
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
Expand Down
10 changes: 0 additions & 10 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ class EngineArgs:
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: Optional[bool] = None
max_context_len_to_capture: Optional[int] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
tokenizer_pool_size: int = 0
Expand Down Expand Up @@ -504,14 +503,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help='Always use eager-mode PyTorch. If False, '
'will use eager mode and CUDA graph in hybrid '
'for maximal performance and flexibility.')
parser.add_argument('--max-context-len-to-capture',
type=int,
default=EngineArgs.max_context_len_to_capture,
help='Maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode. '
'(DEPRECATED. Use --max-seq-len-to-capture instead'
')')
parser.add_argument('--max-seq-len-to-capture',
type=int,
default=EngineArgs.max_seq_len_to_capture,
Expand Down Expand Up @@ -939,7 +930,6 @@ def create_model_config(self) -> ModelConfig:
quantization=self.quantization,
quantization_param_path=self.quantization_param_path,
enforce_eager=self.enforce_eager,
max_context_len_to_capture=self.max_context_len_to_capture,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
disable_sliding_window=self.disable_sliding_window,
Expand Down
5 changes: 0 additions & 5 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,6 @@ class LLM:
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode. Additionally for encoder-decoder models, if the
Expand Down Expand Up @@ -152,7 +149,6 @@ def __init__(
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
Expand Down Expand Up @@ -193,7 +189,6 @@ def __init__(
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ def __init__(
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
# (max batch size to capture, max seq len to capture / block size).
self.graph_block_tables = np.zeros(
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
dtype=np.int32)
Expand Down

0 comments on commit 3ea2dc2

Please sign in to comment.