-
Notifications
You must be signed in to change notification settings - Fork 0
/
web_app.py
122 lines (90 loc) · 3.44 KB
/
web_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import random
import config
import pandas as pd
import numpy as np
from flask import Flask
from flask import request
import matplotlib.pyplot as plt
from data_extractor import get_data
from models.vgg_rnn import VggRNN, VggLSTM
from models.inception_rnn import InceptionRNN, InceptionLSTM
from batch_generator import BatchGenerator
model = None
batch_gen = None
models = {"vggrnn": {"model": VggRNN,
"params": config.VggRNNParams},
"vgglstm": {"model": VggLSTM,
"params": config.VggLSTMParams},
"inceptionrnn": {"model": InceptionRNN,
"params": config.InceptionRNNParams},
"inceptionlstm": {"model": InceptionLSTM,
"params": config.InceptionLSTMParams}
}
app = Flask(__name__)
@app.route('/train', methods=['GET', 'POST'])
def train():
global model
global batch_gen
epochs = request.args.get('epochs', type=int)
for e in range(epochs):
print("Epoch num: " + str(e))
for idx, (im, cap) in enumerate(batch_gen.generate('train')):
if idx == 100:
break
loss = model.fit(im, cap)
print("\rTraining: " + str(loss) + " [" + "="*idx, end="", flush=True)
print("]")
return "Done!\n"
@app.route("/caption", methods=['GET', 'POST'])
def caption():
num_caps = request.args.get('caps', type=int)
(ims, caps) = next(batch_gen.generate("train"))
generated_captions = model.caption(ims)
random_idx = random.sample(list(range(parameters["batch_size"])), num_caps)
for idx in random_idx:
cap = generated_captions[idx]
img = ims[idx]
img = (img.permute(1, 2, 0) - img.min()) / (img.max() - img.min())
plt.imshow(img)
plt.tight_layout()
plt.title(cap)
plt.show()
return "Done!\n"
@app.route("/reset")
def reset():
global model
global batch_gen
model = models[parameters["model_name"]]["model"](parameters)
batch_gen = BatchGenerator(**parameters)
return "Reset!\n"
if __name__ == '__main__':
data_parameters = config.DataParams().__dict__
model_parameters = models[data_parameters["model_name"]]["params"]().__dict__
parameters = model_parameters.copy()
parameters.update(data_parameters)
get_data(parameters)
if parameters["data_source"] == "default":
captions = pd.read_csv("./dataset/captions.csv")
captions = np.array(captions)
histogram = [(captions == i).sum() for i in range(1004)]
histogram = np.array(histogram)
histogram[histogram > 50000] = 50000
smooth_histogram = np.log(histogram)
inverted_weights = 1 / smooth_histogram
scaled = inverted_weights + (inverted_weights - 0.13) * 12 + 0.4
elif parameters["data_source"] == "flickr":
captions = pd.read_csv("./dataset/flickr/captions.csv")
captions = np.array(captions)
histogram = [(captions == i + 1).sum() for i in range(6690)]
histogram = np.array(histogram)
histogram[histogram == 0] = 1
max_val = histogram.max()
weights = np.array([max_val/histogram[i] for i in range(6690)]) / histogram.sum()
scaled = weights
scaled[6639] = 0.0
else:
scaled = None
parameters.update({"weight": scaled})
model = models[parameters["model_name"]]["model"](parameters)
batch_gen = BatchGenerator(**parameters)
app.run(host='0.0.0.0')