From 5b23c3f26fcf08fdbe2b7e6beecb4d970e632897 Mon Sep 17 00:00:00 2001 From: Junda Chen <32371474+GindaChen@users.noreply.github.com> Date: Sat, 20 Jan 2024 16:00:26 -0800 Subject: [PATCH] Add `group` as an argument in broadcast ops (#2522) --- .../parallel_utils/communication_op.py | 53 +++++++++++++------ 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 64992d05527e8..fff6920be72b0 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -1,6 +1,8 @@ from collections import namedtuple from typing import Any, Dict, List, Optional, Union +from torch.distributed import ProcessGroup + import torch from vllm.model_executor.parallel_utils.parallel_state import ( @@ -86,47 +88,59 @@ def tensor_model_parallel_gather(input_: torch.Tensor, return output_tensor -def broadcast(input_: torch.Tensor, src: int = 0): +def broadcast(input_: torch.Tensor, + src: int = 0, + group: Optional[ProcessGroup] = None): """Broadcast the input tensor.""" - world_size = torch.distributed.get_world_size() - assert 0 <= src < world_size, f"Invalid src rank ({src})" + group = group or torch.distributed.group.WORLD + ranks = torch.distributed.get_process_group_ranks(group) + assert src in ranks, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. + world_size = torch.distributed.get_world_size(group=group) if world_size == 1: return input_ # Broadcast. - torch.distributed.broadcast(input_, src=src) + torch.distributed.broadcast(input_, src=src, group=group) return input_ -def broadcast_object_list(obj_list: List[Any], src: int = 0): +def broadcast_object_list(obj_list: List[Any], + src: int = 0, + group: Optional[ProcessGroup] = None): """Broadcast the input object list.""" - world_size = torch.distributed.get_world_size() - assert 0 <= src < world_size, f"Invalid src rank ({src})" + group = group or torch.distributed.group.WORLD + ranks = torch.distributed.get_process_group_ranks(group) + assert src in ranks, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. + world_size = torch.distributed.get_world_size(group=group) if world_size == 1: return obj_list # Broadcast. - torch.distributed.broadcast_object_list(obj_list, src=src) + torch.distributed.broadcast_object_list(obj_list, src=src, group=group) return obj_list TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"]) -def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, - Any]]] = None, - src: int = 0) -> Dict[Any, Union[torch.Tensor, Any]]: +def broadcast_tensor_dict( + tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, +) -> Dict[Any, Union[torch.Tensor, Any]]: """Broadcast the input tensor dictionary.""" - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - assert 0 <= src < world_size, f"Invalid src rank ({src})" + group = group or torch.distributed.group.WORLD + ranks = torch.distributed.get_process_group_ranks(group) + assert src in ranks, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. + world_size = torch.distributed.get_world_size(group=group) if world_size == 1: return tensor_dict + rank = torch.distributed.get_rank() if rank == src: assert isinstance( tensor_dict, @@ -141,14 +155,18 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, (key, TensorMetadata(value.dtype, value.size()))) else: metadata_list.append((key, value)) - torch.distributed.broadcast_object_list([metadata_list], src=src) + torch.distributed.broadcast_object_list([metadata_list], + src=src, + group=group) for key, value in metadata_list: if isinstance(value, TensorMetadata): tensor = tensor_dict[key] torch.distributed.broadcast(tensor, src=src) else: recv_metadata_list = [None] - torch.distributed.broadcast_object_list(recv_metadata_list, src=src) + torch.distributed.broadcast_object_list(recv_metadata_list, + src=src, + group=group) metadata_list = recv_metadata_list[0] tensor_dict = {} async_handles = [] @@ -159,7 +177,8 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, device="cuda") async_handle = torch.distributed.broadcast(tensor, src=src, - async_op=True) + async_op=True, + group=group) async_handles.append(async_handle) tensor_dict[key] = tensor else: