Skip to content

Commit

Permalink
add annoy backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ekorman committed Aug 26, 2024
1 parent 48f9193 commit 99fb2d4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
21 changes: 20 additions & 1 deletion affine/engine/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,26 @@ def query(self, q: np.ndarray, k: int) -> list[int]:


class AnnoyBackend(LocalBackend):
pass
def __init__(self, n_trees: int, n_jobs: int = -1):
self.n_trees = n_trees
self.n_jobs = n_jobs

def create_index(self, data: np.ndarray, metric: Metric) -> None:
try:
from annoy import AnnoyIndex
except ModuleNotFoundError:
raise RuntimeError(
"AnnoyBackend backend requires annoy to be installed"
)

annoy_metric = "angular" if metric == Metric.COSINE else "euclidean"
self.index = AnnoyIndex(data.shape[1], metric=annoy_metric)
for i, v in enumerate(data):
self.index.add_item(i, v)
self.index.build(self.n_trees, self.n_jobs)

def query(self, q: np.ndarray, k: int) -> list[int]:
return self.index.get_nns_by_vector(q, k)


class FAISSBackend(LocalBackend):
Expand Down
14 changes: 13 additions & 1 deletion tests/unit-tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from affine.collection import Collection
from affine.engine import LocalEngine
from affine.engine.local import KDTreeBackend, PyNNDescentBackend
from affine.engine.local import AnnoyBackend, KDTreeBackend, PyNNDescentBackend


def test_local_engine(generic_test_engine):
Expand Down Expand Up @@ -51,6 +51,18 @@ def test_cosine_similarity_pynndescent_backend(generic_test_cosine_similarity):
generic_test_cosine_similarity(db)


def test_euclidean_similarity_annoy_backend(
generic_test_euclidean_similarity,
):
db = LocalEngine(backend=AnnoyBackend(n_trees=10))
generic_test_euclidean_similarity(db)


def test_cosine_similarity_annoy_backend(generic_test_cosine_similarity):
db = LocalEngine(backend=AnnoyBackend(n_trees=10))
generic_test_cosine_similarity(db)


def test_local_engine_save_load(
PersonCollection: Type[Collection],
ProductCollection: Type[Collection],
Expand Down

0 comments on commit 99fb2d4

Please sign in to comment.