diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index b50eed1c8c722..d58f621d36b86 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -1,18 +1,18 @@ import multiprocessing -import os import pytest import torch from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, ncclGetUniqueId) +from vllm.utils import update_environment_variables def distributed_run(fn, world_size): number_of_processes = world_size processes = [] for i in range(number_of_processes): - env = os.environ.copy() + env = {} env['RANK'] = str(i) env['LOCAL_RANK'] = str(i) env['WORLD_SIZE'] = str(number_of_processes) @@ -32,8 +32,7 @@ def update_env(fn): # so we need to pass the environment variables as arguments # and update the environment variables in the function def wrapper(env): - import os - os.environ.update(env) + update_environment_variables(env) fn() return wrapper diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index 04d4ed83976d0..febae42b84549 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -1,55 +1,28 @@ import pickle -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.utils import get_ip, is_hip, set_cuda_visible_devices -from vllm.worker.worker import Worker +from vllm.utils import get_ip, is_hip +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) try: import ray - class RayWorkerVllm: + class RayWorkerWrapper(WorkerWrapperBase): """Ray wrapper for vllm.worker.Worker, allowing Worker to be lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" - def __init__(self, init_cached_hf_modules=False) -> None: - if init_cached_hf_modules: - from transformers.dynamic_module_utils import init_hf_modules - init_hf_modules() - self._worker: Optional[Worker] = None + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) # Since the compiled DAG runs a main execution # in a different thread that calls cuda.set_device. # The flag indicates is set_device is called on # that thread. self.compiled_dag_cuda_device_set = False - def init_worker(self, worker_init_fn: Callable[[], Worker]): - self._worker = worker_init_fn() - - @property - def worker(self) -> Worker: - assert self._worker is not None - return self._worker - - def __getattr__(self, name): - return getattr(self.worker, name) - - def execute_method(self, method, *args, **kwargs): - try: - executor = getattr(self, method) - return executor(*args, **kwargs) - except Exception as e: - # exceptions in ray worker may cause deadlock - # see https://github.com/vllm-project/vllm/issues/3455 - # print the error and inform the user to solve the error - msg = (f"Error executing method {method}. " - "This might cause deadlock in distributed execution.") - logger.exception(msg) - raise e - def get_node_ip(self) -> str: return get_ip() @@ -58,9 +31,6 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: gpu_ids = ray.get_gpu_ids() return node_id, gpu_ids - def set_cuda_visible_devices(self, device_ids) -> None: - set_cuda_visible_devices(device_ids) - def execute_model_compiled_dag_remote(self, ignored): """Used only when compiled DAG is enabled.""" import torch @@ -77,7 +47,7 @@ def execute_model_compiled_dag_remote(self, ignored): "For distributed inference, please install Ray with " "`pip install ray`.") ray = None # type: ignore - RayWorkerVllm = None # type: ignore + RayWorkerWrapper = None # type: ignore def initialize_ray_cluster( diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 5f859fdc9c078..5a43f1fc28a84 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -1,17 +1,16 @@ import asyncio -import copy import os import pickle from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple -from vllm.engine.ray_utils import RayWorkerVllm, ray +from vllm.engine.ray_utils import RayWorkerWrapper, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async, set_cuda_visible_devices) + make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -74,9 +73,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # The driver dummy worker does not actually use any resources. # It holds the resource for the driver worker. - self.driver_dummy_worker: RayWorkerVllm = None + self.driver_dummy_worker: RayWorkerWrapper = None # The remaining workers are the actual ray actors. - self.workers: List[RayWorkerVllm] = [] + self.workers: List[RayWorkerWrapper] = [] if self.parallel_config.ray_workers_use_nsight: ray_remote_kwargs = self._configure_ray_workers_use_nsight( @@ -97,13 +96,20 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerVllm).remote(self.model_config.trust_remote_code) + )(RayWorkerWrapper).remote( + worker_module_name="vllm.worker.worker", + worker_class_name="Worker", + ) worker_ip = ray.get(worker.get_node_ip.remote()) if worker_ip == driver_ip and self.driver_dummy_worker is None: # If the worker is on the same node as the driver, we use it # as the resource holder for the driver process. self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + worker_module_name="vllm.worker.worker", + worker_class_name="Worker", + ) else: # Else, added to the list of workers. self.workers.append(worker) @@ -115,82 +121,56 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", "GPU node.") # Get the set of GPU IDs used on each node. - driver_node_id, driver_gpu_ids = ray.get( - self.driver_dummy_worker.get_node_and_gpu_ids.remote()) - worker_node_and_gpu_ids = ray.get( - [worker.get_node_and_gpu_ids.remote() for worker in self.workers]) + worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", + use_dummy_driver=True) node_workers = defaultdict(list) node_gpus = defaultdict(list) - node_workers[driver_node_id].append(0) - node_gpus[driver_node_id].extend(driver_gpu_ids) - for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, - start=1): + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): node_workers[node_id].append(i) node_gpus[node_id].extend(gpu_ids) for node_id, gpu_ids in node_gpus.items(): node_gpus[node_id] = sorted(gpu_ids) # Set CUDA_VISIBLE_DEVICES for the driver and workers. - set_cuda_visible_devices(node_gpus[driver_node_id]) - for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids): - worker.set_cuda_visible_devices.remote(node_gpus[node_id]) + all_args_to_update_environment_variables = [] + for (node_id, _) in worker_node_and_gpu_ids: + all_args_to_update_environment_variables.append([{ + "CUDA_VISIBLE_DEVICES": + ",".join(map(str, node_gpus[node_id])) + }]) + self._run_workers("update_environment_variables", + all_args=all_args_to_update_environment_variables) distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker - - model_config = copy.deepcopy(self.model_config) - parallel_config = copy.deepcopy(self.parallel_config) - scheduler_config = copy.deepcopy(self.scheduler_config) - load_config = copy.deepcopy(self.load_config) - device_config = copy.deepcopy(self.device_config) - lora_config = copy.deepcopy(self.lora_config) - cache_config = copy.deepcopy(self.cache_config) - vision_language_config = copy.deepcopy(self.vision_language_config) - - # Initialize the actual workers with the Worker class. - for rank, (worker, (node_id, _)) in enumerate( - zip(self.workers, worker_node_and_gpu_ids), - start=1, - ): + def collect_arg_helper_func(**kwargs): + # avoid writing `{"name": value}` manually + return kwargs + + init_worker_all_kwargs = [] + + # Initialize the actual workers inside worker wrapper. + for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ): local_rank = node_workers[node_id].index(rank) - worker.init_worker.remote( - lambda rank=rank, local_rank=local_rank: Worker( - model_config=model_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - cache_config=cache_config, - load_config=load_config, + init_worker_all_kwargs.append( + collect_arg_helper_func( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + load_config=self.load_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, - lora_config=lora_config, - vision_language_config=vision_language_config, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + is_driver_worker=rank == 0, )) - - # Initialize the driver worker with the Worker class. - driver_rank = 0 - driver_local_rank = node_workers[driver_node_id].index(driver_rank) - self.driver_worker = Worker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - local_rank=driver_local_rank, - rank=driver_rank, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - load_config=self.load_config, - is_driver_worker=True, - ) + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_device") self._run_workers( @@ -279,13 +259,35 @@ def _run_workers( self, method: str, *args, - driver_args: Optional[Tuple[Any, ...]] = None, + driver_args: Optional[Tuple[Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, + all_args: Optional[List[List[Any]]] = None, + all_kwargs: Optional[List[Dict[str, Any]]] = None, + use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: - """Runs the given method on all workers.""" + """Runs the given method on all workers. + all_args and all_kwargs are used to pass heterogeneous arguments, + i.e. different arguments for each worker. + """ + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # for mypy type checking + assert driver_args is not None + assert driver_kwargs is not None + if all_args is None: + all_args = [driver_args] + [args] * len(self.workers) + if all_kwargs is None: + all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers) + + # for mypy type checking + assert all_args is not None + assert all_kwargs is not None if max_concurrent_workers: raise NotImplementedError( @@ -299,8 +301,10 @@ def _run_workers( else: # Start the ray workers first. ray_worker_outputs = [ - worker.execute_method.remote(method, *args, **kwargs) - for worker in self.workers + worker.execute_method.remote(method, *worker_args, + **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(self.workers, all_args[1:], all_kwargs[1:]) ] if driver_args is None: @@ -309,9 +313,13 @@ def _run_workers( driver_kwargs = kwargs # Start the driver worker after all the ray workers. - driver_worker_output = getattr(self.driver_worker, - method)(*driver_args, **driver_kwargs) - + if not use_dummy_driver: + driver_worker_output = self.driver_worker.execute_method( + method, *all_args[0], **all_kwargs[0]) + else: + driver_worker_output = ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *all_args[0], **all_kwargs[0])) # Get the results of the ray workers. if self.workers: if use_ray_compiled_dag: @@ -386,8 +394,12 @@ async def _run_workers_async( driver_kwargs = kwargs # Run the driver worker asynchronously. - driver_executor = make_async(getattr(self.driver_worker, method)) - coros.append(driver_executor(*driver_args, **driver_kwargs)) + def helper(): + return self.driver_worker.execute_method(method, *driver_args, + **driver_kwargs) + + driver_executor = make_async(helper) + coros.append(driver_executor()) # Run the ray workers asynchronously. for worker in self.workers: diff --git a/vllm/utils.py b/vllm/utils.py index aad62516ad1b9..e132575e7bf81 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -271,8 +271,12 @@ def get_open_port() -> int: return s.getsockname()[1] -def set_cuda_visible_devices(device_ids: List[int]) -> None: - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) +def update_environment_variables(envs: Dict[str, str]): + for k, v in envs.items(): + if k in os.environ: + logger.warning(f"Overwriting environment variable {k} " + f"from '{os.environ[k]}' to '{v}'") + os.environ[k] = v def chunk_list(lst, chunk_size): @@ -505,3 +509,11 @@ def merge_dicts(dict1: Dict[Any, List[Any]], merged_dict[key].extend(value) return dict(merged_dict) + + +def init_cached_hf_modules(): + """ + Lazy initialization of the Hugging Face modules. + """ + from transformers.dynamic_module_utils import init_hf_modules + init_hf_modules() diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index afc4a1e1f4630..8468ace5a2fdc 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -138,7 +138,10 @@ def __init__( self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." - + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() self.model_runner = CPUModelRunner(model_config, parallel_config, scheduler_config, diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 142c6c97f5194..d0e6aaed180e6 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -29,6 +29,10 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() self.model_runner = NeuronModelRunner(model_config, parallel_config, scheduler_config, device_config) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index e2b47530d41e4..b021866965401 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -60,6 +60,10 @@ def __init__( if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() self.vision_language_config = vision_language_config if self.vision_language_config: assert not self.lora_config, ( diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index a92f5aea76059..309aa6256acea 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,8 +1,14 @@ +import importlib +import os from abc import ABC, abstractmethod from typing import Dict, List, Tuple +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import update_environment_variables + +logger = init_logger(__name__) class WorkerBase(ABC): @@ -82,3 +88,53 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> List[int]: raise ValueError(f"{type(self)} does not support LoRA") + + +class WorkerWrapperBase: + """ + The whole point of this class is to lazily initialize the worker. + We first instantiate the WorkerWrapper, which remembers the worker module + and class name. Then, when we call `update_environment_variables`, and the + real initialization happens in `init_worker`. + """ + + def __init__(self, + worker_module_name=None, + worker_class_name=None) -> None: + self.worker_module_name = worker_module_name + self.worker_class_name = worker_class_name + self.worker = None + + def update_environment_variables(self, envs: Dict[str, str]) -> None: + key = 'CUDA_VISIBLE_DEVICES' + if key in envs and key in os.environ: + # overwriting CUDA_VISIBLE_DEVICES is desired behavior + # suppress the warning in `update_environment_variables` + del os.environ[key] + update_environment_variables(envs) + + def init_worker(self, *args, **kwargs): + """ + Actual initialization of the worker class. + Arguments are passed to the worker class constructor. + """ + mod = importlib.import_module(self.worker_module_name) + worker_class = getattr(mod, self.worker_class_name) + self.worker = worker_class(*args, **kwargs) + + def execute_method(self, method, *args, **kwargs): + try: + if hasattr(self, method): + executor = getattr(self, method) + else: + executor = getattr(self.worker, method) + return executor(*args, **kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = (f"Error executing method {method}. " + "This might cause deadlock in distributed execution.") + logger.exception(msg) + raise e