From c8b017af2fb5a1a4aba7fc4a2ef79038a5df4f6c Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Mon, 18 Sep 2023 23:00:57 +0300 Subject: [PATCH] [ENH]: CIP-5: Large Batch Handling Improvements Proposal (#1077) - Including only CIP for review. Refs: #1049 ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - New proposal to handle large batches of embeddings gracefully ## Test plan *How are these changes tested?* - [ ] Tests pass locally with `pytest` for python, `yarn test` for js ## Documentation Changes TBD --------- Signed-off-by: sunilkumardash9 Co-authored-by: Sunil Kumar Dash <47926185+sunilkumardash9@users.noreply.github.com> --- chromadb/api/__init__.py | 7 ++ chromadb/api/fastapi.py | 81 +++++++++++-------- chromadb/api/segment.py | 46 +++++++++-- chromadb/api/types.py | 12 ++- chromadb/server/fastapi/__init__.py | 8 ++ chromadb/test/property/test_add.py | 81 ++++++++++++++++++- chromadb/test/test_api.py | 18 +++++ chromadb/utils/batch_utils.py | 34 ++++++++ ...CIP_5_Large_Batch_Handling_Improvements.md | 59 ++++++++++++++ 9 files changed, 302 insertions(+), 44 deletions(-) create mode 100644 chromadb/utils/batch_utils.py create mode 100644 docs/CIP_5_Large_Batch_Handling_Improvements.md diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index c1c83580e9e..50f2ff1ecef 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -378,3 +378,10 @@ def get_settings(self) -> Settings: """ pass + + @property + @abstractmethod + def max_batch_size(self) -> int: + """Return the maximum number of records that can be submitted in a single call + to submit_embeddings.""" + pass diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index c08458a2fcb..2ddd537ebff 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional, cast +from typing import Optional, cast, Tuple from typing import Sequence from uuid import UUID @@ -23,6 +23,7 @@ GetResult, QueryResult, CollectionMetadata, + validate_batch, ) from chromadb.auth import ( ClientAuthProvider, @@ -38,6 +39,7 @@ class FastAPI(API): _settings: Settings + _max_batch_size: int = -1 @staticmethod def _validate_host(host: str) -> None: @@ -296,6 +298,29 @@ def _delete( raise_chroma_error(resp) return cast(IDs, resp.json()) + def _submit_batch( + self, + batch: Tuple[ + IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents] + ], + url: str, + ) -> requests.Response: + """ + Submits a batch of embeddings to the database + """ + resp = self._session.post( + self._api_url + url, + data=json.dumps( + { + "ids": batch[0], + "embeddings": batch[1], + "metadatas": batch[2], + "documents": batch[3], + } + ), + ) + return resp + @override def _add( self, @@ -309,18 +334,9 @@ def _add( Adds a batch of embeddings to the database - pass in column oriented data lists """ - resp = self._session.post( - self._api_url + "/collections/" + str(collection_id) + "/add", - data=json.dumps( - { - "ids": ids, - "embeddings": embeddings, - "metadatas": metadatas, - "documents": documents, - } - ), - ) - + batch = (ids, embeddings, metadatas, documents) + validate_batch(batch, {"max_batch_size": self.max_batch_size}) + resp = self._submit_batch(batch, "/collections/" + str(collection_id) + "/add") raise_chroma_error(resp) return True @@ -337,18 +353,11 @@ def _update( Updates a batch of embeddings in the database - pass in column oriented data lists """ - resp = self._session.post( - self._api_url + "/collections/" + str(collection_id) + "/update", - data=json.dumps( - { - "ids": ids, - "embeddings": embeddings, - "metadatas": metadatas, - "documents": documents, - } - ), + batch = (ids, embeddings, metadatas, documents) + validate_batch(batch, {"max_batch_size": self.max_batch_size}) + resp = self._submit_batch( + batch, "/collections/" + str(collection_id) + "/update" ) - resp.raise_for_status() return True @@ -365,18 +374,11 @@ def _upsert( Upserts a batch of embeddings in the database - pass in column oriented data lists """ - resp = self._session.post( - self._api_url + "/collections/" + str(collection_id) + "/upsert", - data=json.dumps( - { - "ids": ids, - "embeddings": embeddings, - "metadatas": metadatas, - "documents": documents, - } - ), + batch = (ids, embeddings, metadatas, documents) + validate_batch(batch, {"max_batch_size": self.max_batch_size}) + resp = self._submit_batch( + batch, "/collections/" + str(collection_id) + "/upsert" ) - resp.raise_for_status() return True @@ -434,6 +436,15 @@ def get_settings(self) -> Settings: """Returns the settings of the client""" return self._settings + @property + @override + def max_batch_size(self) -> int: + if self._max_batch_size == -1: + resp = self._session.get(self._api_url + "/pre-flight-checks") + raise_chroma_error(resp) + self._max_batch_size = cast(int, resp.json()["max_batch_size"]) + return self._max_batch_size + def raise_chroma_error(resp: requests.Response) -> None: """Raises an error if the response is not ok, using a ChromaError if possible""" diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 7f7712922fa..dd846891b28 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -26,6 +26,7 @@ validate_update_metadata, validate_where, validate_where_document, + validate_batch, ) from chromadb.telemetry.events import CollectionAddEvent, CollectionDeleteEvent @@ -38,6 +39,7 @@ import logging import re + logger = logging.getLogger(__name__) @@ -241,9 +243,18 @@ def _add( ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.ADD) - + validate_batch( + (ids, embeddings, metadatas, documents), + {"max_batch_size": self.max_batch_size}, + ) records_to_submit = [] - for r in _records(t.Operation.ADD, ids, embeddings, metadatas, documents): + for r in _records( + t.Operation.ADD, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + ): self._validate_embedding_record(coll, r) records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) @@ -262,9 +273,18 @@ def _update( ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) - + validate_batch( + (ids, embeddings, metadatas, documents), + {"max_batch_size": self.max_batch_size}, + ) records_to_submit = [] - for r in _records(t.Operation.UPDATE, ids, embeddings, metadatas, documents): + for r in _records( + t.Operation.UPDATE, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + ): self._validate_embedding_record(coll, r) records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) @@ -282,9 +302,18 @@ def _upsert( ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) - + validate_batch( + (ids, embeddings, metadatas, documents), + {"max_batch_size": self.max_batch_size}, + ) records_to_submit = [] - for r in _records(t.Operation.UPSERT, ids, embeddings, metadatas, documents): + for r in _records( + t.Operation.UPSERT, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + ): self._validate_embedding_record(coll, r) records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) @@ -524,6 +553,11 @@ def reset(self) -> bool: def get_settings(self) -> Settings: return self._settings + @property + @override + def max_batch_size(self) -> int: + return self._producer.max_batch_size + def _topic(self, collection_id: UUID) -> str: return f"persistent://{self._tenant_id}/{self._topic_ns}/{collection_id}" diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 7979dba624e..017e356ffac 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Sequence, TypeVar, List, Dict, Any +from typing import Optional, Union, Sequence, TypeVar, List, Dict, Any, Tuple from typing_extensions import Literal, TypedDict, Protocol import chromadb.errors as errors from chromadb.types import ( @@ -367,3 +367,13 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings: f"Expected each value in the embedding to be a int or float, got {embeddings}" ) return embeddings + + +def validate_batch( + batch: Tuple[IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents]], + limits: Dict[str, Any], +) -> None: + if len(batch[0]) > limits["max_batch_size"]: + raise ValueError( + f"Batch size {len(batch[0])} exceeds maximum batch size {limits['max_batch_size']}" + ) diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index d8e43c51081..e92d16d63ba 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -126,6 +126,9 @@ def __init__(self, settings: Settings): self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"]) self.router.add_api_route("/api/v1/version", self.version, methods=["GET"]) self.router.add_api_route("/api/v1/heartbeat", self.heartbeat, methods=["GET"]) + self.router.add_api_route( + "/api/v1/pre-flight-checks", self.pre_flight_checks, methods=["GET"] + ) self.router.add_api_route( "/api/v1/collections", @@ -312,3 +315,8 @@ def get_nearest_neighbors( include=query.include, ) return nnresult + + def pre_flight_checks(self) -> Dict[str, Any]: + return { + "max_batch_size": self._api.max_batch_size, + } diff --git a/chromadb/test/property/test_add.py b/chromadb/test/property/test_add.py index 602df2fa81b..1980ed2a9d9 100644 --- a/chromadb/test/property/test_add.py +++ b/chromadb/test/property/test_add.py @@ -1,11 +1,15 @@ -from typing import cast +import random +import uuid +from random import randint +from typing import cast, List, Any, Dict import pytest import hypothesis.strategies as st from hypothesis import given, settings from chromadb.api import API -from chromadb.api.types import Embeddings +from chromadb.api.types import Embeddings, Metadatas import chromadb.test.property.strategies as strategies import chromadb.test.property.invariants as invariants +from chromadb.utils.batch_utils import create_batches collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll") @@ -44,6 +48,79 @@ def test_add( ) +def create_large_recordset( + min_size: int = 45000, + max_size: int = 50000, +) -> strategies.RecordSet: + size = randint(min_size, max_size) + + ids = [str(uuid.uuid4()) for _ in range(size)] + metadatas = [{"some_key": f"{i}"} for i in range(size)] + documents = [f"Document {i}" for i in range(size)] + embeddings = [[1, 2, 3] for _ in range(size)] + record_set: Dict[str, List[Any]] = { + "ids": ids, + "embeddings": cast(Embeddings, embeddings), + "metadatas": metadatas, + "documents": documents, + } + return record_set + + +@given(collection=collection_st) +@settings(deadline=None, max_examples=1) +def test_add_large(api: API, collection: strategies.Collection) -> None: + api.reset() + record_set = create_large_recordset( + min_size=api.max_batch_size, + max_size=api.max_batch_size + int(api.max_batch_size * random.random()), + ) + coll = api.create_collection( + name=collection.name, + metadata=collection.metadata, + embedding_function=collection.embedding_function, + ) + normalized_record_set = invariants.wrap_all(record_set) + + if not invariants.is_metadata_valid(normalized_record_set): + with pytest.raises(Exception): + coll.add(**normalized_record_set) + return + for batch in create_batches( + api=api, + ids=cast(List[str], record_set["ids"]), + embeddings=cast(Embeddings, record_set["embeddings"]), + metadatas=cast(Metadatas, record_set["metadatas"]), + documents=cast(List[str], record_set["documents"]), + ): + coll.add(*batch) + invariants.count(coll, cast(strategies.RecordSet, normalized_record_set)) + + +@given(collection=collection_st) +@settings(deadline=None, max_examples=1) +def test_add_large_exceeding(api: API, collection: strategies.Collection) -> None: + api.reset() + record_set = create_large_recordset( + min_size=api.max_batch_size, + max_size=api.max_batch_size + int(api.max_batch_size * random.random()), + ) + coll = api.create_collection( + name=collection.name, + metadata=collection.metadata, + embedding_function=collection.embedding_function, + ) + normalized_record_set = invariants.wrap_all(record_set) + + if not invariants.is_metadata_valid(normalized_record_set): + with pytest.raises(Exception): + coll.add(**normalized_record_set) + return + with pytest.raises(Exception) as e: + coll.add(**record_set) + assert "exceeds maximum batch size" in str(e.value) + + # TODO: This test fails right now because the ids are not sorted by the input order @pytest.mark.xfail( reason="This is expected to fail right now. We should change the API to sort the \ diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 0583d6eede7..8a12a1d9735 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -1,6 +1,8 @@ # type: ignore +import requests import chromadb +from chromadb.api.fastapi import FastAPI from chromadb.api.types import QueryResult from chromadb.config import Settings import chromadb.server.fastapi @@ -164,6 +166,22 @@ def test_heartbeat(api): assert heartbeat > datetime.now() - timedelta(seconds=10) +def test_max_batch_size(api): + print(api) + batch_size = api.max_batch_size + assert batch_size > 0 + + +def test_pre_flight_checks(api): + if not isinstance(api, FastAPI): + pytest.skip("Not a FastAPI instance") + + resp = requests.get(f"{api._api_url}/pre-flight-checks") + assert resp.status_code == 200 + assert resp.json() is not None + assert "max_batch_size" in resp.json().keys() + + batch_records = { "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], "ids": ["https://example.com/1", "https://example.com/2"], diff --git a/chromadb/utils/batch_utils.py b/chromadb/utils/batch_utils.py new file mode 100644 index 00000000000..c8c1ac1e476 --- /dev/null +++ b/chromadb/utils/batch_utils.py @@ -0,0 +1,34 @@ +from typing import Optional, Tuple, List +from chromadb.api import API +from chromadb.api.types import ( + Documents, + Embeddings, + IDs, + Metadatas, +) + + +def create_batches( + api: API, + ids: IDs, + embeddings: Optional[Embeddings] = None, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, +) -> List[Tuple[IDs, Embeddings, Optional[Metadatas], Optional[Documents]]]: + _batches: List[ + Tuple[IDs, Embeddings, Optional[Metadatas], Optional[Documents]] + ] = [] + if len(ids) > api.max_batch_size: + # create split batches + for i in range(0, len(ids), api.max_batch_size): + _batches.append( + ( # type: ignore + ids[i : i + api.max_batch_size], + embeddings[i : i + api.max_batch_size] if embeddings else None, + metadatas[i : i + api.max_batch_size] if metadatas else None, + documents[i : i + api.max_batch_size] if documents else None, + ) + ) + else: + _batches.append((ids, embeddings, metadatas, documents)) # type: ignore + return _batches diff --git a/docs/CIP_5_Large_Batch_Handling_Improvements.md b/docs/CIP_5_Large_Batch_Handling_Improvements.md new file mode 100644 index 00000000000..9b03d080f0f --- /dev/null +++ b/docs/CIP_5_Large_Batch_Handling_Improvements.md @@ -0,0 +1,59 @@ +# CIP-5: Large Batch Handling Improvements Proposal + +## Status + +Current Status: `Under Discussion` + +## **Motivation** + +As users start putting Chroma in its paces and storing ever-increasing datasets, we must ensure that errors +related to significant and potentially expensive batches are handled gracefully. This CIP proposes to add a new +setting, `max_batch_size` API, on the local segment API and use it to split large batches into smaller ones. + +## **Public Interfaces** + +The following interfaces are impacted: + +- New Server API endpoint - `/pre-flight-checks` +- New `max_batch_size` property on the `API` interface +- Updated `_add`, `_update` and `_upsert` methods on `chromadb.api.segment.SegmentAPI` +- Updated `_add`, `_update` and `_upsert` methods on `chromadb.api.fastapi.FastAPI` +- New utility library `batch_utils.py` +- New exception raised when batch size exceeds `max_batch_size` + +## **Proposed Changes** + +We propose the following changes: + +- The new `max_batch_size` property is now available in the `API` interface. The property relies on the + underlying `Producer` class + to fetch the actual value. The property will be implemented by both `chromadb.api.segment.SegmentAPI` + and `chromadb.api.fastapi.FastAPI` +- `chromadb.api.segment.SegmentAPI` will implement the `max_batch_size` property by fetching the value from the + `Producer` class. +- `chromadb.api.fastapi.FastAPI` will implement the `max_batch_size` by fetching it from a new `/pre-flight-checks` + endpoint on the Server. +- New `/pre-flight-checks` endpoint on the Server will return a dictionary with pre-flight checks the client must + fulfil to integrate with the server side. For now, we propose using this only for `max_batch_size`, but we can + add more checks in the future. The pre-flight checks will be only fetched once per client and cached for the duration + of the client's lifetime. +- Updated `_add`, `_update` and `_upsert` method on `chromadb.api.segment.SegmentAPI` to validate batch size. +- Updated `_add`, `_update` and `_upsert` method on `chromadb.api.fastapi.FastAPI` to validate batch size (client-side + validation) +- New utility library `batch_utils.py` will contain the logic for splitting batches into smaller ones. + +## **Compatibility, Deprecation, and Migration Plan** + +The change will be fully compatible with existing implementations. The changes will be transparent to the user. + +## **Test Plan** + +New tests: + +- Batch splitting tests for `chromadb.api.segment.SegmentAPI` +- Batch splitting tests for `chromadb.api.fastapi.FastAPI` +- Tests for `/pre-flight-checks` endpoint + +## **Rejected Alternatives** + +N/A