Skip to content

Commit

Permalink
[core][distributed] use tcp store directly (#10275)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 13, 2024
1 parent 112fa0b commit 0d4ea3f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 25 deletions.
26 changes: 16 additions & 10 deletions tests/distributed/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ def test_cuda_device_count_stateless():


def cpu_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
if rank <= 2:
pg2 = StatelessProcessGroup.create(
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
port=port2,
rank=rank,
world_size=3)
data = torch.tensor([rank])
data = pg1.broadcast_obj(data, src=2)
assert data.item() == 2
Expand All @@ -62,14 +65,17 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2):

def gpu_worker(rank, WORLD_SIZE, port1, port2):
torch.cuda.set_device(rank)
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
pynccl1 = PyNcclCommunicator(pg1, device=rank)
pynccl1.disabled = False
if rank <= 2:
pg2 = StatelessProcessGroup.create(
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
port=port2,
rank=rank,
world_size=3)
pynccl2 = PyNcclCommunicator(pg2, device=rank)
pynccl2.disabled = False
data = torch.tensor([rank]).cuda()
Expand All @@ -89,7 +95,8 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):


def broadcast_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
if rank == 2:
Expand All @@ -101,16 +108,15 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2):


def allgather_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
data = pg1.all_gather_obj(rank)
assert data == list(range(WORLD_SIZE))
pg1.barrier()


# TODO: investigate why this test is flaky. It hangs during initialization.
@pytest.mark.skip("Skip the test because it is flaky.")
@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize(
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
Expand Down
28 changes: 13 additions & 15 deletions vllm/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any, Deque, Dict, Optional, Sequence, Tuple

import torch
from torch.distributed.rendezvous import rendezvous
from torch.distributed import TCPStore

import vllm.envs as envs
from vllm.logger import init_logger
Expand Down Expand Up @@ -97,7 +97,6 @@ class StatelessProcessGroup:
group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects.
"""
prefix: str
rank: int
world_size: int
store: torch._C._distributed_c10d.Store
Expand Down Expand Up @@ -127,7 +126,7 @@ def __post_init__(self):
def send_obj(self, obj: Any, dst: int):
"""Send an object to a destination rank."""
self.expire_data()
key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}"
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(obj))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
Expand All @@ -147,8 +146,7 @@ def recv_obj(self, src: int) -> Any:
"""Receive an object from a source rank."""
obj = pickle.loads(
self.store.get(
f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}"
))
f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
self.recv_src_counter[src] += 1
return obj

Expand All @@ -159,14 +157,14 @@ def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
"""
if self.rank == src:
self.expire_data()
key = (f"{self.prefix}/broadcast_from/{src}/"
key = (f"broadcast_from/{src}/"
f"{self.broadcast_send_counter}")
self.store.set(key, pickle.dumps(obj))
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return obj
else:
key = (f"{self.prefix}/broadcast_from/{src}/"
key = (f"broadcast_from/{src}/"
f"{self.broadcast_recv_src_counter[src]}")
recv_obj = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
Expand Down Expand Up @@ -194,7 +192,8 @@ def barrier(self):

@staticmethod
def create(
init_method: str,
host: str,
port: int,
rank: int,
world_size: int,
data_expiration_seconds: int = 3600,
Expand All @@ -214,15 +213,14 @@ def create(
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
timeout = _DEFAULT_PG_TIMEOUT

store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)
store = TCPStore(
host_name=host,
port=port,
world_size=world_size,
is_master=(rank == 0),
)

return StatelessProcessGroup(
prefix=init_method,
rank=rank,
world_size=world_size,
store=store,
Expand Down

0 comments on commit 0d4ea3f

Please sign in to comment.