Skip to content

Commit

Permalink
use BLAS in brute force (via NumPy)
Browse files Browse the repository at this point in the history
  • Loading branch information
piskvorky committed Jun 11, 2015
1 parent 09f2fc6 commit 8f34313
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
26 changes: 17 additions & 9 deletions ann_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import panns
import nearpy, nearpy.hashes, nearpy.distances
import pykgraph
import numpy as np
import gzip, numpy, time, os, multiprocessing, argparse, pickle, resource
try:
from urllib import urlretrieve
Expand All @@ -22,7 +23,7 @@
class BaseANN(object):
pass


class LSHF(BaseANN):
def __init__(self, metric, n_estimators=10, n_candidates=50):
self.name = 'LSHF(n_est=%d, n_cand=%d)' % (n_estimators, n_candidates)
Expand Down Expand Up @@ -118,7 +119,7 @@ def __init__(self, metric, n_trees, n_candidates):
self._n_trees = n_trees
self._n_candidates = n_candidates
self._metric = metric
self.name = 'PANNS(n_trees=%d, n_cand=%d)' % (n_trees, n_candidates)
self.name = 'PANNS(n_trees=%d, n_cand=%d)' % (n_trees, n_candidates)

def fit(self, X):
self._panns = panns.PannsIndex(X.shape[1], metric=self._metric)
Expand Down Expand Up @@ -175,17 +176,23 @@ def query(self, v, n):


class BruteForce(BaseANN):
"""kNN search that uses a linear scan = brute force."""
def __init__(self, metric):
assert metric in ('angular', ), "BruteForce doesn't support metric %s" % metric
self._metric = metric
self.name = 'BruteForce()'

def fit(self, X):
metric = {'angular': 'cosine', 'euclidean': 'l2'}[self._metric]
self._nbrs = sklearn.neighbors.NearestNeighbors(algorithm='brute', metric=metric)
self._nbrs.fit(X)
"""Initialize the search index. Modifies `X` in place!"""
X /= np.sqrt((X ** 2).sum(-1))[..., np.newaxis] # normalize vectors to unit length
self.index = X

def query(self, v, n):
return list(self._nbrs.kneighbors(v, return_distance=False, n_neighbors=n)[0])
"""Find indices of `n` most similar vector from index to query vector `v`."""
v /= np.sqrt((v ** 2).sum()) # normalize query to unit length
cossims = numpy.dot(self.index, v) # cossim = dot over normalized
indices = np.argsort(cossims)[::-1] # sort by cossim
return indices[:n] # return top N most similar


def get_dataset(which='glove', limit=-1):
Expand Down Expand Up @@ -246,7 +253,8 @@ def get_queries(args):
print len(queries), '...'

return queries



def get_algos(m):
return {
'lshf': [LSHF(m, 5, 10), LSHF(m, 5, 20), LSHF(m, 10, 20), LSHF(m, 10, 50), LSHF(m, 20, 100)],
Expand Down Expand Up @@ -305,15 +313,15 @@ def get_fn(base, args):
if os.path.exists(results_fn):
for line in open(results_fn):
algos_already_ran.add(line.strip().split('\t')[1])

algos = get_algos(args.distance)
algos_flat = []

for library in algos.keys():
for algo in algos[library]:
if algo.name not in algos_already_ran:
algos_flat.append((library, algo))

random.shuffle(algos_flat)

print 'order:', algos_flat
Expand Down
2 changes: 1 addition & 1 deletion install.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
sudo apt-get install -y python-numpy python-scipy
cd install
for fn in annoy.sh panns.sh nearpy.sh sklearn.sh flann.sh kgraph.sh glove.sh sift.sh
for fn in bruteforce.sh annoy.sh panns.sh nearpy.sh sklearn.sh flann.sh kgraph.sh glove.sh sift.sh
do
source $fn
done
3 changes: 3 additions & 0 deletions install/bruteforce.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
sudo apt-get install -y python-pip python-dev
sudo apt-get install -y libatlas-dev libatlas3gf-base
sudo apt-get install -y python-numpy

0 comments on commit 8f34313

Please sign in to comment.