-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
64 lines (55 loc) · 2.01 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import os
import shutil
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
mp.set_sharing_strategy('file_system')
from libs import load_opt, Trainer
def main(rank, opt):
torch.cuda.set_device(rank)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
print(f"Training process: {rank}")
if opt['_distributed']:
os.environ['MASTER_ADDR'] = 'localhost'
if 'MASTER_PORT' not in os.environ:
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(
backend='nccl', init_method='env://',
rank=rank, world_size=opt['_world_size']
)
trainer = Trainer(opt)
trainer.run()
if opt['_distributed']:
dist.destroy_process_group()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--opt', type=str, help="training options")
parser.add_argument('--name', type=str, help="job name")
args = parser.parse_args()
# create experiment folder
os.makedirs('experiments', exist_ok=True)
root = os.path.join('experiments', args.name)
os.makedirs(root, exist_ok=True)
try:
opt = load_opt(os.path.join(root, 'opt.yaml'), is_training=True)
except:
opt_path = os.path.join('opts', args.opt)
opt = load_opt(opt_path, is_training=True)
shutil.copyfile(opt_path, os.path.join(root, 'opt.yaml'))
os.makedirs(os.path.join(root, 'models'), exist_ok=True)
os.makedirs(os.path.join(root, 'states'), exist_ok=True)
opt['_root'] = root
opt['_resume'] = (
os.path.exists(os.path.join(root, 'models', 'last.pth'))
and os.path.exists(os.path.join(root, 'states', 'last.pth'))
)
# set up distributed training
## NOTE: only supports single-node training
opt['_world_size'] = n_gpus = torch.cuda.device_count()
opt['_distributed'] = n_gpus > 1
if opt['_distributed']:
mp.spawn(main, nprocs=n_gpus, args=(opt, ))
else:
main(0, opt)