Skip to content

Commit

Permalink
get by id update and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
ekorman committed Nov 7, 2024
1 parent 12f0655 commit 94babd1
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 7 deletions.
84 changes: 81 additions & 3 deletions affine/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,36 @@ def _query(
def query(
self, collection_class: Type[Collection], with_vectors: bool = False
) -> QueryObject:
"""
Parameters
----------
collection_class
the collection class to query
with_vectors
wether or not the returned objects should have their vector attributes populated
(or otherwise be set to `None`)
Returns
-------
QueryObject
the resulting QueryObject
"""
return QueryObject(self, collection_class, with_vectors=with_vectors)

@abstractmethod
def insert(self, record: Collection) -> int | str:
"""Insert a record
Parameters
----------
record
the record to insert
Returns
-------
int | str
the resulting id of the inserted record
"""
pass

@abstractmethod
Expand All @@ -35,10 +61,22 @@ def _delete_by_id(self, collection: Type[Collection], id: str) -> None:
def delete(
self,
*,
record: Collection | str | None = None,
record: Collection | None = None,
collection: Type[Collection] | None = None,
id: str | None = None,
) -> None:
"""Delete a record from the database. The record can either be specified
by its `Collection` object or by its id.
Parameters
----------
record
the record to delete
collection
the collection the record belongs to (needed if and and only deleting a record by its id)
id
the id of the record
"""
if bool(record is None) == bool(collection is None and id is None):
raise ValueError(
"Either record or collection and id must be provided"
Expand All @@ -58,15 +96,55 @@ def delete(

@abstractmethod
def get_elements_by_ids(
self, collection: type, ids: list[int]
self, collection: type, ids: list[int | str]
) -> list[Collection]:
"""Get elements by ids
Parameters
----------
ids
list of ids
Returns
-------
list[collection]
the resulting collection objects
"""
pass

@abstractmethod
def register_collection(self, collection_class: Type[Collection]) -> None:
"""Register a collection to the database
Parameters
----------
collection_class
the class of the collection to register. This class must inherit from `Collection`.
"""
pass

def get_element_by_id(self, collection: type, id_: int) -> Collection:
def get_element_by_id(
self, collection: type, id_: int | str
) -> Collection:
"""Get an element by its id
Parameters
----------
collection
the collection class the record belongs to
id_
the id of the record
Returns
-------
collection
the corresponding collection object for the record.
Raises
------
ValueError
if no record is found with the specified id.
"""
ret = self.get_elements_by_ids(collection, [id_])
if len(ret) == 0:
raise ValueError(f"No record found with id {id_}")
Expand Down
35 changes: 32 additions & 3 deletions affine/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ def __init__(
self._similarity = None

def filter(self, filter_set: FilterSet | Filter) -> "QueryObject":
"""Filter the result of a query by specified filters
Parameters
----------
filter_set
the `FilterSet` or `Filter` object to use
Returns
-------
QueryObject
resulting `QueryObject`
"""
if isinstance(filter_set, Filter):
filter_set = FilterSet(
filters=[filter_set], collection=filter_set.collection
Expand All @@ -31,9 +43,28 @@ def filter(self, filter_set: FilterSet | Filter) -> "QueryObject":
return self

def all(self) -> list[Collection]:
"""Get all results of a query
Returns
-------
list[Collection]
all of the matching records for the query
"""
return self.db._query(self._filter_set, with_vectors=self.with_vectors)

def limit(self, n: int) -> list[Collection]:
"""Returns a fixed number of results of a query.
Parameters
----------
n
how many records to retrieve. in the case of a similarity search query
this will be the `n`-closest neighbors
Returns
-------
list[Collection]
"""
return self.db._query(
self._filter_set,
with_vectors=self.with_vectors,
Expand All @@ -42,8 +73,6 @@ def limit(self, n: int) -> list[Collection]:
)

def similarity(self, similarity: Similarity) -> "QueryObject":
"""Apply a similarity search to the query"""
self._similarity = similarity
return self

def get_by_id(self, id_) -> Collection:
return self.db.get_element_by_id(self.collection_class, id_)
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _test_engine(db: Engine):
assert q9[0].name == "Apple"

# check we can query by id
assert db.query(Product).get_by_id(q9[0].id).name == "Apple"
assert db.get_element_by_id(Product, q9[0].id).name == "Apple"

# check we can delete
db.delete(record=q9[0])
Expand Down

0 comments on commit 94babd1

Please sign in to comment.