diff --git a/mmedit/apis/train.py b/mmedit/apis/train.py index 72df80c53b..fcfaecf2f4 100644 --- a/mmedit/apis/train.py +++ b/mmedit/apis/train.py @@ -99,15 +99,29 @@ def _dist_train(model, """ # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] - data_loaders = [ - build_dataloader( - ds, - cfg.data.samples_per_gpu, - cfg.data.workers_per_gpu, - dist=True, - drop_last=cfg.data.get('drop_last', False), - seed=cfg.seed) for ds in dataset - ] + if torch.__version__ == 'parrots': + data_loaders = [ + build_dataloader( + ds, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + dist=True, + drop_last=cfg.data.get('drop_last', False), + seed=cfg.seed, + prefetch_num=cfg.data.get('prefetch_num', 2), + pin_memory=cfg.data.get('pin_memory', False)) for ds in dataset + ] + else: + data_loaders = [ + build_dataloader( + ds, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + dist=True, + drop_last=cfg.data.get('drop_last', False), + seed=cfg.seed) for ds in dataset + ] + # put model on gpus find_unused_parameters = cfg.get('find_unused_parameters', False) model = DistributedDataParallelWrapper( @@ -146,12 +160,22 @@ def _dist_train(model, cfg.data.samples_per_gpu) workers_per_gpu = cfg.data.get('val_workers_per_gpu', cfg.data.workers_per_gpu) - data_loader = build_dataloader( - dataset, - samples_per_gpu=samples_per_gpu, - workers_per_gpu=workers_per_gpu, - dist=True, - shuffle=False) + if torch.__version__ == 'parrots': + data_loader = build_dataloader( + dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + dist=True, + shuffle=False, + prefetch_num=cfg.data.get('prefetch_num', 2), + pin_memory=cfg.data.get('pin_memory', False)) + else: + data_loader = build_dataloader( + dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + dist=True, + shuffle=False) save_path = osp.join(cfg.work_dir, 'val_visuals') runner.register_hook( DistEvalIterHook( @@ -185,16 +209,30 @@ def _non_dist_train(model, """ # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] - data_loaders = [ - build_dataloader( - ds, - cfg.data.samples_per_gpu, - cfg.data.workers_per_gpu, - cfg.gpus, - dist=False, - drop_last=cfg.data.get('drop_last', False), - seed=cfg.seed) for ds in dataset - ] + if torch.__version__ == 'parrots': + data_loaders = [ + build_dataloader( + ds, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + cfg.gpus, + dist=False, + drop_last=cfg.data.get('drop_last', False), + seed=cfg.seed, + prefetch_num=cfg.data.get('prefetch_num', 2), + pin_memory=cfg.data.get('pin_memory', False)) for ds in dataset + ] + else: + data_loaders = [ + build_dataloader( + ds, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + cfg.gpus, + dist=False, + drop_last=cfg.data.get('drop_last', False), + seed=cfg.seed) for ds in dataset + ] # put model on gpus model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() @@ -229,12 +267,22 @@ def _non_dist_train(model, cfg.data.samples_per_gpu) workers_per_gpu = cfg.data.get('val_workers_per_gpu', cfg.data.workers_per_gpu) - data_loader = build_dataloader( - dataset, - samples_per_gpu=samples_per_gpu, - workers_per_gpu=workers_per_gpu, - dist=True, - shuffle=False) + if torch.__version__ == 'parrots': + data_loader = build_dataloader( + dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + dist=True, + shuffle=False, + prefetch_num=cfg.data.get('prefetch_num', 2), + pin_memory=cfg.data.get('pin_memory', False)) + else: + data_loader = build_dataloader( + dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + dist=True, + shuffle=False) save_path = osp.join(cfg.work_dir, 'val_visuals') runner.register_hook( EvalIterHook(data_loader, save_path=save_path, **cfg.evaluation)) diff --git a/mmedit/datasets/builder.py b/mmedit/datasets/builder.py index 7bbba26091..086763891a 100644 --- a/mmedit/datasets/builder.py +++ b/mmedit/datasets/builder.py @@ -4,15 +4,21 @@ from functools import partial import numpy as np +import torch from mmcv.parallel import collate from mmcv.runner import get_dist_info from mmcv.utils import build_from_cfg -from torch.utils.data import ConcatDataset, DataLoader +from torch.utils.data import ConcatDataset from .dataset_wrappers import RepeatDataset from .registry import DATASETS from .samplers import DistributedSampler +if torch.__version__ == 'parrots': + from torch.utils.data import PoolDataLoader as DataLoader +else: + from torch.utils.data import DataLoader + if platform.system() != 'Windows': # https://github.com/pytorch/pytorch/issues/973 import resource