-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathtrain.py
202 lines (173 loc) · 8.66 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
from __future__ import print_function
from model import TFN
from utils import MultimodalDataset
from torch.utils.data import DataLoader
from torch.autograd import Variable
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os
import argparse
import torch
import random
import torch.nn as nn
import torch.optim as optim
import numpy as np
def preprocess(options):
# parse the input args
dataset = options['dataset']
epochs = options['epochs']
model_path = options['model_path']
max_len = options['max_len']
# prepare the paths for storing models
model_path = os.path.join(
model_path, "tfn.pt")
print("Temp location for saving model: {}".format(model_path))
# prepare the datasets
print("Currently using {} dataset.".format(dataset))
mosi = MultimodalDataset(dataset, max_len=max_len)
train_set, valid_set, test_set = mosi.train_set, mosi.valid_set, mosi.test_set
audio_dim = train_set[0][0].shape[1]
print("Audio feature dimension is: {}".format(audio_dim))
visual_dim = train_set[0][1].shape[1]
print("Visual feature dimension is: {}".format(visual_dim))
text_dim = train_set[0][2].shape[1]
print("Text feature dimension is: {}".format(text_dim))
input_dims = (audio_dim, visual_dim, text_dim)
# normalize the visual features
visual_max = np.max(np.max(np.abs(train_set.visual), axis=0), axis=0)
visual_max[visual_max==0] = 1
train_set.visual = train_set.visual / visual_max
valid_set.visual = valid_set.visual / visual_max
test_set.visual = test_set.visual / visual_max
# for visual and audio modality, we average across time
# here the original data has shape (max_len, num_examples, feature_dim)
# after averaging they become (1, num_examples, feature_dim)
train_set.visual = np.mean(train_set.visual, axis=0, keepdims=True)
train_set.audio = np.mean(train_set.audio, axis=0, keepdims=True)
valid_set.visual = np.mean(valid_set.visual, axis=0, keepdims=True)
valid_set.audio = np.mean(valid_set.audio, axis=0, keepdims=True)
test_set.visual = np.mean(test_set.visual, axis=0, keepdims=True)
test_set.audio = np.mean(test_set.audio, axis=0, keepdims=True)
# remove possible NaN values
train_set.visual[train_set.visual != train_set.visual] = 0
valid_set.visual[valid_set.visual != valid_set.visual] = 0
test_set.visual[test_set.visual != test_set.visual] = 0
train_set.audio[train_set.audio != train_set.audio] = 0
valid_set.audio[valid_set.audio != valid_set.audio] = 0
test_set.audio[test_set.audio != test_set.audio] = 0
return train_set, valid_set, test_set, input_dims
def display(test_loss, test_binacc, test_precision, test_recall, test_f1, test_septacc, test_corr):
print("MAE on test set is {}".format(test_loss))
print("Binary accuracy on test set is {}".format(test_binacc))
print("Precision on test set is {}".format(test_precision))
print("Recall on test set is {}".format(test_recall))
print("F1 score on test set is {}".format(test_f1))
print("Seven-class accuracy on test set is {}".format(test_septacc))
print("Correlation w.r.t human evaluation on test set is {}".format(test_corr))
def main(options):
DTYPE = torch.FloatTensor
train_set, valid_set, test_set, input_dims = preprocess(options)
model = TFN(input_dims, (4, 16, 128), 64, (0.3, 0.3, 0.3, 0.3), 32)
if options['cuda']:
model = model.cuda()
DTYPE = torch.cuda.FloatTensor
print("Model initialized")
criterion = nn.L1Loss(size_average=False)
optimizer = optim.Adam(list(model.parameters())[2:]) # don't optimize the first 2 params, they should be fixed (output_range and shift)
# setup training
complete = True
min_valid_loss = float('Inf')
batch_sz = options['batch_size']
patience = options['patience']
epochs = options['epochs']
model_path = options['model_path']
train_iterator = DataLoader(train_set, batch_size=batch_sz, num_workers=4, shuffle=True)
valid_iterator = DataLoader(valid_set, batch_size=len(valid_set), num_workers=4, shuffle=True)
test_iterator = DataLoader(test_set, batch_size=len(test_set), num_workers=4, shuffle=True)
curr_patience = patience
for e in range(epochs):
model.train()
model.zero_grad()
train_loss = 0.0
for batch in train_iterator:
model.zero_grad()
# the provided data has format [batch_size, seq_len, feature_dim] or [batch_size, 1, feature_dim]
x = batch[:-1]
x_a = Variable(x[0].float().type(DTYPE), requires_grad=False).squeeze()
x_v = Variable(x[1].float().type(DTYPE), requires_grad=False).squeeze()
x_t = Variable(x[2].float().type(DTYPE), requires_grad=False)
y = Variable(batch[-1].view(-1, 1).float().type(DTYPE), requires_grad=False)
output = model(x_a, x_v, x_t)
loss = criterion(output, y)
loss.backward()
train_loss += loss.data[0] / len(train_set)
optimizer.step()
print("Epoch {} complete! Average Training loss: {}".format(e, train_loss))
# Terminate the training process if run into NaN
if np.isnan(train_loss):
print("Training got into NaN values...\n\n")
complete = False
break
# On validation set we don't have to compute metrics other than MAE and accuracy
model.eval()
for batch in valid_iterator:
x = batch[:-1]
x_a = Variable(x[0].float().type(DTYPE), requires_grad=False).squeeze()
x_v = Variable(x[1].float().type(DTYPE), requires_grad=False).squeeze()
x_t = Variable(x[2].float().type(DTYPE), requires_grad=False)
y = Variable(batch[-1].view(-1, 1).float().type(DTYPE), requires_grad=False)
output = model(x_a, x_v, x_t)
valid_loss = criterion(output, y)
output_valid = output.cpu().data.numpy().reshape(-1)
y = y.cpu().data.numpy().reshape(-1)
if np.isnan(valid_loss.data[0]):
print("Training got into NaN values...\n\n")
complete = False
break
valid_binacc = accuracy_score(output_valid>=0, y>=0)
print("Validation loss is: {}".format(valid_loss.data[0] / len(valid_set)))
print("Validation binary accuracy is: {}".format(valid_binacc))
if (valid_loss.data[0] < min_valid_loss):
curr_patience = patience
min_valid_loss = valid_loss.data[0]
torch.save(model, model_path)
print("Found new best model, saving to disk...")
else:
curr_patience -= 1
if curr_patience <= 0:
break
print("\n\n")
if complete:
best_model = torch.load(model_path)
best_model.eval()
for batch in test_iterator:
x = batch[:-1]
x_a = Variable(x[0].float().type(DTYPE), requires_grad=False).squeeze()
x_v = Variable(x[1].float().type(DTYPE), requires_grad=False).squeeze()
x_t = Variable(x[2].float().type(DTYPE), requires_grad=False)
y = Variable(batch[-1].view(-1, 1).float().type(DTYPE), requires_grad=False)
output_test = best_model(x_a, x_v, x_t)
loss_test = criterion(output_test, y)
test_loss = loss_test.data[0]
output_test = output_test.cpu().data.numpy().reshape(-1)
y = y.cpu().data.numpy().reshape(-1)
test_binacc = accuracy_score(output_test>=0, y>=0)
test_precision, test_recall, test_f1, _ = precision_recall_fscore_support(y>=0, output_test>=0, average='binary')
test_septacc = (output_test.round() == y.round()).mean()
# compute the correlation between true and predicted scores
test_corr = np.corrcoef([output_test, y])[0][1] # corrcoef returns a matrix
test_loss = test_loss / len(test_set)
display(test_loss, test_binacc, test_precision, test_recall, test_f1, test_septacc, test_corr)
return
if __name__ == "__main__":
OPTIONS = argparse.ArgumentParser()
OPTIONS.add_argument('--dataset', dest='dataset',
type=str, default='MOSI')
OPTIONS.add_argument('--epochs', dest='epochs', type=int, default=50)
OPTIONS.add_argument('--batch_size', dest='batch_size', type=int, default=32)
OPTIONS.add_argument('--patience', dest='patience', type=int, default=20)
OPTIONS.add_argument('--cuda', dest='cuda', type=bool, default=False)
OPTIONS.add_argument('--model_path', dest='model_path',
type=str, default='models')
OPTIONS.add_argument('--max_len', dest='max_len', type=int, default=20)
PARAMS = vars(OPTIONS.parse_args())
main(PARAMS)