Skip to content

Commit

Permalink
[Core] Add multiproc_worker_utils for multiprocessing-based workers (
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored May 1, 2024
1 parent 24750f4 commit a657bfc
Show file tree
Hide file tree
Showing 2 changed files with 440 additions and 0 deletions.
176 changes: 176 additions & 0 deletions tests/engine/test_multiproc_workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from time import sleep
from typing import Any, List, Tuple

import pytest

from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)


class DummyWorker:
"""Dummy version of vllm.worker.worker.Worker"""

def __init__(self, rank: int):
self.rank = rank

def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
sleep(0.05)

if isinstance(worker_input, Exception):
# simulate error case
raise worker_input

return self.rank, input


def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
result_handler = ResultHandler()
workers = [
ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank))
for rank in range(8)
]

worker_monitor = WorkerMonitor(workers, result_handler)
assert not worker_monitor.is_alive()

result_handler.start()
worker_monitor.start()
assert worker_monitor.is_alive()

return workers, worker_monitor


def test_local_workers() -> None:
"""Test workers with sync task submission"""

workers, worker_monitor = _start_workers()

def execute_workers(worker_input: str) -> None:
worker_outputs = [
worker.execute_method("worker_method", worker_input)
for worker in workers
]

for rank, output in enumerate(worker_outputs):
assert output.get() == (rank, input)

executor = ThreadPoolExecutor(max_workers=4)

# Test concurrent submission from different threads
futures = [
executor.submit(partial(execute_workers, f"thread {thread_num}"))
for thread_num in range(4)
]

for future in futures:
future.result()

# Test error case
exception = ValueError("fake error")
result = workers[0].execute_method("worker_method", exception)
try:
result.get()
pytest.fail("task should have failed")
except Exception as e:
assert isinstance(e, ValueError)
assert str(e) == "fake error"

# Test cleanup when a worker fails
assert worker_monitor.is_alive()
workers[3].process.kill()

# Other workers should get shut down here
worker_monitor.join(2)

# Ensure everything is stopped
assert not worker_monitor.is_alive()
assert all(not worker.process.is_alive() for worker in workers)

# Further attempts to submit tasks should fail
try:
_result = workers[0].execute_method("worker_method", "test")
pytest.fail("task should fail once workers have been shut down")
except Exception as e:
assert isinstance(e, ChildProcessError)


def test_local_workers_clean_shutdown() -> None:
"""Test clean shutdown"""

workers, worker_monitor = _start_workers()

assert worker_monitor.is_alive()
assert all(worker.process.is_alive() for worker in workers)

# Clean shutdown
worker_monitor.close()

worker_monitor.join(5)

# Ensure everything is stopped
assert not worker_monitor.is_alive()
assert all(not worker.process.is_alive() for worker in workers)

# Further attempts to submit tasks should fail
try:
_result = workers[0].execute_method("worker_method", "test")
pytest.fail("task should fail once workers have been shut down")
except Exception as e:
assert isinstance(e, ChildProcessError)


@pytest.mark.asyncio
async def test_local_workers_async() -> None:
"""Test local workers with async task submission"""

workers, worker_monitor = _start_workers()

async def execute_workers(worker_input: str) -> None:
worker_coros = [
worker.execute_method_async("worker_method", worker_input)
for worker in workers
]

results = await asyncio.gather(*worker_coros)
for rank, result in enumerate(results):
assert result == (rank, input)

tasks = [
asyncio.create_task(execute_workers(f"task {task_num}"))
for task_num in range(4)
]

for task in tasks:
await task

# Test error case
exception = ValueError("fake error")
try:
_result = await workers[0].execute_method_async(
"worker_method", exception)
pytest.fail("task should have failed")
except Exception as e:
assert isinstance(e, ValueError)
assert str(e) == "fake error"

# Test cleanup when a worker fails
assert worker_monitor.is_alive()
workers[3].process.kill()

# Other workers should get shut down here
worker_monitor.join(2)

# Ensure everything is stopped
assert not worker_monitor.is_alive()
assert all(not worker.process.is_alive() for worker in workers)

# Further attempts to submit tasks should fail
try:
_result = await workers[0].execute_method_async(
"worker_method", "test")
pytest.fail("task should fail once workers have been shut down")
except Exception as e:
assert isinstance(e, ChildProcessError)
Loading

0 comments on commit a657bfc

Please sign in to comment.