-
Notifications
You must be signed in to change notification settings - Fork 67
/
train.py
69 lines (55 loc) · 2.34 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import torch
import random
import numpy as np
from core.train_options import TrainOptions
from core.cfgs import cfg, parse_args_extend
from utils.train_utils import prepare_env
from core.trainer import Trainer
import torch.distributed as dist
import logging
logger = logging.getLogger(__name__)
def main(gpu, ngpus_per_node, options):
parse_args_extend(options)
options.batch_size = cfg.TRAIN.BATCH_SIZE
options.workers = cfg.TRAIN.NUM_WORKERS
options.gpu = gpu
options.ngpus_per_node = ngpus_per_node
if options.distributed:
dist.init_process_group(backend=options.dist_backend, init_method=options.dist_url,
world_size=options.world_size, rank=options.local_rank)
if options.multiprocessing_distributed:
# For multiprocessing distributed training, rank needs to be the
# global rank among all the processes
options.rank, world_size = dist.get_rank(), dist.get_world_size()
assert options.rank == options.local_rank
assert world_size == options.world_size
trainer = Trainer(options)
trainer.fit()
if __name__ == '__main__':
options = TrainOptions().parse_args()
parse_args_extend(options)
if options.local_rank == 0:
prepare_env(options)
else:
options.checkpoint_dir = ''
if cfg.SEED_VALUE >= 0:
logger.info(f'Seed value for the experiment {cfg.SEED_VALUE}')
os.environ['PYTHONHASHSEED'] = str(cfg.SEED_VALUE)
random.seed(cfg.SEED_VALUE)
torch.manual_seed(cfg.SEED_VALUE)
np.random.seed(cfg.SEED_VALUE)
ngpus_per_node = torch.cuda.device_count()
options.distributed = (ngpus_per_node > 1) or options.multiprocessing_distributed
if options.multiprocessing_distributed:
# Since we have ngpus_per_node processes per node, the total world_size
# needs to be adjusted accordingly
options.world_size = ngpus_per_node * options.world_size
# Use torch.multiprocessing.spawn to launch distributed processes: the
# main_worker process function
# mp.spawn(main, nprocs=ngpus_per_node, args=(ngpus_per_node, options))
main(options.local_rank, ngpus_per_node, options)
else:
# Simply call main_worker function
# main_worker(args.gpu, ngpus_per_node, args)
main(None, ngpus_per_node, options)