-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
62 lines (46 loc) · 1.86 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import random
import config
import matplotlib.pyplot as plt
from data_extractor import get_data
from models.vgg_rnn import VggRNN, VggLSTM
from models.inception_rnn import InceptionRNN, InceptionLSTM
from batch_generator import BatchGenerator
models = {"vggrnn": {"model": VggRNN,
"params": config.VggRNNParams},
"vgglstm": {"model": VggLSTM,
"params": config.VggLSTMParams},
"inceptionrnn": {"model": InceptionRNN,
"params": config.InceptionRNNParams},
"inceptionlstm": {"model": InceptionLSTM,
"params": config.InceptionLSTMParams}
}
extract_data = False
def main():
data_parameters = config.DataParams().__dict__
model_parameters = models[data_parameters["model_name"]]["params"]().__dict__
parameters = model_parameters.copy()
parameters.update(data_parameters)
get_data(parameters)
model = models[parameters["model_name"]]["model"](parameters)
batch_gen = BatchGenerator(**parameters)
(ims, caps) = next(batch_gen.generate("train"))
generated_captions = model.caption(ims)
random_idx = random.sample(list(range(parameters["batch_size"])), 5)
for e in range(parameters["num_epochs"]):
for idx in random_idx:
cap = generated_captions[idx]
img = ims[idx]
img = (img.permute(1, 2, 0) - img.min()) / (img.max() - img.min())
plt.imshow(img)
plt.tight_layout()
plt.title(cap)
plt.show()
print("Epoch num: " + str(e))
for idx, (im, cap) in enumerate(batch_gen.generate('train')):
if idx == 20:
break
loss = model.fit(im, cap)
print("\rTraining: " + str(loss) + " [" + "="*idx, end="", flush=True)
print("]")
if __name__ == '__main__':
main()