Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix clean shutdown issues #8492

Merged
merged 4 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class RequestOutput:
finished: bool = False


@dataclass
class MockModelConfig:
use_async_output_proc = True


class MockEngine:

def __init__(self):
Expand All @@ -35,6 +40,7 @@ def __init__(self):
self.request_id = None
# Ugly, remove dependency when possible
self.parallel_config = ParallelConfig(1, 1, False)
self.model_config = MockModelConfig()

async def step_async(self, virtual_engine):
# PP size is 1, ignore virtual engine
Expand Down Expand Up @@ -80,7 +86,7 @@ class MockAsyncLLMEngine(AsyncLLMEngine):

@pytest.mark.asyncio
async def test_new_requests_event():
engine = MockAsyncLLMEngine(worker_use_ray=False)
engine = MockAsyncLLMEngine()
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0
Expand Down Expand Up @@ -113,7 +119,7 @@ async def test_new_requests_event():
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1

engine = MockAsyncLLMEngine(worker_use_ray=True)
engine = MockAsyncLLMEngine()
assert engine.get_model_config() is not None
assert engine.get_tokenizer() is not None
assert engine.get_decoding_config() is not None
Expand Down
70 changes: 45 additions & 25 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import time
import weakref
from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)
from weakref import ReferenceType

import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
Expand All @@ -26,6 +28,7 @@
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import weak_bind

logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
Expand Down Expand Up @@ -450,9 +453,6 @@ class AsyncLLMEngine:
method yields the outputs from the :class:`LLMEngine` to the caller.

Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
Expand All @@ -463,23 +463,22 @@ class AsyncLLMEngine:
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

