diff --git a/ann_benchmarks.py b/ann_benchmarks.py index dc6666bdd..4a23911b0 100644 --- a/ann_benchmarks.py +++ b/ann_benchmarks.py @@ -109,13 +109,19 @@ def query(self, v, n): class NearPy(BaseANN): - def __init__(self, n_bits): + def __init__(self, n_bits, hash_counts): self._n_bits = n_bits - self.name = 'NearPy(n_bits=%d)' % (n_bits,) + self._hash_counts = hash_counts + self.name = 'NearPy(n_bits=%d, hash_counts=%d)' % (n_bits, hash_counts) def fit(self, X): - nearpy_rbp = nearpy.hashes.RandomBinaryProjections('rbp', self._n_bits) - self._nearpy_engine = nearpy.Engine(X.shape[1], lshashes=[nearpy_rbp], distance=nearpy.distances.CosineDistance()) + hashes = [] + + for k in xrange(self._hash_counts): + nearpy_rbp = nearpy.hashes.RandomBinaryProjections('rbp_%d' % k, self._n_bits) + hashes.append(nearpy_rbp) + + self._nearpy_engine = nearpy.Engine(X.shape[1], lshashes=hashes, distance=nearpy.distances.CosineDistance()) for i, x in enumerate(X): self._nearpy_engine.store_vector(x.tolist(), i) @@ -198,7 +204,7 @@ def run_algo(library, algo): 'flann': [FLANN(0.2), FLANN(0.5), FLANN(0.7), FLANN(0.8), FLANN(0.9), FLANN(0.95), FLANN(0.97), FLANN(0.98), FLANN(0.99), FLANN(0.995)], 'panns': [PANNS(5, 20), PANNS(10, 10), PANNS(10, 50), PANNS(10, 100), PANNS(20, 100), PANNS(40, 100)], 'annoy': [Annoy(3, 10), Annoy(5, 25), Annoy(10, 10), Annoy(10, 40), Annoy(10, 100), Annoy(10, 200), Annoy(10, 400), Annoy(10, 1000), Annoy(20, 20), Annoy(20, 100), Annoy(20, 200), Annoy(20, 400), Annoy(40, 40), Annoy(40, 100), Annoy(40, 400), Annoy(100, 100), Annoy(100, 200), Annoy(100, 400), Annoy(100, 1000)], - 'nearpy': [NearPy(10), NearPy(12), NearPy(15), NearPy(20)], + 'nearpy': [NearPy(30, 10), NearPy(30, 20), NearPy(30, 30), NearPy(20, 10), NearPy(20, 20), NearPy(20, 30), NearPy(15, 10), NearPy(15, 20), NearPy(15, 30), NearPy(10, 10), NearPy(10, 20), NearPy(10, 30), NearPy(8, 10), NearPy(8, 20), NearPy(8, 30)], 'kgraph': [KGraph(20), KGraph(50), KGraph(100), KGraph(200), KGraph(500), KGraph(1000)], 'bruteforce': [bf], 'ball': [BallTree(10), BallTree(20), BallTree(40), BallTree(100), BallTree(200), BallTree(400), BallTree(1000)],