Skip to content

Commit

Permalink
fix: asyncio-friendly nccl operations (#52)
Browse files Browse the repository at this point in the history
NCCL operation in PyTorch's distributed package needs to set up NCCL
communicator so that ranks can talk to one another. To set up the
communicator, c10d key-value store needs to be consulted. This is a
blocking call, which blocks asyncio's loop. This prevents the loop
from scheduling different coroutines. The issue is mitigated by using
run_in_executor().

Note that this doesn't seem to be a permanent fix. Depending on timing,
blocking appears from time to time and leads to an exception whose
example may looks like "torch.distributed.DistBackendError: [1] is
setting up NCCL communicator and retrieving ncclUniqueId from [0]
via c10d key-value store by key '0:1', but store->get('0:1') got
error: Socket Timeout".
  • Loading branch information
myungjin authored Jul 26, 2024
1 parent 6914cb0 commit ef048b1
Showing 1 changed file with 92 additions and 26 deletions.
118 changes: 92 additions & 26 deletions multiworld/world_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import asyncio
import concurrent.futures
import logging
from typing import TYPE_CHECKING

Expand All @@ -33,22 +34,24 @@


_errors_to_handle = [
"Connection closed by peer",
"Connection reset by peer",
"NCCL Error 6",
"NCCL communicator was aborted",
"Connection reset by peer",
"Connection closed by peer",
]


class BrokenWorldException(Exception):
"""Raise this exception when world is broken."""

def __init__(self, world_name: str):
def __init__(self, world_name: str, msg: str):
"""Initialize exception instance."""
self._world_name = world_name
self._msg = msg

def __str__(self):
"""Return exception string."""
return f"broken world: {self._world_name}"
return f"{self._world_name} broken: {self._msg}"

pass

Expand Down Expand Up @@ -93,15 +96,24 @@ async def _wait_work(self, work: Work, world_name: str) -> None:
"""
while not work.is_completed():
if self._broken_world[world_name]:
raise BrokenWorldException(f"{world_name}")
raise BrokenWorldException(world_name, "watchdog raised the exception")
await asyncio.sleep(0)

async def send(
self, tensor: Tensor, dst: int, world_name: str = DEFAULT_WORLD_NAME
) -> None:
"""Send a tensor to a destination in a world."""
try:
work = dist.isend(tensor, dst=dst, name=world_name)
with concurrent.futures.ThreadPoolExecutor() as pool:
work = await self._loop.run_in_executor(
pool,
dist.isend,
tensor,
dst,
None,
0,
world_name,
)
except RuntimeError as e:
self._handle_error(e, world_name)

Expand All @@ -112,7 +124,16 @@ async def recv(
) -> None:
"""Receive a tensor from a specific rank in a world."""
try:
work = dist.irecv(tensor, src=src, name=world_name)
with concurrent.futures.ThreadPoolExecutor() as pool:
work = await self._loop.run_in_executor(
pool,
dist.irecv,
tensor,
src,
None,
0,
world_name,
)
except RuntimeError as e:
self._handle_error(e, world_name)

Expand All @@ -123,7 +144,16 @@ async def broadcast(
) -> None:
"""Broadcast a tensor to the world from a source (src)."""
try:
work = dist.broadcast(tensor, src, async_op=True, name=world_name)
with concurrent.futures.ThreadPoolExecutor() as pool:
work = await self._loop.run_in_executor(
pool,
dist.broadcast,
tensor,
src,
None,
True,
world_name,
)
except RuntimeError as e:
self._handle_error(e, world_name)

Expand All @@ -137,7 +167,16 @@ async def all_reduce(
) -> None:
"""Do all-reduce for a given tensor in a world."""
try:
work = dist.all_reduce(tensor, op, async_op=True, name=world_name)
with concurrent.futures.ThreadPoolExecutor() as pool:
work = await self._loop.run_in_executor(
pool,
dist.all_reduce,
tensor,
op,
None,
True,
world_name,
)
except RuntimeError as e:
self._handle_error(e, world_name)

Expand All @@ -155,7 +194,17 @@ async def reduce(
The rank is a receiver of the final result.
"""
try:
work = dist.reduce(tensor, dst, op, async_op=True, name=world_name)
with concurrent.futures.ThreadPoolExecutor() as pool:
work = await self._loop.run_in_executor(
pool,
dist.reduce,
tensor,
dst,
op,
None,
True,
world_name,
)
except RuntimeError as e:
self._handle_error(e, world_name)

Expand All @@ -169,7 +218,16 @@ async def all_gather(
) -> None:
"""Do all-gather for a given tensor in a world."""
try:
work = dist.all_gather(tensors, tensor, async_op=True, name=world_name)
with concurrent.futures.ThreadPoolExecutor() as pool:
work = await self._loop.run_in_executor(
pool,
dist.all_gather,
tensors,
tensor,
None,
True,
world_name,
)
except RuntimeError as e:
self._handle_error(e, world_name)

Expand All @@ -184,13 +242,17 @@ async def gather(
) -> None:
"""Do gather for a list of tensors in a world."""
try:
work = dist.gather(
tensor,
gahter_list=gather_list,
dst=dst,
async_op=True,
name=world_name,
)
with concurrent.futures.ThreadPoolExecutor() as pool:
work = await self._loop.run_in_executor(
pool,
dist.gather,
tensor,
gather_list,
dst,
None,
True,
world_name,
)
except RuntimeError as e:
self._handle_error(e, world_name)

Expand All @@ -205,13 +267,17 @@ async def scatter(
) -> None:
"""Do scatter for a list of tensors from a source (src) in a world."""
try:
work = dist.scatter(
tensor,
scatter_list=scatter_list,
src=src,
async_op=True,
name=world_name,
)
with concurrent.futures.ThreadPoolExecutor() as pool:
work = await self._loop.run_in_executor(
pool,
dist.scatter,
tensor,
scatter_list,
src,
None,
True,
world_name,
)
except RuntimeError as e:
self._handle_error(e, world_name)

Expand All @@ -224,6 +290,6 @@ def _handle_error(self, error: RuntimeError, world_name: str) -> None:
if error_snippet in error_message:
logger.debug(f"broken world: {error_message}")
self._world_manager.remove_world(world_name)
raise BrokenWorldException(f"{world_name}")
raise BrokenWorldException(world_name, error_message)

raise error

0 comments on commit ef048b1

Please sign in to comment.