Skip to content

Commit

Permalink
[ENH] Prevent unrestricted delete (#994)
Browse files Browse the repository at this point in the history
## Description of changes

Addresses #948, #583, #970

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Prevent unrestricted deletes and instead point users to more explicit
alternatives.
	 - Make embeddings_queue properly log the exception in mock_async mode
	 - In tests, reraise from embeddings_queue
- Fix a related bug where delete() on an empty segment throws a
misleading error by checking if index in segment is initialized (#970)
 - New functionality
	 - None

## Test plan
Added unit tests validating that error cases are properly error'ing.
Existing tests should make sure delete still works.

## Documentation Changes
None required.
  • Loading branch information
HammadB authored Aug 18, 2023
1 parent 2fcb377 commit c4484ca
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 7 deletions.
4 changes: 4 additions & 0 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,16 @@ def delete(
Returns:
None
Raises:
ValueError: If you don't provide either ids, where, or where_document
"""
ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None
where = validate_where(where) if where else None
where_document = (
validate_where_document(where_document) if where_document else None
)

self._client._delete(self.id, ids, where, where_document)

def _validate_embedding_set(
Expand Down
23 changes: 21 additions & 2 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,27 @@ def _delete(
else None
)

# You must have at least one of non-empty ids, where, or where_document.
if (
(ids is None or (ids is not None and len(ids) == 0))
and (where is None or (where is not None and len(where) == 0))
and (
where_document is None
or (where_document is not None and len(where_document) == 0)
)
):
raise ValueError(
"""
You must provide either ids, where, or where_document to delete. If
you want to delete all data in a collection you can delete the
collection itself using the delete_collection method. Or alternatively,
you can get() all the relevant ids and then delete them.
"""
)

coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.DELETE)

# TODO: Do we want to warn the user that unrestricted _delete() is 99% of the
# time a bad idea?
if (where or where_document) or not ids:
metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
records = metadata_segment.get_metadata(
Expand All @@ -382,6 +398,9 @@ def _delete(
else:
ids_to_delete = ids

if len(ids_to_delete) == 0:
return []

records_to_submit = []
for r in _records(t.Operation.DELETE, ids_to_delete):
self._validate_embedding_record(coll, r)
Expand Down
12 changes: 9 additions & 3 deletions chromadb/db/mixins/embeddings_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
}
_operation_codes_inv = {v: k for k, v in _operation_codes.items()}

# Set in conftest.py to rethrow errors in the "async" path during testing
# https://doc.pytest.org/en/latest/example/simple.html#detect-if-running-from-within-a-pytest-run
_called_from_test = False


class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
"""A SQL database that stores embeddings, allowing a traditional RDBMS to be used as
Expand Down Expand Up @@ -345,7 +349,9 @@ def _notify_one(
self.unsubscribe(sub.id)
except BaseException as e:
logger.error(
f"Exception occurred invoking consumer for subscription {sub.id}"
+ f"to topic {sub.topic_name}",
e,
f"Exception occurred invoking consumer for subscription {sub.id.hex}"
+ f"to topic {sub.topic_name} %s",
str(e),
)
if _called_from_test:
raise e
7 changes: 7 additions & 0 deletions chromadb/segment/impl/vector/local_persistent_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class PersistentLocalHnswSegment(LocalHnswSegment):
# via brute force search.
_batch_size: int
_brute_force_index: Optional[BruteForceIndex]
_index_initialized: bool = False
_curr_batch: Batch
# How many records to add to index before syncing to disk
_sync_threshold: int
Expand Down Expand Up @@ -168,6 +169,7 @@ def _init_index(self, dimensionality: int) -> None:

self._index = index
self._dimensionality = dimensionality
self._index_initialized = True

def _persist(self) -> None:
"""Persist the index and data to disk"""
Expand Down Expand Up @@ -209,6 +211,11 @@ def _write_records(self, records: Sequence[EmbeddingRecord]) -> None:
for record in records:
if record["embedding"] is not None:
self._ensure_index(len(records), len(record["embedding"]))
if not self._index_initialized:
# If the index is not initialized here, it means that we have
# not yet added any records to the index. So we can just
# ignore the record since it was a delete.
continue
self._brute_force_index = cast(BruteForceIndex, self._brute_force_index)

self._max_seq_id = max(self._max_seq_id, record["seq_id"])
Expand Down
5 changes: 5 additions & 0 deletions chromadb/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import multiprocessing

from chromadb.types import SeqId, SubmitEmbeddingRecord
from chromadb.db.mixins import embeddings_queue

root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG) # This will only run when testing
Expand Down Expand Up @@ -257,3 +258,7 @@ def produce_fns(
request: pytest.FixtureRequest,
) -> Generator[ProducerFn, None, None]:
yield request.param


def pytest_configure(config):
embeddings_queue._called_from_test = True
38 changes: 38 additions & 0 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,41 @@ def test_escape_chars_in_ids(api: API) -> None:
assert coll.count() == 1
coll.delete(ids=[id])
assert coll.count() == 0


def test_delete_empty_fails(api: API):
api.reset()
coll = api.create_collection(name="foo")

error_valid = (
lambda e: "You must provide either ids, where, or where_document to delete."
in e
)

with pytest.raises(Exception) as e:
coll.delete()
assert error_valid(str(e))

with pytest.raises(Exception):
coll.delete(ids=[])
assert error_valid(str(e))

with pytest.raises(Exception):
coll.delete(where={})
assert error_valid(str(e))

with pytest.raises(Exception):
coll.delete(where_document={})
assert error_valid(str(e))

with pytest.raises(Exception):
coll.delete(where_document={}, where={})
assert error_valid(str(e))

# Should not raise
coll.delete(where_document={"$contains": "bar"})
coll.delete(where={"foo": "bar"})
coll.delete(ids=["foo"])
coll.delete(ids=["foo"], where={"foo": "bar"})
coll.delete(ids=["foo"], where_document={"$contains": "bar"})
coll.delete(ids=["foo"], where_document={"$contains": "bar"}, where={"foo": "bar"})
28 changes: 28 additions & 0 deletions chromadb/test/segment/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,31 @@ def test_upsert(
result = segment.get_vectors(ids=["no_such_record"])
assert len(result) == 1
assert approx_equal_vector(result[0]["embedding"], [42, 42])


def test_delete_without_add(
system: System,
vector_reader: Type[VectorReader],
) -> None:
producer = system.instance(Producer)
system.reset_state()
segment_definition = create_random_segment_definition()
topic = str(segment_definition["topic"])

segment = vector_reader(system, segment_definition)
segment.start()

assert segment.count() == 0

delete_record = SubmitEmbeddingRecord(
id="not_in_db",
embedding=None,
encoding=None,
metadata=None,
operation=Operation.DELETE,
)

try:
producer.submit_embedding(topic, delete_record)
except BaseException:
pytest.fail("Unexpected error. Deleting on an empty segment should not raise.")
4 changes: 2 additions & 2 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def test_delete(api):
collection.add(**batch_records)
assert collection.count() == 2

collection.delete()
assert collection.count() == 0
with pytest.raises(Exception):
collection.delete()


def test_delete_with_index(api):
Expand Down

0 comments on commit c4484ca

Please sign in to comment.