-
Notifications
You must be signed in to change notification settings - Fork 2
/
gen_face_masks.py
197 lines (165 loc) · 8.02 KB
/
gen_face_masks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import sys
# git clone https://github.com/zllrunning/face-parsing.PyTorch face_parsing
sys.path.append('face_parsing')
from face_parsing.model import BiSeNet
import torch
import os
import os.path as osp
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import cv2
from pprint import pprint
import argparse
def vis_parsing_maps(im, parsing_anno, stride, save_im, save_path):
# Colors for all 20 parts
part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
[255, 0, 85], [255, 0, 170],
[0, 255, 0], [85, 255, 0], [170, 255, 0],
[0, 255, 85], [0, 255, 170],
[0, 0, 255], [85, 0, 255], [170, 0, 255],
[0, 85, 255], [0, 170, 255],
[255, 255, 0], [255, 255, 85], [255, 255, 170],
[255, 0, 255], [255, 85, 255], [255, 170, 255],
[0, 255, 255], [85, 255, 255], [170, 255, 255]]
im = np.array(im)
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
num_of_class = np.max(vis_parsing_anno)
for pi in range(1, num_of_class + 1):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
# print(vis_parsing_anno_color.shape, vis_im.shape)
#vis_im = im.copy().astype(np.uint8)
#vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
# Save result or not
if save_im:
# Convert to binary mask
vis_parsing_anno[vis_parsing_anno!=0] = 255
cv2.imwrite(save_path[:-4] +'_mask.png', vis_parsing_anno)
#cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
# return vis_im
def gen_masks(ckpt_path, src_paths, result_path, exist_path=None,
trash_path_suffix='trash', inspect_path_suffix='inspect', max_imgs_per_person=-1,
device='cuda'):
if not os.path.exists(result_path):
os.makedirs(result_path)
if isinstance(src_paths, (list, tuple)):
src_path = src_paths[0]
else:
src_path = src_paths
if src_path.endswith('/') or src_path.endswith('\\'):
src_path = src_path[:-1]
# trash_path: a folder to save images with too few (<=9) parts.
trash_path = src_path + "_" + trash_path_suffix
if not os.path.exists(trash_path):
os.makedirs(trash_path)
# inspect_path: a folder to save images with too many (>=18) parts.
inspect_path = src_path + "_" + inspect_path_suffix
if not os.path.exists(inspect_path):
os.makedirs(inspect_path)
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.to(device)
net.load_state_dict(torch.load(ckpt_path))
net.eval()
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
parts_number_stats = {}
img_count = 0
# If src_paths is a list/tuple, then we assume there's no subfolders under src_paths.
# We use src_paths as subj_paths.
if isinstance(src_paths, (list, tuple)):
subj_paths = src_paths
# Don't include subfolders in the path when saving files.
# osp.join(result_path, "") will add a trailing "/" to result_path, which is harmless.
subj_dirs = [ "" for _ in subj_paths ]
# Otherwise, we assume there are subfolders under src_paths. We put them into subj_paths.
else:
subj_dirs = list(os.listdir(src_paths))
subj_paths = []
subj_dirs2 = []
for subj_dir in subj_dirs:
# Already exists in exist_path. Skip.
if exist_path is not None and os.path.isdir(osp.join(exist_path, subj_dir)):
print(f"{subj_path} has been processed before, skip")
continue
if not os.path.isdir(osp.join(src_paths, subj_dir)):
continue
subj_paths.append(osp.join(src_paths, subj_dir))
subj_dirs2.append(subj_dir)
subj_dirs = subj_dirs2
for (subj_dir, subj_path) in zip(subj_dirs, subj_paths):
print(f"Processing {subj_path} '{subj_dir}'")
subj_img_count = 0
for img_path in sorted(os.listdir(subj_path)):
if img_path[-4:].lower() not in ['.jpg', '.png', 'jpeg', '.bmp', '.tif', 'tiff', 'webp']:
continue
if img_path.endswith("_mask.png"):
continue
img_full_path = osp.join(subj_path, img_path)
new_img_path = osp.join(result_path, subj_dir, img_path)
if os.path.exists(new_img_path):
continue
subj_img_count += 1
if max_imgs_per_person > 0 and subj_img_count > max_imgs_per_person:
break
image_obj = Image.open(img_full_path)
image_obj = image_obj.resize((512, 512), Image.BILINEAR).convert("RGB")
img = to_tensor(image_obj)
img = torch.unsqueeze(img, 0)
img = img.to(device)
with torch.no_grad():
out = net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0)
unique_parts = np.unique(parsing)
# print(img_path, unique_parts)
parts_number_stats[len(unique_parts)] = parts_number_stats.get(len(unique_parts), 0) + 1
img_count += 1
if img_count % 100 == 0:
print(f'{img_count}: ', end="")
pprint(parts_number_stats)
# Bad images. Move to trash folder.
if len(unique_parts) <= 9:
img_trash_path = osp.join(trash_path, subj_dir, img_path)
if not os.path.exists(osp.join(trash_path, subj_dir)):
os.makedirs(osp.join(trash_path, subj_dir))
print(f"{img_full_path} -> {img_trash_path}")
os.rename(img_full_path, img_trash_path)
continue
if len(unique_parts) >= 18:
img_inspect_path = osp.join(inspect_path, subj_dir, img_path)
if not os.path.exists(osp.join(inspect_path, subj_dir)):
os.makedirs(osp.join(inspect_path, subj_dir))
os.rename(img_full_path, img_inspect_path)
print(f'Image {img_inspect_path} has {len(unique_parts)} parts.')
continue
if not os.path.exists(osp.join(result_path, subj_dir)):
os.makedirs(osp.join(result_path, subj_dir))
print(f"{img_full_path} -> {new_img_path}")
# Save the image, instead of copying it. So that the new image will be (512, 512).
image_obj.save(new_img_path, compress_level=1)
#shutil.copy(img_full_path, new_img_path)
vis_parsing_maps(image_obj, parsing, stride=1, save_im=True,
save_path=osp.join(result_path, subj_dir, img_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--face_folder', type=str, help='The folder containing face images.')
parser.add_argument('--device', type=int, default=0, help='GPU index to run the model.')
args = parser.parse_args()
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
ckpt_path = osp.join('models/BiSeNet', '79999_iter.pth')
face_folder = args.face_folder
if face_folder.endswith('/') or face_folder.endswith('\\'):
face_folder = face_folder[:-1]
face_folder_par_dir, face_folder_name = osp.split(face_folder)
# If the face folder doesn't have subfolders, we need to put it in a list/tuple.
gen_masks(ckpt_path, src_paths=[face_folder],
result_path=osp.join(face_folder_par_dir, f"{face_folder_name}_masks"),
device=device)