Skip to content

Commit

Permalink
Vector: Add wrapper for HNSW matching function KNN_MATCH
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Jun 10, 2024
1 parent f47d2eb commit 09075ea
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 15 deletions.
10 changes: 9 additions & 1 deletion docs/working-with-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ Vector type
CrateDB's vector data type, :ref:`crate-reference:type-float_vector`,
allows to store dense vectors of float values of fixed length.

>>> from sqlalchemy_cratedb.type.vector import FloatVector
>>> from sqlalchemy_cratedb import FloatVector, knn_match

>>> class SearchIndex(Base):
... __tablename__ = 'search'
Expand All @@ -285,6 +285,14 @@ When reading it back, the ``FLOAT_VECTOR`` value will be returned as a NumPy arr
>>> query.all()
[('foo', array([42.42, 43.43, 44.44], dtype=float32))]

In order to apply search, i.e. to match embeddings against each other, use the
:ref:`crate-reference:scalar_knn_match` function like this.

>>> query = session.query(SearchIndex.name) \
... .filter(knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3))
>>> query.all()
[('foo',)]

.. hidden: Disconnect from database
>>> session.close()
Expand Down
3 changes: 2 additions & 1 deletion src/sqlalchemy_cratedb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .type.array import ObjectArray
from .type.geo import Geopoint, Geoshape
from .type.object import ObjectType
from .type.vector import FloatVector
from .type.vector import FloatVector, knn_match

if SA_VERSION < SA_1_4:
import textwrap
Expand Down Expand Up @@ -58,4 +58,5 @@
ObjectArray,
ObjectType,
match,
knn_match,
]
2 changes: 1 addition & 1 deletion src/sqlalchemy_cratedb/type/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .array import ObjectArray
from .geo import Geopoint, Geoshape
from .object import ObjectType
from .vector import FloatVector
from .vector import FloatVector, knn_match
17 changes: 8 additions & 9 deletions src/sqlalchemy_cratedb/type/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
<=>: cosine_distance
## Backlog
- The type implementation might want to be accompanied by corresponding support
for the `KNN_MATCH` function, similar to what the dialect already offers for
fulltext search through its `Match` predicate.
- After dropping support for SQLAlchemy 1.3, use
`class FloatVector(sa.TypeDecorator[t.Sequence[float]]):`
Expand All @@ -42,10 +39,13 @@
import numpy.typing as npt # pragma: no cover

import sqlalchemy as sa
from sqlalchemy.sql.expression import ColumnElement, literal
from sqlalchemy.ext.compiler import compiles


__all__ = [
"from_db",
"knn_match",
"to_db",
"FloatVector",
]
Expand Down Expand Up @@ -131,7 +131,7 @@ class KnnMatch(ColumnElement):
inherit_cache = True

def __init__(self, column, term, k=None):
super(KnnMatch, self).__init__()
super().__init__()
self.column = column
self.term = term
self.k = k
Expand All @@ -150,11 +150,10 @@ def knn_match(column, term, k):
"""
Generate a match predicate for vector search.
:param column: A reference to a column or an index, or a subcolumn, or a
dictionary of subcolumns with boost values.
:param column: A reference to a column or an index.
:param term: The term to match against. This is an array of floating point
values, which is compared to other vectors using a HNSW index.
values, which is compared to other vectors using a HNSW index search.
:param k: The `k` argument determines the number of nearest neighbours to
search in the index.
Expand All @@ -165,9 +164,9 @@ def knn_match(column, term, k):
@compiles(KnnMatch)
def compile_knn_match(knn_match, compiler, **kwargs):
"""
Clause compiler for `knn_match`.
Clause compiler for `KNN_MATCH`.
"""
return "knn_match(%s, %s, %s)" % (
return "KNN_MATCH(%s, %s, %s)" % (
knn_match.compile_column(compiler),
knn_match.compile_term(compiler),
knn_match.compile_k(compiler),
Expand Down
52 changes: 49 additions & 3 deletions tests/vector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,18 @@

import pytest
import sqlalchemy as sa
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.sql import select

from sqlalchemy_cratedb import SA_VERSION, SA_1_4
from sqlalchemy_cratedb.type import FloatVector
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base

from crate.client.cursor import Cursor

from sqlalchemy_cratedb import SA_VERSION, SA_1_4
from sqlalchemy_cratedb import FloatVector, knn_match
from sqlalchemy_cratedb.type.vector import from_db, to_db

fake_cursor = MagicMock(name="fake_cursor")
Expand Down Expand Up @@ -102,6 +106,14 @@ def test_sql_select(self):
"SELECT testdrive.data FROM testdrive", select(self.table.c.data)
)

def test_sql_match(self):
query = self.session.query(self.table.c.name) \
.filter(knn_match(self.table.c.data, [42.42, 43.43], 3))
self.assertSQL(
"SELECT testdrive.name AS testdrive_name FROM testdrive WHERE KNN_MATCH(testdrive.data, ?, ?)",
query
)


def test_from_db_success():
"""
Expand Down Expand Up @@ -201,3 +213,37 @@ def test_float_vector_as_generic():
fv = FloatVector(3)
assert isinstance(fv.as_generic(), sa.ARRAY)
assert fv.python_type is list


def test_float_vector_integration():
"""
An integration test for `FLOAT_VECTOR` and `KNN_SEARCH`.
"""
np = pytest.importorskip("numpy")

engine = sa.create_engine(f"crate://")
session = sessionmaker(bind=engine)()
Base = declarative_base()

# Define DDL.
class SearchIndex(Base):
__tablename__ = 'search'
name = sa.Column(sa.String, primary_key=True)
embedding = sa.Column(FloatVector(3))

Base.metadata.drop_all(engine, checkfirst=True)
Base.metadata.create_all(engine, checkfirst=True)

# Insert record.
foo_item = SearchIndex(name="foo", embedding=[42.42, 43.43, 44.44])
session.add(foo_item)
session.commit()
session.execute(sa.text("REFRESH TABLE search"))

# Query record.
query = session.query(SearchIndex.embedding) \
.filter(knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3))
result = query.first()

# Compare outcome.
assert np.array_equal(result.embedding, np.array([42.42, 43.43, 44.44], np.float32))

0 comments on commit 09075ea

Please sign in to comment.