From eaeedb94b3325302897794053c37da2b16154e57 Mon Sep 17 00:00:00 2001 From: hammadb Date: Tue, 15 Aug 2023 14:49:55 -0700 Subject: [PATCH] Max batch size warning --- chromadb/db/mixins/embeddings_queue.py | 35 +++++++++++++++++++ chromadb/ingest/__init__.py | 10 +++++- .../test/ingest/test_producer_consumer.py | 29 +++++++++++++-- 3 files changed, 71 insertions(+), 3 deletions(-) diff --git a/chromadb/db/mixins/embeddings_queue.py b/chromadb/db/mixins/embeddings_queue.py index 225de6bc318..170fa0ff4bf 100644 --- a/chromadb/db/mixins/embeddings_queue.py +++ b/chromadb/db/mixins/embeddings_queue.py @@ -68,9 +68,13 @@ def __init__( self.callback = callback _subscriptions: Dict[str, Set[Subscription]] + _max_batch_size: Optional[int] + # How many variables are in the insert statement for a single record + VARIABLES_PER_RECORD = 6 def __init__(self, system: System): self._subscriptions = defaultdict(set) + self._max_batch_size = None super().__init__(system) @override @@ -115,6 +119,15 @@ def submit_embeddings( if len(embeddings) == 0: return [] + if len(embeddings) > self.max_batch_size: + raise ValueError( + f""" + Cannot submit more than {self.max_batch_size:,} embeddings at once. + Please submit your embeddings in batches of size + {self.max_batch_size:,} or less. + """ + ) + t = Table("embeddings_queue") insert = ( self.querybuilder() @@ -208,6 +221,28 @@ def min_seqid(self) -> SeqId: def max_seqid(self) -> SeqId: return 2**63 - 1 + @property + @override + def max_batch_size(self) -> int: + if self._max_batch_size is None: + with self.tx() as cur: + cur.execute("PRAGMA compile_options;") + compile_options = cur.fetchall() + + for option in compile_options: + if "MAX_VARIABLE_NUMBER" in option[0]: + # The pragma returns a string like 'MAX_VARIABLE_NUMBER=999' + self._max_batch_size = int(option[0].split("=")[1]) // ( + self.VARIABLES_PER_RECORD + ) + + if self._max_batch_size is None: + # This value is the default for sqlite3 versions < 3.32.0 + # It is the safest value to use if we can't find the pragma for some + # reason + self._max_batch_size = 999 // self.VARIABLES_PER_RECORD + return self._max_batch_size + def _prepare_vector_encoding_metadata( self, embedding: SubmitEmbeddingRecord ) -> Tuple[Optional[bytes], Optional[str], Optional[str]]: diff --git a/chromadb/ingest/__init__.py b/chromadb/ingest/__init__.py index 6aad15e7c91..56863e8914d 100644 --- a/chromadb/ingest/__init__.py +++ b/chromadb/ingest/__init__.py @@ -59,7 +59,15 @@ def submit_embeddings( """Add a batch of embedding records to the given topic. Returns the SeqIDs of the records. The returned SeqIDs will be in the same order as the given SubmitEmbeddingRecords. However, it is not guaranteed that the SeqIDs will be - processed in the same order as the given SubmitEmbeddingRecords.""" + processed in the same order as the given SubmitEmbeddingRecords. If the number + of records exceeds the maximum batch size, an exception will be thrown.""" + pass + + @property + @abstractmethod + def max_batch_size(self) -> int: + """Return the maximum number of records that can be submitted in a single call + to submit_embeddings.""" pass diff --git a/chromadb/test/ingest/test_producer_consumer.py b/chromadb/test/ingest/test_producer_consumer.py index 22f36958a0e..84aa69ffd07 100644 --- a/chromadb/test/ingest/test_producer_consumer.py +++ b/chromadb/test/ingest/test_producer_consumer.py @@ -98,7 +98,7 @@ def __call__(self, embeddings: Sequence[EmbeddingRecord]) -> None: if len(self.embeddings) >= n: event.set() - async def get(self, n: int) -> Sequence[EmbeddingRecord]: + async def get(self, n: int, timeout_secs: int = 10) -> Sequence[EmbeddingRecord]: "Wait until at least N embeddings are available, then return all embeddings" if len(self.embeddings) >= n: return self.embeddings[:n] @@ -106,7 +106,7 @@ async def get(self, n: int) -> Sequence[EmbeddingRecord]: event = Event() self.waiters.append((n, event)) # timeout so we don't hang forever on failure - await wait_for(event.wait(), 10) + await wait_for(event.wait(), timeout_secs) return self.embeddings[:n] @@ -322,3 +322,28 @@ async def test_multiple_topics_batch( recieved = await consume_fns[i].get(total_produced + PRODUCE_BATCH_SIZE) assert_records_match(embeddings_n[i], recieved) total_produced += PRODUCE_BATCH_SIZE + + +@pytest.mark.asyncio +async def test_max_batch_size( + producer_consumer: Tuple[Producer, Consumer], + sample_embeddings: Iterator[SubmitEmbeddingRecord], +) -> None: + producer, consumer = producer_consumer + producer.reset_state() + max_batch_size = producer_consumer[0].max_batch_size + assert max_batch_size > 0 + + # Make sure that we can produce a batch of size max_batch_size + embeddings = [next(sample_embeddings) for _ in range(max_batch_size)] + consume_fn = CapturingConsumeFn() + consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid()) + producer.submit_embeddings("test_topic", embeddings=embeddings) + received = await consume_fn.get(max_batch_size, timeout_secs=120) + assert_records_match(embeddings, received) + + embeddings = [next(sample_embeddings) for _ in range(max_batch_size + 1)] + # Make sure that we can't produce a batch of size > max_batch_size + with pytest.raises(ValueError) as e: + producer.submit_embeddings("test_topic", embeddings=embeddings) + assert "Cannot submit more than" in str(e.value)