-
Notifications
You must be signed in to change notification settings - Fork 1
/
image_similarity.py
235 lines (179 loc) · 7.35 KB
/
image_similarity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import argparse
import os
import pathlib
import numpy as np
import keras
import keras.applications as kapp
from PIL import Image, ExifTags
import random
import scipy
from collections import namedtuple
from jinja2 import Environment, FileSystemLoader
# Define a named tuple to keep our image information together
ImageFile = namedtuple("ImageFile", "src filename path uri")
# Define a new filter for jinja to get the name of the file from a path
def basename(text):
return text.split(os.path.sep)[-1]
env = Environment(loader=FileSystemLoader("./templates"))
env.filters["basename"] = basename
def is_image(filename):
""" Checks the extension of the file to judge if it an image or not. """
fn = filename.lower()
return fn.endswith("jpg") or fn.endswith("jpeg") or fn.endswith("png")
def find_image_files(root):
""" Starting at the root, look in all the subdirectories for image files. """
file_names = []
for path, _, files in os.walk(os.path.expanduser(root)):
for name in files:
if is_image(name):
file_names.append(os.path.join(path, name))
return file_names
def build_model(model_name):
""" Create a pretrained model without the final classification layer. """
if model_name == "resnet50":
model = kapp.resnet50.ResNet50(weights="imagenet", include_top=False)
return model, kapp.resnet50.preprocess_input
elif model_name == "vgg16":
model = kapp.vgg16.VGG16(weights="imagenet", include_top=False)
return model, kapp.vgg16.preprocess_input
else:
raise Exception("Unsupported model error")
def fix_orientation(image):
""" Look in the EXIF headers to see if this image should be rotated. """
try:
for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation] == "Orientation":
break
exif = dict(image._getexif().items())
if exif[orientation] == 3:
image = image.rotate(180, expand=True)
elif exif[orientation] == 6:
image = image.rotate(270, expand=True)
elif exif[orientation] == 8:
image = image.rotate(90, expand=True)
return image
except (AttributeError, KeyError, IndexError):
return image
def extract_center(image):
""" Most of the models need a small square image. Extract it from the center of our image."""
width, height = image.size
new_width = new_height = min(width, height)
left = (width - new_width) / 2
top = (height - new_height) / 2
right = (width + new_width) / 2
bottom = (height + new_height) / 2
return image.crop((left, top, right, bottom))
def process_images(file_names, preprocess_fn):
""" Take a list of image filenames, load the images, rotate and extract the centers,
process the data and return an array with image data. """
image_size = 224
print_interval = len(file_names) / 10
image_data = np.ndarray(shape=(len(file_names), image_size, image_size, 3))
for i, ifn in enumerate(file_names):
im = Image.open(ifn.src)
im = fix_orientation(im)
im = extract_center(im)
im = im.resize((image_size, image_size))
im = im.convert(mode="RGB")
filename = os.path.join(ifn.path, ifn.filename + ".jpg")
im.save(filename)
image_data[i] = np.array(im)
if i % print_interval == 0:
print("Processing image:", i, "of", len(file_names))
return preprocess_fn(image_data)
def generate_features(model, images):
return model.predict(images)
def calculate_distances(features):
return scipy.spatial.distance.cdist(features, features, "cosine")
def generate_site(output_path, names, features):
""" Take the features and image information. Find the closest features and
and generate static html files with similar images."""
template = env.get_template("image.html")
print_interval = len(names) / 10
# Calculate all pairwise distances
distances = calculate_distances(features)
# Go through each image, sort the distances and generate the html file
for idx, ifn in enumerate(names):
output_filename = os.path.join(output_path, ifn.filename + ".html")
dist = distances[idx]
similar_image_indexes = np.argsort(dist)[:13]
dissimilar_image_indexes = np.argsort(dist)[-12:]
similar_images = [names[i] for i in similar_image_indexes if i != idx]
dissimilar_images = [names[i] for i in dissimilar_image_indexes if i != idx]
html = template.render(
target_image=ifn,
similar_images=similar_images,
dissimilar_images=dissimilar_images,
)
with open(output_filename, "w") as text_file:
text_file.write(html)
if idx % print_interval == 0:
print("Generating page:", idx, "of", len(names))
template = env.get_template("index.html")
with open(os.path.join(output_path, "index.html"), "w") as text_file:
text_file.write(template.render(num_images=len(names)))
def parse_args():
"""Set up the various command line parameters."""
parser = argparse.ArgumentParser(description="Image similarity")
parser.add_argument(
"-i",
"--inputdir",
help="The directory to search for images",
default="~/Pictures",
)
parser.add_argument(
"-o",
"--outputdir",
help="The directory to write output htmls files",
default="./public",
)
parser.add_argument(
"-n",
"--num_images",
help="The maximum number of images to process. If more images are found n of them will be chosen randomly.",
type=int,
default=1000,
)
parser.add_argument(
"-m",
"--model",
help="The model to use to extract features.",
choices=["resnet50", "vgg16"],
default="vgg16",
)
return parser.parse_args()
def ensure_directory(directory):
pathlib.Path(directory).mkdir(parents=True, exist_ok=True)
def main():
parser = parse_args()
input_dir = parser.inputdir
output_dir = parser.outputdir
image_output_dir = os.path.join(output_dir, "images")
num_images = parser.num_images
model, preprocess_fn = build_model(parser.model)
# Find the image files and if we want to use less than are available pick them randomly
image_file_names = find_image_files(input_dir)
n = len(image_file_names)
if n > num_images:
n = num_images
image_file_names = random.sample(image_file_names, n)
print(f"Using {n} images from {input_dir}")
# namedtuple('ImageFile', 'src filename path uri')
image_files = [
ImageFile(src_file, str(i), image_output_dir, "images/" + str(i) + ".jpg")
for i, src_file in enumerate(image_file_names)
]
ensure_directory(image_output_dir)
image_data = process_images(image_files, preprocess_fn)
print("Image data shape:", image_data.shape)
# create the model and run predict on the image data to generate the features
print("Generating features for images")
features = generate_features(model, image_data)
features = features.reshape(features.shape[0], -1)
print("features shape:", features.shape)
# generate the static pages using the images and feature matrix
print("Generating static pages")
generate_site(output_dir, image_files, features)
print("Done")
if __name__ == "__main__":
main()