Skip to content

Commit

Permalink
brute force using blas
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed May 3, 2016
1 parent 87774ee commit 3c21b7d
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion ann_benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,43 @@ def query(self, v, n):
return list(self._nbrs.kneighbors(v, return_distance=False, n_neighbors=n)[0])


class BruteForceBLAS(BaseANN):
"""kNN search that uses a linear scan = brute force."""
def __init__(self, metric, precision=numpy.float32):
if metric not in ('angular', 'euclidean'):
raise NotImplementedError("BruteForceBLAS doesn't support metric %s" % metric)
self._metric = metric
self._precision = precision
self.name = 'BruteForceBLAS()'

def fit(self, X):
"""Initialize the search index."""
lens = (X ** 2).sum(-1) # precompute (squared) length of each vector
if self._metric == 'angular':
X /= numpy.sqrt(lens)[..., numpy.newaxis] # normalize index vectors to unit length
self.index = numpy.ascontiguousarray(X, dtype=self._precision)
elif self._metric == 'euclidean':
self.index = numpy.ascontiguousarray(X, dtype=self._precision)
self.lengths = numpy.ascontiguousarray(lens, dtype=self._precision)
else:
assert False, "invalid metric" # shouldn't get past the constructor!

def query(self, v, n):
"""Find indices of `n` most similar vectors from the index to query vector `v`."""
v = numpy.ascontiguousarray(v, dtype=self._precision) # use same precision for query as for index
# HACK we ignore query length as that's a constant not affecting the final ordering
if self._metric == 'angular':
# argmax_a cossim(a, b) = argmax_a dot(a, b) / |a||b| = argmin_a -dot(a, b)
dists = -numpy.dot(self.index, v)
elif self._metric == 'euclidean':
# argmin_a (a - b)^2 = argmin_a a^2 - 2ab + b^2 = argmin_a a^2 - 2ab
dists = self.lengths - 2 * numpy.dot(self.index, v)
else:
assert False, "invalid metric" # shouldn't get past the constructor!
indices = numpy.argpartition(dists, n)[:n] # partition-sort by distance, get `n` closest
return sorted(indices, key=lambda index: dists[index]) # sort `n` closest into correct order


def get_dataset(which='glove', limit=-1):
local_fn = os.path.join('install', which)
if os.path.exists(local_fn + '.gz'):
Expand Down Expand Up @@ -397,7 +434,7 @@ def run_algo(args, library, algo, results_fn):
def get_queries(args):
print('computing queries with correct results...')

bf = BruteForce(args.distance)
bf = BruteForceBLAS(args.distance)
X_train, X_test = get_dataset(which=args.dataset, limit=args.limit)

# Prepare queries
Expand Down

0 comments on commit 3c21b7d

Please sign in to comment.