diff --git a/distributed/ddp/README.md b/distributed/ddp/README.md new file mode 100644 index 0000000000..b6d09bd22f --- /dev/null +++ b/distributed/ddp/README.md @@ -0,0 +1,10 @@ +`DistributedDataParallel` Example + +This example demonstrates basic use cases of `DistributedDataParallel`, and +also covers some more advanced scenarios including checkpointing models and +combining DDP with model parallelism. + +``` +pip install -r requirements.txt +python main.py +``` diff --git a/distributed/ddp/main.py b/distributed/ddp/main.py new file mode 100644 index 0000000000..34a855f051 --- /dev/null +++ b/distributed/ddp/main.py @@ -0,0 +1,150 @@ +import os +import tempfile +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim + +from torch.nn.parallel import DistributedDataParallel as DDP + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.net2 = nn.Linear(10, 5) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + +def demo_basic(rank, world_size): + print(f"Running basic DDP example on rank {rank}.") + setup(rank, world_size) + + # create model and move it to GPU with id rank + model = ToyModel().to(rank) + ddp_model = DDP(model, device_ids=[rank]) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) + + optimizer.zero_grad() + outputs = ddp_model(torch.randn(20, 10)) + labels = torch.randn(20, 5).to(rank) + loss_fn(outputs, labels).backward() + optimizer.step() + + cleanup() + + +def run_demo(demo_fn, world_size): + mp.spawn(demo_fn, + args=(world_size,), + nprocs=world_size, + join=True) + + +def demo_checkpoint(rank, world_size): + print(f"Running DDP checkpoint example on rank {rank}.") + setup(rank, world_size) + + model = ToyModel().to(rank) + ddp_model = DDP(model, device_ids=[rank]) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) + + CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint" + if rank == 0: + # All processes should see same parameters as they all start from same + # random parameters and gradients are synchronized in backward passes. + # Therefore, saving it in one process is sufficient. + torch.save(ddp_model.state_dict(), CHECKPOINT_PATH) + + # Use a barrier() to make sure that process 1 loads the model after process + # 0 saves it. + dist.barrier() + # configure map_location properly + map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} + ddp_model.load_state_dict( + torch.load(CHECKPOINT_PATH, map_location=map_location)) + + optimizer.zero_grad() + outputs = ddp_model(torch.randn(20, 10)) + labels = torch.randn(20, 5).to(rank) + loss_fn = nn.MSELoss() + loss_fn(outputs, labels).backward() + optimizer.step() + + # Use a barrier() to make sure that all processes have finished reading the + # checkpoint + dist.barrier() + + if rank == 0: + os.remove(CHECKPOINT_PATH) + + cleanup() + + +class ToyMpModel(nn.Module): + def __init__(self, dev0, dev1): + super(ToyMpModel, self).__init__() + self.dev0 = dev0 + self.dev1 = dev1 + self.net1 = torch.nn.Linear(10, 10).to(dev0) + self.relu = torch.nn.ReLU() + self.net2 = torch.nn.Linear(10, 5).to(dev1) + + def forward(self, x): + x = x.to(self.dev0) + x = self.relu(self.net1(x)) + x = x.to(self.dev1) + return self.net2(x) + + +def demo_model_parallel(rank, world_size): + print(f"Running DDP with model parallel example on rank {rank}.") + setup(rank, world_size) + + # setup mp_model and devices for this process + dev0 = rank * 2 + dev1 = rank * 2 + 1 + mp_model = ToyMpModel(dev0, dev1) + ddp_mp_model = DDP(mp_model) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001) + + optimizer.zero_grad() + # outputs will be on dev1 + outputs = ddp_mp_model(torch.randn(20, 10)) + labels = torch.randn(20, 5).to(dev1) + loss_fn(outputs, labels).backward() + optimizer.step() + + cleanup() + + +if __name__ == "__main__": + n_gpus = torch.cuda.device_count() + if n_gpus < 8: + print(f"Requires at least 8 GPUs to run, but got {n_gpus}.") + else: + run_demo(demo_basic, 8) + run_demo(demo_checkpoint, 8) + run_demo(demo_model_parallel, 4) diff --git a/distributed/ddp/requirements.txt b/distributed/ddp/requirements.txt new file mode 100644 index 0000000000..12c6d5d5ea --- /dev/null +++ b/distributed/ddp/requirements.txt @@ -0,0 +1 @@ +torch