-
Notifications
You must be signed in to change notification settings - Fork 94
/
MNIST_reader.py
70 lines (51 loc) · 2.19 KB
/
MNIST_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import struct
import numpy as np
import pickle
"""
Loosely inspired by http://abel.ee.ucla.edu/cvxopt/_downloads/mnist.py
which is GPL licensed.
"""
def read(dataset = "training", path = "."):
if dataset is "training":
fname_img = os.path.join(path, 'train-images-idx3-ubyte')
fname_lbl = os.path.join(path, 'train-labels-idx1-ubyte')
elif dataset is "testing":
fname_img = os.path.join(path, 't10k-images-idx3-ubyte')
fname_lbl = os.path.join(path, 't10k-labels-idx1-ubyte')
else:
raise ValueError, "dataset must be 'testing' or 'training'"
print(fname_lbl)
# Load everything in some numpy arrays
with open(fname_lbl, 'rb') as flbl:
magic, num = struct.unpack(">II", flbl.read(8))
lbl = np.fromfile(flbl, dtype=np.int8)
with open(fname_img, 'rb') as fimg:
magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
img = np.fromfile(fimg, dtype=np.uint8).reshape(len(lbl), rows, cols)
# Reshape and normalize
img = np.reshape(img, [img.shape[0], img.shape[1]*img.shape[2]])*1.0/255.0
return img, lbl
def get_data(d):
# load the data
x_train, y_train = read('training', d + '/MNIST_original')
x_test, y_test = read('testing', d + '/MNIST_original')
# create validation set
x_vali = list(x_train[50000:].astype(float))
y_vali = list(y_train[50000:].astype(float))
# create test_set
x_train = x_train[:50000].astype(float)
y_train = y_train[:50000].astype(float)
# sort test set (to make federated learning non i.i.d.)
indices_train = np.argsort(y_train)
sorted_x_train = list(x_train[indices_train])
sorted_y_train = list(y_train[indices_train])
# create a test set
x_test = list(x_test.astype(float))
y_test = list(y_test.astype(float))
return sorted_x_train, sorted_y_train, x_vali, y_vali, x_test, y_test
class Data:
def __init__(self, save_dir, n):
raw_directory = save_dir + '/DATA'
self.client_set = pickle.load(open(raw_directory + '/clients/' + str(n) + '_clients.pkl', 'rb'))
self.sorted_x_train, self.sorted_y_train, self.x_vali, self.y_vali, self.x_test, self.y_test = get_data(save_dir)