From 450c4255a96ff4609f0384cb7b32e70748c8fde8 Mon Sep 17 00:00:00 2001 From: "Eric O. Korman" Date: Tue, 6 Aug 2024 10:12:57 -0500 Subject: [PATCH] update delete api --- affine/engine/base.py | 26 +++++++++++++++++++++++++- affine/engine/local.py | 8 ++++---- affine/engine/qdrant.py | 10 +++++----- affine/engine/weaviate.py | 6 +++--- tests/conftest.py | 2 +- 5 files changed, 38 insertions(+), 14 deletions(-) diff --git a/affine/engine/base.py b/affine/engine/base.py index 812b8dd..6156854 100644 --- a/affine/engine/base.py +++ b/affine/engine/base.py @@ -30,9 +30,33 @@ def insert(self, record: Collection) -> int | str: pass @abstractmethod - def delete(self, record: Collection) -> None: + def _delete_by_id(self, collection: Type[Collection], id: str) -> None: pass + def delete( + self, + *, + record: Collection | str | None = None, + collection: Type[Collection] | None = None, + id: str | None = None, + ) -> None: + if bool(record is None) == bool(collection is None and id is None): + raise ValueError( + "Either record or collection and id must be provided" + ) + if record is not None: + if collection is not None or id is not None: + raise ValueError( + "Either record or collection and id must be provided" + ) + self._delete_by_id(record.__class__, record.id) + else: + if collection is None or id is None: + raise ValueError( + "Either record or collection and id must be provided" + ) + self._delete_by_id(collection, id) + @abstractmethod def get_elements_by_ids( self, collection: type, ids: list[int] diff --git a/affine/engine/local.py b/affine/engine/local.py index 95c35ac..a47b9d4 100644 --- a/affine/engine/local.py +++ b/affine/engine/local.py @@ -106,14 +106,14 @@ def insert(self, record: Collection) -> int: def register_collection(self, collection_class: Type[Collection]) -> None: pass - def delete(self, record: Collection) -> None: - collection_name = record.__class__.__name__ + def _delete_by_id(self, collection: Type[Collection], id: str) -> None: + collection_name = collection.__name__ for r in self.records[collection_name]: - if r.id == record.id: + if r.id == id: self.records[collection_name].remove(r) return raise ValueError( - f"Record with id {record.id} not found in collection {collection_name}" + f"Record with id {id} not found in collection {collection_name}" ) def get_elements_by_ids( diff --git a/affine/engine/qdrant.py b/affine/engine/qdrant.py index 7e92253..0fc9eac 100644 --- a/affine/engine/qdrant.py +++ b/affine/engine/qdrant.py @@ -142,13 +142,13 @@ def _query( for point in results ] - def delete(self, record: Collection) -> None: - collection_name = record.__class__.__name__ - self.register_collection(record.__class__) - self._ensure_collection_exists(record.__class__) + def _delete_by_id(self, collection: Type[Collection], id: str) -> None: + collection_name = collection.__name__ + self.register_collection(collection) + self._ensure_collection_exists(collection) self.client.delete( collection_name=collection_name, - points_selector=models.PointIdsList(points=[record.id]), + points_selector=models.PointIdsList(points=[id]), ) def _convert_filters_to_qdrant( diff --git a/affine/engine/weaviate.py b/affine/engine/weaviate.py index 31bf674..7ab309a 100644 --- a/affine/engine/weaviate.py +++ b/affine/engine/weaviate.py @@ -158,11 +158,11 @@ def _query( for obj in result ] - def delete(self, record: Collection) -> None: + def _delete_by_id(self, collection: Type[Collection], id: str) -> None: col, _ = self.get_weaviate_collection_and_affine_collection_class( - record.__class__.__name__ + collection.__name__ ) - col.data.delete_by_id(record.id) + col.data.delete_by_id(id) def get_elements_by_ids( self, collection: Type[Collection], ids: List[str] diff --git a/tests/conftest.py b/tests/conftest.py index 3f59b9f..5b2ffea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -107,7 +107,7 @@ def _test_engine(db: Engine): assert db.query(Product).get_by_id(q9[0].id).name == "Apple" # check we can delete - db.delete(q9[0]) + db.delete(record=q9[0]) assert db.query(Product).all() == [] # for non-local engines check `with_vector`