-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
113 lines (96 loc) · 3.55 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import random
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from dataset import Dictionary, SelfCriticalDataset
from train import run
import opts
import pdb
import transformers
def weights_init_kn(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal(m.weight.data, a=0.01)
if __name__ == '__main__':
opt = opts.parse_opt()
### fail safe
if opt.dataset == 'hatcp':
assert opt.impt_threshold == 0.55
elif opt.dataset == 'gqacp':
assert opt.impt_threshold == 0.3
elif opt.dataset == 'clevrxai':
assert opt.impt_threshold == 0.85
## Set random seeds for reproducibility
if opt.seed == 0:
seed = random.randint(1, 10000)
seed = 0
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
seed = opt.seed
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)
torch.backends.cudnn.benchmark = True # For reproducibility
# load dictionary
dictionary = Dictionary.load_from_file(f'{opt.data_dir}/dictionary.pkl')
opt.ntokens = dictionary.ntoken
print("dictionary ntoken", dictionary.ntoken)
if opt.use_scr_loss:
opt.apply_answer_weight = True
### creating datasets
# train dataset
if opt.split is not None:
train_dset = SelfCriticalDataset(opt.split, opt.hint_type, dictionary, opt,
discard_items_without_hints=not opt.do_not_discard_items_without_hints)
train_loader = DataLoader(train_dset,
opt.batch_size,
shuffle=True,
num_workers=opt.num_workers)
else:
train_dset = None
train_loader = None
# val dataset
eval_dset = SelfCriticalDataset(opt.split_test, opt.hint_type, dictionary, opt,
discard_items_without_hints = opt.discard_items_for_test)
eval_loader = DataLoader(eval_dset, opt.batch_size, shuffle=False, num_workers=opt.num_workers)
if opt.use_two_testsets:
assert(opt.split_test_2 is not None)
eval_dset_2 = SelfCriticalDataset(opt.split_test_2, opt.hint_type, dictionary, opt,
discard_items_without_hints = opt.discard_items_for_test)
eval_loader_2 = DataLoader(eval_dset_2, opt.batch_size, shuffle=False, num_workers=opt.num_workers)
else:
eval_loader_2 = None
# update opts
opt.full_v_dim = eval_dset.full_v_dim
opt.num_ans_candidates = eval_dset.num_ans_candidates
opt.num_objects = eval_dset.num_objects
## Create model
if opt.model_type == 'updn':
from models.updn import UpDn
model = UpDn(opt)
elif opt.model_type == 'updn_ramen':
from models.updn import UpDn_ramen
model = UpDn_ramen(opt)
elif opt.model_type == 'lang_only':
# load language-only updn model
from models.lang_only import LangOnly
model = LangOnly(opt)
elif opt.model_type == 'lxmert':
from models.lxmert import lxmert
model = lxmert(opt)
else:
raise ValueError("unsupported model type")
model = model.cuda()
if 'lxmert' not in opt.model_type:
model.apply(weights_init_kn)
model = nn.DataParallel(model).cuda()
model.train()
run(model,
train_loader,
eval_loader,
eval_loader_2,
opt)