Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update delete api #10

Merged
merged 1 commit into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion affine/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,33 @@ def insert(self, record: Collection) -> int | str:
pass

@abstractmethod
def delete(self, record: Collection) -> None:
def _delete_by_id(self, collection: Type[Collection], id: str) -> None:
pass

def delete(
self,
*,
record: Collection | str | None = None,
collection: Type[Collection] | None = None,
id: str | None = None,
) -> None:
if bool(record is None) == bool(collection is None and id is None):
raise ValueError(
"Either record or collection and id must be provided"
)
if record is not None:
if collection is not None or id is not None:
raise ValueError(
"Either record or collection and id must be provided"
)
self._delete_by_id(record.__class__, record.id)
else:
if collection is None or id is None:
raise ValueError(
"Either record or collection and id must be provided"
)
self._delete_by_id(collection, id)

@abstractmethod
def get_elements_by_ids(
self, collection: type, ids: list[int]
Expand Down
8 changes: 4 additions & 4 deletions affine/engine/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ def insert(self, record: Collection) -> int:
def register_collection(self, collection_class: Type[Collection]) -> None:
pass

def delete(self, record: Collection) -> None:
collection_name = record.__class__.__name__
def _delete_by_id(self, collection: Type[Collection], id: str) -> None:
collection_name = collection.__name__
for r in self.records[collection_name]:
if r.id == record.id:
if r.id == id:
self.records[collection_name].remove(r)
return
raise ValueError(
f"Record with id {record.id} not found in collection {collection_name}"
f"Record with id {id} not found in collection {collection_name}"
)

def get_elements_by_ids(
Expand Down
10 changes: 5 additions & 5 deletions affine/engine/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ def _query(
for point in results
]

def delete(self, record: Collection) -> None:
collection_name = record.__class__.__name__
self.register_collection(record.__class__)
self._ensure_collection_exists(record.__class__)
def _delete_by_id(self, collection: Type[Collection], id: str) -> None:
collection_name = collection.__name__
self.register_collection(collection)
self._ensure_collection_exists(collection)
self.client.delete(
collection_name=collection_name,
points_selector=models.PointIdsList(points=[record.id]),
points_selector=models.PointIdsList(points=[id]),
)

def _convert_filters_to_qdrant(
Expand Down
6 changes: 3 additions & 3 deletions affine/engine/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,11 @@ def _query(
for obj in result
]

def delete(self, record: Collection) -> None:
def _delete_by_id(self, collection: Type[Collection], id: str) -> None:
col, _ = self.get_weaviate_collection_and_affine_collection_class(
record.__class__.__name__
collection.__name__
)
col.data.delete_by_id(record.id)
col.data.delete_by_id(id)

def get_elements_by_ids(
self, collection: Type[Collection], ids: List[str]
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _test_engine(db: Engine):
assert db.query(Product).get_by_id(q9[0].id).name == "Apple"

# check we can delete
db.delete(q9[0])
db.delete(record=q9[0])
assert db.query(Product).all() == []

# for non-local engines check `with_vector`
Expand Down
Loading