Skip to content

Commit

Permalink
model tests added, lr changed
Browse files Browse the repository at this point in the history
  • Loading branch information
mg98 committed Nov 3, 2023
1 parent 086fb5b commit 8e653a4
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 3 deletions.
1 change: 1 addition & 0 deletions p2p_ol2r/ltr.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def query(self, query: str) -> dict[str, str]:
k = result_pair[1]
results_scores[k] = results_scores.get(k, 0) + (1 - prob_1_over_2)

# aggregating results by summing the probabilities of being superior
results_scores = dict(sorted(results_scores.items(), key=itemgetter(1), reverse=True))
ranked_results = {res_id: self.metadata[res_id] for res_id, _ in results_scores.items()}
return ranked_results
Expand Down
11 changes: 8 additions & 3 deletions p2p_ol2r/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, quantize: bool) -> None:
self.model = nn.Sequential(OrderedDict(layers))

self._criterion = nn.BCEWithLogitsLoss()
self._optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001)
self._optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)

def serialize_model(self) -> io.BytesIO:
buffer = io.BytesIO()
Expand All @@ -48,7 +48,12 @@ def serialize_model(self) -> io.BytesIO:

return buffer

def make_input(self, query_vector: np.ndarray, sup_doc_vector: np.ndarray, inf_doc_vector: np.ndarray):
def make_input(
self,
query_vector: np.ndarray,
sup_doc_vector: np.ndarray,
inf_doc_vector: np.ndarray
) -> np.ndarray:
"""
Make (query, document-pair) input for model.
"""
Expand All @@ -62,7 +67,7 @@ def _train_step(self, train_data: np.ndarray, label: bool) -> float:
self._optimizer.step()
return loss.item()

def train(self, pos_train_data, neg_train_data, num_epochs: int):
def train(self, pos_train_data: np.ndarray, neg_train_data: np.ndarray, num_epochs: int):
self.model.train()

print(fmt(f'Epoch [0/{num_epochs}], Loss: n/a', 'gray'), end='')
Expand Down
77 changes: 77 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import unittest
import numpy as np
import random
from p2p_ol2r.model import *

class TestModel(unittest.TestCase):

def test_one_above_all(self):
ltr_model = LTRModel(False)
mki = ltr_model.make_input # just an alias
k = 10
q = np.random.rand(768)
docs = [np.random.rand(768) for _ in range(k)]

# train docs[0] to be above all others
pos_train_data = torch.from_numpy(np.array(
[mki(q, docs[0], docs[i]) for i in range(1, k)]
))
neg_train_data = torch.from_numpy(np.array(
[mki(q, docs[i], docs[0]) for i in range(1, k)]
))
with silence(): ltr_model.train(pos_train_data, neg_train_data, 100)

ltr_model.model.eval()
torch.no_grad()

for i in range(1, k):
self.assertGreater(
ltr_model.model(torch.from_numpy(mki(q, docs[0], docs[i]))).item(),
0.5
)
self.assertLess(
ltr_model.model(torch.from_numpy(mki(q, docs[i], docs[0]))).item(),
0.5
)

@unittest.skip("the ultimate test - doesn't work yet 🥲")
def test_full_ranking(self):
ltr_model = LTRModel(False)
mki = ltr_model.make_input # just an alias
k = 10
q = np.random.rand(768)
docs = [np.random.rand(768) for _ in range(k)]

train_data = []

for i in range(k-1):
# docs[i] to be above all others
pos_train_data = torch.from_numpy(np.array(
[mki(q, docs[i], docs[j]) for j in range(k) if i != j]
))
neg_train_data = torch.from_numpy(np.array(
[mki(q, docs[j], docs[i]) for j in range(k) if i != j]
))
train_data.extend([(pos_train_data, neg_train_data)] * (k*100 - i*100))

random.shuffle(train_data)

for (pos_train_data, neg_train_data) in train_data:
with silence(): ltr_model.train(pos_train_data, neg_train_data, 1)

ltr_model.model.eval()
torch.no_grad()

for i in range(k-1):
for j in range(i+1, k):
self.assertGreater(
ltr_model.model(torch.from_numpy(mki(q, docs[i], docs[j]))).item(),
0.5
)
# self.assertLess(
# ltr_model.model(torch.from_numpy(mki(q, docs[j], docs[i]))).item(),
# 0.5
# )

if __name__ == "__main__":
unittest.main()

0 comments on commit 8e653a4

Please sign in to comment.