Skip to content

Commit

Permalink
switch BruteForce to single precision math
Browse files Browse the repository at this point in the history
  • Loading branch information
piskvorky committed Jun 13, 2015
1 parent cf2b2d9 commit b72f452
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions ann_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,25 +176,28 @@ def query(self, v, n):

class BruteForce(BaseANN):
"""kNN search that uses a linear scan = brute force."""
def __init__(self, metric):
def __init__(self, metric, precision=numpy.float32):
if metric not in ('angular', 'euclidean'):
raise NotImplementedError("BruteForce doesn't support metric %s" % metric)
self._metric = metric
self._precision = precision
self.name = 'BruteForce()'

def fit(self, X):
"""Initialize the search index."""
self.lengths = (X ** 2).sum(-1) # precompute (squared) length of each vector
lens = (X ** 2).sum(-1) # precompute (squared) length of each vector
if self._metric == 'angular':
# for cossim, normalize index vectors to unit length
self.index = numpy.ascontiguousarray(X / numpy.sqrt(self.lengths)[..., numpy.newaxis])
X /= numpy.sqrt(lens)[..., numpy.newaxis] # normalize index vectors to unit length
self.index = numpy.ascontiguousarray(X, dtype=self._precision)
elif self._metric == 'euclidean':
self.index = numpy.ascontiguousarray(X)
self.index = numpy.ascontiguousarray(X, dtype=self._precision)
self.lengths = numpy.ascontiguousarray(lens, dtype=self._precision)
else:
assert False, "invalid metric" # shouldn't get past the constructor!

def query(self, v, n):
"""Find indices of `n` most similar vectors from the index to query vector `v`."""
v = numpy.ascontiguousarray(v, dtype=self._precision) # use same precision for query as for index
# HACK we ignore query length as that's a constant not affecting the final ordering
if self._metric == 'angular':
# argmax_a cossim(a, b) = argmax_a dot(a, b) / |a||b| = argmin_a -dot(a, b)
Expand All @@ -205,7 +208,7 @@ def query(self, v, n):
else:
assert False, "invalid metric" # shouldn't get past the constructor!
indices = numpy.argpartition(dists, n)[:n] # partition-sort by distance, get `n` closest
return sorted(indices, key=lambda index: dists[index]) # resort `n` closest into final order
return sorted(indices, key=lambda index: dists[index]) # sort `n` closest into correct order


def get_dataset(which='glove', limit=-1):
Expand Down

0 comments on commit b72f452

Please sign in to comment.