From c4484caaa0676a54540f322cf9735e3a8e033cfc Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Thu, 17 Aug 2023 17:23:56 -0700 Subject: [PATCH] [ENH] Prevent unrestricted delete (#994) ## Description of changes Addresses #948, #583, #970 *Summarize the changes made by this PR.* - Improvements & Bug fixes - Prevent unrestricted deletes and instead point users to more explicit alternatives. - Make embeddings_queue properly log the exception in mock_async mode - In tests, reraise from embeddings_queue - Fix a related bug where delete() on an empty segment throws a misleading error by checking if index in segment is initialized (#970) - New functionality - None ## Test plan Added unit tests validating that error cases are properly error'ing. Existing tests should make sure delete still works. ## Documentation Changes None required. --- chromadb/api/models/Collection.py | 4 ++ chromadb/api/segment.py | 23 ++++++++++- chromadb/db/mixins/embeddings_queue.py | 12 ++++-- .../impl/vector/local_persistent_hnsw.py | 7 ++++ chromadb/test/conftest.py | 5 +++ chromadb/test/property/test_embeddings.py | 38 +++++++++++++++++++ chromadb/test/segment/test_vector.py | 28 ++++++++++++++ chromadb/test/test_api.py | 4 +- 8 files changed, 114 insertions(+), 7 deletions(-) diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index 2e1d0c98f53..6b4f7f18bd9 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -318,12 +318,16 @@ def delete( Returns: None + + Raises: + ValueError: If you don't provide either ids, where, or where_document """ ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None where = validate_where(where) if where else None where_document = ( validate_where_document(where_document) if where_document else None ) + self._client._delete(self.id, ids, where, where_document) def _validate_embedding_set( diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index e14a5067ece..7f7712922fa 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -368,11 +368,27 @@ def _delete( else None ) + # You must have at least one of non-empty ids, where, or where_document. + if ( + (ids is None or (ids is not None and len(ids) == 0)) + and (where is None or (where is not None and len(where) == 0)) + and ( + where_document is None + or (where_document is not None and len(where_document) == 0) + ) + ): + raise ValueError( + """ + You must provide either ids, where, or where_document to delete. If + you want to delete all data in a collection you can delete the + collection itself using the delete_collection method. Or alternatively, + you can get() all the relevant ids and then delete them. + """ + ) + coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.DELETE) - # TODO: Do we want to warn the user that unrestricted _delete() is 99% of the - # time a bad idea? if (where or where_document) or not ids: metadata_segment = self._manager.get_segment(collection_id, MetadataReader) records = metadata_segment.get_metadata( @@ -382,6 +398,9 @@ def _delete( else: ids_to_delete = ids + if len(ids_to_delete) == 0: + return [] + records_to_submit = [] for r in _records(t.Operation.DELETE, ids_to_delete): self._validate_embedding_record(coll, r) diff --git a/chromadb/db/mixins/embeddings_queue.py b/chromadb/db/mixins/embeddings_queue.py index 170fa0ff4bf..472e0254283 100644 --- a/chromadb/db/mixins/embeddings_queue.py +++ b/chromadb/db/mixins/embeddings_queue.py @@ -33,6 +33,10 @@ } _operation_codes_inv = {v: k for k, v in _operation_codes.items()} +# Set in conftest.py to rethrow errors in the "async" path during testing +# https://doc.pytest.org/en/latest/example/simple.html#detect-if-running-from-within-a-pytest-run +_called_from_test = False + class SqlEmbeddingsQueue(SqlDB, Producer, Consumer): """A SQL database that stores embeddings, allowing a traditional RDBMS to be used as @@ -345,7 +349,9 @@ def _notify_one( self.unsubscribe(sub.id) except BaseException as e: logger.error( - f"Exception occurred invoking consumer for subscription {sub.id}" - + f"to topic {sub.topic_name}", - e, + f"Exception occurred invoking consumer for subscription {sub.id.hex}" + + f"to topic {sub.topic_name} %s", + str(e), ) + if _called_from_test: + raise e diff --git a/chromadb/segment/impl/vector/local_persistent_hnsw.py b/chromadb/segment/impl/vector/local_persistent_hnsw.py index 3a6d920578c..0165e8358c2 100644 --- a/chromadb/segment/impl/vector/local_persistent_hnsw.py +++ b/chromadb/segment/impl/vector/local_persistent_hnsw.py @@ -73,6 +73,7 @@ class PersistentLocalHnswSegment(LocalHnswSegment): # via brute force search. _batch_size: int _brute_force_index: Optional[BruteForceIndex] + _index_initialized: bool = False _curr_batch: Batch # How many records to add to index before syncing to disk _sync_threshold: int @@ -168,6 +169,7 @@ def _init_index(self, dimensionality: int) -> None: self._index = index self._dimensionality = dimensionality + self._index_initialized = True def _persist(self) -> None: """Persist the index and data to disk""" @@ -209,6 +211,11 @@ def _write_records(self, records: Sequence[EmbeddingRecord]) -> None: for record in records: if record["embedding"] is not None: self._ensure_index(len(records), len(record["embedding"])) + if not self._index_initialized: + # If the index is not initialized here, it means that we have + # not yet added any records to the index. So we can just + # ignore the record since it was a delete. + continue self._brute_force_index = cast(BruteForceIndex, self._brute_force_index) self._max_seq_id = max(self._max_seq_id, record["seq_id"]) diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index acd2f843263..b8b9fc864b5 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -25,6 +25,7 @@ import multiprocessing from chromadb.types import SeqId, SubmitEmbeddingRecord +from chromadb.db.mixins import embeddings_queue root_logger = logging.getLogger() root_logger.setLevel(logging.DEBUG) # This will only run when testing @@ -257,3 +258,7 @@ def produce_fns( request: pytest.FixtureRequest, ) -> Generator[ProducerFn, None, None]: yield request.param + + +def pytest_configure(config): + embeddings_queue._called_from_test = True diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index 2be59684702..5b11f378b8d 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -363,3 +363,41 @@ def test_escape_chars_in_ids(api: API) -> None: assert coll.count() == 1 coll.delete(ids=[id]) assert coll.count() == 0 + + +def test_delete_empty_fails(api: API): + api.reset() + coll = api.create_collection(name="foo") + + error_valid = ( + lambda e: "You must provide either ids, where, or where_document to delete." + in e + ) + + with pytest.raises(Exception) as e: + coll.delete() + assert error_valid(str(e)) + + with pytest.raises(Exception): + coll.delete(ids=[]) + assert error_valid(str(e)) + + with pytest.raises(Exception): + coll.delete(where={}) + assert error_valid(str(e)) + + with pytest.raises(Exception): + coll.delete(where_document={}) + assert error_valid(str(e)) + + with pytest.raises(Exception): + coll.delete(where_document={}, where={}) + assert error_valid(str(e)) + + # Should not raise + coll.delete(where_document={"$contains": "bar"}) + coll.delete(where={"foo": "bar"}) + coll.delete(ids=["foo"]) + coll.delete(ids=["foo"], where={"foo": "bar"}) + coll.delete(ids=["foo"], where_document={"$contains": "bar"}) + coll.delete(ids=["foo"], where_document={"$contains": "bar"}, where={"foo": "bar"}) diff --git a/chromadb/test/segment/test_vector.py b/chromadb/test/segment/test_vector.py index d0199d43da7..cf55985d0f4 100644 --- a/chromadb/test/segment/test_vector.py +++ b/chromadb/test/segment/test_vector.py @@ -488,3 +488,31 @@ def test_upsert( result = segment.get_vectors(ids=["no_such_record"]) assert len(result) == 1 assert approx_equal_vector(result[0]["embedding"], [42, 42]) + + +def test_delete_without_add( + system: System, + vector_reader: Type[VectorReader], +) -> None: + producer = system.instance(Producer) + system.reset_state() + segment_definition = create_random_segment_definition() + topic = str(segment_definition["topic"]) + + segment = vector_reader(system, segment_definition) + segment.start() + + assert segment.count() == 0 + + delete_record = SubmitEmbeddingRecord( + id="not_in_db", + embedding=None, + encoding=None, + metadata=None, + operation=Operation.DELETE, + ) + + try: + producer.submit_embedding(topic, delete_record) + except BaseException: + pytest.fail("Unexpected error. Deleting on an empty segment should not raise.") diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index da051ff01be..dc2a21f0467 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -271,8 +271,8 @@ def test_delete(api): collection.add(**batch_records) assert collection.count() == 2 - collection.delete() - assert collection.count() == 0 + with pytest.raises(Exception): + collection.delete() def test_delete_with_index(api):