-
Notifications
You must be signed in to change notification settings - Fork 32
/
sample_fake.py
85 lines (60 loc) · 2.41 KB
/
sample_fake.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from multiprocessing import Process, Manager
from imageio import imread
import pickle
import numpy as np
import os
from tqdm import tqdm
fake_path = 'dataset-dist/phase-01/training/fake/'
pristine_path = 'dataset-dist/phase-01/training/pristine/'
mask_path = fake_path + 'masks/'
def count_255(mask):
i=0
for row in range(mask.shape[0]):
for col in range(mask.shape[1]):
if mask[row,col]==255:
i+=1
return i
def sample_fake(img, mask):
kernel_size = 64
stride = 32
samples = []
for y_start in range(0, img.shape[0] - kernel_size + 1, stride):
for x_start in range(0, img.shape[1] - kernel_size + 1, stride):
c_255 = count_255(mask[y_start:y_start + kernel_size, x_start:x_start + kernel_size])
if (c_255 > 1600) and (kernel_size * kernel_size - c_255 > 1600):
samples.append(img[y_start:y_start + kernel_size, x_start:x_start + kernel_size, :3])
return samples
def process(batch, common_list, images, masks):
for img, mask in zip(images, masks):
samples=sample_fake(img, mask)
for s in samples:
common_list.append(s)
print('Number of samples = ' + str(len(samples)))
print(f'batch {batch} completed')
def main():
x_train_masks = []
for i in range(9):
with open('x_train_masks_' + str(i) + '.pickle', 'rb') as f:
x_train_masks.extend(pickle.load(f))
with open('x_train_fakes_names.pickle', 'rb') as f:
x_train_fakes_names = pickle.load(f)
x_train_fake_images = []
for img in x_train_fakes_names:
x_train_fake_images.append(imread(fake_path + img))
m = Manager()
common_list = m.list()
processes = []
for batch in range(9):
processes.append(Process(target=process, args=(batch, common_list, x_train_fake_images[batch*40:(batch+1)*40],
x_train_masks[batch*40:(batch+1)*40])))
for p in processes:
p.start()
for p in processes:
p.join()
samples_fake_np = common_list[0][np.newaxis, :, :, :]
for fake_sample in common_list[1:]:
samples_fake_np = np.concatenate((samples_fake_np, fake_sample[np.newaxis, :, :, :3]), axis=0)
print('done')
np.save('sample_fakes_np.npy', samples_fake_np)
if __name__ == '__main__':
main()