-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Multiprocessing executor for single-node multi-GPU deployment #3466
Closed
Closed
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
937753b
Make ray optional for single-node deployment
njhill c0bad3d
Use getattr to access model_executor in engine destructor
njhill d0a8709
Address a couple of review comments
njhill 5e214a3
Extend existing distributed correctness test
njhill 0fb0743
Use factory for worker initialization
njhill e048001
Test local worker mechanics in isolation
njhill 1938c35
Add pid prefix to process stdout/stderr instead of logger
njhill 56a1ad4
Update new chunked prefill distributed test to include non-Ray
njhill File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
# Common dependencies | ||
-r requirements-common.txt | ||
|
||
# Dependencies for AMD GPUs | ||
ray == 2.9.3 | ||
# No specific dependencies currently for AMD GPUs | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import asyncio | ||
from concurrent.futures import ThreadPoolExecutor | ||
from functools import partial | ||
from time import sleep | ||
from typing import Any, List, Tuple | ||
|
||
import pytest | ||
|
||
from vllm.engine.local_worker_utils import (LocalWorkerVllm, ResultHandler, | ||
WorkerMonitor) | ||
|
||
|
||
class DummyWorker: | ||
"""Dummy version of vllm.worker.worker""" | ||
|
||
def __init__(self, rank: int): | ||
self.rank = rank | ||
|
||
def worker_method(self, worker_input: Any) -> Tuple[int, Any]: | ||
sleep(0.05) | ||
|
||
if isinstance(worker_input, Exception): | ||
# simulate error case | ||
raise worker_input | ||
|
||
return self.rank, input | ||
|
||
|
||
def _start_workers() -> Tuple[List[LocalWorkerVllm], WorkerMonitor]: | ||
result_handler = ResultHandler() | ||
workers = [ | ||
LocalWorkerVllm(result_handler, partial(DummyWorker, rank=rank)) | ||
for rank in range(8) | ||
] | ||
|
||
for worker in workers: | ||
worker.start() | ||
|
||
worker_monitor = WorkerMonitor(workers, result_handler) | ||
assert not worker_monitor.is_alive() | ||
|
||
result_handler.start() | ||
worker_monitor.start() | ||
assert worker_monitor.is_alive() | ||
|
||
return workers, worker_monitor | ||
|
||
|
||
def test_local_workers() -> None: | ||
"""Test workers with sync task submission""" | ||
|
||
workers, worker_monitor = _start_workers() | ||
|
||
def execute_workers(worker_input: str) -> None: | ||
worker_outputs = [ | ||
worker.execute_method("worker_method", worker_input) | ||
for worker in workers | ||
] | ||
|
||
for rank, output in enumerate(worker_outputs): | ||
assert output.get() == (rank, input) | ||
|
||
executor = ThreadPoolExecutor(max_workers=4) | ||
|
||
# Test concurrent submission from different threads | ||
futures = [ | ||
executor.submit(partial(execute_workers, f"thread {thread_num}")) | ||
for thread_num in range(4) | ||
] | ||
|
||
for future in futures: | ||
future.result() | ||
|
||
# Test error case | ||
exception = ValueError("fake error") | ||
result = workers[0].execute_method("worker_method", exception) | ||
try: | ||
result.get() | ||
pytest.fail("task should have failed") | ||
except Exception as e: | ||
assert isinstance(e, ValueError) | ||
assert str(e) == "fake error" | ||
|
||
# Test cleanup when a worker fails | ||
assert worker_monitor.is_alive() | ||
workers[3].kill() | ||
|
||
# Other workers should get shut down here | ||
worker_monitor.join(2) | ||
|
||
# Ensure everything is stopped | ||
assert not worker_monitor.is_alive() | ||
assert all(not worker.is_alive() for worker in workers) | ||
|
||
# Further attempts to submit tasks should fail | ||
try: | ||
_result = workers[0].execute_method("worker_method", "test") | ||
pytest.fail("task should fail once workers have been shut down") | ||
except Exception as e: | ||
assert isinstance(e, ChildProcessError) | ||
|
||
|
||
def test_local_workers_clean_shutdown() -> None: | ||
"""Test clean shutdown""" | ||
|
||
workers, worker_monitor = _start_workers() | ||
|
||
assert worker_monitor.is_alive() | ||
assert all(worker.is_alive() for worker in workers) | ||
|
||
# Clean shutdown | ||
worker_monitor.close() | ||
|
||
worker_monitor.join(2) | ||
|
||
# Ensure everything is stopped | ||
assert not worker_monitor.is_alive() | ||
assert all(not worker.is_alive() for worker in workers) | ||
|
||
# Further attempts to submit tasks should fail | ||
try: | ||
_result = workers[0].execute_method("worker_method", "test") | ||
pytest.fail("task should fail once workers have been shut down") | ||
except Exception as e: | ||
assert isinstance(e, ChildProcessError) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_local_workers_async() -> None: | ||
"""Test local workers with async task submission""" | ||
|
||
workers, worker_monitor = _start_workers() | ||
|
||
async def execute_workers(worker_input: str) -> None: | ||
worker_coros = [ | ||
worker.execute_method_async("worker_method", worker_input) | ||
for worker in workers | ||
] | ||
|
||
results = await asyncio.gather(*worker_coros) | ||
for rank, result in enumerate(results): | ||
assert result == (rank, input) | ||
|
||
tasks = [ | ||
asyncio.create_task(execute_workers(f"task {task_num}")) | ||
for task_num in range(4) | ||
] | ||
|
||
for task in tasks: | ||
await task | ||
|
||
# Test error case | ||
exception = ValueError("fake error") | ||
try: | ||
_result = await workers[0].execute_method_async( | ||
"worker_method", exception) | ||
pytest.fail("task should have failed") | ||
except Exception as e: | ||
assert isinstance(e, ValueError) | ||
assert str(e) == "fake error" | ||
|
||
# Test cleanup when a worker fails | ||
assert worker_monitor.is_alive() | ||
workers[3].kill() | ||
|
||
# Other workers should get shut down here | ||
worker_monitor.join(2) | ||
|
||
# Ensure everything is stopped | ||
assert not worker_monitor.is_alive() | ||
assert all(not worker.is_alive() for worker in workers) | ||
|
||
# Further attempts to submit tasks should fail | ||
try: | ||
_result = await workers[0].execute_method_async( | ||
"worker_method", "test") | ||
pytest.fail("task should fail once workers have been shut down") | ||
except Exception as e: | ||
assert isinstance(e, ChildProcessError) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we check whether ray is successfully imported in
vllm/engine/ray_utils.py
?