-
Notifications
You must be signed in to change notification settings - Fork 1
/
image_retrieval.py
120 lines (97 loc) · 5.59 KB
/
image_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import cv2
import time
import argparse
import numpy as np
import torch
from datetime import timedelta
from retrieval.create_thumb_images import create_thumb_images
from flask import Flask, render_template, request, redirect, url_for, make_response,jsonify, flash
from retrieval.retrieval import load_model, load_data, extract_feature, load_query_image, sort_img, extract_feature_query
#parsing instrutions
parser = argparse.ArgumentParser(description='Image Retrieval')
parser.add_argument('--update', action='store_true', default=False, help='update database')
args = parser.parse_args()
# Create thumb images. 创建缩略图 /home/sun/WorkSpace/HashCode/HashNet/pytorch/data/cifar10/test/
if args.update:
create_thumb_images(full_folder='./static/image_database/',
thumb_folder='./static/thumb_images/',
suffix='',
height=200,
del_former_thumb=True,
)
# Prepare data set.
data_loader = load_data(data_path='./static/image_database/',
batch_size=1,
shuffle=False,
transform='default',
)
# Prepare model. 加载预训练的model
model = load_model(pretrained_model='./retrieval/models/net_best.pth', use_gpu=True)
print("Model load successfully!")
# Extract database features.
# 在数据库图片不改变的情况下 选择是否保存特征向量 以节约时间
if args.update:
# Extract database features.
gallery_feature, image_paths = extract_feature(model=model, dataloaders=data_loader) # torch.Size([59, 2048])
print(type(gallery_feature))
print(gallery_feature)
print(type(image_paths))
np.save('./retrieval/models/gallery_feature.npy', gallery_feature.numpy())
np.save('./retrieval/models/image_paths.npy', np.array(image_paths))
else:
gallery_feature = np.load('./retrieval/models/gallery_feature.npy')
print(type(gallery_feature))
gallery_feature = torch.from_numpy(gallery_feature)
print(type(gallery_feature))
image_paths = np.load('./retrieval/models/image_paths.npy')
print(type(image_paths))
image_paths = image_paths.tolist()
print(type(image_paths))
print("extract_feature successfully!")
# Picture extension supported.
ALLOWED_EXTENSIONS = set(['png', 'jpg', 'JPG', 'PNG', 'bmp', 'jpeg', 'JPEG'])
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
app = Flask(__name__)
# Set static file cache expiration time
# app.send_file_max_age_default = timedelta(seconds=1)
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = timedelta(seconds=1)
@app.route('/', methods=['POST', 'GET']) # add route
def image_retrieval():
basepath = os.path.dirname(__file__) # current path
upload_path = os.path.join(basepath, 'static/upload_image','query.jpg')
if request.method == 'POST':
if request.form['submit'] == 'upload':
if len(request.files) == 0:
return render_template('upload_finish.html', message='Please select a picture file!',img_query='./static/upload_image/query.jpg?123456')
else:
f = request.files['picture']
if not (f and allowed_file(f.filename)):
# return jsonify({"error": 1001, "msg": "Examine picture extension, only png, PNG, jpg, JPG, or bmp supported."})
return render_template('upload_finish.html', message='Examine picture extension, png、PNG、jpg、JPG、bmp support.',img_query='./static/upload_image/query.jpg')
else:
f.save(upload_path)
# transform image format and name with opencv.
#img = cv2.imread(upload_path) # 从原来的读取img
#cv2.imwrite(os.path.join(basepath, 'static/upload_image', 'query.jpg'), img) # 保存到 当前目录下
return render_template('upload_finish.html', message='Upload successfully!' ,img_query='./static/upload_image/query.jpg?123456') # 点了upload之后的成功界面
elif request.form['submit'] == 'retrieval':
start_time = time.time()
# Query.
query_image = load_query_image('./static/upload_image/query.jpg')
# Extract query features.
query_feature = extract_feature_query(model=model, img=query_image) # [1,2048]
# Sort.
similarity, index = sort_img(query_feature, gallery_feature)
sorted_paths = [image_paths[i] for i in index]
print(sorted_paths) # 打印出查找之后根据相似度进行排序后的图片路径
tmb_images = ['./static/thumb_images/' + os.path.split(sorted_path)[1] for sorted_path in sorted_paths]
# sorted_files = [os.path.split(sorted_path)[1] for sorted_path in sorted_paths]
return render_template('retrieval.html', message="Retrieval finished, cost {:3f} seconds.".format(time.time() - start_time),
sml1=similarity[0], sml2=similarity[1], sml3=similarity[2], sml4=similarity[3], sml5=similarity[4], sml6=similarity[5], sml7=similarity[6], sml8=similarity[7], sml9=similarity[8],
img1_tmb=tmb_images[0], img2_tmb=tmb_images[1],img3_tmb=tmb_images[2],img4_tmb=tmb_images[3],img5_tmb=tmb_images[4],img6_tmb=tmb_images[5],img7_tmb=tmb_images[6],img8_tmb=tmb_images[7],img9_tmb=tmb_images[8],img_query='./static/upload_image/query.jpg?123456')
return render_template('upload.html')
if __name__ == '__main__':
# app.debug = True
app.run(host='127.0.0.1', port=8090, debug=True)