-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstreamlit_app.py
355 lines (300 loc) · 14.4 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import streamlit as st
import re
import pandas as pd
import numpy as np
import plotly.express as px
import json
import os
import os.path as osp
import shutil
from streamlit_timeline_maker import app_timeline_maker
from streamlit_video_maker import app_video_maker
import sys
import pickle
sys.path.append(os.path.join(os.path.dirname(__file__), "face_embedding"))
sys.path.append(os.path.join(os.path.dirname(__file__), "body_embedding"))
import download_youtube.YoutubeDownloader as ytdownload
from mmtracking.utils.Tracker import tracking
from mmtracking.utils.Postprocesser import postprocessing
from timeline.TimeLineMaker import make_timeline
import sampler
import face_embedding
import group_recognizer
import predictor
from body_embedding.BodyEmbed import body_embedding_extractor
from body_embedding.BodyEmbed import generate_body_anchor
from visualization.sampling_visualization import visualize_sample
from video_generator.VideoGenerator import video_generator
### 1️. module code startline ###
def save_pickle(path, obj):
with open(path, "wb") as f:
pickle.dump(obj, f)
### 1. module code endline ###
### 2. streamlit code startline ###
# for on_click
def session_change_to_timeline():
st.session_state.page = 1
# for on_click
def session_change_to_video():
st.session_state.page = 2
# main page
def main_page():
st.title("Torch-kpop")
st.title("AI makes personal videos for you 😍")
url = st.text_input(label="Input youtube URL 🔻", placeholder="https://www.youtube.com/watch?v=KXX3F4j1xjo")
youtube_id = url.split('=')[-1]
start_sec = 0 # ⭐
end_sec = 260 # ⭐
save_dir = osp.join('./streamlit_output', youtube_id, str(start_sec) + '_' + str(end_sec))
# if input btn clicked
if st.button("SHOW TARGET VIDEO"):
# check youtube url # regex reference from https://stackoverflow.com/questions/19377262/regex-for-youtube-url
if re.match("^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube(-nocookie)?\.com|youtu.be))(\/(?:[\w\-]+\?v=|embed\/|v\/)?)([\w\-]+)(\S+)?$", url):
# if matching youtube url, show input video, and if you click SHOW TIMELINE button, session_change_to_timeline
st.session_state.url = url
st.session_state.youtube_id = youtube_id
st.session_state.start_sec = start_sec # ⭐
st.session_state.end_sec = end_sec # ⭐
st.session_state.save_dir = save_dir
st.video(url)
if st.button("SHOW TIMELINE", on_click=session_change_to_timeline): #
pass
else:
st.write("Input is not youtube URL, check URL")
# make timeline
def get_timeline_fig(timeline, meta_info):
member_list = meta_info['member_list']
df_timeline_list = []
for i, member in enumerate(member_list): # member 예시 : 'aespa_karina'
member_timeline = list(set(timeline[member]))
df_member_timeline = pd.DataFrame(np.ones_like(member_timeline) * int(i+1), columns=[member], index=member_timeline)
df_timeline_list.append(df_member_timeline)
df_group = pd.concat(df_timeline_list, axis=1)
for i, member in enumerate(member_list):
df_group.replace(i+1, member)
fig = px.scatter(df_group, width=1000, labels={'index':'\nframe', 'value':'member'})
return fig
def get_meta_info():
start_sec = st.session_state.start_sec # ⭐
end_sec = st.session_state.end_sec # ⭐
YOUTUBE_LINK = st.session_state.url
save_dir = st.session_state.save_dir
youtube_id = YOUTUBE_LINK.split('=')[-1]
meta_info_path = osp.join(save_dir, f'{youtube_id}.json')
if osp.exists(meta_info_path):
with open(meta_info_path) as meta_info_file:
meta_info = json.load(meta_info_file)
print(f'🎉 download_and_capture 함수 skip')
print(f'load 경로 : {meta_info_path}')
else: # mp4 download and frame capture
os.makedirs(save_dir, exist_ok=True) # make dir if not exist
os.makedirs(osp.join(save_dir, 'csv'), exist_ok=True) # create dir : save_dir/csv
meta_info = ytdownload.download_and_capture(YOUTUBE_LINK, start_sec, end_sec, save_dir)
return meta_info
def get_raw_df1(meta_info):
save_dir = st.session_state.save_dir
raw_df1_path = osp.join(save_dir, 'csv/df1_raw.csv')
if osp.exists(raw_df1_path):
print(f'🎉 tracking 함수 skip')
print(f'load 경로 : {raw_df1_path}')
raw_df1 = pd.read_csv(raw_df1_path)
else:
clipped_df1, raw_df1 = tracking(meta_info, output=save_dir, ANALYSIS=False) # output is save dir
return raw_df1
def get_df1_postprocessed(raw_df1, meta_info, sec=1):
save_dir = st.session_state.save_dir
df1_postprocessed_path = osp.join(save_dir, "csv/df1_postprocessed.pickle")
if os.path.exists(df1_postprocessed_path):
with open(df1_postprocessed_path, 'rb') as df1_postprocessed_pickle:
df1 = pickle.load(df1_postprocessed_pickle)
print(f'🎉 Postprocessing 함수 skip')
print(f'load 경로 : {df1_postprocessed_path}')
else:
df1 = postprocessing(raw_df1, meta_info, sec=1)
save_pickle(df1_postprocessed_path, df1) ## save
return df1
def get_df2_sampled(df1, meta_info, seconds_per_frame=1):
save_dir = st.session_state.save_dir
df2_sampled_path = osp.join(save_dir, "csv/df2_sampled.pickle")
if os.path.exists(df2_sampled_path):
with open(df2_sampled_path, 'rb') as df2_sampled_pickle:
df2 = pickle.load(df2_sampled_pickle)
print(f'🎉 sampler 함수 skip')
print(f'load 경로 : {df2_sampled_path}')
else:
df2 = sampler.sampler(df1, meta_info, seconds_per_frame=1)
save_pickle(df2_sampled_path, df2) ## save
return df2
def get_group_recognized_meta_info(meta_info, anchor_face_embedding, df1, df2):
save_dir = st.session_state.save_dir
meta_info_path = osp.join(save_dir, "csv/meta_info.json")
if osp.exists(meta_info_path):
with open(meta_info_path) as meta_info_file:
meta_info = json.load(meta_info_file)
print(f'🎉 group_recognizer 함수 skip')
print(f'load 경로 : {meta_info_path}')
else:
GR = group_recognizer.GroupRecognizer(meta_info = meta_info, anchors = anchor_face_embedding)
GR.register_dataframes(df1 = df1, df2 = df2)
meta_info = GR.guess_group()
json.dump(meta_info, open(meta_info_path, 'w'), indent=4, ensure_ascii=False)
return meta_info
def get_current_face_anchors(meta_info, anchor_face_embedding):
current_face_anchors = dict()
for k, v in anchor_face_embedding.items():
if k in meta_info['member_list']:
current_face_anchors[k] = v
return current_face_anchors
def get_df1_face(df1, df2, current_face_anchors, meta_info):
save_dir = st.session_state.save_dir
df1_face_path = osp.join(save_dir, "csv/df1_face.pickle")
if osp.exists(df1_face_path):
with open(df1_face_path, 'rb') as df1_face_pickle:
df1 = pickle.load(df1_face_pickle)
print(f'🎉 face_embedding_extractor_all 함수 skip')
print(f'load 경로 : {df1_face_path}')
else:
df1 = face_embedding.face_embedding_extractor_all(df1, df2, current_face_anchors, meta_info)
save_pickle(df1_face_path, df1) ## save
return df1
def get_df2_out_of_face_embedding(df1, df2, current_face_anchors, meta_info):
save_dir = st.session_state.save_dir
df2_out_of_face_embedding_path = osp.join(save_dir, 'csv/df2_out_of_face_embedding.pickle')
if osp.exists(df2_out_of_face_embedding_path):
with open(df2_out_of_face_embedding_path, 'rb') as df2_out_of_face_embedding_pickle:
df2 = pickle.load(df2_out_of_face_embedding_pickle)
print(f'🎉 face_embedding_extractor 함수 skip')
print(f'load 경로 : {df2_out_of_face_embedding_path}')
else:
df2 = face_embedding.face_embedding_extractor(df1, df2, current_face_anchors, meta_info)
save_pickle(df2_out_of_face_embedding_path, df2) ## save
return df2
def get_df2_out_of_body_embedding(df1, df2, save_dir, meta_info):
df2_out_of_body_embedding_path = osp.join(save_dir, 'csv/df2_out_of_body_embedding.pickle')
if osp.exists(df2_out_of_body_embedding_path):
with open(df2_out_of_body_embedding_path, 'rb') as df2_out_of_body_embedding_pickle:
df2 = pickle.load(df2_out_of_body_embedding_pickle)
print(f'🎉 generate_body_anchor, body_embedding_extractor 함수 skip')
print(f'load 경로 : {df2_out_of_body_embedding_path}')
else:
body_anchors = generate_body_anchor(df1, df2, save_dir, meta_info=meta_info) #, group_name="aespa"
df2 = body_embedding_extractor(df1, df2, body_anchors, meta_info=meta_info)
save_pickle(df2_out_of_body_embedding_path, df2) ## save
return df2
def get_pred(df1, df2, face_coefficient=1, body_coefficient=1, no_duplicate=True):
save_dir = st.session_state.save_dir
pred_path = osp.join(save_dir, 'csv/pred.pickle')
if osp.exists(pred_path):
with open(pred_path, 'rb') as pred_pickle:
pred = pickle.load(pred_pickle)
print(f'🎉 predictor 함수 skip')
print(f'load 경로 : {pred_path}')
else:
pred = predictor.predictor(df1, df2, face_coefficient=1, body_coefficient=1, no_duplicate=True)
save_pickle(pred_path, pred)
return pred
# timeline page
def timeline_page():
# show text
st.title("Timeline 🎥")
# get timeline by inference
with st.spinner('please wait...'):
start_sec = st.session_state.start_sec # ⭐
end_sec = st.session_state.end_sec # ⭐
YOUTUBE_LINK = st.session_state.url
save_dir = st.session_state.save_dir
# DOWNLOAD_PATH = './data'
youtube_id = YOUTUBE_LINK.split('=')[-1]
# 0. mp4 download and frame capture
meta_info = get_meta_info()
st.info('🎉 Download and Capture complete')
# 1. tracking
raw_df1 = get_raw_df1(meta_info)
st.info('🎉 Tracking complete')
# 2. postprocessing
df1 = get_df1_postprocessed(raw_df1, meta_info, sec=1)
st.info('🎉 Postprocessing complete')
# 3. sampling for extract body, face feature
df2 = get_df2_sampled(df1, meta_info, seconds_per_frame=1)
st.info('🎉 Sampler complete')
## load pretrained face embedding
with open("./pretrained_weight/integrated_face_embedding.json", "r", encoding="utf-8") as f:
anchor_face_embedding = json.load(f)
# 3-1. Group Recognizer
meta_info = get_group_recognized_meta_info(meta_info, anchor_face_embedding, df1, df2)
# 3-2. Make new anchor face dict containing current group members
current_face_anchors = get_current_face_anchors(meta_info, anchor_face_embedding)
st.info(f'🎉 Group Recognizer complete 🔥 group : {meta_info["group"]}')
# 4. sampling for extract body, face feature
df1 = get_df1_face(df1, df2, current_face_anchors, meta_info)
st.info('🎉 Face Embedding Extractor All complete')
# 5. query face similarity
df2 = get_df2_out_of_face_embedding(df1, df2, current_face_anchors, meta_info)
st.info('🎉 Face Embedding Extractor complete')
# 6. make body representation
df2 = get_df2_out_of_body_embedding(df1, df2, save_dir, meta_info=meta_info)
st.info('🎉 Body Embedding Extractor complete')
# extra. sampling df2 visualization
visualize_sample(df1, df2, save_dir, meta_info=meta_info)
# 7. predictor
pred = get_pred(df1, df2, face_coefficient=1, body_coefficient=1, no_duplicate=True)
st.info('🎉 Predictor complete')
# timeline maker
df1_name_tagged, timeline = make_timeline(df1, pred)
df1_name_tagged_path = osp.join(save_dir, 'csv/df1_name_tagged.csv')
df1_name_tagged.to_csv(df1_name_tagged_path) ## save
timeline_path = osp.join(save_dir, 'csv/timeline.pickle')
save_pickle(timeline_path, timeline) ## save
# show group name
st.subheader(f'{meta_info["group"]} timeline')
# get timeline figure
timeline_fig = get_timeline_fig(timeline, meta_info)
st.plotly_chart(timeline_fig, use_container_width=False)
st.session_state.df1_name_tagged_GT = df1_name_tagged
st.session_state.meta_info = meta_info
st.session_state.pred = pred
if st.button("MAKE PERSONAL VIDEO", on_click=session_change_to_video):
pass
# video page
def video_page():
st.title("show all members Video 🎵")
with st.spinner('please wait...'):
df1 = st.session_state.df1_name_tagged_GT
meta_info = st.session_state.meta_info
pred = st.session_state.pred
save_dir = st.session_state.save_dir
member_list = meta_info['member_list']
for member in member_list:
member_video_path = osp.join(save_dir, f'make_video_video_{member}', f'{member}_output.mp4') # 저장할 곳에 video 파일이 이미 있으면 그대로 사용
if osp.exists(member_video_path): # 저장할 곳에 video 파일이 이미 존재하면 load
print(f'The {member} video file already exists.')
else: # 저장할 곳에 video 파일이 없으면 generate
print(member)
member_video_path = video_generator(df1, meta_info, member, pred, save_dir, face_loc=3, video_size=0.4)
# encoding h264(dst)
src = member_video_path
dst = "./temp.mp4"
os.system("ffmpeg " +
f"-i {src} " +
f"-c:v h264 " +
f"-c:a copy {dst}")
# delete mp4v(src)
os.remove(src)
# move (dst) to (src)
shutil.move(dst, src)
# show video
st.subheader(f'{member} 직캠')
video_file_per_member = open(member_video_path, 'rb')
video_bytes_per_member = video_file_per_member.read()
st.video(video_bytes_per_member)
st.text("!🎉 End 🎉!")
# init session_state
if "page" not in st.session_state:
st.session_state.page = 0
# print cli current page number
if st.session_state.page == 0:
main_page()
elif st.session_state.page == 1:
timeline_page()
elif st.session_state.page == 2:
video_page()