From 4a4e5c66f67a71ac4f7146d90b8e0f048b57fad4 Mon Sep 17 00:00:00 2001 From: jiaomenglei Date: Tue, 1 Sep 2020 18:45:16 +0800 Subject: [PATCH 1/7] add PoolDataLoader for parrots --- mmedit/datasets/builder.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mmedit/datasets/builder.py b/mmedit/datasets/builder.py index 7bbba26091..608f4b67ae 100644 --- a/mmedit/datasets/builder.py +++ b/mmedit/datasets/builder.py @@ -7,7 +7,13 @@ from mmcv.parallel import collate from mmcv.runner import get_dist_info from mmcv.utils import build_from_cfg -from torch.utils.data import ConcatDataset, DataLoader +from torch.utils.data import ConcatDataset +import torch +if torch.__version__ == 'parrots': + from torch.utils.data import PoolDataLoader + DataLoader = partial(PoolDataLoader, prefetch_num=2) +else: + from torch.utils.data import DataLoader from .dataset_wrappers import RepeatDataset from .registry import DATASETS @@ -133,6 +139,8 @@ def build_dataloader(dataset, worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None + if torch.__version__ == 'parrots': + pin_memory = False data_loader = DataLoader( dataset, batch_size=batch_size, From b9f6916070779ce8563b24be4210c68d7fa42bc4 Mon Sep 17 00:00:00 2001 From: jiaomenglei Date: Tue, 1 Sep 2020 18:50:13 +0800 Subject: [PATCH 2/7] fix lint --- mmedit/datasets/builder.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mmedit/datasets/builder.py b/mmedit/datasets/builder.py index 608f4b67ae..f34ee40569 100644 --- a/mmedit/datasets/builder.py +++ b/mmedit/datasets/builder.py @@ -4,20 +4,22 @@ from functools import partial import numpy as np +import torch from mmcv.parallel import collate from mmcv.runner import get_dist_info from mmcv.utils import build_from_cfg from torch.utils.data import ConcatDataset -import torch + +from .dataset_wrappers import RepeatDataset +from .registry import DATASETS +from .samplers import DistributedSampler + if torch.__version__ == 'parrots': from torch.utils.data import PoolDataLoader DataLoader = partial(PoolDataLoader, prefetch_num=2) else: from torch.utils.data import DataLoader -from .dataset_wrappers import RepeatDataset -from .registry import DATASETS -from .samplers import DistributedSampler if platform.system() != 'Windows': # https://github.com/pytorch/pytorch/issues/973 From 348ef577b97c1566e03cf04b10f570954d732dd5 Mon Sep 17 00:00:00 2001 From: jiaomenglei Date: Tue, 1 Sep 2020 18:53:41 +0800 Subject: [PATCH 3/7] fix lint --- mmedit/datasets/builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mmedit/datasets/builder.py b/mmedit/datasets/builder.py index f34ee40569..84053f6c1f 100644 --- a/mmedit/datasets/builder.py +++ b/mmedit/datasets/builder.py @@ -20,7 +20,6 @@ else: from torch.utils.data import DataLoader - if platform.system() != 'Windows': # https://github.com/pytorch/pytorch/issues/973 import resource From 18c2c9aba6cbbe8b2942934385e4547037c20cab Mon Sep 17 00:00:00 2001 From: jiaomenglei Date: Wed, 2 Sep 2020 19:28:13 +0800 Subject: [PATCH 4/7] pin_memory and prefetch_num can be got by config --- mmedit/datasets/builder.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mmedit/datasets/builder.py b/mmedit/datasets/builder.py index 84053f6c1f..d5d15c437e 100644 --- a/mmedit/datasets/builder.py +++ b/mmedit/datasets/builder.py @@ -16,7 +16,6 @@ if torch.__version__ == 'parrots': from torch.utils.data import PoolDataLoader - DataLoader = partial(PoolDataLoader, prefetch_num=2) else: from torch.utils.data import DataLoader @@ -120,6 +119,10 @@ def build_dataloader(dataset, Returns: DataLoader: A PyTorch dataloader. """ + if torch.__version__ == 'parrots': + prefetch_num = kwargs.get('prefetch_num', 2) + DataLoader = partial(PoolDataLoader, prefetch_num=prefetch_num) + rank, world_size = get_dist_info() if dist: sampler = DistributedSampler( @@ -140,8 +143,6 @@ def build_dataloader(dataset, worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None - if torch.__version__ == 'parrots': - pin_memory = False data_loader = DataLoader( dataset, batch_size=batch_size, From 206b90a55ae444863b12d9ce2fff7e0d9921892b Mon Sep 17 00:00:00 2001 From: jiaomenglei Date: Wed, 2 Sep 2020 19:38:00 +0800 Subject: [PATCH 5/7] fix lint --- mmedit/datasets/builder.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmedit/datasets/builder.py b/mmedit/datasets/builder.py index d5d15c437e..d8d0c47963 100644 --- a/mmedit/datasets/builder.py +++ b/mmedit/datasets/builder.py @@ -119,10 +119,6 @@ def build_dataloader(dataset, Returns: DataLoader: A PyTorch dataloader. """ - if torch.__version__ == 'parrots': - prefetch_num = kwargs.get('prefetch_num', 2) - DataLoader = partial(PoolDataLoader, prefetch_num=prefetch_num) - rank, world_size = get_dist_info() if dist: sampler = DistributedSampler( @@ -143,6 +139,10 @@ def build_dataloader(dataset, worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None + if torch.__version__ == 'parrots': + prefetch_num = kwargs.get('prefetch_num', 2) + DataLoader = partial(PoolDataLoader, prefetch_num=prefetch_num) + data_loader = DataLoader( dataset, batch_size=batch_size, From 8a84dd3c7f2e1aeffccead68c22ce47de4bb9207 Mon Sep 17 00:00:00 2001 From: jiaomenglei Date: Wed, 2 Sep 2020 19:44:51 +0800 Subject: [PATCH 6/7] fix lint --- mmedit/datasets/builder.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mmedit/datasets/builder.py b/mmedit/datasets/builder.py index d8d0c47963..086763891a 100644 --- a/mmedit/datasets/builder.py +++ b/mmedit/datasets/builder.py @@ -15,7 +15,7 @@ from .samplers import DistributedSampler if torch.__version__ == 'parrots': - from torch.utils.data import PoolDataLoader + from torch.utils.data import PoolDataLoader as DataLoader else: from torch.utils.data import DataLoader @@ -139,10 +139,6 @@ def build_dataloader(dataset, worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None - if torch.__version__ == 'parrots': - prefetch_num = kwargs.get('prefetch_num', 2) - DataLoader = partial(PoolDataLoader, prefetch_num=prefetch_num) - data_loader = DataLoader( dataset, batch_size=batch_size, From 85ee4e515f548b7db9f6e3548588b866001ba2ee Mon Sep 17 00:00:00 2001 From: jiaomenglei Date: Thu, 3 Sep 2020 09:35:30 +0800 Subject: [PATCH 7/7] update train.py --- mmedit/apis/train.py | 110 +++++++++++++++++++++++++++++++------------ 1 file changed, 79 insertions(+), 31 deletions(-) diff --git a/mmedit/apis/train.py b/mmedit/apis/train.py index 72df80c53b..fcfaecf2f4 100644 --- a/mmedit/apis/train.py +++ b/mmedit/apis/train.py @@ -99,15 +99,29 @@ def _dist_train(model, """ # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] - data_loaders = [ - build_dataloader( - ds, - cfg.data.samples_per_gpu, - cfg.data.workers_per_gpu, - dist=True, - drop_last=cfg.data.get('drop_last', False), - seed=cfg.seed) for ds in dataset - ] + if torch.__version__ == 'parrots': + data_loaders = [ + build_dataloader( + ds, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + dist=True, + drop_last=cfg.data.get('drop_last', False), + seed=cfg.seed, + prefetch_num=cfg.data.get('prefetch_num', 2), + pin_memory=cfg.data.get('pin_memory', False)) for ds in dataset + ] + else: + data_loaders = [ + build_dataloader( + ds, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + dist=True, + drop_last=cfg.data.get('drop_last', False), + seed=cfg.seed) for ds in dataset + ] + # put model on gpus find_unused_parameters = cfg.get('find_unused_parameters', False) model = DistributedDataParallelWrapper( @@ -146,12 +160,22 @@ def _dist_train(model, cfg.data.samples_per_gpu) workers_per_gpu = cfg.data.get('val_workers_per_gpu', cfg.data.workers_per_gpu) - data_loader = build_dataloader( - dataset, - samples_per_gpu=samples_per_gpu, - workers_per_gpu=workers_per_gpu, - dist=True, - shuffle=False) + if torch.__version__ == 'parrots': + data_loader = build_dataloader( + dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + dist=True, + shuffle=False, + prefetch_num=cfg.data.get('prefetch_num', 2), + pin_memory=cfg.data.get('pin_memory', False)) + else: + data_loader = build_dataloader( + dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + dist=True, + shuffle=False) save_path = osp.join(cfg.work_dir, 'val_visuals') runner.register_hook( DistEvalIterHook( @@ -185,16 +209,30 @@ def _non_dist_train(model, """ # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] - data_loaders = [ - build_dataloader( - ds, - cfg.data.samples_per_gpu, - cfg.data.workers_per_gpu, - cfg.gpus, - dist=False, - drop_last=cfg.data.get('drop_last', False), - seed=cfg.seed) for ds in dataset - ] + if torch.__version__ == 'parrots': + data_loaders = [ + build_dataloader( + ds, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + cfg.gpus, + dist=False, + drop_last=cfg.data.get('drop_last', False), + seed=cfg.seed, + prefetch_num=cfg.data.get('prefetch_num', 2), + pin_memory=cfg.data.get('pin_memory', False)) for ds in dataset + ] + else: + data_loaders = [ + build_dataloader( + ds, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + cfg.gpus, + dist=False, + drop_last=cfg.data.get('drop_last', False), + seed=cfg.seed) for ds in dataset + ] # put model on gpus model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() @@ -229,12 +267,22 @@ def _non_dist_train(model, cfg.data.samples_per_gpu) workers_per_gpu = cfg.data.get('val_workers_per_gpu', cfg.data.workers_per_gpu) - data_loader = build_dataloader( - dataset, - samples_per_gpu=samples_per_gpu, - workers_per_gpu=workers_per_gpu, - dist=True, - shuffle=False) + if torch.__version__ == 'parrots': + data_loader = build_dataloader( + dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + dist=True, + shuffle=False, + prefetch_num=cfg.data.get('prefetch_num', 2), + pin_memory=cfg.data.get('pin_memory', False)) + else: + data_loader = build_dataloader( + dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + dist=True, + shuffle=False) save_path = osp.join(cfg.work_dir, 'val_visuals') runner.register_hook( EvalIterHook(data_loader, save_path=save_path, **cfg.evaluation))