diff --git a/algos.yaml b/algos.yaml index 33970040f..8afdce7fd 100644 --- a/algos.yaml +++ b/algos.yaml @@ -26,8 +26,8 @@ float: base-args: ["@metric"] run-groups: base: - args: [[5, 10, 20, 50, 100, 200, 400, 800, 1000], - [1, 2, 3, 4, 5, 8, 10, 20, 50, 100, 200]] + args: [[5, 10, 20, 50, 100, 200, 400, 800, 1000]] + query-args: [[1, 2, 3, 4, 5, 8, 10, 20, 50, 100, 200]] flann: docker-tag: ann-benchmarks-flann module: ann_benchmarks.algorithms.flann diff --git a/ann_benchmarks/algorithms/faiss.py b/ann_benchmarks/algorithms/faiss.py index 17b682f41..6bf975c44 100644 --- a/ann_benchmarks/algorithms/faiss.py +++ b/ann_benchmarks/algorithms/faiss.py @@ -38,11 +38,9 @@ def use_threads(self): import sklearn.preprocessing class FaissIVF(BaseANN): - def __init__(self, metric, n_list, n_probe): + def __init__(self, metric, n_list): self._n_list = n_list - self._n_probe = n_probe self._metric = metric - self.name = 'FaissIVF(n_list=%d, n_probe=%d)' % (self._n_list, self._n_probe) def fit(self, X): if self._metric == 'angular': @@ -55,11 +53,17 @@ def fit(self, X): index = faiss.IndexIVFFlat(self.quantizer, X.shape[1], self._n_list, faiss.METRIC_L2) index.train(X) index.add(X) - index.nprobe = self._n_probe self._index = index + def set_query_arguments(self, n_probe): + self._n_probe = n_probe + self._index.n_probe = self._n_probe + def query(self, v, n): if self._metric == 'angular': v /= numpy.linalg.norm(v) (dist,), (ids,) = self._index.search(v.reshape(1, -1).astype('float32'), n) return ids + + def __str__(self): + return 'FaissIVF(n_list=%d, n_probe=%d)' % (self._n_list, self._n_probe) diff --git a/ann_benchmarks/runner.py b/ann_benchmarks/runner.py index d15d17d12..30835fa70 100644 --- a/ann_benchmarks/runner.py +++ b/ann_benchmarks/runner.py @@ -57,6 +57,7 @@ def batch_query(X): elif algo.use_threads() and not force_single: pool = multiprocessing.pool.ThreadPool() results = pool.map(single_query, X_test) + pool.close() else: results = [single_query(x) for x in X_test]