diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 02c673c96fd9a..03a2b1157652b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2,6 +2,7 @@ from collections import defaultdict import os import time +import pickle from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union) @@ -30,6 +31,11 @@ logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 +# If the env var is set, it uses the Ray's compiled DAG API +# which optimizes the control plane overhead. +# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. +USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -124,6 +130,10 @@ def __init__( self.stat_logger = StatLogger( local_interval=_LOCAL_LOGGING_INTERVAL_SEC) + self.forward_dag = None + if USE_RAY_COMPILED_DAG: + self.forward_dag = self._compiled_ray_dag() + def get_tokenizer_for_seq(self, sequence: Sequence): return self.tokenizer.get_lora_tokenizer(sequence.lora_request) @@ -806,7 +816,8 @@ def step(self) -> List[RequestOutput]: "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, "blocks_to_copy": scheduler_outputs.blocks_to_copy, - }) + }, + use_ray_compiled_dag=USE_RAY_COMPILED_DAG) # Only the driver worker returns the sampling results. output = all_outputs[0] @@ -966,6 +977,7 @@ def _run_workers( driver_args: Optional[List[Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, max_concurrent_workers: Optional[int] = None, + use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers.""" @@ -974,11 +986,16 @@ def _run_workers( raise NotImplementedError( "max_concurrent_workers is not supported yet.") - # Start the ray workers first. - ray_worker_outputs = [ - worker.execute_method.remote(method, *args, **kwargs) - for worker in self.workers - ] + if use_ray_compiled_dag: + # Right now, compiled DAG can only accept a single + # input. TODO(sang): Fix it. + output_channels = self.forward_dag.execute(1) + else: + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *args, **kwargs) + for worker in self.workers + ] if driver_args is None: driver_args = args @@ -991,6 +1008,37 @@ def _run_workers( # Get the results of the ray workers. if self.workers: - ray_worker_outputs = ray.get(ray_worker_outputs) + if use_ray_compiled_dag: + try: + ray_worker_outputs = [ + pickle.loads(chan.begin_read()) + for chan in output_channels + ] + finally: + # Has to call end_read in order to reuse the DAG. + for chan in output_channels: + chan.end_read() + else: + ray_worker_outputs = ray.get(ray_worker_outputs) return [driver_worker_output] + ray_worker_outputs + + def _compiled_ray_dag(self): + import pkg_resources + required_version = "2.9" + current_version = pkg_resources.get_distribution("ray").version + if current_version < required_version: + raise ValueError(f"Ray version {required_version} or greater is " + f"required, but found {current_version}") + + from ray.dag import MultiOutputNode, InputNode + assert self.parallel_config.worker_use_ray + + # Right now, compiled DAG requires at least 1 arg. We send + # a dummy value for now. It will be fixed soon. + with InputNode() as input_data: + forward_dag = MultiOutputNode([ + worker.execute_model_compiled_dag_remote.bind(input_data) + for worker in self.workers + ]) + return forward_dag.experimental_compile() diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index afbc33ed19a0c..bbcbbdfea2f00 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -1,3 +1,5 @@ +import pickle + from typing import Optional, List, Tuple, TYPE_CHECKING from vllm.config import ParallelConfig @@ -18,6 +20,11 @@ def __init__(self, init_cached_hf_modules=False) -> None: from transformers.dynamic_module_utils import init_hf_modules init_hf_modules() self.worker = None + # Since the compiled DAG runs a main execution + # in a different thread that calls cuda.set_device. + # The flag indicates is set_device is called on + # that thread. + self.compiled_dag_cuda_device_set = False def init_worker(self, worker_init_fn): self.worker = worker_init_fn() @@ -40,6 +47,17 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: def set_cuda_visible_devices(self, device_ids) -> None: set_cuda_visible_devices(device_ids) + def execute_model_compiled_dag_remote(self, ignored): + """Used only when compiled DAG is enabled.""" + import torch + if not self.compiled_dag_cuda_device_set: + torch.cuda.set_device(self.worker.device) + self.compiled_dag_cuda_device_set = True + + output = self.worker.execute_model() + output = pickle.dumps(output) + return output + except ImportError as e: logger.warning(f"Failed to import Ray with {e!r}. " "For distributed inference, please install Ray with "