forked from XuyangBai/PointDSC
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_KITTI.py
116 lines (108 loc) · 4.23 KB
/
train_KITTI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import time
import shutil
import json
from config import get_config
from easydict import EasyDict as edict
from libs.loss import TransformationLoss, ClassificationLoss, SpectralMatchingLoss
from datasets.KITTI import KITTIDataset
from datasets.dataloader import get_dataloader
from libs.trainer import Trainer
from models.PointDSC import PointDSC
from torch import optim
if __name__ == '__main__':
config = get_config()
dconfig = vars(config)
for k in dconfig:
print(f" {k}: {dconfig[k]}")
config = edict(dconfig)
os.makedirs(config.snapshot_dir, exist_ok=True)
os.makedirs(config.tboard_dir, exist_ok=True)
os.makedirs(config.save_dir, exist_ok=True)
shutil.copy2(os.path.join('.', 'train_KITTI.py'), os.path.join(config.snapshot_dir, 'train.py'))
shutil.copy2(os.path.join('.', 'libs/trainer.py'), os.path.join(config.snapshot_dir, 'trainer.py'))
shutil.copy2(os.path.join('.', 'models/PointDSC.py'), os.path.join(config.snapshot_dir, 'model.py')) # for the model setting.
shutil.copy2(os.path.join('.', 'libs/loss.py'), os.path.join(config.snapshot_dir, 'loss.py'))
shutil.copy2(os.path.join('.', 'datasets/KITTI.py'), os.path.join(config.snapshot_dir, 'dataset.py'))
json.dump(
config,
open(os.path.join(config.snapshot_dir, 'config.json'), 'w'),
indent=4,
)
# create model
config.model = PointDSC(
in_dim=config.in_dim,
num_layers=config.num_layers,
num_channels=config.num_channels,
num_iterations=config.num_iterations,
inlier_threshold=config.inlier_threshold,
sigma_d=config.sigma_d,
ratio=config.ratio,
k=config.k,
)
# create optimizer
if config.optimizer == 'SGD':
config.optimizer = optim.SGD(
config.model.parameters(),
lr=config.lr,
momentum=config.momentum,
weight_decay=config.weight_decay,
)
elif config.optimizer == 'ADAM':
config.optimizer = optim.Adam(
config.model.parameters(),
lr=config.lr,
betas=(0.9, 0.999),
# momentum=config.momentum,
weight_decay=config.weight_decay,
)
config.scheduler = optim.lr_scheduler.ExponentialLR(
config.optimizer,
gamma=config.scheduler_gamma,
)
# create dataset and dataloader
train_set = KITTIDataset(
root=config.root,
split='train',
descriptor=config.descriptor,
in_dim=config.in_dim,
inlier_threshold=config.inlier_threshold,
num_node=config.num_node,
use_mutual=config.use_mutual,
augment_axis=config.augment_axis,
augment_rotation=config.augment_rotation,
augment_translation=config.augment_translation,
)
val_set = KITTIDataset(
root=config.root,
split='val',
descriptor=config.descriptor,
in_dim=config.in_dim,
inlier_threshold=config.inlier_threshold,
num_node=config.num_node,
use_mutual=config.use_mutual,
augment_axis=config.augment_axis,
augment_rotation=config.augment_rotation,
augment_translation=config.augment_translation,
)
config.train_loader = get_dataloader(dataset=train_set,
batch_size=config.batch_size,
num_workers=config.num_workers,
)
config.val_loader = get_dataloader(dataset=val_set,
batch_size=config.batch_size,
num_workers=config.num_workers,
)
# create evaluation
config.evaluate_metric = {
"ClassificationLoss": ClassificationLoss(balanced=config.balanced),
"SpectralMatchingLoss": SpectralMatchingLoss(balanced=config.balanced),
"TransformationLoss": TransformationLoss(re_thre=config.re_thre, te_thre=config.te_thre),
}
config.metric_weight = {
"ClassificationLoss": config.weight_classification,
"SpectralMatchingLoss": config.weight_spectralmatching,
"TransformationLoss": config.weight_transformation,
}
trainer = Trainer(config)
trainer.train()