Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: restructuring packaging #83

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/all_gather/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import asyncio

import torch
import torch.distributed as dist
from multiworld.manager import WorldManager

NUM_OF_STEPS = 100

Expand Down Expand Up @@ -114,7 +114,7 @@ async def main(args):
world_size = 3
global world_manager

world_manager = dist.WorldManager()
world_manager = WorldManager()

assert len(args.worldinfo) <= 2, "the number of worldinfo arguments must be <= 2"

Expand Down
3 changes: 2 additions & 1 deletion examples/all_reduce/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import torch
import torch.distributed as dist
from multiworld.manager import WorldManager

NUM_OF_STEPS = 100

Expand Down Expand Up @@ -102,7 +103,7 @@ async def main(args):
world_size = 3
global world_manager

world_manager = dist.WorldManager()
world_manager = WorldManager()

assert len(args.worldinfo) <= 2, "the number of worldinfo arguments must be <= 2"

Expand Down
4 changes: 2 additions & 2 deletions examples/broadcast/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import asyncio

import torch
import torch.distributed as dist
from multiworld.manager import WorldManager

NUM_OF_STEPS = 100

Expand Down Expand Up @@ -109,7 +109,7 @@ async def main(args):
world_size = 3
global world_manager

world_manager = dist.WorldManager()
world_manager = WorldManager()

assert len(args.worldinfo) <= 2, "the number of worldinfo arguments must be <= 2"

Expand Down
4 changes: 2 additions & 2 deletions examples/gather/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import asyncio

import torch
import torch.distributed as dist
from multiworld.manager import WorldManager

NUM_OF_STEPS = 100

Expand Down Expand Up @@ -117,7 +117,7 @@ async def main(args):
world_size = 3
global world_manager

world_manager = dist.WorldManager()
world_manager = WorldManager()

assert len(args.worldinfo) <= 2, "the number of worldinfo arguments must be <= 2"

Expand Down
3 changes: 2 additions & 1 deletion examples/reduce/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import torch
import torch.distributed as dist
from multiworld.manager import WorldManager

NUM_OF_STEPS = 100

Expand Down Expand Up @@ -108,7 +109,7 @@ async def main(args):
world_size = 3
global world_manager

world_manager = dist.WorldManager()
world_manager = WorldManager()

assert len(args.worldinfo) <= 2, "the number of worldinfo arguments must be <= 2"

