diff --git a/tests/compile/piecewise/piecewise_compilation_config.json b/tests/compile/piecewise/piecewise_compilation_config.json deleted file mode 100644 index 798a34e8dd92d..0000000000000 --- a/tests/compile/piecewise/piecewise_compilation_config.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "use_cudagraph": true, - "non_cudagraph_ops": ["silly.attention"], - "cudagraph_copy_inputs": true -} \ No newline at end of file diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 45f56cbbd4b16..0e40e3b4ebc96 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -2,7 +2,6 @@ Test the piecewise compilation with a simple model so that we can exactly calculate the expected output and side effects. """ -import os import torch from torch import nn @@ -11,7 +10,7 @@ from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import CompilationLevel, VllmConfig +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.plugins import set_current_vllm_config from vllm.utils import direct_register_custom_op @@ -77,12 +76,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def test_simple_piecewise_compile(): - directory = os.path.dirname(__file__) - config = os.path.join(directory, "piecewise_compilation_config.json") - os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) - - vllm_config = VllmConfig() + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + non_cudagraph_ops=["silly.attention"], + cudagraph_copy_inputs=True, + )) with set_current_vllm_config(vllm_config): model = SillyModel(vllm_config=vllm_config, prefix='') @@ -109,6 +108,3 @@ def test_simple_piecewise_compile(): output = model(input) assert global_counter == 2 assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) - - # clean up to avoid side effects for other tests - del os.environ["VLLM_TORCH_COMPILE_CONFIG"] diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 8032304e95806..356d119a40334 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -6,7 +6,6 @@ if the config `tractable_init` is set to True. Otherwise, the weights are initialized randomly with a fixed seed. """ -import os from dataclasses import dataclass from typing import Optional, Tuple @@ -18,7 +17,7 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import CompilationConfig, CompilationLevel, VllmConfig -from vllm.plugins import set_compilation_config, set_current_vllm_config +from vllm.plugins import set_current_vllm_config from vllm.utils import direct_register_custom_op # create a library to hold the custom op @@ -254,23 +253,17 @@ def run_model(llama_config, split_attn: bool = False) -> torch.Tensor: if use_compile: - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str( - CompilationLevel.PIECEWISE) - + compilation_config = CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + ) if split_attn: - set_compilation_config( - CompilationConfig( - use_cudagraph=True, - non_cudagraph_ops=["silly.attention"], - )) - else: - set_compilation_config(CompilationConfig(use_cudagraph=True, )) + compilation_config.non_cudagraph_ops = ["silly.attention"] else: - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str( - CompilationLevel.NO_COMPILATION) - set_compilation_config(None) + compilation_config = CompilationConfig( + level=CompilationLevel.NO_COMPILATION, ) - vllm_config = VllmConfig() + vllm_config = VllmConfig(compilation_config=compilation_config) with set_current_vllm_config(vllm_config): model = LlamaModel(config=llama_config, vllm_config=vllm_config, @@ -288,10 +281,6 @@ def run_model(llama_config, input_ids[:2].zero_() output = model(input_ids[:2], positions[:2]) - # manual cleanup - del os.environ["VLLM_TORCH_COMPILE_LEVEL"] - set_compilation_config(None) - output = output.cpu() if llama_config.tractable_init: @@ -361,7 +350,6 @@ def test_toy_llama(): @torch.inference_mode def benchmark(): - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) from triton.testing import do_bench # similar to llama 3.1-8B @@ -387,15 +375,16 @@ def benchmark(): for piecewise in [False, True]: if piecewise: - set_compilation_config( - CompilationConfig( - use_cudagraph=True, - non_cudagraph_ops=["silly.attention"], - )) + compilation_config = CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + non_cudagraph_ops=["silly.attention"], + ) else: - set_compilation_config(None) + compilation_config = CompilationConfig( + level=CompilationLevel.PIECEWISE, ) - vllm_config = VllmConfig() + vllm_config = VllmConfig(compilation_config=compilation_config) with set_current_vllm_config(vllm_config): model = LlamaModel(config=llama_config, vllm_config=vllm_config, diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 08747ebc58b75..c0db2e78824be 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -96,31 +96,36 @@ def test_compile_correctness(test_setting: TestSetting): final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \ ["-tp", str(tp_size)] + all_args: List[List[str]] = [] all_envs: List[Optional[Dict[str, str]]] = [] for level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE, ]: - all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)}) + all_args.append(final_args + ["-O", str(level)]) + all_envs.append({}) # inductor will change the output, so we only compare if the output # is close, not exactly the same. compare_all_settings( - model, [final_args] * 2, + model, + all_args, all_envs, method=method if method != "generate" else "generate_close") all_envs.clear() + all_args.clear() for level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE, ]: - all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)}) + all_args.append(final_args + ["-O", str(level)]) + all_envs.append({}) if level != CompilationLevel.DYNAMO_ONCE and not fullgraph: # "DYNAMO_ONCE" will always use fullgraph all_envs[-1][ "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore - compare_all_settings(model, [final_args] * 3, all_envs, method=method) + compare_all_settings(model, all_args * 3, all_envs, method=method) diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 729f10676888b..078c6bf9ea1df 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -4,7 +4,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.config import CompilationLevel +from vllm.config import CompilationConfig, CompilationLevel from vllm.platforms import current_platform TEST_MODELS = [ @@ -65,7 +65,6 @@ def check_full_graph_support(model, optimization_level, tp_size=1): # make sure these models can be captured in full graph mode - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level) os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" # The base meta llama uses too much memory. @@ -86,6 +85,7 @@ def check_full_graph_support(model, enforce_eager=True, tensor_parallel_size=tp_size, disable_custom_all_reduce=True, + compilation_config=CompilationConfig(level=optimization_level), **model_kwargs) outputs = llm.generate(prompts, sampling_params) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index c3219bc50646b..c54e30995da49 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -1,4 +1,3 @@ -import os from typing import List import pytest @@ -53,9 +52,8 @@ class Relu3(ReLUSquaredActivation): ]) def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], default_on: bool): - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level) vllm_config = VllmConfig(compilation_config=CompilationConfig( - custom_ops=env.split(","))) + level=torch_level, custom_ops=env.split(","))) with set_current_vllm_config(vllm_config): assert CustomOp.default_on() == default_on diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 941abe17a3378..65bee85e7a1ea 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -1,24 +1,47 @@ import glob import os -import runpy import tempfile import depyf -from vllm.config import CompilationLevel - -# disable custom dispatcher, let Dynamo takes over -# all the control -os.environ['VLLM_TORCH_COMPILE_LEVEL'] = str(CompilationLevel.DYNAMO_AS_IS) +from vllm.config import CompilationConfig, CompilationLevel temp_dir = tempfile.mkdtemp() with depyf.prepare_debug(temp_dir): - cur_dir = os.path.dirname(__file__) - parent_dir = os.path.dirname(cur_dir) - root_dir = os.path.dirname(parent_dir) - example_file = os.path.join(root_dir, "examples", - "offline_inference_tpu.py") - runpy.run_path(example_file) + from vllm import LLM, SamplingParams + + prompts = [ + "A robot may not injure a human being", + "It is only with the heart that one can see rightly;", + "The greatest glory in living lies not in never falling,", + ] + answers = [ + " or, through inaction, allow a human being to come to harm.", + " what is essential is invisible to the eye.", + " but in rising every time we fall.", + ] + N = 1 + # Currently, top-p sampling is disabled. `top_p` should be 1.0. + sampling_params = SamplingParams(temperature=0.7, + top_p=1.0, + n=N, + max_tokens=16) + + # Set `enforce_eager=True` to avoid ahead-of-time compilation. + # In real workloads, `enforace_eager` should be `False`. + + # disable custom dispatcher, let Dynamo takes over + # all the control + llm = LLM(model="google/gemma-2b", + enforce_eager=True, + compilation_config=CompilationConfig( + level=CompilationLevel.DYNAMO_AS_IS)) + outputs = llm.generate(prompts, sampling_params) + for output, answer in zip(outputs, answers): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text.startswith(answer) compiled_code = sorted( glob.glob(os.path.join(temp_dir, "__transformed_code*.py"))) diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 53b10c06135a1..df348258efcba 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -13,7 +13,9 @@ def test_custom_dispatcher(): compare_two_settings( "google/gemma-2b", - arg1=["--enforce-eager"], - arg2=["--enforce-eager"], - env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)}, - env2={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_AS_IS)}) + arg1=["--enforce-eager", "-O", + str(CompilationLevel.DYNAMO_ONCE)], + arg2=["--enforce-eager", "-O", + str(CompilationLevel.DYNAMO_AS_IS)], + env1={}, + env2={}) diff --git a/vllm/config.py b/vllm/config.py index ea9ec43cc5a15..e69cbd3eb402a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2174,8 +2174,14 @@ class CompilationConfig(BaseModel): enabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr + @classmethod + def from_cli(cls, cli_value: str) -> "CompilationConfig": + """Parse the CLI value for the compilation config.""" + if cli_value in ["0", "1", "2", "3"]: + return cls(level=int(cli_value)) + return CompilationConfig.model_validate_json(cli_value) + def model_post_init(self, __context: Any) -> None: - self.level = envs.VLLM_TORCH_COMPILE_LEVEL count_none = self.custom_ops.count("none") count_all = self.custom_ops.count("all") @@ -2249,26 +2255,6 @@ def init_during_runtime(self): "inductor_specialize_for_cudagraph_no_more_than is None") self.compile_sizes = self.inductor_compile_sizes - @staticmethod - def select_and_init_config() -> "CompilationConfig": - """The order of selecting config is: - 1. Use the config specified in environment variable. - 2. Use the config specified in plugins. - 3. Use the default config. - """ - config_path = envs.VLLM_TORCH_COMPILE_CONFIG - if config_path is not None: - with open(config_path) as json_file: - config = CompilationConfig.model_validate_json( - json_file.read()) - else: - from vllm.plugins import get_compilation_config - predefined_config = get_compilation_config() - config = predefined_config if predefined_config is not None else ( - CompilationConfig()) - - return config - @dataclass class VllmConfig: @@ -2354,8 +2340,19 @@ def __post_init__(self): self.model_config, self.load_config) if self.compilation_config is None: - self.compilation_config = CompilationConfig.select_and_init_config( - ) + self.compilation_config = CompilationConfig() + if envs.VLLM_USE_V1: + # NOTE(woosuk): Currently, we use inductor because the piecewise + # CUDA graphs do not work properly with the custom CUDA kernels. + # FIXME(woosuk): Disable inductor to reduce the compilation time + # and avoid any potential issues with the inductor. + self.compilation_config.custom_ops = ["none"] + self.compilation_config.use_cudagraph = True + self.compilation_config.non_cudagraph_ops = [ + "vllm.unified_v1_flash_attention" + ] + self.compilation_config.use_inductor = True + self.compilation_config.enable_fusion = False current_platform.check_and_update_config(self) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ee4b6addfd466..a3ae1889774f3 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -8,12 +8,13 @@ import torch import vllm.envs as envs -from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, - DeviceConfig, HfOverrides, LoadConfig, LoadFormat, - LoRAConfig, ModelConfig, ObservabilityConfig, - ParallelConfig, PoolerConfig, PromptAdapterConfig, - SchedulerConfig, SpeculativeConfig, TaskOption, - TokenizerPoolConfig, VllmConfig) +from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, + DecodingConfig, DeviceConfig, HfOverrides, LoadConfig, + LoadFormat, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, PoolerConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, TaskOption, TokenizerPoolConfig, + VllmConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -189,6 +190,7 @@ class EngineArgs: override_neuron_config: Optional[Dict[str, Any]] = None override_pooler_config: Optional[PoolerConfig] = None + compilation_config: Optional[CompilationConfig] = None def __post_init__(self): if not self.tokenizer: @@ -868,6 +870,20 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Override or set the pooling method in the embedding model. " "e.g. {\"pooling_type\": \"mean\", \"normalize\": false}.'") + parser.add_argument('--compilation-config', + '-O', + type=CompilationConfig.from_cli, + default=None, + help='torch.compile configuration for the model.' + 'When it is a number (0, 1, 2, 3), it will be ' + 'interpreted as the optimization level.\n' + 'NOTE: level 0 is the default level without ' + 'any optimization. level 1 and 2 are for internal ' + 'testing only. level 3 is the recommended level ' + 'for production.\n' + 'To specify the full compilation config, ' + 'use a JSON string.') + return parser @classmethod @@ -1142,6 +1158,7 @@ def create_engine_config(self) -> VllmConfig: decoding_config=decoding_config, observability_config=observability_config, prompt_adapter_config=prompt_adapter_config, + compilation_config=self.compilation_config, ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e72dc81f35b67..2a5eaf1340762 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -262,7 +262,8 @@ def __init__( "num_scheduler_steps=%d, chunked_prefill_enabled=%s " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, " - "mm_processor_kwargs=%s, pooler_config=%r)", + "mm_processor_kwargs=%s, pooler_config=%r," + "compilation_config=%r", VLLM_VERSION, model_config.model, speculative_config, @@ -297,6 +298,7 @@ def __init__( use_cached_outputs, model_config.mm_processor_kwargs, model_config.pooler_config, + vllm_config.compilation_config, ) # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config diff --git a/vllm/envs.py b/vllm/envs.py index 716e835a555f1..853c49bc4dbc1 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -67,8 +67,6 @@ VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False - VLLM_TORCH_COMPILE_LEVEL: int = 0 - VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = False @@ -209,12 +207,6 @@ def get_default_config_root(): "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": lambda: bool( os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), - "VLLM_TORCH_COMPILE_LEVEL": - lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")), - - # Path to the config file for torch compile - "VLLM_TORCH_COMPILE_CONFIG": - lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None), # local rank of the process in the distributed setting, used to determine # the GPU device id diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 9057afb6514e4..2a7ca9fb8c576 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,4 +1,3 @@ -import os from typing import TYPE_CHECKING import torch @@ -40,7 +39,8 @@ def inference_mode(cls): def check_and_update_config(cls, vllm_config: VllmConfig) -> None: from vllm.config import CompilationLevel compilation_config = vllm_config.compilation_config - if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: + if compilation_config.level == CompilationLevel.NO_COMPILATION: + # TPU does not support NO_COMPILATION compilation_config.level = CompilationLevel.DYNAMO_ONCE assert compilation_config.level < CompilationLevel.PIECEWISE,\ "TPU does not support Inductor." diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 05a9739d99e71..dc183dbfc9b96 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -5,7 +5,7 @@ import vllm.envs as envs if TYPE_CHECKING: - from vllm.config import CompilationConfig, VllmConfig + from vllm.config import VllmConfig logger = logging.getLogger(__name__) @@ -54,18 +54,6 @@ def load_general_plugins(): logger.exception("Failed to load plugin %s", plugin.name) -_compilation_config: Optional["CompilationConfig"] = None - - -def set_compilation_config(config: Optional["CompilationConfig"]): - global _compilation_config - _compilation_config = config - - -def get_compilation_config() -> Optional["CompilationConfig"]: - return _compilation_config - - _current_vllm_config: Optional["VllmConfig"] = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d60f93a44f6dd..1f9b544637bf7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,13 +8,12 @@ import torch.nn as nn from vllm.compilation.compile_context import set_compile_context -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.config import CompilationLevel, VllmConfig from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.multimodal import MultiModalKwargs -from vllm.plugins import set_compilation_config from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, is_pin_memory_available) @@ -508,20 +507,6 @@ def execute_model( return model_runner_output def load_model(self) -> None: - if self.use_cuda_graph: - # NOTE(woosuk): Currently, we use inductor because the piecewise - # CUDA graphs do not work properly with the custom CUDA kernels. - # FIXME(woosuk): Disable inductor to reduce the compilation time - # and avoid any potential issues with the inductor. - set_compilation_config( - CompilationConfig( - custom_ops=["none"], - use_cudagraph=True, - non_cudagraph_ops=["vllm.unified_v1_flash_attention"], - use_inductor=True, - enable_fusion=False, - )) - logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) @@ -562,9 +547,8 @@ def profile_run(self) -> None: def capture_model(self) -> None: if not self.use_cuda_graph: logger.warning( - "Skipping CUDA graph capture. Please set " - "VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs.", - CompilationLevel.PIECEWISE) + "Skipping CUDA graph capture. Please add " + "-O 3 to use CUDA graphs.", CompilationLevel.PIECEWISE) return start_time = time.perf_counter()