-
Notifications
You must be signed in to change notification settings - Fork 63
/
train.py
267 lines (226 loc) · 11.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import torch
from torch import nn
from models import SAINT
from data_openml import data_prep_openml,task_dset_ids,DataSetCatCon
import argparse
from torch.utils.data import DataLoader
import torch.optim as optim
from utils import count_parameters, classification_scores, mean_sq_error
from augmentations import embed_data_mask
from augmentations import add_noise
import os
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('--dset_id', required=True, type=int)
parser.add_argument('--vision_dset', action = 'store_true')
parser.add_argument('--task', required=True, type=str,choices = ['binary','multiclass','regression'])
parser.add_argument('--cont_embeddings', default='MLP', type=str,choices = ['MLP','Noemb','pos_singleMLP'])
parser.add_argument('--embedding_size', default=32, type=int)
parser.add_argument('--transformer_depth', default=6, type=int)
parser.add_argument('--attention_heads', default=8, type=int)
parser.add_argument('--attention_dropout', default=0.1, type=float)
parser.add_argument('--ff_dropout', default=0.1, type=float)
parser.add_argument('--attentiontype', default='colrow', type=str,choices = ['col','colrow','row','justmlp','attn','attnmlp'])
parser.add_argument('--optimizer', default='AdamW', type=str,choices = ['AdamW','Adam','SGD'])
parser.add_argument('--scheduler', default='cosine', type=str,choices = ['cosine','linear'])
parser.add_argument('--lr', default=0.0001, type=float)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--batchsize', default=256, type=int)
parser.add_argument('--savemodelroot', default='./bestmodels', type=str)
parser.add_argument('--run_name', default='testrun', type=str)
parser.add_argument('--set_seed', default= 1 , type=int)
parser.add_argument('--dset_seed', default= 5 , type=int)
parser.add_argument('--active_log', action = 'store_true')
parser.add_argument('--pretrain', action = 'store_true')
parser.add_argument('--pretrain_epochs', default=50, type=int)
parser.add_argument('--pt_tasks', default=['contrastive','denoising'], type=str,nargs='*',choices = ['contrastive','contrastive_sim','denoising'])
parser.add_argument('--pt_aug', default=[], type=str,nargs='*',choices = ['mixup','cutmix'])
parser.add_argument('--pt_aug_lam', default=0.1, type=float)
parser.add_argument('--mixup_lam', default=0.3, type=float)
parser.add_argument('--train_mask_prob', default=0, type=float)
parser.add_argument('--mask_prob', default=0, type=float)
parser.add_argument('--ssl_avail_y', default= 0, type=int)
parser.add_argument('--pt_projhead_style', default='diff', type=str,choices = ['diff','same','nohead'])
parser.add_argument('--nce_temp', default=0.7, type=float)
parser.add_argument('--lam0', default=0.5, type=float)
parser.add_argument('--lam1', default=10, type=float)
parser.add_argument('--lam2', default=1, type=float)
parser.add_argument('--lam3', default=10, type=float)
parser.add_argument('--final_mlp_style', default='sep', type=str,choices = ['common','sep'])
opt = parser.parse_args()
modelsave_path = os.path.join(os.getcwd(),opt.savemodelroot,opt.task,str(opt.dset_id),opt.run_name)
if opt.task == 'regression':
opt.dtask = 'reg'
else:
opt.dtask = 'clf'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device is {device}.")
torch.manual_seed(opt.set_seed)
os.makedirs(modelsave_path, exist_ok=True)
if opt.active_log:
import wandb
if opt.pretrain:
wandb.init(project="saint_v2_all", group =opt.run_name ,name = f'pretrain_{opt.task}_{str(opt.attentiontype)}_{str(opt.dset_id)}_{str(opt.set_seed)}')
else:
if opt.task=='multiclass':
wandb.init(project="saint_v2_all_kamal", group =opt.run_name ,name = f'{opt.task}_{str(opt.attentiontype)}_{str(opt.dset_id)}_{str(opt.set_seed)}')
else:
wandb.init(project="saint_v2_all", group =opt.run_name ,name = f'{opt.task}_{str(opt.attentiontype)}_{str(opt.dset_id)}_{str(opt.set_seed)}')
print('Downloading and processing the dataset, it might take some time.')
cat_dims, cat_idxs, con_idxs, X_train, y_train, X_valid, y_valid, X_test, y_test, train_mean, train_std = data_prep_openml(opt.dset_id, opt.dset_seed,opt.task, datasplit=[.65, .15, .2])
continuous_mean_std = np.array([train_mean,train_std]).astype(np.float32)
##### Setting some hyperparams based on inputs and dataset
_,nfeat = X_train['data'].shape
if nfeat > 100:
opt.embedding_size = min(8,opt.embedding_size)
opt.batchsize = min(64, opt.batchsize)
if opt.attentiontype != 'col':
opt.transformer_depth = 1
opt.attention_heads = min(4,opt.attention_heads)
opt.attention_dropout = 0.8
opt.embedding_size = min(32,opt.embedding_size)
opt.ff_dropout = 0.8
print(nfeat,opt.batchsize)
print(opt)
if opt.active_log:
wandb.config.update(opt)
train_ds = DataSetCatCon(X_train, y_train, cat_idxs,opt.dtask,continuous_mean_std)
trainloader = DataLoader(train_ds, batch_size=opt.batchsize, shuffle=True,num_workers=4)
valid_ds = DataSetCatCon(X_valid, y_valid, cat_idxs,opt.dtask, continuous_mean_std)
validloader = DataLoader(valid_ds, batch_size=opt.batchsize, shuffle=False,num_workers=4)
test_ds = DataSetCatCon(X_test, y_test, cat_idxs,opt.dtask, continuous_mean_std)
testloader = DataLoader(test_ds, batch_size=opt.batchsize, shuffle=False,num_workers=4)
if opt.task == 'regression':
y_dim = 1
else:
y_dim = len(np.unique(y_train['data'][:,0]))
cat_dims = np.append(np.array([1]),np.array(cat_dims)).astype(int) #Appending 1 for CLS token, this is later used to generate embeddings.
model = SAINT(
categories = tuple(cat_dims),
num_continuous = len(con_idxs),
dim = opt.embedding_size,
dim_out = 1,
depth = opt.transformer_depth,
heads = opt.attention_heads,
attn_dropout = opt.attention_dropout,
ff_dropout = opt.ff_dropout,
mlp_hidden_mults = (4, 2),
cont_embeddings = opt.cont_embeddings,
attentiontype = opt.attentiontype,
final_mlp_style = opt.final_mlp_style,
y_dim = y_dim
)
vision_dset = opt.vision_dset
if y_dim == 2 and opt.task == 'binary':
# opt.task = 'binary'
criterion = nn.CrossEntropyLoss().to(device)
elif y_dim > 2 and opt.task == 'multiclass':
# opt.task = 'multiclass'
criterion = nn.CrossEntropyLoss().to(device)
elif opt.task == 'regression':
criterion = nn.MSELoss().to(device)
else:
raise'case not written yet'
model.to(device)
if opt.pretrain:
from pretraining import SAINT_pretrain
model = SAINT_pretrain(model, cat_idxs,X_train,y_train, continuous_mean_std, opt,device)
## Choosing the optimizer
if opt.optimizer == 'SGD':
optimizer = optim.SGD(model.parameters(), lr=opt.lr,
momentum=0.9, weight_decay=5e-4)
from utils import get_scheduler
scheduler = get_scheduler(opt, optimizer)
elif opt.optimizer == 'Adam':
optimizer = optim.Adam(model.parameters(),lr=opt.lr)
elif opt.optimizer == 'AdamW':
optimizer = optim.AdamW(model.parameters(),lr=opt.lr)
best_valid_auroc = 0
best_valid_accuracy = 0
best_test_auroc = 0
best_test_accuracy = 0
best_valid_rmse = 100000
print('Training begins now.')
for epoch in range(opt.epochs):
model.train()
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
optimizer.zero_grad()
# x_categ is the the categorical data, x_cont has continuous data, y_gts has ground truth ys. cat_mask is an array of ones same shape as x_categ and an additional column(corresponding to CLS token) set to 0s. con_mask is an array of ones same shape as x_cont.
x_categ, x_cont, y_gts, cat_mask, con_mask = data[0].to(device), data[1].to(device),data[2].to(device),data[3].to(device),data[4].to(device)
# We are converting the data to embeddings in the next step
_ , x_categ_enc, x_cont_enc = embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset)
reps = model.transformer(x_categ_enc, x_cont_enc)
# select only the representations corresponding to CLS token and apply mlp on it in the next step to get the predictions.
y_reps = reps[:,0,:]
y_outs = model.mlpfory(y_reps)
if opt.task == 'regression':
loss = criterion(y_outs,y_gts)
else:
loss = criterion(y_outs,y_gts.squeeze())
loss.backward()
optimizer.step()
if opt.optimizer == 'SGD':
scheduler.step()
running_loss += loss.item()
# print(running_loss)
if opt.active_log:
wandb.log({'epoch': epoch ,'train_epoch_loss': running_loss,
'loss': loss.item()
})
if epoch%5==0:
model.eval()
with torch.no_grad():
if opt.task in ['binary','multiclass']:
accuracy, auroc = classification_scores(model, validloader, device, opt.task,vision_dset)
test_accuracy, test_auroc = classification_scores(model, testloader, device, opt.task,vision_dset)
print('[EPOCH %d] VALID ACCURACY: %.3f, VALID AUROC: %.3f' %
(epoch + 1, accuracy,auroc ))
print('[EPOCH %d] TEST ACCURACY: %.3f, TEST AUROC: %.3f' %
(epoch + 1, test_accuracy,test_auroc ))
if opt.active_log:
wandb.log({'valid_accuracy': accuracy ,'valid_auroc': auroc })
wandb.log({'test_accuracy': test_accuracy ,'test_auroc': test_auroc })
if opt.task =='multiclass':
if accuracy > best_valid_accuracy:
best_valid_accuracy = accuracy
best_test_auroc = test_auroc
best_test_accuracy = test_accuracy
torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))
else:
if accuracy > best_valid_accuracy:
best_valid_accuracy = accuracy
# if auroc > best_valid_auroc:
# best_valid_auroc = auroc
best_test_auroc = test_auroc
best_test_accuracy = test_accuracy
torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))
else:
valid_rmse = mean_sq_error(model, validloader, device,vision_dset)
test_rmse = mean_sq_error(model, testloader, device,vision_dset)
print('[EPOCH %d] VALID RMSE: %.3f' %
(epoch + 1, valid_rmse ))
print('[EPOCH %d] TEST RMSE: %.3f' %
(epoch + 1, test_rmse ))
if opt.active_log:
wandb.log({'valid_rmse': valid_rmse ,'test_rmse': test_rmse })
if valid_rmse < best_valid_rmse:
best_valid_rmse = valid_rmse
best_test_rmse = test_rmse
torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))
model.train()
total_parameters = count_parameters(model)
print('TOTAL NUMBER OF PARAMS: %d' %(total_parameters))
if opt.task =='binary':
print('AUROC on best model: %.3f' %(best_test_auroc))
elif opt.task =='multiclass':
print('Accuracy on best model: %.3f' %(best_test_accuracy))
else:
print('RMSE on best model: %.3f' %(best_test_rmse))
if opt.active_log:
if opt.task == 'regression':
wandb.log({'total_parameters': total_parameters, 'test_rmse_bestep':best_test_rmse ,
'cat_dims':len(cat_idxs) , 'con_dims':len(con_idxs) })
else:
wandb.log({'total_parameters': total_parameters, 'test_auroc_bestep':best_test_auroc ,
'test_accuracy_bestep':best_test_accuracy,'cat_dims':len(cat_idxs) , 'con_dims':len(con_idxs) })