diff --git a/examples/all_gather/m8d.py b/examples/all_gather/m8d.py index e61c8a4..f984297 100644 --- a/examples/all_gather/m8d.py +++ b/examples/all_gather/m8d.py @@ -159,10 +159,6 @@ async def main(args): # for example: --worldinfo 1,0` means world with the index 1 will have a rank 0 parser.add_argument("--worldinfo", type=str, action="append") - # https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126 - # "2" is CleanUpOnly - os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2" - args = parser.parse_args() loop = asyncio.get_event_loop() diff --git a/examples/all_reduce/m8d.py b/examples/all_reduce/m8d.py index 32b8bd1..0286d04 100644 --- a/examples/all_reduce/m8d.py +++ b/examples/all_reduce/m8d.py @@ -150,10 +150,6 @@ async def main(args): parser.add_argument("--addr", default="127.0.0.1") parser.add_argument("--worldinfo", type=str, action="append") - # https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126 - # "2" is CleanUpOnly - os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2" - args = parser.parse_args() loop = asyncio.get_event_loop() diff --git a/examples/broadcast/m8d.py b/examples/broadcast/m8d.py index 497bbf3..efa7482 100644 --- a/examples/broadcast/m8d.py +++ b/examples/broadcast/m8d.py @@ -154,10 +154,6 @@ async def main(args): # for example: --worldinfo 1,0` means world with the index 1 will have a rank 0 parser.add_argument("--worldinfo", type=str, action="append") - # https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126 - # "2" is CleanUpOnly - os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2" - args = parser.parse_args() loop = asyncio.get_event_loop() diff --git a/examples/reduce/m8d.py b/examples/reduce/m8d.py index 3d137c2..9b9c1d0 100644 --- a/examples/reduce/m8d.py +++ b/examples/reduce/m8d.py @@ -95,7 +95,12 @@ async def reduce(world_name, world_size, rank, backend): if dst == rank: print( - "Rank ", rank, " within world ", world_name, " has reduced tensor", tensor + "Rank ", + rank, + " within world ", + world_name, + " has reduced tensor", + tensor, ) print(f"done with step: {step}") @@ -154,10 +159,6 @@ async def main(args): parser.add_argument("--addr", default="127.0.0.1") parser.add_argument("--worldinfo", type=str, action="append") - # https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126 - # "2" is CleanUpOnly - os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2" - args = parser.parse_args() loop = asyncio.get_event_loop() diff --git a/examples/resnet/m8d.py b/examples/resnet/m8d.py index e96a486..b9c63c8 100644 --- a/examples/resnet/m8d.py +++ b/examples/resnet/m8d.py @@ -384,10 +384,6 @@ async def multi_host(args): "--multihost", action=argparse.BooleanOptionalAction, default=False ) - # https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126 - # "2" is CleanUpOnly - os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2" - args = parser.parse_args() atexit.register(cleanup) diff --git a/examples/send_recv/m8d.py b/examples/send_recv/m8d.py index 71e8beb..1c21cd6 100644 --- a/examples/send_recv/m8d.py +++ b/examples/send_recv/m8d.py @@ -191,10 +191,6 @@ async def main(args): parser.add_argument("--addr", default="127.0.0.1") parser.add_argument("--worldinfo", type=str, action="append") - # https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126 - # "2" is CleanUpOnly - os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2" - args = parser.parse_args() loop = asyncio.get_event_loop() diff --git a/multiworld/world_manager.py b/multiworld/world_manager.py index 834852e..b1824fd 100644 --- a/multiworld/world_manager.py +++ b/multiworld/world_manager.py @@ -38,6 +38,12 @@ class WorldManager: def __init__(self, enable_monitor=True): """Initialize a world manager.""" + # https://github.com/pytorch/pytorch/blob/v2.4.0/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L118-L130 + # "2" is CleanUpOnly + # We use CleanupOnly in order to allow error handling at user process + # level without tearing down the process. + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2" + self._worlds_stores: dict[str, dist.TCPStore] = dict() self._communicator = WorldCommunicator(self) self._current_world = ""