Skip to content

Commit

Permalink
Load and parse metadata for CIFAR-10, CIFAR-100 (pytorch#502)
Browse files Browse the repository at this point in the history
* cifar10.meta['label_names']

* cifar100.meta['fine_label_names']

* cifar100.meta['coarse_label_names']
  • Loading branch information
xenosoz authored and varunagrawal committed Jul 23, 2018
1 parent 4311767 commit 3e04694
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class CIFAR10(data.Dataset):
['test_batch', '40351d587109b95175f43aff81a1287e'],
]

meta_list = [
['batches.meta', '5ff9c542aee3614f3951f8cda6e48888'],
]

def __init__(self, root, train=True,
transform=None, target_transform=None,
download=False):
Expand Down Expand Up @@ -100,6 +104,16 @@ def __init__(self, root, train=True,
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC

f = self.meta_list[0][0]
file = os.path.join(self.root, self.base_folder, f)
fo = open(file, 'rb')
if sys.version_info[0] == 2:
entry = pickle.load(fo)
else:
entry = pickle.load(fo, encoding='latin1')
fo.close()
self.meta = entry

def __getitem__(self, index):
"""
Args:
Expand Down Expand Up @@ -133,7 +147,7 @@ def __len__(self):

def _check_integrity(self):
root = self.root
for fentry in (self.train_list + self.test_list):
for fentry in (self.train_list + self.test_list + self.meta_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not check_integrity(fpath, md5):
Expand Down Expand Up @@ -187,3 +201,7 @@ class CIFAR100(CIFAR10):
test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]

meta_list = [
['meta', '7973b15100ade9c7d40fb424638fde48'],
]

0 comments on commit 3e04694

Please sign in to comment.