Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use BLAS in brute force (via NumPy) #5

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions ann_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class BaseANN(object):
pass


class LSHF(BaseANN):
def __init__(self, metric, n_estimators=10, n_candidates=50):
self.name = 'LSHF(n_est=%d, n_cand=%d)' % (n_estimators, n_candidates)
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(self, metric, n_trees, n_candidates):
self._n_trees = n_trees
self._n_candidates = n_candidates
self._metric = metric
self.name = 'PANNS(n_trees=%d, n_cand=%d)' % (n_trees, n_candidates)
self.name = 'PANNS(n_trees=%d, n_cand=%d)' % (n_trees, n_candidates)

def fit(self, X):
self._panns = panns.PannsIndex(X.shape[1], metric=self._metric)
Expand Down Expand Up @@ -175,17 +175,24 @@ def query(self, v, n):


class BruteForce(BaseANN):
"""kNN search that uses a linear scan = brute force."""
def __init__(self, metric):
if metric not in ('angular', ):
raise NotImplementedError("BruteForce doesn't support metric %s" % metric)
self._metric = metric
self.name = 'BruteForce()'

def fit(self, X):
metric = {'angular': 'cosine', 'euclidean': 'l2'}[self._metric]
self._nbrs = sklearn.neighbors.NearestNeighbors(algorithm='brute', metric=metric)
self._nbrs.fit(X)
"""Initialize the search index."""
# normalize vectors to unit length
self.index = X / numpy.sqrt((X ** 2).sum(-1))[..., numpy.newaxis]

def query(self, v, n):
return list(self._nbrs.kneighbors(v, return_distance=False, n_neighbors=n)[0])
"""Find indices of `n` most similar vectors from the index to query vector `v`."""
query = v / numpy.sqrt((v ** 2).sum()) # normalize query to unit length
cossims = numpy.dot(self.index, query) # cossim = dot product over normalized vectors
indices = numpy.argsort(cossims)[::-1] # sort by cossim, highest first
return indices[:n] # return top `n` most similar


def get_dataset(which='glove', limit=-1):
Expand Down Expand Up @@ -246,7 +253,8 @@ def get_queries(args):
print len(queries), '...'

return queries



def get_algos(m):
return {
'lshf': [LSHF(m, 5, 10), LSHF(m, 5, 20), LSHF(m, 10, 20), LSHF(m, 10, 50), LSHF(m, 20, 100)],
Expand Down Expand Up @@ -305,15 +313,15 @@ def get_fn(base, args):
if os.path.exists(results_fn):
for line in open(results_fn):
algos_already_ran.add(line.strip().split('\t')[1])

algos = get_algos(args.distance)
algos_flat = []

for library in algos.keys():
for algo in algos[library]:
if algo.name not in algos_already_ran:
algos_flat.append((library, algo))

random.shuffle(algos_flat)

print 'order:', algos_flat
Expand Down
2 changes: 1 addition & 1 deletion install.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
sudo apt-get install -y python-numpy python-scipy
cd install
for fn in annoy.sh panns.sh nearpy.sh sklearn.sh flann.sh kgraph.sh glove.sh sift.sh
for fn in bruteforce.sh annoy.sh panns.sh nearpy.sh sklearn.sh flann.sh kgraph.sh glove.sh sift.sh
do
source $fn
done
3 changes: 3 additions & 0 deletions install/bruteforce.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
sudo apt-get install -y python-pip python-dev
sudo apt-get install -y libatlas-dev libatlas3gf-base
sudo apt-get install -y python-numpy
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erikbern I'm not terribly sure how the debian-packaged NumPy plays with BLAS... can you check that ATLAS is being picked up by NumPy (=dot calls are fast)?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like the pypi version of numpy gets pulled in from some other package so it might not be needed anyway

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright... can you check the timings for np.dot anyway, just to be sure? What version of numpy is that?