Skip to content

Commit

Permalink
Merge pull request #2 from maumueller/faiss-build
Browse files Browse the repository at this point in the history
Split up build and query args for faiss-ivf
  • Loading branch information
ale-f authored Mar 12, 2018
2 parents 9826fcc + 91b6cc3 commit 9bb93fe
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
4 changes: 2 additions & 2 deletions algos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ float:
base-args: ["@metric"]
run-groups:
base:
args: [[5, 10, 20, 50, 100, 200, 400, 800, 1000],
[1, 2, 3, 4, 5, 8, 10, 20, 50, 100, 200]]
args: [[5, 10, 20, 50, 100, 200, 400, 800, 1000]]
query-args: [[1, 2, 3, 4, 5, 8, 10, 20, 50, 100, 200]]
flann:
docker-tag: ann-benchmarks-flann
module: ann_benchmarks.algorithms.flann
Expand Down
12 changes: 8 additions & 4 deletions ann_benchmarks/algorithms/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,9 @@ def use_threads(self):
import sklearn.preprocessing

class FaissIVF(BaseANN):
def __init__(self, metric, n_list, n_probe):
def __init__(self, metric, n_list):
self._n_list = n_list
self._n_probe = n_probe
self._metric = metric
self.name = 'FaissIVF(n_list=%d, n_probe=%d)' % (self._n_list, self._n_probe)

def fit(self, X):
if self._metric == 'angular':
Expand All @@ -55,11 +53,17 @@ def fit(self, X):
index = faiss.IndexIVFFlat(self.quantizer, X.shape[1], self._n_list, faiss.METRIC_L2)
index.train(X)
index.add(X)
index.nprobe = self._n_probe
self._index = index

def set_query_arguments(self, n_probe):
self._n_probe = n_probe
self._index.n_probe = self._n_probe

def query(self, v, n):
if self._metric == 'angular':
v /= numpy.linalg.norm(v)
(dist,), (ids,) = self._index.search(v.reshape(1, -1).astype('float32'), n)
return ids

def __str__(self):
return 'FaissIVF(n_list=%d, n_probe=%d)' % (self._n_list, self._n_probe)
1 change: 1 addition & 0 deletions ann_benchmarks/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def batch_query(X):
elif algo.use_threads() and not force_single:
pool = multiprocessing.pool.ThreadPool()
results = pool.map(single_query, X_test)
pool.close()
else:
results = [single_query(x) for x in X_test]

Expand Down

0 comments on commit 9bb93fe

Please sign in to comment.