-
Notifications
You must be signed in to change notification settings - Fork 44
/
utils.py
129 lines (94 loc) · 4.86 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from tqdm import tqdm
import torch
import torch.nn.functional as F
import torch.nn as nn
import clip
def cls_acc(output, target, topk=1):
pred = output.topk(topk, 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
acc = 100 * acc / target.shape[0]
return acc
def clip_classifier(classnames, template, clip_model):
with torch.no_grad():
clip_weights = []
for classname in classnames:
# Tokenize the prompts
classname = classname.replace('_', ' ')
texts = [t.format(classname) for t in template]
texts = clip.tokenize(texts).cuda()
# prompt ensemble for ImageNet
class_embeddings = clip_model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
clip_weights.append(class_embedding)
clip_weights = torch.stack(clip_weights, dim=1).cuda()
return clip_weights
def build_cache_model(cfg, clip_model, train_loader_cache):
if cfg['load_cache'] == False:
cache_keys = []
cache_values = []
with torch.no_grad():
# Data augmentation for the cache model
for augment_idx in range(cfg['augment_epoch']):
train_features = []
print('Augment Epoch: {:} / {:}'.format(augment_idx, cfg['augment_epoch']))
for i, (images, target) in enumerate(tqdm(train_loader_cache)):
images = images.cuda()
image_features = clip_model.encode_image(images)
train_features.append(image_features)
if augment_idx == 0:
target = target.cuda()
cache_values.append(target)
cache_keys.append(torch.cat(train_features, dim=0).unsqueeze(0))
cache_keys = torch.cat(cache_keys, dim=0).mean(dim=0)
cache_keys /= cache_keys.norm(dim=-1, keepdim=True)
cache_keys = cache_keys.permute(1, 0)
cache_values = F.one_hot(torch.cat(cache_values, dim=0)).half()
torch.save(cache_keys, cfg['cache_dir'] + '/keys_' + str(cfg['shots']) + "shots.pt")
torch.save(cache_values, cfg['cache_dir'] + '/values_' + str(cfg['shots']) + "shots.pt")
else:
cache_keys = torch.load(cfg['cache_dir'] + '/keys_' + str(cfg['shots']) + "shots.pt")
cache_values = torch.load(cfg['cache_dir'] + '/values_' + str(cfg['shots']) + "shots.pt")
return cache_keys, cache_values
def pre_load_features(cfg, split, clip_model, loader):
if cfg['load_pre_feat'] == False:
features, labels = [], []
with torch.no_grad():
for i, (images, target) in enumerate(tqdm(loader)):
images, target = images.cuda(), target.cuda()
image_features = clip_model.encode_image(images)
image_features /= image_features.norm(dim=-1, keepdim=True)
features.append(image_features)
labels.append(target)
features, labels = torch.cat(features), torch.cat(labels)
torch.save(features, cfg['cache_dir'] + "/" + split + "_f.pt")
torch.save(labels, cfg['cache_dir'] + "/" + split + "_l.pt")
else:
features = torch.load(cfg['cache_dir'] + "/" + split + "_f.pt")
labels = torch.load(cfg['cache_dir'] + "/" + split + "_l.pt")
return features, labels
def search_hp(cfg, cache_keys, cache_values, features, labels, clip_weights, adapter=None):
if cfg['search_hp'] == True:
beta_list = [i * (cfg['search_scale'][0] - 0.1) / cfg['search_step'][0] + 0.1 for i in range(cfg['search_step'][0])]
alpha_list = [i * (cfg['search_scale'][1] - 0.1) / cfg['search_step'][1] + 0.1 for i in range(cfg['search_step'][1])]
best_acc = 0
best_beta, best_alpha = 0, 0
for beta in beta_list:
for alpha in alpha_list:
if adapter:
affinity = adapter(features)
else:
affinity = features @ cache_keys
cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values
clip_logits = 100. * features @ clip_weights
tip_logits = clip_logits + cache_logits * alpha
acc = cls_acc(tip_logits, labels)
if acc > best_acc:
print("New best setting, beta: {:.2f}, alpha: {:.2f}; accuracy: {:.2f}".format(beta, alpha, acc))
best_acc = acc
best_beta = beta
best_alpha = alpha
print("\nAfter searching, the best accuarcy: {:.2f}.\n".format(best_acc))
return best_beta, best_alpha