diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 53654dc40d10d..bf0f31df02fa5 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -8,12 +8,11 @@ import ray import torch -from vllm.distributed import (broadcast_tensor_dict, +from vllm.distributed import (broadcast_tensor_dict, get_pp_group, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from ..utils import (init_test_distributed_environment, - multi_process_tensor_parallel) +from ..utils import init_test_distributed_environment, multi_process_parallel @ray.remote(num_gpus=1, max_calls=1) @@ -105,6 +104,68 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, assert torch.allclose(recv_dict["f"], test_dict["f"]) +@ray.remote(num_gpus=1, max_calls=1) +def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, + distributed_init_port: str): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + test_dict = { + # device tensor + "a": torch.arange(8, dtype=torch.float32, device="cuda"), + # CPU tensor + "b": torch.arange(16, dtype=torch.int8, device="cpu"), + "c": "test", + "d": [1, 2, 3], + "e": { + "a": 1, + "b": 2 + }, + # empty tensor + "f": torch.tensor([], dtype=torch.float32, device="cuda"), + } + + if not get_pp_group().is_first_rank: + recv_dict = get_pp_group().recv_tensor_dict() + + if not get_pp_group().is_last_rank: + get_pp_group().send_tensor_dict(test_dict) + + if not get_pp_group().is_first_rank: + assert len(recv_dict) == len(test_dict) + assert torch.allclose(recv_dict["a"], test_dict["a"]) + assert torch.allclose(recv_dict["b"], test_dict["b"]) + assert recv_dict["c"] == test_dict["c"] + assert recv_dict["d"] == test_dict["d"] + assert recv_dict["e"] == test_dict["e"] + assert torch.allclose(recv_dict["f"], test_dict["f"]) + + +@ray.remote(num_gpus=1, max_calls=1) +def send_recv_test_worker(tp_size: int, pp_size: int, rank: int, + distributed_init_port: str): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + size = 64 + test_tensor = torch.arange(64, dtype=torch.float32, device="cuda") + + if not get_pp_group().is_first_rank: + recv_tensor = get_pp_group().recv(size, dtype=torch.float32) + + if not get_pp_group().is_last_rank: + get_pp_group().send(test_tensor) + + if not get_pp_group().is_first_rank: + assert torch.allclose(test_tensor, recv_tensor) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") @pytest.mark.parametrize("tp_size", [2]) @@ -113,4 +174,13 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, broadcast_tensor_dict_test_worker ]) def test_multi_process_tensor_parallel(tp_size, test_target): - multi_process_tensor_parallel(tp_size, 1, test_target) + multi_process_parallel(tp_size, 1, test_target) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("pp_size", [2]) +@pytest.mark.parametrize( + "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]) +def test_multi_process_pipeline_parallel(pp_size, test_target): + multi_process_parallel(1, pp_size, test_target) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 9a39160b8a462..3c281a45fcaf1 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -12,8 +12,7 @@ get_tp_group, graph_capture) from ..utils import (ensure_model_parallel_initialized, - init_test_distributed_environment, - multi_process_tensor_parallel) + init_test_distributed_environment, multi_process_parallel) random.seed(42) test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] @@ -113,4 +112,4 @@ def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target): world_size = tp_size * pipeline_parallel_size if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs to run the test.") - multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target) + multi_process_parallel(tp_size, pipeline_parallel_size, test_target) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 964dbc5423e75..e0e424439e3a5 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -168,9 +168,13 @@ def send_recv_worker_fn(): dtype=torch.float32).cuda(pynccl_comm.rank) with pynccl_comm.change_state(enable=True): if pynccl_comm.rank == 0: - pynccl_comm.send(tensor) + pynccl_comm.send(tensor, + dst=(pynccl_comm.rank + 1) % + pynccl_comm.world_size) else: - pynccl_comm.recv(tensor) + pynccl_comm.recv(tensor, + src=(pynccl_comm.rank - 1) % + pynccl_comm.world_size) result = tensor.mean().cpu().item() assert result == 1 @@ -203,9 +207,13 @@ def multiple_send_recv_worker_fn(): device=device) with pynccl_comm.change_state(enable=True): if torch.distributed.get_rank() in [0, 1]: - pynccl_comm.send(tensor) + pynccl_comm.send(tensor, + dst=(pynccl_comm.rank + 1) % + pynccl_comm.world_size) else: - pynccl_comm.recv(tensor) + pynccl_comm.recv(tensor, + src=(pynccl_comm.rank - 1) % + pynccl_comm.world_size) result = tensor.mean().cpu().item() if torch.distributed.get_rank() in [0, 2]: assert result == 1 diff --git a/tests/utils.py b/tests/utils.py index bc30515c83100..174efca4af532 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -129,7 +129,7 @@ def init_test_distributed_environment( ensure_model_parallel_initialized(tp_size, pp_size) -def multi_process_tensor_parallel( +def multi_process_parallel( tp_size: int, pp_size: int, test_target, diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 83eec264b6f81..7319566545678 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -121,10 +121,7 @@ def all_reduce(self, ncclRedOpTypeEnum.from_torch(op), self.comm, cudaStream_t(stream.cuda_stream)) - def send(self, - tensor: torch.Tensor, - dst: Optional[int] = None, - stream=None): + def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( @@ -132,16 +129,11 @@ def send(self, f"but the input tensor is on {tensor.device}") if stream is None: stream = self.stream - if dst is None: - dst = (self.rank + 1) % self.world_size self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), dst, self.comm, cudaStream_t(stream.cuda_stream)) - def recv(self, - tensor: torch.Tensor, - src: Optional[int] = None, - stream=None): + def recv(self, tensor: torch.Tensor, src: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( @@ -149,8 +141,6 @@ def recv(self, f"but the input tensor is on {tensor.device}") if stream is None: stream = self.stream - if src is None: - src = (self.rank - 1) % self.world_size self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 5188fadbb92a5..5f1decb376af5 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -20,6 +20,7 @@ steps. """ import contextlib +import pickle from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass @@ -28,6 +29,7 @@ from unittest.mock import patch import torch +import torch.distributed from torch.distributed import Backend, ProcessGroup import vllm.envs as envs @@ -180,6 +182,16 @@ def last_rank(self): """Return the global rank of the last process in the group""" return self.ranks[-1] + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + @property def next_rank(self): """Return the global rank of the process that follows the caller""" @@ -374,6 +386,70 @@ def broadcast_object_list(self, group=self.device_group) return obj_list + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank.") + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor([object_tensor.numel()], + dtype=torch.long, + device="cpu") + + # Send object size + + torch.distributed.send(size_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert src != self.rank, ( + "Invalid source rank. Source rank is the same as the current rank." + ) + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv(size_tensor, + src=src, + group=self.cpu_group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu") + + rank_object = torch.distributed.recv(object_tensor, + src=src, + group=self.cpu_group) + + assert rank_object == rank_size, ( + "Received object sender rank does not match the size sender rank.") + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + def broadcast_tensor_dict( self, tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, @@ -459,6 +535,88 @@ def broadcast_tensor_dict( async_handle.wait() return tensor_dict + def send_tensor_dict( + self, + tensor_dict: Dict[Any, Union[torch.Tensor, Any]], + dst: Optional[int] = None + ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = self.next_rank + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send(tensor, dst=dst, group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=dst, group=group) + return None + + def recv_tensor_dict( + self, + src: Optional[int] = None + ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = self.prev_rank + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv(tensor, + src=src, + group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=src, group=group) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + return tensor_dict + def barrier(self): """Barrier synchronization among the group. NOTE: don't use `device_group` here! `barrier` in NCCL is @@ -468,6 +626,35 @@ def barrier(self): """ torch.distributed.barrier(group=self.cpu_group) + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = self.next_rank + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.send(tensor, dst) + else: + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the local rank of the destination rank.""" + if src is None: + src = self.prev_rank + + tensor = torch.empty(size, dtype=dtype, device=self.device) + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.recv(tensor, src) + else: + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + def destroy(self): if self.device_group is not None: torch.distributed.destroy_process_group(self.device_group)