Skip to content

Commit

Permalink
[Fix] Fix examples/distributed_training.py does not work in DDP (open…
Browse files Browse the repository at this point in the history
…-mmlab#700)

* Update distributed_training.py

Better example for DDP training

* Update distributed_training.py

* Update distributed_training.py

update according to reviwer's suggesstions.

* Update distributed_training.py

* Update distributed_training.py

The previous update copy data from main branch, its a mistake.
This update fix this mistake and the code is tested.
  • Loading branch information
gaopinghai authored Nov 9, 2022
1 parent b35196a commit 46209b8
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions examples/distributed_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torchvision
import torchvision.transforms as transforms
from torch.optim import SGD
from torch.utils.data import DataLoader

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
Expand Down Expand Up @@ -57,29 +56,33 @@ def parse_args():
def main():
args = parse_args()
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(
train_set = torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
]))
valid_set = torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(**norm_cfg)]))
train_dataloader = dict(
batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
val_dataloader = DataLoader(
dataset=train_set,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'))
val_dataloader = dict(
batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(**norm_cfg)])))
dataset=valid_set,
sampler=dict(type='DefaultSampler', shuffle=False),
collate_fn=dict(type='default_collate'))
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
Expand Down

0 comments on commit 46209b8

Please sign in to comment.