From 4608aef83c472a41a429238a36ac49518ee269fd Mon Sep 17 00:00:00 2001 From: Oliver Rice Date: Wed, 9 Oct 2024 09:38:27 -0500 Subject: [PATCH] add l1 distance --- .github/workflows/tests.yml | 2 +- setup.py | 2 +- src/tests/conftest.py | 2 +- src/tests/test_collection.py | 13 +++++++++++++ src/vecs/collection.py | 3 +++ 5 files changed, 19 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5cd7172..0694f67 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,7 +8,7 @@ jobs: strategy: matrix: python-version: ['3.8', '3.9', '3.10', '3.11'] - postgres-version: ['15.1.0.118'] + postgres-version: ['15.1.1.78'] services: diff --git a/setup.py b/setup.py index 39d97ee..7bce72c 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ def read_package_variable(key, filename="__init__.py"): long_description = (Path(__file__).parent / "README.md").read_text() REQUIRES = [ - "pgvector==0.1.*", + "pgvector==0.3.*", "sqlalchemy==2.*", "psycopg2-binary==2.9.*", "flupy==1.*", diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 4b6e3e7..23ce338 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -21,7 +21,7 @@ def maybe_start_pg() -> Generator[None, None, None]: to using the PYTEST_DB connection string""" container_name = "vecs_pg" - image = "supabase/postgres:15.1.0.118" + image = "supabase/postgres:15.1.1.78" connection_template = "postgresql://{user}:{pw}@{host}:{port:d}/{db}" conn_args = parse(connection_template, PYTEST_DB) diff --git a/src/tests/test_collection.py b/src/tests/test_collection.py index 4200f14..b7cde59 100644 --- a/src/tests/test_collection.py +++ b/src/tests/test_collection.py @@ -806,6 +806,19 @@ def test_l2_index_query(client: vecs.Client) -> None: assert len(results) == 1 +def test_l1_index_query(client: vecs.Client) -> None: + dim = 4 + bar = client.get_or_create_collection(name="bar", dimension=dim) + bar.upsert([("a", [1, 2, 3, 4], {})]) + bar.create_index(measure=vecs.IndexMeasure.l1_distance) + results = bar.query( + data=[1, 2, 3, 4], + limit=1, + measure="l1_distance", + ) + assert len(results) == 1 + + def test_max_inner_product_index_query(client: vecs.Client) -> None: dim = 4 bar = client.get_or_create_collection(name="bar", dimension=dim) diff --git a/src/vecs/collection.py b/src/vecs/collection.py index 64169a5..e00ebce 100644 --- a/src/vecs/collection.py +++ b/src/vecs/collection.py @@ -82,6 +82,7 @@ class IndexMeasure(str, Enum): cosine_distance = "cosine_distance" l2_distance = "l2_distance" max_inner_product = "max_inner_product" + l1_distance = "l1_distance" @dataclass @@ -124,12 +125,14 @@ class IndexArgsHNSW: IndexMeasure.cosine_distance: "vector_cosine_ops", IndexMeasure.l2_distance: "vector_l2_ops", IndexMeasure.max_inner_product: "vector_ip_ops", + IndexMeasure.l1_distance: "vector_l1_ops", } INDEX_MEASURE_TO_SQLA_ACC = { IndexMeasure.cosine_distance: lambda x: x.cosine_distance, IndexMeasure.l2_distance: lambda x: x.l2_distance, IndexMeasure.max_inner_product: lambda x: x.max_inner_product, + IndexMeasure.l1_distance: lambda x: x.l1_distance, }