-
Notifications
You must be signed in to change notification settings - Fork 3
/
generate_clip_pseudolabels.py
218 lines (186 loc) · 12.2 KB
/
generate_clip_pseudolabels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import ast
import torch
import clip
from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import argparse
def main(args):
device = "cuda" if torch.cuda.is_available() else "cpu"
data_dir = '{}{}/'.format(args['root_data_dir'],args['dataset'])
print(data_dir)
metadir = data_dir+'{}_meta.csv'.format(args['dataset'])
meta = pd.read_csv(metadir,index_col=0)
meta['category_id'] = meta.category_id.astype(int)
if args['dataset'] in ['cct20','kenya','icct','serengeti']:
meta['label'] = meta['label'].apply(lambda x: x.replace('shoats','sheep or goat')) ## shoats in kenya corresponds to sheeps or goats, rename accordingly
# In Serengeti make the labels clip readable
meta['label'] = meta['label'].apply(lambda x: x.replace('guineaFowl','guineafowl')
.replace('lionFemale','lion female')
.replace('gazelleThomsons','gazelle thomsons')
.replace('vervetMonkey','vervet monkey')
.replace('lionMale','lion male')
.replace('gazelleGrants','gazelle grants')
.replace('otherBird','other bird')
.replace('koriBustard','kori bustard')
.replace('dikDik','dik dik')
.replace('batEaredFox','bat-eared fox')
.replace('secretaryBird','secretary bird')
.replace('hyenaSpotted','hyena spotted')
.replace('hyenaStriped','hyena striped')
.replace('secretaryBird','secretary bird'))
meta['label'] = meta['label'].apply(lambda x: x.replace('_',' '))
if args['dataset']=="oct":
meta['label'] = meta.label.apply(lambda x: x.replace('CNV','Choroidal Neovascularization').replace('DME','Diabetic Macular Edema')
.replace('DRUSEN','Drusen').replace('NORMAL','Healthy'))
split_list = ['train']
sampled = meta.loc[meta.img_set.isin(split_list)]
label_mapper = meta.drop_duplicates(subset=['label'])['label'].reset_index()
label_mapper.drop(columns=['index'],inplace=True)
label_mapper['clip_label'] = label_mapper.label.apply(lambda x: x.replace('_',' '))
if args['dataset'] == 'kenya':
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: 'a photo of {}.'.format(x))
elif args['dataset']=="cct20":
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: 'a photo of {}.'.format(x))
elif args['dataset']=="icct":
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: 'a photo of {}.'.format(x))
elif args['dataset']=="serengeti":
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: 'a photo of {}.'.format(x))
elif args['dataset']=="fgvc_aircraft":
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: 'a photo of a {}, a type of aircraft.'.format(x))
elif args['dataset']=="caltech_101":
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: 'a photo of a {}.'.format(x))
elif args['dataset']=="eurosat":
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: 'a centered satellite photo of {}.'.format(x))
elif args['dataset']=="oxford_pets":
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: 'a photo of a {}, a type of pet.'.format(x))
elif args['dataset']=="oxford_flowers":
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: 'a photo of a {}, a type of flower.'.format(x))
elif args['dataset']=="dtd":
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: '{} texture.'.format(x))
elif args['dataset']=="ucf101":
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: 'a photo of a person doing {}.'.format(x))
elif args['dataset']=="food101":
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: 'a photo of {}, a type of food.'.format(x))
elif args['dataset']=="sun397":
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: 'a photo of a {}.'.format(x))
elif args['dataset']=="stanford_cars":
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: 'a photo of a {}.'.format(x))
elif args['dataset']=="fmow":
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: 'a satellite photo of {}.'.format(x))
elif args['dataset']=="oct":
label_mapper['clip_label'] = label_mapper['clip_label'].apply(lambda x: 'an OCT scan of {} retina.'.format(x))
clip_to_target = dict(label_mapper[['clip_label','label']].values)
## label to pretrained model id mapping
grouped_label_to_clip_ids = pd.DataFrame(clip_to_target.values()).reset_index().rename(columns={'index':'clip_id',0:'dataset_label'}).groupby('dataset_label').clip_id.apply(list)
N = sampled.shape[0]
### Load CLIP model
model, preprocess = clip.load(args['model_subtype'], device=device)
image_links = list(sampled.img_path)
targets = list(sampled.label)
clip_labels = list(clip_to_target.keys())
pred_df = pd.DataFrame()
correct_list = []
print('DATASET : {}'.format(args['dataset']))
for i in range(N):
if i%100==0:
print('{}% done'.format(i/N))
img = Image.open(os.path.join(data_dir,image_links[i]))
image = preprocess(img).unsqueeze(0).to(device)
text = clip.tokenize(clip_labels).to(device)
with torch.no_grad():
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1) #.cpu().numpy()
label_probs = grouped_label_to_clip_ids.apply(lambda x: torch.index_select(probs.cpu(), 1, torch.tensor(x)).sum().item())
if len(clip_labels)>2:
prob1,prob2,prob3 = label_probs.sort_values(ascending=False)[0:3].values
pred1,pred2,pred3 = label_probs.sort_values(ascending=False)[0:3].index
else:
## Datasets with binary classification, less than 3 potential labels
prob1,prob2 = label_probs.sort_values(ascending=False)[0:2].values
pred1,pred2 = label_probs.sort_values(ascending=False)[0:2].index
prob3 = np.nan
pred3 = np.nan
rest_of_the_predictions = label_probs.sort_values(ascending=False)[1:].index
rest_of_the_prediction_probs =label_probs.sort_values(ascending=False)[1:].values
pred_correct = targets[i]==pred1
correct_list.append(pred_correct)
df_temp = pd.DataFrame({'dataset':args['dataset'],
'img_path':image_links[i],
'pred1': [pred1],
'pred2': [pred2],
'pred3': [pred3],
'prob1':prob1,
'prob2':prob2,
'prob3':prob3,
'rest_of_pred':[list(rest_of_the_predictions)],
'rest_of_pred_probs':[list(rest_of_the_prediction_probs)],
'correct':pred_correct,
'target': [targets[i]]})
pred_df = pred_df.append(df_temp,ignore_index=True)
num_shot=0
acc_mean = 100*np.mean(correct_list)
clip_model = 'clip_'+args['model_subtype'].replace('/','').replace('-','_')
save_line = "{},{}, {} Shot, Test acc stat: {:.2f} ()\n".format(args['dataset'],clip_model, num_shot, acc_mean, '')
print(save_line, flush=True)
pred_df['img_path_trimmed'] = pred_df['img_path'].apply(lambda x: x.replace(data_dir,''))
label_to_category = dict(meta[['label','category_id']].drop_duplicates().values)
predicted_labels = set(pred_df.pred1)
sub_df = pred_df.copy()
sub_df = sub_df.rename(columns={'target':'label'})
images_left =sub_df.shape[0]
unique_pred = sub_df.pred1.nunique()
pseudo_df = pd.DataFrame()
for pred_label in predicted_labels:
sub_label_df = pred_df.loc[(pred_df.pred1==pred_label) & (pred_df.prob1>=args['confidence_lower_bound'])]
sub_label_df = sub_label_df.sort_values('prob1',ascending=False).iloc[0:args['imgs_per_label']]
pseudo_df = pd.concat((pseudo_df,sub_label_df))
pseudo_full = pseudo_df.rename(columns={'target':'label'}).copy()
pseudo_full.drop_duplicates(subset='img_path',inplace=True)
## Check whether every label has representation in the pseudolabel space
list_of_classes_without_pseudolabel = set(pred_df.target)-predicted_labels
if len(list_of_classes_without_pseudolabel)>0:
print(args['dataset']+' NEEDED EXTRA GUESSES') ## if no 1st choice pseudolabels for a certain category, keep some of the 2nd guesses
rare_df = pd.DataFrame()
for rare_label in list_of_classes_without_pseudolabel:
temp_df = pred_df.copy()
indices = temp_df.rest_of_pred.apply(lambda x: x.index(rare_label) if rare_label in x else -1).index.to_list()
order_of_pred = temp_df.rest_of_pred.apply(lambda x: x.index(rare_label) if rare_label in x else -1).values
for idx in indices:
if order_of_pred[idx]!=-1:
temp_df.loc[idx,'rare_pred'] = temp_df.loc[idx].rest_of_pred[order_of_pred[idx]]
temp_df.loc[idx,'rare_pred_prob'] = temp_df.loc[idx].rest_of_pred_probs[order_of_pred[idx]]
if len(rare_df)>0:
temp_df = temp_df.loc[~temp_df.img_path.isin(set(rare_df.img_path))]
rare_df = pd.concat((rare_df,temp_df.dropna(subset=['rare_pred']).sort_values('rare_pred_prob',ascending=False).head(1)))
rare_df['pred1'] = rare_df['rare_pred']
pseudo_full = pseudo_full.loc[~pseudo_full.img_path.isin(set(rare_df.img_path))]
pseudo_full = pd.concat((pseudo_full,rare_df))
meta_train_replace = meta.loc[meta.img_path.isin(set(pseudo_full.img_path_trimmed))]
pseudo_full.sort_values('img_path_trimmed',inplace=True)
pseudo_full['pseudolabel'] = pseudo_full['pred1']
meta_train_replace.sort_values('img_path',inplace=True)
### Replace label and the corresponding category id with the predicted ones(pseudolabels)
meta_train_replace['label'] = pseudo_full['pseudolabel'].values
meta_train_replace['category_id'] = meta_train_replace['label'].apply(lambda x: label_to_category[x])
### keep val/test as they are for evaluation purposes
meta_test = meta.loc[~meta.img_set.isin(split_list)].copy()
### save file updated with pseudolabels
meta_new = pd.concat((meta_train_replace,meta_test))
# reset index
meta_new.reset_index(inplace=True)
meta_new.drop(columns=['index'],inplace=True)
meta_new.to_csv('{}/{}_meta_{}_pseudo_clip_{}shot.csv'.format(data_dir,args['dataset'],clip_model,args['imgs_per_label']))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--root_data_dir",type=str,default='data/')
parser.add_argument('--dataset', choices=['ucf101','food101','sun397','stanford_cars','fgvc-aircraft','caltech-101','eurosat','oxford_pets','oxford_flowers','dtd',
'kenya','cct20','serengeti','icct','fmow','oct'], type=str)
parser.add_argument("--model_subtype",type=str, choices=["ViT-B/32", "ViT-B/16","ViT-L/14", "RN50"],default="RN50", help="exact type of clip pretraining backbone")
parser.add_argument("--confidence_lower_bound", type=float,help='minimum confidence required for a pseudolabel to be kept',default=0.0)
parser.add_argument("--imgs_per_label", type=int,help='the amount of pseudolabels to keep for each of the predicted labels (ranked based on their clip confidence, higher first)',default=16)
# turn the args into a dictionary
args = vars(parser.parse_args())
main(args)