Skip to content

Commit

Permalink
[Ray] Integration compiled DAG off by default (vllm-project#2471)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkooo567 authored and jvmncs committed Feb 14, 2024
1 parent e1152b1 commit 593578c
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 7 deletions.
62 changes: 55 additions & 7 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
import os
import time
import pickle
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union)

Expand Down Expand Up @@ -30,6 +31,11 @@
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5

# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))


class LLMEngine:
"""An LLM engine that receives requests and generates texts.
Expand Down Expand Up @@ -124,6 +130,10 @@ def __init__(
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC)

self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()

def get_tokenizer_for_seq(self, sequence: Sequence):
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)

Expand Down Expand Up @@ -806,7 +816,8 @@ def step(self) -> List[RequestOutput]:
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
})
},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)

# Only the driver worker returns the sampling results.
output = all_outputs[0]
Expand Down Expand Up @@ -966,6 +977,7 @@ def _run_workers(
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
Expand All @@ -974,11 +986,16 @@ def _run_workers(
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")

# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
for worker in self.workers
]
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
output_channels = self.forward_dag.execute(1)
else:
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
for worker in self.workers
]

if driver_args is None:
driver_args = args
Expand All @@ -991,6 +1008,37 @@ def _run_workers(

# Get the results of the ray workers.
if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs)
if use_ray_compiled_dag:
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs)

return [driver_worker_output] + ray_worker_outputs

def _compiled_ray_dag(self):
import pkg_resources
required_version = "2.9"
current_version = pkg_resources.get_distribution("ray").version
if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}")

from ray.dag import MultiOutputNode, InputNode
assert self.parallel_config.worker_use_ray

# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
with InputNode() as input_data:
forward_dag = MultiOutputNode([
worker.execute_model_compiled_dag_remote.bind(input_data)
for worker in self.workers
])
return forward_dag.experimental_compile()
18 changes: 18 additions & 0 deletions vllm/engine/ray_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle

from typing import Optional, List, Tuple, TYPE_CHECKING

from vllm.config import ParallelConfig
Expand All @@ -18,6 +20,11 @@ def __init__(self, init_cached_hf_modules=False) -> None:
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
self.worker = None
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread.
self.compiled_dag_cuda_device_set = False

def init_worker(self, worker_init_fn):
self.worker = worker_init_fn()
Expand All @@ -40,6 +47,17 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
def set_cuda_visible_devices(self, device_ids) -> None:
set_cuda_visible_devices(device_ids)

def execute_model_compiled_dag_remote(self, ignored):
"""Used only when compiled DAG is enabled."""
import torch
if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True

output = self.worker.execute_model()
output = pickle.dumps(output)
return output

except ImportError as e:
logger.warning(f"Failed to import Ray with {e!r}. "
"For distributed inference, please install Ray with "
Expand Down

0 comments on commit 593578c

Please sign in to comment.