Skip to content

Commit

Permalink
new query synatx
Browse files Browse the repository at this point in the history
  • Loading branch information
ekorman committed Jul 29, 2024
1 parent 913735e commit 723bcff
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 110 deletions.
89 changes: 59 additions & 30 deletions affine/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ def __eq__(self, other: Any) -> bool:
return np.allclose(self.array, other.array)


@dataclass
class TopK:
vector: np.ndarray | list | Vector
k: int
VectorType = Vector | np.ndarray | list

# @dataclass
# class TopK:
# vector: np.ndarray | list | Vector
# k: int


@dataclass
Expand All @@ -42,11 +44,36 @@ def __and__(self, other: "Filter") -> "FilterSet":
raise ValueError("Filters must be from the same collection")
return FilterSet(filters=[self, other], collection=self.collection)

# @property
# def is_semantic_search(self) -> bool:
# return self.value


@dataclass
class TopKFilter(Filter):
# just used for typing
value: TopK
class Similarity:
collection: str
field: str
value: VectorType

def get_list(self) -> list[float]:
if isinstance(self.value, Vector):
return self.value.array.tolist()
if isinstance(self.value, np.ndarray):
return self.value.tolist()
return self.value

def get_array(self) -> np.ndarray:
if isinstance(self.value, Vector):
return self.value.array
if isinstance(self.value, np.ndarray):
return self.value
return np.array(self.value)


# @dataclass
# class TopKFilter(Filter):
# # just used for typing
# value: TopK


@dataclass
Expand All @@ -70,14 +97,16 @@ class Attribute:
collection: str
name: str

def __eq__(self, value: object) -> Filter:
if isinstance(value, TopK):
operation = "topk"
else:
operation = "eq"
def __eq__(self, value: object) -> Filter | Similarity:
if isinstance(value, VectorType):
return Similarity(
collection=self.collection,
field=self.name,
value=value,
)
return Filter(
field=self.name,
operation=operation,
operation="eq",
value=value,
collection=self.collection,
)
Expand Down Expand Up @@ -150,20 +179,20 @@ def __post_init__(self):
self.id = None


def get_topk_filter_and_non_topk_filters(
filters: list[Filter],
) -> tuple[TopKFilter | None, list[Filter]]:
topk_filters = []
non_topk_filters = []
for f in filters:
if f.operation == "topk":
topk_filters.append(f)
else:
non_topk_filters.append(f)

if len(topk_filters) > 1:
raise ValueError(
f"Only one topk filter is allowed but got {len(topk_filters)}."
)
topk_filter = topk_filters[0] if len(topk_filters) == 1 else None
return topk_filter, non_topk_filters
# def get_topk_filter_and_non_topk_filters(
# filters: list[Filter],
# ) -> tuple[TopKFilter | None, list[Filter]]:
# topk_filters = []
# non_topk_filters = []
# for f in filters:
# if f.operation == "topk":
# topk_filters.append(f)
# else:
# non_topk_filters.append(f)

# if len(topk_filters) > 1:
# raise ValueError(
# f"Only one topk filter is allowed but got {len(topk_filters)}."
# )
# topk_filter = topk_filters[0] if len(topk_filters) == 1 else None
# return topk_filter, non_topk_filters
9 changes: 7 additions & 2 deletions affine/engine/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from abc import ABC, abstractmethod
from typing import Type

from affine.collection import Collection, FilterSet
from affine.collection import Collection, FilterSet, Similarity
from affine.query import QueryObject


class Engine(ABC):
@abstractmethod
# TODO: add `return_vectors` as an argument here?
def _query(self, filter_set: FilterSet) -> list[Collection]:
def _query(
self,
filter_set: FilterSet,
similarity: Similarity | None = None,
limit: int | None = None,
) -> list[Collection]:
pass

def query(self, collection_class: Type[Collection]) -> QueryObject:
Expand Down
51 changes: 19 additions & 32 deletions affine/engine/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@

