Skip to content

Commit

Permalink
[SymmetricMemory] introduce user-facing APIs empty() and rendezvous() (
Browse files Browse the repository at this point in the history
…pytorch#139677)

Previously `SymmetricMemory` only had private pybind APIs:
```python
from torch.distributed._symmetric_memory import _SymmetricMemory
t = _SymmetricMemory.empty_strided_p2p(
    size=(64,),
    stride=(1,),
    dtype=torch.float32,
    device=device,
)
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group_name=group.group_name)
```

This PR introduces user-facing APIs empty() and rendezvous():
```python
import torch.distributed._symmetric_memory as symm_mem
t = symm_mem.empty(64, device="cuda")
symm_mem_hdl = symm_mem.rendezvous(t, group_name=group.group_name)
```

Notable differences compared to the pybind APIs:
- `empty()` now resembles `torch.empty()`:
  - shape can either be an integer sequence or pack
  - no need to/can't specify stride anymore
  - device can either be `torch.device` or string
- `group_name` needs to be specified at rendezvous time as opposed to allocation time. See pytorch#139529 for the rationales. I feel the new semantic is superior, hence enforcing it in the public API.
  - Currently, the pybind API still support specifying `group_name` at rendezvous time.

This PR does not change the behavior of the pybind APIs.

Pull Request resolved: pytorch#139677
Approved by: https://github.com/lw
ghstack dependencies: pytorch#139529
  • Loading branch information
yifuwang authored and youssef62 committed Nov 23, 2024
1 parent d492213 commit 20e9221
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 76 deletions.
142 changes: 66 additions & 76 deletions test/distributed/test_symmetric_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _SymmetricMemory
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
Expand Down Expand Up @@ -81,9 +82,25 @@ def _init_process(self):
rank=self.rank,
store=store,
)
enable_symm_mem_for_group(dist.group.WORLD.group_name)
torch.manual_seed(42 + self.rank)

@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_cuda_nvlink_connectivity_detection(self) -> None:
from torch._C._distributed_c10d import _detect_dma_connectivity

connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")
self.assertEqual(connectivity.device_type, DeviceType.CUDA)
self.assertEqual(connectivity.connection_type, "nvlink")
self.assertEqual(len(connectivity.matrix), torch.cuda.device_count())
for row in connectivity.matrix:
self.assertEqual(len(row), torch.cuda.device_count())

@skipIfRocm
def test_large_alloc(self) -> None:
t = symm_mem.empty(2 * 1024**3, dtype=torch.uint8, device="cuda")
self.assertEqual(t.numel() * t.element_size(), 2 * 1024**3)

def _get_test_alloc_args(self):
shape = (64, 64)
stride = (64, 1)
Expand All @@ -92,64 +109,56 @@ def _get_test_alloc_args(self):
group_name = "0"
return (shape, stride, dtype, device, group_name)

def _verify_symmetric_memory(self, symm_mem):
self.assertEqual(symm_mem.world_size, 2)
def _verify_symmetric_memory(self, symm_mem_hdl):
self.assertEqual(symm_mem_hdl.world_size, 2)

buf = symm_mem.get_buffer(0, (symm_mem.buffer_size // 4,), torch.float32)
buf = symm_mem_hdl.get_buffer(
0, (symm_mem_hdl.buffer_size // 4,), torch.float32
)
self.assertEqual(buf.storage_offset(), 0)
self.assertEqual(buf.untyped_storage().size(), symm_mem.buffer_size)
self.assertEqual(buf.untyped_storage().size(), symm_mem_hdl.buffer_size)

if symm_mem.rank == 0:
symm_mem.wait_signal(src_rank=1)
if symm_mem_hdl.rank == 0:
symm_mem_hdl.wait_signal(src_rank=1)
self.assertTrue(buf.eq(42).all())
else:
buf.fill_(42)
symm_mem.put_signal(dst_rank=0)
symm_mem_hdl.put_signal(dst_rank=0)

symm_mem.barrier()
symm_mem_hdl.barrier()

if symm_mem.rank == 0:
symm_mem.barrier()
if symm_mem_hdl.rank == 0:
symm_mem_hdl.barrier()
self.assertTrue(buf.eq(43).all())
else:
buf.fill_(43)
symm_mem.barrier()
symm_mem_hdl.barrier()

symm_mem.barrier()

@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_cuda_nvlink_connectivity_detection(self) -> None:
from torch._C._distributed_c10d import _detect_dma_connectivity

connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")
self.assertEqual(connectivity.device_type, DeviceType.CUDA)
self.assertEqual(connectivity.connection_type, "nvlink")
self.assertEqual(len(connectivity.matrix), torch.cuda.device_count())
for row in connectivity.matrix:
self.assertEqual(len(row), torch.cuda.device_count())
symm_mem_hdl.barrier()

@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_empty_strided_p2p(self) -> None:
self._init_process()
enable_symm_mem_for_group(dist.group.WORLD.group_name)

alloc_args = self._get_test_alloc_args()

t = torch.empty((64, 64), device=self.device)
self.assertIsNone(_SymmetricMemory.rendezvous(t))

t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
symm_mem = _SymmetricMemory.rendezvous(t)
symm_mem_hdl = _SymmetricMemory.rendezvous(t)

del t
self._verify_symmetric_memory(symm_mem)
self._verify_symmetric_memory(symm_mem_hdl)
dist.destroy_process_group()

@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_empty_strided_p2p_persistent(self) -> None:
self._init_process()
enable_symm_mem_for_group(dist.group.WORLD.group_name)

alloc_args = self._get_test_alloc_args()

Expand All @@ -168,51 +177,47 @@ def test_empty_strided_p2p_persistent(self) -> None:
t = _SymmetricMemory.empty_strided_p2p(*alloc_args, alloc_id=42)
self.assertEqual(t.data_ptr(), data_ptr)

symm_mem = _SymmetricMemory.rendezvous(t)
self._verify_symmetric_memory(symm_mem)
symm_mem_hdl = _SymmetricMemory.rendezvous(t)
self._verify_symmetric_memory(symm_mem_hdl)
dist.destroy_process_group()

@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_get_signal_pad(self) -> None:
self._init_process()

t = _SymmetricMemory.empty_strided_p2p(*self._get_test_alloc_args())
symm_mem = _SymmetricMemory.rendezvous(t)
t = symm_mem.empty(1, device="cuda")
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
peer_rank = (self.rank + 1) % self.world_size

signal_pad = symm_mem.get_signal_pad(self.rank)
self.assertEqual(signal_pad.data_ptr(), symm_mem.signal_pad_ptrs[symm_mem.rank])
signal_pad = symm_mem_hdl.get_signal_pad(self.rank)
self.assertEqual(
signal_pad.data_ptr(), symm_mem_hdl.signal_pad_ptrs[symm_mem_hdl.rank]
)

signal_pad = symm_mem.get_signal_pad(peer_rank)
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank)
self.assertEqual(signal_pad.dtype, torch.uint32)
self.assertEqual(signal_pad.numel(), symm_mem.signal_pad_size // 4)
self.assertEqual(signal_pad.numel(), symm_mem_hdl.signal_pad_size // 4)

# Only specify sizes
signal_pad = symm_mem.get_signal_pad(peer_rank, (8, 8))
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, (8, 8))
self.assertEqual(signal_pad.dtype, torch.uint32)
self.assertEqual(signal_pad.numel(), 64)

# Only specify dtype
signal_pad = symm_mem.get_signal_pad(peer_rank, dtype=torch.uint64)
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, dtype=torch.uint64)
self.assertEqual(signal_pad.dtype, torch.uint64)
self.assertEqual(signal_pad.numel(), symm_mem.signal_pad_size // 8)
self.assertEqual(signal_pad.numel(), symm_mem_hdl.signal_pad_size // 8)

# Specify both sizes and dtype
signal_pad = symm_mem.get_signal_pad(peer_rank, (8, 8), dtype=torch.uint64)
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, (8, 8), dtype=torch.uint64)
self.assertEqual(signal_pad.dtype, torch.uint64)
self.assertEqual(signal_pad.numel(), 64)

# Sanity check that writes to buffer doesn't corrupt signal_pad
t = _SymmetricMemory.empty_strided_p2p(
(0,),
(0,),
torch.float32,
self.device,
dist.group.WORLD.group_name,
)
symm_mem = _SymmetricMemory.rendezvous(t)
signal_pad = symm_mem.get_signal_pad(self.rank)
t = symm_mem.empty(0, device="cuda")
symm_mem_hdl = symm_mem.rendezvous(t)
signal_pad = symm_mem_hdl.get_signal_pad(self.rank)
signal_pad.fill_(42)
t.fill_(0)
self.assertTrue(signal_pad.eq(42).all())
Expand All @@ -224,14 +229,12 @@ def test_get_signal_pad(self) -> None:
def test_barrier_timeout(self) -> None:
self._init_process()

alloc_args = self._get_test_alloc_args()

t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
symm_mem = _SymmetricMemory.rendezvous(t)
t = symm_mem.empty(1, device="cuda")
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group=dist.group.WORLD)

if self.rank == 0:
with self.assertRaises(RuntimeError):
symm_mem.barrier(timeout_ms=1000)
symm_mem_hdl.barrier(timeout_ms=1000)
torch.cuda.synchronize()
else:
torch.cuda.synchronize()
Expand All @@ -247,17 +250,15 @@ def test_barrier_timeout(self) -> None:
def test_put_signal_timeout(self) -> None:
self._init_process()

alloc_args = self._get_test_alloc_args()

t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
symm_mem = _SymmetricMemory.rendezvous(t)
t = symm_mem.empty(1, device="cuda")
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group=dist.group.WORLD)

if self.rank == 0:
with self.assertRaises(RuntimeError):
# First, put a signal into rank 1's signal pad. Since rank 1
# doesn't wait on this signal, the subsequent put will timeout.
symm_mem.put_signal(dst_rank=1)
symm_mem.put_signal(dst_rank=1, timeout_ms=1000)
symm_mem_hdl.put_signal(dst_rank=1)
symm_mem_hdl.put_signal(dst_rank=1, timeout_ms=1000)
torch.cuda.synchronize()
else:
torch.cuda.synchronize()
Expand All @@ -273,14 +274,12 @@ def test_put_signal_timeout(self) -> None:
def test_wait_signal_timeout(self) -> None:
self._init_process()

alloc_args = self._get_test_alloc_args()

t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
symm_mem = _SymmetricMemory.rendezvous(t)
t = symm_mem.empty(1, device="cuda")
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group=dist.group.WORLD)

if self.rank == 0:
with self.assertRaises(RuntimeError):
symm_mem.wait_signal(src_rank=1, timeout_ms=1000)
symm_mem_hdl.wait_signal(src_rank=1, timeout_ms=1000)
torch.cuda.synchronize()
else:
torch.cuda.synchronize()
Expand Down Expand Up @@ -685,7 +684,6 @@ def _init_process(self):
rank=self.rank,
store=store,
)
enable_symm_mem_for_group(dist.group.WORLD.group_name)
torch.manual_seed(42 + self.rank)

@skipIfRocm
Expand All @@ -699,18 +697,10 @@ def test_subgroup(self) -> None:

world = dist.group.WORLD
subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1
enable_symm_mem_for_group(subgroup.group_name)

t = _SymmetricMemory.empty_strided_p2p(
size=(64,),
stride=(1,),
dtype=torch.float32,
device=self.device,
)
symm_mem_world = _SymmetricMemory.rendezvous(t, group_name=world.group_name)
symm_mem_subgroup = _SymmetricMemory.rendezvous(
t, group_name=subgroup.group_name
)
t = symm_mem.empty(64, device="cuda")
symm_mem_world = symm_mem.rendezvous(t, group=world)
symm_mem_subgroup = symm_mem.rendezvous(t, group=subgroup)

