From ae4733cd5ccca6c6474134d853e41712cf2e9d41 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Thu, 22 Feb 2024 12:26:07 +0200 Subject: [PATCH] feat: Added thread pool limiter - Configurable via `chroma_server_thread_pool_size` setting --- chromadb/config.py | 12 +++++++++--- chromadb/server/fastapi/__init__.py | 20 ++++++++++++++++++-- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/chromadb/config.py b/chromadb/config.py index b4a78d5746cd..f670199b8b5b 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -82,6 +82,7 @@ DEFAULT_TENANT = "default_tenant" DEFAULT_DATABASE = "default_database" + class Settings(BaseSettings): # type: ignore environment: str = "" @@ -139,6 +140,8 @@ def empty_str_to_none(cls, v: str) -> Optional[str]: return v chroma_server_nofile: Optional[int] = None + # the number of maximum threads to handle synchronous tasks in the FastAPI server + chroma_server_thread_pool_size: Optional[int] = 40 pulsar_broker_url: Optional[str] = None pulsar_admin_port: Optional[int] = 8080 @@ -320,13 +323,16 @@ def __init__(self, settings: Settings): if settings[key] is not None: raise ValueError(LEGACY_ERROR) - if settings["chroma_segment_cache_policy"] is not None and settings["chroma_segment_cache_policy"] != "LRU": + if ( + settings["chroma_segment_cache_policy"] is not None + and settings["chroma_segment_cache_policy"] != "LRU" + ): logger.error( - f"Failed to set chroma_segment_cache_policy: Only LRU is available." + "Failed to set chroma_segment_cache_policy: Only LRU is available." ) if settings["chroma_memory_limit_bytes"] == 0: logger.error( - f"Failed to set chroma_segment_cache_policy: chroma_memory_limit_bytes is require." + "Failed to set chroma_segment_cache_policy: chroma_memory_limit_bytes is require." ) # Apply the nofile limit if set diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index 394b95a22c3e..4ac6f324086a 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -5,6 +5,7 @@ # https://anyio.readthedocs.io/en/stable/threads.html#adjusting-the-default-maximum-worker-thread-count from anyio import ( to_thread, + CapacityLimiter, ) # this is used to transform sync code to async. By default, AnyIO uses 40 threads pool from fastapi import FastAPI as _FastAPI, Response, Request from fastapi.responses import JSONResponse @@ -141,6 +142,8 @@ def __init__(self, settings: Settings): self._system = System(settings) self._api: ServerAPI = self._system.instance(ServerAPI) self._opentelemetry_client = self._api.require(OpenTelemetryClient) + self._thread_pool_size = settings.chroma_server_thread_pool_size + self._capacity_limiter = CapacityLimiter(self._thread_pool_size) self._system.start() self._app.middleware("http")(check_http_version_middleware) @@ -487,7 +490,11 @@ async def delete_collection( database: str = DEFAULT_DATABASE, ) -> None: return await to_thread.run_sync( # type: ignore - self._api.delete_collection, collection_name, tenant, database + self._api.delete_collection, + collection_name, + tenant, + database, + limiter=self._capacity_limiter, ) @trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION) @@ -511,6 +518,7 @@ async def add(self, request: Request, collection_id: str) -> None: add.metadatas, add.documents, add.uris, + limiter=self._capacity_limiter, ) except InvalidDimensionException as e: raise HTTPException(status_code=500, detail=str(e)) @@ -536,6 +544,7 @@ async def update(self, request: Request, collection_id: str) -> None: add.metadatas, add.documents, add.uris, + limiter=self._capacity_limiter, ) @trace_method("FastAPI.upsert", OpenTelemetryGranularity.OPERATION) @@ -559,6 +568,7 @@ async def upsert(self, request: Request, collection_id: str) -> None: upsert.metadatas, upsert.documents, upsert.uris, + limiter=self._capacity_limiter, ) @trace_method("FastAPI.get", OpenTelemetryGranularity.OPERATION) @@ -585,6 +595,7 @@ async def get(self, collection_id: str, get: GetEmbedding) -> GetResult: None, get.where_document, get.include, + limiter=self._capacity_limiter, ), ) @@ -606,6 +617,7 @@ async def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID delete.ids, delete.where, delete.where_document, + limiter=self._capacity_limiter, ), ) @@ -620,7 +632,10 @@ async def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID ) async def count(self, collection_id: str) -> int: return cast( - int, await to_thread.run_sync(self._api._count, _uuid(collection_id)) + int, + await to_thread.run_sync( + self._api._count, _uuid(collection_id), limiter=self._capacity_limiter + ), ) @trace_method("FastAPI.reset", OpenTelemetryGranularity.OPERATION) @@ -656,6 +671,7 @@ async def get_nearest_neighbors( query.where, query.where_document, query.include, + limiter=self._capacity_limiter, ), ) return nnresult