-
Notifications
You must be signed in to change notification settings - Fork 19
/
score_translit.py
132 lines (109 loc) · 4.13 KB
/
score_translit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from typing import List
import os
import codecs
import random
import numpy as np
from pathlib import Path
from pandas import read_csv
random.seed(42)
SCORED_PARTS = ('train', 'dev', 'train_small', 'dev_small', 'test')
TRANSLIT_PATH = Path(__file__).with_name("TRANSLIT")
def load_dataset(data_dir_path=None, parts: List[str] = SCORED_PARTS):
part2ixy = {}
for part in parts:
path = os.path.join(data_dir_path, f'{part}.tsv')
with open(path, 'r', encoding='utf-8') as rf:
# first line is a header of the corresponding columns
lines = rf.readlines()[1:]
col_count = len(lines[0].strip('\n').split('\t'))
if col_count == 2:
strings, transliterations = zip(
*list(map(lambda l: l.strip('\n').split('\t'), lines))
)
elif col_count == 1:
strings = list(map(lambda l: l.strip('\n'), lines))
transliterations = None
else:
raise ValueError("wrong amount of columns")
part2ixy[part] = (
[f'{part}/{i}' for i in range(len(strings))],
strings, transliterations,
)
return part2ixy
def load_transliterations_only(data_dir_path=None, parts: List[str] = SCORED_PARTS):
part2iy = {}
for part in parts:
path = os.path.join(data_dir_path, f'{part}.tsv')
with open(path, 'r', encoding='utf-8') as rf:
# first line is a header of the corresponding columns
lines = rf.readlines()[1:]
col_count = len(lines[0].strip('\n').split('\t'))
n_lines = len(lines)
if col_count == 2:
transliterations = [l.strip('\n').split('\t')[1] for l in lines]
elif col_count == 1:
transliterations = None
else:
raise ValueError("Wrong amount of columns")
part2iy[part] = (
[f'{part}/{i}' for i in range(n_lines)],
transliterations,
)
return part2iy
def save_preds(preds, preds_fname):
"""
Save classifier predictions in format appropriate for scoring.
"""
with codecs.open(preds_fname, 'w', encoding='utf-8') as outp:
for idx, preds in preds:
print(idx, *preds, sep='\t', file=outp)
print('Predictions saved to %s' % preds_fname)
def load_preds(preds_fname, top_k=1, compressed=False):
"""
Load classifier predictions in format appropriate for scoring.
"""
kwargs = {
"filepath_or_buffer": preds_fname,
"names": ["id", "pred"],
"sep": '\t',
"compression": 'gzip' if compressed else 'infer'
}
pred_ids = list(read_csv(**kwargs, usecols=["id"])["id"])
pred_y = {
pred_id: [y]
for pred_id, y in zip(
pred_ids, read_csv(**kwargs, usecols=["pred"])["pred"]
)
}
for y in pred_y.values():
assert len(y) == top_k, f'len(y)={len(y)}, top_k={top_k}, {preds_fname}'
return pred_ids, pred_y
def compute_hit_k(preds, k=10):
raise NotImplementedError
def compute_mrr(preds):
raise NotImplementedError
def compute_acc_1(preds, true):
right_answers = 0
for pred, y in zip(preds, true):
if pred[0] == y:
right_answers += 1
return right_answers / len(preds)
def score(preds, true):
assert len(preds) == len(true), 'inconsistent amount of predictions and ground truth answers'
acc_1 = compute_acc_1(preds, true)
return {'acc@1': acc_1}
def score_preds(preds_path, data_dir, parts=SCORED_PARTS, compressed=False):
part2iy = load_transliterations_only(data_dir, parts=parts)
pred_ids, pred_dict = load_preds(preds_path, compressed=compressed)
# pred_dict = {i:y for i,y in zip(pred_ids, pred_y)}
scores = {}
for part, (true_ids, true_y) in part2iy.items():
if true_y is None:
print('no labels for %s set' % part)
continue
pred_y = [pred_dict[i] for i in true_ids]
score_values = score(pred_y, true_y)
acc_1 = score_values['acc@1']
print('%s set accuracy@1: %.2f' % (part, acc_1))
scores[part] = score_values
return scores