Skip to content

Commit

Permalink
add FAISS backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ekorman committed Aug 26, 2024
1 parent 49f07b5 commit 5af2591
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
with:
python-version: "3.10"
- run: pip install ".[test,pinecone]"
- run: pip install scikit-learn pynndescent annoy
- run: pip install scikit-learn pynndescent annoy faiss-cpu
- run: coverage run --source=affine -m pytest -v --durations 0 tests/unit-tests
- run: coverage report
- name: upload coverage report as artifact
Expand Down
28 changes: 27 additions & 1 deletion affine/engine/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,33 @@ def query(self, q: np.ndarray, k: int) -> list[int]:


class FAISSBackend(LocalBackend):
pass
def __init__(self, index_factory_str: str):
"""
Parameters
----------
index_factory_str : str
A string that specifies the index type to be created.
See https://github.com/facebookresearch/faiss/wiki/The-index-factory for details.
"""
self.index_factory_str = index_factory_str

def create_index(self, data: np.ndarray, metric: Metric) -> None:
try:
import faiss
except ModuleNotFoundError:
raise RuntimeError(
"FAISSBackend backend requires FAISS to be installed. See "
"https://github.com/facebookresearch/faiss/blob/main/INSTALL.md for installation instructions."
)
if metric == Metric.COSINE:
data = data / np.linalg.norm(data, axis=1).reshape(-1, 1)
self.index = faiss.index_factory(data.shape[1], self.index_factory_str)
self.index.add(data)

def query(self, q: np.ndarray, k: int) -> list[int]:
_, idxs = self.index.search(q.reshape(1, -1) / np.linalg.norm(q), k)
assert idxs.shape[0] == 1
return idxs[0].tolist()


class LocalEngine(Engine):
Expand Down
19 changes: 18 additions & 1 deletion tests/unit-tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

from affine.collection import Collection
from affine.engine import LocalEngine
from affine.engine.local import AnnoyBackend, KDTreeBackend, PyNNDescentBackend
from affine.engine.local import (
AnnoyBackend,
FAISSBackend,
KDTreeBackend,
PyNNDescentBackend,
)


def test_local_engine(generic_test_engine):
Expand Down Expand Up @@ -63,6 +68,18 @@ def test_cosine_similarity_annoy_backend(generic_test_cosine_similarity):
generic_test_cosine_similarity(db)


def test_euclidean_similarity_faiss_backend(
generic_test_euclidean_similarity,
):
db = LocalEngine(backend=FAISSBackend("Flat"))
generic_test_euclidean_similarity(db)


def test_cosine_similarity_faiss_backend(generic_test_cosine_similarity):
db = LocalEngine(backend=FAISSBackend("Flat"))
generic_test_cosine_similarity(db)


def test_local_engine_save_load(
PersonCollection: Type[Collection],
ProductCollection: Type[Collection],
Expand Down

0 comments on commit 5af2591

Please sign in to comment.