-
Notifications
You must be signed in to change notification settings - Fork 2
/
main_isl.py
100 lines (91 loc) · 4.52 KB
/
main_isl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import numpy as np
import random
import sys
from ctrl.utils.logger import Logger
from ctrl.utils.common_config import get_model, get_criterion, get_optimizer,\
get_optimizer_v2, get_data_loaders, convert_yaml_to_edict, setup_exp_params
from train_isl import train_model
import yaml
import os
from ctrl.utils.train_utils import print_output_paths
import argparse
def main():
DEBUG = False
parser = argparse.ArgumentParser(description='CTRL-UDA Training')
parser.add_argument('--machine', type=int, default=-1, help='which machine to use')
parser.add_argument('--expid', type=int, default=1, help='experiment id')
parser.add_argument('--reso', type=str, default='FULL', help='inputs resolution full or low')
parser.add_argument('--isl', type=str, default='true', help='activate the ISL training')
parser.add_argument('--exp_root_dir', type=str, help='experiment root folder')
parser.add_argument('--data_root', type=str, help='dataset root folder')
parser.add_argument('--pret_model', type=str, help='pretrained weights to be used for initialization')
parser.add_argument('--model_path', default='isl/model/path', type=str, help='trained model path for ISL training')
cmdline_inputs = parser.parse_args()
expid = cmdline_inputs.expid
if expid == 1:
exp_file = 'ctrl/configs/synthia_to_cityscapes_16cls_isl.yml'
elif expid == 2:
exp_file = 'ctrl/configs/synthia_to_cityscapes_7cls_fr_isl.yml'
elif expid == 3:
exp_file = 'ctrl/configs/synthia_to_cityscapes_7cls_lr_isl.yml'
elif expid == 4:
exp_file = 'ctrl/configs/synthia_to_mapillary_7cls_fr_isl.yml'
elif expid == 5:
exp_file = 'ctrl/configs/synthia_to_mapillary_7cls_lr_isl.yml'
cfg = convert_yaml_to_edict(exp_file)
cfg = setup_exp_params(cfg, cmdline_inputs, DEBUG)
# set random seed
torch.manual_seed(cfg.TRAIN.RANDOM_SEED)
torch.cuda.manual_seed(cfg.TRAIN.RANDOM_SEED)
np.random.seed(cfg.TRAIN.RANDOM_SEED)
random.seed(cfg.TRAIN.RANDOM_SEED)
# train, val logs
sys.stdout = Logger(cfg.TRAIN_LOG_FNAME)
# print output paths
print_output_paths(cfg, is_isl_training=True)
# get model
model, discriminator, optim_state_dict, disc_optim_state_dict, resume_iteration = get_model(cfg)
if cfg.USE_DATA_PARALLEL:
model = torch.nn.DataParallel(model)
model = model.to(cfg.GPU_ID)
if cfg.USE_DATA_PARALLEL:
discriminator = torch.nn.DataParallel(discriminator)
discriminator = discriminator.to(cfg.GPU_ID)
# get criterion
criterion_dict = get_criterion()
if cfg.USE_DATA_PARALLEL:
criterion_dict['semseg'] = torch.nn.DataParallel(criterion_dict['semseg'])
criterion_dict['depth'] = torch.nn.DataParallel(criterion_dict['depth'])
criterion_dict['disc_loss'] = torch.nn.DataParallel(criterion_dict['disc_loss'])
criterion_dict['semseg'].to(cfg.GPU_ID)
criterion_dict['depth'].to(cfg.GPU_ID)
criterion_dict['disc_loss'].to(cfg.GPU_ID)
print(criterion_dict)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
# get optimizer
optimizer, optimizer_disc = get_optimizer_v2(cfg, model, USeDataParallel=cfg.USE_DATA_PARALLEL,
discriminator=discriminator, optim_state_dict=optim_state_dict,
disc_optim_state_dict=disc_optim_state_dict)
print(optimizer)
if cfg.ENABLE_DISCRIMINATOR:
print(optimizer_disc)
# dataloaders
source_train_loader, target_train_loader, target_val_loader,\
source_train_nsamp, target_train_nsamp, target_test_nsamp = get_data_loaders(cfg, get_target_train_loader=False)
# dump cfg into a yml file
cfg_file = os.path.join(cfg.TRAIN.SNAPSHOT_DIR, 'cfg.yml')
with open(cfg_file, 'w') as fp:
yaml.dump(dict(cfg), fp)
print('cfg written to: {}'.format(cfg_file))
# train the model
# the target_train_loader will be dynamically initialized (and re-initialized) in train_model()
# because we need to switch between two different instances of the dataloader, one loads original semnaitc GT labels
# and another one loads pseudo semantic labels
optimizer_disc = None
train_model(cfg, model, discriminator, resume_iteration, criterion_dict, optimizer, optimizer_disc, source_train_loader,
target_train_loader, target_val_loader, source_train_nsamp, target_train_nsamp, target_test_nsamp)
if __name__ == "__main__":
main()