-
Notifications
You must be signed in to change notification settings - Fork 8
/
prepare_IIIT5K_dataset.py
81 lines (63 loc) · 2.31 KB
/
prepare_IIIT5K_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import pickle
import cv2
import numpy as np
import scipy.io as sio
def char_to_label(char):
if ord('A') <= ord(char) <= ord('Z'):
return ord(char) - ord('A') + 1
return 26 + ord(char) - ord('0') + 1
def load_data_mat(name):
print('Loading %s ...' % name)
mat = sio.loadmat('dataset/IIIT5K/' + name + '.mat')[name][0]
count = mat.shape[0]
labels = []
images = []
for i in range(0, count):
word = mat[i]['GroundTruth'][0]
image = mat[i]['ImgName'][0]
images.append('dataset/IIIT5K/' + image)
labels.append([])
for j in range(0, len(word)):
labels[i].append(char_to_label(word[j]))
labels[i] = np.asarray(labels[i], dtype='int32')
labels = np.asarray(labels)
return images, labels
def prepare_images(images):
decoded_images = []
for img in images:
img = cv2.imread(img, 0)
scale = 32 / img.shape[0]
img = cv2.resize(img, None, fx=scale, fy=scale)
if img.shape[1] < 256:
# Padding
img = np.concatenate([np.array([[0] * ((256 - img.shape[1]) // 2)] * 32), img], axis=1)
img = np.concatenate([img, np.array([[0] * (256 - img.shape[1])] * 32)], axis=1)
else:
img = cv2.resize(img, None, fx=256 / img.shape[1], fy=1)
if img.shape[1] != 256:
raise ValueError('shape = %d,%d' % img.shape)
decoded_images.append(img)
return np.asarray(decoded_images, np.float32) / 255
def convert_if_needed(name):
if os.path.exists('dataset/IIIT5K/' + name + '.pickle'):
return
images, labels = load_data_mat(name)
with open('dataset/IIIT5K/' + name + '.pickle', 'wb') as f:
pickle.dump((images, labels), f)
def load_data(name):
with open('dataset/IIIT5K/' + name + '.pickle', 'rb') as f:
return pickle.load(f)
def main():
convert_if_needed('traindata')
convert_if_needed('testdata')
images, labels = load_data('traindata')
images = prepare_images(images)
# assert not np.any(np.isnan(images))
# assert not np.any(np.isnan(labels))
eval_images, eval_labels = load_data('testdata')
eval_images = prepare_images(eval_images)
# assert not np.any(np.isnan(eval_images))
# assert not np.any(np.isnan(eval_images))
if __name__ == '__main__':
main()