forked from GMvandeVen/continual-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
exemplars.py
175 lines (138 loc) · 7.05 KB
/
exemplars.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import abc
import torch
from torch import nn
from torch.nn import functional as F
import utils
import copy
import numpy as np
class ExemplarHandler(nn.Module, metaclass=abc.ABCMeta):
"""Abstract module for a classifier that can store and use exemplars.
Adds a exemplar-methods to subclasses, and requires them to provide a 'feature-extractor' method."""
def __init__(self):
super().__init__()
# list with exemplar-sets
self.exemplar_sets = [] #--> each exemplar_set is an <np.array> of N images with shape (N, Ch, H, W)
self.exemplar_means = []
self.compute_means = True
# settings
self.memory_budget = 2000
self.norm_exemplars = True
self.herding = True
def _device(self):
return next(self.parameters()).device
def _is_on_cuda(self):
return next(self.parameters()).is_cuda
@abc.abstractmethod
def feature_extractor(self, images):
pass
####----MANAGING EXEMPLAR SETS----####
def reduce_exemplar_sets(self, m):
for y, P_y in enumerate(self.exemplar_sets):
self.exemplar_sets[y] = P_y[:m]
def construct_exemplar_set(self, dataset, n):
'''Construct set of [n] exemplars from [dataset] using 'herding'.
Note that [dataset] should be from specific class; selected sets are added to [self.exemplar_sets] in order.'''
# set model to eval()-mode
mode = self.training
self.eval()
n_max = len(dataset)
exemplar_set = []
if self.herding:
# compute features for each example in [dataset]
first_entry = True
dataloader = utils.get_data_loader(dataset, 128, cuda=self._is_on_cuda())
for (image_batch, _) in dataloader:
image_batch = image_batch.to(self._device())
with torch.no_grad():
feature_batch = self.feature_extractor(image_batch)
if first_entry:
features = feature_batch
first_entry = False
else:
features = torch.cat([features, feature_batch], dim=0)
if self.norm_exemplars:
features = F.normalize(features, p=2, dim=1)
# calculate mean of all features
class_mean = torch.mean(features, dim=0, keepdim=True)
if self.norm_exemplars:
class_mean = F.normalize(class_mean, p=2, dim=1)
# one by one, select exemplar that makes mean of all exemplars as close to [class_mean] as possible
exemplar_features = torch.zeros_like(features[:min(n, n_max)])
list_of_selected = []
for k in range(min(n, n_max)):
if k>0:
exemplar_sum = torch.sum(exemplar_features[:k], dim=0).unsqueeze(0)
features_means = (features + exemplar_sum)/(k+1)
# if self.norm_exemplars:
# # perhaps this normalization should be left out??
# features_means = F.normalize(features_means, p=2, dim=1)
features_dists = features_means - class_mean
else:
features_dists = features - class_mean
index_selected = np.argmin(torch.norm(features_dists, p=2, dim=1))
if index_selected in list_of_selected:
raise ValueError("Exemplars should not be repeated!!!!")
list_of_selected.append(index_selected)
exemplar_set.append(dataset[index_selected][0].numpy())
exemplar_features[k] = copy.deepcopy(features[index_selected])
# make sure this example won't be selected again
features[index_selected] = features[index_selected] + 10000
else:
indeces_selected = np.random.choice(n_max, size=min(n, n_max), replace=False)
for k in indeces_selected:
exemplar_set.append(dataset[k][0].numpy())
# add this [exemplar_set] as a [n]x[ich]x[isz]x[isz] to the list of [exemplar_sets]
self.exemplar_sets.append(np.array(exemplar_set))
# set mode of model back
self.train(mode=mode)
####----CLASSIFICATION----####
def classify_with_exemplars(self, x, allowed_classes=None):
"""Classify images by nearest-means-of-exemplars (after transform to feature representation)
INPUT: x = <tensor> of size (bsz,ich,isz,isz) with input image batch
allowed_classes = None or <list> containing all "active classes" between which should be chosen
OUTPUT: preds = <tensor> of size (bsz,)"""
# Set model to eval()-mode
mode = self.training
self.eval()
batch_size = x.size(0)
# Do the exemplar-means need to be recomputed?
if self.compute_means:
exemplar_means = [] #--> list of 1D-tensors (of size [feature_size]), list is of length [n_classes]
for P_y in self.exemplar_sets:
exemplars = []
# Collect all exemplars in P_y into a <tensor> and extract their features
for ex in P_y:
exemplars.append(torch.from_numpy(ex))
exemplars = torch.stack(exemplars).to(self._device())
with torch.no_grad():
features = self.feature_extractor(exemplars)
if self.norm_exemplars:
features = F.normalize(features, p=2, dim=1)
# Calculate their mean and add to list
mu_y = features.mean(dim=0, keepdim=True)
if self.norm_exemplars:
mu_y = F.normalize(mu_y, p=2, dim=1)
exemplar_means.append(mu_y.squeeze()) # -> squeeze removes all dimensions of size 1
# Update model's attributes
self.exemplar_means = exemplar_means
self.compute_means = False
# Reorganize the [exemplar_means]-<tensor>
exemplar_means = self.exemplar_means if allowed_classes is None else [
self.exemplar_means[i] for i in allowed_classes
]
means = torch.stack(exemplar_means) # (n_classes, feature_size)
means = torch.stack([means] * batch_size) # (batch_size, n_classes, feature_size)
means = means.transpose(1, 2) # (batch_size, feature_size, n_classes)
# Extract features for input data (and reorganize)
with torch.no_grad():
feature = self.feature_extractor(x) # (batch_size, feature_size)
if self.norm_exemplars:
feature = F.normalize(feature, p=2, dim=1)
feature = feature.unsqueeze(2) # (batch_size, feature_size, 1)
feature = feature.expand_as(means) # (batch_size, feature_size, n_classes)
# For each data-point in [x], find which exemplar-mean is closest to its extracted features
dists = (feature - means).pow(2).sum(dim=1).squeeze() # (batch_size, n_classes)
_, preds = dists.min(1)
# Set mode of model back
self.train(mode=mode)
return preds