From b72f4526bdcf243b5606ad5a43b3191082bc0d5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radim=20=C5=98eh=C5=AF=C5=99ek?= Date: Sat, 13 Jun 2015 22:30:22 +0200 Subject: [PATCH] switch BruteForce to single precision math --- ann_benchmarks.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/ann_benchmarks.py b/ann_benchmarks.py index 3da038f38..1402bab6c 100644 --- a/ann_benchmarks.py +++ b/ann_benchmarks.py @@ -176,25 +176,28 @@ def query(self, v, n): class BruteForce(BaseANN): """kNN search that uses a linear scan = brute force.""" - def __init__(self, metric): + def __init__(self, metric, precision=numpy.float32): if metric not in ('angular', 'euclidean'): raise NotImplementedError("BruteForce doesn't support metric %s" % metric) self._metric = metric + self._precision = precision self.name = 'BruteForce()' def fit(self, X): """Initialize the search index.""" - self.lengths = (X ** 2).sum(-1) # precompute (squared) length of each vector + lens = (X ** 2).sum(-1) # precompute (squared) length of each vector if self._metric == 'angular': - # for cossim, normalize index vectors to unit length - self.index = numpy.ascontiguousarray(X / numpy.sqrt(self.lengths)[..., numpy.newaxis]) + X /= numpy.sqrt(lens)[..., numpy.newaxis] # normalize index vectors to unit length + self.index = numpy.ascontiguousarray(X, dtype=self._precision) elif self._metric == 'euclidean': - self.index = numpy.ascontiguousarray(X) + self.index = numpy.ascontiguousarray(X, dtype=self._precision) + self.lengths = numpy.ascontiguousarray(lens, dtype=self._precision) else: assert False, "invalid metric" # shouldn't get past the constructor! def query(self, v, n): """Find indices of `n` most similar vectors from the index to query vector `v`.""" + v = numpy.ascontiguousarray(v, dtype=self._precision) # use same precision for query as for index # HACK we ignore query length as that's a constant not affecting the final ordering if self._metric == 'angular': # argmax_a cossim(a, b) = argmax_a dot(a, b) / |a||b| = argmin_a -dot(a, b) @@ -205,7 +208,7 @@ def query(self, v, n): else: assert False, "invalid metric" # shouldn't get past the constructor! indices = numpy.argpartition(dists, n)[:n] # partition-sort by distance, get `n` closest - return sorted(indices, key=lambda index: dists[index]) # resort `n` closest into final order + return sorted(indices, key=lambda index: dists[index]) # sort `n` closest into correct order def get_dataset(which='glove', limit=-1):