diff --git a/ann_benchmarks.py b/ann_benchmarks.py index 217178e6a..1402bab6c 100644 --- a/ann_benchmarks.py +++ b/ann_benchmarks.py @@ -22,7 +22,7 @@ class BaseANN(object): pass - + class LSHF(BaseANN): def __init__(self, metric, n_estimators=10, n_candidates=50): self.name = 'LSHF(n_est=%d, n_cand=%d)' % (n_estimators, n_candidates) @@ -118,7 +118,7 @@ def __init__(self, metric, n_trees, n_candidates): self._n_trees = n_trees self._n_candidates = n_candidates self._metric = metric - self.name = 'PANNS(n_trees=%d, n_cand=%d)' % (n_trees, n_candidates) + self.name = 'PANNS(n_trees=%d, n_cand=%d)' % (n_trees, n_candidates) def fit(self, X): self._panns = panns.PannsIndex(X.shape[1], metric=self._metric) @@ -175,17 +175,40 @@ def query(self, v, n): class BruteForce(BaseANN): - def __init__(self, metric): + """kNN search that uses a linear scan = brute force.""" + def __init__(self, metric, precision=numpy.float32): + if metric not in ('angular', 'euclidean'): + raise NotImplementedError("BruteForce doesn't support metric %s" % metric) self._metric = metric + self._precision = precision self.name = 'BruteForce()' def fit(self, X): - metric = {'angular': 'cosine', 'euclidean': 'l2'}[self._metric] - self._nbrs = sklearn.neighbors.NearestNeighbors(algorithm='brute', metric=metric) - self._nbrs.fit(X) + """Initialize the search index.""" + lens = (X ** 2).sum(-1) # precompute (squared) length of each vector + if self._metric == 'angular': + X /= numpy.sqrt(lens)[..., numpy.newaxis] # normalize index vectors to unit length + self.index = numpy.ascontiguousarray(X, dtype=self._precision) + elif self._metric == 'euclidean': + self.index = numpy.ascontiguousarray(X, dtype=self._precision) + self.lengths = numpy.ascontiguousarray(lens, dtype=self._precision) + else: + assert False, "invalid metric" # shouldn't get past the constructor! def query(self, v, n): - return list(self._nbrs.kneighbors(v, return_distance=False, n_neighbors=n)[0]) + """Find indices of `n` most similar vectors from the index to query vector `v`.""" + v = numpy.ascontiguousarray(v, dtype=self._precision) # use same precision for query as for index + # HACK we ignore query length as that's a constant not affecting the final ordering + if self._metric == 'angular': + # argmax_a cossim(a, b) = argmax_a dot(a, b) / |a||b| = argmin_a -dot(a, b) + dists = -numpy.dot(self.index, v) + elif self._metric == 'euclidean': + # argmin_a (a - b)^2 = argmin_a a^2 - 2ab + b^2 = argmin_a a^2 - 2ab + dists = self.lengths - 2 * numpy.dot(self.index, v) + else: + assert False, "invalid metric" # shouldn't get past the constructor! + indices = numpy.argpartition(dists, n)[:n] # partition-sort by distance, get `n` closest + return sorted(indices, key=lambda index: dists[index]) # sort `n` closest into correct order def get_dataset(which='glove', limit=-1): @@ -246,7 +269,8 @@ def get_queries(args): print len(queries), '...' return queries - + + def get_algos(m): return { 'lshf': [LSHF(m, 5, 10), LSHF(m, 5, 20), LSHF(m, 10, 20), LSHF(m, 10, 50), LSHF(m, 20, 100)], @@ -305,7 +329,7 @@ def get_fn(base, args): if os.path.exists(results_fn): for line in open(results_fn): algos_already_ran.add(line.strip().split('\t')[1]) - + algos = get_algos(args.distance) algos_flat = [] @@ -313,7 +337,7 @@ def get_fn(base, args): for algo in algos[library]: if algo.name not in algos_already_ran: algos_flat.append((library, algo)) - + random.shuffle(algos_flat) print 'order:', algos_flat diff --git a/install.sh b/install.sh index 59068bcb5..9dbffd901 100644 --- a/install.sh +++ b/install.sh @@ -1,6 +1,6 @@ sudo apt-get install -y python-numpy python-scipy cd install -for fn in annoy.sh panns.sh nearpy.sh sklearn.sh flann.sh kgraph.sh glove.sh sift.sh +for fn in bruteforce.sh annoy.sh panns.sh nearpy.sh sklearn.sh flann.sh kgraph.sh glove.sh sift.sh do source $fn done diff --git a/install/bruteforce.sh b/install/bruteforce.sh new file mode 100644 index 000000000..19c649d30 --- /dev/null +++ b/install/bruteforce.sh @@ -0,0 +1,3 @@ +sudo apt-get install -y python-pip python-dev +sudo apt-get install -y libatlas-dev libatlas3gf-base +sudo apt-get install -y python-numpy