-
Notifications
You must be signed in to change notification settings - Fork 11
/
train.py
107 lines (89 loc) · 3.63 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import random
import sys
import os
import numpy as np
import torch
import torch.backends.cudnn
import torch.cuda
import torch.nn
import torch.utils.data
from torchpack import distributed as dist
from torchpack.callbacks import InferenceRunner, MaxSaver, Saver
from torchpack.environ import auto_set_run_dir, set_run_dir
from torchpack.utils.config import configs
from torchpack.utils.logging import logger
from core import builder
from core.callbacks import MeanIoU
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument('config', metavar='FILE', help='config file')
parser.add_argument('--run-dir', metavar='DIR', help='run directory')
parser.add_argument('--gpu', default='0', help='gpu index')
args, opts = parser.parse_known_args()
configs.load(args.config, recursive=True)
configs.update(opts)
if configs.distributed:
dist.init()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
torch.backends.cudnn.benchmark = True
torch.cuda.set_device(dist.local_rank())
if args.run_dir is None:
args.run_dir = auto_set_run_dir()
else:
set_run_dir(args.run_dir)
configs.run_dir = args.run_dir
logger.info(' '.join([sys.executable] + sys.argv))
logger.info(f'Experiment started: "{args.run_dir}".' + '\n' + f'{configs}')
# seed
if ('seed' not in configs.train) or (configs.train.seed is None):
configs.train.seed = torch.initial_seed() % (2 ** 32 - 1)
seed = configs.train.seed + dist.rank() * configs.workers_per_gpu * configs.num_epochs
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if configs.dataset.name == 'semantic_kitti':
dataset = builder.make_dataset()
else:
raise ValueError
dataflow = {}
for split in dataset:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset[split],
num_replicas=dist.size(),
rank=dist.rank(),
shuffle=(split == 'train'))
dataflow[split] = torch.utils.data.DataLoader(
dataset[split],
batch_size=configs.batch_size,
sampler=sampler,
num_workers=configs.workers_per_gpu,
pin_memory=True,
collate_fn=dataset[split].collate_fn)
model = builder.make_model().cuda()
if configs.distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[dist.local_rank()], find_unused_parameters=True)
criterion = builder.make_criterion()
optimizer = builder.make_optimizer(model)
scheduler = builder.make_scheduler(optimizer)
from core.trainers import SemanticKITTITrainer
trainer = SemanticKITTITrainer(model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
num_workers=configs.workers_per_gpu,
seed=seed,
amp_enabled=configs.amp_enabled)
trainer.train_with_defaults(
dataflow['train'],
num_epochs=configs.num_epochs,
callbacks=[InferenceRunner(dataflow[split],
callbacks=[MeanIoU(name=f'iou/{split}',
num_classes=configs.data.num_classes,
ignore_label=configs.data.ignore_label)],
) for split in ['test']] +
[MaxSaver('iou/test'), Saver(),])
if __name__ == '__main__':
main()