From ee4ae63618ca9a69f698eeab40c95324264aa81f Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 5 Sep 2023 12:27:59 +0300 Subject: [PATCH] feat: CIP-5: Large Batch Handling Improvements Proposal - Minor improvement suggested by @imartinez to pass API to create_batches utility method. Refs: #1049 --- chromadb/test/property/test_add.py | 2 +- chromadb/utils/batch_utils.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/chromadb/test/property/test_add.py b/chromadb/test/property/test_add.py index f361a166dfa..d80de52727f 100644 --- a/chromadb/test/property/test_add.py +++ b/chromadb/test/property/test_add.py @@ -85,7 +85,7 @@ def test_add_large(api: API, collection: strategies.Collection) -> None: with pytest.raises(Exception): coll.add(**normalized_record_set) return - for batch in create_batches(api.max_batch_size, **record_set): + for batch in create_batches(api, **record_set): coll.add(*batch) invariants.count(coll, cast(strategies.RecordSet, normalized_record_set)) diff --git a/chromadb/utils/batch_utils.py b/chromadb/utils/batch_utils.py index 04258e4e97f..e9688ef38e0 100644 --- a/chromadb/utils/batch_utils.py +++ b/chromadb/utils/batch_utils.py @@ -1,5 +1,6 @@ from typing import Optional, Tuple, List +from chromadb.api import API from chromadb.api.types import ( Documents, Embeddings, @@ -9,7 +10,7 @@ def create_batches( - max_batch_size: int, + api: API, ids: IDs, embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, @@ -18,15 +19,15 @@ def create_batches( _batches: List[ Tuple[IDs, Embeddings, Optional[Metadatas], Optional[Documents]] ] = [] - if len(ids) > max_batch_size: + if len(ids) > api.max_batch_size: # create split batches - for i in range(0, len(ids), max_batch_size): + for i in range(0, len(ids), api.max_batch_size): _batches.append( ( # type: ignore - ids[i : i + max_batch_size], - embeddings[i : i + max_batch_size] if embeddings else None, - metadatas[i : i + max_batch_size] if metadatas else None, - documents[i : i + max_batch_size] if documents else None, + ids[i : i + api.max_batch_size], + embeddings[i : i + api.max_batch_size] if embeddings else None, + metadatas[i : i + api.max_batch_size] if metadatas else None, + documents[i : i + api.max_batch_size] if documents else None, ) ) else: