From 20e9221f235ca0630cf05c61661547a028af9dff Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Sat, 16 Nov 2024 20:36:51 -0800 Subject: [PATCH] [SymmetricMemory] introduce user-facing APIs empty() and rendezvous() (#139677) Previously `SymmetricMemory` only had private pybind APIs: ```python from torch.distributed._symmetric_memory import _SymmetricMemory t = _SymmetricMemory.empty_strided_p2p( size=(64,), stride=(1,), dtype=torch.float32, device=device, ) symm_mem_hdl = _SymmetricMemory.rendezvous(t, group_name=group.group_name) ``` This PR introduces user-facing APIs empty() and rendezvous(): ```python import torch.distributed._symmetric_memory as symm_mem t = symm_mem.empty(64, device="cuda") symm_mem_hdl = symm_mem.rendezvous(t, group_name=group.group_name) ``` Notable differences compared to the pybind APIs: - `empty()` now resembles `torch.empty()`: - shape can either be an integer sequence or pack - no need to/can't specify stride anymore - device can either be `torch.device` or string - `group_name` needs to be specified at rendezvous time as opposed to allocation time. See https://github.com/pytorch/pytorch/pull/139529 for the rationales. I feel the new semantic is superior, hence enforcing it in the public API. - Currently, the pybind API still support specifying `group_name` at rendezvous time. This PR does not change the behavior of the pybind APIs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139677 Approved by: https://github.com/lw ghstack dependencies: #139529 --- test/distributed/test_symmetric_memory.py | 142 ++++++++---------- .../distributed/_symmetric_memory/__init__.py | 108 +++++++++++++ 2 files changed, 174 insertions(+), 76 deletions(-) diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index bbe229ba098971..d02a2dd1005058 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -5,6 +5,7 @@ import torch import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem from torch._C._autograd import DeviceType from torch._C._distributed_c10d import _SymmetricMemory from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code @@ -81,9 +82,25 @@ def _init_process(self): rank=self.rank, store=store, ) - enable_symm_mem_for_group(dist.group.WORLD.group_name) torch.manual_seed(42 + self.rank) + @skipIfRocm + @skip_if_lt_x_gpu(2) + def test_cuda_nvlink_connectivity_detection(self) -> None: + from torch._C._distributed_c10d import _detect_dma_connectivity + + connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink") + self.assertEqual(connectivity.device_type, DeviceType.CUDA) + self.assertEqual(connectivity.connection_type, "nvlink") + self.assertEqual(len(connectivity.matrix), torch.cuda.device_count()) + for row in connectivity.matrix: + self.assertEqual(len(row), torch.cuda.device_count()) + + @skipIfRocm + def test_large_alloc(self) -> None: + t = symm_mem.empty(2 * 1024**3, dtype=torch.uint8, device="cuda") + self.assertEqual(t.numel() * t.element_size(), 2 * 1024**3) + def _get_test_alloc_args(self): shape = (64, 64) stride = (64, 1) @@ -92,47 +109,38 @@ def _get_test_alloc_args(self): group_name = "0" return (shape, stride, dtype, device, group_name) - def _verify_symmetric_memory(self, symm_mem): - self.assertEqual(symm_mem.world_size, 2) + def _verify_symmetric_memory(self, symm_mem_hdl): + self.assertEqual(symm_mem_hdl.world_size, 2) - buf = symm_mem.get_buffer(0, (symm_mem.buffer_size // 4,), torch.float32) + buf = symm_mem_hdl.get_buffer( + 0, (symm_mem_hdl.buffer_size // 4,), torch.float32 + ) self.assertEqual(buf.storage_offset(), 0) - self.assertEqual(buf.untyped_storage().size(), symm_mem.buffer_size) + self.assertEqual(buf.untyped_storage().size(), symm_mem_hdl.buffer_size) - if symm_mem.rank == 0: - symm_mem.wait_signal(src_rank=1) + if symm_mem_hdl.rank == 0: + symm_mem_hdl.wait_signal(src_rank=1) self.assertTrue(buf.eq(42).all()) else: buf.fill_(42) - symm_mem.put_signal(dst_rank=0) + symm_mem_hdl.put_signal(dst_rank=0) - symm_mem.barrier() + symm_mem_hdl.barrier() - if symm_mem.rank == 0: - symm_mem.barrier() + if symm_mem_hdl.rank == 0: + symm_mem_hdl.barrier() self.assertTrue(buf.eq(43).all()) else: buf.fill_(43) - symm_mem.barrier() + symm_mem_hdl.barrier() - symm_mem.barrier() - - @skipIfRocm - @skip_if_lt_x_gpu(2) - def test_cuda_nvlink_connectivity_detection(self) -> None: - from torch._C._distributed_c10d import _detect_dma_connectivity - - connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink") - self.assertEqual(connectivity.device_type, DeviceType.CUDA) - self.assertEqual(connectivity.connection_type, "nvlink") - self.assertEqual(len(connectivity.matrix), torch.cuda.device_count()) - for row in connectivity.matrix: - self.assertEqual(len(row), torch.cuda.device_count()) + symm_mem_hdl.barrier() @skipIfRocm @skip_if_lt_x_gpu(2) def test_empty_strided_p2p(self) -> None: self._init_process() + enable_symm_mem_for_group(dist.group.WORLD.group_name) alloc_args = self._get_test_alloc_args() @@ -140,16 +148,17 @@ def test_empty_strided_p2p(self) -> None: self.assertIsNone(_SymmetricMemory.rendezvous(t)) t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - symm_mem = _SymmetricMemory.rendezvous(t) + symm_mem_hdl = _SymmetricMemory.rendezvous(t) del t - self._verify_symmetric_memory(symm_mem) + self._verify_symmetric_memory(symm_mem_hdl) dist.destroy_process_group() @skipIfRocm @skip_if_lt_x_gpu(2) def test_empty_strided_p2p_persistent(self) -> None: self._init_process() + enable_symm_mem_for_group(dist.group.WORLD.group_name) alloc_args = self._get_test_alloc_args() @@ -168,8 +177,8 @@ def test_empty_strided_p2p_persistent(self) -> None: t = _SymmetricMemory.empty_strided_p2p(*alloc_args, alloc_id=42) self.assertEqual(t.data_ptr(), data_ptr) - symm_mem = _SymmetricMemory.rendezvous(t) - self._verify_symmetric_memory(symm_mem) + symm_mem_hdl = _SymmetricMemory.rendezvous(t) + self._verify_symmetric_memory(symm_mem_hdl) dist.destroy_process_group() @skipIfRocm @@ -177,42 +186,38 @@ def test_empty_strided_p2p_persistent(self) -> None: def test_get_signal_pad(self) -> None: self._init_process() - t = _SymmetricMemory.empty_strided_p2p(*self._get_test_alloc_args()) - symm_mem = _SymmetricMemory.rendezvous(t) + t = symm_mem.empty(1, device="cuda") + symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD) peer_rank = (self.rank + 1) % self.world_size - signal_pad = symm_mem.get_signal_pad(self.rank) - self.assertEqual(signal_pad.data_ptr(), symm_mem.signal_pad_ptrs[symm_mem.rank]) + signal_pad = symm_mem_hdl.get_signal_pad(self.rank) + self.assertEqual( + signal_pad.data_ptr(), symm_mem_hdl.signal_pad_ptrs[symm_mem_hdl.rank] + ) - signal_pad = symm_mem.get_signal_pad(peer_rank) + signal_pad = symm_mem_hdl.get_signal_pad(peer_rank) self.assertEqual(signal_pad.dtype, torch.uint32) - self.assertEqual(signal_pad.numel(), symm_mem.signal_pad_size // 4) + self.assertEqual(signal_pad.numel(), symm_mem_hdl.signal_pad_size // 4) # Only specify sizes - signal_pad = symm_mem.get_signal_pad(peer_rank, (8, 8)) + signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, (8, 8)) self.assertEqual(signal_pad.dtype, torch.uint32) self.assertEqual(signal_pad.numel(), 64) # Only specify dtype - signal_pad = symm_mem.get_signal_pad(peer_rank, dtype=torch.uint64) + signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, dtype=torch.uint64) self.assertEqual(signal_pad.dtype, torch.uint64) - self.assertEqual(signal_pad.numel(), symm_mem.signal_pad_size // 8) + self.assertEqual(signal_pad.numel(), symm_mem_hdl.signal_pad_size // 8) # Specify both sizes and dtype - signal_pad = symm_mem.get_signal_pad(peer_rank, (8, 8), dtype=torch.uint64) + signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, (8, 8), dtype=torch.uint64) self.assertEqual(signal_pad.dtype, torch.uint64) self.assertEqual(signal_pad.numel(), 64) # Sanity check that writes to buffer doesn't corrupt signal_pad - t = _SymmetricMemory.empty_strided_p2p( - (0,), - (0,), - torch.float32, - self.device, - dist.group.WORLD.group_name, - ) - symm_mem = _SymmetricMemory.rendezvous(t) - signal_pad = symm_mem.get_signal_pad(self.rank) + t = symm_mem.empty(0, device="cuda") + symm_mem_hdl = symm_mem.rendezvous(t) + signal_pad = symm_mem_hdl.get_signal_pad(self.rank) signal_pad.fill_(42) t.fill_(0) self.assertTrue(signal_pad.eq(42).all()) @@ -224,14 +229,12 @@ def test_get_signal_pad(self) -> None: def test_barrier_timeout(self) -> None: self._init_process() - alloc_args = self._get_test_alloc_args() - - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - symm_mem = _SymmetricMemory.rendezvous(t) + t = symm_mem.empty(1, device="cuda") + symm_mem_hdl = _SymmetricMemory.rendezvous(t, group=dist.group.WORLD) if self.rank == 0: with self.assertRaises(RuntimeError): - symm_mem.barrier(timeout_ms=1000) + symm_mem_hdl.barrier(timeout_ms=1000) torch.cuda.synchronize() else: torch.cuda.synchronize() @@ -247,17 +250,15 @@ def test_barrier_timeout(self) -> None: def test_put_signal_timeout(self) -> None: self._init_process() - alloc_args = self._get_test_alloc_args() - - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - symm_mem = _SymmetricMemory.rendezvous(t) + t = symm_mem.empty(1, device="cuda") + symm_mem_hdl = _SymmetricMemory.rendezvous(t, group=dist.group.WORLD) if self.rank == 0: with self.assertRaises(RuntimeError): # First, put a signal into rank 1's signal pad. Since rank 1 # doesn't wait on this signal, the subsequent put will timeout. - symm_mem.put_signal(dst_rank=1) - symm_mem.put_signal(dst_rank=1, timeout_ms=1000) + symm_mem_hdl.put_signal(dst_rank=1) + symm_mem_hdl.put_signal(dst_rank=1, timeout_ms=1000) torch.cuda.synchronize() else: torch.cuda.synchronize() @@ -273,14 +274,12 @@ def test_put_signal_timeout(self) -> None: def test_wait_signal_timeout(self) -> None: self._init_process() - alloc_args = self._get_test_alloc_args() - - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - symm_mem = _SymmetricMemory.rendezvous(t) + t = symm_mem.empty(1, device="cuda") + symm_mem_hdl = _SymmetricMemory.rendezvous(t, group=dist.group.WORLD) if self.rank == 0: with self.assertRaises(RuntimeError): - symm_mem.wait_signal(src_rank=1, timeout_ms=1000) + symm_mem_hdl.wait_signal(src_rank=1, timeout_ms=1000) torch.cuda.synchronize() else: torch.cuda.synchronize() @@ -685,7 +684,6 @@ def _init_process(self): rank=self.rank, store=store, ) - enable_symm_mem_for_group(dist.group.WORLD.group_name) torch.manual_seed(42 + self.rank) @skipIfRocm @@ -699,18 +697,10 @@ def test_subgroup(self) -> None: world = dist.group.WORLD subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1 - enable_symm_mem_for_group(subgroup.group_name) - t = _SymmetricMemory.empty_strided_p2p( - size=(64,), - stride=(1,), - dtype=torch.float32, - device=self.device, - ) - symm_mem_world = _SymmetricMemory.rendezvous(t, group_name=world.group_name) - symm_mem_subgroup = _SymmetricMemory.rendezvous( - t, group_name=subgroup.group_name - ) + t = symm_mem.empty(64, device="cuda") + symm_mem_world = symm_mem.rendezvous(t, group=world) + symm_mem_subgroup = symm_mem.rendezvous(t, group=subgroup) self.assertEqual(symm_mem_world.world_size, world.size()) self.assertEqual(symm_mem_world.rank, world.rank()) diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index 38c927f37d8671..80c2bdb711fcb2 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -110,6 +110,8 @@ def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory: _SymmetricMemory: the symmetric memory workspace associated with the group. """ + enable_symm_mem_for_group(group_name) + tensor = _group_name_to_workspace_tensor.get(group_name) size = tensor.numel() * tensor.element_size() if tensor is not None else 0 if tensor is None or size < min_size: @@ -1386,3 +1388,109 @@ def _low_contention_reduce_scatter( return _low_contention_reduce_scatter_with_workspace( tensor, reduce_op, workspace ) + + +# ============================================================================= +# User-facing APIs +# ============================================================================= + + +from typing import Any, overload, Sequence, TYPE_CHECKING, Union + +from torch.types import _device, _dtype, _int + + +if TYPE_CHECKING: + from torch._C._distributed_c10d import ProcessGroup + + +@overload +def empty( + *size: _int, dtype: Optional[_dtype] = None, device: Optional[_device] = None +) -> torch.Tensor: + ... + + +@overload +def empty( + size: Sequence[_int], + *, + dtype: Optional[_dtype] = None, + device: Optional[_device] = None, +) -> torch.Tensor: + ... + + +def empty( # type: ignore[misc] + *size: Any, + dtype: Optional[_dtype] = None, + device: Optional[_device] = None, +) -> torch.Tensor: + r""" + empty(*size, *, dtype=None, device=None) -> Tensor + + Similar to :func:`torch.empty()`. The returned tensor can be used by + :func:`torch._distributed._symmetric_memory.rendezvous()` to establish a + symmetric memory tensor among participating processes. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + """ + if len(size) == 1 and isinstance(size[0], Sequence): + size = tuple(size[0]) + else: + size = tuple(size) + + if dtype is None: + dtype = torch.get_default_dtype() + + if device is None: + device = torch.get_default_device() + + return _SymmetricMemory.empty_strided_p2p( + size=size, + stride=torch._prims_common.make_contiguous_strides_for(size), + dtype=dtype, + device=torch.device(device), + ) + + +def rendezvous( + tensor: torch.Tensor, group: Union[str, "ProcessGroup"] +) -> _SymmetricMemory: + r""" + rendezvous(tensor, group) -> _SymmetricMemory + + Establish a symmetric memory tensor among participating processes. This is + a collective operation. + + Args: + tensor (:class:`torch.Tensor`): the local tensor used to establish the symmetric memory tensor. + It must be allocated via :func:`torch._distributed._symmetric_memory.empty()`. The shape, + dtype, and device type must be identical across all participating processes. + group (Union[str, :class:`torch.distributed.ProcessGroup`]): The group identifying the + participating processes. This can be either a group name or a process group object. + """ + from torch._C._distributed_c10d import ProcessGroup + + if isinstance(group, str): + group_name = group + elif isinstance(group, ProcessGroup): + group_name = group.group_name + else: + raise TypeError(f"rendezvous: unsupported group type: {type(group)}") + + enable_symm_mem_for_group(group_name) + return _SymmetricMemory.rendezvous(tensor, group_name) + + +__all__ = ["empty", "rendezvous"]