From 94babd13d9c61a08f72e19f4ef9e1e188f593ff2 Mon Sep 17 00:00:00 2001 From: "Eric O. Korman" Date: Thu, 7 Nov 2024 16:35:53 -0600 Subject: [PATCH] get by id update and docstrings --- affine/engine/base.py | 84 +++++++++++++++++++++++++++++++++++++++++-- affine/query.py | 35 ++++++++++++++++-- tests/conftest.py | 2 +- 3 files changed, 114 insertions(+), 7 deletions(-) diff --git a/affine/engine/base.py b/affine/engine/base.py index eb26607..43df171 100644 --- a/affine/engine/base.py +++ b/affine/engine/base.py @@ -22,10 +22,36 @@ def _query( def query( self, collection_class: Type[Collection], with_vectors: bool = False ) -> QueryObject: + """ + Parameters + ---------- + collection_class + the collection class to query + with_vectors + wether or not the returned objects should have their vector attributes populated + (or otherwise be set to `None`) + + Returns + ------- + QueryObject + the resulting QueryObject + """ return QueryObject(self, collection_class, with_vectors=with_vectors) @abstractmethod def insert(self, record: Collection) -> int | str: + """Insert a record + + Parameters + ---------- + record + the record to insert + + Returns + ------- + int | str + the resulting id of the inserted record + """ pass @abstractmethod @@ -35,10 +61,22 @@ def _delete_by_id(self, collection: Type[Collection], id: str) -> None: def delete( self, *, - record: Collection | str | None = None, + record: Collection | None = None, collection: Type[Collection] | None = None, id: str | None = None, ) -> None: + """Delete a record from the database. The record can either be specified + by its `Collection` object or by its id. + + Parameters + ---------- + record + the record to delete + collection + the collection the record belongs to (needed if and and only deleting a record by its id) + id + the id of the record + """ if bool(record is None) == bool(collection is None and id is None): raise ValueError( "Either record or collection and id must be provided" @@ -58,15 +96,55 @@ def delete( @abstractmethod def get_elements_by_ids( - self, collection: type, ids: list[int] + self, collection: type, ids: list[int | str] ) -> list[Collection]: + """Get elements by ids + + Parameters + ---------- + ids + list of ids + + Returns + ------- + list[collection] + the resulting collection objects + """ pass @abstractmethod def register_collection(self, collection_class: Type[Collection]) -> None: + """Register a collection to the database + + Parameters + ---------- + collection_class + the class of the collection to register. This class must inherit from `Collection`. + """ pass - def get_element_by_id(self, collection: type, id_: int) -> Collection: + def get_element_by_id( + self, collection: type, id_: int | str + ) -> Collection: + """Get an element by its id + + Parameters + ---------- + collection + the collection class the record belongs to + id_ + the id of the record + + Returns + ------- + collection + the corresponding collection object for the record. + + Raises + ------ + ValueError + if no record is found with the specified id. + """ ret = self.get_elements_by_ids(collection, [id_]) if len(ret) == 0: raise ValueError(f"No record found with id {id_}") diff --git a/affine/query.py b/affine/query.py index 23e5030..133cad4 100644 --- a/affine/query.py +++ b/affine/query.py @@ -22,6 +22,18 @@ def __init__( self._similarity = None def filter(self, filter_set: FilterSet | Filter) -> "QueryObject": + """Filter the result of a query by specified filters + + Parameters + ---------- + filter_set + the `FilterSet` or `Filter` object to use + + Returns + ------- + QueryObject + resulting `QueryObject` + """ if isinstance(filter_set, Filter): filter_set = FilterSet( filters=[filter_set], collection=filter_set.collection @@ -31,9 +43,28 @@ def filter(self, filter_set: FilterSet | Filter) -> "QueryObject": return self def all(self) -> list[Collection]: + """Get all results of a query + + Returns + ------- + list[Collection] + all of the matching records for the query + """ return self.db._query(self._filter_set, with_vectors=self.with_vectors) def limit(self, n: int) -> list[Collection]: + """Returns a fixed number of results of a query. + + Parameters + ---------- + n + how many records to retrieve. in the case of a similarity search query + this will be the `n`-closest neighbors + + Returns + ------- + list[Collection] + """ return self.db._query( self._filter_set, with_vectors=self.with_vectors, @@ -42,8 +73,6 @@ def limit(self, n: int) -> list[Collection]: ) def similarity(self, similarity: Similarity) -> "QueryObject": + """Apply a similarity search to the query""" self._similarity = similarity return self - - def get_by_id(self, id_) -> Collection: - return self.db.get_element_by_id(self.collection_class, id_) diff --git a/tests/conftest.py b/tests/conftest.py index 44e5804..b8fe500 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -110,7 +110,7 @@ def _test_engine(db: Engine): assert q9[0].name == "Apple" # check we can query by id - assert db.query(Product).get_by_id(q9[0].id).name == "Apple" + assert db.get_element_by_id(Product, q9[0].id).name == "Apple" # check we can delete db.delete(record=q9[0])