-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmnist_loader.py
161 lines (132 loc) · 5.58 KB
/
mnist_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os, struct
from array import array as pyarray
import numpy as np
from numpy import append, array, int8, uint8, zeros
from random import shuffle
def load_mnist(dataset="training", digits=None, path=None, asbytes=False, selection=None, return_labels=True, return_indices=False):
"""
Loads MNIST files into a 3D numpy array.
You have to download the data separately from [MNIST]_. Use the ``path`` parameter
to specify the directory that contains all four downloaded MNIST files.
Parameters
----------
dataset : str
Either "training" or "testing", depending on which dataset you want to
load.
digits : list
Integer list of digits to load. The entire database is loaded if set to
``None``. Default is ``None``.
path : str
Path to your MNIST datafiles. The default is ``None``, which will try
to take the path from your environment variable ``MNIST``. The data can
be downloaded from http://yann.lecun.com/exdb/mnist/.
asbytes : bool
If True, returns data as ``numpy.uint8`` in [0, 255] as opposed to
``numpy.float64`` in [0.0, 1.0].
selection : slice
Using a `slice` object, specify what subset of the dataset to load. An
example is ``slice(0, 20, 2)``, which would load every other digit
until--but not including--the twentieth.
return_labels : bool
Specify whether or not labels should be returned. This is also a speed
performance if digits are not specified, since then the labels file
does not need to be read at all.
return_indicies : bool
Specify whether or not to return the MNIST indices that were fetched.
This is valuable only if digits is specified, because in that case it
can be valuable to know how far
in the database it reached.
Returns
-------
images : ndarray
Image data of shape ``(N, rows, cols)``, where ``N`` is the number of images. If neither labels nor inices are returned, then this is returned directly, and not inside a 1-sized tuple.
labels : ndarray
Array of size ``N`` describing the labels. Returned only if ``return_labels`` is `True`, which is default.
indices : ndarray
The indices in the database that were returned.
Examples
--------
Assuming that you have downloaded the MNIST database and set the
environment variable ``$MNIST`` point to the folder, this will load all
images and labels from the training set:
>>> images, labels = ag.io.load_mnist('training') # doctest: +SKIP
Load 100 sevens from the testing set:
>>> sevens = ag.io.load_mnist('testing', digits=[7], selection=slice(0, 100), return_labels=False) # doctest: +SKIP
"""
# The files are assumed to have these names and should be found in 'path'
files = {
'training': ('train-images-idx3-ubyte', 'train-labels-idx1-ubyte'),
'testing': ('t10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte'),
}
if path is None:
try:
path = os.environ['MNIST']
except KeyError:
raise ValueError("Unspecified path requires environment variable $MNIST to be set")
try:
images_fname = os.path.join(path, files[dataset][0])
labels_fname = os.path.join(path, files[dataset][1])
except KeyError:
raise ValueError("Data set must be 'testing' or 'training'")
# We can skip the labels file only if digits aren't specified and labels aren't asked for
if return_labels or digits is not None:
flbl = open(labels_fname, 'rb')
magic_nr, size = struct.unpack(">II", flbl.read(8))
labels_raw = pyarray("b", flbl.read())
flbl.close()
fimg = open(images_fname, 'rb')
magic_nr, size, rows, cols = struct.unpack(">IIII", fimg.read(16))
images_raw = pyarray("B", fimg.read())
fimg.close()
if digits:
indices = [k for k in range(size) if labels_raw[k] in digits]
else:
indices = range(size)
if selection:
indices = indices[selection]
N = len(indices)
images = zeros((N, rows, cols), dtype=uint8)
if return_labels:
labels = zeros((N), dtype=int8)
for i, index in enumerate(indices):
images[i] = array(images_raw[ indices[i]*rows*cols : (indices[i]+1)*rows*cols ]).reshape((rows, cols))
if return_labels:
labels[i] = labels_raw[indices[i]]
if not asbytes:
images = images.astype(float)/255.0
ret = (images,)
if return_labels:
ret += (labels,)
if return_indices:
ret += (indices,)
if len(ret) == 1:
return ret[0] # Don't return a tuple of one
else:
return ret
def randIndices(num, max):
nums = list(range(max))
shuffle(nums)
return nums[0:num]
# returns data as a dict mapping each label 0-9 to a list of objects
# each object consists of the label and a flattened representation of the image
def gather_data(dataset, path):
images, labels = load_mnist(dataset=dataset, path=path)
data = {}
label_counts = [0] * 10
for i in range(len(labels)):
label = labels[i]
img = images[i]
obj = {'label': label, 'img': img}
label_counts[label] += 1
if label in data:
data[label].append(obj)
else:
data[label] = [obj]
antimode = np.argmin(label_counts)
min_count = min(label_counts)
print('antimode: ', antimode)
print('min: ', min_count)
data_list = []
for l in data:
data_list += list(map(lambda x: {'label': l, 'img': data[l][x]['img'].flatten()}, randIndices(min_count, len(data[l]))))
return data_list