Skip to content

Commit

Permalink
add l1 distance
Browse files Browse the repository at this point in the history
  • Loading branch information
olirice committed Oct 9, 2024
1 parent 668249d commit 4608aef
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
postgres-version: ['15.1.0.118']
postgres-version: ['15.1.1.78']

services:

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def read_package_variable(key, filename="__init__.py"):
long_description = (Path(__file__).parent / "README.md").read_text()

REQUIRES = [
"pgvector==0.1.*",
"pgvector==0.3.*",
"sqlalchemy==2.*",
"psycopg2-binary==2.9.*",
"flupy==1.*",
Expand Down
2 changes: 1 addition & 1 deletion src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def maybe_start_pg() -> Generator[None, None, None]:
to using the PYTEST_DB connection string"""

container_name = "vecs_pg"
image = "supabase/postgres:15.1.0.118"
image = "supabase/postgres:15.1.1.78"

connection_template = "postgresql://{user}:{pw}@{host}:{port:d}/{db}"
conn_args = parse(connection_template, PYTEST_DB)
Expand Down
13 changes: 13 additions & 0 deletions src/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,19 @@ def test_l2_index_query(client: vecs.Client) -> None:
assert len(results) == 1


def test_l1_index_query(client: vecs.Client) -> None:
dim = 4
bar = client.get_or_create_collection(name="bar", dimension=dim)
bar.upsert([("a", [1, 2, 3, 4], {})])
bar.create_index(measure=vecs.IndexMeasure.l1_distance)
results = bar.query(
data=[1, 2, 3, 4],
limit=1,
measure="l1_distance",
)
assert len(results) == 1


def test_max_inner_product_index_query(client: vecs.Client) -> None:
dim = 4
bar = client.get_or_create_collection(name="bar", dimension=dim)
Expand Down
3 changes: 3 additions & 0 deletions src/vecs/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class IndexMeasure(str, Enum):
cosine_distance = "cosine_distance"
l2_distance = "l2_distance"
max_inner_product = "max_inner_product"
l1_distance = "l1_distance"


@dataclass
Expand Down Expand Up @@ -124,12 +125,14 @@ class IndexArgsHNSW:
IndexMeasure.cosine_distance: "vector_cosine_ops",
IndexMeasure.l2_distance: "vector_l2_ops",
IndexMeasure.max_inner_product: "vector_ip_ops",
IndexMeasure.l1_distance: "vector_l1_ops",
}

INDEX_MEASURE_TO_SQLA_ACC = {
IndexMeasure.cosine_distance: lambda x: x.cosine_distance,
IndexMeasure.l2_distance: lambda x: x.l2_distance,
IndexMeasure.max_inner_product: lambda x: x.max_inner_product,
IndexMeasure.l1_distance: lambda x: x.l1_distance,
}


Expand Down

0 comments on commit 4608aef

Please sign in to comment.