Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] replace narrow-usage RayWorkerVllm to general WorkerWrapper to reduce code duplication #4024

Merged
merged 40 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
f7a6356
replace narrow-usage set_cuda_visible_devices to general update_envir…
youkaichao Apr 12, 2024
bbdfc69
add warning when env is overwritten
youkaichao Apr 12, 2024
1e62614
use logger.warning
youkaichao Apr 12, 2024
37eb344
fix env copy
youkaichao Apr 12, 2024
6f64b48
avoid overwritten warning in ray
youkaichao Apr 12, 2024
0499106
fix lint
youkaichao Apr 12, 2024
d26672f
allow heterogeneous args in _run_workers; move update_environment_var…
youkaichao Apr 12, 2024
3a01337
unified init worker
youkaichao Apr 12, 2024
c85d040
fix recursion
youkaichao Apr 12, 2024
5e49b98
on the fly local rank calculation
youkaichao Apr 12, 2024
37ed6c9
post update kwargs
youkaichao Apr 12, 2024
b654ee2
add remote
youkaichao Apr 12, 2024
e11448e
fix update_environment_variables in ray worker
youkaichao Apr 12, 2024
97e6601
use staticmethod
youkaichao Apr 12, 2024
fd2cbe2
fix dummy worker local_rank
youkaichao Apr 12, 2024
a8d7504
fix dummy worker rank
youkaichao Apr 12, 2024
e659635
add WorkerWrapperBase
youkaichao Apr 12, 2024
778fb3f
add all_args to _run_workers
youkaichao Apr 12, 2024
d295107
refactor
youkaichao Apr 12, 2024
7ca22a4
fix dangling self
youkaichao Apr 12, 2024
5f6c8f3
fix execute_method in driver worker
youkaichao Apr 12, 2024
13de66e
withdraw changes in many workers
youkaichao Apr 12, 2024
32ef3bb
no need for init_worker in workerbase
youkaichao Apr 12, 2024
221f626
unify worker_node_and_gpu_ids
youkaichao Apr 12, 2024
0087773
use id rather than ip
youkaichao Apr 12, 2024
36a185e
unify init
youkaichao Apr 12, 2024
95ca917
fix lint
youkaichao Apr 12, 2024
ea5f2a5
finish todo
youkaichao Apr 12, 2024
d10ca88
rename to RayWorkerWrapper
youkaichao Apr 12, 2024
a164219
Merge remote-tracking branch 'origin' into update_env
youkaichao Apr 16, 2024
eb27be9
fix mypy typing
youkaichao Apr 16, 2024
74deb44
move init hf decision to each worker
youkaichao Apr 16, 2024
3bd2c98
use quotes to address white space in env var values
youkaichao Apr 16, 2024
21be004
add docstring
youkaichao Apr 16, 2024
1aee6a0
add config
youkaichao Apr 16, 2024
4337ac6
Merge remote-tracking branch 'origin' into update_env
youkaichao Apr 17, 2024
40d4560
fix _run_workers_async
youkaichao Apr 17, 2024
2509db4
move duplicate code to utils
youkaichao Apr 17, 2024
d1bda36
add docstring
youkaichao Apr 17, 2024
1e30d89
use docstring
youkaichao Apr 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
from vllm.utils import update_environment_variables


def distributed_run(fn, world_size):
Expand All @@ -32,8 +33,7 @@ def update_env(fn):
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapper(env):
import os
os.environ.update(env)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
update_environment_variables(env)
fn()

return wrapper
Expand Down
8 changes: 4 additions & 4 deletions vllm/engine/ray_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pickle
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
from vllm.utils import get_ip, is_hip, update_environment_variables

logger = init_logger(__name__)

Expand Down Expand Up @@ -52,8 +52,8 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids

def set_cuda_visible_devices(self, device_ids) -> None:
set_cuda_visible_devices(device_ids)
def update_environment_variables(self, envs: Dict[str, str]) -> None:
update_environment_variables(envs)

def execute_model_compiled_dag_remote(self, ignored):
"""Used only when compiled DAG is enabled."""
Expand Down
13 changes: 10 additions & 3 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async, set_cuda_visible_devices)
make_async, update_environment_variables)

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
Expand Down Expand Up @@ -134,9 +134,16 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
node_gpus[node_id] = sorted(gpu_ids)

# Set CUDA_VISIBLE_DEVICES for the driver and workers.
set_cuda_visible_devices(node_gpus[driver_node_id])
# ",".join(map(str, device_ids))
update_environment_variables({
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[driver_node_id]))
})
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
worker.update_environment_variables.remote({
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id]))
})

distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
Expand Down
10 changes: 8 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,14 @@ def get_open_port() -> int:
return s.getsockname()[1]


def set_cuda_visible_devices(device_ids: List[int]) -> None:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
def update_environment_variables(envs: Dict[str, str]):
for k, v in envs.items():
if k in os.environ:
warnings.warn(
f"Overwriting environment variable {k} "
f"from {os.environ[k]} to {v}",
stacklevel=2)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
os.environ[k] = v


def chunk_list(lst, chunk_size):
Expand Down
Loading