-
Notifications
You must be signed in to change notification settings - Fork 0
/
cifar100_utils.py
59 lines (55 loc) · 3.61 KB
/
cifar100_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
from torchvision.datasets import CIFAR100
def sparse2coarse(targets):
"""Convert Pytorch CIFAR100 sparse targets to coarse targets.
Usage:
trainset = torchvision.datasets.CIFAR100(path)
trainset.targets = sparse2coarse(trainset.targets)
"""
coarse_labels = np.array([ 4, 1, 14, 8, 0, 6, 7, 7, 18, 3,
3, 14, 9, 18, 7, 11, 3, 9, 7, 11,
6, 11, 5, 10, 7, 6, 13, 15, 3, 15,
0, 11, 1, 10, 12, 14, 16, 9, 11, 5,
5, 19, 8, 8, 15, 13, 14, 17, 18, 10,
16, 4, 17, 4, 2, 0, 17, 4, 18, 17,
10, 3, 2, 12, 12, 16, 12, 1, 9, 19,
2, 10, 0, 1, 16, 12, 9, 13, 15, 13,
16, 19, 2, 4, 6, 19, 5, 5, 8, 19,
18, 1, 2, 15, 6, 0, 17, 8, 14, 13])
return coarse_labels[targets]
class CIFAR100Coarse(CIFAR100):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
super(CIFAR100Coarse, self).__init__(root, train, transform, target_transform, download)
# update labels
coarse_labels = np.array([ 4, 1, 14, 8, 0, 6, 7, 7, 18, 3,
3, 14, 9, 18, 7, 11, 3, 9, 7, 11,
6, 11, 5, 10, 7, 6, 13, 15, 3, 15,
0, 11, 1, 10, 12, 14, 16, 9, 11, 5,
5, 19, 8, 8, 15, 13, 14, 17, 18, 10,
16, 4, 17, 4, 2, 0, 17, 4, 18, 17,
10, 3, 2, 12, 12, 16, 12, 1, 9, 19,
2, 10, 0, 1, 16, 12, 9, 13, 15, 13,
16, 19, 2, 4, 6, 19, 5, 5, 8, 19,
18, 1, 2, 15, 6, 0, 17, 8, 14, 13])
self.targets = coarse_labels[self.targets]
# update classes
self.classes = [['beaver', 'dolphin', 'otter', 'seal', 'whale'],
['aquarium_fish', 'flatfish', 'ray', 'shark', 'trout'],
['orchid', 'poppy', 'rose', 'sunflower', 'tulip'],
['bottle', 'bowl', 'can', 'cup', 'plate'],
['apple', 'mushroom', 'orange', 'pear', 'sweet_pepper'],
['clock', 'keyboard', 'lamp', 'telephone', 'television'],
['bed', 'chair', 'couch', 'table', 'wardrobe'],
['bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach'],
['bear', 'leopard', 'lion', 'tiger', 'wolf'],
['bridge', 'castle', 'house', 'road', 'skyscraper'],
['cloud', 'forest', 'mountain', 'plain', 'sea'],
['camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo'],
['fox', 'porcupine', 'possum', 'raccoon', 'skunk'],
['crab', 'lobster', 'snail', 'spider', 'worm'],
['baby', 'boy', 'girl', 'man', 'woman'],
['crocodile', 'dinosaur', 'lizard', 'snake', 'turtle'],
['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'],
['maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree'],
['bicycle', 'bus', 'motorcycle', 'pickup_truck', 'train'],
['lawn_mower', 'rocket', 'streetcar', 'tank', 'tractor']]