Skip to content

Commit

Permalink
update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
magicdream2222 committed Sep 3, 2020
1 parent 8a84dd3 commit 85ee4e5
Showing 1 changed file with 79 additions and 31 deletions.
110 changes: 79 additions & 31 deletions mmedit/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,29 @@ def _dist_train(model,
"""
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
dist=True,
drop_last=cfg.data.get('drop_last', False),
seed=cfg.seed) for ds in dataset
]
if torch.__version__ == 'parrots':
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
dist=True,
drop_last=cfg.data.get('drop_last', False),
seed=cfg.seed,
prefetch_num=cfg.data.get('prefetch_num', 2),
pin_memory=cfg.data.get('pin_memory', False)) for ds in dataset
]
else:
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
dist=True,
drop_last=cfg.data.get('drop_last', False),
seed=cfg.seed) for ds in dataset
]

# put model on gpus
find_unused_parameters = cfg.get('find_unused_parameters', False)
model = DistributedDataParallelWrapper(
Expand Down Expand Up @@ -146,12 +160,22 @@ def _dist_train(model,
cfg.data.samples_per_gpu)
workers_per_gpu = cfg.data.get('val_workers_per_gpu',
cfg.data.workers_per_gpu)
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=True,
shuffle=False)
if torch.__version__ == 'parrots':
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=True,
shuffle=False,
prefetch_num=cfg.data.get('prefetch_num', 2),
pin_memory=cfg.data.get('pin_memory', False))
else:
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=True,
shuffle=False)
save_path = osp.join(cfg.work_dir, 'val_visuals')
runner.register_hook(
DistEvalIterHook(
Expand Down Expand Up @@ -185,16 +209,30 @@ def _non_dist_train(model,
"""
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False,
drop_last=cfg.data.get('drop_last', False),
seed=cfg.seed) for ds in dataset
]
if torch.__version__ == 'parrots':
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False,
drop_last=cfg.data.get('drop_last', False),
seed=cfg.seed,
prefetch_num=cfg.data.get('prefetch_num', 2),
pin_memory=cfg.data.get('pin_memory', False)) for ds in dataset
]
else:
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False,
drop_last=cfg.data.get('drop_last', False),
seed=cfg.seed) for ds in dataset
]
# put model on gpus
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()

Expand Down Expand Up @@ -229,12 +267,22 @@ def _non_dist_train(model,
cfg.data.samples_per_gpu)
workers_per_gpu = cfg.data.get('val_workers_per_gpu',
cfg.data.workers_per_gpu)
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=True,
shuffle=False)
if torch.__version__ == 'parrots':
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=True,
shuffle=False,
prefetch_num=cfg.data.get('prefetch_num', 2),
pin_memory=cfg.data.get('pin_memory', False))
else:
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=True,
shuffle=False)
save_path = osp.join(cfg.work_dir, 'val_visuals')
runner.register_hook(
EvalIterHook(data_loader, save_path=save_path, **cfg.evaluation))
Expand Down

0 comments on commit 85ee4e5

Please sign in to comment.