-
Notifications
You must be signed in to change notification settings - Fork 8
/
utils.py
executable file
·113 lines (96 loc) · 3.54 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import re
import string
import collections
from collections import Counter
from nltk.corpus import stopwords
import numpy as np
stops = set(stopwords.words('english'))
puncs = list(string.punctuation)
def most_common(lst):
data = Counter(lst)
return max(lst, key=data.get)
def normalize_answer(s):
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(puncs)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def get_tokens(s):
if not s: return []
return normalize_answer(s).split()
def compute_exact(a_gold, a_pred):
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
def compute_f1(a_gold, a_pred):
gold_toks = get_tokens(a_gold)
pred_toks = get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if len(gold_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def single_ans_em(pred, gold):
# pred: prediction string
# gold: a list of gold answer strings
if type(gold) !=list:
gold = [gold]
pred = answer_extract_textqa(pred)
return max(compute_exact(pred, a) for a in gold)
def single_ans_f1(pred, gold):
# pred: prediction string
# gold: a list of gold answer strings
if type(gold) !=list:
gold = [gold]
pred = answer_extract_textqa(pred)
return max(compute_f1(pred, a) for a in gold)
def get_exact_match(answers1, answers2):
if type(answers1)==list:
if len(answers1)==0:
return 0
return np.max([get_exact_match(a, answers2) for a in answers1])
if type(answers2)==list:
if len(answers2)==0:
return 0
return np.max([get_exact_match(answers1, a) for a in answers2])
return (normalize_answer(answers1) == normalize_answer(answers2))
def get_f1(answers, predictions, is_equal=get_exact_match):
'''
:answers: a list of list of strings
:predictions: a list of strings
'''
assert len(answers)>0 and len(predictions)>0, (answers, predictions)
occupied_answers = [False for _ in answers]
occupied_predictions = [False for _ in predictions]
for i, answer in enumerate(answers):
for j, prediction in enumerate(predictions):
if occupied_answers[i] or occupied_predictions[j]:
continue
em = is_equal(answer, prediction)
if em:
occupied_answers[i] = True
occupied_predictions[j] = True
assert np.sum(occupied_answers)==np.sum(occupied_predictions)
a, b = np.mean(occupied_answers), np.mean(occupied_predictions)
if a+b==0:
return 0
return 2*a*b/(a+b)
def answer_match_textqa(pred, ans):
pred = answer_extract_textqa(pred)
return normalize_answer(pred) == normalize_answer(ans)
def answer_extract_textqa(pred):
prefix = "answer is "
if prefix in pred:
idx = pred.rfind(prefix)
# print ("extracted ans string: ", pred[idx + len(prefix) : ])
return pred[idx + len(prefix) : ]
return pred.strip()