-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
157 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
from __future__ import print_function | ||
import torch.utils.data as data | ||
from PIL import Image | ||
import os | ||
import os.path | ||
import errno | ||
import torch | ||
import json | ||
import codecs | ||
import numpy as np | ||
|
||
class MNIST(data.Dataset): | ||
urls = [ | ||
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', | ||
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', | ||
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', | ||
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', | ||
] | ||
raw_folder = 'raw' | ||
processed_folder = 'processed' | ||
training_file = 'training.pt' | ||
test_file = 'test.pt' | ||
|
||
def __init__(self, root, train=True, transform=None, target_transform=None, download=False): | ||
self.root = root | ||
self.transform = transform | ||
self.target_transform = target_transform | ||
self.train = train # training set or test set | ||
|
||
if download: | ||
self.download() | ||
|
||
if not self._check_exists(): | ||
raise RuntimeError('Dataset not found.' | ||
+ ' You can use download=True to download it') | ||
|
||
if self.train: | ||
self.train_data, self.train_labels = torch.load(os.path.join(root, self.processed_folder, self.training_file)) | ||
else: | ||
self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file)) | ||
|
||
def __getitem__(self, index): | ||
if self.train: | ||
img, target = self.train_data[index], self.train_labels[index] | ||
else: | ||
img, target = self.test_data[index], self.test_labels[index] | ||
|
||
# doing this so that it is consistent with all other datasets | ||
# to return a PIL Image | ||
img = Image.fromarray(img.numpy(), mode='L') | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
|
||
if self.target_transform is not None: | ||
target = self.target_transform(target) | ||
|
||
return img, target | ||
|
||
def __len__(self): | ||
if self.train: | ||
return 60000 | ||
else: | ||
return 10000 | ||
|
||
def _check_exists(self): | ||
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ | ||
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) | ||
|
||
def download(self): | ||
from six.moves import urllib | ||
import gzip | ||
|
||
if self._check_exists(): | ||
print('Files already downloaded') | ||
return | ||
|
||
# download files | ||
try: | ||
os.makedirs(os.path.join(self.root, self.raw_folder)) | ||
os.makedirs(os.path.join(self.root, self.processed_folder)) | ||
except OSError as e: | ||
if e.errno == errno.EEXIST: | ||
pass | ||
else: | ||
raise | ||
|
||
for url in self.urls: | ||
print('Downloading ' + url) | ||
data = urllib.request.urlopen(url) | ||
filename = url.rpartition('/')[2] | ||
file_path = os.path.join(self.root, self.raw_folder, filename) | ||
with open(file_path, 'wb') as f: | ||
f.write(data.read()) | ||
with open(file_path.replace('.gz', ''), 'wb') as out_f, \ | ||
gzip.GzipFile(file_path) as zip_f: | ||
out_f.write(zip_f.read()) | ||
os.unlink(file_path) | ||
|
||
# process and save as torch files | ||
print('Processing') | ||
|
||
training_set = ( | ||
read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')), | ||
read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte')) | ||
) | ||
test_set = ( | ||
read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')), | ||
read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte')) | ||
) | ||
with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f: | ||
torch.save(training_set, f) | ||
with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f: | ||
torch.save(test_set, f) | ||
|
||
print('Done!') | ||
|
||
def get_int(b): | ||
return int(codecs.encode(b, 'hex'), 16) | ||
|
||
def parse_byte(b): | ||
if isinstance(b, str): | ||
return ord(b) | ||
return b | ||
|
||
def read_label_file(path): | ||
with open(path, 'rb') as f: | ||
data = f.read() | ||
assert get_int(data[:4]) == 2049 | ||
length = get_int(data[4:8]) | ||
labels = [parse_byte(b) for b in data[8:]] | ||
assert len(labels) == length | ||
return torch.LongTensor(labels) | ||
|
||
def read_image_file(path): | ||
with open(path, 'rb') as f: | ||
data = f.read() | ||
assert get_int(data[:4]) == 2051 | ||
length = get_int(data[4:8]) | ||
num_rows = get_int(data[8:12]) | ||
num_cols = get_int(data[12:16]) | ||
images = [] | ||
idx = 16 | ||
for l in range(length): | ||
img = [] | ||
images.append(img) | ||
for r in range(num_rows): | ||
row = [] | ||
img.append(row) | ||
for c in range(num_cols): | ||
row.append(parse_byte(data[idx])) | ||
idx += 1 | ||
assert len(images) == length | ||
return torch.ByteTensor(images).view(-1, 28, 28) |