-
Notifications
You must be signed in to change notification settings - Fork 11
/
automatic_evaluation.py
179 lines (144 loc) · 6.59 KB
/
automatic_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import PIL
import argparse
import pandas as pd
import torch
from transformers import CLIPProcessor, CLIPModel
import spacy
nlp = spacy.load('en_core_web_trf')
def custom_extract_attribution_indices(prompt, parser):
doc = parser(prompt)
subtrees = []
modifiers = ['amod', 'nmod', 'compound', 'npadvmod', 'det']
for w in doc:
if w.pos_ not in ['NOUN', 'PROPN'] or w.dep_ in modifiers:
continue
subtree = []
stack = []
for child in w.children:
if child.dep_ in modifiers:
subtree.append(child)
stack.extend(child.children)
while stack:
node = stack.pop()
if node.dep_ in modifiers or node.dep_ == 'conj':
subtree.append(node)
stack.extend(node.children)
if subtree:
subtree.append(w)
subtrees.append(subtree)
return subtrees
def custom_segment_text(text: str):
segments = []
doc = nlp(text)
subtrees = custom_extract_attribution_indices(doc, nlp)
if subtrees:
for subtree in subtrees:
segments.append(" ".join([t.text for t in subtree]))
return segments
def segment_text_for_automatic_eval(text: str):
segments = custom_segment_text(text)
doc = nlp(text)
# add empty trees if needed
for w in doc:
if w.pos_ in ['NOUN', 'PROPN']:
found_in_segment = False
for segment in segments:
if w.text in segment:
found_in_segment = True
break
if not found_in_segment:
segments.append(w.text)
print(f"Segments: {segments}")
return segments
class CLIPSimilarity:
def __init__(self, clip_type='openai/clip-vit-base-patch32', device=None):
self.clip = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
self.processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
self.clip_type = clip_type
def get_similarity_score(self, text, image):
text = segment_text_for_automatic_eval(text)
inputs = self.processor(images=image, text=text, return_tensors="pt", padding=True)
outputs = self.clip(**inputs)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=0) # we can take the softmax to get the label probabilities
scores = torch.mean(probs, dim=1)
return scores
def get_scorable_images(images_dir, relevant_captions):
scorable_images = []
images_dirs = [os.path.join(images_dir, caption) for caption in os.listdir(images_dir) if
os.path.isdir(os.path.join(images_dir, caption))]
print(f"Found {len(images_dirs)} images dirs.")
for caption_dir in images_dirs:
caption = caption_dir.split("/")[-1].replace("'", "_").replace("_", " ")
if caption not in relevant_captions:
continue
images = []
models = []
seed = None
for image_name in os.listdir(caption_dir):
model = image_name.split("_")[0]
if not seed:
if '.jpg' in image_name:
seed = int(image_name.split("_")[1][:-4])
else:
seed = int(image_name.split("_")[-1])
image_path = os.path.join(images_dir, caption_dir, image_name)
# image = PIL.Image.open(image_path)
with PIL.Image.open(image_path) as image:
images.append(image.copy())
# images.append(image)
models.append(model)
assert len(models) == len(images)
scorable_images.append({
'caption': caption,
'seed': seed,
"images": images,
"models": models,
})
cap = [s['caption'] for s in scorable_images]
for caption in relevant_captions:
if caption not in cap:
print(caption)
return scorable_images
def load_evaluatable_images(majority_path, images_dir, exclude_no_clear_winner=True):
majority_choices = pd.read_csv(majority_path)
if 'caption_type' in majority_choices.columns:
majority_choices = majority_choices[majority_choices['caption_type'] != 'animal_animal']
if exclude_no_clear_winner and 'human_annotation' in majority_choices.columns:
concept_mask = ~majority_choices['human_annotation'].isin(['equally good', 'equally bad', 'Undecided'])
filtered_majority_choices = majority_choices[concept_mask]
decided_captions = [caption.replace("'", " ") for caption in filtered_majority_choices['caption'].tolist()]
else:
decided_captions = [caption.replace("'", " ") for caption in majority_choices['caption'].tolist()]
scorable_images = get_scorable_images(images_dir, decided_captions)
print(len(scorable_images), len(decided_captions))
assert len(decided_captions) == len(scorable_images)
# Update the list of dictionaries with the human_annotation value from the dataframe
if 'human_annotation' in majority_choices.columns:
# Create a dictionary from the dataframe where the caption column is the index
concept_maj_dict = majority_choices.set_index('caption')['human_annotation'].to_dict()
for scorable in scorable_images:
scorable['human_annotation'] = concept_maj_dict.get(scorable['caption'], None)
return scorable_images
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Automatic Evaluation Parameters')
parser.add_argument('--captions_and_labels', type=str, required=True,
help='Path to the CSV file containing captions and labels.')
parser.add_argument('--images_dir', type=str, required=True,
help='Path to the directory containing image subdirectories.')
args = parser.parse_args()
captions_and_labels = args.captions_and_labels
images_dir = args.images_dir
data = load_evaluatable_images(captions_and_labels, images_dir, exclude_no_clear_winner=False)
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
scorer = CLIPSimilarity(device=device)
score_counter = {data[0]['models'][i]: 0 for i in range(len(data[0]['models']))}
for idx, image in enumerate(data):
image['scores'] = scorer.get_similarity_score(image['caption'], image['images'])
image['scores'] = [score.item() for score in image['scores']]
scores = image['scores']
max_index = scores.index(max(scores))
score_counter[image['models'][max_index]] += 1
data[idx] = image
print(score_counter)