From 51a5f2ec40e92470e8a83f36f8c067ba20578431 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 25 Jun 2024 21:56:02 -0700 Subject: [PATCH] [bugfix][distributed] fix shm broadcast when the queue size is full (#5801) --- tests/distributed/test_shm_broadcast.py | 49 +++++++++---- .../device_communicators/shm_broadcast.py | 73 +++++++++++-------- 2 files changed, 76 insertions(+), 46 deletions(-) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index d92900ffce00b..2c2466f81bb8a 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -1,7 +1,9 @@ import multiprocessing import random import time +from typing import List +import numpy as np import torch.distributed as dist from vllm.distributed.device_communicators.shm_broadcast import ( @@ -9,6 +11,14 @@ from vllm.utils import update_environment_variables +def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]: + np.random.seed(seed) + sizes = np.random.randint(1, 10_000, n) + # on average, each array will have 5k elements + # with int64, each array will have 40kb + return [np.random.randint(1, 100, i) for i in sizes] + + def distributed_run(fn, world_size): number_of_processes = world_size processes = [] @@ -47,24 +57,31 @@ def wrapped_fn(env): def worker_fn(): writer_rank = 2 broadcaster = ShmRingBufferIO.create_from_process_group( - dist.group.WORLD, 1024, 2, writer_rank) + dist.group.WORLD, 1024 * 1024, 2, writer_rank) + if dist.get_rank() == writer_rank: + seed = random.randint(0, 1000) + dist.broadcast_object_list([seed], writer_rank) + else: + recv = [None] + dist.broadcast_object_list(recv, writer_rank) + seed = recv[0] # type: ignore + dist.barrier() + # in case we find a race condition + # print the seed so that we can reproduce the error + print(f"Rank {dist.get_rank()} got seed {seed}") + # test broadcasting with about 400MB of data + N = 10_000 if dist.get_rank() == writer_rank: - time.sleep(random.random()) - broadcaster.broadcast_object(0) - time.sleep(random.random()) - broadcaster.broadcast_object({}) - time.sleep(random.random()) - broadcaster.broadcast_object([]) + arrs = get_arrays(N, seed) + for x in arrs: + broadcaster.broadcast_object(x) + time.sleep(random.random() / 1000) else: - time.sleep(random.random()) - a = broadcaster.broadcast_object(None) - time.sleep(random.random()) - b = broadcaster.broadcast_object(None) - time.sleep(random.random()) - c = broadcaster.broadcast_object(None) - assert a == 0 - assert b == {} - assert c == [] + arrs = get_arrays(N, seed) + for x in arrs: + y = broadcaster.broadcast_object(None) + assert np.array_equal(x, y) + time.sleep(random.random() / 1000) dist.barrier() diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index c44bd2f11ee8b..550271f881df5 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -14,6 +14,12 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL +# time to wait if the queue is full or empty +# if we sleep for too short, it will consume too much CPU +# if we sleep for too long, it will slow down the writer/reader +# 0.1 us is a good balance +RINGBUFFER_SLEEP_INTERVAL = 1e-7 + logger = init_logger(__name__) @@ -145,8 +151,7 @@ def __init__(self, buffer: ShmRingBuffer, reader_rank: int): @contextmanager def acquire_write(self): assert self._is_writer, "Only writers can acquire write" - start_index = self.current_idx - start_time = time.time() + start_time = time.monotonic() n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: @@ -154,19 +159,21 @@ def acquire_write(self): written_flag = metadata_buffer[0] if written_flag and read_count != self.buffer.n_reader: # this block is written and not read by all readers - # try to write to the next block - self.current_idx = (self.current_idx + - 1) % self.buffer.max_chunks - if self.current_idx == start_index: - # no empty block found - if time.time( - ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa - logger.warning( - "No available block found in %s second. ", - VLLM_RINGBUFFER_WARNING_INTERVAL) - n_warning += 1 - # wait for a while (0.1 us) - time.sleep(1e-7) + # for writers, `self.current_idx` is the next block to write + # if this block is not ready to write, + # we need to wait until it is read by all readers + + # wait for a while + time.sleep(RINGBUFFER_SLEEP_INTERVAL) + + # if we wait for a long time, we should warn the user + if time.monotonic( + ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa + logger.warning( + "No available block found in %s second. ", + VLLM_RINGBUFFER_WARNING_INTERVAL) + n_warning += 1 + continue # found a block that is either # (1) not written @@ -188,13 +195,14 @@ def acquire_write(self): metadata_buffer[i] = 0 # mark the block as written metadata_buffer[0] = 1 + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks break @contextmanager def acquire_read(self): assert self._is_reader, "Only readers can acquire read" - start_index = self.current_idx - start_time = time.time() + start_time = time.monotonic() n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: @@ -204,19 +212,22 @@ def acquire_read(self): # this block is either # (1) not written # (2) already read by this reader - # try to read the next block - self.current_idx = (self.current_idx + - 1) % self.buffer.max_chunks - if self.current_idx == start_index: - # no block found - if time.time( - ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa - logger.warning( - "No available block found in %s second. ", - VLLM_RINGBUFFER_WARNING_INTERVAL) - n_warning += 1 - # wait for a while (0.1 us) - time.sleep(1e-7) + + # for readers, `self.current_idx` is the next block to read + # if this block is not ready, + # we need to wait until it is written + + # wait for a while + time.sleep(RINGBUFFER_SLEEP_INTERVAL) + + # if we wait for a long time, we should warn the user + if time.monotonic( + ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa + logger.warning( + "No available block found in %s second. ", + VLLM_RINGBUFFER_WARNING_INTERVAL) + n_warning += 1 + continue # found a block that is not read by this reader # let caller read from the buffer @@ -226,6 +237,8 @@ def acquire_read(self): # caller has read from the buffer # set the read flag metadata_buffer[self.reader_rank + 1] = 1 + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks break def enqueue(self, obj):