diff --git a/ann_benchmarks.py b/ann_benchmarks.py index ea8f7c390..297344a1c 100644 --- a/ann_benchmarks.py +++ b/ann_benchmarks.py @@ -178,20 +178,20 @@ def query(self, v, n): class BruteForce(BaseANN): """kNN search that uses a linear scan = brute force.""" def __init__(self, metric): - assert metric in ('angular', ), "BruteForce doesn't support metric %s" % metric + if metric not in ('angular', ): + raise NotImplementedError("BruteForce doesn't support metric %s" % metric) self._metric = metric self.name = 'BruteForce()' def fit(self, X): - """Initialize the search index. Modifies `X` in place!""" - X /= np.sqrt((X ** 2).sum(-1))[..., np.newaxis] # normalize vectors to unit length - self.index = X + """Initialize the search index.""" + self.index = X / np.sqrt((X ** 2).sum(-1))[..., np.newaxis] # normalize vectors to unit length def query(self, v, n): """Find indices of `n` most similar vector from index to query vector `v`.""" v /= np.sqrt((v ** 2).sum()) # normalize query to unit length - cossims = numpy.dot(self.index, v) # cossim = dot over normalized - indices = np.argsort(cossims)[::-1] # sort by cossim + cossims = numpy.dot(self.index, v) # cossim = dot product over normalized vectors + indices = np.argsort(cossims)[::-1] # sort by cossim, highest first return indices[:n] # return top N most similar