-
Notifications
You must be signed in to change notification settings - Fork 0
/
caption_it.py
124 lines (75 loc) · 2.36 KB
/
caption_it.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import keras
import re
import nltk
from nltk.corpus import stopwords
import string
import json
from time import time
import pickle
from keras.applications.vgg16 import VGG16
from keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions
from keras.preprocessing import image
from keras.models import Model, load_model
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
from keras.layers import Input, Dense, Dropout, Embedding, LSTM
from keras.layers.merge import add
# In[2]:
model = load_model('./model_weights/model_9.h5')
model._make_predict_function()
# In[3]:
model_temp = ResNet50(weights="imagenet",input_shape=(224,224,3))
# In[4]:
model_resnet = Model(model_temp.input,model_temp.layers[-2].output)
model_resnet._make_predict_function()
# In[5]:
def preprocess_img(img):
img = image.load_img(img,target_size=(224,224))
img = image.img_to_array(img)
img = np.expand_dims(img,axis=0)
# Normalisation
img = preprocess_input(img)
return img
# In[10]:
def encode_image(img):
img = preprocess_img(img)
feature_vector = model_resnet.predict(img)
feature_vector = feature_vector.reshape((1,feature_vector.shape[1]))
#print(feature_vector.shape)
return feature_vector
# In[11]:
# In[13]:
# In[18]:
with open("storage/word_to_idx.pkl","rb") as f:
word_to_idx=pickle.load(f)
with open("storage/idx_to_word.pkl","rb") as f:
idx_to_word=pickle.load(f)
# In[19]:
def predict_caption(photo):
in_text = "startseq"
max_len=35
for i in range(max_len):
sequence = [word_to_idx[w] for w in in_text.split() if w in word_to_idx]
sequence = pad_sequences([sequence],maxlen=max_len,padding='post')
ypred = model.predict([photo,sequence])
ypred = ypred.argmax() #WOrd with max prob always - Greedy Sampling
word = idx_to_word[ypred]
in_text += (' ' + word)
if word == "endseq":
break
final_caption = in_text.split()[1:-1]
final_caption = ' '.join(final_caption)
return final_caption
# In[20]:
# In[ ]:
def caption_this_image(image):
enc = encode_image(image)
caption = predict_caption(enc)
return caption
# In[ ]: