forked from danfeiX/scene-graph-TF-release
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvg_to_imdb.py
128 lines (104 loc) · 4.09 KB
/
vg_to_imdb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# coding=utf8
import argparse, os, json, string
from Queue import Queue
from threading import Thread, Lock
import h5py
import numpy as np
from scipy.misc import imread, imresize
def build_filename_dict(data):
# First make sure all basenames are unique
basenames_list = [os.path.basename(img['image_path']) for img in data]
assert len(basenames_list) == len(set(basenames_list))
next_idx = 1
filename_to_idx, idx_to_filename = {}, {}
for img in data:
filename = os.path.basename(img['image_path'])
filename_to_idx[filename] = next_idx
idx_to_filename[next_idx] = filename
next_idx += 1
return filename_to_idx, idx_to_filename
def encode_filenames(data, filename_to_idx):
filename_idxs = []
for img in data:
filename = os.path.basename(img['image_path'])
idx = filename_to_idx[filename]
filename_idxs.append(idx)
return np.asarray(filename_idxs, dtype=np.int32)
def add_images(im_data, h5_file, args):
fns = []; ids = []; idx = []
corrupted_ims = ['1592.jpg', '1722.jpg', '4616.jpg', '4617.jpg']
for i, img in enumerate(im_data):
basename = str(img['image_id']) + '.jpg'
if basename in corrupted_ims:
continue
filename = os.path.join(args.image_dir, basename)
if os.path.exists(filename):
fns.append(filename)
ids.append(img['image_id'])
idx.append(i)
ids = np.array(ids, dtype=np.int32)
idx = np.array(idx, dtype=np.int32)
h5_file.create_dataset('image_ids', data=ids)
h5_file.create_dataset('valid_idx', data=idx)
num_images = len(fns)
shape = (num_images, 3, args.image_size, args.image_size)
image_dset = h5_file.create_dataset('images', shape, dtype=np.uint8)
original_heights = np.zeros(num_images, dtype=np.int32)
original_widths = np.zeros(num_images, dtype=np.int32)
image_heights = np.zeros(num_images, dtype=np.int32)
image_widths = np.zeros(num_images, dtype=np.int32)
lock = Lock()
q = Queue()
for i, fn in enumerate(fns):
q.put((i, fn))
def worker():
while True:
i, filename = q.get()
if i % 10000 == 0:
print('processing %i images...' % i)
img = imread(filename)
# handle grayscale
if img.ndim == 2:
img = img[:, :, None][:, :, [0, 0, 0]]
H0, W0 = img.shape[0], img.shape[1]
img = imresize(img, float(args.image_size) / max(H0, W0))
H, W = img.shape[0], img.shape[1]
# swap rgb to bgr. This can't be the best way right? #fail
r = img[:,:,0].copy()
img[:,:,0] = img[:,:,2]
img[:,:,2] = r
lock.acquire()
original_heights[i] = H0
original_widths[i] = W0
image_heights[i] = H
image_widths[i] = W
image_dset[i, :, :H, :W] = img.transpose(2, 0, 1)
lock.release()
q.task_done()
for i in range(args.num_workers):
t = Thread(target=worker)
t.daemon = True
t.start()
q.join()
h5_file.create_dataset('image_heights', data=image_heights)
h5_file.create_dataset('image_widths', data=image_widths)
h5_file.create_dataset('original_heights', data=original_heights)
h5_file.create_dataset('original_widths', data=original_widths)
return fns
def main(args):
im_metadata = json.load(open(args.metadata_input))
h5_fn = 'imdb_' + str(args.image_size) + '.h5'
# write the h5 file
h5_file = os.path.join(args.imh5_dir, h5_fn)
f = h5py.File(h5_file, 'w')
# load images
im_fns = add_images(im_metadata, f, args)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', default='VG/images')
parser.add_argument('--image_size', default=1024, type=int)
parser.add_argument('--imh5_dir', default='.')
parser.add_argument('--num_workers', default=20, type=int)
parser.add_argument('--metadata_input', default='VG/image_data.json', type=str)
args = parser.parse_args()
main(args)