-
Notifications
You must be signed in to change notification settings - Fork 7
/
blend.py
31 lines (20 loc) · 891 Bytes
/
blend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from collections import defaultdict
from operator import itemgetter
from typing import List
def get_list_of_recs(rec_lines: List[str]):
return [[x for x in s.strip().split(' ')] for s in rec_lines]
def blend(*preds, weights: list, topk: int = 100) -> str:
res = defaultdict(float)
for i in range(len(preds)):
for rank, item_id in enumerate(preds[i]):
res[item_id] += weights[i] / (rank + 1)
res = list(dict(sorted(res.items(), key=itemgetter(1))).keys())
return ' '.join(res[:topk])
if __name__ == '__main__':
with open('submission_a.csv', 'r') as f:
sub_a = get_list_of_recs(f.readlines())
with open('submission_b.csv', 'r') as f:
sub_b = get_list_of_recs(f.readlines())
with open('submission_ab.csv', 'w') as f:
for i in range(len(sub_a)):
f.write(blend(sub_a, sub_b, [0.55, 0.45]) + '\n')