import numpy as np

from affine.collection import (
Collection,
Filter,
FilterSet,
get_topk_filter_and_non_topk_filters,
)
from affine.collection import Collection, Filter, FilterSet, Similarity
from affine.engine import Engine


Expand All @@ -26,7 +21,7 @@ def apply_filter_to_record(filter_: Filter, record: Collection) -> bool:
raise ValueError(f"Operation {filter_.operation} not supported")


def apply_non_topk_filters_to_records(
def apply_filters_to_records(
filters: list[Filter], records: list[Collection]
) -> list[Collection]:
ret = []
Expand All @@ -41,32 +36,16 @@ def apply_non_topk_filters_to_records(
return ret


def apply_topk_filter_to_records(
topk_filter: Filter, records: list[Collection]
def filter_by_similarity(
similarity: Similarity, limit: int, records: list[Collection]
) -> list[Collection]:
vectors = np.stack([getattr(r, topk_filter.field).array for r in records])
query_vector = topk_filter.value.vector.array
vectors = np.stack([getattr(r, similarity.field).array for r in records])
query_vector = similarity.get_array()
distances = np.linalg.norm(vectors - query_vector, axis=1)
topk_indices = distances.argsort()[: topk_filter.value.k]
topk_indices = distances.argsort()[:limit]
return [records[i] for i in topk_indices]


def apply_filters_to_records(
filters: list[Filter], records: list[Collection]
) -> list[Collection]:
# split out topk and other filters
topk_filter, non_topk_filters = get_topk_filter_and_non_topk_filters(
filters
)

records = apply_non_topk_filters_to_records(non_topk_filters, records)

if topk_filter is not None:
records = apply_topk_filter_to_records(topk_filter, records)

return records


class LocalEngine(Engine):
def __init__(self) -> None: # maybe add option to the init for ANN algo
self.records: dict[str, list[Collection]] = defaultdict(list)
Expand Down Expand Up @@ -97,12 +76,20 @@ def save(self, fp: str | Path | BinaryIO = None) -> None:
fp.seek(0)
pickle.dump(self.records, fp) # don't close, handle it outside

def _query(self, filter_set: FilterSet = None) -> list[Collection]:
def _query(
self,
filter_set: FilterSet,
similarity: Similarity | None = None,
limit: int | None = None,
) -> list[Collection]:
records = self.records[filter_set.collection]
if len(filter_set) == 0 or filter_set is None:
return records
records = apply_filters_to_records(filter_set.filters, records)
if similarity is None:
if limit is None:
return records
return records[:limit]

return apply_filters_to_records(filter_set.filters, records)
return filter_by_similarity(similarity, limit, records)

def insert(self, record: Collection) -> int:
record.id = self.collection_id_counter[record.__class__.__name__] + 1
Expand Down
25 changes: 13 additions & 12 deletions affine/engine/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse

from affine.collection import Collection, Filter, FilterSet, Vector
from affine.collection import Collection, Filter, FilterSet, Similarity, Vector
from affine.engine import Engine


Expand Down Expand Up @@ -85,7 +85,12 @@ def _convert_collection_to_payload(self, record: Collection) -> dict:
def register_collection(self, collection_class: Type[Collection]) -> None:
self.collection_classes[collection_class.__name__] = collection_class

def _query(self, filter_set: FilterSet) -> List[Collection]:
def _query(
self,
filter_set: FilterSet,
similarity: Similarity | None = None,
limit: int | None = None,
) -> list[Collection]:
collection_name = filter_set.collection
collection_class = self.collection_classes.get(collection_name)
if not collection_class:
Expand All @@ -97,12 +102,8 @@ def _query(self, filter_set: FilterSet) -> List[Collection]:

search_params = models.SearchParams(hnsw_ef=128, exact=False)

topk_filter = next(
(f for f in filter_set.filters if f.operation == "topk"), None
)
if topk_filter:
vector = topk_filter.value.vector.array
limit = topk_filter.value.k
if similarity:
vector = similarity.get_list()
results = self.client.search(
collection_name=collection_name,
query_vector=vector,
Expand All @@ -111,7 +112,6 @@ def _query(self, filter_set: FilterSet) -> List[Collection]:
search_params=search_params,
)
else:
limit = 100 # Default limit, adjust as needed
results = self.client.scroll(
collection_name=collection_name,
scroll_filter=qdrant_filters,
Expand Down Expand Up @@ -163,9 +163,10 @@ def _convert_filters_to_qdrant(
key=f.field, range=models.Range(lte=f.value)
)
)
elif f.operation == "topk":
# topk is handled separately in the query method
continue
else:
raise ValueError(
f"Unsupported filter operation: {f.operation}"
)

return (
models.Filter(must=qdrant_conditions)
Expand Down
30 changes: 12 additions & 18 deletions affine/engine/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,7 @@
from weaviate.collections.classes.filters import _FilterValue
from weaviate.collections.classes.internal import Object

from affine.collection import (
Collection,
Filter,
FilterSet,
Vector,
get_topk_filter_and_non_topk_filters,
)
from affine.collection import Collection, Filter, FilterSet, Similarity, Vector
from affine.engine import Engine


Expand Down Expand Up @@ -102,27 +96,27 @@ def get_weaviate_collection_and_affine_collection_class(
col = self.client.collections.get(collection_name)
return col, collection_class

def _query(self, filter_set: FilterSet) -> List[Collection]:
def _query(
self,
filter_set: FilterSet,
similarity: Similarity | None = None,
limit: int | None = None,
) -> list[Collection]:
(
col,
collection_class,
) = self.get_weaviate_collection_and_affine_collection_class(
filter_set.collection
)

topk_filter, non_topk_filters = get_topk_filter_and_non_topk_filters(
filter_set.filters
)

# Add filters
where_filter = self._build_where_filter(non_topk_filters)
if topk_filter:
where_filter = self._build_where_filter(filter_set.filters)
if similarity:
result = col.query.near_vector(
topk_filter.value.vector.array.tolist(),
target_vector=topk_filter.field,
similarity.get_list(),
target_vector=similarity.field,
filters=where_filter,
include_vector=True,
limit=topk_filter.value.k,
limit=limit,
).objects
else:
result = col.query.fetch_objects(
Expand Down
22 changes: 18 additions & 4 deletions affine/query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Type

from affine.collection import Collection, Filter, FilterSet
from affine.collection import Collection, Filter, FilterSet, Similarity

if TYPE_CHECKING:
from affine.engine import Engine
Expand All @@ -10,18 +10,32 @@ class QueryObject:
def __init__(self, db: "Engine", collection_class: Type[Collection]):
self.db = db
self.collection_class = collection_class
self._filter_set = FilterSet(
filters=[], collection=collection_class.__name__
)
self._similarity = None

def filter(self, filter_set: FilterSet | Filter) -> list[Collection]:
# def filter(self, filter_set: FilterSet | Filter) -> list[Collection]:
def filter(self, filter_set: FilterSet | Filter) -> "QueryObject":
if isinstance(filter_set, Filter):
filter_set = FilterSet(
filters=[filter_set], collection=filter_set.collection
)
return self.db._query(filter_set)

self._filter_set = self._filter_set & filter_set
return self

def all(self) -> list[Collection]:
return self.db._query(self._filter_set)

def limit(self, n: int) -> list[Collection]:
return self.db._query(
FilterSet(filters=[], collection=self.collection_class.__name__)
self._filter_set, limit=n, similarity=self._similarity
)

def similarity(self, similarity: Similarity) -> "QueryObject":
self._similarity = similarity
return self

def get_by_id(self, id_) -> Collection:
return self.db.get_element_by_id(self.collection_class, id_)
Loading

0 comments on commit 723bcff

Please sign in to comment.