Skip to content

Commit

Permalink
more tests, fixed bug with torch_retriever, fixed rank indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Oct 14, 2023
1 parent 3736246 commit d6cbf7a
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pyterrier_dr/flex/faiss_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def transform(self, inp):
s = s[mask]
res['docid'].append(d)
res['score'].append(s)
res['rank'].append(np.arange(d.shape[0]))
res['rank'].append(np.arange(d.shape[0])+1)
idxs.extend(itertools.repeat(qidx+i, d.shape[0]))
res = {k: np.concatenate(v) for k, v in res.items()}
res['docno'] = docnos.fwd[res['docid']]
Expand Down
3 changes: 1 addition & 2 deletions pyterrier_dr/flex/np_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
else:
raise ValueError(f'{self.flex_index.sim_fn} not supported')
num_q = query_vecs.shape[0]
res = []
ranked_lists = RankedLists(self.num_results, num_q)
batch_it = range(0, dvecs.shape[0], self.batch_size)
if self.flex_index.verbose:
Expand All @@ -44,7 +43,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
'score': np.concatenate(result_scores),
'docno': np.concatenate(result_docnos),
'docid': np.concatenate(result_dids),
'rank': np.concatenate([np.arange(len(scores)) for scores in result_scores]),
'rank': np.concatenate([np.arange(len(scores))+1 for scores in result_scores]),
}
idxs = list(itertools.chain(*(itertools.repeat(i, len(scores)) for i, scores in enumerate(result_scores))))
for col in inp.columns:
Expand Down
10 changes: 7 additions & 3 deletions pyterrier_dr/flex/torch_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,15 @@ def transform(self, inp):
scores = batch @ self.torch_vecs.T
else:
raise ValueError(f'{self.flex_index.sim_fn} not supported')
scores, docids = scores.topk(self.num_results, dim=1)
if scores.shape[1] > self.num_results:
scores, docids = scores.topk(self.num_results, dim=1)
else:
docids = scores.argsort(descending=True, dim=1)
scores = torch.gather(scores, dim=1, index=docids)
res_scores.append(scores.cpu().numpy().reshape(-1))
res_docids.append(docids.cpu().numpy().reshape(-1))
res_idxs.append(np.arange(start_idx, start_idx+batch.shape[0]).reshape(-1, 1).repeat(self.num_results, axis=1).reshape(-1))
res_ranks.append(np.arange(self.num_results).reshape(1, -1).repeat(batch.shape[0], axis=0).reshape(-1))
res_idxs.append(np.arange(start_idx, start_idx+batch.shape[0]).reshape(-1, 1).repeat(scores.shape[1], axis=1).reshape(-1))
res_ranks.append(np.arange(scores.shape[1]).reshape(1, -1).repeat(batch.shape[0], axis=0).reshape(-1) + 1)
res_idxs = np.concatenate(res_idxs)
res = {k: inp[k][res_idxs] for k in inp.columns if k not in ['docid', 'docno', 'rank', 'score']}
res['score'] = np.concatenate(res_scores)
Expand Down
1 change: 0 additions & 1 deletion pyterrier_dr/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,6 @@ def update(self, scores, docids):
assert self.num_queries == scores.shape[0]
self.scores = np.concatenate([self.scores, -scores], axis=1)
self.docids = np.concatenate([self.docids, docids], axis=1)
print(self.scores.shape[1], self.num_results)
if self.scores.shape[1] > self.num_results:
partition_idxs = np.argpartition(self.scores, self.num_results, axis=1)[:, :self.num_results]
self.scores = np.take_along_axis(self.scores, partition_idxs, axis=1)
Expand Down
11 changes: 11 additions & 0 deletions pyterrier_dr/util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from enum import Enum
import torch


class SimFn(Enum):
dot = 'dot'
cos = 'cos'


class Variants(type):
def __getattr__(cls, name):
if name in cls.VARIANTS:
Expand All @@ -15,7 +17,16 @@ def wrapped(*args, **kwargs):
def __init__(self, *args, **kwargs):
return super().__init__(*args, **kwargs)


def infer_device(device=None):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
return torch.device(device)


def package_available(name):
try:
__import__(name)
return True
except ImportError:
return False
73 changes: 57 additions & 16 deletions tests/test_flexindex.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import tempfile
import unittest
import numpy as np
import pandas as pd
import tempfile
import pyterrier as pt
import numpy as np
import pyterrier_dr
from pyterrier_dr import FlexIndex


