From 58d456ebec2b682d5701340c3dfee3528478dbfb Mon Sep 17 00:00:00 2001 From: Myungjin Lee Date: Tue, 20 Aug 2024 19:08:58 -0700 Subject: [PATCH] refactor: restructuring packaging Currently, world_manager.py and world_communicator.py are copied into torch package and they become part of torch after installing multiworld. This is inappropriate. We keep those files under multiworld package, but we patch pytorch such that the variables (_worlds, _World, etc) required by multiworld are exposed by patching __init__.py in distributed module. Therefore, pytorch patch file is updated and all other necessary changes are made to ensure that all the examples can be executed. --- examples/all_gather/m8d.py | 4 +- examples/all_reduce/m8d.py | 3 +- examples/broadcast/m8d.py | 4 +- examples/gather/m8d.py | 4 +- examples/reduce/m8d.py | 3 +- examples/resnet/m8d.py | 4 +- examples/scatter/m8d.py | 4 +- examples/send_recv/m8d.py | 5 +-- ...{world_communicator.py => communicator.py} | 5 +-- multiworld/{world_manager.py => manager.py} | 15 ++++--- multiworld/patch/pytorch-v2.4.0.patch | 44 ++++++++++--------- multiworld/post_setup.py | 23 ++++++---- 12 files changed, 63 insertions(+), 55 deletions(-) rename multiworld/{world_communicator.py => communicator.py} (98%) rename multiworld/{world_manager.py => manager.py} (94%) diff --git a/examples/all_gather/m8d.py b/examples/all_gather/m8d.py index 27a0fee..842840f 100644 --- a/examples/all_gather/m8d.py +++ b/examples/all_gather/m8d.py @@ -23,7 +23,7 @@ import asyncio import torch -import torch.distributed as dist +from multiworld.manager import WorldManager NUM_OF_STEPS = 100 @@ -114,7 +114,7 @@ async def main(args): world_size = 3 global world_manager - world_manager = dist.WorldManager() + world_manager = WorldManager() assert len(args.worldinfo) <= 2, "the number of worldinfo arguments must be <= 2" diff --git a/examples/all_reduce/m8d.py b/examples/all_reduce/m8d.py index 4001e66..d0708ce 100644 --- a/examples/all_reduce/m8d.py +++ b/examples/all_reduce/m8d.py @@ -24,6 +24,7 @@ import torch import torch.distributed as dist +from multiworld.manager import WorldManager NUM_OF_STEPS = 100 @@ -102,7 +103,7 @@ async def main(args): world_size = 3 global world_manager - world_manager = dist.WorldManager() + world_manager = WorldManager() assert len(args.worldinfo) <= 2, "the number of worldinfo arguments must be <= 2" diff --git a/examples/broadcast/m8d.py b/examples/broadcast/m8d.py index 3b78a63..5aa795d 100644 --- a/examples/broadcast/m8d.py +++ b/examples/broadcast/m8d.py @@ -23,7 +23,7 @@ import asyncio import torch -import torch.distributed as dist +from multiworld.manager import WorldManager NUM_OF_STEPS = 100 @@ -109,7 +109,7 @@ async def main(args): world_size = 3 global world_manager - world_manager = dist.WorldManager() + world_manager = WorldManager() assert len(args.worldinfo) <= 2, "the number of worldinfo arguments must be <= 2" diff --git a/examples/gather/m8d.py b/examples/gather/m8d.py index 8c03b86..d137a7f 100644 --- a/examples/gather/m8d.py +++ b/examples/gather/m8d.py @@ -23,7 +23,7 @@ import asyncio import torch -import torch.distributed as dist +from multiworld.manager import WorldManager NUM_OF_STEPS = 100 @@ -117,7 +117,7 @@ async def main(args): world_size = 3 global world_manager - world_manager = dist.WorldManager() + world_manager = WorldManager() assert len(args.worldinfo) <= 2, "the number of worldinfo arguments must be <= 2" diff --git a/examples/reduce/m8d.py b/examples/reduce/m8d.py index ca5b69c..312a14d 100644 --- a/examples/reduce/m8d.py +++ b/examples/reduce/m8d.py @@ -24,6 +24,7 @@ import torch import torch.distributed as dist +from multiworld.manager import WorldManager NUM_OF_STEPS = 100 @@ -108,7 +109,7 @@ async def main(args): world_size = 3 global world_manager - world_manager = dist.WorldManager() + world_manager = WorldManager() assert len(args.worldinfo) <= 2, "the number of worldinfo arguments must be <= 2" diff --git a/examples/resnet/m8d.py b/examples/resnet/m8d.py index 49b1bfc..f8d5d25 100644 --- a/examples/resnet/m8d.py +++ b/examples/resnet/m8d.py @@ -35,9 +35,9 @@ import time import torch -import torch.distributed as dist import torch.multiprocessing as mp import torchvision.transforms as transforms +from multiworld.manager import WorldManager from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 from transformers import AutoModelForImageClassification @@ -190,7 +190,7 @@ async def init_world( if world_manager is None: # TODO: make WorldManager as singleton - world_manager = dist.WorldManager() + world_manager = WorldManager() await world_manager.initialize_world( world_name, rank, size, backend=backend, addr=addr, port=port diff --git a/examples/scatter/m8d.py b/examples/scatter/m8d.py index 9cc7d5f..5030328 100644 --- a/examples/scatter/m8d.py +++ b/examples/scatter/m8d.py @@ -23,7 +23,7 @@ import asyncio import torch -import torch.distributed as dist +from multiworld.manager import WorldManager NUM_OF_STEPS = 100 @@ -114,7 +114,7 @@ async def main(args): world_size = 3 global world_manager - world_manager = dist.WorldManager() + world_manager = WorldManager() assert len(args.worldinfo) <= 2, "the number of worldinfo arguments must be <= 2" diff --git a/examples/send_recv/m8d.py b/examples/send_recv/m8d.py index 1c21cd6..7a1d1d8 100644 --- a/examples/send_recv/m8d.py +++ b/examples/send_recv/m8d.py @@ -22,11 +22,10 @@ import argparse import asyncio -import os import time import torch -import torch.distributed as dist +from multiworld.manager import WorldManager async def init_world(world_name, rank, size, backend="gloo", addr="127.0.0.1", port=-1): @@ -142,7 +141,7 @@ async def main(args): size = 2 global world_manager - world_manager = dist.WorldManager() + world_manager = WorldManager() assert len(args.worldinfo) <= 2, "the number of worldinfo arguments must be <= 2" diff --git a/multiworld/world_communicator.py b/multiworld/communicator.py similarity index 98% rename from multiworld/world_communicator.py rename to multiworld/communicator.py index fb727b6..b130e74 100644 --- a/multiworld/world_communicator.py +++ b/multiworld/communicator.py @@ -24,11 +24,10 @@ import torch.distributed as dist from torch import Tensor -from torch.distributed import Work -from torch.distributed.distributed_c10d import DEFAULT_WORLD_NAME +from torch.distributed import DEFAULT_WORLD_NAME, Work if TYPE_CHECKING: - from torch.distributed.world_manager import WorldManager + from multiworld.manager import WorldManager logger = logging.getLogger(__name__) diff --git a/multiworld/world_manager.py b/multiworld/manager.py similarity index 94% rename from multiworld/world_manager.py rename to multiworld/manager.py index e48e7b1..ebe7d7d 100644 --- a/multiworld/world_manager.py +++ b/multiworld/manager.py @@ -25,9 +25,10 @@ from queue import Queue as SyncQ import torch.distributed as dist -import torch.distributed.distributed_c10d as dist_c10d -from torch.distributed.world_communicator import WorldCommunicator +from torch.distributed import _World as dist_c10d_World +from torch.distributed import _worlds as dist_c10d_worlds +from multiworld.communicator import WorldCommunicator from multiworld.watchdog import WatchDog logger = logging.Logger(__name__) @@ -160,18 +161,18 @@ async def initialize_world( def add_world(self, world_name: str, backend: str) -> None: """Add a new world to the world manager.""" - if world_name in dist_c10d._worlds: + if world_name in dist_c10d_worlds: raise ValueError(f"World {world_name} already exists.") - world = dist_c10d._World(world_name) + world = dist_c10d_World(world_name) - dist_c10d._worlds[world_name] = world + dist_c10d_worlds[world_name] = world self._communicator.add_world(world_name, backend) def remove_world(self, world_name: str) -> None: """Remove a world from the world manager.""" - if world_name not in dist_c10d._worlds: + if world_name not in dist_c10d_worlds: raise ValueError(f"World {world_name} does not exist.") self._communicator.remove_world(world_name) @@ -187,7 +188,7 @@ def remove_world(self, world_name: str) -> None: # we need to find out a right timing/way to call them. # calling them is temporarily disabled. # dist.destroy_process_group(name=world_name) - # del dist_c10d._worlds[world_name] + # del dist_c10d_worlds[world_name] logger.debug(f"done removing world {world_name}") @property diff --git a/multiworld/patch/pytorch-v2.4.0.patch b/multiworld/patch/pytorch-v2.4.0.patch index 3e49f50..84c9907 100644 --- a/multiworld/patch/pytorch-v2.4.0.patch +++ b/multiworld/patch/pytorch-v2.4.0.patch @@ -1,5 +1,5 @@ diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py -index 6816ea9b1e..b979ef3c82 100644 +index 6816ea9b1e1..b979ef3c824 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -92,7 +92,7 @@ class WorldMetaClassVariable(DistributedVariable): @@ -12,20 +12,22 @@ index 6816ea9b1e..b979ef3c82 100644 diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py -index b8e911c873..3b093244ef 100644 +index b8e911c8738..7e9a85cc21e 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py -@@ -123,6 +123,8 @@ if is_available(): +@@ -121,6 +121,10 @@ if is_available(): + _CoalescingManager, + _get_process_group_name, get_node_local_rank, ++ Work, ++ _worlds, ++ _World, ++ DEFAULT_WORLD_NAME, ) -+ from .world_manager import WorldManager -+ from .rendezvous import ( - rendezvous, - _create_store_from_options, diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py -index 9ac89166b2..c0f5c93d15 100644 +index 9ac89166b25..c0f5c93d157 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -16,6 +16,7 @@ try: @@ -82,7 +84,7 @@ index 9ac89166b2..c0f5c93d15 100644 output = all_gather_tensor(tensor, 0, group, tag) diff --git a/torch/distributed/_tensor/_collective_utils.py b/torch/distributed/_tensor/_collective_utils.py -index 4c1d184036..728f7be47c 100644 +index 4c1d1840366..728f7be47cd 100644 --- a/torch/distributed/_tensor/_collective_utils.py +++ b/torch/distributed/_tensor/_collective_utils.py @@ -23,6 +23,7 @@ from torch.distributed.distributed_c10d import ( @@ -112,7 +114,7 @@ index 4c1d184036..728f7be47c 100644 return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py -index 621e46fc19..8a9f701f1d 100644 +index 621e46fc198..8a9f701f1d6 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -12,12 +12,13 @@ __all__ = [ @@ -149,7 +151,7 @@ index 621e46fc19..8a9f701f1d 100644 buffer = ( diff --git a/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py -index 3528f39874..693b5a728d 100644 +index 3528f398747..693b5a728dc 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py @@ -8,6 +8,7 @@ from . import default_hooks as default @@ -170,7 +172,7 @@ index 3528f39874..693b5a728d 100644 # The input tensor is a flattened 1D tensor. diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py -index fbc3b9e873..467728516f 100644 +index fbc3b9e8739..467728516f3 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -10,6 +10,8 @@ import torch.distributed as dist @@ -201,7 +203,7 @@ index fbc3b9e873..467728516f 100644 # The input tensor is a flattened 1D tensor. diff --git a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py -index cbc1290e76..2b09ad525c 100644 +index cbc1290e76e..2b09ad525c6 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py @@ -3,6 +3,7 @@ import torch @@ -231,7 +233,7 @@ index cbc1290e76..2b09ad525c 100644 world_size = group_to_use.size() diff --git a/torch/distributed/algorithms/model_averaging/averagers.py b/torch/distributed/algorithms/model_averaging/averagers.py -index 178efd1dba..75d117fd5b 100644 +index 178efd1dbad..75d117fd5b7 100644 --- a/torch/distributed/algorithms/model_averaging/averagers.py +++ b/torch/distributed/algorithms/model_averaging/averagers.py @@ -8,6 +8,8 @@ import torch.distributed.algorithms.model_averaging.utils as utils @@ -253,7 +255,7 @@ index 178efd1dba..75d117fd5b 100644 self.step = 0 diff --git a/torch/distributed/algorithms/model_averaging/utils.py b/torch/distributed/algorithms/model_averaging/utils.py -index de1977959d..d19c5a7626 100644 +index de1977959d2..d19c5a7626d 100644 --- a/torch/distributed/algorithms/model_averaging/utils.py +++ b/torch/distributed/algorithms/model_averaging/utils.py @@ -10,6 +10,8 @@ import torch.distributed as dist @@ -275,7 +277,7 @@ index de1977959d..d19c5a7626 100644 if dist._rank_not_in_group(group_to_use): return diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py -index bd81fd61b0..225a55f08c 100644 +index bd81fd61b02..225a55f08c0 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -450,18 +450,20 @@ class _CollOp: @@ -2633,7 +2635,7 @@ index bd81fd61b0..225a55f08c 100644 # This ops are not friendly to TorchDynamo. So, we decide to disallow these ops # in FX graph, allowing them to run them on eager, with torch.compile. diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py -index 3487e01263..82e1910dfb 100644 +index 3487e01263c..82e1910dfbd 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -8,6 +8,8 @@ import torch.distributed as dist @@ -2670,7 +2672,7 @@ index 3487e01263..82e1910dfb 100644 device, init_scale=init_scale, diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py -index e90a78a693..a45dcf1e29 100644 +index e90a78a6932..a45dcf1e29b 100644 --- a/torch/distributed/nn/functional.py +++ b/torch/distributed/nn/functional.py @@ -7,7 +7,9 @@ from torch.autograd import Function @@ -2766,7 +2768,7 @@ index e90a78a693..a45dcf1e29 100644 Reduces the tensor data across all machines in such a way that all get the final result. diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py -index 8a3be3b018..30f706d11a 100644 +index 8a3be3b0181..30f706d11a9 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.py +++ b/torch/distributed/optim/zero_redundancy_optimizer.py @@ -19,6 +19,7 @@ from torch.distributed.algorithms.join import Join, Joinable, JoinHook @@ -2823,7 +2825,7 @@ index 8a3be3b018..30f706d11a 100644 self.world_size: int = dist.get_world_size(self.process_group) self.rank: int = dist.get_rank(self.process_group) diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py -index 75c8b5504d..55cfdb78df 100644 +index 75c8b5504d4..55cfdb78dfd 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -570,6 +570,7 @@ class LazyBatchNorm3d(_LazyNormBase, _BatchNorm): @@ -2844,7 +2846,7 @@ index 75c8b5504d..55cfdb78df 100644 process_group = self.process_group world_size = torch.distributed.get_world_size(process_group) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py -index 0ec5dd2224..cbab1ac374 100644 +index 0ec5dd22244..cbab1ac3747 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -655,7 +655,7 @@ class DistributedTest: diff --git a/multiworld/post_setup.py b/multiworld/post_setup.py index 7d57339..08adf47 100644 --- a/multiworld/post_setup.py +++ b/multiworld/post_setup.py @@ -19,11 +19,15 @@ import pathlib import shutil import site + import torch + def main(): parser = argparse.ArgumentParser() - parser.add_argument("patchfile", nargs='?', default=None, help="Path to the patch file") + parser.add_argument( + "patchfile", nargs="?", default=None, help="Path to the patch file" + ) args = parser.parse_args() path_to_sitepackages = site.getsitepackages()[0] @@ -31,8 +35,15 @@ def main(): if args.patchfile: patchfile = args.patchfile else: - torch_version = torch.__version__.split('+')[0] # torch version is in "2.2.1+cu121" format - patchfile = os.path.join(path_to_sitepackages, "multiworld", "patch", "pytorch-v" + torch_version + ".patch") + torch_version = torch.__version__.split("+")[ + 0 + ] # torch version is in "2.2.1+cu121" format + patchfile = os.path.join( + path_to_sitepackages, + "multiworld", + "patch", + "pytorch-v" + torch_version + ".patch", + ) patch_basename = os.path.basename(patchfile) @@ -45,12 +56,6 @@ def main(): p = pathlib.Path(patch_basename) p.unlink() - files_to_copy = ["world_manager.py", "world_communicator.py"] - for f in files_to_copy: - src = os.path.join(path_to_sitepackages, "multiworld", f) - dst = os.path.join(path_to_sitepackages, "torch/distributed", f) - shutil.copyfile(src, dst) - if __name__ == "__main__": main()