Skip to content

Commit

Permalink
Merge pull request pytorch#763 from mrshenli/ddp
Browse files Browse the repository at this point in the history
Adding DDP example
  • Loading branch information
Jessica Lin authored Apr 28, 2020
2 parents 69d2798 + 0cbd699 commit e49fa58
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 0 deletions.
10 changes: 10 additions & 0 deletions distributed/ddp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
`DistributedDataParallel` Example

This example demonstrates basic use cases of `DistributedDataParallel`, and
also covers some more advanced scenarios including checkpointing models and
combining DDP with model parallelism.

```
pip install -r requirements.txt
python main.py
```
150 changes: 150 additions & 0 deletions distributed/ddp/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import os
import tempfile
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP


def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)


def cleanup():
dist.destroy_process_group()


class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)

def forward(self, x):
return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)

# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])

loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()

cleanup()


def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)


def demo_checkpoint(rank, world_size):
print(f"Running DDP checkpoint example on rank {rank}.")
setup(rank, world_size)

model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])

loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
if rank == 0:
# All processes should see same parameters as they all start from same
# random parameters and gradients are synchronized in backward passes.
# Therefore, saving it in one process is sufficient.
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

# Use a barrier() to make sure that process 1 loads the model after process
# 0 saves it.
dist.barrier()
# configure map_location properly
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
ddp_model.load_state_dict(
torch.load(CHECKPOINT_PATH, map_location=map_location))

optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn = nn.MSELoss()
loss_fn(outputs, labels).backward()
optimizer.step()

# Use a barrier() to make sure that all processes have finished reading the
# checkpoint
dist.barrier()

if rank == 0:
os.remove(CHECKPOINT_PATH)

cleanup()


class ToyMpModel(nn.Module):
def __init__(self, dev0, dev1):
super(ToyMpModel, self).__init__()
self.dev0 = dev0
self.dev1 = dev1
self.net1 = torch.nn.Linear(10, 10).to(dev0)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 5).to(dev1)

def forward(self, x):
x = x.to(self.dev0)
x = self.relu(self.net1(x))
x = x.to(self.dev1)
return self.net2(x)


def demo_model_parallel(rank, world_size):
print(f"Running DDP with model parallel example on rank {rank}.")
setup(rank, world_size)

# setup mp_model and devices for this process
dev0 = rank * 2
dev1 = rank * 2 + 1
mp_model = ToyMpModel(dev0, dev1)
ddp_mp_model = DDP(mp_model)

loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)

optimizer.zero_grad()
# outputs will be on dev1
outputs = ddp_mp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(dev1)
loss_fn(outputs, labels).backward()
optimizer.step()

cleanup()


if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
if n_gpus < 8:
print(f"Requires at least 8 GPUs to run, but got {n_gpus}.")
else:
run_demo(demo_basic, 8)
run_demo(demo_checkpoint, 8)
run_demo(demo_model_parallel, 4)
1 change: 1 addition & 0 deletions distributed/ddp/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
torch

0 comments on commit e49fa58

Please sign in to comment.