-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_cifar_10.py
97 lines (95 loc) · 4.3 KB
/
load_cifar_10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
import matplotlib.pyplot as plt
import pickle
"""
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000
training images and 10000 test images.
The dataset is divided into five training batches and one test batch, each with 10000 images. The test batch contains
exactly 1000 randomly-selected images from each class. The training batches contain the remaining images in random
order, but some training batches may contain more images from one class than another. Between them, the training
batches contain exactly 5000 images from each class.
"""
def unpickle(file):
"""load the cifar-10 data"""
with open(file, 'rb') as fo:
data = pickle.load(fo, encoding='bytes')
return data
def load_cifar_10_data(data_dir, negatives=False):
"""
Return train_data, train_filenames, train_labels, test_data, test_filenames, test_labels
"""
# get the meta_data_dict
# num_cases_per_batch: 1000
# label_names: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# num_vis: :3072
meta_data_dict = unpickle(data_dir + "/batches.meta")
cifar_label_names = meta_data_dict[b'label_names']
cifar_label_names = np.array(cifar_label_names)
# training data
cifar_train_data = None
cifar_train_filenames = []
cifar_train_labels = []
# cifar_train_data_dict
# 'batch_label': 'training batch 5 of 5'
# 'data': ndarray
# 'filenames': list
# 'labels': list
for i in range(1, 6):
cifar_train_data_dict = unpickle(data_dir + "/data_batch_{}".format(i))
if i == 1:
cifar_train_data = cifar_train_data_dict[b'data']
else:
cifar_train_data = np.vstack((cifar_train_data, cifar_train_data_dict[b'data']))
cifar_train_filenames += cifar_train_data_dict[b'filenames']
cifar_train_labels += cifar_train_data_dict[b'labels']
cifar_train_data = cifar_train_data.reshape((len(cifar_train_data), 3, 32, 32))
if negatives:
cifar_train_data = cifar_train_data.transpose(0, 2, 3, 1).astype(np.float32)
else:
cifar_train_data = np.rollaxis(cifar_train_data, 1, 4)
cifar_train_filenames = np.array(cifar_train_filenames)
cifar_train_labels = np.array(cifar_train_labels)
# test data
# cifar_test_data_dict
# 'batch_label': 'testing batch 1 of 1'
# 'data': ndarray
# 'filenames': list
# 'labels': list
cifar_test_data_dict = unpickle(data_dir + "/test_batch")
cifar_test_data = cifar_test_data_dict[b'data']
cifar_test_filenames = cifar_test_data_dict[b'filenames']
cifar_test_labels = cifar_test_data_dict[b'labels']
cifar_test_data = cifar_test_data.reshape((len(cifar_test_data), 3, 32, 32))
if negatives:
cifar_test_data = cifar_test_data.transpose(0, 2, 3, 1).astype(np.float32)
else:
cifar_test_data = np.rollaxis(cifar_test_data, 1, 4)
cifar_test_filenames = np.array(cifar_test_filenames)
cifar_test_labels = np.array(cifar_test_labels)
return cifar_train_data, cifar_train_filenames, cifar_train_labels, \
cifar_test_data, cifar_test_filenames, cifar_test_labels, cifar_label_names
if __name__ == "__main__":
"""show it works"""
cifar_10_dir = "./data/cifar10"#'cifar10'
train_data, train_filenames, train_labels, test_data, test_filenames, test_labels, label_names = \
load_cifar_10_data(cifar_10_dir)
print("Train data: ", train_data.shape)
print("Train filenames: ", train_filenames.shape)
print("Train labels: ", train_labels.shape)
print("Test data: ", test_data.shape)
print("Test filenames: ", test_filenames.shape)
print("Test labels: ", test_labels.shape)
print("Label names: ", label_names.shape)
# Don't forget that the label_names and filesnames are in binary and need conversion if used.
# display some random training images in a 25x25 grid
num_plot = 5
f, ax = plt.subplots(num_plot, num_plot)
for m in range(num_plot):
for n in range(num_plot):
idx = np.random.randint(0, train_data.shape[0])
ax[m, n].imshow(train_data[idx])
ax[m, n].get_xaxis().set_visible(False)
ax[m, n].get_yaxis().set_visible(False)
f.subplots_adjust(hspace=0.1)
f.subplots_adjust(wspace=0)
plt.show()