-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
114 lines (89 loc) · 4.63 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import argparse
import yaml
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataset.dataset import CXRVisDialDataset
from dataset.vocabulary import Vocabulary
from models.models import LateFusionModel, RecursiveAttentionModel, StackedAttentionModel
from utils import report_metric
if __name__ == "__main__":
# ------------------------------------------------------------------------
# ARGUMENTS
# ------------------------------------------------------------------------
parser = argparse.ArgumentParser()
parser.add_argument("--test_json",
required=False,
help="Location of the test json file")
parser.add_argument("--test_img_feats",
required=False,
help="Location of test images features h5 file")
parser.add_argument("--word_counts",
required=True,
help="Location of the word counts file")
parser.add_argument("--config_yml",
default="config.yaml",
help="Location of the config yaml file")
parser.add_argument("--model_path",
required=True,
help="Output location where the weights are stored")
parser.add_argument("--embeddings",
default=None,
help="Whether pretrained embeddings should be used. If yes, the argument is a path to a pickled file.")
parser.add_argument("--model",
default="lf",
help="Model to use for training. Valid options are: `lf` (LateFusion), "
"`rva` (RecursiveVisualAttention), and `san` (Stacked Attention Network)")
args = parser.parse_args()
with open(args.config_yml) as f:
config = yaml.safe_load(f)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# ------------------------------------------------------------------------
# DATASET & DATALOADER
# ------------------------------------------------------------------------
if args.model == "rva":
mode = 'seq'
elif args.model == "lf" or args.model == "san":
mode = 'concat'
else:
raise ValueError("Unknown model")
# If word_counts are passed, use them to construct the vocabulary. If word embeddings are also passed,
# use them for initializing the embedding layer. Otherwise, use BERT
train_vocabulary = Vocabulary(args.word_counts)
test_dataset = CXRVisDialDataset(args.test_img_feats, args.test_json, args.word_counts, mode)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'])
# ------------------------------------------------------------------------
# SETUP
# ------------------------------------------------------------------------
if args.model == "lf":
model = LateFusionModel(config, train_vocabulary, args.embeddings)
elif args.model == "rva":
model = RecursiveAttentionModel(config, train_vocabulary)
elif args.model == "san":
model = StackedAttentionModel(config, train_vocabulary)
model = model.to(device)
model.load_state_dict(torch.load(args.model_path))
print("Model weights restored")
model.eval()
# ------------------------------------------------------------------------
# EVALUATION
# ------------------------------------------------------------------------
targets = []
outputs = []
turns = []
for batch in tqdm(test_dataloader, desc="Evaluation"):
image, history, question, options, caption, turn = batch['image'], batch['history_ids'], \
batch['question_ids'], batch['options'], \
batch['caption_ids'], batch['turn']
image, history, question, options, caption, turn = image.to(device), history.to(device), question.to(device), \
options.to(device), caption.to(device), turn.to(device)
output = model(image, history, question, options, caption, turn)
target = batch["answer_ind"].to(device)
targets.append(target.detach().cpu().numpy())
outputs.append(output.detach().cpu().numpy())
turns.append(turn.detach().cpu().numpy())
f1_scores, conf_matrix, macro_f1, accuracies = report_metric(targets, outputs, turns)
scores_dict = {'Yes': f1_scores[0], 'No': f1_scores[1], 'Maybe': f1_scores[2],
'N/A': f1_scores[3], 'macro_f1': macro_f1}
print(conf_matrix)
print(scores_dict)