From 8f343139fd459b5a102f24bc162a3bcd1deb7418 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radim=20=C5=98eh=C5=AF=C5=99ek?= Date: Thu, 11 Jun 2015 11:19:41 +0200 Subject: [PATCH] use BLAS in brute force (via NumPy) --- ann_benchmarks.py | 26 +++++++++++++++++--------- install.sh | 2 +- install/bruteforce.sh | 3 +++ 3 files changed, 21 insertions(+), 10 deletions(-) create mode 100644 install/bruteforce.sh diff --git a/ann_benchmarks.py b/ann_benchmarks.py index 217178e6a..ea8f7c390 100644 --- a/ann_benchmarks.py +++ b/ann_benchmarks.py @@ -4,6 +4,7 @@ import panns import nearpy, nearpy.hashes, nearpy.distances import pykgraph +import numpy as np import gzip, numpy, time, os, multiprocessing, argparse, pickle, resource try: from urllib import urlretrieve @@ -22,7 +23,7 @@ class BaseANN(object): pass - + class LSHF(BaseANN): def __init__(self, metric, n_estimators=10, n_candidates=50): self.name = 'LSHF(n_est=%d, n_cand=%d)' % (n_estimators, n_candidates) @@ -118,7 +119,7 @@ def __init__(self, metric, n_trees, n_candidates): self._n_trees = n_trees self._n_candidates = n_candidates self._metric = metric - self.name = 'PANNS(n_trees=%d, n_cand=%d)' % (n_trees, n_candidates) + self.name = 'PANNS(n_trees=%d, n_cand=%d)' % (n_trees, n_candidates) def fit(self, X): self._panns = panns.PannsIndex(X.shape[1], metric=self._metric) @@ -175,17 +176,23 @@ def query(self, v, n): class BruteForce(BaseANN): + """kNN search that uses a linear scan = brute force.""" def __init__(self, metric): + assert metric in ('angular', ), "BruteForce doesn't support metric %s" % metric self._metric = metric self.name = 'BruteForce()' def fit(self, X): - metric = {'angular': 'cosine', 'euclidean': 'l2'}[self._metric] - self._nbrs = sklearn.neighbors.NearestNeighbors(algorithm='brute', metric=metric) - self._nbrs.fit(X) + """Initialize the search index. Modifies `X` in place!""" + X /= np.sqrt((X ** 2).sum(-1))[..., np.newaxis] # normalize vectors to unit length + self.index = X def query(self, v, n): - return list(self._nbrs.kneighbors(v, return_distance=False, n_neighbors=n)[0]) + """Find indices of `n` most similar vector from index to query vector `v`.""" + v /= np.sqrt((v ** 2).sum()) # normalize query to unit length + cossims = numpy.dot(self.index, v) # cossim = dot over normalized + indices = np.argsort(cossims)[::-1] # sort by cossim + return indices[:n] # return top N most similar def get_dataset(which='glove', limit=-1): @@ -246,7 +253,8 @@ def get_queries(args): print len(queries), '...' return queries - + + def get_algos(m): return { 'lshf': [LSHF(m, 5, 10), LSHF(m, 5, 20), LSHF(m, 10, 20), LSHF(m, 10, 50), LSHF(m, 20, 100)], @@ -305,7 +313,7 @@ def get_fn(base, args): if os.path.exists(results_fn): for line in open(results_fn): algos_already_ran.add(line.strip().split('\t')[1]) - + algos = get_algos(args.distance) algos_flat = [] @@ -313,7 +321,7 @@ def get_fn(base, args): for algo in algos[library]: if algo.name not in algos_already_ran: algos_flat.append((library, algo)) - + random.shuffle(algos_flat) print 'order:', algos_flat diff --git a/install.sh b/install.sh index 59068bcb5..9dbffd901 100644 --- a/install.sh +++ b/install.sh @@ -1,6 +1,6 @@ sudo apt-get install -y python-numpy python-scipy cd install -for fn in annoy.sh panns.sh nearpy.sh sklearn.sh flann.sh kgraph.sh glove.sh sift.sh +for fn in bruteforce.sh annoy.sh panns.sh nearpy.sh sklearn.sh flann.sh kgraph.sh glove.sh sift.sh do source $fn done diff --git a/install/bruteforce.sh b/install/bruteforce.sh new file mode 100644 index 000000000..19c649d30 --- /dev/null +++ b/install/bruteforce.sh @@ -0,0 +1,3 @@ +sudo apt-get install -y python-pip python-dev +sudo apt-get install -y libatlas-dev libatlas3gf-base +sudo apt-get install -y python-numpy