Skip to content

Commit

Permalink
Add group as an argument in broadcast ops (vllm-project#2522)
Browse files Browse the repository at this point in the history
  • Loading branch information
GindaChen authored Jan 21, 2024
1 parent 00efdc8 commit 5b23c3f
Showing 1 changed file with 36 additions and 17 deletions.
53 changes: 36 additions & 17 deletions vllm/model_executor/parallel_utils/communication_op.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections import namedtuple
from typing import Any, Dict, List, Optional, Union

from torch.distributed import ProcessGroup

import torch

from vllm.model_executor.parallel_utils.parallel_state import (
Expand Down Expand Up @@ -86,47 +88,59 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
return output_tensor


def broadcast(input_: torch.Tensor, src: int = 0):
def broadcast(input_: torch.Tensor,
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input tensor."""
world_size = torch.distributed.get_world_size()
assert 0 <= src < world_size, f"Invalid src rank ({src})"
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"

# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return input_
# Broadcast.
torch.distributed.broadcast(input_, src=src)
torch.distributed.broadcast(input_, src=src, group=group)
return input_


def broadcast_object_list(obj_list: List[Any], src: int = 0):
def broadcast_object_list(obj_list: List[Any],
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input object list."""
world_size = torch.distributed.get_world_size()
assert 0 <= src < world_size, f"Invalid src rank ({src})"
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"

# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return obj_list
# Broadcast.
torch.distributed.broadcast_object_list(obj_list, src=src)
torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
return obj_list


TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])


def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
Any]]] = None,
src: int = 0) -> Dict[Any, Union[torch.Tensor, Any]]:
def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
) -> Dict[Any, Union[torch.Tensor, Any]]:
"""Broadcast the input tensor dictionary."""
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert 0 <= src < world_size, f"Invalid src rank ({src})"
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"

# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return tensor_dict

rank = torch.distributed.get_rank()
if rank == src:
assert isinstance(
tensor_dict,
Expand All @@ -141,14 +155,18 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
(key, TensorMetadata(value.dtype, value.size())))
else:
metadata_list.append((key, value))
torch.distributed.broadcast_object_list([metadata_list], src=src)
torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=group)
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = tensor_dict[key]
torch.distributed.broadcast(tensor, src=src)
else:
recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list, src=src)
torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=group)
metadata_list = recv_metadata_list[0]
tensor_dict = {}
async_handles = []
Expand All @@ -159,7 +177,8 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
device="cuda")
async_handle = torch.distributed.broadcast(tensor,
src=src,
async_op=True)
async_op=True,
group=group)
async_handles.append(async_handle)
tensor_dict[key] = tensor
else:
Expand Down

0 comments on commit 5b23c3f

Please sign in to comment.