Skip to content

Commit

Permalink
[Core] Shut down aDAG workers with clean async llm engine exit (vllm-…
Browse files Browse the repository at this point in the history
…project#7224)

Signed-off-by: Rui Qiao <[email protected]>
  • Loading branch information
ruisearch42 authored Aug 13, 2024
1 parent 774cd1d commit 198d6a2
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 25 deletions.
12 changes: 4 additions & 8 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")

USE_RAY_ADAG_NCCL = 0
USE_RAY_ADAG = 0

pp_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
Expand Down Expand Up @@ -70,14 +67,13 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager")
pp_env = None
if USE_RAY_ADAG:
assert DIST_BACKEND == "ray", (
"Ray ADAG is only supported with Ray distributed backend")
if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2
and CHUNKED_PREFILL):
# Test Ray ADAG for a subset of the tests
pp_env = {
"VLLM_USE_RAY_COMPILED_DAG": "1",
"VLLM_USE_RAY_SPMD_WORKER": "1",
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL":
str(int(USE_RAY_ADAG_NCCL)),
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
}

compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env)
Expand Down
14 changes: 14 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,20 @@ def start_background_loop(self) -> None:
partial(_log_task_completion, error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded)

def shutdown_background_loop(self) -> None:
"""
Shut down the background loop.
This method needs to be called during cleanup to remove
references to `self` and properly GC the resources held
by the async LLM engine (e.g., the executors as well as
their resources).
"""
if self._background_loop_unshielded is not None:
self._background_loop_unshielded.cancel()
self._background_loop_unshielded = None
self.background_loop = None

def _init_engine(self, *args,
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
if not self.engine_use_ray:
Expand Down
17 changes: 11 additions & 6 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,18 @@ def __init__(
if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
else:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None

# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
assert tokenizer_group, ("tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False")
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)

self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
Expand Down Expand Up @@ -356,10 +365,10 @@ def __init__(
self.detokenizer,
self.scheduler,
self.seq_counter,
self.get_tokenizer_for_seq,
get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
self.get_tokenizer_for_seq,
get_tokenizer_for_seq,
),
))

Expand Down Expand Up @@ -491,10 +500,6 @@ def get_tokenizer(
) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)

def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)

def _init_tokenizer(self) -> BaseTokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def cleanup(self):
"""Cleanup all resources."""
self.socket.close()
self.context.destroy()
self.engine.shutdown_background_loop()

async def get_model_config(self, identity):
"""Send the ModelConfig"""
Expand Down
21 changes: 10 additions & 11 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ def _init_executor(self) -> None:
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)

def shutdown(self) -> None:
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.forward_dag = None

def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
Expand Down Expand Up @@ -117,7 +125,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers.
driver_ip = get_ip()
logger.info("driver_ip: %s", driver_ip)
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
Expand Down Expand Up @@ -446,11 +453,7 @@ def _compiled_ray_dag(self, enable_asyncio: bool):
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)

def __del__(self):
if self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.shutdown()


class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
Expand Down Expand Up @@ -523,8 +526,4 @@ async def _start_worker_execution_loop(self):
return await asyncio.gather(*coros)

def __del__(self):
if self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.shutdown()

0 comments on commit 198d6a2

Please sign in to comment.