class TestFlexIndex(unittest.TestCase):

def _generate_data(self, count=1000, dim=100):
def _generate_data(self, count=2000, dim=100):
return [
{'docno': str(i), 'doc_vec': np.random.rand(dim).astype(np.float32)}
for i in range(count)
]

def test_index_typical(self):
from pyterrier_dr import FlexIndex
destdir = tempfile.mkdtemp()
self.test_dirs.append(destdir)
index = FlexIndex(destdir+'/index')
Expand All @@ -36,46 +37,86 @@ def test_index_typical(self):
self.assertTrue((a['doc_vec'] == b['doc_vec']).all())

def test_corpus_graph(self):
from pyterrier_dr import FlexIndex
destdir = tempfile.mkdtemp()
self.test_dirs.append(destdir)
index = FlexIndex(destdir+'/index')

self.assertFalse(index.built())

dataset = self._generate_data()

index.index(dataset)

graph = index.corpus_graph(16)
self.assertEqual(graph.neighbours(4).shape, (16,))

@unittest.skipIf(not pyterrier_dr.util.package_available('faiss'), "faiss not available")
def test_faiss_hnsw_graph(self):
from pyterrier_dr import FlexIndex
destdir = tempfile.mkdtemp()
self.test_dirs.append(destdir)
index = FlexIndex(destdir+'/index')

self.assertFalse(index.built())

dataset = self._generate_data()

index.index(dataset)

graph = index.faiss_hnsw_graph(16)
self.assertEqual(graph.neighbours(4).shape, (16,))

def _test_exact_retr(self, Retr):
with self.subTest('basic'):
destdir = tempfile.mkdtemp()
self.test_dirs.append(destdir)
index = FlexIndex(destdir+'/index')
dataset = self._generate_data(count=2000)
index.index(dataset)

retr = Retr(index)
res = retr(pd.DataFrame([
{'qid': '0', 'query_vec': dataset[0]['doc_vec']},
{'qid': '1', 'query_vec': dataset[1]['doc_vec']},
]))
self.assertTrue(all(c in res.columns) for c in ['qid', 'docno', 'rank', 'score'])
self.assertEqual(len(res), 2000)
self.assertEqual(len(res[res.qid=='0']), 1000)
self.assertEqual(res[(res.qid=='0')&((res['rank']==1))].iloc[0]['docno'], '0')
self.assertEqual(len(res[res.qid=='1']), 1000)
self.assertEqual(res[(res.qid=='1')&((res['rank']==1))].iloc[0]['docno'], '1')

with self.subTest('smaller'):
destdir = tempfile.mkdtemp()
self.test_dirs.append(destdir)
index = FlexIndex(destdir+'/index')
dataset = self._generate_data(count=100)
index.index(dataset)

retr = Retr(index)
res = retr(pd.DataFrame([
{'qid': '0', 'query_vec': dataset[0]['doc_vec']},
{'qid': '1', 'query_vec': dataset[1]['doc_vec']},
]))
self.assertTrue(all(c in res.columns) for c in ['qid', 'docno', 'rank', 'score'])
self.assertEqual(len(res), 200)
self.assertEqual(len(res[res.qid=='0']), 100)
self.assertEqual(res[(res.qid=='0')&((res['rank']==1))].iloc[0]['docno'], '0')
self.assertEqual(len(res[res.qid=='1']), 100)
self.assertEqual(res[(res.qid=='1')&((res['rank']==1))].iloc[0]['docno'], '1')

@unittest.skipIf(not pyterrier_dr.util.package_available('faiss'), "faiss not available")
def test_faiss_flat_retriever(self):
self._test_exact_retr(FlexIndex.faiss_flat_retriever)

def test_np_retriever(self):
self._test_exact_retr(FlexIndex.np_retriever)

def test_torch_retriever(self):
self._test_exact_retr(FlexIndex.torch_retriever)

# TODO: tests for:
# - faiss_flat_retriever
# - faiss_hnsw_retriever
# - faiss_ivf_retriever
# - pre_ladr
# - ada_ladr
# - np_retriever
# - np_vec_loader
# - np_scorer
# - scann_retriever
# - torch_vecs
# - torch_scorer
# - torch_retriever

def setUp(self):
import pyterrier as pt
Expand Down

0 comments on commit d6cbf7a

Please sign in to comment.