From d6cbf7a1672f0163f2bcb3eab8376f7771697e6e Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Sat, 14 Oct 2023 21:07:46 +0100 Subject: [PATCH] more tests, fixed bug with torch_retriever, fixed rank indexing --- pyterrier_dr/flex/faiss_retr.py | 2 +- pyterrier_dr/flex/np_retr.py | 3 +- pyterrier_dr/flex/torch_retr.py | 10 +++-- pyterrier_dr/indexes.py | 1 - pyterrier_dr/util.py | 11 +++++ tests/test_flexindex.py | 73 +++++++++++++++++++++++++-------- 6 files changed, 77 insertions(+), 23 deletions(-) diff --git a/pyterrier_dr/flex/faiss_retr.py b/pyterrier_dr/flex/faiss_retr.py index 9e45514..9f8fd21 100644 --- a/pyterrier_dr/flex/faiss_retr.py +++ b/pyterrier_dr/flex/faiss_retr.py @@ -49,7 +49,7 @@ def transform(self, inp): s = s[mask] res['docid'].append(d) res['score'].append(s) - res['rank'].append(np.arange(d.shape[0])) + res['rank'].append(np.arange(d.shape[0])+1) idxs.extend(itertools.repeat(qidx+i, d.shape[0])) res = {k: np.concatenate(v) for k, v in res.items()} res['docno'] = docnos.fwd[res['docid']] diff --git a/pyterrier_dr/flex/np_retr.py b/pyterrier_dr/flex/np_retr.py index 1320cb5..63c0fbd 100644 --- a/pyterrier_dr/flex/np_retr.py +++ b/pyterrier_dr/flex/np_retr.py @@ -26,7 +26,6 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame: else: raise ValueError(f'{self.flex_index.sim_fn} not supported') num_q = query_vecs.shape[0] - res = [] ranked_lists = RankedLists(self.num_results, num_q) batch_it = range(0, dvecs.shape[0], self.batch_size) if self.flex_index.verbose: @@ -44,7 +43,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame: 'score': np.concatenate(result_scores), 'docno': np.concatenate(result_docnos), 'docid': np.concatenate(result_dids), - 'rank': np.concatenate([np.arange(len(scores)) for scores in result_scores]), + 'rank': np.concatenate([np.arange(len(scores))+1 for scores in result_scores]), } idxs = list(itertools.chain(*(itertools.repeat(i, len(scores)) for i, scores in enumerate(result_scores)))) for col in inp.columns: diff --git a/pyterrier_dr/flex/torch_retr.py b/pyterrier_dr/flex/torch_retr.py index 09cc8f6..6f960c7 100644 --- a/pyterrier_dr/flex/torch_retr.py +++ b/pyterrier_dr/flex/torch_retr.py @@ -58,11 +58,15 @@ def transform(self, inp): scores = batch @ self.torch_vecs.T else: raise ValueError(f'{self.flex_index.sim_fn} not supported') - scores, docids = scores.topk(self.num_results, dim=1) + if scores.shape[1] > self.num_results: + scores, docids = scores.topk(self.num_results, dim=1) + else: + docids = scores.argsort(descending=True, dim=1) + scores = torch.gather(scores, dim=1, index=docids) res_scores.append(scores.cpu().numpy().reshape(-1)) res_docids.append(docids.cpu().numpy().reshape(-1)) - res_idxs.append(np.arange(start_idx, start_idx+batch.shape[0]).reshape(-1, 1).repeat(self.num_results, axis=1).reshape(-1)) - res_ranks.append(np.arange(self.num_results).reshape(1, -1).repeat(batch.shape[0], axis=0).reshape(-1)) + res_idxs.append(np.arange(start_idx, start_idx+batch.shape[0]).reshape(-1, 1).repeat(scores.shape[1], axis=1).reshape(-1)) + res_ranks.append(np.arange(scores.shape[1]).reshape(1, -1).repeat(batch.shape[0], axis=0).reshape(-1) + 1) res_idxs = np.concatenate(res_idxs) res = {k: inp[k][res_idxs] for k in inp.columns if k not in ['docid', 'docno', 'rank', 'score']} res['score'] = np.concatenate(res_scores) diff --git a/pyterrier_dr/indexes.py b/pyterrier_dr/indexes.py index 80d6626..c572fab 100644 --- a/pyterrier_dr/indexes.py +++ b/pyterrier_dr/indexes.py @@ -379,7 +379,6 @@ def update(self, scores, docids): assert self.num_queries == scores.shape[0] self.scores = np.concatenate([self.scores, -scores], axis=1) self.docids = np.concatenate([self.docids, docids], axis=1) - print(self.scores.shape[1], self.num_results) if self.scores.shape[1] > self.num_results: partition_idxs = np.argpartition(self.scores, self.num_results, axis=1)[:, :self.num_results] self.scores = np.take_along_axis(self.scores, partition_idxs, axis=1) diff --git a/pyterrier_dr/util.py b/pyterrier_dr/util.py index dfa1913..ac89fbb 100644 --- a/pyterrier_dr/util.py +++ b/pyterrier_dr/util.py @@ -1,10 +1,12 @@ from enum import Enum import torch + class SimFn(Enum): dot = 'dot' cos = 'cos' + class Variants(type): def __getattr__(cls, name): if name in cls.VARIANTS: @@ -15,7 +17,16 @@ def wrapped(*args, **kwargs): def __init__(self, *args, **kwargs): return super().__init__(*args, **kwargs) + def infer_device(device=None): if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' return torch.device(device) + + +def package_available(name): + try: + __import__(name) + return True + except ImportError: + return False diff --git a/tests/test_flexindex.py b/tests/test_flexindex.py index c99b92d..60dd4cf 100644 --- a/tests/test_flexindex.py +++ b/tests/test_flexindex.py @@ -1,20 +1,21 @@ +import tempfile import unittest +import numpy as np import pandas as pd -import tempfile import pyterrier as pt -import numpy as np +import pyterrier_dr +from pyterrier_dr import FlexIndex class TestFlexIndex(unittest.TestCase): - def _generate_data(self, count=1000, dim=100): + def _generate_data(self, count=2000, dim=100): return [ {'docno': str(i), 'doc_vec': np.random.rand(dim).astype(np.float32)} for i in range(count) ] def test_index_typical(self): - from pyterrier_dr import FlexIndex destdir = tempfile.mkdtemp() self.test_dirs.append(destdir) index = FlexIndex(destdir+'/index') @@ -36,46 +37,86 @@ def test_index_typical(self): self.assertTrue((a['doc_vec'] == b['doc_vec']).all()) def test_corpus_graph(self): - from pyterrier_dr import FlexIndex destdir = tempfile.mkdtemp() self.test_dirs.append(destdir) index = FlexIndex(destdir+'/index') - - self.assertFalse(index.built()) - dataset = self._generate_data() - index.index(dataset) + graph = index.corpus_graph(16) self.assertEqual(graph.neighbours(4).shape, (16,)) + @unittest.skipIf(not pyterrier_dr.util.package_available('faiss'), "faiss not available") def test_faiss_hnsw_graph(self): - from pyterrier_dr import FlexIndex destdir = tempfile.mkdtemp() self.test_dirs.append(destdir) index = FlexIndex(destdir+'/index') - - self.assertFalse(index.built()) - dataset = self._generate_data() - index.index(dataset) + graph = index.faiss_hnsw_graph(16) self.assertEqual(graph.neighbours(4).shape, (16,)) + def _test_exact_retr(self, Retr): + with self.subTest('basic'): + destdir = tempfile.mkdtemp() + self.test_dirs.append(destdir) + index = FlexIndex(destdir+'/index') + dataset = self._generate_data(count=2000) + index.index(dataset) + + retr = Retr(index) + res = retr(pd.DataFrame([ + {'qid': '0', 'query_vec': dataset[0]['doc_vec']}, + {'qid': '1', 'query_vec': dataset[1]['doc_vec']}, + ])) + self.assertTrue(all(c in res.columns) for c in ['qid', 'docno', 'rank', 'score']) + self.assertEqual(len(res), 2000) + self.assertEqual(len(res[res.qid=='0']), 1000) + self.assertEqual(res[(res.qid=='0')&((res['rank']==1))].iloc[0]['docno'], '0') + self.assertEqual(len(res[res.qid=='1']), 1000) + self.assertEqual(res[(res.qid=='1')&((res['rank']==1))].iloc[0]['docno'], '1') + + with self.subTest('smaller'): + destdir = tempfile.mkdtemp() + self.test_dirs.append(destdir) + index = FlexIndex(destdir+'/index') + dataset = self._generate_data(count=100) + index.index(dataset) + + retr = Retr(index) + res = retr(pd.DataFrame([ + {'qid': '0', 'query_vec': dataset[0]['doc_vec']}, + {'qid': '1', 'query_vec': dataset[1]['doc_vec']}, + ])) + self.assertTrue(all(c in res.columns) for c in ['qid', 'docno', 'rank', 'score']) + self.assertEqual(len(res), 200) + self.assertEqual(len(res[res.qid=='0']), 100) + self.assertEqual(res[(res.qid=='0')&((res['rank']==1))].iloc[0]['docno'], '0') + self.assertEqual(len(res[res.qid=='1']), 100) + self.assertEqual(res[(res.qid=='1')&((res['rank']==1))].iloc[0]['docno'], '1') + + @unittest.skipIf(not pyterrier_dr.util.package_available('faiss'), "faiss not available") + def test_faiss_flat_retriever(self): + self._test_exact_retr(FlexIndex.faiss_flat_retriever) + + def test_np_retriever(self): + self._test_exact_retr(FlexIndex.np_retriever) + + def test_torch_retriever(self): + self._test_exact_retr(FlexIndex.torch_retriever) + # TODO: tests for: # - faiss_flat_retriever # - faiss_hnsw_retriever # - faiss_ivf_retriever # - pre_ladr # - ada_ladr - # - np_retriever # - np_vec_loader # - np_scorer # - scann_retriever # - torch_vecs # - torch_scorer - # - torch_retriever def setUp(self): import pyterrier as pt