-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgenerate_feats.py
78 lines (62 loc) · 3.08 KB
/
generate_feats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
run inference for Image Text Retrieval
"""
import argparse
import sys
import torch
from dataset.dataset import WinogroundDataset
from pathlib import Path
import pickle
import joblib
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_path = Path(args.dataset)
add_wha = True if args.model == 'uniter' else False
# Load winoground dataset
winoground_dataset = WinogroundDataset(dataset_path/"winoground_features.npy",
dataset_path/"winoground_boxes.npy",
dataset_path/"examples.jsonl",
add_wha, device)
# winoground_dataloader = DataLoader(winoground_dataset, batch_size=None, shuffle=False)
if args.model == 'uniter':
import UNITER_codebase.uniter_code as Uniter
model, tokenizer = Uniter.setup_model(device)
run_inference = lambda image, caption: Uniter.run_inference(image, caption, model, tokenizer, device)
elif args.model == 'lxmert':
import LXMERT.lxmert_code as Lxmert
model, tokenizer = Lxmert.setup_model(device)
run_inference = lambda image, caption: Lxmert.run_inference(image, caption, model, tokenizer, device)
else:
raise ValueError(f"Shouldn't reach here; Model {args.model} not supported")
winoground_sets = []
# Evaluate
model.eval()
with torch.no_grad():
for i in range(len(winoground_dataset)):
print(i, file=sys.stderr)
winoground_set = {}
image_0, caption_0, image_1, caption_1 = winoground_dataset[i]
img0_text0_outputs = run_inference(image_0, caption_0)
img0_text1_outputs = run_inference(image_0, caption_1)
img1_text0_outputs = run_inference(image_1, caption_0)
img1_text1_outputs = run_inference(image_1, caption_1)
winoground_set['caption0'] = img0_text0_outputs['caption'] # Same as img1_text0_outputs['caption']
winoground_set['caption1'] = img1_text1_outputs['caption'] # Same as img0_text1_outputs['caption']
winoground_set['output_img_0_cap_0'] = img0_text0_outputs["model_output"]
winoground_set['output_img_0_cap_1'] = img0_text1_outputs["model_output"]
winoground_set['output_img_1_cap_0'] = img1_text0_outputs["model_output"]
winoground_set['output_img_1_cap_1'] = img1_text1_outputs["model_output"]
winoground_sets.append(winoground_set)
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
with open(Path(args.output_dir)/"feats.pkl", 'wb') as f:
pickle.dump(winoground_sets, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--model", type=str, required=True, choices=["uniter", "lxmert"])
parser.add_argument("--dataset", default="dataset/", type=str)
parser.add_argument("--output_dir", type=str, required=True)
args = parser.parse_args()
main(args)