self.assertEqual(symm_mem_world.world_size, world.size())
self.assertEqual(symm_mem_world.rank, world.rank())
Expand Down
108 changes: 108 additions & 0 deletions torch/distributed/_symmetric_memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory:
_SymmetricMemory: the symmetric memory workspace associated with the
group.
"""
enable_symm_mem_for_group(group_name)

tensor = _group_name_to_workspace_tensor.get(group_name)
size = tensor.numel() * tensor.element_size() if tensor is not None else 0
if tensor is None or size < min_size:
Expand Down Expand Up @@ -1386,3 +1388,109 @@ def _low_contention_reduce_scatter(
return _low_contention_reduce_scatter_with_workspace(
tensor, reduce_op, workspace
)


# =============================================================================
# User-facing APIs
# =============================================================================


from typing import Any, overload, Sequence, TYPE_CHECKING, Union

from torch.types import _device, _dtype, _int


if TYPE_CHECKING:
from torch._C._distributed_c10d import ProcessGroup


@overload
def empty(
*size: _int, dtype: Optional[_dtype] = None, device: Optional[_device] = None
) -> torch.Tensor:
...


@overload
def empty(
size: Sequence[_int],
*,
dtype: Optional[_dtype] = None,
device: Optional[_device] = None,
) -> torch.Tensor:
...


def empty( # type: ignore[misc]
*size: Any,
dtype: Optional[_dtype] = None,
device: Optional[_device] = None,
) -> torch.Tensor:
r"""
empty(*size, *, dtype=None, device=None) -> Tensor
Similar to :func:`torch.empty()`. The returned tensor can be used by
:func:`torch._distributed._symmetric_memory.rendezvous()` to establish a
symmetric memory tensor among participating processes.
Args:
size (int...): a sequence of integers defining the shape of the output tensor.
Can be a variable number of arguments or a collection like a list or tuple.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
device (:class:`torch.device`, optional): the desired device of returned tensor.
Default: if ``None``, uses the current device for the default tensor type
(see :func:`torch.set_default_device`). :attr:`device` will be the CPU
for CPU tensor types and the current CUDA device for CUDA tensor types.
"""
if len(size) == 1 and isinstance(size[0], Sequence):
size = tuple(size[0])
else:
size = tuple(size)

if dtype is None:
dtype = torch.get_default_dtype()

if device is None:
device = torch.get_default_device()

return _SymmetricMemory.empty_strided_p2p(
size=size,
stride=torch._prims_common.make_contiguous_strides_for(size),
dtype=dtype,
device=torch.device(device),
)


def rendezvous(
tensor: torch.Tensor, group: Union[str, "ProcessGroup"]
) -> _SymmetricMemory:
r"""
rendezvous(tensor, group) -> _SymmetricMemory
Establish a symmetric memory tensor among participating processes. This is
a collective operation.
Args:
tensor (:class:`torch.Tensor`): the local tensor used to establish the symmetric memory tensor.
It must be allocated via :func:`torch._distributed._symmetric_memory.empty()`. The shape,
dtype, and device type must be identical across all participating processes.
group (Union[str, :class:`torch.distributed.ProcessGroup`]): The group identifying the
participating processes. This can be either a group name or a process group object.
"""
from torch._C._distributed_c10d import ProcessGroup

if isinstance(group, str):
group_name = group
elif isinstance(group, ProcessGroup):
group_name = group.group_name
else:
raise TypeError(f"rendezvous: unsupported group type: {type(group)}")

enable_symm_mem_for_group(group_name)
return _SymmetricMemory.rendezvous(tensor, group_name)


__all__ = ["empty", "rendezvous"]

0 comments on commit 20e9221

Please sign in to comment.