Expand Down
4 changes: 2 additions & 2 deletions examples/resnet/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
import time

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torchvision.transforms as transforms
from multiworld.manager import WorldManager
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from transformers import AutoModelForImageClassification
Expand Down Expand Up @@ -190,7 +190,7 @@ async def init_world(

if world_manager is None:
# TODO: make WorldManager as singleton
world_manager = dist.WorldManager()
world_manager = WorldManager()

await world_manager.initialize_world(
world_name, rank, size, backend=backend, addr=addr, port=port
Expand Down
4 changes: 2 additions & 2 deletions examples/scatter/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import asyncio

import torch
import torch.distributed as dist
from multiworld.manager import WorldManager

NUM_OF_STEPS = 100

Expand Down Expand Up @@ -114,7 +114,7 @@ async def main(args):
world_size = 3
global world_manager

world_manager = dist.WorldManager()
world_manager = WorldManager()

assert len(args.worldinfo) <= 2, "the number of worldinfo arguments must be <= 2"

Expand Down
5 changes: 2 additions & 3 deletions examples/send_recv/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@

import argparse
import asyncio
import os
import time

import torch
import torch.distributed as dist
from multiworld.manager import WorldManager


async def init_world(world_name, rank, size, backend="gloo", addr="127.0.0.1", port=-1):
Expand Down Expand Up @@ -142,7 +141,7 @@ async def main(args):
size = 2
global world_manager

world_manager = dist.WorldManager()
world_manager = WorldManager()

assert len(args.worldinfo) <= 2, "the number of worldinfo arguments must be <= 2"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@

import torch.distributed as dist
from torch import Tensor
from torch.distributed import Work
from torch.distributed.distributed_c10d import DEFAULT_WORLD_NAME
from torch.distributed import DEFAULT_WORLD_NAME, Work

if TYPE_CHECKING:
from torch.distributed.world_manager import WorldManager
from multiworld.manager import WorldManager

logger = logging.getLogger(__name__)

Expand Down
15 changes: 8 additions & 7 deletions multiworld/world_manager.py → multiworld/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
from queue import Queue as SyncQ

import torch.distributed as dist
import torch.distributed.distributed_c10d as dist_c10d
from torch.distributed.world_communicator import WorldCommunicator
from torch.distributed import _World as dist_c10d_World
from torch.distributed import _worlds as dist_c10d_worlds

from multiworld.communicator import WorldCommunicator
from multiworld.watchdog import WatchDog

logger = logging.Logger(__name__)
Expand Down Expand Up @@ -160,18 +161,18 @@ async def initialize_world(

def add_world(self, world_name: str, backend: str) -> None:
"""Add a new world to the world manager."""
if world_name in dist_c10d._worlds:
if world_name in dist_c10d_worlds:
raise ValueError(f"World {world_name} already exists.")

world = dist_c10d._World(world_name)
world = dist_c10d_World(world_name)

dist_c10d._worlds[world_name] = world
dist_c10d_worlds[world_name] = world

self._communicator.add_world(world_name, backend)

def remove_world(self, world_name: str) -> None:
"""Remove a world from the world manager."""
if world_name not in dist_c10d._worlds:
if world_name not in dist_c10d_worlds:
raise ValueError(f"World {world_name} does not exist.")

self._communicator.remove_world(world_name)
Expand All @@ -187,7 +188,7 @@ def remove_world(self, world_name: str) -> None:
# we need to find out a right timing/way to call them.
# calling them is temporarily disabled.
# dist.destroy_process_group(name=world_name)
# del dist_c10d._worlds[world_name]
# del dist_c10d_worlds[world_name]
logger.debug(f"done removing world {world_name}")

@property
Expand Down
44 changes: 23 additions & 21 deletions multiworld/patch/pytorch-v2.4.0.patch
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py
index 6816ea9b1e..b979ef3c82 100644
index 6816ea9b1e1..b979ef3c824 100644
--- a/torch/_dynamo/variables/distributed.py
+++ b/torch/_dynamo/variables/distributed.py
@@ -92,7 +92,7 @@ class WorldMetaClassVariable(DistributedVariable):
Expand All @@ -12,20 +12,22 @@ index 6816ea9b1e..b979ef3c82 100644


diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py
index b8e911c873..3b093244ef 100644
index b8e911c8738..7e9a85cc21e 100644
--- a/torch/distributed/__init__.py
+++ b/torch/distributed/__init__.py
@@ -123,6 +123,8 @@ if is_available():
@@ -121,6 +121,10 @@ if is_available():
_CoalescingManager,
_get_process_group_name,
get_node_local_rank,
+ Work,
+ _worlds,
+ _World,
+ DEFAULT_WORLD_NAME,
)

+ from .world_manager import WorldManager
+
from .rendezvous import (
rendezvous,
_create_store_from_options,
diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py
index 9ac89166b2..c0f5c93d15 100644
index 9ac89166b25..c0f5c93d157 100644
--- a/torch/distributed/_functional_collectives.py
+++ b/torch/distributed/_functional_collectives.py
@@ -16,6 +16,7 @@ try:
Expand Down Expand Up @@ -82,7 +84,7 @@ index 9ac89166b2..c0f5c93d15 100644

output = all_gather_tensor(tensor, 0, group, tag)
diff --git a/torch/distributed/_tensor/_collective_utils.py b/torch/distributed/_tensor/_collective_utils.py
index 4c1d184036..728f7be47c 100644
index 4c1d1840366..728f7be47cd 100644
--- a/torch/distributed/_tensor/_collective_utils.py
+++ b/torch/distributed/_tensor/_collective_utils.py
@@ -23,6 +23,7 @@ from torch.distributed.distributed_c10d import (
Expand Down Expand Up @@ -112,7 +114,7 @@ index 4c1d184036..728f7be47c 100644

return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op)
diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
index 621e46fc19..8a9f701f1d 100644
index 621e46fc198..8a9f701f1d6 100644
--- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
+++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
@@ -12,12 +12,13 @@ __all__ = [
Expand Down Expand Up @@ -149,7 +151,7 @@ index 621e46fc19..8a9f701f1d 100644

buffer = (
diff --git a/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py
index 3528f39874..693b5a728d 100644
index 3528f398747..693b5a728dc 100644
--- a/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py
+++ b/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py
@@ -8,6 +8,7 @@ from . import default_hooks as default
Expand All @@ -170,7 +172,7 @@ index 3528f39874..693b5a728d 100644

# The input tensor is a flattened 1D tensor.
diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
index fbc3b9e873..467728516f 100644
index fbc3b9e8739..467728516f3 100644
--- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
+++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
@@ -10,6 +10,8 @@ import torch.distributed as dist
Expand Down Expand Up @@ -201,7 +203,7 @@ index fbc3b9e873..467728516f 100644

# The input tensor is a flattened 1D tensor.
diff --git a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py
index cbc1290e76..2b09ad525c 100644
index cbc1290e76e..2b09ad525c6 100644
--- a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py
+++ b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py
@@ -3,6 +3,7 @@ import torch
Expand Down Expand Up @@ -231,7 +233,7 @@ index cbc1290e76..2b09ad525c 100644
world_size = group_to_use.size()

diff --git a/torch/distributed/algorithms/model_averaging/averagers.py b/torch/distributed/algorithms/model_averaging/averagers.py
index 178efd1dba..75d117fd5b 100644
index 178efd1dbad..75d117fd5b7 100644
--- a/torch/distributed/algorithms/model_averaging/averagers.py
+++ b/torch/distributed/algorithms/model_averaging/averagers.py
@@ -8,6 +8,8 @@ import torch.distributed.algorithms.model_averaging.utils as utils
Expand All @@ -253,7 +255,7 @@ index 178efd1dba..75d117fd5b 100644
self.step = 0

diff --git a/torch/distributed/algorithms/model_averaging/utils.py b/torch/distributed/algorithms/model_averaging/utils.py
index de1977959d..d19c5a7626 100644
index de1977959d2..d19c5a7626d 100644
--- a/torch/distributed/algorithms/model_averaging/utils.py
+++ b/torch/distributed/algorithms/model_averaging/utils.py
@@ -10,6 +10,8 @@ import torch.distributed as dist
Expand All @@ -275,7 +277,7 @@ index de1977959d..d19c5a7626 100644
if dist._rank_not_in_group(group_to_use):
return
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index bd81fd61b0..225a55f08c 100644
index bd81fd61b02..225a55f08c0 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -450,18 +450,20 @@ class _CollOp:
Expand Down Expand Up @@ -2633,7 +2635,7 @@ index bd81fd61b0..225a55f08c 100644
# This ops are not friendly to TorchDynamo. So, we decide to disallow these ops
# in FX graph, allowing them to run them on eager, with torch.compile.
diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py
index 3487e01263..82e1910dfb 100644
index 3487e01263c..82e1910dfbd 100644
--- a/torch/distributed/fsdp/sharded_grad_scaler.py
+++ b/torch/distributed/fsdp/sharded_grad_scaler.py
@@ -8,6 +8,8 @@ import torch.distributed as dist
Expand Down Expand Up @@ -2670,7 +2672,7 @@ index 3487e01263..82e1910dfb 100644
device,
init_scale=init_scale,
diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py
index e90a78a693..a45dcf1e29 100644
index e90a78a6932..a45dcf1e29b 100644
--- a/torch/distributed/nn/functional.py
+++ b/torch/distributed/nn/functional.py
@@ -7,7 +7,9 @@ from torch.autograd import Function
Expand Down Expand Up @@ -2766,7 +2768,7 @@ index e90a78a693..a45dcf1e29 100644
Reduces the tensor data across all machines in such a way that all get the final result.

diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py
index 8a3be3b018..30f706d11a 100644
index 8a3be3b0181..30f706d11a9 100644
--- a/torch/distributed/optim/zero_redundancy_optimizer.py
+++ b/torch/distributed/optim/zero_redundancy_optimizer.py
@@ -19,6 +19,7 @@ from torch.distributed.algorithms.join import Join, Joinable, JoinHook
Expand Down Expand Up @@ -2823,7 +2825,7 @@ index 8a3be3b018..30f706d11a 100644
self.world_size: int = dist.get_world_size(self.process_group)
self.rank: int = dist.get_rank(self.process_group)
diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py
index 75c8b5504d..55cfdb78df 100644
index 75c8b5504d4..55cfdb78dfd 100644
--- a/torch/nn/modules/batchnorm.py
+++ b/torch/nn/modules/batchnorm.py
@@ -570,6 +570,7 @@ class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
Expand All @@ -2844,7 +2846,7 @@ index 75c8b5504d..55cfdb78df 100644
process_group = self.process_group
world_size = torch.distributed.get_world_size(process_group)
diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py
index 0ec5dd2224..cbab1ac374 100644
index 0ec5dd22244..cbab1ac3747 100644
--- a/torch/testing/_internal/distributed/distributed_test.py
+++ b/torch/testing/_internal/distributed/distributed_test.py
@@ -655,7 +655,7 @@ class DistributedTest:
Expand Down
Loading