Skip to content

Commit

Permalink
Merge pull request #134 from jiaoml1996/jml/add_pooldataloader_pat
Browse files Browse the repository at this point in the history
add PoolDataLoader for parrots
  • Loading branch information
nbei authored Sep 3, 2020
2 parents 4c2345f + 85ee4e5 commit 7253ba8
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 32 deletions.
110 changes: 79 additions & 31 deletions mmedit/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,29 @@ def _dist_train(model,
"""
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
dist=True,
drop_last=cfg.data.get('drop_last', False),
seed=cfg.seed) for ds in dataset
]
if torch.__version__ == 'parrots':
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
dist=True,
drop_last=cfg.data.get('drop_last', False),
seed=cfg.seed,
prefetch_num=cfg.data.get('prefetch_num', 2),
pin_memory=cfg.data.get('pin_memory', False)) for ds in dataset
]
else:
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
dist=True,
drop_last=cfg.data.get('drop_last', False),
seed=cfg.seed) for ds in dataset
]

# put model on gpus
find_unused_parameters = cfg.get('find_unused_parameters', False)
model = DistributedDataParallelWrapper(
Expand Down Expand Up @@ -146,12 +160,22 @@ def _dist_train(model,
cfg.data.samples_per_gpu)
workers_per_gpu = cfg.data.get('val_workers_per_gpu',
cfg.data.workers_per_gpu)
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=True,
shuffle=False)
if torch.__version__ == 'parrots':
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=True,
shuffle=False,
prefetch_num=cfg.data.get('prefetch_num', 2),
pin_memory=cfg.data.get('pin_memory', False))
else:
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=True,
shuffle=False)
save_path = osp.join(cfg.work_dir, 'val_visuals')
runner.register_hook(
DistEvalIterHook(
Expand Down Expand Up @@ -185,16 +209,30 @@ def _non_dist_train(model,
"""
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False,
drop_last=cfg.data.get('drop_last', False),
seed=cfg.seed) for ds in dataset
]
if torch.__version__ == 'parrots':
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False,
drop_last=cfg.data.get('drop_last', False),
seed=cfg.seed,
prefetch_num=cfg.data.get('prefetch_num', 2),
pin_memory=cfg.data.get('pin_memory', False)) for ds in dataset
]
else:
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False,
drop_last=cfg.data.get('drop_last', False),
seed=cfg.seed) for ds in dataset
]
# put model on gpus
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()

Expand Down Expand Up @@ -229,12 +267,22 @@ def _non_dist_train(model,
cfg.data.samples_per_gpu)
workers_per_gpu = cfg.data.get('val_workers_per_gpu',
cfg.data.workers_per_gpu)
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=True,
shuffle=False)
if torch.__version__ == 'parrots':
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=True,
shuffle=False,
prefetch_num=cfg.data.get('prefetch_num', 2),
pin_memory=cfg.data.get('pin_memory', False))
else:
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=True,
shuffle=False)
save_path = osp.join(cfg.work_dir, 'val_visuals')
runner.register_hook(
EvalIterHook(data_loader, save_path=save_path, **cfg.evaluation))
Expand Down
8 changes: 7 additions & 1 deletion mmedit/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@
from functools import partial

import numpy as np
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import build_from_cfg
from torch.utils.data import ConcatDataset, DataLoader
from torch.utils.data import ConcatDataset

from .dataset_wrappers import RepeatDataset
from .registry import DATASETS
from .samplers import DistributedSampler

if torch.__version__ == 'parrots':
from torch.utils.data import PoolDataLoader as DataLoader
else:
from torch.utils.data import DataLoader

if platform.system() != 'Windows':
# https://github.com/pytorch/pytorch/issues/973
import resource
Expand Down

0 comments on commit 7253ba8

Please sign in to comment.