def __init__(self,
worker_use_ray: bool,
*args,
log_requests: bool = True,
start_engine_loop: bool = True,
**kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.log_requests = log_requests
self.engine = self._engine_class(*args, **kwargs)

# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# especially for multi-step.
#
self.use_process_request_outputs_callback = True
self.use_process_request_outputs_callback = (
self.engine.model_config.use_async_output_proc)

if self.use_process_request_outputs_callback:
self.engine.process_request_outputs_callback = \
self.process_request_outputs
weak_bind(self.process_request_outputs)

self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded
Expand All @@ -492,6 +491,11 @@ def __init__(self,
# Lazy initialized fields
self._request_tracker: RequestTracker

def __del__(self):
if rt := getattr(self, "request_tracker", None):
# Wake up engine loop so that it will exit cleanly
rt.new_requests_event.set()

@classmethod
def _get_executor_cls(
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
Expand All @@ -502,15 +506,12 @@ def _get_executor_cls(
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
if distributed_executor_backend.uses_ray: # type: ignore
initialize_ray_cluster(engine_config.parallel_config)
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
executor_class = RayTPUExecutorAsync
else:
Expand All @@ -531,19 +532,16 @@ def _get_executor_cls(
from vllm.executor.xpu_executor import XPUExecutorAsync
executor_class = XPUExecutorAsync
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
executor_class = RayXPUExecutorAsync
elif distributed_executor_backend == "mp":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.multiproc_xpu_executor import (
MultiprocessingXPUExecutorAsync)
executor_class = MultiprocessingXPUExecutorAsync
else:
raise RuntimeError(
"Not supported distributed execution model on XPU device.")
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp":
Expand All @@ -559,19 +557,23 @@ def _get_executor_cls(
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[EngineConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
if engine_config is None:
engine_config = engine_args.create_engine_config()

executor_class = cls._get_executor_cls(engine_config)

if executor_class.uses_ray:
initialize_ray_cluster(engine_config.parallel_config)

# Create the async LLM engine.
engine = cls(
executor_class.uses_ray,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
Expand Down Expand Up @@ -628,7 +630,7 @@ def start_background_loop(self) -> None:
self._request_tracker = RequestTracker()

self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop())
).create_task(self.run_engine_loop(weakref.ref(self)))
self._background_loop_unshielded.add_done_callback(
partial(_log_task_completion, error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded)
Expand Down Expand Up @@ -698,9 +700,16 @@ def process_request_outputs(self, request_outputs) -> bool:
async def _engine_abort(self, request_ids: Iterable[str]):
self.engine.abort_request(request_ids)

async def run_engine_loop(self):
@staticmethod
async def run_engine_loop(engine_ref: ReferenceType):
"""We use a weakref to the engine so that the running loop
doesn't prevent the engine being garbage collected."""
engine: Optional["AsyncLLMEngine"] = engine_ref()
if not engine:
return

pipeline_parallel_size = \
self.engine.parallel_config.pipeline_parallel_size
engine.engine.parallel_config.pipeline_parallel_size
has_requests_in_progress = [False] * pipeline_parallel_size
while True:
if not any(has_requests_in_progress):
Expand All @@ -711,11 +720,21 @@ async def run_engine_loop(self):
# timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages,
# such as add/remove lora adapters.
await self.engine.stop_remote_worker_execution_loop_async()
await self._request_tracker.wait_for_new_requests()
await engine.engine.stop_remote_worker_execution_loop_async()
request_tracker = engine._request_tracker
# Allow engine to be garbage collected while
# waiting for new requests
del engine
await asyncio.sleep(0)
if engine_ref() is None:
return
await request_tracker.wait_for_new_requests()
engine = engine_ref()
if not engine:
return
logger.debug("Got new requests!")
requests_in_progress = [
asyncio.create_task(self.engine_step(ve))
asyncio.create_task(engine.engine_step(ve))
for ve in range(pipeline_parallel_size)
]
has_requests_in_progress = [True] * pipeline_parallel_size
Expand All @@ -733,19 +752,20 @@ async def run_engine_loop(self):
result = task.result()
virtual_engine = requests_in_progress.index(task)
has_unfinished_requests = (
self.engine.has_unfinished_requests_for_virtual_engine(
engine.engine.
has_unfinished_requests_for_virtual_engine(
virtual_engine))
if result or has_unfinished_requests:
requests_in_progress[virtual_engine] = (
asyncio.create_task(
self.engine_step(virtual_engine)))
engine.engine_step(virtual_engine)))
has_requests_in_progress[virtual_engine] = True
else:
has_requests_in_progress[virtual_engine] = False
except asyncio.TimeoutError as exc:
logger.error(
"Engine iteration timed out. This should never happen!")
self.set_errored(exc)
engine.set_errored(exc)
raise
await asyncio.sleep(0)

Expand Down
21 changes: 13 additions & 8 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import functools
import time
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
Expand Down Expand Up @@ -51,7 +51,7 @@
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device
from vllm.utils import Counter, Device, weak_bind
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -382,11 +382,16 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
for _ in range(self.parallel_config.pipeline_parallel_size)
]

self.async_callbacks = [
functools.partial(self._process_model_outputs,
ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
if model_config.use_async_output_proc:
process_model_outputs = weak_bind(self._process_model_outputs)

self.async_callbacks = [
partial(process_model_outputs,
ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
else:
self.async_callbacks = []

# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
Expand Down Expand Up @@ -869,8 +874,8 @@ def has_unfinished_requests_for_virtual_engine(
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()

@staticmethod
def _process_sequence_group_outputs(
self,
seq_group: SequenceGroup,
outputs: List[EmbeddingSequenceGroupOutput],
) -> None:
Expand Down
21 changes: 10 additions & 11 deletions vllm/entrypoints/launcher.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import asyncio
import signal
from http import HTTPStatus
from typing import Any
from typing import Any, Optional

import uvicorn
from fastapi import FastAPI, Response
from fastapi import FastAPI, Request, Response

from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.protocol import AsyncEngineClient
from vllm.logger import init_logger
from vllm.utils import find_process_using_port

logger = init_logger(__name__)


async def serve_http(app: FastAPI, engine: AsyncEngineClient,
async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
**uvicorn_kwargs: Any):
logger.info("Available routes are:")
for route in app.routes:
Expand All @@ -29,16 +28,16 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,

# Set concurrency limits in uvicorn if running in multiprocessing mode
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
if engine.limit_concurrency is not None:
if limit_concurrency is not None:
logger.info(
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
"limit at the expense of performance run with "
"--disable-frontend-multiprocessing", engine.limit_concurrency)
uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency
"--disable-frontend-multiprocessing", limit_concurrency)
uvicorn_kwargs["limit_concurrency"] = limit_concurrency

config = uvicorn.Config(app, **uvicorn_kwargs)
server = uvicorn.Server(config)
_add_shutdown_handlers(app, server, engine)
_add_shutdown_handlers(app, server)

loop = asyncio.get_running_loop()

Expand Down Expand Up @@ -68,15 +67,15 @@ async def dummy_shutdown() -> None:
return server.shutdown()


def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
engine: AsyncEngineClient) -> None:
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""Adds handlers for fatal errors that should crash the server"""

@app.exception_handler(RuntimeError)
async def runtime_error_handler(_, __):
async def runtime_error_handler(request: Request, __):
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
engine = request.app.state.engine_client
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
and not engine.is_running):
logger.fatal("AsyncLLMEngine has failed, terminating server "
Expand Down
Loading
Loading