-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
135 lines (112 loc) · 4.91 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import streamlit as st
from oceanbase import ObImgVec
from PIL import Image
import os
import tempfile
import time
import logging
from towhee import ops,pipe,AutoPipes,AutoConfig,DataCollection
# embedding model
logging.log(logging.INFO, "init embedding model....")
img_pipe = AutoPipes.pipeline('text_image_embedding')
logging.log(logging.INFO, "init embedding model finished.")
# system
server_img_store_path = os.getenv("SERVER_IMG_STORE_PATH", "./img")
# oceanbase
def get_max_imgid():
global server_img_store_path
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'}
max_number = 0
for item in os.listdir(server_img_store_path):
full_path = os.path.join(server_img_store_path, item)
if os.path.isfile(full_path):
filename, extension = os.path.splitext(item)
if extension.lower() in image_extensions:
try:
number = int(filename)
if number > max_number:
max_number = number
except ValueError:
continue
return max_number
first_embedding = True
ob_host = "oceanbase" # change to localhost for local test
ob_port = 2881
ob_database = "test"
ob_user = "root@test"
ob_password = ""
connection_str = f"mysql+pymysql://{ob_user}:{ob_password}@{ob_host}:{ob_port}/{ob_database}?charset=utf8mb4"
obvec = ObImgVec(connection_str, "img2img", get_max_imgid())
def img_embedding(path):
return img_pipe(path).get()[0]
def process_image(image_path, target_path):
global first_embedding
vec = img_embedding(image_path)
if first_embedding:
embedding_dim = len(vec.tolist())
obvec.ob_create_img2img(embedding_dim)
first_embedding = False
obvec.ob_insert_img2img(vec, target_path)
def find_similar_images(image_path, num_results):
query_vec = img_embedding(image_path)
res, cost = obvec.ob_ann_search("<~>", query_vec, num_results)
result_paths = [r.path for r in res]
similarity = [1 - r.distance for r in res]
return result_paths, similarity, cost
# 设置应用的布局为宽模式
st.set_page_config(layout="wide")
if 'files_uploaded_tab1' not in st.session_state:
st.session_state.files_uploaded_tab1 = False
if 'file_uploaded_tab2' not in st.session_state:
st.session_state.file_uploaded_tab2 = False
def on_file_uploaded_tab1():
st.session_state.files_uploaded_tab1 = True
st.session_state.file_uploaded_tab2 = False
def on_file_uploaded_tab2():
st.session_state.file_uploaded_tab2 = True
st.session_state.files_uploaded_tab1 = False
# 创建两个功能标签页
tab1, tab2 = st.tabs(["Image Embedding", "Similar Search"])
with tab1:
st.header("Image Embedding")
# 文件导入组件
uploaded_files = st.file_uploader("Select your images",
accept_multiple_files=True,
type=['png', 'jpg', 'jpeg'],
on_change=on_file_uploaded_tab1)
if uploaded_files and st.session_state.files_uploaded_tab1:
# 初始化进度条
progress_bar = st.progress(0)
for index, uploaded_file in enumerate(uploaded_files, start=1):
suffix = os.path.splitext(uploaded_file.name)[-1]
target_path = os.path.join(server_img_store_path, f"{obvec.get_imgid()}{suffix}")
with open(target_path, "wb") as f:
f.write(uploaded_file.getvalue())
process_image(target_path, target_path)
# 更新进度条
progress_bar.progress(index / len(uploaded_files))
st.success("Image Embedding finish!")
with tab2:
st.header("Similar Search")
# 文件导入组件
uploaded_file = st.file_uploader("Choose one image",
type=['png', 'jpg', 'jpeg'],
key="uploader2",
on_change=on_file_uploaded_tab2)
# 1到10的slide组件
num_results = st.slider("How many images do you want to search?", 1, 10, value=5)
# 当文件被上传时执行
if uploaded_file and st.session_state.file_uploaded_tab2:
# 将上传的文件保存到临时目录
with tempfile.NamedTemporaryFile(delete=True, suffix=os.path.splitext(uploaded_file.name)[-1]) as tmpfile:
tmpfile.write(uploaded_file.getvalue())
image_path = tmpfile.name
st.write("Your input image:")
st.image(image_path, width=300)
# 搜索相似图片
similar_images, similarity, cost = find_similar_images(image_path, num_results)
# 展示找到的相似图片
st.write(f"Search finish in {cost} s. Here are the similar images:")
for similar_image, sim in zip(similar_images, similarity):
st.write(f"Similarity: {sim}")
st.image(similar_image, width=800)